普通文本  |  188行  |  5.6 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "annotator/number/number.h"

#include <climits>
#include <cstdlib>

#include "annotator/collections.h"
#include "utils/base/logging.h"

namespace libtextclassifier3 {

bool NumberAnnotator::ClassifyText(
    const UnicodeText& context, CodepointSpan selection_indices,
    AnnotationUsecase annotation_usecase,
    ClassificationResult* classification_result) const {
  int64 parsed_value;
  int num_prefix_codepoints;
  int num_suffix_codepoints;
  if (ParseNumber(UnicodeText::Substring(context, selection_indices.first,
                                         selection_indices.second),
                  &parsed_value, &num_prefix_codepoints,
                  &num_suffix_codepoints)) {
    ClassificationResult classification{Collections::Number(), 1.0};
    TC3_CHECK(classification_result != nullptr);
    classification_result->collection = Collections::Number();
    classification_result->score = options_->score();
    classification_result->priority_score = options_->priority_score();
    classification_result->numeric_value = parsed_value;
    return true;
  }
  return false;
}

bool NumberAnnotator::FindAll(const UnicodeText& context,
                              AnnotationUsecase annotation_usecase,
                              std::vector<AnnotatedSpan>* result) const {
  if (!options_->enabled() || ((1 << annotation_usecase) &
                               options_->enabled_annotation_usecases()) == 0) {
    return true;
  }

  const std::vector<Token> tokens = feature_processor_->Tokenize(context);
  for (const Token& token : tokens) {
    const UnicodeText token_text =
        UTF8ToUnicodeText(token.value, /*do_copy=*/false);
    int64 parsed_value;
    int num_prefix_codepoints;
    int num_suffix_codepoints;
    if (ParseNumber(token_text, &parsed_value, &num_prefix_codepoints,
                    &num_suffix_codepoints)) {
      ClassificationResult classification{Collections::Number(),
                                          options_->score()};
      classification.numeric_value = parsed_value;
      classification.priority_score = options_->priority_score();

      AnnotatedSpan annotated_span;
      annotated_span.span = {token.start + num_prefix_codepoints,
                             token.end - num_suffix_codepoints};
      annotated_span.classification.push_back(classification);

      result->push_back(annotated_span);
    }
  }

  return true;
}

std::unordered_set<int> NumberAnnotator::FlatbuffersVectorToSet(
    const flatbuffers::Vector<int32_t>* codepoints) {
  if (codepoints == nullptr) {
    return std::unordered_set<int>{};
  }

  std::unordered_set<int> result;
  for (const int codepoint : *codepoints) {
    result.insert(codepoint);
  }
  return result;
}

namespace {
UnicodeText::const_iterator ConsumeAndParseNumber(
    const UnicodeText::const_iterator& it_begin,
    const UnicodeText::const_iterator& it_end, int64* result) {
  *result = 0;

  // See if there's a sign in the beginning of the number.
  int sign = 1;
  auto it = it_begin;
  if (it != it_end) {
    if (*it == '-') {
      ++it;
      sign = -1;
    } else if (*it == '+') {
      ++it;
      sign = 1;
    }
  }

  while (it != it_end) {
    if (*it >= '0' && *it <= '9') {
      // When overflow is imminent we'll fail to parse the number.
      if (*result > INT64_MAX / 10) {
        return it_begin;
      }
      *result *= 10;
      *result += *it - '0';
    } else {
      *result *= sign;
      return it;
    }

    ++it;
  }

  *result *= sign;
  return it_end;
}
}  // namespace

bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* result,
                                  int* num_prefix_codepoints,
                                  int* num_suffix_codepoints) const {
  TC3_CHECK(result != nullptr && num_prefix_codepoints != nullptr &&
            num_suffix_codepoints != nullptr);
  auto it = text.begin();
  auto it_end = text.end();

  // Strip boundary codepoints from both ends.
  const CodepointSpan original_span{0, text.size_codepoints()};
  const CodepointSpan stripped_span =
      feature_processor_->StripBoundaryCodepoints(text, original_span);
  const int num_stripped_end = (original_span.second - stripped_span.second);
  std::advance(it, stripped_span.first);
  std::advance(it_end, -num_stripped_end);

  // Consume prefix codepoints.
  *num_prefix_codepoints = stripped_span.first;
  while (it != text.end()) {
    if (allowed_prefix_codepoints_.find(*it) ==
        allowed_prefix_codepoints_.end()) {
      break;
    }

    ++it;
    ++(*num_prefix_codepoints);
  }

  auto it_start = it;
  it = ConsumeAndParseNumber(it, text.end(), result);
  if (it == it_start) {
    return false;
  }

  // Consume suffix codepoints.
  bool valid_suffix = true;
  *num_suffix_codepoints = 0;
  while (it != it_end) {
    if (allowed_suffix_codepoints_.find(*it) ==
        allowed_suffix_codepoints_.end()) {
      valid_suffix = false;
      break;
    }

    ++it;
    ++(*num_suffix_codepoints);
  }
  *num_suffix_codepoints += num_stripped_end;
  return valid_suffix;
}

}  // namespace libtextclassifier3