普通文本  |  103行  |  3.56 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "lang_id/fb_model/model-provider-from-fb.h"

#include "lang_id/common/file/file-utils.h"
#include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
#include "lang_id/common/flatbuffers/model-utils.h"
#include "lang_id/common/lite_strings/str-split.h"

namespace libtextclassifier3 {
namespace mobile {
namespace lang_id {

ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(const string &filename)

    // Using mmap as a fast way to read the model bytes.  As the file is
    // unmapped only when the field scoped_mmap_ is destructed, the model bytes
    // stay alive for the entire lifetime of this object.
    : scoped_mmap_(new ScopedMmap(filename)) {
  Initialize(scoped_mmap_->handle().to_stringpiece());
}

ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(int fd)

    // Using mmap as a fast way to read the model bytes.  As the file is
    // unmapped only when the field scoped_mmap_ is destructed, the model bytes
    // stay alive for the entire lifetime of this object.
    : scoped_mmap_(new ScopedMmap(fd)) {
  Initialize(scoped_mmap_->handle().to_stringpiece());
}

void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
  // Note: valid_ was initialized to false.  In the code below, we set valid_ to
  // true only if all initialization steps completed successfully.  Otherwise,
  // we return early, leaving valid_ to its default value false.
  model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
  if (model_ == nullptr) {
    SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
    return;
  }

  // Initialize context_ parameters.
  if (!saft_fbs::FillParameters(*model_, &context_)) {
    // FillParameters already performs error logging.
    return;
  }

  // Init languages_.
  const string known_languages_str = context_.Get("supported_languages", "");
  for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
    languages_.emplace_back(sp);
  }
  if (languages_.empty()) {
    SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
    return;
  }

  // Init nn_params_.
  if (!InitNetworkParams()) {
    // InitNetworkParams already performs error logging.
    return;
  }

  // Everything looks fine.
  valid_ = true;
}

bool ModelProviderFromFlatbuffer::InitNetworkParams() {
  const string kInputName = "language-identifier-network";
  StringPiece bytes =
      saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
  if ((bytes.data() == nullptr) || bytes.empty()) {
    SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
    return false;
  }
  std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
      new EmbeddingNetworkParamsFromFlatbuffer(bytes));
  if (!nn_params_from_fb->is_valid()) {
    SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
    return false;
  }
  nn_params_ = std::move(nn_params_from_fb);
  return true;
}

}  // namespace lang_id
}  // namespace mobile
}  // namespace nlp_saft