#ifndef FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H #define FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H #include "HalOperation.h" #include "OperationsUtils.h" #include <vector> namespace android { namespace nn { struct RunTimeOperandInfo; class QuantizedLSTMCell { public: QuantizedLSTMCell(const android::hardware::neuralnetworks::V1_2::Operation& operation, std::vector<RunTimeOperandInfo>& operands); static bool prepare(const android::hardware::neuralnetworks::V1_2::Operation& operation, std::vector<RunTimeOperandInfo>& operands, Shape* cellStateShape, Shape* outputShape); bool eval(); // Inputs: static constexpr int kInputTensor = 0; // Input weight tensors of size: {n_cell, n_input} static constexpr int kInputToInputWeightsTensor = 1; static constexpr int kInputToForgetWeightsTensor = 2; static constexpr int kInputToCellWeightsTensor = 3; static constexpr int kInputToOutputWeightsTensor = 4; // Recurrent weight tensors of size {n_cell, n_output} static constexpr int kRecurrentToInputWeightsTensor = 5; static constexpr int kRecurrentToForgetWeightsTensor = 6; static constexpr int kRecurrentToCellWeightsTensor = 7; static constexpr int kRecurrentToOutputWeightsTensor = 8; // Gates bias tensors of size {n_cell} static constexpr int kInputGateBiasTensor = 9; static constexpr int kForgetGateBiasTensor = 10; static constexpr int kCellGateBiasTensor = 11; static constexpr int kOutputGateBiasTensor = 12; static constexpr int kPrevCellStateTensor = 13; static constexpr int kPrevOutputTensor = 14; // Outputs: static constexpr int kCellStateOutTensor = 0; static constexpr int kOutputTensor = 1; private: const RunTimeOperandInfo* input_; const RunTimeOperandInfo* inputToInputWeights_; const RunTimeOperandInfo* inputToForgetWeights_; const RunTimeOperandInfo* inputToCellWeights_; const RunTimeOperandInfo* inputToOutputWeights_; const RunTimeOperandInfo* recurrentToInputWeights_; const RunTimeOperandInfo* recurrentToForgetWeights_; const RunTimeOperandInfo* recurrentToCellWeights_; const RunTimeOperandInfo* recurrentToOutputWeights_; const RunTimeOperandInfo* inputGateBias_; const RunTimeOperandInfo* forgetGateBias_; const RunTimeOperandInfo* cellGateBias_; const RunTimeOperandInfo* outputGateBias_; const RunTimeOperandInfo* prevCellState_; const RunTimeOperandInfo* prevOutput_; RunTimeOperandInfo* cellStateOut_; RunTimeOperandInfo* output_; void concatenateWeights(const std::vector<uint32_t>& weightsDims, uint8_t* weights); void concatenateBiases(uint32_t outputSize, int32_t* bias); }; } // namespace nn } // namespace android #endif // FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H