/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FRAMEWORKS_ML_NN_LSTMCELL_H
#define FRAMEWORKS_ML_NN_LSTMCELL_H
#include "ActivationFunctor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include <algorithm>
#include <cmath>
namespace android {
namespace hardware {
namespace neuralnetworks {
namespace V1_1 {
struct Operation;
}
} // namespace neuralnetworks
} // namespace hardware
} // namespace android
namespace android {
namespace nn {
struct LSTMParams {
TfLiteFusedActivation activation_;
float cell_clip_;
float proj_clip_;
};
struct RunTimeOperandInfo;
struct Shape;
class LSTMCell {
public:
LSTMCell(const android::hardware::neuralnetworks::V1_1::Operation &operation,
std::vector<RunTimeOperandInfo> &operands);
static bool Prepare(const android::hardware::neuralnetworks::V1_1::Operation &operation,
std::vector<RunTimeOperandInfo> &operands,
Shape *scratchShape,
Shape *outputStateShape,
Shape *cellStateShape,
Shape *outputShape);
bool Eval();
// Input Tensors of size {n_batch, n_input}
static constexpr int kInputTensor = 0;
// Input weight tensors of size: {n_cell, n_input}
static constexpr int kInputToInputWeightsTensor = 1; // Optional
static constexpr int kInputToForgetWeightsTensor = 2;
static constexpr int kInputToCellWeightsTensor = 3;
static constexpr int kInputToOutputWeightsTensor = 4;
// Recurrent weight tensors of size {n_cell, n_output}
static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
static constexpr int kRecurrentToForgetWeightsTensor = 6;
static constexpr int kRecurrentToCellWeightsTensor = 7;
static constexpr int kRecurrentToOutputWeightsTensor = 8;
// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
static constexpr int kCellToInputWeightsTensor = 9; // Optional
static constexpr int kCellToForgetWeightsTensor = 10; // Optional
static constexpr int kCellToOutputWeightsTensor = 11; // Optional
// Gates bias tensors of size {n_cell}
static constexpr int kInputGateBiasTensor = 12; // Optional
static constexpr int kForgetGateBiasTensor = 13;
static constexpr int kCellGateBiasTensor = 14;
static constexpr int kOutputGateBiasTensor = 15;
// Projection weight tensor of size {n_output, n_cell}
static constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
static constexpr int kProjectionBiasTensor = 17; // Optional
static constexpr int kOutputStateInTensor = 18;
static constexpr int kCellStateInTensor = 19;
static constexpr int kActivationParam = 20;
static constexpr int kCellClipParam = 21;
static constexpr int kProjClipParam = 22;
// Output tensors.
static constexpr int kScratchBufferTensor = 0;
static constexpr int kOutputStateOutTensor = 1;
static constexpr int kCellStateOutTensor = 2;
static constexpr int kOutputTensor = 3;
private:
static bool CheckInputTensorDimensions(
const android::hardware::neuralnetworks::V1_1::Operation &operation,
std::vector<RunTimeOperandInfo> &operands, uint32_t n_input,
uint32_t n_output, uint32_t n_cell);
LSTMParams params_;
const RunTimeOperandInfo *input_;
const RunTimeOperandInfo *input_to_input_weights_;
const RunTimeOperandInfo *input_to_forget_weights_;
const RunTimeOperandInfo *input_to_cell_weights_;
const RunTimeOperandInfo *input_to_output_weights_;
const RunTimeOperandInfo *recurrent_to_input_weights_;
const RunTimeOperandInfo *recurrent_to_forget_weights_;
const RunTimeOperandInfo *recurrent_to_cell_weights_;
const RunTimeOperandInfo *recurrent_to_output_weights_;
const RunTimeOperandInfo *cell_to_input_weights_;
const RunTimeOperandInfo *cell_to_forget_weights_;
const RunTimeOperandInfo *cell_to_output_weights_;
const RunTimeOperandInfo *input_gate_bias_;
const RunTimeOperandInfo *forget_gate_bias_;
const RunTimeOperandInfo *cell_bias_;
const RunTimeOperandInfo *output_gate_bias_;
const RunTimeOperandInfo *projection_weights_;
const RunTimeOperandInfo *projection_bias_;
const RunTimeOperandInfo *output_state_in_;
const RunTimeOperandInfo *cell_state_in_;
RunTimeOperandInfo *output_state_out_;
RunTimeOperandInfo *cell_state_out_;
RunTimeOperandInfo *output_;
RunTimeOperandInfo *scratch_buffer_;
};
} // namespace nn
} // namespace android
#endif // FRAMEWORKS_ML_NN_LSTMCELL_H