C++程序  |  429行  |  16.43 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "LSTM.h"

#include <android-base/logging.h>

#include "NeuralNetworksWrapper.h"
#include "gmock/gmock-matchers.h"
#include "gtest/gtest.h"

#include <sstream>
#include <string>
#include <vector>

namespace android {
namespace nn {
namespace wrapper {

using ::testing::Each;
using ::testing::FloatNear;
using ::testing::Matcher;

namespace {

std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
                                           float max_abs_error = 1.e-6) {
    std::vector<Matcher<float>> matchers;
    matchers.reserve(values.size());
    for (const float& v : values) {
        matchers.emplace_back(FloatNear(v, max_abs_error));
    }
    return matchers;
}

}  // anonymous namespace

#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
    ACTION(Input)                                \
    ACTION(InputToInputWeights)                  \
    ACTION(InputToCellWeights)                   \
    ACTION(InputToForgetWeights)                 \
    ACTION(InputToOutputWeights)                 \
    ACTION(RecurrentToInputWeights)              \
    ACTION(RecurrentToCellWeights)               \
    ACTION(RecurrentToForgetWeights)             \
    ACTION(RecurrentToOutputWeights)             \
    ACTION(CellToInputWeights)                   \
    ACTION(CellToForgetWeights)                  \
    ACTION(CellToOutputWeights)                  \
    ACTION(InputGateBias)                        \
    ACTION(CellGateBias)                         \
    ACTION(ForgetGateBias)                       \
    ACTION(OutputGateBias)                       \
    ACTION(ProjectionWeights)                    \
    ACTION(ProjectionBias)                       \
    ACTION(OutputStateIn)                        \
    ACTION(CellStateIn)

#define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
    ACTION(InputLayerNormWeights)          \
    ACTION(ForgetLayerNormWeights)         \
    ACTION(CellLayerNormWeights)           \
    ACTION(OutputLayerNormWeights)

// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
    ACTION(ScratchBuffer)              \
    ACTION(OutputStateOut)             \
    ACTION(CellStateOut)               \
    ACTION(Output)

class LayerNormLSTMOpModel {
   public:
    LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
                         bool use_cifg, bool use_peephole, bool use_projection_weights,
                         bool use_projection_bias, float cell_clip, float proj_clip,
                         const std::vector<std::vector<uint32_t>>& input_shapes0)
        : n_input_(n_input),
          n_output_(n_output),
          use_cifg_(use_cifg),
          use_peephole_(use_peephole),
          use_projection_weights_(use_projection_weights),
          use_projection_bias_(use_projection_bias),
          activation_(ActivationFn::kActivationTanh),
          cell_clip_(cell_clip),
          proj_clip_(proj_clip) {
        std::vector<uint32_t> inputs;
        std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);

        auto it = input_shapes.begin();

        // Input and weights
#define AddInput(X)                                     \
    CHECK(it != input_shapes.end());                    \
    OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
    inputs.push_back(model_.addOperand(&X##OpndTy));

        FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);

        // Parameters
        OperandType ActivationOpndTy(Type::INT32, {});
        inputs.push_back(model_.addOperand(&ActivationOpndTy));
        OperandType CellClipOpndTy(Type::FLOAT32, {});
        inputs.push_back(model_.addOperand(&CellClipOpndTy));
        OperandType ProjClipOpndTy(Type::FLOAT32, {});
        inputs.push_back(model_.addOperand(&ProjClipOpndTy));

        FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);

#undef AddOperand

        // Output and other intermediate state
        std::vector<std::vector<uint32_t>> output_shapes{
                {n_batch, n_cell * (use_cifg ? 3 : 4)},
                {n_batch, n_output},
                {n_batch, n_cell},
                {n_batch, n_output},
        };
        std::vector<uint32_t> outputs;

        auto it2 = output_shapes.begin();

#define AddOutput(X)                                     \
    CHECK(it2 != output_shapes.end());                   \
    OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
    outputs.push_back(model_.addOperand(&X##OpndTy));

        FOR_ALL_OUTPUT_TENSORS(AddOutput);

#undef AddOutput

        model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
        model_.identifyInputsAndOutputs(inputs, outputs);

        Input_.insert(Input_.end(), n_batch * n_input, 0.f);
        OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
        CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);

        auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
            uint32_t sz = 1;
            for (uint32_t d : dims) {
                sz *= d;
            }
            return sz;
        };

        it2 = output_shapes.begin();

#define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);

        FOR_ALL_OUTPUT_TENSORS(ReserveOutput);

#undef ReserveOutput

        model_.finish();
    }

#define DefineSetter(X) \
    void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }

    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
    FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);

#undef DefineSetter

    void ResetOutputState() {
        std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
        std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
    }

    void ResetCellState() {
        std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
        std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
    }

    void SetInput(int offset, const float* begin, const float* end) {
        for (; begin != end; begin++, offset++) {
            Input_[offset] = *begin;
        }
    }

    uint32_t num_inputs() const { return n_input_; }
    uint32_t num_outputs() const { return n_output_; }

    const std::vector<float>& GetOutput() const { return Output_; }

    void Invoke() {
        ASSERT_TRUE(model_.isValid());

        OutputStateIn_.swap(OutputStateOut_);
        CellStateIn_.swap(CellStateOut_);

        Compilation compilation(&model_);
        compilation.finish();
        Execution execution(&compilation);
#define SetInputOrWeight(X)                                                                       \
    ASSERT_EQ(                                                                                    \
            execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
            Result::NO_ERROR);

        FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
        FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);

#undef SetInputOrWeight

#define SetOutput(X)                                                                               \
    ASSERT_EQ(                                                                                     \
            execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
            Result::NO_ERROR);

        FOR_ALL_OUTPUT_TENSORS(SetOutput);

#undef SetOutput

        if (use_cifg_) {
            execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
            execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
        }

        if (use_peephole_) {
            if (use_cifg_) {
                execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
            }
        } else {
            execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
            execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
            execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
        }

        if (use_projection_weights_) {
            if (!use_projection_bias_) {
                execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
            }
        } else {
            execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
            execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
        }

        ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
                  Result::NO_ERROR);
        ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
                  Result::NO_ERROR);
        ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
                  Result::NO_ERROR);

        ASSERT_EQ(execution.compute(), Result::NO_ERROR);
    }

   private:
    Model model_;
    // Execution execution_;
    const uint32_t n_input_;
    const uint32_t n_output_;

    const bool use_cifg_;
    const bool use_peephole_;
    const bool use_projection_weights_;
    const bool use_projection_bias_;

    const int activation_;
    const float cell_clip_;
    const float proj_clip_;

#define DefineTensor(X) std::vector<float> X##_;

    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
    FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
    FOR_ALL_OUTPUT_TENSORS(DefineTensor);

#undef DefineTensor
};

TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
    const int n_batch = 2;
    const int n_input = 5;
    // n_cell and n_output have the same size when there is no projection.
    const int n_cell = 4;
    const int n_output = 3;

    LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
                              /*use_cifg=*/false, /*use_peephole=*/true,
                              /*use_projection_weights=*/true,
                              /*use_projection_bias=*/false,
                              /*cell_clip=*/0.0, /*proj_clip=*/0.0,
                              {
                                      {n_batch, n_input},  // input tensor

                                      {n_cell, n_input},  // input_to_input_weight tensor
                                      {n_cell, n_input},  // input_to_forget_weight tensor
                                      {n_cell, n_input},  // input_to_cell_weight tensor
                                      {n_cell, n_input},  // input_to_output_weight tensor

                                      {n_cell, n_output},  // recurrent_to_input_weight tensor
                                      {n_cell, n_output},  // recurrent_to_forget_weight tensor
                                      {n_cell, n_output},  // recurrent_to_cell_weight tensor
                                      {n_cell, n_output},  // recurrent_to_output_weight tensor

                                      {n_cell},  // cell_to_input_weight tensor
                                      {n_cell},  // cell_to_forget_weight tensor
                                      {n_cell},  // cell_to_output_weight tensor

                                      {n_cell},  // input_gate_bias tensor
                                      {n_cell},  // forget_gate_bias tensor
                                      {n_cell},  // cell_bias tensor
                                      {n_cell},  // output_gate_bias tensor

                                      {n_output, n_cell},  // projection_weight tensor
                                      {0},                 // projection_bias tensor

                                      {n_batch, n_output},  // output_state_in tensor
                                      {n_batch, n_cell},    // cell_state_in tensor

                                      {n_cell},  // input_layer_norm_weights tensor
                                      {n_cell},  // forget_layer_norm_weights tensor
                                      {n_cell},  // cell_layer_norm_weights tensor
                                      {n_cell},  // output_layer_norm_weights tensor
                              });

    lstm.SetInputToInputWeights({0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
                                 -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});

    lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
                                  -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5});

    lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
                                0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6});

    lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
                                  0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4});

    lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});

    lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});

    lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});

    lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});

    lstm.SetRecurrentToInputWeights(
            {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});

    lstm.SetRecurrentToCellWeights(
            {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});

    lstm.SetRecurrentToForgetWeights(
            {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});

    lstm.SetRecurrentToOutputWeights(
            {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});

    lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
    lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
    lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});

    lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});

    lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
    lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
    lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
    lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});

    const std::vector<std::vector<float>> lstm_input = {
            {                           // Batch0: 3 (input_sequence_size) * 5 (n_input)
             0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
             0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
             0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2

            {                           // Batch1: 3 (input_sequence_size) * 5 (n_input)
             0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
             0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
             0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
    };

    const std::vector<std::vector<float>> lstm_golden_output = {
            {
                    // Batch0: 3 (input_sequence_size) * 3 (n_output)
                    0.0244077, 0.128027, -0.00170918,  // seq 0
                    0.0137642, 0.140751, 0.0395835,    // seq 1
                    -0.00459231, 0.155278, 0.0837377,  // seq 2
            },
            {
                    // Batch1: 3 (input_sequence_size) * 3 (n_output)
                    -0.00692428, 0.0848741, 0.063445,  // seq 0
                    -0.00403912, 0.139963, 0.072681,   // seq 1
                    0.00752706, 0.161903, 0.0561371,   // seq 2
            }};

    // Resetting cell_state and output_state
    lstm.ResetCellState();
    lstm.ResetOutputState();

    const int input_sequence_size = lstm_input[0].size() / n_input;
    for (int i = 0; i < input_sequence_size; i++) {
        for (int b = 0; b < n_batch; ++b) {
            const float* batch_start = lstm_input[b].data() + i * n_input;
            const float* batch_end = batch_start + n_input;

            lstm.SetInput(b * n_input, batch_start, batch_end);
        }

        lstm.Invoke();

        std::vector<float> expected;
        for (int b = 0; b < n_batch; ++b) {
            const float* golden_start = lstm_golden_output[b].data() + i * n_output;
            const float* golden_end = golden_start + n_output;
            expected.insert(expected.end(), golden_start, golden_end);
        }
        EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
    }
}

}  // namespace wrapper
}  // namespace nn
}  // namespace android