// compose.h
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// \file
// Class to compute the composition of two FSTs

#ifndef FST_LIB_COMPOSE_H__
#define FST_LIB_COMPOSE_H__

#include <algorithm>

#include <ext/hash_map>
using __gnu_cxx::hash_map;

#include "fst/lib/cache.h"
#include "fst/lib/test-properties.h"

namespace fst {

// Enumeration of uint64 bits used to represent the user-defined
// properties of FST composition (in the template parameter to
// ComposeFstOptions<T>). The bits stand for extensions of generic FST
// composition. ComposeFstOptions<> (all the bits unset) is the "plain"
// compose without any extra extensions.
enum ComposeTypes {
  // RHO: flags dealing with a special "rest" symbol in the FSTs.
  // NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO
  // may be set.
  COMPOSE_FST1_RHO    = 1ULL<<0,  // "Rest" symbol on the output side of fst1.
  COMPOSE_FST2_RHO    = 1ULL<<1,  // "Rest" symbol on the input side of fst2.
  COMPOSE_FST1_PHI    = 1ULL<<2,  // "Failure" symbol on the output
                                  // side of fst1.
  COMPOSE_FST2_PHI    = 1ULL<<3,  // "Failure" symbol on the input side
                                  // of fst2.
  COMPOSE_FST1_SIGMA  = 1ULL<<4,  // "Any" symbol on the output side of
                                  // fst1.
  COMPOSE_FST2_SIGMA  = 1ULL<<5,  // "Any" symbol on the input side of
                                  // fst2.
  // Optimization related bits.
  COMPOSE_GENERIC     = 1ULL<<32,  // Disables optimizations, applies
                                   // the generic version of the
                                   // composition algorithm. This flag
                                   // is used for internal testing
                                   // only.

  // -----------------------------------------------------------------
  // Auxiliary enum values denoting specific combinations of
  // bits. Internal use only.
  COMPOSE_RHO         = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO,
  COMPOSE_PHI         = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI,
  COMPOSE_SIGMA       = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA,
  COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA,

  // -----------------------------------------------------------------
  // The following bits, denoting specific optimizations, are
  // typically set *internally* by the composition algorithm.
  COMPOSE_FST1_STRING = 1ULL<<33,  // fst1 is a string
  COMPOSE_FST2_STRING = 1ULL<<34,  // fst2 is a string
  COMPOSE_FST1_DET    = 1ULL<<35,  // fst1 is deterministic
  COMPOSE_FST2_DET    = 1ULL<<36,  // fst2 is deterministic
  COMPOSE_INTERNAL_MASK    = 0xffffffff00000000ULL
};


template <uint64 T = 0ULL>
struct ComposeFstOptions : public CacheOptions {
  explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
  ComposeFstOptions() { }
};


// Abstract base for the implementation of delayed ComposeFst. The
// concrete specializations are templated on the (uint64-valued)
// properties of the FSTs being composed.
template <class A>
class ComposeFstImplBase : public CacheImpl<A> {
 public:
  using FstImpl<A>::SetType;
  using FstImpl<A>::SetProperties;
  using FstImpl<A>::Properties;
  using FstImpl<A>::SetInputSymbols;
  using FstImpl<A>::SetOutputSymbols;

  using CacheBaseImpl< CacheState<A> >::HasStart;
  using CacheBaseImpl< CacheState<A> >::HasFinal;
  using CacheBaseImpl< CacheState<A> >::HasArcs;

  typedef typename A::Label Label;
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;
  typedef CacheState<A> State;

  ComposeFstImplBase(const Fst<A> &fst1,
                     const Fst<A> &fst2,
                     const CacheOptions &opts)
      :CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) {
    SetType("compose");
    uint64 props1 = fst1.Properties(kFstProperties, false);
    uint64 props2 = fst2.Properties(kFstProperties, false);
    SetProperties(ComposeProperties(props1, props2), kCopyProperties);

    if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols()))
      LOG(FATAL) << "ComposeFst: output symbol table of 1st argument "
                 << "does not match input symbol table of 2nd argument";

    SetInputSymbols(fst1.InputSymbols());
    SetOutputSymbols(fst2.OutputSymbols());
  }

  virtual ~ComposeFstImplBase() {
    delete fst1_;
    delete fst2_;
  }

  StateId Start() {
    if (!HasStart()) {
      StateId start = ComputeStart();
      if (start != kNoStateId) {
        SetStart(start);
      }
    }
    return CacheImpl<A>::Start();
  }

  Weight Final(StateId s) {
    if (!HasFinal(s)) {
      Weight final = ComputeFinal(s);
      SetFinal(s, final);
    }
    return CacheImpl<A>::Final(s);
  }

  virtual void Expand(StateId s) = 0;

  size_t NumArcs(StateId s) {
    if (!HasArcs(s))
      Expand(s);
    return CacheImpl<A>::NumArcs(s);
  }

  size_t NumInputEpsilons(StateId s) {
    if (!HasArcs(s))
      Expand(s);
    return CacheImpl<A>::NumInputEpsilons(s);
  }

  size_t NumOutputEpsilons(StateId s) {
    if (!HasArcs(s))
      Expand(s);
    return CacheImpl<A>::NumOutputEpsilons(s);
  }

  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    if (!HasArcs(s))
      Expand(s);
    CacheImpl<A>::InitArcIterator(s, data);
  }

  // Access to flags encoding compose options/optimizations etc.  (for
  // debugging).
  virtual uint64 ComposeFlags() const = 0;

 protected:
  virtual StateId ComputeStart() = 0;
  virtual Weight ComputeFinal(StateId s) = 0;

  const Fst<A> *fst1_;            // first input Fst
  const Fst<A> *fst2_;            // second input Fst
};


// The following class encapsulates implementation-dependent details
// of state tuple lookup, i.e. a bijective mapping from triples of two
// FST states and an epsilon filter state to the corresponding state
// IDs of the fst resulting from composition. The mapping must
// implement the [] operator in the style of STL associative
// containers (map, hash_map), i.e. table[x] must return a reference
// to the value associated with x. If x is an unassigned tuple, the
// operator must automatically associate x with value 0.
//
// NB: "table[x] == 0" for unassigned tuples x is required by the
// following off-by-one device used in the implementation of
// ComposeFstImpl. The value stored in the table is equal to tuple ID
// plus one, i.e. it is always a strictly positive number. Therefore,
// table[x] is equal to 0 if and only if x is an unassigned tuple (in
// which the algorithm assigns a new ID to x, and sets table[x] -
// stored in a reference - to "new ID + 1"). This form of lookup is
// more efficient than calling "find(x)" and "insert(make_pair(x, new
// ID))" if x is an unassigned tuple.
//
// The generic implementation is a wrapper around a hash_map.
template <class A, uint64 T>
class ComposeStateTable {
 public:
  typedef typename A::StateId StateId;

  struct StateTuple {
    StateTuple() {}
    StateTuple(StateId s1, StateId s2, int f)
        : state_id1(s1), state_id2(s2), filt(f) {}
    StateId state_id1;  // state Id on fst1
    StateId state_id2;  // state Id on fst2
    int filt;           // epsilon filter state
  };

  ComposeStateTable() {
    StateTuple empty_tuple(kNoStateId, kNoStateId, 0);
  }

  // NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is
  // inserted into 'table_' (standard STL container semantics). Since
  // StateId is a built-in type, the explicit default constructor call
  // StateId() returns 0.
  StateId &operator[](const StateTuple &tuple) {
    return table_[tuple];
  }

 private:
  // Comparison object for hashing StateTuple(s).
  class StateTupleEqual {
   public:
    bool operator()(const StateTuple& x, const StateTuple& y) const {
      return x.state_id1 == y.state_id1 &&
             x.state_id2 == y.state_id2 &&
             x.filt == y.filt;
    }
  };

  static const int kPrime0 = 7853;
  static const int kPrime1 = 7867;

  // Hash function for StateTuple to Fst states.
  class StateTupleKey {
   public:
    size_t operator()(const StateTuple& x) const {
      return static_cast<size_t>(x.state_id1 +
                                 x.state_id2 * kPrime0 +
                                 x.filt * kPrime1);
    }
  };

  // Lookup table mapping state tuples to state IDs.
  typedef hash_map<StateTuple,
                         StateId,
                         StateTupleKey,
                         StateTupleEqual> StateTable;
 // Actual table data.
  StateTable table_;

  DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable);
};


// State tuple lookup table for the composition of a string FST with a
// deterministic FST.  The class maps state tuples to their unique IDs
// (i.e. states of the ComposeFst). Main optimization: due to the
// 1-to-1 correspondence between the states of the input string FST
// and those of the resulting (string) FST, a state tuple (s1, s2) is
// simply mapped to StateId s1. Hence, we use an STL vector as a
// lookup table. Template argument Fst1IsString specifies which FST is
// a string (this determines whether or not we index the lookup table
// by the first or by the second state).
template <class A, bool Fst1IsString>
class StringDetComposeStateTable {
 public:
  typedef typename A::StateId StateId;

  struct StateTuple {
    typedef typename A::StateId StateId;
    StateTuple() {}
    StateTuple(StateId s1, StateId s2, int /* f */)
        : state_id1(s1), state_id2(s2) {}
    StateId state_id1;  // state Id on fst1
    StateId state_id2;  // state Id on fst2
    static const int filt = 0;  // 'fake' epsilon filter - only needed
                                // for API compatibility
  };

  StringDetComposeStateTable() {}

  // Subscript operator. Behaves in a way similar to its map/hash_map
  // counterpart, i.e. returns a reference to the value associated
  // with 'tuple', inserting a 0 value if 'tuple' is unassigned.
  StateId &operator[](const StateTuple &tuple) {
    StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2;
    if (index >= (StateId)data_.size()) { 
      // NB: all values in [old_size; index] are initialized to 0.
      data_.resize(index + 1);
    }
    return data_[index];
  }

 private:
  vector<StateId> data_;

  DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable);
};


// Specializations of ComposeStateTable for the string/det case.
// Both inherit from StringDetComposeStateTable.
template <class A>
class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET>
    : public StringDetComposeStateTable<A, true> { };

template <class A>
class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET>
    : public StringDetComposeStateTable<A, false> { };


// Parameterized implementation of FST composition for a pair of FSTs
// matching the property bit vector T. If possible,
// instantiation-specific switches in the code are based on the values
// of the bits in T, which are known at compile time, so unused code
// should be optimized away by the compiler.
template <class A, uint64 T>
class ComposeFstImpl : public ComposeFstImplBase<A> {
  typedef typename A::StateId StateId;
  typedef typename A::Label   Label;
  typedef typename A::Weight  Weight;
  using FstImpl<A>::SetType;
  using FstImpl<A>::SetProperties;

  enum FindType { FIND_INPUT  = 1,          // find input label on fst2
                  FIND_OUTPUT = 2,          // find output label on fst1
                  FIND_BOTH   = 3 };        // find choice state dependent

  typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable;
  typedef typename StateTupleTable::StateTuple StateTuple;

 public:
  ComposeFstImpl(const Fst<A> &fst1,
                 const Fst<A> &fst2,
                 const CacheOptions &opts)
      :ComposeFstImplBase<A>(fst1, fst2, opts) {

    bool osorted = fst1.Properties(kOLabelSorted, false);
    bool isorted = fst2.Properties(kILabelSorted, false);

    switch (T & COMPOSE_SPECIAL_SYMBOLS) {
      case COMPOSE_FST1_RHO:
      case COMPOSE_FST1_PHI:
      case COMPOSE_FST1_SIGMA:
        if (!osorted || FLAGS_fst_verify_properties)
          osorted = fst1.Properties(kOLabelSorted, true);
        if (!osorted)
          LOG(FATAL) << "ComposeFst: 1st argument not output label "
                     << "sorted (special symbols present)";
        break;
      case COMPOSE_FST2_RHO:
      case COMPOSE_FST2_PHI:
      case COMPOSE_FST2_SIGMA:
        if (!isorted || FLAGS_fst_verify_properties)
          isorted = fst2.Properties(kILabelSorted, true);
        if (!isorted)
          LOG(FATAL) << "ComposeFst: 2nd argument not input label "
                     << "sorted (special symbols present)";
        break;
      case 0:
        if (!isorted && !osorted || FLAGS_fst_verify_properties) {
          osorted = fst1.Properties(kOLabelSorted, true);
          if (!osorted)
            isorted = fst2.Properties(kILabelSorted, true);
        }
        break;
      default:
        LOG(FATAL)
          << "ComposeFst: More than one special symbol used in composition";
    }

    if (isorted && (T & COMPOSE_FST2_SIGMA)) {
      find_type_ = FIND_INPUT;
    } else if (osorted && (T & COMPOSE_FST1_SIGMA)) {
      find_type_ = FIND_OUTPUT;
    } else if (isorted && (T & COMPOSE_FST2_PHI)) {
      find_type_ = FIND_INPUT;
    } else if (osorted && (T & COMPOSE_FST1_PHI)) {
      find_type_ = FIND_OUTPUT;
    } else if (isorted && (T & COMPOSE_FST2_RHO)) {
      find_type_ = FIND_INPUT;
    } else if (osorted && (T & COMPOSE_FST1_RHO)) {
      find_type_ = FIND_OUTPUT;
    } else if (isorted && (T & COMPOSE_FST1_STRING)) {
      find_type_ = FIND_INPUT;
    } else if(osorted && (T & COMPOSE_FST2_STRING)) {
      find_type_ = FIND_OUTPUT;
    } else if (isorted && osorted) {
      find_type_ = FIND_BOTH;
    } else if (isorted) {
      find_type_ = FIND_INPUT;
    } else if (osorted) {
      find_type_ = FIND_OUTPUT;
    } else {
      LOG(FATAL) << "ComposeFst: 1st argument not output label sorted "
                 << "and 2nd argument is not input label sorted";
    }
  }

  // Finds/creates an Fst state given a StateTuple.  Only creates a new
  // state if StateTuple is not found in the state hash.
  //
  // The method exploits the following device: all pairs stored in the
  // associative container state_tuple_table_ are of the form (tuple,
  // id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has
  // been stored previously. For unassigned tuples, the call to
  // state_tuple_table_[tuple] creates a new pair (tuple, 0). As a
  // result, state_tuple_table_[tuple] == 0 iff tuple is new.
  StateId FindState(const StateTuple& tuple) {
    StateId &assoc_value = state_tuple_table_[tuple];
    if (assoc_value == 0) {  // tuple wasn't present in lookup table:
                             // assign it a new ID.
      state_tuples_.push_back(tuple);
      assoc_value = state_tuples_.size();
    }
    return assoc_value - 1;  // NB: assoc_value = ID + 1
  }

  // Generates arc for composition state s from matched input Fst arcs.
  void AddArc(StateId s, const A &arca, const A &arcb, int f,
              bool find_input) {
    A arc;
    if (find_input) {
      arc.ilabel = arcb.ilabel;
      arc.olabel = arca.olabel;
      arc.weight = Times(arcb.weight, arca.weight);
      StateTuple tuple(arcb.nextstate, arca.nextstate, f);
      arc.nextstate = FindState(tuple);
    } else {
      arc.ilabel = arca.ilabel;
      arc.olabel = arcb.olabel;
      arc.weight = Times(arca.weight, arcb.weight);
      StateTuple tuple(arca.nextstate, arcb.nextstate, f);
      arc.nextstate = FindState(tuple);
    }
    CacheImpl<A>::AddArc(s, arc);
  }

  // Arranges it so that the first arg to OrderedExpand is the Fst
  // that will be passed to FindLabel.
  void Expand(StateId s) {
    StateTuple &tuple = state_tuples_[s];
    StateId s1 = tuple.state_id1;
    StateId s2 = tuple.state_id2;
    int f = tuple.filt;
    if (find_type_ == FIND_INPUT)
      OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2,
                    ComposeFstImplBase<A>::fst1_, s1, f, true);
    else
      OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1,
                    ComposeFstImplBase<A>::fst2_, s2, f, false);
  }

  // Access to flags encoding compose options/optimizations etc.  (for
  // debugging).
  virtual uint64 ComposeFlags() const { return T; }

 private:
  // This does that actual matching of labels in the composition. The
  // arguments are ordered so FindLabel is called with state SA of
  // FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg
  // determines whether the input or output label of arcs at SB is
  // the one to match on.
  void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa,
                     const Fst<A> *fstb, StateId sb, int f, bool find_input) {

    size_t numarcsa = fsta->NumArcs(sa);
    size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) :
                     fsta->NumOutputEpsilons(sa);
    bool finala = fsta->Final(sa) != Weight::Zero();
    ArcIterator< Fst<A> > aitera(*fsta, sa);
    // First handle special epsilons and sigmas on FSTA
    for (; !aitera.Done(); aitera.Next()) {
      const A &arca = aitera.Value();
      Label match_labela = find_input ? arca.ilabel : arca.olabel;
      if (match_labela > 0) {
        break;
      }
      if ((T & COMPOSE_SIGMA) != 0 &&  match_labela == kSigmaLabel) {
        // Found a sigma? Match it against all (non-special) symbols
        // on side b.
        for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
             !aiterb.Done();
             aiterb.Next()) {
          const A &arcb = aiterb.Value();
          Label labelb = find_input ? arcb.olabel : arcb.ilabel;
          if (labelb <= 0) continue;
          AddArc(s, arca, arcb, 0, find_input);
        }
      } else if (f == 0 && match_labela == 0) {
        A earcb(0, 0, Weight::One(), sb);
        AddArc(s, arca, earcb, 0, find_input);  // move forward on epsilon
      }
    }
    // Next handle non-epsilon matches, rho labels, and epsilons on FSTB
    for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
         !aiterb.Done();
         aiterb.Next()) {
      const A &arcb = aiterb.Value();
      Label match_labelb = find_input ? arcb.olabel : arcb.ilabel;
      if (match_labelb) {  // Consider non-epsilon match
        if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) {
          for (; !aitera.Done(); aitera.Next()) {
            const A &arca = aitera.Value();
            Label match_labela = find_input ? arca.ilabel : arca.olabel;
            if (match_labela != match_labelb)
              break;
            AddArc(s, arca, arcb, 0, find_input);  // move forward on match
          }
        } else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) {
          // If there is no transition labelled 'match_labelb' in
          // fsta, try matching 'match_labelb' against special symbols
          // (Phi, Rho,...).
          for (aitera.Reset(); !aitera.Done(); aitera.Next()) {
            A arca = aitera.Value();
            Label labela = find_input ? arca.ilabel : arca.olabel;
            if (labela >= 0) {
              break;
            } else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) {
              // Case 1: if a failure transition exists, follow its
              // transitive closure until a) a transition labelled
              // 'match_labelb' is found, or b) the initial state of
              // fsta is reached.

              StateId sf = sa;  // Start of current failure transition.
              while (labela == kPhiLabel && sf != arca.nextstate) {
                sf = arca.nextstate;

                size_t numarcsf = fsta->NumArcs(sf);
                ArcIterator< Fst<A> > aiterf(*fsta, sf);
                if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) {
                  // Sub-case 1a: there exists a transition starting
                  // in sf and consuming symbol 'match_labelb'.
                  AddArc(s, aiterf.Value(), arcb, 0, find_input);
                  break;
                } else {
                  // No transition labelled 'match_labelb' found: try
                  // next failure transition (starting at 'sf').
                  for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) {
                    arca = aiterf.Value();
                    labela = find_input ? arca.ilabel : arca.olabel;
                    if (labela >= kPhiLabel) break;
                  }
                }
              }
              if (labela == kPhiLabel && sf == arca.nextstate) {
                // Sub-case 1b: failure transitions lead to start
                // state without finding a matching
                // transition. Therefore, we generate a loop in start
                // state of fsta.
                A loop(match_labelb, match_labelb, Weight::One(), sf);
                AddArc(s, loop, arcb, 0, find_input);
              }
            } else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) {
              // Case 2: 'match_labelb' can be matched against a
              // "rest" (rho) label in fsta.
              if (find_input) {
                arca.ilabel = match_labelb;
                if (arca.olabel == kRhoLabel)
                  arca.olabel = match_labelb;
              } else {
                arca.olabel = match_labelb;
                if (arca.ilabel == kRhoLabel)
                  arca.ilabel = match_labelb;
              }
              AddArc(s, arca, arcb, 0, find_input);  // move fwd on match
            }
          }
        }
      } else if (numepsa != numarcsa || finala) {  // Handle FSTB epsilon
        A earca(0, 0, Weight::One(), sa);
        AddArc(s, earca, arcb, numepsa > 0, find_input);  // move on epsilon
      }
    }
    SetArcs(s);
   }


  // Finds matches to MATCH_LABEL in arcs given by AITER
  // using FIND_INPUT to determine whether to look on input or output.
  bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs,
                 Label match_label, bool find_input) {
    // binary search for match
    size_t low = 0;
    size_t high = numarcs;
    while (low < high) {
      size_t mid = (low + high) / 2;
      aiter->Seek(mid);
      Label label = find_input ?
                    aiter->Value().ilabel : aiter->Value().olabel;
      if (label > match_label) {
        high = mid;
      } else if (label < match_label) {
        low = mid + 1;
      } else {
        // find first matching label (when non-determinism)
        for (size_t i = mid; i > low; --i) {
          aiter->Seek(i - 1);
          label = find_input ? aiter->Value().ilabel : aiter->Value().olabel;
          if (label != match_label) {
            aiter->Seek(i);
            return true;
          }
        }
        return true;
      }
    }
    return false;
  }

  StateId ComputeStart() {
    StateId s1 = ComposeFstImplBase<A>::fst1_->Start();
    StateId s2 = ComposeFstImplBase<A>::fst2_->Start();
    if (s1 == kNoStateId || s2 == kNoStateId)
      return kNoStateId;
    StateTuple tuple(s1, s2, 0);
    return FindState(tuple);
  }

  Weight ComputeFinal(StateId s) {
    StateTuple &tuple = state_tuples_[s];
    Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1),
                         ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2));
    return final;
  }


  FindType find_type_;            // find label on which side?

  // Maps from StateId to StateTuple.
  vector<StateTuple> state_tuples_;

  // Maps from StateTuple to StateId.
  StateTupleTable state_tuple_table_;

  DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl);
};


// Computes the composition of two transducers. This version is a
// delayed Fst. If FST1 transduces string x to y with weight a and FST2
// transduces y to z with weight b, then their composition transduces
// string x to z with weight Times(x, z).
//
// The output labels of the first transducer or the input labels of
// the second transducer must be sorted.  The weights need to form a
// commutative semiring (valid for TropicalWeight and LogWeight).
//
// Complexity:
// Assuming the first FST is unsorted and the second is sorted:
// - Time: O(v1 v2 d1 (log d2 + m2)),
// - Space: O(v1 v2)
// where vi = # of states visited, di = maximum out-degree, and mi the
// maximum multiplicity of the states visited for the ith
// FST. Constant time and space to visit an input state or arc is
// assumed and exclusive of caching.
//
// Caveats:
// - ComposeFst does not trim its output (since it is a delayed operation).
// - The efficiency of composition can be strongly affected by several factors:
//   - the choice of which tnansducer is sorted - prefer sorting the FST
//     that has the greater average out-degree.
//   - the amount of non-determinism
//   - the presence and location of epsilon transitions - avoid epsilon
//     transitions on the output side of the first transducer or
//     the input side of the second transducer or prefer placing
//     them later in a path since they delay matching and can
//     introduce non-coaccessible states and transitions.
template <class A>
class ComposeFst : public Fst<A> {
 public:
  friend class ArcIterator< ComposeFst<A> >;
  friend class CacheStateIterator< ComposeFst<A> >;
  friend class CacheArcIterator< ComposeFst<A> >;

  typedef A Arc;
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;
  typedef CacheState<A> State;

  ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2)
      : impl_(Init(fst1, fst2, ComposeFstOptions<>())) { }

  template <uint64 T>
  ComposeFst(const Fst<A> &fst1,
             const Fst<A> &fst2,
             const ComposeFstOptions<T> &opts)
      : impl_(Init(fst1, fst2, opts)) { }

  ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) {
    impl_->IncrRefCount();
  }

  virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_;  }

  virtual StateId Start() const { return impl_->Start(); }

  virtual Weight Final(StateId s) const { return impl_->Final(s); }

  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }

  virtual size_t NumInputEpsilons(StateId s) const {
    return impl_->NumInputEpsilons(s);
  }

  virtual size_t NumOutputEpsilons(StateId s) const {
    return impl_->NumOutputEpsilons(s);
  }

  virtual uint64 Properties(uint64 mask, bool test) const {
    if (test) {
      uint64 known, test = TestProperties(*this, mask, &known);
      impl_->SetProperties(test, known);
      return test & mask;
    } else {
      return impl_->Properties(mask);
    }
  }

  virtual const string& Type() const { return impl_->Type(); }

  virtual ComposeFst<A> *Copy() const {
    return new ComposeFst<A>(*this);
  }

  virtual const SymbolTable* InputSymbols() const {
    return impl_->InputSymbols();
  }

  virtual const SymbolTable* OutputSymbols() const {
    return impl_->OutputSymbols();
  }

  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;

  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    impl_->InitArcIterator(s, data);
  }

  // Access to flags encoding compose options/optimizations etc.  (for
  // debugging).
  uint64 ComposeFlags() const { return impl_->ComposeFlags(); }

 protected:
  ComposeFstImplBase<A> *Impl() { return impl_; }

 private:
  ComposeFstImplBase<A> *impl_;

  // Auxiliary method encapsulating the creation of a ComposeFst
  // implementation that is appropriate for the properties of fst1 and
  // fst2.
  template <uint64 T>
  static ComposeFstImplBase<A> *Init(
      const Fst<A> &fst1,
      const Fst<A> &fst2,
      const ComposeFstOptions<T> &opts) {

    // Filter for sort properties (forces a property check).
    uint64 sort_props_mask = kILabelSorted | kOLabelSorted;
    // Filter for optimization-related properties (does not force a
    // property-check).
    uint64 opt_props_mask =
      kString | kIDeterministic | kODeterministic | kNoIEpsilons |
      kNoOEpsilons;

    uint64 props1 = fst1.Properties(sort_props_mask, true);
    uint64 props2 = fst2.Properties(sort_props_mask, true);

    props1 |= fst1.Properties(opt_props_mask, false);
    props2 |= fst2.Properties(opt_props_mask, false);

    if (!(Weight::Properties() & kCommutative)) {
      props1 |= fst1.Properties(kUnweighted, true);
      props2 |= fst2.Properties(kUnweighted, true);
      if (!(props1 & kUnweighted) && !(props2 & kUnweighted))
        LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: "
                   << Weight::Type();
    }

    // Case 1: flag COMPOSE_GENERIC disables optimizations.
    if (T & COMPOSE_GENERIC) {
      return new ComposeFstImpl<A, T>(fst1, fst2, opts);
    }

    const uint64 kStringDetOptProps =
      kIDeterministic | kILabelSorted | kNoIEpsilons;
    const uint64 kDetStringOptProps =
      kODeterministic | kOLabelSorted | kNoOEpsilons;

    // Case 2: fst1 is a string, fst2 is deterministic and epsilon-free.
    if ((props1 & kString) &&
        !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
        ((props2 & kStringDetOptProps) == kStringDetOptProps)) {
      return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>(
          fst1, fst2, opts);
    }
    // Case 3: fst1 is deterministic and epsilon-free, fst2 is string.
    if ((props2 & kString) &&
        !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
        ((props1 & kDetStringOptProps) == kDetStringOptProps)) {
      return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>(
          fst1, fst2, opts);
    }

    // Default case: no optimizations.
    return new ComposeFstImpl<A, T>(fst1, fst2, opts);
  }

  void operator=(const ComposeFst<A> &fst);  // disallow
};


// Specialization for ComposeFst.
template<class A>
class StateIterator< ComposeFst<A> >
    : public CacheStateIterator< ComposeFst<A> > {
 public:
  explicit StateIterator(const ComposeFst<A> &fst)
      : CacheStateIterator< ComposeFst<A> >(fst) {}
};


// Specialization for ComposeFst.
template <class A>
class ArcIterator< ComposeFst<A> >
    : public CacheArcIterator< ComposeFst<A> > {
 public:
  typedef typename A::StateId StateId;

  ArcIterator(const ComposeFst<A> &fst, StateId s)
      : CacheArcIterator< ComposeFst<A> >(fst, s) {
    if (!fst.impl_->HasArcs(s))
      fst.impl_->Expand(s);
  }

 private:
  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
};

template <class A> inline
void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
  data->base = new StateIterator< ComposeFst<A> >(*this);
}

// Useful alias when using StdArc.
typedef ComposeFst<StdArc> StdComposeFst;


struct ComposeOptions {
  bool connect;  // Connect output

  ComposeOptions(bool c) : connect(c) {}
  ComposeOptions() : connect(true) { }
};


// Computes the composition of two transducers. This version writes
// the composed FST into a MurableFst. If FST1 transduces string x to
// y with weight a and FST2 transduces y to z with weight b, then
// their composition transduces string x to z with weight
// Times(x, z).
//
// The output labels of the first transducer or the input labels of
// the second transducer must be sorted.  The weights need to form a
// commutative semiring (valid for TropicalWeight and LogWeight).
//
// Complexity:
// Assuming the first FST is unsorted and the second is sorted:
// - Time: O(V1 V2 D1 (log D2 + M2)),
// - Space: O(V1 V2 D1 M2)
// where Vi = # of states, Di = maximum out-degree, and Mi is
// the maximum multiplicity for the ith FST.
//
// Caveats:
// - Compose trims its output.
// - The efficiency of composition can be strongly affected by several factors:
//   - the choice of which tnansducer is sorted - prefer sorting the FST
//     that has the greater average out-degree.
//   - the amount of non-determinism
//   - the presence and location of epsilon transitions - avoid epsilon
//     transitions on the output side of the first transducer or
//     the input side of the second transducer or prefer placing
//     them later in a path since they delay matching and can
//     introduce non-coaccessible states and transitions.
template<class Arc>
void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
             MutableFst<Arc> *ofst,
             const ComposeOptions &opts = ComposeOptions()) {
  ComposeFstOptions<> nopts;
  nopts.gc_limit = 0;  // Cache only the last state for fastest copy.
  *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
  if (opts.connect)
    Connect(ofst);
}

}  // namespace fst

#endif  // FST_LIB_COMPOSE_H__