普通文本  |  177行  |  6.03 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <fstream>
#include <string>
#include <vector>

#include "utils/tflite/text_encoder.h"
#include "gtest/gtest.h"
#include "third_party/absl/flags/flag.h"
#include "flatbuffers/flexbuffers.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"

namespace libtextclassifier3 {
namespace {

std::string GetTestConfigPath() {
  return "";
}

class TextEncoderOpModel : public tflite::SingleOpModel {
 public:
  TextEncoderOpModel(std::initializer_list<int> input_strings_shape,
                     std::initializer_list<int> attribute_shape);
  void SetInputText(const std::initializer_list<string>& strings) {
    PopulateStringTensor(input_string_, strings);
    PopulateTensor(input_length_, {static_cast<int32_t>(strings.size())});
  }
  void SetMaxOutputLength(int length) {
    PopulateTensor(input_output_maxlength_, {length});
  }
  void SetInt32Attribute(const std::initializer_list<int>& attribute) {
    PopulateTensor(input_attributes_int32_, attribute);
  }
  void SetFloatAttribute(const std::initializer_list<float>& attribute) {
    PopulateTensor(input_attributes_float_, attribute);
  }

  std::vector<int> GetOutputEncoding() {
    return ExtractVector<int>(output_encoding_);
  }
  std::vector<int> GetOutputPositions() {
    return ExtractVector<int>(output_positions_);
  }
  std::vector<int> GetOutputAttributeInt32() {
    return ExtractVector<int>(output_attributes_int32_);
  }
  std::vector<float> GetOutputAttributeFloat() {
    return ExtractVector<float>(output_attributes_float_);
  }
  int GetEncodedLength() { return ExtractVector<int>(output_length_)[0]; }

 private:
  int input_string_;
  int input_length_;
  int input_output_maxlength_;
  int input_attributes_int32_;
  int input_attributes_float_;

  int output_encoding_;
  int output_positions_;
  int output_length_;
  int output_attributes_int32_;
  int output_attributes_float_;
};

TextEncoderOpModel::TextEncoderOpModel(
    std::initializer_list<int> input_strings_shape,
    std::initializer_list<int> attribute_shape) {
  input_string_ = AddInput(tflite::TensorType_STRING);
  input_length_ = AddInput(tflite::TensorType_INT32);
  input_output_maxlength_ = AddInput(tflite::TensorType_INT32);
  input_attributes_int32_ = AddInput(tflite::TensorType_INT32);
  input_attributes_float_ = AddInput(tflite::TensorType_FLOAT32);

  output_encoding_ = AddOutput(tflite::TensorType_INT32);
  output_positions_ = AddOutput(tflite::TensorType_INT32);
  output_length_ = AddOutput(tflite::TensorType_INT32);
  output_attributes_int32_ = AddOutput(tflite::TensorType_INT32);
  output_attributes_float_ = AddOutput(tflite::TensorType_FLOAT32);

  std::ifstream test_config_stream(GetTestConfigPath());
  std::string config((std::istreambuf_iterator<char>(test_config_stream)),
                     (std::istreambuf_iterator<char>()));
  flexbuffers::Builder builder;
  builder.Map([&]() { builder.String("text_encoder_config", config); });
  builder.Finish();
  SetCustomOp("TextEncoder", builder.GetBuffer(),
              tflite::ops::custom::Register_TEXT_ENCODER);
  BuildInterpreter(
      {input_strings_shape, {1}, {1}, attribute_shape, attribute_shape});
}

// Tests
TEST(TextEncoderTest, SimpleEncoder) {
  TextEncoderOpModel m({1, 1}, {1, 1});
  m.SetInputText({"Hello"});
  m.SetMaxOutputLength(10);
  m.SetInt32Attribute({7});
  m.SetFloatAttribute({3.f});

  m.Invoke();

  EXPECT_EQ(m.GetEncodedLength(), 5);
  EXPECT_THAT(m.GetOutputEncoding(),
              testing::ElementsAre(1, 90, 547, 58, 2, 2, 2, 2, 2, 2));
  EXPECT_THAT(m.GetOutputPositions(),
              testing::ElementsAre(0, 1, 2, 3, 4, 10, 10, 10, 10, 10));
  EXPECT_THAT(m.GetOutputAttributeInt32(),
              testing::ElementsAre(7, 7, 7, 7, 7, 7, 7, 7, 7, 7));
  EXPECT_THAT(
      m.GetOutputAttributeFloat(),
      testing::ElementsAre(3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f, 3.f));
}

TEST(TextEncoderTest, ManyStrings) {
  TextEncoderOpModel m({1, 3}, {1, 3});
  m.SetInt32Attribute({1, 2, 3});
  m.SetFloatAttribute({5.f, 4.f, 3.f});
  m.SetInputText({"Hello", "Hi", "Bye"});
  m.SetMaxOutputLength(10);

  m.Invoke();

  EXPECT_EQ(m.GetEncodedLength(), 10);
  EXPECT_THAT(m.GetOutputEncoding(),
              testing::ElementsAre(547, 58, 2, 1, 862, 2, 1, 1919, 19, 2));
  EXPECT_THAT(m.GetOutputPositions(),
              testing::ElementsAre(2, 3, 4, 0, 1, 2, 0, 1, 2, 3));
  EXPECT_THAT(m.GetOutputAttributeInt32(),
              testing::ElementsAre(1, 1, 1, 2, 2, 2, 3, 3, 3, 3));
  EXPECT_THAT(
      m.GetOutputAttributeFloat(),
      testing::ElementsAre(5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 3.f));
}

TEST(TextEncoderTest, LongStrings) {
  TextEncoderOpModel m({1, 4}, {1, 4});
  m.SetInt32Attribute({1, 2, 3, 4});
  m.SetFloatAttribute({5.f, 4.f, 3.f, 2.f});
  m.SetInputText({"Hello", "Hi", "Bye", "Hi"});
  m.SetMaxOutputLength(9);

  m.Invoke();

  EXPECT_EQ(m.GetEncodedLength(), 9);
  EXPECT_THAT(m.GetOutputEncoding(),
              testing::ElementsAre(862, 2, 1, 1919, 19, 2, 1, 862, 2));
  EXPECT_THAT(m.GetOutputPositions(),
              testing::ElementsAre(1, 2, 0, 1, 2, 3, 0, 1, 2));
  EXPECT_THAT(m.GetOutputAttributeInt32(),
              testing::ElementsAre(2, 2, 3, 3, 3, 3, 4, 4, 4));
  EXPECT_THAT(
      m.GetOutputAttributeFloat(),
      testing::ElementsAre(4.f, 4.f, 3.f, 3.f, 3.f, 3.f, 2.f, 2.f, 2.f));
}

}  // namespace
}  // namespace libtextclassifier3