/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "LSTM.h"
#include <android-base/logging.h>
#include "NeuralNetworksWrapper.h"
#include "gmock/gmock-matchers.h"
#include "gtest/gtest.h"
#include <sstream>
#include <string>
#include <vector>
namespace android {
namespace nn {
namespace wrapper {
using ::testing::Each;
using ::testing::FloatNear;
using ::testing::Matcher;
namespace {
std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
float max_abs_error = 1.e-6) {
std::vector<Matcher<float>> matchers;
matchers.reserve(values.size());
for (const float& v : values) {
matchers.emplace_back(FloatNear(v, max_abs_error));
}
return matchers;
}
} // anonymous namespace
#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
ACTION(Input) \
ACTION(InputToInputWeights) \
ACTION(InputToCellWeights) \
ACTION(InputToForgetWeights) \
ACTION(InputToOutputWeights) \
ACTION(RecurrentToInputWeights) \
ACTION(RecurrentToCellWeights) \
ACTION(RecurrentToForgetWeights) \
ACTION(RecurrentToOutputWeights) \
ACTION(CellToInputWeights) \
ACTION(CellToForgetWeights) \
ACTION(CellToOutputWeights) \
ACTION(InputGateBias) \
ACTION(CellGateBias) \
ACTION(ForgetGateBias) \
ACTION(OutputGateBias) \
ACTION(ProjectionWeights) \
ACTION(ProjectionBias) \
ACTION(OutputStateIn) \
ACTION(CellStateIn)
#define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
ACTION(InputLayerNormWeights) \
ACTION(ForgetLayerNormWeights) \
ACTION(CellLayerNormWeights) \
ACTION(OutputLayerNormWeights)
// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
ACTION(ScratchBuffer) \
ACTION(OutputStateOut) \
ACTION(CellStateOut) \
ACTION(Output)
class LayerNormLSTMOpModel {
public:
LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
bool use_cifg, bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip,
const std::vector<std::vector<uint32_t>>& input_shapes0)
: n_input_(n_input),
n_output_(n_output),
use_cifg_(use_cifg),
use_peephole_(use_peephole),
use_projection_weights_(use_projection_weights),
use_projection_bias_(use_projection_bias),
activation_(ActivationFn::kActivationTanh),
cell_clip_(cell_clip),
proj_clip_(proj_clip) {
std::vector<uint32_t> inputs;
std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
auto it = input_shapes.begin();
// Input and weights
#define AddInput(X) \
CHECK(it != input_shapes.end()); \
OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
inputs.push_back(model_.addOperand(&X##OpndTy));
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
// Parameters
OperandType ActivationOpndTy(Type::INT32, {});
inputs.push_back(model_.addOperand(&ActivationOpndTy));
OperandType CellClipOpndTy(Type::FLOAT32, {});
inputs.push_back(model_.addOperand(&CellClipOpndTy));
OperandType ProjClipOpndTy(Type::FLOAT32, {});
inputs.push_back(model_.addOperand(&ProjClipOpndTy));
FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
#undef AddOperand
// Output and other intermediate state
std::vector<std::vector<uint32_t>> output_shapes{
{n_batch, n_cell * (use_cifg ? 3 : 4)},
{n_batch, n_output},
{n_batch, n_cell},
{n_batch, n_output},
};
std::vector<uint32_t> outputs;
auto it2 = output_shapes.begin();
#define AddOutput(X) \
CHECK(it2 != output_shapes.end()); \
OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
outputs.push_back(model_.addOperand(&X##OpndTy));
FOR_ALL_OUTPUT_TENSORS(AddOutput);
#undef AddOutput
model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
model_.identifyInputsAndOutputs(inputs, outputs);
Input_.insert(Input_.end(), n_batch * n_input, 0.f);
OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
uint32_t sz = 1;
for (uint32_t d : dims) {
sz *= d;
}
return sz;
};
it2 = output_shapes.begin();
#define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
#undef ReserveOutput
model_.finish();
}
#define DefineSetter(X) \
void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
#undef DefineSetter
void ResetOutputState() {
std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
}
void ResetCellState() {
std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
}
void SetInput(int offset, const float* begin, const float* end) {
for (; begin != end; begin++, offset++) {
Input_[offset] = *begin;
}
}
uint32_t num_inputs() const { return n_input_; }
uint32_t num_outputs() const { return n_output_; }
const std::vector<float>& GetOutput() const { return Output_; }
void Invoke() {
ASSERT_TRUE(model_.isValid());
OutputStateIn_.swap(OutputStateOut_);
CellStateIn_.swap(CellStateOut_);
Compilation compilation(&model_);
compilation.finish();
Execution execution(&compilation);
#define SetInputOrWeight(X) \
ASSERT_EQ( \
execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
Result::NO_ERROR);
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
#undef SetInputOrWeight
#define SetOutput(X) \
ASSERT_EQ( \
execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
Result::NO_ERROR);
FOR_ALL_OUTPUT_TENSORS(SetOutput);
#undef SetOutput
if (use_cifg_) {
execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
}
if (use_peephole_) {
if (use_cifg_) {
execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
}
} else {
execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
}
if (use_projection_weights_) {
if (!use_projection_bias_) {
execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
}
} else {
execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
}
ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
Result::NO_ERROR);
ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
Result::NO_ERROR);
ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
Result::NO_ERROR);
ASSERT_EQ(execution.compute(), Result::NO_ERROR);
}
private:
Model model_;
// Execution execution_;
const uint32_t n_input_;
const uint32_t n_output_;
const bool use_cifg_;
const bool use_peephole_;
const bool use_projection_weights_;
const bool use_projection_bias_;
const int activation_;
const float cell_clip_;
const float proj_clip_;
#define DefineTensor(X) std::vector<float> X##_;
FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
FOR_ALL_OUTPUT_TENSORS(DefineTensor);
#undef DefineTensor
};
TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
const int n_batch = 2;
const int n_input = 5;
// n_cell and n_output have the same size when there is no projection.
const int n_cell = 4;
const int n_output = 3;
LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
/*use_cifg=*/false, /*use_peephole=*/true,
/*use_projection_weights=*/true,
/*use_projection_bias=*/false,
/*cell_clip=*/0.0, /*proj_clip=*/0.0,
{
{n_batch, n_input}, // input tensor
{n_cell, n_input}, // input_to_input_weight tensor
{n_cell, n_input}, // input_to_forget_weight tensor
{n_cell, n_input}, // input_to_cell_weight tensor
{n_cell, n_input}, // input_to_output_weight tensor
{n_cell, n_output}, // recurrent_to_input_weight tensor
{n_cell, n_output}, // recurrent_to_forget_weight tensor
{n_cell, n_output}, // recurrent_to_cell_weight tensor
{n_cell, n_output}, // recurrent_to_output_weight tensor
{n_cell}, // cell_to_input_weight tensor
{n_cell}, // cell_to_forget_weight tensor
{n_cell}, // cell_to_output_weight tensor
{n_cell}, // input_gate_bias tensor
{n_cell}, // forget_gate_bias tensor
{n_cell}, // cell_bias tensor
{n_cell}, // output_gate_bias tensor
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
{n_batch, n_output}, // output_state_in tensor
{n_batch, n_cell}, // cell_state_in tensor
{n_cell}, // input_layer_norm_weights tensor
{n_cell}, // forget_layer_norm_weights tensor
{n_cell}, // cell_layer_norm_weights tensor
{n_cell}, // output_layer_norm_weights tensor
});
lstm.SetInputToInputWeights({0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
-0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
-0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5});
lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6});
lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4});
lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
lstm.SetRecurrentToInputWeights(
{-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
lstm.SetRecurrentToCellWeights(
{-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
lstm.SetRecurrentToForgetWeights(
{-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
lstm.SetRecurrentToOutputWeights(
{0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
const std::vector<std::vector<float>> lstm_input = {
{ // Batch0: 3 (input_sequence_size) * 5 (n_input)
0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
{ // Batch1: 3 (input_sequence_size) * 5 (n_input)
0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
};
const std::vector<std::vector<float>> lstm_golden_output = {
{
// Batch0: 3 (input_sequence_size) * 3 (n_output)
0.0244077, 0.128027, -0.00170918, // seq 0
0.0137642, 0.140751, 0.0395835, // seq 1
-0.00459231, 0.155278, 0.0837377, // seq 2
},
{
// Batch1: 3 (input_sequence_size) * 3 (n_output)
-0.00692428, 0.0848741, 0.063445, // seq 0
-0.00403912, 0.139963, 0.072681, // seq 1
0.00752706, 0.161903, 0.0561371, // seq 2
}};
// Resetting cell_state and output_state
lstm.ResetCellState();
lstm.ResetOutputState();
const int input_sequence_size = lstm_input[0].size() / n_input;
for (int i = 0; i < input_sequence_size; i++) {
for (int b = 0; b < n_batch; ++b) {
const float* batch_start = lstm_input[b].data() + i * n_input;
const float* batch_end = batch_start + n_input;
lstm.SetInput(b * n_input, batch_start, batch_end);
}
lstm.Invoke();
std::vector<float> expected;
for (int b = 0; b < n_batch; ++b) {
const float* golden_start = lstm_golden_output[b].data() + i * n_output;
const float* golden_end = golden_start + n_output;
expected.insert(expected.end(), golden_start, golden_end);
}
EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
}
}
} // namespace wrapper
} // namespace nn
} // namespace android