// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// All Rights Reserved.
//
// Author : Johan Schalkwyk
//
// \file
// Classes to provide symbol-to-integer and integer-to-symbol mappings.

#ifndef FST_LIB_SYMBOL_TABLE_H__
#define FST_LIB_SYMBOL_TABLE_H__

#include <cstring>
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;


#include <fst/compat.h>
#include <iostream>
#include <fstream>
#include <sstream>


#include <map>

DECLARE_bool(fst_compat_symbols);

namespace fst {

// WARNING: Reading via symbol table read options should
//          not be used. This is a temporary work around for
//          reading symbol ranges of previously stored symbol sets.
struct SymbolTableReadOptions {
  SymbolTableReadOptions() { }

  SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
                         const string& source_)
      : string_hash_ranges(string_hash_ranges_),
        source(source_) { }

  vector<pair<int64, int64> > string_hash_ranges;
  string source;
};

struct SymbolTableTextOptions {
  SymbolTableTextOptions();

  bool allow_negative;
  string fst_field_separator;
};

class SymbolTableImpl {
 public:
  SymbolTableImpl(const string &name)
      : name_(name),
        available_key_(0),
        dense_key_limit_(0),
        check_sum_finalized_(false) {}

  explicit SymbolTableImpl(const SymbolTableImpl& impl)
      : name_(impl.name_),
        available_key_(0),
        dense_key_limit_(0),
        check_sum_finalized_(false) {
    for (size_t i = 0; i < impl.symbols_.size(); ++i) {
      AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
    }
  }

  ~SymbolTableImpl() {
    for (size_t i = 0; i < symbols_.size(); ++i)
      delete[] symbols_[i];
  }

  // TODO(johans): Add flag to specify whether the symbol
  //               should be indexed as string or int or both.
  int64 AddSymbol(const string& symbol, int64 key);

  int64 AddSymbol(const string& symbol) {
    int64 key = Find(symbol);
    return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
  }

  static SymbolTableImpl* ReadText(
      istream &strm, const string &name,
      const SymbolTableTextOptions &opts = SymbolTableTextOptions());

  static SymbolTableImpl* Read(istream &strm,
                               const SymbolTableReadOptions& opts);

  bool Write(ostream &strm) const;

  //
  // Return the string associated with the key. If the key is out of
  // range (<0, >max), return an empty string.
  string Find(int64 key) const {
    if (key >=0 && key < dense_key_limit_)
      return string(symbols_[key]);

    map<int64, const char*>::const_iterator it =
        key_map_.find(key);
    if (it == key_map_.end()) {
      return "";
    }
    return string(it->second);
  }

  //
  // Return the key associated with the symbol. If the symbol
  // does not exists, return SymbolTable::kNoSymbol.
  int64 Find(const string& symbol) const {
    return Find(symbol.c_str());
  }

  //
  // Return the key associated with the symbol. If the symbol
  // does not exists, return SymbolTable::kNoSymbol.
  int64 Find(const char* symbol) const {
    map<const char *, int64, StrCmp>::const_iterator it =
        symbol_map_.find(symbol);
    if (it == symbol_map_.end()) {
      return -1;
    }
    return it->second;
  }

  int64 GetNthKey(ssize_t pos) const {
    if ((pos < 0) || (pos >= symbols_.size())) return -1;
    else return Find(symbols_[pos]);
  }

  const string& Name() const { return name_; }

  int IncrRefCount() const {
    return ref_count_.Incr();
  }
  int DecrRefCount() const {
    return ref_count_.Decr();
  }
  int RefCount() const {
    return ref_count_.count();
  }

  string CheckSum() const {
    MaybeRecomputeCheckSum();
    return check_sum_string_;
  }

  string LabeledCheckSum() const {
    MaybeRecomputeCheckSum();
    return labeled_check_sum_string_;
  }

  int64 AvailableKey() const {
    return available_key_;
  }

  size_t NumSymbols() const {
    return symbols_.size();
  }

 private:
  // Recomputes the checksums (both of them) if we've had changes since the last
  // computation (i.e., if check_sum_finalized_ is false).
  // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
  // if the checksum is up-to-date (requiring no recomputation).
  void MaybeRecomputeCheckSum() const;

  struct StrCmp {
    bool operator()(const char *s1, const char *s2) const {
      return strcmp(s1, s2) < 0;
    }
  };

  string name_;
  int64 available_key_;
  int64 dense_key_limit_;
  vector<const char *> symbols_;
  map<int64, const char*> key_map_;
  map<const char *, int64, StrCmp> symbol_map_;

  mutable RefCounter ref_count_;
  mutable bool check_sum_finalized_;
  mutable string check_sum_string_;
  mutable string labeled_check_sum_string_;
  mutable Mutex check_sum_mutex_;
};

//
// \class SymbolTable
// \brief Symbol (string) to int and reverse mapping
//
// The SymbolTable implements the mappings of labels to strings and reverse.
// SymbolTables are used to describe the alphabet of the input and output
// labels for arcs in a Finite State Transducer.
//
// SymbolTables are reference counted and can therefore be shared across
// multiple machines. For example a language model grammar G, with a
// SymbolTable for the words in the language model can share this symbol
// table with the lexical representation L o G.
//
class SymbolTable {
 public:
  static const int64 kNoSymbol = -1;

  // Construct symbol table with an unspecified name.
  SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {}

  // Construct symbol table with a unique name.
  SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}

  // Create a reference counted copy.
  SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
    impl_->IncrRefCount();
  }

  // Derefence implentation object. When reference count hits 0, delete
  // implementation.
  virtual ~SymbolTable() {
    if (!impl_->DecrRefCount()) delete impl_;
  }

  // Copys the implemenation from one symbol table to another.
  void operator=(const SymbolTable &st) {
    if (impl_ != st.impl_) {
      st.impl_->IncrRefCount();
      if (!impl_->DecrRefCount()) delete impl_;
      impl_ = st.impl_;
    }
  }

  // Read an ascii representation of the symbol table from an istream. Pass a
  // name to give the resulting SymbolTable.
  static SymbolTable* ReadText(
      istream &strm, const string& name,
      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
    SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts);
    if (!impl)
      return 0;
    else
      return new SymbolTable(impl);
  }

  // read an ascii representation of the symbol table
  static SymbolTable* ReadText(const string& filename,
      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
    ifstream strm(filename.c_str(), ifstream::in);
    if (!strm) {
      LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
      return 0;
    }
    return ReadText(strm, filename, opts);
  }


  // WARNING: Reading via symbol table read options should
  //          not be used. This is a temporary work around.
  static SymbolTable* Read(istream &strm,
                           const SymbolTableReadOptions& opts) {
    SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
    if (!impl)
      return 0;
    else
      return new SymbolTable(impl);
  }

  // read a binary dump of the symbol table from a stream
  static SymbolTable* Read(istream &strm, const string& source) {
    SymbolTableReadOptions opts;
    opts.source = source;
    return Read(strm, opts);
  }

  // read a binary dump of the symbol table
  static SymbolTable* Read(const string& filename) {
    ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
    if (!strm) {
      LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
      return 0;
    }
    return Read(strm, filename);
  }

  //--------------------------------------------------------
  // Derivable Interface (final)
  //--------------------------------------------------------
  // create a reference counted copy
  virtual SymbolTable* Copy() const {
    return new SymbolTable(*this);
  }

  // Add a symbol with given key to table. A symbol table also
  // keeps track of the last available key (highest key value in
  // the symbol table).
  virtual int64 AddSymbol(const string& symbol, int64 key) {
    MutateCheck();
    return impl_->AddSymbol(symbol, key);
  }

  // Add a symbol to the table. The associated value key is automatically
  // assigned by the symbol table.
  virtual int64 AddSymbol(const string& symbol) {
    MutateCheck();
    return impl_->AddSymbol(symbol);
  }

  // Add another symbol table to this table. All key values will be offset
  // by the current available key (highest key value in the symbol table).
  // Note string symbols with the same key value with still have the same
  // key value after the symbol table has been merged, but a different
  // value. Adding symbol tables do not result in changes in the base table.
  virtual void AddTable(const SymbolTable& table);

  // return the name of the symbol table
  virtual const string& Name() const {
    return impl_->Name();
  }

  // Return the label-agnostic MD5 check-sum for this table.  All new symbols
  // added to the table will result in an updated checksum.
  // DEPRECATED.
  virtual string CheckSum() const {
    return impl_->CheckSum();
  }

  // Same as CheckSum(), but this returns an label-dependent version.
  virtual string LabeledCheckSum() const {
    return impl_->LabeledCheckSum();
  }

  virtual bool Write(ostream &strm) const {
    return impl_->Write(strm);
  }

  bool Write(const string& filename) const {
    ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
    if (!strm) {
      LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
      return false;
    }
    return Write(strm);
  }

  // Dump an ascii text representation of the symbol table via a stream
  virtual bool WriteText(
      ostream &strm,
      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const;

  // Dump an ascii text representation of the symbol table
  bool WriteText(const string& filename) const {
    ofstream strm(filename.c_str());
    if (!strm) {
      LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
      return false;
    }
    return WriteText(strm);
  }

  // Return the string associated with the key. If the key is out of
  // range (<0, >max), log error and return an empty string.
  virtual string Find(int64 key) const {
    return impl_->Find(key);
  }

  // Return the key associated with the symbol. If the symbol
  // does not exists, log error and  return SymbolTable::kNoSymbol
  virtual int64 Find(const string& symbol) const {
    return impl_->Find(symbol);
  }

  // Return the key associated with the symbol. If the symbol
  // does not exists, log error and  return SymbolTable::kNoSymbol
  virtual int64 Find(const char* symbol) const {
    return impl_->Find(symbol);
  }

  // Return the current available key (i.e highest key number+1) in
  // the symbol table
  virtual int64 AvailableKey(void) const {
    return impl_->AvailableKey();
  }

  // Return the current number of symbols in table (not necessarily
  // equal to AvailableKey())
  virtual size_t NumSymbols(void) const {
    return impl_->NumSymbols();
  }

  virtual int64 GetNthKey(ssize_t pos) const {
    return impl_->GetNthKey(pos);
  }

 private:
  explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}

  void MutateCheck() {
    // Copy on write
    if (impl_->RefCount() > 1) {
      impl_->DecrRefCount();
      impl_ = new SymbolTableImpl(*impl_);
    }
  }

  const SymbolTableImpl* Impl() const {
    return impl_;
  }

 private:
  SymbolTableImpl* impl_;
};


//
// \class SymbolTableIterator
// \brief Iterator class for symbols in a symbol table
class SymbolTableIterator {
 public:
  SymbolTableIterator(const SymbolTable& table)
      : table_(table),
        pos_(0),
        nsymbols_(table.NumSymbols()),
        key_(table.GetNthKey(0)) { }

  ~SymbolTableIterator() { }

  // is iterator done
  bool Done(void) {
    return (pos_ == nsymbols_);
  }

  // return the Value() of the current symbol (int64 key)
  int64 Value(void) {
    return key_;
  }

  // return the string of the current symbol
  string Symbol(void) {
    return table_.Find(key_);
  }

  // advance iterator forward
  void Next(void) {
    ++pos_;
    if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
  }

  // reset iterator
  void Reset(void) {
    pos_ = 0;
    key_ = table_.GetNthKey(0);
  }

 private:
  const SymbolTable& table_;
  ssize_t pos_;
  size_t nsymbols_;
  int64 key_;
};


// Tests compatibilty between two sets of symbol tables
inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
                          bool warning = true) {
  if (!FLAGS_fst_compat_symbols) {
    return true;
  } else if (!syms1 && !syms2) {
    return true;
  } else if (syms1 && !syms2) {
    if (warning)
      LOG(WARNING) <<
          "CompatSymbols: first symbol table present but second missing";
    return false;
  } else if (!syms1 && syms2) {
    if (warning)
      LOG(WARNING) <<
          "CompatSymbols: second symbol table present but first missing";
    return false;
  } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
    if (warning)
      LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
    return false;
  } else {
    return true;
  }
}


// Relabels a symbol table as specified by the input vector of pairs
// (old label, new label). The new symbol table only retains symbols
// for which a relabeling is *explicitely* specified.
// TODO(allauzen): consider adding options to allow for some form
// of implicit identity relabeling.
template <class Label>
SymbolTable *RelabelSymbolTable(const SymbolTable *table,
                                const vector<pair<Label, Label> > &pairs) {
  SymbolTable *new_table = new SymbolTable(
      table->Name().empty() ? string() :
      (string("relabeled_") + table->Name()));

  for (size_t i = 0; i < pairs.size(); ++i)
    new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);

  return new_table;
}

// Symbol Table Serialization
inline void SymbolTableToString(const SymbolTable *table, string *result) {
  ostringstream ostrm;
  table->Write(ostrm);
  *result = ostrm.str();
}

inline SymbolTable *StringToSymbolTable(const string &s) {
  istringstream istrm(s);
  return SymbolTable::Read(istrm, SymbolTableReadOptions());
}



}  // namespace fst

#endif  // FST_LIB_SYMBOL_TABLE_H__