/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_ #define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_ #include <string> #include <vector> #include "common/embedding-feature-extractor.h" #include "common/feature-extractor.h" #include "common/task-context.h" #include "common/workspace.h" #include "lang_id/light-sentence-features.h" #include "lang_id/light-sentence.h" #include "util/base/macros.h" namespace libtextclassifier { namespace nlp_core { namespace lang_id { // Specialization of EmbeddingFeatureExtractor that extracts from LightSentence. class LangIdEmbeddingFeatureExtractor : public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> { public: LangIdEmbeddingFeatureExtractor() {} const std::string ArgPrefix() const override { return "language_identifier"; } TC_DISALLOW_COPY_AND_ASSIGN(LangIdEmbeddingFeatureExtractor); }; // Handles sentence -> numeric_features and numeric_prediction -> language // conversions. class LangIdBrainInterface { public: LangIdBrainInterface() {} // Initializes resources and parameters. bool Init(TaskContext *context) { if (!feature_extractor_.Init(context)) { return false; } feature_extractor_.RequestWorkspaces(&workspace_registry_); return true; } // Extract features from sentence. On return, FeatureVector features[i] // contains the features for the embedding space #i. void GetFeatures(LightSentence *sentence, std::vector<FeatureVector> *features) const { WorkspaceSet workspace; workspace.Reset(workspace_registry_); feature_extractor_.Preprocess(&workspace, sentence); return feature_extractor_.ExtractFeatures(workspace, *sentence, features); } int NumEmbeddings() const { return feature_extractor_.NumEmbeddings(); } private: // Typed feature extractor for embeddings. LangIdEmbeddingFeatureExtractor feature_extractor_; // The registry of shared workspaces in the feature extractor. WorkspaceRegistry workspace_registry_; TC_DISALLOW_COPY_AND_ASSIGN(LangIdBrainInterface); }; } // namespace lang_id } // namespace nlp_core } // namespace libtextclassifier #endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_