// string.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Utilities to convert strings into FSTs.
//

#ifndef FST_LIB_STRING_H_
#define FST_LIB_STRING_H_

#include <fst/compact-fst.h>
#include <fst/icu.h>
#include <fst/mutable-fst.h>

DECLARE_string(fst_field_separator);

namespace fst {

// Functor compiling a string in an FST
template <class A>
class StringCompiler {
 public:
  typedef A Arc;
  typedef typename A::Label Label;
  typedef typename A::Weight Weight;

  enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };

  StringCompiler(TokenType type, const SymbolTable *syms = 0,
                 Label unknown_label = kNoLabel,
                 bool allow_negative = false)
      : token_type_(type), syms_(syms), unknown_label_(unknown_label),
        allow_negative_(allow_negative) {}

  // Compile string 's' into FST 'fst'.
  template <class F>
  bool operator()(const string &s, F *fst) const {
    vector<Label> labels;
    if (!ConvertStringToLabels(s, &labels))
      return false;
    Compile(labels, fst);
    return true;
  }

  template <class F>
  bool operator()(const string &s, F *fst, Weight w) const {
    vector<Label> labels;
    if (!ConvertStringToLabels(s, &labels))
      return false;
    Compile(labels, fst, w);
    return true;
  }

 private:
  bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
    labels->clear();
    if (token_type_ == BYTE) {
      for (size_t i = 0; i < str.size(); ++i)
        labels->push_back(static_cast<unsigned char>(str[i]));
    } else if (token_type_ == UTF8) {
      return UTF8StringToLabels(str, labels);
    } else {
      char *c_str = new char[str.size() + 1];
      str.copy(c_str, str.size());
      c_str[str.size()] = 0;
      vector<char *> vec;
      string separator = "\n" + FLAGS_fst_field_separator;
      SplitToVector(c_str, separator.c_str(), &vec, true);
      for (size_t i = 0; i < vec.size(); ++i) {
        Label label;
        if (!ConvertSymbolToLabel(vec[i], &label))
          return false;
        labels->push_back(label);
      }
      delete[] c_str;
    }
    return true;
  }

  void Compile(const vector<Label> &labels, MutableFst<A> *fst,
               const Weight &weight = Weight::One()) const {
    fst->DeleteStates();
    while (fst->NumStates() <= labels.size())
      fst->AddState();
    for (size_t i = 0; i < labels.size(); ++i)
      fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
    fst->SetStart(0);
    fst->SetFinal(labels.size(), weight);
  }

  template <class Unsigned>
  void Compile(const vector<Label> &labels,
               CompactFst<A, StringCompactor<A>, Unsigned> *fst) const {
    fst->SetCompactElements(labels.begin(), labels.end());
  }

  template <class Unsigned>
  void Compile(const vector<Label> &labels,
               CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst,
               const Weight &weight = Weight::One()) const {
    vector<pair<Label, Weight> > compacts;
    compacts.reserve(labels.size());
    for (size_t i = 0; i < labels.size(); ++i)
      compacts.push_back(make_pair(labels[i], Weight::One()));
    compacts.back().second = weight;
    fst->SetCompactElements(compacts.begin(), compacts.end());
  }

  bool ConvertSymbolToLabel(const char *s, Label* output) const {
    int64 n;
    if (syms_) {
      n = syms_->Find(s);
      if ((n == -1) && (unknown_label_ != kNoLabel))
        n = unknown_label_;
      if (n == -1 || (!allow_negative_ && n < 0)) {
        VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
                << "\" is not mapped to any integer label, symbol table = "
                 << syms_->Name();
        return false;
      }
    } else {
      char *p;
      n = strtoll(s, &p, 10);
      if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
        VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
                << "= \"" << s << "\"";
        return false;
      }
    }
    *output = n;
    return true;
  }

  TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
  const SymbolTable *syms_;  // Symbol table used when token type is symbol
  Label unknown_label_;      // Label for token missing from symbol table
  bool allow_negative_;      // Negative labels allowed?

  DISALLOW_COPY_AND_ASSIGN(StringCompiler);
};

// Functor to print a string FST as a string.
template <class A>
class StringPrinter {
 public:
  typedef A Arc;
  typedef typename A::Label Label;
  typedef typename A::StateId StateId;
  typedef typename A::Weight Weight;

  enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };

  StringPrinter(TokenType token_type,
                const SymbolTable *syms = 0)
      : token_type_(token_type), syms_(syms) {}

  // Convert the FST 'fst' into the string 'output'
  bool operator()(const Fst<A> &fst, string *output) {
    bool is_a_string = FstToLabels(fst);
    if (!is_a_string) {
      VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
      return false;
    }

    output->clear();

    if (token_type_ == SYMBOL) {
      stringstream sstrm;
      for (size_t i = 0; i < labels_.size(); ++i) {
        if (i)
          sstrm << *(FLAGS_fst_field_separator.rbegin());
        if (!PrintLabel(labels_[i], sstrm))
          return false;
      }
      *output = sstrm.str();
    } else if (token_type_ == BYTE) {
      output->reserve(labels_.size());
      for (size_t i = 0; i < labels_.size(); ++i) {
        output->push_back(labels_[i]);
      }
    } else if (token_type_ == UTF8) {
      return LabelsToUTF8String(labels_, output);
    } else {
      VLOG(1) << "StringPrinter::operator(): Unknown token type: "
              << token_type_;
      return false;
    }
    return true;
  }

 private:
  bool FstToLabels(const Fst<A> &fst) {
    labels_.clear();

    StateId s = fst.Start();
    if (s == kNoStateId) {
      VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
              << "string fst.";
      return false;
    }

    while (fst.Final(s) == Weight::Zero()) {
      ArcIterator<Fst<A> > aiter(fst, s);
      if (aiter.Done()) {
        VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
                << "not reach final state.";
        return false;
      }

      const A& arc = aiter.Value();
      labels_.push_back(arc.olabel);

      s = arc.nextstate;
      if (s == kNoStateId) {
        VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
                << "state.";
        return false;
      }

      aiter.Next();
      if (!aiter.Done()) {
        VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
                << "outgoing arcs found.";
        return false;
      }
    }

    return true;
  }

  bool PrintLabel(Label lab, ostream& ostrm) {
    if (syms_) {
      string symbol = syms_->Find(lab);
      if (symbol == "") {
        VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
                << "mapped to any textual symbol, symbol table = "
                 << syms_->Name();
        return false;
      }
      ostrm << symbol;
    } else {
      ostrm << lab;
    }
    return true;
  }

  TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
  const SymbolTable *syms_;  // Symbol table used when token type is symbol
  vector<Label> labels_;     // Input FST labels.

  DISALLOW_COPY_AND_ASSIGN(StringPrinter);
};

}  // namespace fst

#endif // FST_LIB_STRING_H_