C++程序  |  128行  |  4.43 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// Contains classes that can execute different models/parts of a model.

#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_

#include <memory>

#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/tensor-view.h"
#include "utils/tflite-model-executor.h"

namespace libtextclassifier3 {

// Executor for the text selection prediction and classification models.
class ModelExecutor : public TfLiteModelExecutor {
 public:
  static std::unique_ptr<ModelExecutor> FromModelSpec(
      const tflite::Model* model_spec) {
    auto model = TfLiteModelFromModelSpec(model_spec);
    if (!model) {
      return nullptr;
    }
    return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
  }

  static std::unique_ptr<ModelExecutor> FromBuffer(
      const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
    auto model = TfLiteModelFromBuffer(model_spec_buffer);
    if (!model) {
      return nullptr;
    }
    return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
  }

  TensorView<float> ComputeLogits(const TensorView<float>& features,
                                  tflite::Interpreter* interpreter) const;

 protected:
  explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
      : TfLiteModelExecutor(std::move(model)) {}

  static const int kInputIndexFeatures = 0;
  static const int kOutputIndexLogits = 0;
};

// Executor for embedding sparse features into a dense vector.
class EmbeddingExecutor {
 public:
  virtual ~EmbeddingExecutor() {}

  // Embeds the sparse_features into a dense embedding and adds (+) it
  // element-wise to the dest vector.
  virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
                            int dest_size) const = 0;

  // Returns true when the model is ready to be used, false otherwise.
  virtual bool IsReady() const { return true; }
};

class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
 public:
  static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
      const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
      int quantization_bits,
      const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);

  // Embeds the sparse_features into a dense embedding and adds (+) it
  // element-wise to the dest vector.
  bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
                    int dest_size) const;

  // Auxiliary function for computing prefixes used in implementation of
  // efficient mask indexing data structure.
  void ComputePrefixCounts();

  // Function implementing mask indexing based on efficient data structure
  int PruneBucketId(int bucket_id) const;

 protected:
  explicit TFLiteEmbeddingExecutor(
      std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
      int num_buckets, int bytes_per_embedding, int output_embedding_size,
      const TfLiteTensor* scales, const TfLiteTensor* embeddings,
      std::unique_ptr<tflite::Interpreter> interpreter,
      const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);

  std::unique_ptr<TfLiteModelExecutor> executor_;

  int quantization_bits_;
  int num_buckets_ = -1;
  int bytes_per_embedding_ = -1;
  int output_embedding_size_ = -1;
  const TfLiteTensor* scales_ = nullptr;
  const TfLiteTensor* embeddings_ = nullptr;

  // NOTE: This interpreter is used in a read-only way (as a storage for the
  // model params), thus is still thread-safe.
  std::unique_ptr<tflite::Interpreter> interpreter_;

  std::vector<uint64> pruning_mask_;
  std::vector<uint16> prefix_counts_;
  int full_num_buckets_ = -1;

  // Index of row of embedding table corresponding to all pruned buckets.
  int pruned_row_bucket_id_ = -1;
};

}  // namespace libtextclassifier3

#endif  // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_