// encode.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: johans@google.com (Johan Schalkwyk)
//
// \file
// Class to encode and decoder an fst.

#ifndef FST_LIB_ENCODE_H__
#define FST_LIB_ENCODE_H__

#include <climits>
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <string>
#include <vector>
using std::vector;

#include <fst/arc-map.h>
#include <fst/rmfinalepsilon.h>


namespace fst {

static const uint32 kEncodeLabels      = 0x0001;
static const uint32 kEncodeWeights     = 0x0002;
static const uint32 kEncodeFlags       = 0x0003;  // All non-internal flags

static const uint32 kEncodeHasISymbols = 0x0004;  // For internal use
static const uint32 kEncodeHasOSymbols = 0x0008;  // For internal use

enum EncodeType { ENCODE = 1, DECODE = 2 };

// Identifies stream data as an encode table (and its endianity)
static const int32 kEncodeMagicNumber = 2129983209;


// The following class encapsulates implementation details for the
// encoding and decoding of label/weight tuples used for encoding
// and decoding of Fsts. The EncodeTable is bidirectional. I.E it
// stores both the Tuple of encode labels and weights to a unique
// label, and the reverse.
template <class A>  class EncodeTable {
 public:
  typedef typename A::Label Label;
  typedef typename A::Weight Weight;

  // Encoded data consists of arc input/output labels and arc weight
  struct Tuple {
    Tuple() {}
    Tuple(Label ilabel_, Label olabel_, Weight weight_)
        : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
    Tuple(const Tuple& tuple)
        : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}

    Label ilabel;
    Label olabel;
    Weight weight;
  };

  // Comparison object for hashing EncodeTable Tuple(s).
  class TupleEqual {
   public:
    bool operator()(const Tuple* x, const Tuple* y) const {
      return (x->ilabel == y->ilabel &&
              x->olabel == y->olabel &&
              x->weight == y->weight);
    }
  };

  // Hash function for EncodeTabe Tuples. Based on the encode flags
  // we either hash the labels, weights or combination of them.
  class TupleKey {
   public:
    TupleKey()
        : encode_flags_(kEncodeLabels | kEncodeWeights) {}

    TupleKey(const TupleKey& key)
        : encode_flags_(key.encode_flags_) {}

    explicit TupleKey(uint32 encode_flags)
        : encode_flags_(encode_flags) {}

    size_t operator()(const Tuple* x) const {
      size_t hash = x->ilabel;
      const int lshift = 5;
      const int rshift = CHAR_BIT * sizeof(size_t) - 5;
      if (encode_flags_ & kEncodeLabels)
        hash = hash << lshift ^ hash >> rshift ^ x->olabel;
      if (encode_flags_ & kEncodeWeights)
        hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
      return hash;
    }

   private:
    int32 encode_flags_;
  };

  typedef unordered_map<const Tuple*,
                   Label,
                   TupleKey,
                   TupleEqual> EncodeHash;

  explicit EncodeTable(uint32 encode_flags)
      : flags_(encode_flags),
        encode_hash_(1024, TupleKey(encode_flags)),
        isymbols_(0), osymbols_(0) {}

  ~EncodeTable() {
    for (size_t i = 0; i < encode_tuples_.size(); ++i) {
      delete encode_tuples_[i];
    }
    delete isymbols_;
    delete osymbols_;
  }

  // Given an arc encode either input/ouptut labels or input/costs or both
  Label Encode(const A &arc) {
    const Tuple tuple(arc.ilabel,
                      flags_ & kEncodeLabels ? arc.olabel : 0,
                      flags_ & kEncodeWeights ? arc.weight : Weight::One());
    typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
    if (it == encode_hash_.end()) {
      encode_tuples_.push_back(new Tuple(tuple));
      encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
      return encode_tuples_.size();
    } else {
      return it->second;
    }
  }

  // Given an arc, look up its encoded label. Returns kNoLabel if not found.
  Label GetLabel(const A &arc) const {
    const Tuple tuple(arc.ilabel,
                      flags_ & kEncodeLabels ? arc.olabel : 0,
                      flags_ & kEncodeWeights ? arc.weight : Weight::One());
    typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
    if (it == encode_hash_.end()) {
      return kNoLabel;
    } else {
      return it->second;
    }
  }

  // Given an encode arc Label decode back to input/output labels and costs
  const Tuple* Decode(Label key) const {
    if (key < 1 || key > encode_tuples_.size()) {
      LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key;
      return 0;
    }
    return encode_tuples_[key - 1];
  }

  size_t Size() const { return encode_tuples_.size(); }

  bool Write(ostream &strm, const string &source) const;

  static EncodeTable<A> *Read(istream &strm, const string &source);

  const uint32 flags() const { return flags_ & kEncodeFlags; }

  int RefCount() const { return ref_count_.count(); }
  int IncrRefCount() { return ref_count_.Incr(); }
  int DecrRefCount() { return ref_count_.Decr(); }


  SymbolTable *InputSymbols() const { return isymbols_; }

  SymbolTable *OutputSymbols() const { return osymbols_; }

  void SetInputSymbols(const SymbolTable* syms) {
    if (isymbols_) delete isymbols_;
    if (syms) {
      isymbols_ = syms->Copy();
      flags_ |= kEncodeHasISymbols;
    } else {
      isymbols_ = 0;
      flags_ &= ~kEncodeHasISymbols;
    }
  }

  void SetOutputSymbols(const SymbolTable* syms) {
    if (osymbols_) delete osymbols_;
    if (syms) {
      osymbols_ = syms->Copy();
      flags_ |= kEncodeHasOSymbols;
    } else {
      osymbols_ = 0;
      flags_ &= ~kEncodeHasOSymbols;
    }
  }

 private:
  uint32 flags_;
  vector<Tuple*> encode_tuples_;
  EncodeHash encode_hash_;
  RefCounter ref_count_;
  SymbolTable *isymbols_;       // Pre-encoded ilabel symbol table
  SymbolTable *osymbols_;       // Pre-encoded olabel symbol table

  DISALLOW_COPY_AND_ASSIGN(EncodeTable);
};

template <class A> inline
bool EncodeTable<A>::Write(ostream &strm, const string &source) const {
  WriteType(strm, kEncodeMagicNumber);
  WriteType(strm, flags_);
  int64 size = encode_tuples_.size();
  WriteType(strm, size);
  for (size_t i = 0;  i < size; ++i) {
    const Tuple* tuple = encode_tuples_[i];
    WriteType(strm, tuple->ilabel);
    WriteType(strm, tuple->olabel);
    tuple->weight.Write(strm);
  }

  if (flags_ & kEncodeHasISymbols)
    isymbols_->Write(strm);

  if (flags_ & kEncodeHasOSymbols)
    osymbols_->Write(strm);

  strm.flush();
  if (!strm) {
    LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
    return false;
  }
  return true;
}

template <class A> inline
EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) {
  int32 magic_number = 0;
  ReadType(strm, &magic_number);
  if (magic_number != kEncodeMagicNumber) {
    LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
    return 0;
  }
  uint32 flags;
  ReadType(strm, &flags);
  EncodeTable<A> *table = new EncodeTable<A>(flags);

  int64 size;
  ReadType(strm, &size);
  if (!strm) {
    LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
    return 0;
  }

  for (size_t i = 0; i < size; ++i) {
    Tuple* tuple = new Tuple();
    ReadType(strm, &tuple->ilabel);
    ReadType(strm, &tuple->olabel);
    tuple->weight.Read(strm);
    if (!strm) {
      LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
      return 0;
    }
    table->encode_tuples_.push_back(tuple);
    table->encode_hash_[table->encode_tuples_.back()] =
        table->encode_tuples_.size();
  }

  if (flags & kEncodeHasISymbols)
    table->isymbols_ = SymbolTable::Read(strm, source);

  if (flags & kEncodeHasOSymbols)
    table->osymbols_ = SymbolTable::Read(strm, source);

  return table;
}


// A mapper to encode/decode weighted transducers. Encoding of an
// Fst is useful for performing classical determinization or minimization
// on a weighted transducer by treating it as an unweighted acceptor over
// encoded labels.
//
// The Encode mapper stores the encoding in a local hash table (EncodeTable)
// This table is shared (and reference counted) between the encoder and
// decoder. A decoder has read only access to the EncodeTable.
//
// The EncodeMapper allows on the fly encoding of the machine. As the
// EncodeTable is generated the same table may by used to decode the machine
// on the fly. For example in the following sequence of operations
//
//  Encode -> Determinize -> Decode
//
// we will use the encoding table generated during the encode step in the
// decode, even though the encoding is not complete.
//
template <class A> class EncodeMapper {
  typedef typename A::Weight Weight;
  typedef typename A::Label  Label;
 public:
  EncodeMapper(uint32 flags, EncodeType type)
    : flags_(flags),
      type_(type),
      table_(new EncodeTable<A>(flags)),
      error_(false) {}

  EncodeMapper(const EncodeMapper& mapper)
      : flags_(mapper.flags_),
        type_(mapper.type_),
        table_(mapper.table_),
        error_(false) {
    table_->IncrRefCount();
  }

  // Copy constructor but setting the type, typically to DECODE
  EncodeMapper(const EncodeMapper& mapper, EncodeType type)
      : flags_(mapper.flags_),
        type_(type),
        table_(mapper.table_),
        error_(mapper.error_) {
    table_->IncrRefCount();
  }

  ~EncodeMapper() {
    if (!table_->DecrRefCount()) delete table_;
  }

  A operator()(const A &arc);

  MapFinalAction FinalAction() const {
    return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
                   MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
  }

  MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }

  MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;}

  uint64 Properties(uint64 inprops) {
    uint64 outprops = inprops;
    if (error_) outprops |= kError;

    uint64 mask = kFstProperties;
    if (flags_ & kEncodeLabels)
      mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
    if (flags_ & kEncodeWeights)
      mask &= kILabelInvariantProperties & kWeightInvariantProperties &
          (type_ == ENCODE ? kAddSuperFinalProperties :
           kRmSuperFinalProperties);

    return outprops & mask;
  }

  const uint32 flags() const { return flags_; }
  const EncodeType type() const { return type_; }
  const EncodeTable<A> &table() const { return *table_; }

  bool Write(ostream &strm, const string& source) {
    return table_->Write(strm, source);
  }

  bool Write(const string& filename) {
    ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
    if (!strm) {
      LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
      return false;
    }
    return Write(strm, filename);
  }

  static EncodeMapper<A> *Read(istream &strm,
                               const string& source,
                               EncodeType type = ENCODE) {
    EncodeTable<A> *table = EncodeTable<A>::Read(strm, source);
    return table ? new EncodeMapper(table->flags(), type, table) : 0;
  }

  static EncodeMapper<A> *Read(const string& filename,
                               EncodeType type = ENCODE) {
    ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
    if (!strm) {
      LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
      return NULL;
    }
    return Read(strm, filename, type);
  }

  SymbolTable *InputSymbols() const { return table_->InputSymbols(); }

  SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); }

  void SetInputSymbols(const SymbolTable* syms) {
    table_->SetInputSymbols(syms);
  }

  void SetOutputSymbols(const SymbolTable* syms) {
    table_->SetOutputSymbols(syms);
  }

 private:
  uint32 flags_;
  EncodeType type_;
  EncodeTable<A>* table_;
  bool error_;

  explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
      : flags_(flags), type_(type), table_(table) {}
  void operator=(const EncodeMapper &);  // Disallow.
};

template <class A> inline
A EncodeMapper<A>::operator()(const A &arc) {
  if (type_ == ENCODE) {  // labels and/or weights to single label
    if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
        (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
         arc.weight == Weight::Zero())) {
      return arc;
    } else {
      Label label = table_->Encode(arc);
      return A(label,
               flags_ & kEncodeLabels ? label : arc.olabel,
               flags_ & kEncodeWeights ? Weight::One() : arc.weight,
               arc.nextstate);
    }
  } else {  // type_ == DECODE
    if (arc.nextstate == kNoStateId) {
      return arc;
    } else {
      if (arc.ilabel == 0) return arc;
      if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
        FSTERROR() << "EncodeMapper: Label-encoded arc has different "
            "input and output labels";
        error_ = true;
      }
      if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
        FSTERROR() <<
            "EncodeMapper: Weight-encoded arc has non-trivial weight";
        error_ = true;
      }
      const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel);
      if (!tuple) {
        FSTERROR() << "EncodeMapper: decode failed";
        error_ = true;
        return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate);
      } else {
        return A(tuple->ilabel,
                 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
                 flags_ & kEncodeWeights ? tuple->weight : arc.weight,
                 arc.nextstate);
      }
    }
  }
}


// Complexity: O(nstates + narcs)
template<class A> inline
void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
  mapper->SetInputSymbols(fst->InputSymbols());
  mapper->SetOutputSymbols(fst->OutputSymbols());
  ArcMap(fst, mapper);
}

template<class A> inline
void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
  ArcMap(fst, EncodeMapper<A>(mapper, DECODE));
  RmFinalEpsilon(fst);
  fst->SetInputSymbols(mapper.InputSymbols());
  fst->SetOutputSymbols(mapper.OutputSymbols());
}


// On the fly label and/or weight encoding of input Fst
//
// Complexity:
// - Constructor: O(1)
// - Traversal: O(nstates_visited + narcs_visited), assuming constant
//   time to visit an input state or arc.
template <class A>
class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
 public:
  typedef A Arc;
  typedef EncodeMapper<A> C;
  typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
  using ImplToFst<Impl>::GetImpl;

  EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
      : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {
    encoder->SetInputSymbols(fst.InputSymbols());
    encoder->SetOutputSymbols(fst.OutputSymbols());
  }

  EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
      : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {}

  // See Fst<>::Copy() for doc.
  EncodeFst(const EncodeFst<A> &fst, bool copy = false)
      : ArcMapFst<A, A, C>(fst, copy) {}

  // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc.
  virtual EncodeFst<A> *Copy(bool safe = false) const {
    if (safe) {
      FSTERROR() << "EncodeFst::Copy(true): not allowed.";
      GetImpl()->SetProperties(kError, kError);
    }
    return new EncodeFst(*this);
  }
};


// On the fly label and/or weight encoding of input Fst
//
// Complexity:
// - Constructor: O(1)
// - Traversal: O(nstates_visited + narcs_visited), assuming constant
//   time to visit an input state or arc.
template <class A>
class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
 public:
  typedef A Arc;
  typedef EncodeMapper<A> C;
  typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
  using ImplToFst<Impl>::GetImpl;

  DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
      : ArcMapFst<A, A, C>(fst,
                            EncodeMapper<A>(encoder, DECODE),
                            ArcMapFstOptions()) {
    GetImpl()->SetInputSymbols(encoder.InputSymbols());
    GetImpl()->SetOutputSymbols(encoder.OutputSymbols());
  }

  // See Fst<>::Copy() for doc.
  DecodeFst(const DecodeFst<A> &fst, bool safe = false)
      : ArcMapFst<A, A, C>(fst, safe) {}

  // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc.
  virtual DecodeFst<A> *Copy(bool safe = false) const {
    return new DecodeFst(*this, safe);
  }
};


// Specialization for EncodeFst.
template <class A>
class StateIterator< EncodeFst<A> >
    : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
 public:
  explicit StateIterator(const EncodeFst<A> &fst)
      : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
};


// Specialization for EncodeFst.
template <class A>
class ArcIterator< EncodeFst<A> >
    : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
 public:
  ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
      : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
};


// Specialization for DecodeFst.
template <class A>
class StateIterator< DecodeFst<A> >
    : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
 public:
  explicit StateIterator(const DecodeFst<A> &fst)
      : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
};


// Specialization for DecodeFst.
template <class A>
class ArcIterator< DecodeFst<A> >
    : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
 public:
  ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
      : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
};


// Useful aliases when using StdArc.
typedef EncodeFst<StdArc> StdEncodeFst;

typedef DecodeFst<StdArc> StdDecodeFst;

}  // namespace fst

#endif  // FST_LIB_ENCODE_H__