C++程序  |  107行  |  4.45 KB

/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
#define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_

#include <memory>
#include <vector>

#include "base.h"
#include "common/vector-span.h"
#include "smartselect/types.h"

namespace libtextclassifier {

// Holds state for extracting features across multiple calls and reusing them.
// Assumes that features for each Token are independent.
class CachedFeatures {
 public:
  // Extracts the features for the given sequence of tokens.
  //  - context_size: Specifies how many tokens to the left, and how many
  //                   tokens to the right spans the context.
  //  - sparse_features, dense_features: Extracted features for each token.
  //  - feature_vector_fn: Writes features for given Token to the specified
  //                       storage.
  //                       NOTE: The function can assume that the underlying
  //                       storage is initialized to all zeros.
  //  - feature_vector_size: Size of a feature vector for one Token.
  CachedFeatures(VectorSpan<Token> tokens, int context_size,
                 const std::vector<std::vector<int>>& sparse_features,
                 const std::vector<std::vector<float>>& dense_features,
                 const std::function<bool(const std::vector<int>&,
                                          const std::vector<float>&, float*)>&
                     feature_vector_fn,
                 int feature_vector_size)
      : tokens_(tokens),
        context_size_(context_size),
        feature_vector_size_(feature_vector_size),
        remap_v0_feature_vector_(false),
        remap_v0_chargram_embedding_size_(-1) {
    Extract(sparse_features, dense_features, feature_vector_fn);
  }

  // Gets a VectorSpan with the features for given click position.
  bool Get(int click_pos, VectorSpan<float>* features,
           VectorSpan<Token>* output_tokens);

  // Turns on a compatibility mode, which re-maps the extracted features to the
  // v0 feature format (where the dense features were at the end).
  // WARNING: Internally v0_feature_storage_ is used as a backing buffer for
  // VectorSpan<float>, so the output of Extract is valid only until the next
  // call or destruction of the current CachedFeatures object.
  // TODO(zilka): Remove when we'll have retrained models.
  void SetV0FeatureMode(int chargram_embedding_size) {
    remap_v0_feature_vector_ = true;
    remap_v0_chargram_embedding_size_ = chargram_embedding_size;
    v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1));
  }

 protected:
  // Extracts features for all tokens and stores them for later retrieval.
  void Extract(const std::vector<std::vector<int>>& sparse_features,
               const std::vector<std::vector<float>>& dense_features,
               const std::function<bool(const std::vector<int>&,
                                        const std::vector<float>&, float*)>&
                   feature_vector_fn);

  // Remaps extracted features to V0 feature format. The mapping is using
  // the v0_feature_storage_ as the backing storage for the mapped features.
  // For each token the features consist of:
  //  - chargram embeddings
  //  - dense features
  // They are concatenated together as [chargram embeddings; dense features]
  // for each token independently.
  // The V0 features require that the chargram embeddings for tokens are
  // concatenated first together, and at the end, the dense features for the
  // tokens are concatenated to it.
  void RemapV0FeatureVector(VectorSpan<float>* features);

 private:
  const VectorSpan<Token> tokens_;
  const int context_size_;
  const int feature_vector_size_;
  bool remap_v0_feature_vector_;
  int remap_v0_chargram_embedding_size_;

  std::vector<float> features_;
  std::vector<float> v0_feature_storage_;
};

}  // namespace libtextclassifier

#endif  // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_