/*
* Copyright (C) 2019 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define LOG_TAG "Operations"
#include "HalInterfaces.h"
#include "IndexedShapeWrapper.h"
#include "LSTM.h"
#include "OperationResolver.h"
#include "OperationsUtils.h"
namespace android {
namespace nn {
namespace unidirectional_sequence_lstm {
// Inputs
constexpr uint32_t kNumInputs = 28;
// Input tensor of size {max_time, n_batch, n_input}
constexpr uint32_t kInputTensor = 0;
// Input weight tensors of size: {n_cell, n_input}
constexpr uint32_t kInputToInputWeightsTensor = 1; // Optional
constexpr uint32_t kInputToForgetWeightsTensor = 2;
constexpr uint32_t kInputToCellWeightsTensor = 3;
constexpr uint32_t kInputToOutputWeightsTensor = 4;
// Recurrent weight tensors of size {n_cell, n_output}
constexpr uint32_t kRecurrentToInputWeightsTensor = 5; // Optional
constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
constexpr uint32_t kCellToInputWeightsTensor = 9; // Optional
constexpr uint32_t kCellToForgetWeightsTensor = 10; // Optional
constexpr uint32_t kCellToOutputWeightsTensor = 11; // Optional
// Gates bias tensors of size {n_cell}
constexpr uint32_t kInputGateBiasTensor = 12; // Optional
constexpr uint32_t kForgetGateBiasTensor = 13;
constexpr uint32_t kCellGateBiasTensor = 14;
constexpr uint32_t kOutputGateBiasTensor = 15;
// Projection weight tensor of size {n_output, n_cell}
constexpr uint32_t kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr uint32_t kProjectionBiasTensor = 17; // Optional
// Input from the output of the previous step, tensor of size {batch_size, n_output}
constexpr uint32_t kOutputStateInTensor = 18;
// Input from the cell state of the previous step, tensor of size {batch_size, n_cell}
constexpr uint32_t kCellStateInTensor = 19;
constexpr uint32_t kActivationParam = 20;
constexpr uint32_t kCellClipParam = 21;
constexpr uint32_t kProjClipParam = 22;
constexpr uint32_t kTimeMajorParam = 23;
// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
constexpr uint32_t kInputLayerNormWeightsTensor = 24; // Optional
constexpr uint32_t kForgetLayerNormWeightsTensor = 25; // Optional
constexpr uint32_t kCellLayerNormWeightsTensor = 26; // Optional
constexpr uint32_t kOutputLayerNormWeightsTensor = 27; // Optional
// Output tensors.
constexpr uint32_t kNumOutputs = 1;
constexpr uint32_t kOutputTensor = 0;
namespace {
inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
return context->getInputBuffer(tensor) != nullptr;
}
inline bool isTimeMajor(IOperationExecutionContext* context) {
return context->getInputValue<bool>(kTimeMajorParam);
}
template <typename T>
inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
LSTMParams params;
params.activation =
static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam));
params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
return params;
}
} // namespace
bool validate(const IOperationValidationContext* context) {
NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
const OperandType inputType = context->getInputType(kInputTensor);
std::vector<OperandType> inExpectedTypes;
std::vector<OperandType> outExpectedTypes;
if (inputType == OperandType::TENSOR_FLOAT32) {
inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::INT32, OperandType::FLOAT32,
OperandType::FLOAT32, OperandType::BOOL,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32};
outExpectedTypes = {OperandType::TENSOR_FLOAT32};
} else if (inputType == OperandType::TENSOR_FLOAT16) {
inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::INT32, OperandType::FLOAT16,
OperandType::FLOAT16, OperandType::BOOL,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16};
outExpectedTypes = {OperandType::TENSOR_FLOAT16};
} else {
NN_RET_CHECK_FAIL()
<< "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: "
<< toString(inputType);
}
NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
return validateHalVersion(context, HalVersion::V1_2);
}
bool prepare(IOperationExecutionContext* context) {
// Check that none of the required inputs are omitted
const std::vector<int> requiredInputs = {
kInputTensor,
kInputToForgetWeightsTensor,
kInputToCellWeightsTensor,
kInputToOutputWeightsTensor,
kRecurrentToForgetWeightsTensor,
kRecurrentToCellWeightsTensor,
kRecurrentToOutputWeightsTensor,
kForgetGateBiasTensor,
kCellGateBiasTensor,
kOutputGateBiasTensor,
kOutputStateInTensor,
kCellStateInTensor,
kActivationParam,
kCellClipParam,
kProjClipParam,
kTimeMajorParam,
};
for (const int requiredInput : requiredInputs) {
NN_RET_CHECK(!context->isOmittedInput(requiredInput))
<< "required input " << requiredInput << " is omitted";
}
const Shape inputShape = context->getInputShape(kInputTensor);
const uint32_t inputRank = getNumberOfDimensions(inputShape);
NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank;
const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
if (hasTensor(context, kInputToInputWeightsTensor)) {
const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
}
const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
}
const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
// We make sure the input-gate's parameters are either both present (regular
// LSTM) or not at all (CIFG-LSTM).
const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
hasTensor(context, kRecurrentToInputWeightsTensor)) ||
(!hasTensor(context, kInputToInputWeightsTensor) &&
!hasTensor(context, kRecurrentToInputWeightsTensor));
NN_RET_CHECK(cifgWeightsAllOrNone);
if (hasTensor(context, kCellToInputWeightsTensor)) {
const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
}
if (hasTensor(context, kCellToForgetWeightsTensor)) {
const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
}
if (hasTensor(context, kCellToOutputWeightsTensor)) {
const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
}
// Making sure the peephole weights are there all or none.
const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
const bool peepholeWeightsAllOrNone =
((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
hasTensor(context, kCellToForgetWeightsTensor) &&
hasTensor(context, kCellToOutputWeightsTensor)) ||
(!hasTensor(context, kCellToInputWeightsTensor) &&
!hasTensor(context, kCellToForgetWeightsTensor) &&
!hasTensor(context, kCellToOutputWeightsTensor));
NN_RET_CHECK(peepholeWeightsAllOrNone);
if (!cifgUsed) {
NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
} else {
NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
<< "Input gate bias tensor is present when CIFG is used";
}
const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
if (hasTensor(context, kProjectionWeightsTensor)) {
const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
}
if (hasTensor(context, kProjectionBiasTensor)) {
const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
}
const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
if (hasTensor(context, kInputLayerNormWeightsTensor)) {
const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
}
if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
}
if (hasTensor(context, kCellLayerNormWeightsTensor)) {
const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
}
if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
}
if (cifgUsed) {
NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
<< "Input layer norm weights tensor is present when CIFG is used";
const bool layerNormWeightsAllOrNoneCifg =
(hasTensor(context, kForgetLayerNormWeightsTensor) &&
hasTensor(context, kCellLayerNormWeightsTensor) &&
hasTensor(context, kOutputLayerNormWeightsTensor)) ||
(!hasTensor(context, kForgetLayerNormWeightsTensor) &&
!hasTensor(context, kCellLayerNormWeightsTensor) &&
!hasTensor(context, kOutputLayerNormWeightsTensor));
NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
} else {
const bool layerNormWeightsAllOrNone =
(hasTensor(context, kInputLayerNormWeightsTensor) &&
hasTensor(context, kForgetLayerNormWeightsTensor) &&
hasTensor(context, kCellLayerNormWeightsTensor) &&
hasTensor(context, kOutputLayerNormWeightsTensor)) ||
(!hasTensor(context, kInputLayerNormWeightsTensor) &&
!hasTensor(context, kForgetLayerNormWeightsTensor) &&
!hasTensor(context, kCellLayerNormWeightsTensor) &&
!hasTensor(context, kOutputLayerNormWeightsTensor));
NN_RET_CHECK(layerNormWeightsAllOrNone);
}
Shape outputShape = context->getInputShape(kInputTensor);
outputShape.dimensions[2] = outputSize;
return context->setOutputShape(kOutputTensor, outputShape);
}
bool execute(IOperationExecutionContext* context) {
const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
const OperandType inputType = context->getInputType(kInputTensor);
switch (inputType) {
case OperandType::TENSOR_FLOAT32: {
std::vector<float> outputStateOut(outputStateSize);
std::vector<float> cellStateOut(cellStateSize);
std::vector<float> scratchBuffer(scratchSize);
LSTMCell::LSTMEvalFloat32(
getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
context->getInputShape(kInputTensor),
context->getInputBuffer<float>(kInputToInputWeightsTensor),
context->getInputBuffer<float>(kInputToForgetWeightsTensor),
context->getInputBuffer<float>(kInputToCellWeightsTensor),
context->getInputBuffer<float>(kInputToOutputWeightsTensor),
context->getInputShape(kInputToOutputWeightsTensor),
context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
context->getInputShape(kRecurrentToOutputWeightsTensor),
context->getInputBuffer<float>(kCellToInputWeightsTensor),
context->getInputBuffer<float>(kCellToForgetWeightsTensor),
context->getInputBuffer<float>(kCellToOutputWeightsTensor),
/*aux_input_buffer=*/nullptr,
/*aux_input_to_input_weights_buffer=*/nullptr,
/*aux_input_to_forget_weights_buffer=*/nullptr,
/*aux_input_to_cell_weights_buffer=*/nullptr,
/*aux_input_to_output_weights_buffer=*/nullptr,
context->getInputBuffer<float>(kInputGateBiasTensor),
context->getInputBuffer<float>(kForgetGateBiasTensor),
context->getInputBuffer<float>(kCellGateBiasTensor),
context->getInputBuffer<float>(kOutputGateBiasTensor),
context->getInputBuffer<float>(kProjectionWeightsTensor),
context->getInputBuffer<float>(kProjectionBiasTensor),
context->getInputBuffer<float>(kOutputStateInTensor),
context->getInputBuffer<float>(kCellStateInTensor),
context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
context->getInputBuffer<float>(kOutputLayerNormWeightsTensor),
outputStateOut.data(), cellStateOut.data(),
context->getOutputBuffer<float>(kOutputTensor), scratchBuffer.data(),
isTimeMajor(context));
} break;
case OperandType::TENSOR_FLOAT16: {
std::vector<_Float16> outputStateOut(outputStateSize);
std::vector<_Float16> cellStateOut(cellStateSize);
std::vector<_Float16> scratchBuffer(scratchSize);
LSTMCell::LSTMEvalFloat16(
getLSTMParams<_Float16>(context),
context->getInputBuffer<_Float16>(kInputTensor),
context->getInputShape(kInputTensor),
context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
context->getInputShape(kInputToOutputWeightsTensor),
context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
context->getInputShape(kRecurrentToOutputWeightsTensor),
context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
/*aux_input_buffer=*/nullptr,
/*aux_input_to_input_weights_buffer=*/nullptr,
/*aux_input_to_forget_weights_buffer=*/nullptr,
/*aux_input_to_cell_weights_buffer=*/nullptr,
/*aux_input_to_output_weights_buffer=*/nullptr,
context->getInputBuffer<_Float16>(kInputGateBiasTensor),
context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
context->getInputBuffer<_Float16>(kCellGateBiasTensor),
context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
context->getInputBuffer<_Float16>(kProjectionBiasTensor),
context->getInputBuffer<_Float16>(kOutputStateInTensor),
context->getInputBuffer<_Float16>(kCellStateInTensor),
context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
outputStateOut.data(), cellStateOut.data(),
context->getOutputBuffer<_Float16>(kOutputTensor), scratchBuffer.data(),
isTimeMajor(context));
} break;
default: {
LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
return false;
}
}
return true;
}
} // namespace unidirectional_sequence_lstm
NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM",
unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare,
unidirectional_sequence_lstm::execute, .allowOmittedOperand = true);
} // namespace nn
} // namespace android