// shortest-path.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: allauzen@google.com (Cyril Allauzen) // // \file // Functions to find shortest paths in an FST. #ifndef FST_LIB_SHORTEST_PATH_H__ #define FST_LIB_SHORTEST_PATH_H__ #include <functional> #include <utility> using std::pair; using std::make_pair; #include <vector> using std::vector; #include <fst/cache.h> #include <fst/determinize.h> #include <fst/queue.h> #include <fst/shortest-distance.h> #include <fst/test-properties.h> namespace fst { template <class Arc, class Queue, class ArcFilter> struct ShortestPathOptions : public ShortestDistanceOptions<Arc, Queue, ArcFilter> { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; size_t nshortest; // return n-shortest paths bool unique; // only return paths with distinct input strings bool has_distance; // distance vector already contains the // shortest distance from the initial state bool first_path; // Single shortest path stops after finding the first // path to a final state. That path is the shortest path // only when using the ShortestFirstQueue and // only when all the weights in the FST are between // One() and Zero() according to NaturalLess. Weight weight_threshold; // pruning weight threshold. StateId state_threshold; // pruning state threshold. ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false, bool hasdist = false, float d = kDelta, bool fp = false, Weight w = Weight::Zero(), StateId s = kNoStateId) : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d), nshortest(n), unique(u), has_distance(hasdist), first_path(fp), weight_threshold(w), state_threshold(s) {} }; // Shortest-path algorithm: normally not called directly; prefer // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in // 'ifst'. 'distance' returns the shortest distances from the source // state to each state in 'ifst'. 'opts' is used to specify options // such as the queue discipline, the arc filter and delta. // // The shortest path is the lowest weight path w.r.t. the natural // semiring order. // // The weights need to be right distributive and have the path (kPath) // property. template<class Arc, class Queue, class ArcFilter> void SingleShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, vector<typename Arc::Weight> *distance, ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; ofst->DeleteStates(); ofst->SetInputSymbols(ifst.InputSymbols()); ofst->SetOutputSymbols(ifst.OutputSymbols()); if (ifst.Start() == kNoStateId) { if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); return; } vector<bool> enqueued; vector<StateId> parent; vector<Arc> arc_parent; Queue *state_queue = opts.state_queue; StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source; Weight f_distance = Weight::Zero(); StateId f_parent = kNoStateId; distance->clear(); state_queue->Clear(); if (opts.nshortest != 1) { FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath" << " instead"; ofst->SetProperties(kError, kError); return; } if (opts.weight_threshold != Weight::Zero() || opts.state_threshold != kNoStateId) { FSTERROR() << "SingleShortestPath: weight and state thresholds not applicable"; ofst->SetProperties(kError, kError); return; } if ((Weight::Properties() & (kPath | kRightSemiring)) != (kPath | kRightSemiring)) { FSTERROR() << "SingleShortestPath: Weight needs to have the path" << " property and be right distributive: " << Weight::Type(); ofst->SetProperties(kError, kError); return; } while (distance->size() < source) { distance->push_back(Weight::Zero()); enqueued.push_back(false); parent.push_back(kNoStateId); arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); } distance->push_back(Weight::One()); parent.push_back(kNoStateId); arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); state_queue->Enqueue(source); enqueued.push_back(true); while (!state_queue->Empty()) { StateId s = state_queue->Head(); state_queue->Dequeue(); enqueued[s] = false; Weight sd = (*distance)[s]; if (ifst.Final(s) != Weight::Zero()) { Weight w = Times(sd, ifst.Final(s)); if (f_distance != Plus(f_distance, w)) { f_distance = Plus(f_distance, w); f_parent = s; } if (!f_distance.Member()) { ofst->SetProperties(kError, kError); return; } if (opts.first_path) break; } for (ArcIterator< Fst<Arc> > aiter(ifst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); while (distance->size() <= arc.nextstate) { distance->push_back(Weight::Zero()); enqueued.push_back(false); parent.push_back(kNoStateId); arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); } Weight &nd = (*distance)[arc.nextstate]; Weight w = Times(sd, arc.weight); if (nd != Plus(nd, w)) { nd = Plus(nd, w); if (!nd.Member()) { ofst->SetProperties(kError, kError); return; } parent[arc.nextstate] = s; arc_parent[arc.nextstate] = arc; if (!enqueued[arc.nextstate]) { state_queue->Enqueue(arc.nextstate); enqueued[arc.nextstate] = true; } else { state_queue->Update(arc.nextstate); } } } } StateId s_p = kNoStateId, d_p = kNoStateId; for (StateId s = f_parent, d = kNoStateId; s != kNoStateId; d = s, s = parent[s]) { d_p = s_p; s_p = ofst->AddState(); if (d == kNoStateId) { ofst->SetFinal(s_p, ifst.Final(f_parent)); } else { arc_parent[d].nextstate = d_p; ofst->AddArc(s_p, arc_parent[d]); } } ofst->SetStart(s_p); if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); ofst->SetProperties( ShortestPathProperties(ofst->Properties(kFstProperties, false)), kFstProperties); } template <class S, class W> class ShortestPathCompare { public: typedef S StateId; typedef W Weight; typedef pair<StateId, Weight> Pair; ShortestPathCompare(const vector<Pair>& pairs, const vector<Weight>& distance, StateId sfinal, float d) : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d) {} bool operator()(const StateId x, const StateId y) const { const Pair &px = pairs_[x]; const Pair &py = pairs_[y]; Weight dx = px.first == superfinal_ ? Weight::One() : px.first < distance_.size() ? distance_[px.first] : Weight::Zero(); Weight dy = py.first == superfinal_ ? Weight::One() : py.first < distance_.size() ? distance_[py.first] : Weight::Zero(); Weight wx = Times(dx, px.second); Weight wy = Times(dy, py.second); // Penalize complete paths to ensure correct results with inexact weights. // This forms a strict weak order so long as ApproxEqual(a, b) => // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b). if (px.first == superfinal_ && py.first != superfinal_) { return less_(wy, wx) || ApproxEqual(wx, wy, delta_); } else if (py.first == superfinal_ && px.first != superfinal_) { return less_(wy, wx) && !ApproxEqual(wx, wy, delta_); } else { return less_(wy, wx); } } private: const vector<Pair> &pairs_; const vector<Weight> &distance_; StateId superfinal_; float delta_; NaturalLess<Weight> less_; }; // N-Shortest-path algorithm: implements the core n-shortest path // algorithm. The output is built REVERSED. See below for versions with // more options and not reversed. // // 'ofst' contains the REVERSE of 'n'-shortest paths in 'ifst'. // 'distance' must contain the shortest distance from each state to a final // state in 'ifst'. 'delta' is the convergence delta. // // The n-shortest paths are the n-lowest weight paths w.r.t. the // natural semiring order. The single path that can be read from the // ith of at most n transitions leaving the initial state of 'ofst' is // the ith shortest path. Disregarding the initial state and initial // transitions, the n-shortest paths, in fact, form a tree rooted at // the single final state. // // The weights need to be left and right distributive (kSemiring) and // have the path (kPath) property. // // The algorithm is from Mohri and Riley, "An Efficient Algorithm for // the n-best-strings problem", ICSLP 2002. The algorithm relies on // the shortest-distance algorithm. There are some issues with the // pseudo-code as written in the paper (viz., line 11). // // IMPLEMENTATION NOTE: The input fst 'ifst' can be a delayed fst and // and at any state in its expansion the values of distance vector need only // be defined at that time for the states that are known to exist. template<class Arc, class RevArc> void NShortestPath(const Fst<RevArc> &ifst, MutableFst<Arc> *ofst, const vector<typename Arc::Weight> &distance, size_t n, float delta = kDelta, typename Arc::Weight weight_threshold = Arc::Weight::Zero(), typename Arc::StateId state_threshold = kNoStateId) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; typedef pair<StateId, Weight> Pair; typedef typename RevArc::Weight RevWeight; if (n <= 0) return; if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) { FSTERROR() << "NShortestPath: Weight needs to have the " << "path property and be distributive: " << Weight::Type(); ofst->SetProperties(kError, kError); return; } ofst->DeleteStates(); ofst->SetInputSymbols(ifst.InputSymbols()); ofst->SetOutputSymbols(ifst.OutputSymbols()); // Each state in 'ofst' corresponds to a path with weight w from the // initial state of 'ifst' to a state s in 'ifst', that can be // characterized by a pair (s,w). The vector 'pairs' maps each // state in 'ofst' to the corresponding pair maps states in OFST to // the corresponding pair (s,w). vector<Pair> pairs; // The supefinal state is denoted by -1, 'compare' knows that the // distance from 'superfinal' to the final state is 'Weight::One()', // hence 'distance[superfinal]' is not needed. StateId superfinal = -1; ShortestPathCompare<StateId, Weight> compare(pairs, distance, superfinal, delta); vector<StateId> heap; // 'r[s + 1]', 's' state in 'fst', is the number of states in 'ofst' // which corresponding pair contains 's' ,i.e. , it is number of // paths computed so far to 's'. Valid for 's == -1' (superfinal). vector<int> r; NaturalLess<Weight> less; if (ifst.Start() == kNoStateId || distance.size() <= ifst.Start() || distance[ifst.Start()] == Weight::Zero() || less(weight_threshold, Weight::One()) || state_threshold == 0) { if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); return; } ofst->SetStart(ofst->AddState()); StateId final = ofst->AddState(); ofst->SetFinal(final, Weight::One()); while (pairs.size() <= final) pairs.push_back(Pair(kNoStateId, Weight::Zero())); pairs[final] = Pair(ifst.Start(), Weight::One()); heap.push_back(final); Weight limit = Times(distance[ifst.Start()], weight_threshold); while (!heap.empty()) { pop_heap(heap.begin(), heap.end(), compare); StateId state = heap.back(); Pair p = pairs[state]; heap.pop_back(); Weight d = p.first == superfinal ? Weight::One() : p.first < distance.size() ? distance[p.first] : Weight::Zero(); if (less(limit, Times(d, p.second)) || (state_threshold != kNoStateId && ofst->NumStates() >= state_threshold)) continue; while (r.size() <= p.first + 1) r.push_back(0); ++r[p.first + 1]; if (p.first == superfinal) ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state)); if ((p.first == superfinal) && (r[p.first + 1] == n)) break; if (r[p.first + 1] > n) continue; if (p.first == superfinal) continue; for (ArcIterator< Fst<RevArc> > aiter(ifst, p.first); !aiter.Done(); aiter.Next()) { const RevArc &rarc = aiter.Value(); Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate); Weight w = Times(p.second, arc.weight); StateId next = ofst->AddState(); pairs.push_back(Pair(arc.nextstate, w)); arc.nextstate = state; ofst->AddArc(next, arc); heap.push_back(next); push_heap(heap.begin(), heap.end(), compare); } Weight finalw = ifst.Final(p.first).Reverse(); if (finalw != Weight::Zero()) { Weight w = Times(p.second, finalw); StateId next = ofst->AddState(); pairs.push_back(Pair(superfinal, w)); ofst->AddArc(next, Arc(0, 0, finalw, state)); heap.push_back(next); push_heap(heap.begin(), heap.end(), compare); } } Connect(ofst); if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); ofst->SetProperties( ShortestPathProperties(ofst->Properties(kFstProperties, false)), kFstProperties); } // N-Shortest-path algorithm: this version allow fine control // via the options argument. See below for a simpler interface. // // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns // the shortest distances from the source state to each state in // 'ifst'. 'opts' is used to specify options such as the number of // paths to return, whether they need to have distinct input // strings, the queue discipline, the arc filter and the convergence // delta. // // The n-shortest paths are the n-lowest weight paths w.r.t. the // natural semiring order. The single path that can be read from the // ith of at most n transitions leaving the initial state of 'ofst' is // the ith shortest path. Disregarding the initial state and initial // transitions, The n-shortest paths, in fact, form a tree rooted at // the single final state. // The weights need to be right distributive and have the path (kPath) // property. They need to be left distributive as well for nshortest // > 1. // // The algorithm is from Mohri and Riley, "An Efficient Algorithm for // the n-best-strings problem", ICSLP 2002. The algorithm relies on // the shortest-distance algorithm. There are some issues with the // pseudo-code as written in the paper (viz., line 11). template<class Arc, class Queue, class ArcFilter> void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, vector<typename Arc::Weight> *distance, ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; typedef ReverseArc<Arc> ReverseArc; size_t n = opts.nshortest; if (n == 1) { SingleShortestPath(ifst, ofst, distance, opts); return; } if (n <= 0) return; if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) { FSTERROR() << "ShortestPath: n-shortest: Weight needs to have the " << "path property and be distributive: " << Weight::Type(); ofst->SetProperties(kError, kError); return; } if (!opts.has_distance) { ShortestDistance(ifst, distance, opts); if (distance->size() == 1 && !(*distance)[0].Member()) { ofst->SetProperties(kError, kError); return; } } // Algorithm works on the reverse of 'fst' : 'rfst', 'distance' is // the distance to the final state in 'rfst', 'ofst' is built as the // reverse of the tree of n-shortest path in 'rfst'. VectorFst<ReverseArc> rfst; Reverse(ifst, &rfst); Weight d = Weight::Zero(); for (ArcIterator< VectorFst<ReverseArc> > aiter(rfst, 0); !aiter.Done(); aiter.Next()) { const ReverseArc &arc = aiter.Value(); StateId s = arc.nextstate - 1; if (s < distance->size()) d = Plus(d, Times(arc.weight.Reverse(), (*distance)[s])); } distance->insert(distance->begin(), d); if (!opts.unique) { NShortestPath(rfst, ofst, *distance, n, opts.delta, opts.weight_threshold, opts.state_threshold); } else { vector<Weight> ddistance; DeterminizeFstOptions<ReverseArc> dopts(opts.delta); DeterminizeFst<ReverseArc> dfst(rfst, *distance, &ddistance, dopts); NShortestPath(dfst, ofst, ddistance, n, opts.delta, opts.weight_threshold, opts.state_threshold); } distance->erase(distance->begin()); } // Shortest-path algorithm: simplified interface. See above for a // version that allows finer control. // // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue // discipline is automatically selected. When 'unique' == true, only // paths with distinct input labels are returned. // // The n-shortest paths are the n-lowest weight paths w.r.t. the // natural semiring order. The single path that can be read from the // ith of at most n transitions leaving the initial state of 'ofst' is // the ith best path. // // The weights need to be right distributive and have the path // (kPath) property. template<class Arc> void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, size_t n = 1, bool unique = false, bool first_path = false, typename Arc::Weight weight_threshold = Arc::Weight::Zero(), typename Arc::StateId state_threshold = kNoStateId) { vector<typename Arc::Weight> distance; AnyArcFilter<Arc> arc_filter; AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter); ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>, AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique, false, kDelta, first_path, weight_threshold, state_threshold); ShortestPath(ifst, ofst, &distance, opts); } } // namespace fst #endif // FST_LIB_SHORTEST_PATH_H__