// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Authors: allauzen@google.com (Cyril Allauzen)
//          ttai@google.com (Terry Tai)
//          jpr@google.com (Jake Ratkiewicz)


#ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
#define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_

#include <libgen.h>
#include <string>
#include <vector>
using std::vector;

#include <fst/extensions/far/far.h>
#include <fst/string.h>

namespace fst {

// Construct a reader that provides FSTs from a file (stream) either on a
// line-by-line basis or on a per-stream basis.  Note that the freshly
// constructed reader is already set to the first input.
//
// Sample Usage:
//   for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) {
//     Fst *fst = reader.GetVectorFst();
//   }
template <class A>
class StringReader {
 public:
  typedef A Arc;
  typedef typename A::Label Label;
  typedef typename A::Weight Weight;
  typedef typename StringCompiler<A>::TokenType TokenType;

  enum EntryType { LINE = 1, FILE = 2 };

  StringReader(istream &istrm,
               const string &source,
               EntryType entry_type,
               TokenType token_type,
               bool allow_negative_labels,
               const SymbolTable *syms = 0,
               Label unknown_label = kNoStateId)
      : nline_(0), strm_(istrm), source_(source), entry_type_(entry_type),
        token_type_(token_type), done_(false),
        compiler_(token_type, syms, unknown_label, allow_negative_labels) {
    Next();  // Initialize the reader to the first input.
  }

  bool Done() {
    return done_;
  }

  void Next() {
    VLOG(1) << "Processing source " << source_ << " at line " << nline_;
    if (!strm_) {                    // We're done if we have no more input.
      done_ = true;
      return;
    }
    if (entry_type_ == LINE) {
      getline(strm_, content_);
      ++nline_;
    } else {
      content_.clear();
      string line;
      while (getline(strm_, line)) {
        ++nline_;
        content_.append(line);
        content_.append("\n");
      }
    }
    if (!strm_ && content_.empty())  // We're also done if we read off all the
      done_ = true;                  // whitespace at the end of a file.
  }

  VectorFst<A> *GetVectorFst() {
    VectorFst<A> *fst = new VectorFst<A>;
    if (compiler_(content_, fst)) {
      return fst;
    } else {
      delete fst;
      return NULL;
    }
  }

  CompactFst<A, StringCompactor<A> > *GetCompactFst() {
    CompactFst<A, StringCompactor<A> > *fst =
        new CompactFst<A, StringCompactor<A> >;
    if (compiler_(content_, fst)) {
      return fst;
    } else {
      delete fst;
      return NULL;
    }
  }

 private:
  size_t nline_;
  istream &strm_;
  string source_;
  EntryType entry_type_;
  TokenType token_type_;
  bool done_;
  StringCompiler<A> compiler_;
  string content_;  // The actual content of the input stream's next FST.

  DISALLOW_COPY_AND_ASSIGN(StringReader);
};

// Compute the minimal length required to encode each line number as a decimal
// number.
int KeySize(const char *filename);

template <class Arc>
void FarCompileStrings(const vector<string> &in_fnames,
                       const string &out_fname,
                       const string &fst_type,
                       const FarType &far_type,
                       int32 generate_keys,
                       FarEntryType fet,
                       FarTokenType tt,
                       const string &symbols_fname,
                       const string &unknown_symbol,
                       bool allow_negative_labels,
                       bool file_list_input,
                       const string &key_prefix,
                       const string &key_suffix) {
  typename StringReader<Arc>::EntryType entry_type;
  if (fet == FET_LINE) {
    entry_type = StringReader<Arc>::LINE;
  } else if (fet == FET_FILE) {
    entry_type = StringReader<Arc>::FILE;
  } else {
    FSTERROR() << "FarCompileStrings: unknown entry type";
    return;
  }

  typename StringCompiler<Arc>::TokenType token_type;
  if (tt == FTT_SYMBOL) {
    token_type = StringCompiler<Arc>::SYMBOL;
  } else if (tt == FTT_BYTE) {
    token_type = StringCompiler<Arc>::BYTE;
  } else if (tt == FTT_UTF8) {
    token_type = StringCompiler<Arc>::UTF8;
  } else {
    FSTERROR() << "FarCompileStrings: unknown token type";
    return;
  }

  bool compact;
  if (fst_type.empty() || (fst_type == "vector")) {
    compact = false;
  } else if (fst_type == "compact") {
    compact = true;
  } else {
    FSTERROR() << "FarCompileStrings: unknown fst type: "
               << fst_type;
    return;
  }

  const SymbolTable *syms = 0;
  typename Arc::Label unknown_label = kNoLabel;
  if (!symbols_fname.empty()) {
    syms = SymbolTable::ReadText(symbols_fname,
                                 allow_negative_labels);
    if (!syms) {
      FSTERROR() << "FarCompileStrings: error reading symbol table: "
                 << symbols_fname;
      return;
    }
    if (!unknown_symbol.empty()) {
      unknown_label = syms->Find(unknown_symbol);
      if (unknown_label == kNoLabel) {
        FSTERROR() << "FarCompileStrings: unknown label \"" << unknown_label
                   << "\" missing from symbol table: " << symbols_fname;
        return;
      }
    }
  }

  FarWriter<Arc> *far_writer =
      FarWriter<Arc>::Create(out_fname, far_type);
  if (!far_writer) return;

  vector<string> inputs;
  if (file_list_input) {
    for (int i = 1; i < in_fnames.size(); ++i) {
      ifstream istrm(in_fnames[i].c_str());
      string str;
      while (getline(istrm, str))
        inputs.push_back(str);
    }
  } else {
    inputs = in_fnames;
  }

  for (int i = 0, n = 0; i < inputs.size(); ++i) {
    int key_size = generate_keys ? generate_keys :
        (entry_type == StringReader<Arc>::FILE ? 1 :
         KeySize(inputs[i].c_str()));
    ifstream istrm(inputs[i].c_str());

    for (StringReader<Arc> reader(
             istrm, inputs[i], entry_type, token_type,
             allow_negative_labels, syms, unknown_label);
         !reader.Done();
         reader.Next()) {
      ++n;
      const Fst<Arc> *fst;
      if (compact)
        fst = reader.GetCompactFst();
      else
        fst = reader.GetVectorFst();
      if (!fst) {
        FSTERROR() << "FarCompileStrings: compiling string number " << n
                   << " in file " << inputs[i] << " failed with token_type = "
                   << (tt == FTT_BYTE ? "byte" :
                       (tt == FTT_UTF8 ? "utf8" :
                        (tt == FTT_SYMBOL ? "symbol" : "unknown")))
                   << " and entry_type = "
                   << (fet == FET_LINE ? "line" :
                       (fet == FET_FILE ? "file" : "unknown"));
        delete far_writer;
        delete syms;
        return;
      }
      ostringstream keybuf;
      keybuf.width(key_size);
      keybuf.fill('0');
      keybuf << n;
      string key;
      if (generate_keys > 0) {
        key = keybuf.str();
      } else {
        char* filename = new char[inputs[i].size() + 1];
        strcpy(filename, inputs[i].c_str());
        key = basename(filename);
        if (entry_type != StringReader<Arc>::FILE) {
          key += "-";
          key += keybuf.str();
        }
        delete[] filename;
      }
      far_writer->Add(key_prefix + key + key_suffix, *fst);
      delete fst;
    }
    if (generate_keys == 0)
      n = 0;
  }

  delete far_writer;
}

}  // namespace fst


#endif  // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_