/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "lang_id/lang-id.h"

#include <stdio.h>

#include <algorithm>
#include <limits>
#include <memory>
#include <string>
#include <vector>

#include "common/algorithm.h"
#include "common/embedding-network-params-from-proto.h"
#include "common/embedding-network.pb.h"
#include "common/embedding-network.h"
#include "common/feature-extractor.h"
#include "common/file-utils.h"
#include "common/list-of-strings.pb.h"
#include "common/memory_image/in-memory-model-data.h"
#include "common/mmap.h"
#include "common/softmax.h"
#include "common/task-context.h"
#include "lang_id/custom-tokenizer.h"
#include "lang_id/lang-id-brain-interface.h"
#include "lang_id/language-identifier-features.h"
#include "lang_id/light-sentence-features.h"
#include "lang_id/light-sentence.h"
#include "lang_id/relevant-script-feature.h"
#include "util/base/logging.h"
#include "util/base/macros.h"

using ::libtextclassifier::nlp_core::file_utils::ParseProtoFromMemory;

namespace libtextclassifier {
namespace nlp_core {
namespace lang_id {

namespace {
// Default value for the probability threshold; see comments for
// LangId::SetProbabilityThreshold().
static const float kDefaultProbabilityThreshold = 0.50;

// Default value for min text size below which our model can't provide a
// meaningful prediction.
static const int kDefaultMinTextSizeInBytes = 20;

// Initial value for the default language for LangId::FindLanguage().  The
// default language can be changed (for an individual LangId object) using
// LangId::SetDefaultLanguage().
static const char kInitialDefaultLanguage[] = "";

// Returns total number of bytes of the words from sentence, without the ^
// (start-of-word) and $ (end-of-word) markers.  Note: "real text" means that
// this ignores whitespace and punctuation characters from the original text.
int GetRealTextSize(const LightSentence &sentence) {
  int total = 0;
  for (int i = 0; i < sentence.num_words(); ++i) {
    TC_DCHECK(!sentence.word(i).empty());
    TC_DCHECK_EQ('^', sentence.word(i).front());
    TC_DCHECK_EQ('$', sentence.word(i).back());
    total += sentence.word(i).size() - 2;
  }
  return total;
}

}  // namespace

// Class that performs all work behind LangId.
class LangIdImpl {
 public:
  explicit LangIdImpl(const std::string &filename) {
    // Using mmap as a fast way to read the model bytes.
    ScopedMmap scoped_mmap(filename);
    MmapHandle mmap_handle = scoped_mmap.handle();
    if (!mmap_handle.ok()) {
      TC_LOG(ERROR) << "Unable to read model bytes.";
      return;
    }

    Initialize(mmap_handle.to_stringpiece());
  }

  explicit LangIdImpl(int fd) {
    // Using mmap as a fast way to read the model bytes.
    ScopedMmap scoped_mmap(fd);
    MmapHandle mmap_handle = scoped_mmap.handle();
    if (!mmap_handle.ok()) {
      TC_LOG(ERROR) << "Unable to read model bytes.";
      return;
    }

    Initialize(mmap_handle.to_stringpiece());
  }

  LangIdImpl(const char *ptr, size_t length) {
    Initialize(StringPiece(ptr, length));
  }

  void Initialize(StringPiece model_bytes) {
    // Will set valid_ to true only on successful initialization.
    valid_ = false;

    // Make sure all relevant features are registered:
    ContinuousBagOfNgramsFunction::RegisterClass();
    RelevantScriptFeature::RegisterClass();

    // NOTE(salcianu): code below relies on the fact that the current features
    // do not rely on data from a TaskInput.  Otherwise, one would have to use
    // the more complex model registration mechanism, which requires more code.
    InMemoryModelData model_data(model_bytes);
    TaskContext context;
    if (!model_data.GetTaskSpec(context.mutable_spec())) {
      TC_LOG(ERROR) << "Unable to get model TaskSpec";
      return;
    }

    if (!ParseNetworkParams(model_data, &context)) {
      return;
    }
    if (!ParseListOfKnownLanguages(model_data, &context)) {
      return;
    }

    network_.reset(new EmbeddingNetwork(network_params_.get()));
    if (!network_->is_valid()) {
      return;
    }

    probability_threshold_ =
        context.Get("reliability_thresh", kDefaultProbabilityThreshold);
    min_text_size_in_bytes_ =
        context.Get("min_text_size_in_bytes", kDefaultMinTextSizeInBytes);
    version_ = context.Get("version", 0);

    if (!lang_id_brain_interface_.Init(&context)) {
      return;
    }
    valid_ = true;
  }

  void SetProbabilityThreshold(float threshold) {
    probability_threshold_ = threshold;
  }

  void SetDefaultLanguage(const std::string &lang) { default_language_ = lang; }

  std::string FindLanguage(const std::string &text) const {
    std::vector<float> scores = ScoreLanguages(text);
    if (scores.empty()) {
      return default_language_;
    }

    // Softmax label with max score.
    int label = GetArgMax(scores);
    float probability = scores[label];
    if (probability < probability_threshold_) {
      return default_language_;
    }
    return GetLanguageForSoftmaxLabel(label);
  }

  std::vector<std::pair<std::string, float>> FindLanguages(
      const std::string &text) const {
    std::vector<float> scores = ScoreLanguages(text);

    std::vector<std::pair<std::string, float>> result;
    for (int i = 0; i < scores.size(); i++) {
      result.push_back({GetLanguageForSoftmaxLabel(i), scores[i]});
    }

    // To avoid crashing clients that always expect at least one predicted
    // language, we promised (see doc for this method) that the result always
    // contains at least one element.
    if (result.empty()) {
      // We use a tiny probability, such that any client that uses a meaningful
      // probability threshold ignores this prediction.  We don't use 0.0f, to
      // avoid crashing clients that normalize the probabilities we return here.
      result.push_back({default_language_, 0.001f});
    }
    return result;
  }

  std::vector<float> ScoreLanguages(const std::string &text) const {
    if (!is_valid()) {
      return {};
    }

    // Create a Sentence storing the input text.
    LightSentence sentence;
    TokenizeTextForLangId(text, &sentence);

    if (GetRealTextSize(sentence) < min_text_size_in_bytes_) {
      return {};
    }

    // TODO(salcianu): reuse vector<FeatureVector>.
    std::vector<FeatureVector> features(
        lang_id_brain_interface_.NumEmbeddings());
    lang_id_brain_interface_.GetFeatures(&sentence, &features);

    // Predict language.
    EmbeddingNetwork::Vector scores;
    network_->ComputeFinalScores(features, &scores);

    return ComputeSoftmax(scores);
  }

  bool is_valid() const { return valid_; }

  int version() const { return version_; }

 private:
  // Returns name of the (in-memory) file for the indicated TaskInput from
  // context.
  static std::string GetInMemoryFileNameForTaskInput(
      const std::string &input_name, TaskContext *context) {
    TaskInput *task_input = context->GetInput(input_name);
    if (task_input->part_size() != 1) {
      TC_LOG(ERROR) << "TaskInput " << input_name << " has "
                    << task_input->part_size() << " parts";
      return "";
    }
    return task_input->part(0).file_pattern();
  }

  bool ParseNetworkParams(const InMemoryModelData &model_data,
                          TaskContext *context) {
    const std::string input_name = "language-identifier-network";
    const std::string input_file_name =
        GetInMemoryFileNameForTaskInput(input_name, context);
    if (input_file_name.empty()) {
      TC_LOG(ERROR) << "No input file name for TaskInput " << input_name;
      return false;
    }
    StringPiece bytes = model_data.GetBytesForInputFile(input_file_name);
    if (bytes.data() == nullptr) {
      TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name;
      return false;
    }
    std::unique_ptr<EmbeddingNetworkProto> proto(new EmbeddingNetworkProto());
    if (!ParseProtoFromMemory(bytes, proto.get())) {
      TC_LOG(ERROR) << "Unable to parse EmbeddingNetworkProto";
      return false;
    }
    network_params_.reset(
        new EmbeddingNetworkParamsFromProto(std::move(proto)));
    if (!network_params_->is_valid()) {
      TC_LOG(ERROR) << "EmbeddingNetworkParamsFromProto not valid";
      return false;
    }
    return true;
  }

  // Parses dictionary with known languages (i.e., field languages_) from a
  // TaskInput of context.  Note: that TaskInput should be a ListOfStrings proto
  // with a single element, the serialized form of a ListOfStrings.
  //
  bool ParseListOfKnownLanguages(const InMemoryModelData &model_data,
                                 TaskContext *context) {
    const std::string input_name = "language-name-id-map";
    const std::string input_file_name =
        GetInMemoryFileNameForTaskInput(input_name, context);
    if (input_file_name.empty()) {
      TC_LOG(ERROR) << "No input file name for TaskInput " << input_name;
      return false;
    }
    StringPiece bytes = model_data.GetBytesForInputFile(input_file_name);
    if (bytes.data() == nullptr) {
      TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name;
      return false;
    }
    ListOfStrings records;
    if (!ParseProtoFromMemory(bytes, &records)) {
      TC_LOG(ERROR) << "Unable to parse ListOfStrings from TaskInput "
                    << input_name;
      return false;
    }
    if (records.element_size() != 1) {
      TC_LOG(ERROR) << "Wrong number of records in TaskInput " << input_name
                    << " : " << records.element_size();
      return false;
    }
    if (!ParseProtoFromMemory(std::string(records.element(0)), &languages_)) {
      TC_LOG(ERROR) << "Unable to parse dictionary with known languages";
      return false;
    }
    return true;
  }

  // Returns language code for a softmax label.  See comments for languages_
  // field.  If label is out of range, returns default_language_.
  std::string GetLanguageForSoftmaxLabel(int label) const {
    if ((label >= 0) && (label < languages_.element_size())) {
      return languages_.element(label);
    } else {
      TC_LOG(ERROR) << "Softmax label " << label << " outside range [0, "
                    << languages_.element_size() << ")";
      return default_language_;
    }
  }

  LangIdBrainInterface lang_id_brain_interface_;

  // Parameters for the neural network network_ (see below).
  std::unique_ptr<EmbeddingNetworkParamsFromProto> network_params_;

  // Neural network to use for scoring.
  std::unique_ptr<EmbeddingNetwork> network_;

  // True if this object is ready to perform language predictions.
  bool valid_;

  // Only predictions with a probability (confidence) above this threshold are
  // reported.  Otherwise, we report default_language_.
  float probability_threshold_ = kDefaultProbabilityThreshold;

  // Min size of the input text for our predictions to be meaningful.  Below
  // this threshold, the underlying model may report a wrong language and a high
  // confidence score.
  int min_text_size_in_bytes_ = kDefaultMinTextSizeInBytes;

  // Version of the model.
  int version_ = -1;

  // Known languages: softmax label i (an integer) means languages_.element(i)
  // (something like "en", "fr", "ru", etc).
  ListOfStrings languages_;

  // Language code to return in case of errors.
  std::string default_language_ = kInitialDefaultLanguage;

  TC_DISALLOW_COPY_AND_ASSIGN(LangIdImpl);
};

LangId::LangId(const std::string &filename) : pimpl_(new LangIdImpl(filename)) {
  if (!pimpl_->is_valid()) {
    TC_LOG(ERROR) << "Unable to construct a valid LangId based "
                  << "on the data from " << filename
                  << "; nothing should crash, but "
                  << "accuracy will be bad.";
  }
}

LangId::LangId(int fd) : pimpl_(new LangIdImpl(fd)) {
  if (!pimpl_->is_valid()) {
    TC_LOG(ERROR) << "Unable to construct a valid LangId based "
                  << "on the data from descriptor " << fd
                  << "; nothing should crash, "
                  << "but accuracy will be bad.";
  }
}

LangId::LangId(const char *ptr, size_t length)
    : pimpl_(new LangIdImpl(ptr, length)) {
  if (!pimpl_->is_valid()) {
    TC_LOG(ERROR) << "Unable to construct a valid LangId based "
                  << "on the memory region; nothing should crash, "
                  << "but accuracy will be bad.";
  }
}

LangId::~LangId() = default;

void LangId::SetProbabilityThreshold(float threshold) {
  pimpl_->SetProbabilityThreshold(threshold);
}

void LangId::SetDefaultLanguage(const std::string &lang) {
  pimpl_->SetDefaultLanguage(lang);
}

std::string LangId::FindLanguage(const std::string &text) const {
  return pimpl_->FindLanguage(text);
}

std::vector<std::pair<std::string, float>> LangId::FindLanguages(
    const std::string &text) const {
  return pimpl_->FindLanguages(text);
}

bool LangId::is_valid() const { return pimpl_->is_valid(); }

int LangId::version() const { return pimpl_->version(); }

}  // namespace lang_id
}  // namespace nlp_core
}  // namespace libtextclassifier