/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "Operations" #include "HalInterfaces.h" #include "IndexedShapeWrapper.h" #include "LSTM.h" #include "OperationResolver.h" #include "OperationsUtils.h" namespace android { namespace nn { namespace unidirectional_sequence_lstm { // Inputs constexpr uint32_t kNumInputs = 28; // Input tensor of size {max_time, n_batch, n_input} constexpr uint32_t kInputTensor = 0; // Input weight tensors of size: {n_cell, n_input} constexpr uint32_t kInputToInputWeightsTensor = 1; // Optional constexpr uint32_t kInputToForgetWeightsTensor = 2; constexpr uint32_t kInputToCellWeightsTensor = 3; constexpr uint32_t kInputToOutputWeightsTensor = 4; // Recurrent weight tensors of size {n_cell, n_output} constexpr uint32_t kRecurrentToInputWeightsTensor = 5; // Optional constexpr uint32_t kRecurrentToForgetWeightsTensor = 6; constexpr uint32_t kRecurrentToCellWeightsTensor = 7; constexpr uint32_t kRecurrentToOutputWeightsTensor = 8; // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. constexpr uint32_t kCellToInputWeightsTensor = 9; // Optional constexpr uint32_t kCellToForgetWeightsTensor = 10; // Optional constexpr uint32_t kCellToOutputWeightsTensor = 11; // Optional // Gates bias tensors of size {n_cell} constexpr uint32_t kInputGateBiasTensor = 12; // Optional constexpr uint32_t kForgetGateBiasTensor = 13; constexpr uint32_t kCellGateBiasTensor = 14; constexpr uint32_t kOutputGateBiasTensor = 15; // Projection weight tensor of size {n_output, n_cell} constexpr uint32_t kProjectionWeightsTensor = 16; // Optional // Projection bias tensor of size {n_output} constexpr uint32_t kProjectionBiasTensor = 17; // Optional // Input from the output of the previous step, tensor of size {batch_size, n_output} constexpr uint32_t kOutputStateInTensor = 18; // Input from the cell state of the previous step, tensor of size {batch_size, n_cell} constexpr uint32_t kCellStateInTensor = 19; constexpr uint32_t kActivationParam = 20; constexpr uint32_t kCellClipParam = 21; constexpr uint32_t kProjClipParam = 22; constexpr uint32_t kTimeMajorParam = 23; // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. constexpr uint32_t kInputLayerNormWeightsTensor = 24; // Optional constexpr uint32_t kForgetLayerNormWeightsTensor = 25; // Optional constexpr uint32_t kCellLayerNormWeightsTensor = 26; // Optional constexpr uint32_t kOutputLayerNormWeightsTensor = 27; // Optional // Output tensors. constexpr uint32_t kNumOutputs = 1; constexpr uint32_t kOutputTensor = 0; namespace { inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) { return context->getInputBuffer(tensor) != nullptr; } inline bool isTimeMajor(IOperationExecutionContext* context) { return context->getInputValue<bool>(kTimeMajorParam); } template <typename T> inline LSTMParams getLSTMParams(IOperationExecutionContext* context) { LSTMParams params; params.activation = static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam)); params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam)); params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam)); params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor); params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor); params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor); params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor); params.use_projection_bias = hasTensor(context, kProjectionBiasTensor); return params; } } // namespace bool validate(const IOperationValidationContext* context) { NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs); const OperandType inputType = context->getInputType(kInputTensor); std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; if (inputType == OperandType::TENSOR_FLOAT32) { inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::INT32, OperandType::FLOAT32, OperandType::FLOAT32, OperandType::BOOL, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32}; outExpectedTypes = {OperandType::TENSOR_FLOAT32}; } else if (inputType == OperandType::TENSOR_FLOAT16) { inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::FLOAT16, OperandType::FLOAT16, OperandType::BOOL, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16}; outExpectedTypes = {OperandType::TENSOR_FLOAT16}; } else { NN_RET_CHECK_FAIL() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: " << toString(inputType); } NN_RET_CHECK(validateInputTypes(context, inExpectedTypes)); NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); return validateHalVersion(context, HalVersion::V1_2); } bool prepare(IOperationExecutionContext* context) { // Check that none of the required inputs are omitted const std::vector<int> requiredInputs = { kInputTensor, kInputToForgetWeightsTensor, kInputToCellWeightsTensor, kInputToOutputWeightsTensor, kRecurrentToForgetWeightsTensor, kRecurrentToCellWeightsTensor, kRecurrentToOutputWeightsTensor, kForgetGateBiasTensor, kCellGateBiasTensor, kOutputGateBiasTensor, kOutputStateInTensor, kCellStateInTensor, kActivationParam, kCellClipParam, kProjClipParam, kTimeMajorParam, }; for (const int requiredInput : requiredInputs) { NN_RET_CHECK(!context->isOmittedInput(requiredInput)) << "required input " << requiredInput << " is omitted"; } const Shape inputShape = context->getInputShape(kInputTensor); const uint32_t inputRank = getNumberOfDimensions(inputShape); NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank; const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1); const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0); const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1); const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize); const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0); const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells); const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1); if (hasTensor(context, kInputToInputWeightsTensor)) { const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize); } const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize); const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize); if (hasTensor(context, kRecurrentToInputWeightsTensor)) { const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize); } const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize); const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells); NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize); // We make sure the input-gate's parameters are either both present (regular // LSTM) or not at all (CIFG-LSTM). const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) && hasTensor(context, kRecurrentToInputWeightsTensor)) || (!hasTensor(context, kInputToInputWeightsTensor) && !hasTensor(context, kRecurrentToInputWeightsTensor)); NN_RET_CHECK(cifgWeightsAllOrNone); if (hasTensor(context, kCellToInputWeightsTensor)) { const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells); } if (hasTensor(context, kCellToForgetWeightsTensor)) { const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells); } if (hasTensor(context, kCellToOutputWeightsTensor)) { const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells); } // Making sure the peephole weights are there all or none. const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor); const bool peepholeWeightsAllOrNone = ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) && hasTensor(context, kCellToForgetWeightsTensor) && hasTensor(context, kCellToOutputWeightsTensor)) || (!hasTensor(context, kCellToInputWeightsTensor) && !hasTensor(context, kCellToForgetWeightsTensor) && !hasTensor(context, kCellToOutputWeightsTensor)); NN_RET_CHECK(peepholeWeightsAllOrNone); if (!cifgUsed) { NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor)); const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells); } else { NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor)) << "Input gate bias tensor is present when CIFG is used"; } const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells); const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells); const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells); if (hasTensor(context, kProjectionWeightsTensor)) { const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize); NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells); } if (hasTensor(context, kProjectionBiasTensor)) { const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize); } const Shape outputStateShape = context->getInputShape(kOutputStateInTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize); NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize); const Shape cellStateShape = context->getInputShape(kCellStateInTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2); NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize); NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells); if (hasTensor(context, kInputLayerNormWeightsTensor)) { const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells); } if (hasTensor(context, kForgetLayerNormWeightsTensor)) { const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells); } if (hasTensor(context, kCellLayerNormWeightsTensor)) { const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells); } if (hasTensor(context, kOutputLayerNormWeightsTensor)) { const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor); NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1); NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells); } if (cifgUsed) { NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor)) << "Input layer norm weights tensor is present when CIFG is used"; const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormWeightsTensor) && hasTensor(context, kCellLayerNormWeightsTensor) && hasTensor(context, kOutputLayerNormWeightsTensor)) || (!hasTensor(context, kForgetLayerNormWeightsTensor) && !hasTensor(context, kCellLayerNormWeightsTensor) && !hasTensor(context, kOutputLayerNormWeightsTensor)); NN_RET_CHECK(layerNormWeightsAllOrNoneCifg); } else { const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormWeightsTensor) && hasTensor(context, kForgetLayerNormWeightsTensor) && hasTensor(context, kCellLayerNormWeightsTensor) && hasTensor(context, kOutputLayerNormWeightsTensor)) || (!hasTensor(context, kInputLayerNormWeightsTensor) && !hasTensor(context, kForgetLayerNormWeightsTensor) && !hasTensor(context, kCellLayerNormWeightsTensor) && !hasTensor(context, kOutputLayerNormWeightsTensor)); NN_RET_CHECK(layerNormWeightsAllOrNone); } Shape outputShape = context->getInputShape(kInputTensor); outputShape.dimensions[2] = outputSize; return context->setOutputShape(kOutputTensor, outputShape); } bool execute(IOperationExecutionContext* context) { const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor)); const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor)); const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor); const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize; const OperandType inputType = context->getInputType(kInputTensor); switch (inputType) { case OperandType::TENSOR_FLOAT32: { std::vector<float> outputStateOut(outputStateSize); std::vector<float> cellStateOut(cellStateSize); std::vector<float> scratchBuffer(scratchSize); LSTMCell::LSTMEvalFloat32( getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor), context->getInputShape(kInputTensor), context->getInputBuffer<float>(kInputToInputWeightsTensor), context->getInputBuffer<float>(kInputToForgetWeightsTensor), context->getInputBuffer<float>(kInputToCellWeightsTensor), context->getInputBuffer<float>(kInputToOutputWeightsTensor), context->getInputShape(kInputToOutputWeightsTensor), context->getInputBuffer<float>(kRecurrentToInputWeightsTensor), context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor), context->getInputBuffer<float>(kRecurrentToCellWeightsTensor), context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor), context->getInputShape(kRecurrentToOutputWeightsTensor), context->getInputBuffer<float>(kCellToInputWeightsTensor), context->getInputBuffer<float>(kCellToForgetWeightsTensor), context->getInputBuffer<float>(kCellToOutputWeightsTensor), /*aux_input_buffer=*/nullptr, /*aux_input_to_input_weights_buffer=*/nullptr, /*aux_input_to_forget_weights_buffer=*/nullptr, /*aux_input_to_cell_weights_buffer=*/nullptr, /*aux_input_to_output_weights_buffer=*/nullptr, context->getInputBuffer<float>(kInputGateBiasTensor), context->getInputBuffer<float>(kForgetGateBiasTensor), context->getInputBuffer<float>(kCellGateBiasTensor), context->getInputBuffer<float>(kOutputGateBiasTensor), context->getInputBuffer<float>(kProjectionWeightsTensor), context->getInputBuffer<float>(kProjectionBiasTensor), context->getInputBuffer<float>(kOutputStateInTensor), context->getInputBuffer<float>(kCellStateInTensor), context->getInputBuffer<float>(kInputLayerNormWeightsTensor), context->getInputBuffer<float>(kForgetLayerNormWeightsTensor), context->getInputBuffer<float>(kCellLayerNormWeightsTensor), context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut.data(), cellStateOut.data(), context->getOutputBuffer<float>(kOutputTensor), scratchBuffer.data(), isTimeMajor(context)); } break; case OperandType::TENSOR_FLOAT16: { std::vector<_Float16> outputStateOut(outputStateSize); std::vector<_Float16> cellStateOut(cellStateSize); std::vector<_Float16> scratchBuffer(scratchSize); LSTMCell::LSTMEvalFloat16( getLSTMParams<_Float16>(context), context->getInputBuffer<_Float16>(kInputTensor), context->getInputShape(kInputTensor), context->getInputBuffer<_Float16>(kInputToInputWeightsTensor), context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor), context->getInputBuffer<_Float16>(kInputToCellWeightsTensor), context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor), context->getInputShape(kInputToOutputWeightsTensor), context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor), context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor), context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor), context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor), context->getInputShape(kRecurrentToOutputWeightsTensor), context->getInputBuffer<_Float16>(kCellToInputWeightsTensor), context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor), context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor), /*aux_input_buffer=*/nullptr, /*aux_input_to_input_weights_buffer=*/nullptr, /*aux_input_to_forget_weights_buffer=*/nullptr, /*aux_input_to_cell_weights_buffer=*/nullptr, /*aux_input_to_output_weights_buffer=*/nullptr, context->getInputBuffer<_Float16>(kInputGateBiasTensor), context->getInputBuffer<_Float16>(kForgetGateBiasTensor), context->getInputBuffer<_Float16>(kCellGateBiasTensor), context->getInputBuffer<_Float16>(kOutputGateBiasTensor), context->getInputBuffer<_Float16>(kProjectionWeightsTensor), context->getInputBuffer<_Float16>(kProjectionBiasTensor), context->getInputBuffer<_Float16>(kOutputStateInTensor), context->getInputBuffer<_Float16>(kCellStateInTensor), context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor), context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor), context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor), context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor), outputStateOut.data(), cellStateOut.data(), context->getOutputBuffer<_Float16>(kOutputTensor), scratchBuffer.data(), isTimeMajor(context)); } break; default: { LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType); return false; } } return true; } } // namespace unidirectional_sequence_lstm NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM", unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare, unidirectional_sequence_lstm::execute, .allowOmittedOperand = true); } // namespace nn } // namespace android