// shortest-path.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Functions to find shortest paths in a PDT.

#ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
#define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__

#include <fst/shortest-path.h>
#include <fst/extensions/pdt/paren.h>
#include <fst/extensions/pdt/pdt.h>

#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <tr1/unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <stack>
#include <vector>
using std::vector;

namespace fst {

template <class Arc, class Queue>
struct PdtShortestPathOptions {
  bool keep_parentheses;
  bool path_gc;

  PdtShortestPathOptions(bool kp = false, bool gc = true)
      : keep_parentheses(kp), path_gc(gc) {}
};


// Class to store PDT shortest path results. Stores shortest path
// tree info 'Distance()', Parent(), and ArcParent() information keyed
// on two types:
// (1) By SearchState: This is a usual node in a shortest path tree but:
//    (a) is w.r.t a PDT search state - a pair of a PDT state and
//        a 'start' state, which is either the PDT start state or
//        the destination state of an open parenthesis.
//    (b) the Distance() is from this 'start' state to the search state.
//    (c) Parent().state is kNoLabel for the 'start' state.
//
// (2) By ParenSpec: This connects shortest path trees depending on the
// the parenthesis taken. Given the parenthesis spec:
//    (a) the Distance() is from the Parent() 'start' state to the
//     parenthesis destination state.
//    (b) the ArcParent() is the parenthesis arc.
template <class Arc>
class PdtShortestPathData {
 public:
  static const uint8 kFinal;

  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;
  typedef typename Arc::Label Label;

  struct SearchState {
    SearchState() : state(kNoStateId), start(kNoStateId) {}

    SearchState(StateId s, StateId t) : state(s), start(t) {}

    bool operator==(const SearchState &s) const {
      if (&s == this)
        return true;
      return s.state == this->state && s.start == this->start;
    }

    StateId state;  // PDT state
    StateId start;  // PDT paren 'source' state
  };


  // Specifies paren id, source and dest 'start' states of a paren.
  // These are the 'start' states of the respective sub-graphs.
  struct ParenSpec {
    ParenSpec()
        : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {}

    ParenSpec(Label id, StateId s, StateId d)
        : paren_id(id), src_start(s), dest_start(d) {}

    Label paren_id;        // Id of parenthesis
    StateId src_start;     // sub-graph 'start' state for paren source.
    StateId dest_start;    // sub-graph 'start' state for paren dest.

    bool operator==(const ParenSpec &x) const {
      if (&x == this)
        return true;
      return x.paren_id == this->paren_id &&
          x.src_start == this->src_start &&
          x.dest_start == this->dest_start;
    }
  };

  struct SearchData {
    SearchData() : distance(Weight::Zero()),
                   parent(kNoStateId, kNoStateId),
                   paren_id(kNoLabel),
                   flags(0) {}

    Weight distance;     // Distance to this state from PDT 'start' state
    SearchState parent;  // Parent state in shortest path tree
    int16 paren_id;      // If parent arc has paren, paren ID, o.w. kNoLabel
    uint8 flags;         // First byte reserved for PdtShortestPathData use
  };

  PdtShortestPathData(bool gc)
      : state_(kNoStateId, kNoStateId),
        paren_(kNoLabel, kNoStateId, kNoStateId),
        gc_(gc),
        nstates_(0),
        ngc_(0),
        finished_(false) {}

  ~PdtShortestPathData() {
    VLOG(1) << "opm size: " << paren_map_.size();
    VLOG(1) << "# of search states: " << nstates_;
    if (gc_)
      VLOG(1) << "# of GC'd search states: " << ngc_;
  }

  void Clear() {
    search_map_.clear();
    search_multimap_.clear();
    paren_map_.clear();
    state_ = SearchState(kNoStateId, kNoStateId);
    nstates_ = 0;
    ngc_ = 0;
  }

  Weight Distance(SearchState s) const {
    SearchData *data = GetSearchData(s);
    return data->distance;
  }

  Weight Distance(const ParenSpec &paren) const {
    SearchData *data = GetSearchData(paren);
    return data->distance;
  }

  SearchState Parent(SearchState s) const {
    SearchData *data = GetSearchData(s);
    return data->parent;
  }

  SearchState Parent(const ParenSpec &paren) const {
    SearchData *data = GetSearchData(paren);
    return data->parent;
  }

  Label ParenId(SearchState s) const {
    SearchData *data = GetSearchData(s);
    return data->paren_id;
  }

  uint8 Flags(SearchState s) const {
    SearchData *data = GetSearchData(s);
    return data->flags;
  }

  void SetDistance(SearchState s, Weight w) {
    SearchData *data = GetSearchData(s);
    data->distance = w;
  }

  void SetDistance(const ParenSpec &paren, Weight w) {
    SearchData *data = GetSearchData(paren);
    data->distance = w;
  }

  void SetParent(SearchState s, SearchState p) {
    SearchData *data = GetSearchData(s);
    data->parent = p;
  }

  void SetParent(const ParenSpec &paren, SearchState p) {
    SearchData *data = GetSearchData(paren);
    data->parent = p;
  }

  void SetParenId(SearchState s, Label p) {
    if (p >= 32768)
      FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16";
    SearchData *data = GetSearchData(s);
    data->paren_id = p;
  }

  void SetFlags(SearchState s, uint8 f, uint8 mask) {
    SearchData *data = GetSearchData(s);
    data->flags &= ~mask;
    data->flags |= f & mask;
  }

  void GC(StateId s);

  void Finish() { finished_ = true; }

 private:
  static const Arc kNoArc;
  static const size_t kPrime0;
  static const size_t kPrime1;
  static const uint8 kInited;
  static const uint8 kMarked;

  // Hash for search state
  struct SearchStateHash {
    size_t operator()(const SearchState &s) const {
      return s.state + s.start * kPrime0;
    }
  };

  // Hash for paren map
  struct ParenHash {
    size_t operator()(const ParenSpec &paren) const {
      return paren.paren_id + paren.src_start * kPrime0 +
          paren.dest_start * kPrime1;
    }
  };

  typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap;

  typedef unordered_multimap<StateId, StateId> SearchMultimap;

  // Hash map from paren spec to open paren data
  typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap;

  SearchData *GetSearchData(SearchState s) const {
    if (s == state_)
      return state_data_;
    if (finished_) {
      typename SearchMap::iterator it = search_map_.find(s);
      if (it == search_map_.end())
        return &null_search_data_;
      state_ = s;
      return state_data_ = &(it->second);
    } else {
      state_ = s;
      state_data_ = &search_map_[s];
      if (!(state_data_->flags & kInited)) {
        ++nstates_;
        if (gc_)
          search_multimap_.insert(make_pair(s.start, s.state));
        state_data_->flags = kInited;
      }
      return state_data_;
    }
  }

  SearchData *GetSearchData(ParenSpec paren) const {
    if (paren == paren_)
      return paren_data_;
    if (finished_) {
      typename ParenMap::iterator it = paren_map_.find(paren);
      if (it == paren_map_.end())
        return &null_search_data_;
      paren_ = paren;
      return state_data_ = &(it->second);
    } else {
      paren_ = paren;
      return paren_data_ = &paren_map_[paren];
    }
  }

  mutable SearchMap search_map_;            // Maps from search state to data
  mutable SearchMultimap search_multimap_;  // Maps from 'start' to subgraph
  mutable ParenMap paren_map_;              // Maps paren spec to search data
  mutable SearchState state_;               // Last state accessed
  mutable SearchData *state_data_;          // Last state data accessed
  mutable ParenSpec paren_;                 // Last paren spec accessed
  mutable SearchData *paren_data_;          // Last paren data accessed
  bool gc_;                                 // Allow GC?
  mutable size_t nstates_;                  // Total number of search states
  size_t ngc_;                              // Number of GC'd search states
  mutable SearchData null_search_data_;     // Null search data
  bool finished_;                           // Read-only access when true

  DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData);
};

// Deletes inaccessible search data from a given 'start' (open paren dest)
// state. Assumes 'final' (close paren source or PDT final) states have
// been flagged 'kFinal'.
template<class Arc>
void  PdtShortestPathData<Arc>::GC(StateId start) {
  if (!gc_)
    return;
  vector<StateId> final;
  for (typename SearchMultimap::iterator mmit = search_multimap_.find(start);
       mmit != search_multimap_.end() && mmit->first == start;
       ++mmit) {
    SearchState s(mmit->second, start);
    const SearchData &data = search_map_[s];
    if (data.flags & kFinal)
      final.push_back(s.state);
  }

  // Mark phase
  for (size_t i = 0; i < final.size(); ++i) {
    SearchState s(final[i], start);
    while (s.state != kNoLabel) {
      SearchData *sdata = &search_map_[s];
      if (sdata->flags & kMarked)
        break;
      sdata->flags |= kMarked;
      SearchState p = sdata->parent;
      if (p.start != start && p.start != kNoLabel) {  // entering sub-subgraph
        ParenSpec paren(sdata->paren_id, s.start, p.start);
        SearchData *pdata = &paren_map_[paren];
        s = pdata->parent;
      } else {
        s = p;
      }
    }
  }

  // Sweep phase
  typename SearchMultimap::iterator mmit = search_multimap_.find(start);
  while (mmit != search_multimap_.end() && mmit->first == start) {
    SearchState s(mmit->second, start);
    typename SearchMap::iterator mit = search_map_.find(s);
    const SearchData &data = mit->second;
    if (!(data.flags & kMarked)) {
      search_map_.erase(mit);
      ++ngc_;
    }
    search_multimap_.erase(mmit++);
  }
}

template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc
    = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);

template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853;

template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867;

template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01;

template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal =  0x02;

template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04;


// This computes the single source shortest (balanced) path (SSSP)
// through a weighted PDT that has a bounded stack (i.e. is expandable
// as an FST). It is a generalization of the classic SSSP graph
// algorithm that removes a state s from a queue (defined by a
// user-provided queue type) and relaxes the destination states of
// transitions leaving s. In this PDT version, states that have
// entering open parentheses are treated as source states for a
// sub-graph SSSP problem with the shortest path up to the open
// parenthesis being first saved. When a close parenthesis is then
// encountered any balancing open parenthesis is examined for this
// saved information and multiplied back. In this way, each sub-graph
// is entered only once rather than repeatedly.  If every state in the
// input PDT has the property that there is a unique 'start' state for
// it with entering open parentheses, then this algorithm is quite
// straight-forward. In general, this will not be the case, so the
// algorithm (implicitly) creates a new graph where each state is a
// pair of an original state and a possible parenthesis 'start' state
// for that state.
template<class Arc, class Queue>
class PdtShortestPath {
 public:
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Weight Weight;
  typedef typename Arc::Label Label;

  typedef PdtShortestPathData<Arc> SpData;
  typedef typename SpData::SearchState SearchState;
  typedef typename SpData::ParenSpec ParenSpec;

  typedef typename PdtParenReachable<Arc>::SetIterator StateSetIterator;
  typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;

  PdtShortestPath(const Fst<Arc> &ifst,
                  const vector<pair<Label, Label> > &parens,
                  const PdtShortestPathOptions<Arc, Queue> &opts)
      : kFinal(SpData::kFinal),
        ifst_(ifst.Copy()),
        parens_(parens),
        keep_parens_(opts.keep_parentheses),
        start_(ifst.Start()),
        sp_data_(opts.path_gc),
        error_(false) {

    if ((Weight::Properties() & (kPath | kRightSemiring))
        != (kPath | kRightSemiring)) {
      FSTERROR() << "SingleShortestPath: Weight needs to have the path"
                 << " property and be right distributive: " << Weight::Type();
      error_ = true;
    }

    for (Label i = 0; i < parens.size(); ++i) {
      const pair<Label, Label>  &p = parens[i];
      paren_id_map_[p.first] = i;
      paren_id_map_[p.second] = i;
    }
  };

  ~PdtShortestPath() {
    VLOG(1) << "# of input states: " << CountStates(*ifst_);
    VLOG(1) << "# of enqueued: " << nenqueued_;
    VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
    delete ifst_;
  }

  void ShortestPath(MutableFst<Arc> *ofst) {
    Init(ofst);
    GetDistance(start_);
    GetPath();
    sp_data_.Finish();
    if (error_) ofst->SetProperties(kError, kError);
  }

  const PdtShortestPathData<Arc> &GetShortestPathData() const {
    return sp_data_;
  }

  PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }

 private:
  static const Arc kNoArc;
  static const uint8 kEnqueued;
  static const uint8 kExpanded;
  const uint8 kFinal;

 public:
  // Hash multimap from close paren label to an paren arc.
  typedef unordered_multimap<ParenState<Arc>, Arc,
                        typename ParenState<Arc>::Hash> CloseParenMultimap;

  const CloseParenMultimap &GetCloseParenMultimap() const {
    return close_paren_multimap_;
  }

 private:
  void Init(MutableFst<Arc> *ofst);
  void GetDistance(StateId start);
  void ProcFinal(SearchState s);
  void ProcArcs(SearchState s);
  void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w);
  void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w);
  void ProcNonParen(SearchState s, const Arc &arc, Weight w);
  void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id);
  void Enqueue(SearchState d);
  void GetPath();
  Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);

  Fst<Arc> *ifst_;
  MutableFst<Arc> *ofst_;
  const vector<pair<Label, Label> > &parens_;
  bool keep_parens_;
  Queue *state_queue_;                   // current state queue
  StateId start_;
  Weight f_distance_;
  SearchState f_parent_;
  SpData sp_data_;
  unordered_map<Label, Label> paren_id_map_;
  CloseParenMultimap close_paren_multimap_;
  PdtBalanceData<Arc> balance_data_;
  ssize_t nenqueued_;
  bool error_;

  DISALLOW_COPY_AND_ASSIGN(PdtShortestPath);
};

template<class Arc, class Queue>
void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
  ofst_ = ofst;
  ofst->DeleteStates();
  ofst->SetInputSymbols(ifst_->InputSymbols());
  ofst->SetOutputSymbols(ifst_->OutputSymbols());

  if (ifst_->Start() == kNoStateId)
    return;

  f_distance_ = Weight::Zero();
  f_parent_ = SearchState(kNoStateId, kNoStateId);

  sp_data_.Clear();
  close_paren_multimap_.clear();
  balance_data_.Clear();
  nenqueued_ = 0;

  // Find open parens per destination state and close parens per source state.
  for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
    StateId s = siter.Value();
    for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
         !aiter.Done(); aiter.Next()) {
      const Arc &arc = aiter.Value();
      typename unordered_map<Label, Label>::const_iterator pit
          = paren_id_map_.find(arc.ilabel);
      if (pit != paren_id_map_.end()) {               // Is a paren?
        Label paren_id = pit->second;
        if (arc.ilabel == parens_[paren_id].first) {  // Open paren
          balance_data_.OpenInsert(paren_id, arc.nextstate);
        } else {                                      // Close paren
          ParenState<Arc> paren_state(paren_id, s);
          close_paren_multimap_.insert(make_pair(paren_state, arc));
        }
      }
    }
  }
}

// Computes the shortest distance stored in a recursive way. Each
// sub-graph (i.e. different paren 'start' state) begins with weight One().
template<class Arc, class Queue>
void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
  if (start == kNoStateId)
    return;

  Queue state_queue;
  state_queue_ = &state_queue;
  SearchState q(start, start);
  Enqueue(q);
  sp_data_.SetDistance(q, Weight::One());

  while (!state_queue_->Empty()) {
    StateId state = state_queue_->Head();
    state_queue_->Dequeue();
    SearchState s(state, start);
    sp_data_.SetFlags(s, 0, kEnqueued);
    ProcFinal(s);
    ProcArcs(s);
    sp_data_.SetFlags(s, kExpanded, kExpanded);
  }
  balance_data_.FinishInsert(start);
  sp_data_.GC(start);
}

// Updates best complete path.
template<class Arc, class Queue>
void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
  if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
    Weight w = Times(sp_data_.Distance(s),
                     ifst_->Final(s.state));
    if (f_distance_ != Plus(f_distance_, w)) {
      if (f_parent_.state != kNoStateId)
        sp_data_.SetFlags(f_parent_, 0, kFinal);
      sp_data_.SetFlags(s, kFinal, kFinal);

      f_distance_ = Plus(f_distance_, w);
      f_parent_ = s;
    }
  }
}

// Processes all arcs leaving the state s.
template<class Arc, class Queue>
void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
  for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
       !aiter.Done();
       aiter.Next()) {
    Arc arc = aiter.Value();
    Weight w = Times(sp_data_.Distance(s), arc.weight);

    typename unordered_map<Label, Label>::const_iterator pit
        = paren_id_map_.find(arc.ilabel);
    if (pit != paren_id_map_.end()) {  // Is a paren?
      Label paren_id = pit->second;
      if (arc.ilabel == parens_[paren_id].first)
        ProcOpenParen(paren_id, s, arc, w);
      else
        ProcCloseParen(paren_id, s, arc, w);
    } else {
      ProcNonParen(s, arc, w);
    }
  }
}

// Saves the shortest path info for reaching this parenthesis
// and starts a new SSSP in the sub-graph pointed to by the parenthesis
// if previously unvisited. Otherwise it finds any previously encountered
// closing parentheses and relaxes them using the recursively stored
// shortest distance to them.
template<class Arc, class Queue> inline
void PdtShortestPath<Arc, Queue>::ProcOpenParen(
    Label paren_id, SearchState s, Arc arc, Weight w) {

  SearchState d(arc.nextstate, arc.nextstate);
  ParenSpec paren(paren_id, s.start, d.start);
  Weight pdist = sp_data_.Distance(paren);
  if (pdist != Plus(pdist, w)) {
    sp_data_.SetDistance(paren, w);
    sp_data_.SetParent(paren, s);
    Weight dist = sp_data_.Distance(d);
    if (dist == Weight::Zero()) {
      Queue *state_queue = state_queue_;
      GetDistance(d.start);
      state_queue_ = state_queue;
    }
    for (CloseSourceIterator set_iter =
             balance_data_.Find(paren_id, arc.nextstate);
         !set_iter.Done(); set_iter.Next()) {
      SearchState cpstate(set_iter.Element(), d.start);
      ParenState<Arc> paren_state(paren_id, cpstate.state);
      for (typename CloseParenMultimap::const_iterator cpit =
               close_paren_multimap_.find(paren_state);
           cpit != close_paren_multimap_.end() && paren_state == cpit->first;
           ++cpit) {
        const Arc &cparc = cpit->second;
        Weight cpw = Times(w, Times(sp_data_.Distance(cpstate),
                                    cparc.weight));
        Relax(cpstate, s, cparc, cpw, paren_id);
      }
    }
  }
}

// Saves the correspondence between each closing parenthesis and its
// balancing open parenthesis info. Relaxes any close parenthesis
// destination state that has a balancing previously encountered open
// parenthesis.
template<class Arc, class Queue> inline
void PdtShortestPath<Arc, Queue>::ProcCloseParen(
    Label paren_id, SearchState s, const Arc &arc, Weight w) {
  ParenState<Arc> paren_state(paren_id, s.start);
  if (!(sp_data_.Flags(s) & kExpanded)) {
    balance_data_.CloseInsert(paren_id, s.start, s.state);
    sp_data_.SetFlags(s, kFinal, kFinal);
  }
}

// For non-parentheses, classical relaxation.
template<class Arc, class Queue> inline
void PdtShortestPath<Arc, Queue>::ProcNonParen(
    SearchState s, const Arc &arc, Weight w) {
  Relax(s, s, arc, w, kNoLabel);
}

// Classical relaxation on the search graph for 'arc' from state 's'.
// State 't' is in the same sub-graph as the nextstate should be (i.e.
// has the same paren 'start'.
template<class Arc, class Queue> inline
void PdtShortestPath<Arc, Queue>::Relax(
    SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) {
  SearchState d(arc.nextstate, t.start);
  Weight dist = sp_data_.Distance(d);
  if (dist != Plus(dist, w)) {
    sp_data_.SetParent(d, s);
    sp_data_.SetParenId(d, paren_id);
    sp_data_.SetDistance(d, Plus(dist, w));
    Enqueue(d);
  }
}

template<class Arc, class Queue> inline
void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
  if (!(sp_data_.Flags(s) & kEnqueued)) {
    state_queue_->Enqueue(s.state);
    sp_data_.SetFlags(s, kEnqueued, kEnqueued);
    ++nenqueued_;
  } else {
    state_queue_->Update(s.state);
  }
}

// Follows parent pointers to find the shortest path. Uses a stack
// since the shortest distance is stored recursively.
template<class Arc, class Queue>
void PdtShortestPath<Arc, Queue>::GetPath() {
  SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId);
  StateId s_p = kNoStateId, d_p = kNoStateId;
  Arc arc(kNoArc);
  Label paren_id = kNoLabel;
  stack<ParenSpec> paren_stack;
  while (s.state != kNoStateId) {
    d_p = s_p;
    s_p = ofst_->AddState();
    if (d.state == kNoStateId) {
      ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
    } else {
      if (paren_id != kNoLabel) {                     // paren?
        if (arc.ilabel == parens_[paren_id].first) {  // open paren
          paren_stack.pop();
        } else {                                      // close paren
          ParenSpec paren(paren_id, d.start, s.start);
          paren_stack.push(paren);
        }
        if (!keep_parens_)
          arc.ilabel = arc.olabel = 0;
      }
      arc.nextstate = d_p;
      ofst_->AddArc(s_p, arc);
    }
    d = s;
    s = sp_data_.Parent(d);
    paren_id = sp_data_.ParenId(d);
    if (s.state != kNoStateId) {
      arc = GetPathArc(s, d, paren_id, false);
    } else if (!paren_stack.empty()) {
      ParenSpec paren = paren_stack.top();
      s = sp_data_.Parent(paren);
      paren_id = paren.paren_id;
      arc = GetPathArc(s, d, paren_id, true);
    }
  }
  ofst_->SetStart(s_p);
  ofst_->SetProperties(
      ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
      kFstProperties);
}


// Finds transition with least weight between two states with label matching
// paren_id and open/close paren type or a non-paren if kNoLabel.
template<class Arc, class Queue>
Arc PdtShortestPath<Arc, Queue>::GetPathArc(
    SearchState s, SearchState d, Label paren_id, bool open_paren) {
  Arc path_arc = kNoArc;
  for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
       !aiter.Done();
       aiter.Next()) {
    const Arc &arc = aiter.Value();
    if (arc.nextstate != d.state)
      continue;
    Label arc_paren_id = kNoLabel;
    typename unordered_map<Label, Label>::const_iterator pit
        = paren_id_map_.find(arc.ilabel);
    if (pit != paren_id_map_.end()) {
      arc_paren_id = pit->second;
      bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first;
      if (arc_open_paren != open_paren)
        continue;
    }
    if (arc_paren_id != paren_id)
      continue;
    if (arc.weight == Plus(arc.weight, path_arc.weight))
      path_arc = arc;
  }
  if (path_arc.nextstate == kNoStateId) {
    FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc";
    error_ = true;
  }
  return path_arc;
}

template<class Arc, class Queue>
const Arc PdtShortestPath<Arc, Queue>::kNoArc
    = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);

template<class Arc, class Queue>
const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10;

template<class Arc, class Queue>
const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;

template<class Arc, class Queue>
void ShortestPath(const Fst<Arc> &ifst,
                  const vector<pair<typename Arc::Label,
                                    typename Arc::Label> > &parens,
                  MutableFst<Arc> *ofst,
                  const PdtShortestPathOptions<Arc, Queue> &opts) {
  PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
  psp.ShortestPath(ofst);
}

template<class Arc>
void ShortestPath(const Fst<Arc> &ifst,
                  const vector<pair<typename Arc::Label,
                                    typename Arc::Label> > &parens,
                  MutableFst<Arc> *ofst) {
  typedef FifoQueue<typename Arc::StateId> Queue;
  PdtShortestPathOptions<Arc, Queue> opts;
  PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
  psp.ShortestPath(ofst);
}

}  // namespace fst

#endif  // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__