#ifndef FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
#define FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H
#include "HalOperation.h"
#include "OperationsUtils.h"
#include <vector>
namespace android {
namespace nn {
struct RunTimeOperandInfo;
class QuantizedLSTMCell {
public:
QuantizedLSTMCell(const android::hardware::neuralnetworks::V1_2::Operation& operation,
std::vector<RunTimeOperandInfo>& operands);
static bool prepare(const android::hardware::neuralnetworks::V1_2::Operation& operation,
std::vector<RunTimeOperandInfo>& operands, Shape* cellStateShape,
Shape* outputShape);
bool eval();
// Inputs:
static constexpr int kInputTensor = 0;
// Input weight tensors of size: {n_cell, n_input}
static constexpr int kInputToInputWeightsTensor = 1;
static constexpr int kInputToForgetWeightsTensor = 2;
static constexpr int kInputToCellWeightsTensor = 3;
static constexpr int kInputToOutputWeightsTensor = 4;
// Recurrent weight tensors of size {n_cell, n_output}
static constexpr int kRecurrentToInputWeightsTensor = 5;
static constexpr int kRecurrentToForgetWeightsTensor = 6;
static constexpr int kRecurrentToCellWeightsTensor = 7;
static constexpr int kRecurrentToOutputWeightsTensor = 8;
// Gates bias tensors of size {n_cell}
static constexpr int kInputGateBiasTensor = 9;
static constexpr int kForgetGateBiasTensor = 10;
static constexpr int kCellGateBiasTensor = 11;
static constexpr int kOutputGateBiasTensor = 12;
static constexpr int kPrevCellStateTensor = 13;
static constexpr int kPrevOutputTensor = 14;
// Outputs:
static constexpr int kCellStateOutTensor = 0;
static constexpr int kOutputTensor = 1;
private:
const RunTimeOperandInfo* input_;
const RunTimeOperandInfo* inputToInputWeights_;
const RunTimeOperandInfo* inputToForgetWeights_;
const RunTimeOperandInfo* inputToCellWeights_;
const RunTimeOperandInfo* inputToOutputWeights_;
const RunTimeOperandInfo* recurrentToInputWeights_;
const RunTimeOperandInfo* recurrentToForgetWeights_;
const RunTimeOperandInfo* recurrentToCellWeights_;
const RunTimeOperandInfo* recurrentToOutputWeights_;
const RunTimeOperandInfo* inputGateBias_;
const RunTimeOperandInfo* forgetGateBias_;
const RunTimeOperandInfo* cellGateBias_;
const RunTimeOperandInfo* outputGateBias_;
const RunTimeOperandInfo* prevCellState_;
const RunTimeOperandInfo* prevOutput_;
RunTimeOperandInfo* cellStateOut_;
RunTimeOperandInfo* output_;
void concatenateWeights(const std::vector<uint32_t>& weightsDims, uint8_t* weights);
void concatenateBiases(uint32_t outputSize, int32_t* bias);
};
} // namespace nn
} // namespace android
#endif // FRAMEWORKS_ML_NN_QUANTIZEDLSTM_H