C++程序  |  225行  |  7.61 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_

#include "common/embedding-network-package.pb.h"
#include "common/embedding-network-params.h"
#include "common/embedding-network.pb.h"
#include "common/memory_image/memory-image-reader.h"
#include "util/base/integral_types.h"

namespace libtextclassifier {
namespace nlp_core {

// EmbeddingNetworkParams backed by a memory image.
//
// In this context, a memory image is like an EmbeddingNetworkProto, but with
// all repeated weights (>99% of the size) directly usable (with no parsing
// required).
class EmbeddingNetworkParamsFromImage : public EmbeddingNetworkParams {
 public:
  // Constructs an EmbeddingNetworkParamsFromImage, using the memory image that
  // starts at address start and contains num_bytes bytes.
  EmbeddingNetworkParamsFromImage(const void *start, uint64 num_bytes)
      : memory_reader_(start, num_bytes),
        trimmed_proto_(memory_reader_.trimmed_proto()) {
    embeddings_blob_offset_ = 0;

    hidden_blob_offset_ = embeddings_blob_offset_ + embeddings_size();
    if (trimmed_proto_.embeddings_size() &&
        trimmed_proto_.embeddings(0).is_quantized()) {
      // Adjust for quantization: each quantized matrix takes two blobs (instead
      // of one): one for the quantized values and one for the scales.
      hidden_blob_offset_ += embeddings_size();
    }

    hidden_bias_blob_offset_ = hidden_blob_offset_ + hidden_size();
    softmax_blob_offset_ = hidden_bias_blob_offset_ + hidden_bias_size();
    softmax_bias_blob_offset_ = softmax_blob_offset_ + softmax_size();
  }

  ~EmbeddingNetworkParamsFromImage() override {}

  const TaskSpec *GetTaskSpec() override {
    auto extension_id = task_spec_in_embedding_network_proto;
    if (trimmed_proto_.HasExtension(extension_id)) {
      return &(trimmed_proto_.GetExtension(extension_id));
    } else {
      return nullptr;
    }
  }

 protected:
  int embeddings_size() const override {
    return trimmed_proto_.embeddings_size();
  }

  int embeddings_num_rows(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    return trimmed_proto_.embeddings(i).rows();
  }

  int embeddings_num_cols(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    return trimmed_proto_.embeddings(i).cols();
  }

  const void *embeddings_weights(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    const int blob_index = trimmed_proto_.embeddings(i).is_quantized()
                               ? (embeddings_blob_offset_ + 2 * i)
                               : (embeddings_blob_offset_ + i);
    DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
    return data_blob_view.data();
  }

  QuantizationType embeddings_quant_type(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    if (trimmed_proto_.embeddings(i).is_quantized()) {
      return QuantizationType::UINT8;
    } else {
      return QuantizationType::NONE;
    }
  }

  const float16 *embeddings_quant_scales(int i) const override {
    TC_DCHECK(InRange(i, embeddings_size()));
    if (trimmed_proto_.embeddings(i).is_quantized()) {
      // Each embedding matrix has two atttached data blobs (hence the "2 * i"):
      // one blob with the quantized values and (immediately after it, hence the
      // "+ 1") one blob with the scales.
      int blob_index = embeddings_blob_offset_ + 2 * i + 1;
      DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
      return reinterpret_cast<const float16 *>(data_blob_view.data());
    } else {
      return nullptr;
    }
  }

  int hidden_size() const override { return trimmed_proto_.hidden_size(); }

  int hidden_num_rows(int i) const override {
    TC_DCHECK(InRange(i, hidden_size()));
    return trimmed_proto_.hidden(i).rows();
  }

  int hidden_num_cols(int i) const override {
    TC_DCHECK(InRange(i, hidden_size()));
    return trimmed_proto_.hidden(i).cols();
  }

  const void *hidden_weights(int i) const override {
    TC_DCHECK(InRange(i, hidden_size()));
    DataBlobView data_blob_view =
        memory_reader_.data_blob_view(hidden_blob_offset_ + i);
    return data_blob_view.data();
  }

  int hidden_bias_size() const override {
    return trimmed_proto_.hidden_bias_size();
  }

  int hidden_bias_num_rows(int i) const override {
    TC_DCHECK(InRange(i, hidden_bias_size()));
    return trimmed_proto_.hidden_bias(i).rows();
  }

  int hidden_bias_num_cols(int i) const override {
    TC_DCHECK(InRange(i, hidden_bias_size()));
    return trimmed_proto_.hidden_bias(i).cols();
  }

  const void *hidden_bias_weights(int i) const override {
    TC_DCHECK(InRange(i, hidden_bias_size()));
    DataBlobView data_blob_view =
        memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i);
    return data_blob_view.data();
  }

  int softmax_size() const override {
    return trimmed_proto_.has_softmax() ? 1 : 0;
  }

  int softmax_num_rows(int i) const override {
    TC_DCHECK(InRange(i, softmax_size()));
    return trimmed_proto_.softmax().rows();
  }

  int softmax_num_cols(int i) const override {
    TC_DCHECK(InRange(i, softmax_size()));
    return trimmed_proto_.softmax().cols();
  }

  const void *softmax_weights(int i) const override {
    TC_DCHECK(InRange(i, softmax_size()));
    DataBlobView data_blob_view =
        memory_reader_.data_blob_view(softmax_blob_offset_ + i);
    return data_blob_view.data();
  }

  int softmax_bias_size() const override {
    return trimmed_proto_.has_softmax_bias() ? 1 : 0;
  }

  int softmax_bias_num_rows(int i) const override {
    TC_DCHECK(InRange(i, softmax_bias_size()));
    return trimmed_proto_.softmax_bias().rows();
  }

  int softmax_bias_num_cols(int i) const override {
    TC_DCHECK(InRange(i, softmax_bias_size()));
    return trimmed_proto_.softmax_bias().cols();
  }

  const void *softmax_bias_weights(int i) const override {
    TC_DCHECK(InRange(i, softmax_bias_size()));
    DataBlobView data_blob_view =
        memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i);
    return data_blob_view.data();
  }

  int embedding_num_features_size() const override {
    return trimmed_proto_.embedding_num_features_size();
  }

  int embedding_num_features(int i) const override {
    TC_DCHECK(InRange(i, embedding_num_features_size()));
    return trimmed_proto_.embedding_num_features(i);
  }

 private:
  MemoryImageReader<EmbeddingNetworkProto> memory_reader_;

  const EmbeddingNetworkProto &trimmed_proto_;

  // 0-based offsets in the list of data blobs for the different MatrixParams
  // fields.  E.g., the 1st hidden MatrixParams has its weights stored in the
  // data blob number hidden_blob_offset_, the 2nd one in hidden_blob_offset_ +
  // 1, and so on.
  int embeddings_blob_offset_;
  int hidden_blob_offset_;
  int hidden_bias_blob_offset_;
  int softmax_blob_offset_;
  int softmax_bias_blob_offset_;
};

}  // namespace nlp_core
}  // namespace libtextclassifier

#endif  // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_