// expand.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Expand a PDT to an FST.

#ifndef FST_EXTENSIONS_PDT_EXPAND_H__
#define FST_EXTENSIONS_PDT_EXPAND_H__

#include <vector>
using std::vector;

#include <fst/extensions/pdt/pdt.h>
#include <fst/extensions/pdt/paren.h>
#include <fst/extensions/pdt/shortest-path.h>
#include <fst/extensions/pdt/reverse.h>
#include <fst/cache.h>
#include <fst/mutable-fst.h>
#include <fst/queue.h>
#include <fst/state-table.h>
#include <fst/test-properties.h>

namespace fst {

template <class Arc>
struct ExpandFstOptions : public CacheOptions {
  bool keep_parentheses;
  PdtStack<typename Arc::StateId, typename Arc::Label> *stack;
  PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table;

  ExpandFstOptions(
      const CacheOptions &opts = CacheOptions(),
      bool kp = false,
      PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0,
      PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0)
      : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
};

// Properties for an expanded PDT.
inline uint64 ExpandProperties(uint64 inprops) {
  return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
}


// Implementation class for ExpandFst
template <class A>
class ExpandFstImpl
    : public CacheImpl<A> {
 public:
  using FstImpl<A>::SetType;
  using FstImpl<A>::SetProperties;
  using FstImpl<A>::Properties;
  using FstImpl<A>::SetInputSymbols;
  using FstImpl<A>::SetOutputSymbols;

  using CacheBaseImpl< CacheState<A> >::PushArc;
  using CacheBaseImpl< CacheState<A> >::HasArcs;
  using CacheBaseImpl< CacheState<A> >::HasFinal;
  using CacheBaseImpl< CacheState<A> >::HasStart;
  using CacheBaseImpl< CacheState<A> >::SetArcs;
  using CacheBaseImpl< CacheState<A> >::SetFinal;
  using CacheBaseImpl< CacheState<A> >::SetStart;

  typedef A Arc;
  typedef typename A::Label Label;
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;
  typedef StateId StackId;
  typedef PdtStateTuple<StateId, StackId> StateTuple;

  ExpandFstImpl(const Fst<A> &fst,
                const vector<pair<typename Arc::Label,
                                  typename Arc::Label> > &parens,
                const ExpandFstOptions<A> &opts)
      : CacheImpl<A>(opts), fst_(fst.Copy()),
        stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)),
        state_table_(opts.state_table ? opts.state_table :
                     new PdtStateTable<StateId, StackId>()),
        own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0),
        keep_parentheses_(opts.keep_parentheses) {
    SetType("expand");

    uint64 props = fst.Properties(kFstProperties, false);
    SetProperties(ExpandProperties(props), kCopyProperties);

    SetInputSymbols(fst.InputSymbols());
    SetOutputSymbols(fst.OutputSymbols());
  }

  ExpandFstImpl(const ExpandFstImpl &impl)
      : CacheImpl<A>(impl),
        fst_(impl.fst_->Copy(true)),
        stack_(new PdtStack<StateId, Label>(*impl.stack_)),
        state_table_(new PdtStateTable<StateId, StackId>()),
        own_stack_(true), own_state_table_(true),
        keep_parentheses_(impl.keep_parentheses_) {
    SetType("expand");
    SetProperties(impl.Properties(), kCopyProperties);
    SetInputSymbols(impl.InputSymbols());
    SetOutputSymbols(impl.OutputSymbols());
  }

  ~ExpandFstImpl() {
    delete fst_;
    if (own_stack_)
      delete stack_;
    if (own_state_table_)
      delete state_table_;
  }

  StateId Start() {
    if (!HasStart()) {
      StateId s = fst_->Start();
      if (s == kNoStateId)
        return kNoStateId;
      StateTuple tuple(s, 0);
      StateId start = state_table_->FindState(tuple);
      SetStart(start);
    }
    return CacheImpl<A>::Start();
  }

  Weight Final(StateId s) {
    if (!HasFinal(s)) {
      const StateTuple &tuple = state_table_->Tuple(s);
      Weight w = fst_->Final(tuple.state_id);
      if (w != Weight::Zero() && tuple.stack_id == 0)
        SetFinal(s, w);
      else
        SetFinal(s, Weight::Zero());
    }
    return CacheImpl<A>::Final(s);
  }

  size_t NumArcs(StateId s) {
    if (!HasArcs(s)) {
      ExpandState(s);
    }
    return CacheImpl<A>::NumArcs(s);
  }

  size_t NumInputEpsilons(StateId s) {
    if (!HasArcs(s))
      ExpandState(s);
    return CacheImpl<A>::NumInputEpsilons(s);
  }

  size_t NumOutputEpsilons(StateId s) {
    if (!HasArcs(s))
      ExpandState(s);
    return CacheImpl<A>::NumOutputEpsilons(s);
  }

  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
    if (!HasArcs(s))
      ExpandState(s);
    CacheImpl<A>::InitArcIterator(s, data);
  }

  // Computes the outgoing transitions from a state, creating new destination
  // states as needed.
  void ExpandState(StateId s) {
    StateTuple tuple = state_table_->Tuple(s);
    for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id);
         !aiter.Done(); aiter.Next()) {
      Arc arc = aiter.Value();
      StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
      if (stack_id == -1) {
        // Non-matching close parenthesis
        continue;
      } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
        // Stack push/pop
        arc.ilabel = arc.olabel = 0;
      }

      StateTuple ntuple(arc.nextstate, stack_id);
      arc.nextstate = state_table_->FindState(ntuple);
      PushArc(s, arc);
    }
    SetArcs(s);
  }

  const PdtStack<StackId, Label> &GetStack() const { return *stack_; }

  const PdtStateTable<StateId, StackId> &GetStateTable() const {
    return *state_table_;
  }

 private:
  const Fst<A> *fst_;

  PdtStack<StackId, Label> *stack_;
  PdtStateTable<StateId, StackId> *state_table_;
  bool own_stack_;
  bool own_state_table_;
  bool keep_parentheses_;

  void operator=(const ExpandFstImpl<A> &);  // disallow
};

// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
// This version is a delayed Fst.  In the PDT, some transitions are
// labeled with open or close parentheses. To be interpreted as a PDT,
// the parens must balance on a path. The open-close parenthesis label
// pairs are passed in 'parens'. The expansion enforces the
// parenthesis constraints. The PDT must be expandable as an FST.
//
// This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A>
class ExpandFst : public ImplToFst< ExpandFstImpl<A> > {
 public:
  friend class ArcIterator< ExpandFst<A> >;
  friend class StateIterator< ExpandFst<A> >;

  typedef A Arc;
  typedef typename A::Label Label;
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;
  typedef StateId StackId;
  typedef CacheState<A> State;
  typedef ExpandFstImpl<A> Impl;

  ExpandFst(const Fst<A> &fst,
            const vector<pair<typename Arc::Label,
                              typename Arc::Label> > &parens)
      : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {}

  ExpandFst(const Fst<A> &fst,
            const vector<pair<typename Arc::Label,
                              typename Arc::Label> > &parens,
            const ExpandFstOptions<A> &opts)
      : ImplToFst<Impl>(new Impl(fst, parens, opts)) {}

  // See Fst<>::Copy() for doc.
  ExpandFst(const ExpandFst<A> &fst, bool safe = false)
      : ImplToFst<Impl>(fst, safe) {}

  // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
  virtual ExpandFst<A> *Copy(bool safe = false) const {
    return new ExpandFst<A>(*this, safe);
  }

  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;

  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    GetImpl()->InitArcIterator(s, data);
  }

  const PdtStack<StackId, Label> &GetStack() const {
    return GetImpl()->GetStack();
  }

  const PdtStateTable<StateId, StackId> &GetStateTable() const {
    return GetImpl()->GetStateTable();
  }

 private:
  // Makes visible to friends.
  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }

  void operator=(const ExpandFst<A> &fst);  // Disallow
};


// Specialization for ExpandFst.
template<class A>
class StateIterator< ExpandFst<A> >
    : public CacheStateIterator< ExpandFst<A> > {
 public:
  explicit StateIterator(const ExpandFst<A> &fst)
      : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {}
};


// Specialization for ExpandFst.
template <class A>
class ArcIterator< ExpandFst<A> >
    : public CacheArcIterator< ExpandFst<A> > {
 public:
  typedef typename A::StateId StateId;

  ArcIterator(const ExpandFst<A> &fst, StateId s)
      : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) {
    if (!fst.GetImpl()->HasArcs(s))
      fst.GetImpl()->ExpandState(s);
  }

 private:
  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
};


template <class A> inline
void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const
{
  data->base = new StateIterator< ExpandFst<A> >(*this);
}

//
// PrunedExpand Class
//

// Prunes the delayed expansion of a pushdown transducer (PDT) encoded
// as an FST into an FST.  In the PDT, some transitions are labeled
// with open or close parentheses. To be interpreted as a PDT, the
// parens must balance on a path. The open-close parenthesis label
// pairs are passed in 'parens'. The expansion enforces the
// parenthesis constraints.
//
// The algorithm works by visiting the delayed ExpandFst using a
// shortest-stack first queue discipline and relies on the
// shortest-distance information computed using a reverse
// shortest-path call to perform the pruning.
//
// The algorithm maintains the same state ordering between the ExpandFst
// being visited 'efst_' and the result of pruning written into the
// MutableFst 'ofst_' to improve readability of the code.
//
template <class A>
class PrunedExpand {
 public:
  typedef A Arc;
  typedef typename A::Label Label;
  typedef typename A::StateId StateId;
  typedef typename A::Weight Weight;
  typedef StateId StackId;
  typedef PdtStack<StackId, Label> Stack;
  typedef PdtStateTable<StateId, StackId> StateTable;
  typedef typename PdtBalanceData<Arc>::SetIterator SetIterator;

  // Constructor taking as input a PDT specified by 'ifst' and 'parens'.
  // 'keep_parentheses' specifies whether parentheses are replaced by
  // epsilons or not during the expansion. 'opts' is the cache options
  // used to instantiate the underlying ExpandFst.
  PrunedExpand(const Fst<A> &ifst,
               const vector<pair<Label, Label> > &parens,
               bool keep_parentheses = false,
               const CacheOptions &opts = CacheOptions())
      : ifst_(ifst.Copy()),
        keep_parentheses_(keep_parentheses),
        stack_(parens),
        efst_(ifst, parens,
              ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
        queue_(state_table_, stack_, stack_length_, distance_, fdistance_) {
    Reverse(*ifst_, parens, &rfst_);
    VectorFst<Arc> path;
    reverse_shortest_path_ = new SP(
        rfst_, parens,
        PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false));
    reverse_shortest_path_->ShortestPath(&path);
    balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse(
        rfst_.NumStates(), 10, -1);

    InitCloseParenMultimap(parens);
  }

  ~PrunedExpand() {
    delete ifst_;
    delete reverse_shortest_path_;
    delete balance_data_;
  }

  // Expands and prunes with weight threshold 'threshold' the input PDT.
  // Writes the result in 'ofst'.
  void Expand(MutableFst<A> *ofst, const Weight &threshold);

 private:
  static const uint8 kEnqueued;
  static const uint8 kExpanded;
  static const uint8 kSourceState;

  // Comparison functor used by the queue:
  // 1. states corresponding to shortest stack first,
  // 2. among stacks of the same length, reverse lexicographic order is used,
  // 3. among states with the same stack, shortest-first order is used.
  class StackCompare {
   public:
    StackCompare(const StateTable &st,
                 const Stack &s, const vector<StackId> &sl,
                 const vector<Weight> &d, const vector<Weight> &fd)
        : state_table_(st), stack_(s), stack_length_(sl),
          distance_(d), fdistance_(fd) {}

    bool operator()(StateId s1, StateId s2) const {
      StackId si1 = state_table_.Tuple(s1).stack_id;
      StackId si2 = state_table_.Tuple(s2).stack_id;
      if (stack_length_[si1] < stack_length_[si2])
        return true;
      if  (stack_length_[si1] > stack_length_[si2])
        return false;
      // If stack id equal, use A*
      if (si1 == si2) {
        Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ?
            Times(distance_[s1], fdistance_[s1]) : Weight::Zero();
        Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ?
            Times(distance_[s2], fdistance_[s2]) : Weight::Zero();
        return less_(w1, w2);
      }
      // If lenghts are equal, use reverse lexico.
      for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
        if (stack_.Top(si1) < stack_.Top(si2)) return true;
        if (stack_.Top(si1) > stack_.Top(si2)) return false;
      }
      return false;
    }

   private:
    const StateTable &state_table_;
    const Stack &stack_;
    const vector<StackId> &stack_length_;
    const vector<Weight> &distance_;
    const vector<Weight> &fdistance_;
    NaturalLess<Weight> less_;
  };

  class ShortestStackFirstQueue
      : public ShortestFirstQueue<StateId, StackCompare> {
   public:
    ShortestStackFirstQueue(
        const PdtStateTable<StateId, StackId> &st,
        const Stack &s,
        const vector<StackId> &sl,
        const vector<Weight> &d, const vector<Weight> &fd)
        : ShortestFirstQueue<StateId, StackCompare>(
            StackCompare(st, s, sl, d, fd)) {}
  };


  void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens);
  Weight DistanceToDest(StateId state, StateId source) const;
  uint8 Flags(StateId s) const;
  void SetFlags(StateId s, uint8 flags, uint8 mask);
  Weight Distance(StateId s) const;
  void SetDistance(StateId s, Weight w);
  Weight FinalDistance(StateId s) const;
  void SetFinalDistance(StateId s, Weight w);
  StateId SourceState(StateId s) const;
  void SetSourceState(StateId s, StateId p);
  void AddStateAndEnqueue(StateId s);
  void Relax(StateId s, const A &arc, Weight w);
  bool PruneArc(StateId s, const A &arc);
  void ProcStart();
  void ProcFinal(StateId s);
  bool ProcNonParen(StateId s, const A &arc, bool add_arc);
  bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi);
  bool ProcCloseParen(StateId s, const A &arc);
  void ProcDestStates(StateId s, StackId si);

  Fst<A> *ifst_;                   // Input PDT
  VectorFst<Arc> rfst_;            // Reversed PDT
  bool keep_parentheses_;          // Keep parentheses in ofst?
  StateTable state_table_;         // State table for efst_
  Stack stack_;                    // Stack trie
  ExpandFst<Arc> efst_;            // Expanded PDT
  vector<StackId> stack_length_;   // Length of stack for given stack id
  vector<Weight> distance_;        // Distance from initial state in efst_/ofst
  vector<Weight> fdistance_;       // Distance to final states in efst_/ofst
  ShortestStackFirstQueue queue_;  // Queue used to visit efst_
  vector<uint8> flags_;            // Status flags for states in efst_/ofst
  vector<StateId> sources_;        // PDT source state for each expanded state

  typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP;
  typedef typename SP::CloseParenMultimap ParenMultimap;
  SP *reverse_shortest_path_;  // Shortest path for rfst_
  PdtBalanceData<Arc> *balance_data_;   // Not owned by shortest_path_
  ParenMultimap close_paren_multimap_;  // Maps open paren arcs to
  // balancing close paren arcs.

  MutableFst<Arc> *ofst_;  // Output fst
  Weight limit_;           // Weight limit

  typedef unordered_map<StateId, Weight> DestMap;
  DestMap dest_map_;
  StackId current_stack_id_;
  // 'current_stack_id_' is the stack id of the states currently at the top
  // of queue, i.e., the states currently being popped and processed.
  // 'dest_map_' maps a state 's' in 'ifst_' that is the source
  // of a close parentheses matching the top of 'current_stack_id_; to
  // the shortest-distance from '(s, current_stack_id_)' to the final
  // states in 'efst_'.
  ssize_t current_paren_id_;  // Paren id at top of current stack
  ssize_t cached_stack_id_;
  StateId cached_source_;
  slist<pair<StateId, Weight> > cached_dest_list_;
  // 'cached_dest_list_' contains the set of pair of destination
  // states and weight to final states for source state
  // 'cached_source_' and paren id 'cached_paren_id': the set of
  // source state of a close parenthesis with paren id
  // 'cached_paren_id' balancing an incoming open parenthesis with
  // paren id 'cached_paren_id' in state 'cached_source_'.

  NaturalLess<Weight> less_;
};

template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01;
template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02;
template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04;


// Initializes close paren multimap, mapping pairs (s,paren_id) to
// all the arcs out of s labeled with close parenthese for paren_id.
template <class A>
void PrunedExpand<A>::InitCloseParenMultimap(
    const vector<pair<Label, Label> > &parens) {
  unordered_map<Label, Label> paren_id_map;
  for (Label i = 0; i < parens.size(); ++i) {
    const pair<Label, Label>  &p = parens[i];
    paren_id_map[p.first] = i;
    paren_id_map[p.second] = i;
  }

  for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
    StateId s = siter.Value();
    for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
         !aiter.Done(); aiter.Next()) {
      const Arc &arc = aiter.Value();
      typename unordered_map<Label, Label>::const_iterator pit
          = paren_id_map.find(arc.ilabel);
      if (pit == paren_id_map.end()) continue;
      if (arc.ilabel == parens[pit->second].second) {  // Close paren
        ParenState<Arc> paren_state(pit->second, s);
        close_paren_multimap_.insert(make_pair(paren_state, arc));
      }
    }
  }
}


// Returns the weight of the shortest balanced path from 'source' to 'dest'
// in 'ifst_', 'dest' must be the source state of a close paren arc.
template <class A>
typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source,
                                                   StateId dest) const {
  typename SP::SearchState s(source + 1, dest + 1);
  VLOG(2) << "D(" << source << ", " << dest << ") ="
            << reverse_shortest_path_->GetShortestPathData().Distance(s);
  return reverse_shortest_path_->GetShortestPathData().Distance(s);
}

// Returns the flags for state 's' in 'ofst_'.
template <class A>
uint8 PrunedExpand<A>::Flags(StateId s) const {
  return s < flags_.size() ? flags_[s] : 0;
}

// Modifies the flags for state 's' in 'ofst_'.
template <class A>
void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) {
  while (flags_.size() <= s) flags_.push_back(0);
  flags_[s] &= ~mask;
  flags_[s] |= flags & mask;
}


// Returns the shortest distance from the initial state to 's' in 'ofst_'.
template <class A>
typename A::Weight PrunedExpand<A>::Distance(StateId s) const {
  return s < distance_.size() ? distance_[s] : Weight::Zero();
}

// Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'.
template <class A>
void PrunedExpand<A>::SetDistance(StateId s, Weight w) {
  while (distance_.size() <= s ) distance_.push_back(Weight::Zero());
  distance_[s] = w;
}


// Returns the shortest distance from 's' to the final states in 'ofst_'.
template <class A>
typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const {
  return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
}

// Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'.
template <class A>
void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) {
  while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
  fdistance_[s] = w;
}

// Returns the PDT "source" state of state 's' in 'ofst_'.
template <class A>
typename A::StateId PrunedExpand<A>::SourceState(StateId s) const {
  return s < sources_.size() ? sources_[s] : kNoStateId;
}

// Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'.
template <class A>
void PrunedExpand<A>::SetSourceState(StateId s, StateId p) {
  while (sources_.size() <= s) sources_.push_back(kNoStateId);
  sources_[s] = p;
}

// Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue,
// modifying the flags for 's' accordingly.
template <class A>
void PrunedExpand<A>::AddStateAndEnqueue(StateId s) {
  if (!(Flags(s) & (kEnqueued | kExpanded))) {
    while (ofst_->NumStates() <= s) ofst_->AddState();
    queue_.Enqueue(s);
    SetFlags(s, kEnqueued, kEnqueued);
  } else if (Flags(s) & kEnqueued) {
    queue_.Update(s);
  }
  // TODO(allauzen): Check everything is fine when kExpanded?
}

// Relaxes arc 'arc' out of state 's' in 'ofst_':
// * if the distance to 's' times the weight of 'arc' is smaller than
//   the currently stored distance for 'arc.nextstate',
//   updates 'Distance(arc.nextstate)' with new estimate;
// * if 'fd' is less than the currently stored distance from 'arc.nextstate'
//   to the final state, updates with new estimate.
template <class A>
void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) {
  Weight nd = Times(Distance(s), arc.weight);
  if (less_(nd, Distance(arc.nextstate))) {
    SetDistance(arc.nextstate, nd);
    SetSourceState(arc.nextstate, SourceState(s));
  }
  if (less_(fd, FinalDistance(arc.nextstate)))
    SetFinalDistance(arc.nextstate, fd);
  VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
            << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
            << ", nd = " << nd;
}

// Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to
// be pruned.
template <class A>
bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) {
  VLOG(2) << "Prune ?";
  Weight fd = Weight::Zero();

  if ((cached_source_ != SourceState(s)) ||
      (cached_stack_id_ != current_stack_id_)) {
    cached_source_ = SourceState(s);
    cached_stack_id_ = current_stack_id_;
    cached_dest_list_.clear();
    if (cached_source_ != ifst_->Start()) {
      for (SetIterator set_iter =
               balance_data_->Find(current_paren_id_, cached_source_);
           !set_iter.Done(); set_iter.Next()) {
        StateId dest = set_iter.Element();
        typename DestMap::const_iterator iter = dest_map_.find(dest);
        cached_dest_list_.push_front(*iter);
      }
    } else {
      // TODO(allauzen): queue discipline should prevent this never
      // from happening; replace by a check.
      cached_dest_list_.push_front(
          make_pair(rfst_.Start() -1, Weight::One()));
    }
  }

  for (typename slist<pair<StateId, Weight> >::const_iterator iter =
           cached_dest_list_.begin();
       iter != cached_dest_list_.end();
       ++iter) {
    fd = Plus(fd,
              Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id,
                                   iter->first),
                    iter->second));
  }
  Relax(s, arc, fd);
  Weight w = Times(Distance(s), Times(arc.weight, fd));
  return less_(limit_, w);
}

// Adds start state of 'efst_' to 'ofst_', enqueues it and initializes
// the distance data structures.
template <class A>
void PrunedExpand<A>::ProcStart() {
  StateId s = efst_.Start();
  AddStateAndEnqueue(s);
  ofst_->SetStart(s);
  SetSourceState(s, ifst_->Start());

  current_stack_id_ = 0;
  current_paren_id_ = -1;
  stack_length_.push_back(0);
  dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed

  cached_source_ = ifst_->Start();
  cached_stack_id_ = 0;
  cached_dest_list_.push_front(
          make_pair(rfst_.Start() -1, Weight::One()));

  PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0);
  SetFinalDistance(state_table_.FindState(tuple), Weight::One());
  SetDistance(s, Weight::One());
  SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1));
  VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1);
}

// Makes 's' final in 'ofst_' if shortest accepting path ending in 's'
// is below threshold.
template <class A>
void PrunedExpand<A>::ProcFinal(StateId s) {
  Weight final = efst_.Final(s);
  if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final)))
    return;
  ofst_->SetFinal(s, final);
}

// Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is
// below the threshold.  When 'add_arc' is true, 'arc' is added to 'ofst_'.
template <class A>
bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) {
  VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate
          << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
          << ", add_arc = " << (add_arc ? "true" : "false");
  if (PruneArc(s, arc)) return false;
  if(add_arc) ofst_->AddArc(s, arc);
  AddStateAndEnqueue(arc.nextstate);
  return true;
}

// Processes an open paren arc 'arc' out of state 's' in 'ofst_'.
// When 'arc' is labeled with an open paren,
// 1. considers each (shortest) balanced path starting in 's' by
//    taking 'arc' and ending by a close paren balancing the open
//    paren of 'arc' as a meta-arc, processes and prunes each meta-arc
//    as a non-paren arc, inserting its destination to the queue;
// 2. if at least one of these meta-arcs has not been pruned,
//    adds the destination of 'arc' to 'ofst_' as a new source state
//    for the stack id 'nsi' and inserts it in the queue.
template <class A>
bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si,
                                    StackId nsi) {
  // Update the stack lenght when needed: |nsi| = |si| + 1.
  while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
  if (stack_length_[nsi] == -1)
    stack_length_[nsi] = stack_length_[si] + 1;

  StateId ns = arc.nextstate;
  VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
            << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
  bool proc_arc = false;
  Weight fd = Weight::Zero();
  ssize_t paren_id = stack_.ParenId(arc.ilabel);
  slist<StateId> sources;
  for (SetIterator set_iter =
           balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
       !set_iter.Done(); set_iter.Next()) {
    sources.push_front(set_iter.Element());
  }
  for (typename slist<StateId>::const_iterator sources_iter = sources.begin();
       sources_iter != sources.end();
       ++ sources_iter) {
    StateId source = *sources_iter;
    VLOG(2) << "Close paren source: " << source;
    ParenState<Arc> paren_state(paren_id, source);
    for (typename ParenMultimap::const_iterator iter =
             close_paren_multimap_.find(paren_state);
         iter != close_paren_multimap_.end() && paren_state == iter->first;
         ++iter) {
      Arc meta_arc = iter->second;
      PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
      meta_arc.nextstate =  state_table_.FindState(tuple);
      VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source;
      VLOG(2) << "Meta arc weight = " << arc.weight << " Times "
                << DistanceToDest(state_table_.Tuple(ns).state_id, source)
                << " Times " << meta_arc.weight;
      meta_arc.weight = Times(
          arc.weight,
          Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
                meta_arc.weight));
      proc_arc |= ProcNonParen(s, meta_arc, false);
      fd = Plus(fd, Times(
          Times(
              DistanceToDest(state_table_.Tuple(ns).state_id, source),
              iter->second.weight),
          FinalDistance(meta_arc.nextstate)));
    }
  }
  if (proc_arc) {
    VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
    ofst_->AddArc(
      s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
    AddStateAndEnqueue(arc.nextstate);
    Weight nd = Times(Distance(s), arc.weight);
    if(less_(nd, Distance(arc.nextstate)))
      SetDistance(arc.nextstate, nd);
    // FinalDistance not necessary for source state since pruning
    // decided using the meta-arcs above.  But this is a problem with
    // A*, hence:
    if (less_(fd, FinalDistance(arc.nextstate)))
      SetFinalDistance(arc.nextstate, fd);
    SetFlags(arc.nextstate, kSourceState, kSourceState);
  }
  return proc_arc;
}

// Checks that shortest path through close paren arc in 'efst_' is
// below threshold, if so adds it to 'ofst_'.
template <class A>
bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) {
  Weight w = Times(Distance(s),
                   Times(arc.weight, FinalDistance(arc.nextstate)));
  if (less_(limit_, w))
    return false;
  ofst_->AddArc(
      s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
  return true;
}

// When 's' in 'ofst_' is a source state for stack id 'si', identifies
// all the corresponding possible destination states, that is, all the
// states in 'ifst_' that have an outgoing close paren arc balancing
// the incoming open paren taken to get to 's', and for each such
// state 't', computes the shortest distance from (t, si) to the final
// states in 'ofst_'. Stores this information in 'dest_map_'.
template <class A>
void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) {
  if (!(Flags(s) & kSourceState)) return;
  if (si != current_stack_id_) {
    dest_map_.clear();
    current_stack_id_ = si;
    current_paren_id_ = stack_.Top(current_stack_id_);
    VLOG(2) << "StackID " << si << " dequeued for first time";
  }
  // TODO(allauzen): clean up source state business; rename current function to
  // ProcSourceState.
  SetSourceState(s, state_table_.Tuple(s).state_id);

  ssize_t paren_id = stack_.Top(si);
  for (SetIterator set_iter =
           balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
       !set_iter.Done(); set_iter.Next()) {
    StateId dest_state = set_iter.Element();
    if (dest_map_.find(dest_state) != dest_map_.end())
      continue;
    Weight dest_weight = Weight::Zero();
    ParenState<Arc> paren_state(paren_id, dest_state);
    for (typename ParenMultimap::const_iterator iter =
             close_paren_multimap_.find(paren_state);
         iter != close_paren_multimap_.end() && paren_state == iter->first;
         ++iter) {
      const Arc &arc = iter->second;
      PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si));
      dest_weight = Plus(dest_weight,
                         Times(arc.weight,
                               FinalDistance(state_table_.FindState(tuple))));
    }
    dest_map_[dest_state] = dest_weight;
    VLOG(2) << "State " << dest_state << " is a dest state for stack id "
              << si << " with weight " << dest_weight;
  }
}

// Expands and prunes with weight threshold 'threshold' the input PDT.
// Writes the result in 'ofst'.
template <class A>
void PrunedExpand<A>::Expand(
    MutableFst<A> *ofst, const typename A::Weight &threshold) {
  ofst_ = ofst;
  ofst_->DeleteStates();
  ofst_->SetInputSymbols(ifst_->InputSymbols());
  ofst_->SetOutputSymbols(ifst_->OutputSymbols());

  limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
  flags_.clear();

  ProcStart();

  while (!queue_.Empty()) {
    StateId s = queue_.Head();
    queue_.Dequeue();
    SetFlags(s, kExpanded, kExpanded | kEnqueued);
    VLOG(2) << s << " dequeued!";

    ProcFinal(s);
    StackId stack_id = state_table_.Tuple(s).stack_id;
    ProcDestStates(s, stack_id);

    for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s);
         !aiter.Done();
         aiter.Next()) {
      Arc arc = aiter.Value();
      StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
      if (stack_id == nextstack_id)
        ProcNonParen(s, arc, true);
      else if (stack_id == stack_.Pop(nextstack_id))
        ProcOpenParen(s, arc, stack_id, nextstack_id);
      else
        ProcCloseParen(s, arc);
    }
    VLOG(2) << "d[" << s << "] = " << Distance(s)
            << ", fd[" << s << "] = " << FinalDistance(s);
  }
}

//
// Expand() Functions
//

template <class Arc>
struct ExpandOptions {
  bool connect;
  bool keep_parentheses;
  typename Arc::Weight weight_threshold;

  ExpandOptions(bool c  = true, bool k = false,
                typename Arc::Weight w = Arc::Weight::Zero())
      : connect(c), keep_parentheses(k), weight_threshold(w) {}
};

// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
// This version writes the expanded PDT result to a MutableFst.
// In the PDT, some transitions are labeled with open or close
// parentheses. To be interpreted as a PDT, the parens must balance on
// a path. The open-close parenthesis label pairs are passed in
// 'parens'. The expansion enforces the parenthesis constraints. The
// PDT must be expandable as an FST.
template <class Arc>
void Expand(
    const Fst<Arc> &ifst,
    const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
    MutableFst<Arc> *ofst,
    const ExpandOptions<Arc> &opts) {
  typedef typename Arc::Label Label;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;
  typedef typename ExpandFst<Arc>::StackId StackId;

  ExpandFstOptions<Arc> eopts;
  eopts.gc_limit = 0;
  if (opts.weight_threshold == Weight::Zero()) {
    eopts.keep_parentheses = opts.keep_parentheses;
    *ofst = ExpandFst<Arc>(ifst, parens, eopts);
  } else {
    PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
    pruned_expand.Expand(ofst, opts.weight_threshold);
  }

  if (opts.connect)
    Connect(ofst);
}

// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
// This version writes the expanded PDT result to a MutableFst.
// In the PDT, some transitions are labeled with open or close
// parentheses. To be interpreted as a PDT, the parens must balance on
// a path. The open-close parenthesis label pairs are passed in
// 'parens'. The expansion enforces the parenthesis constraints. The
// PDT must be expandable as an FST.
template<class Arc>
void Expand(
    const Fst<Arc> &ifst,
    const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
    MutableFst<Arc> *ofst,
    bool connect = true, bool keep_parentheses = false) {
  Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses));
}

}  // namespace fst

#endif  // FST_EXTENSIONS_PDT_EXPAND_H__