// replace-util.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Utility classes for the recursive replacement of Fsts (RTNs). #ifndef FST_LIB_REPLACE_UTIL_H__ #define FST_LIB_REPLACE_UTIL_H__ #include <vector> using std::vector; #include <unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; #include <unordered_set> using std::tr1::unordered_set; using std::tr1::unordered_multiset; #include <map> #include <fst/connect.h> #include <fst/mutable-fst.h> #include <fst/topsort.h> namespace fst { template <class Arc> void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&, MutableFst<Arc> *, typename Arc::Label, bool); // Utility class for the recursive replacement of Fsts (RTNs). The // user provides a set of Label, Fst pairs at construction. These are // used by methods for testing cyclic dependencies and connectedness // and doing RTN connection and specific Fst replacement by label or // for various optimization properties. The modified results can be // obtained with the GetFstPairs() or GetMutableFstPairs() methods. template <class Arc> class ReplaceUtil { public: typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; typedef pair<Label, const Fst<Arc>*> FstPair; typedef pair<Label, MutableFst<Arc>*> MutableFstPair; typedef unordered_map<Label, Label> NonTerminalHash; // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil. ReplaceUtil(const vector<MutableFstPair> &fst_pairs, Label root_label, bool epsilon_on_replace = false); // Constructs from Fsts; Fst ownership retained by caller. ReplaceUtil(const vector<FstPair> &fst_pairs, Label root_label, bool epsilon_on_replace = false); // Constructs from ReplaceFst internals; ownership retained by caller. ReplaceUtil(const vector<const Fst<Arc> *> &fst_array, const NonTerminalHash &nonterminal_hash, Label root_fst, bool epsilon_on_replace = false); ~ReplaceUtil() { for (Label i = 0; i < fst_array_.size(); ++i) delete fst_array_[i]; } // True if the non-terminal dependencies are cyclic. Cyclic // dependencies will result in an unexpandable replace fst. bool CyclicDependencies() const { GetDependencies(false); return depprops_ & kCyclic; } // Returns true if no useless Fsts, states or transitions. bool Connected() const { GetDependencies(false); uint64 props = kAccessible | kCoAccessible; for (Label i = 0; i < fst_array_.size(); ++i) { if (!fst_array_[i]) continue; if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) return false; } return true; } // Removes useless Fsts, states and transitions. void Connect(); // Replaces Fsts specified by labels. // Does nothing if there are cyclic dependencies. void ReplaceLabels(const vector<Label> &labels); // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and // 'nnonterm' non-terminals (updating in reverse dependency order). // Does nothing if there are cyclic dependencies. void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms); // Replaces singleton Fsts. // Does nothing if there are cyclic dependencies. void ReplaceTrivial() { ReplaceBySize(2, 1, 1); } // Replaces non-terminals that have at most 'ninstances' instances // (updating in dependency order). // Does nothing if there are cyclic dependencies. void ReplaceByInstances(size_t ninstances); // Replaces non-terminals that have only one instance. // Does nothing if there are cyclic dependencies. void ReplaceUnique() { ReplaceByInstances(1); } // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil. void GetFstPairs(vector<FstPair> *fst_pairs); // Returns Label, MutableFst pairs; Fst ownership given to caller. void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs); private: // Per Fst statistics struct ReplaceStats { StateId nstates; // # of states StateId nfinal; // # of final states size_t narcs; // # of arcs Label nnonterms; // # of non-terminals in Fst size_t nref; // # of non-terminal instances referring to this Fst // # of times that ith Fst references this Fst map<Label, size_t> inref; // # of times that this Fst references the ith Fst map<Label, size_t> outref; ReplaceStats() : nstates(0), nfinal(0), narcs(0), nnonterms(0), nref(0) {} }; // Check Mutable Fsts exist o.w. create them. void CheckMutableFsts(); // Computes the dependency graph of the replace Fsts. // If 'stats' is true, dependency statistics computed as well. void GetDependencies(bool stats) const; void ClearDependencies() const { depfst_.DeleteStates(); stats_.clear(); depprops_ = 0; have_stats_ = false; } // Get topological order of dependencies. Returns false with cyclic input. bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const; // Update statistics assuming that jth Fst will be replaced. void UpdateStats(Label j); Label root_label_; // root non-terminal Label root_fst_; // root Fst ID bool epsilon_on_replace_; // see Replace() vector<const Fst<Arc> *> fst_array_; // Fst per ID vector<MutableFst<Arc> *> mutable_fst_array_; // MutableFst per ID vector<Label> nonterminal_array_; // Fst ID to non-terminal NonTerminalHash nonterminal_hash_; // non-terminal to Fst ID mutable VectorFst<Arc> depfst_; // Fst ID dependencies mutable vector<bool> depaccess_; // Fst ID accessibility mutable uint64 depprops_; // dependency Fst props mutable bool have_stats_; // have dependency statistics mutable vector<ReplaceStats> stats_; // Per Fst statistics DISALLOW_COPY_AND_ASSIGN(ReplaceUtil); }; template <class Arc> ReplaceUtil<Arc>::ReplaceUtil( const vector<MutableFstPair> &fst_pairs, Label root_label, bool epsilon_on_replace) : root_label_(root_label), epsilon_on_replace_(epsilon_on_replace), depprops_(0), have_stats_(false) { fst_array_.push_back(0); mutable_fst_array_.push_back(0); nonterminal_array_.push_back(kNoLabel); for (Label i = 0; i < fst_pairs.size(); ++i) { Label label = fst_pairs[i].first; MutableFst<Arc> *fst = fst_pairs[i].second; nonterminal_hash_[label] = fst_array_.size(); nonterminal_array_.push_back(label); fst_array_.push_back(fst); mutable_fst_array_.push_back(fst); } root_fst_ = nonterminal_hash_[root_label_]; if (!root_fst_) FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; } template <class Arc> ReplaceUtil<Arc>::ReplaceUtil( const vector<FstPair> &fst_pairs, Label root_label, bool epsilon_on_replace) : root_label_(root_label), epsilon_on_replace_(epsilon_on_replace), depprops_(0), have_stats_(false) { fst_array_.push_back(0); nonterminal_array_.push_back(kNoLabel); for (Label i = 0; i < fst_pairs.size(); ++i) { Label label = fst_pairs[i].first; const Fst<Arc> *fst = fst_pairs[i].second; nonterminal_hash_[label] = fst_array_.size(); nonterminal_array_.push_back(label); fst_array_.push_back(fst->Copy()); } root_fst_ = nonterminal_hash_[root_label]; if (!root_fst_) FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; } template <class Arc> ReplaceUtil<Arc>::ReplaceUtil( const vector<const Fst<Arc> *> &fst_array, const NonTerminalHash &nonterminal_hash, Label root_fst, bool epsilon_on_replace) : root_fst_(root_fst), epsilon_on_replace_(epsilon_on_replace), nonterminal_array_(fst_array.size()), nonterminal_hash_(nonterminal_hash), depprops_(0), have_stats_(false) { fst_array_.push_back(0); for (Label i = 1; i < fst_array.size(); ++i) fst_array_.push_back(fst_array[i]->Copy()); for (typename NonTerminalHash::const_iterator it = nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) nonterminal_array_[it->second] = it->first; root_label_ = nonterminal_array_[root_fst_]; } template <class Arc> void ReplaceUtil<Arc>::GetDependencies(bool stats) const { if (depfst_.NumStates() > 0) { if (stats && !have_stats_) ClearDependencies(); else return; } have_stats_ = stats; if (have_stats_) stats_.reserve(fst_array_.size()); for (Label i = 0; i < fst_array_.size(); ++i) { depfst_.AddState(); depfst_.SetFinal(i, Weight::One()); if (have_stats_) stats_.push_back(ReplaceStats()); } depfst_.SetStart(root_fst_); // An arc from each state (representing the fst) to the // state representing the fst being replaced for (Label i = 0; i < fst_array_.size(); ++i) { const Fst<Arc> *ifst = fst_array_[i]; if (!ifst) continue; for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); if (have_stats_) { ++stats_[i].nstates; if (ifst->Final(s) != Weight::Zero()) ++stats_[i].nfinal; } for (ArcIterator<Fst<Arc> > aiter(*ifst, s); !aiter.Done(); aiter.Next()) { if (have_stats_) ++stats_[i].narcs; const Arc& arc = aiter.Value(); typename NonTerminalHash::const_iterator it = nonterminal_hash_.find(arc.olabel); if (it != nonterminal_hash_.end()) { Label j = it->second; depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j)); if (have_stats_) { ++stats_[i].nnonterms; ++stats_[j].nref; ++stats_[j].inref[i]; ++stats_[i].outref[j]; } } } } } // Gets accessibility info SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_); DfsVisit(depfst_, &scc_visitor); } template <class Arc> void ReplaceUtil<Arc>::UpdateStats(Label j) { if (!have_stats_) { FSTERROR() << "ReplaceUtil::UpdateStats: stats not available"; return; } if (j == root_fst_) // can't replace root return; typedef typename map<Label, size_t>::iterator Iter; for (Iter in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) { Label i = in->first; size_t ni = in->second; stats_[i].nstates += stats_[j].nstates * ni; stats_[i].narcs += (stats_[j].narcs + 1) * ni; // narcs - 1 + 2 (eps) stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni; stats_[i].outref.erase(stats_[i].outref.find(j)); for (Iter out = stats_[j].outref.begin(); out != stats_[j].outref.end(); ++out) { Label k = out->first; size_t nk = out->second; stats_[i].outref[k] += ni * nk; } } for (Iter out = stats_[j].outref.begin(); out != stats_[j].outref.end(); ++out) { Label k = out->first; size_t nk = out->second; stats_[k].nref -= nk; stats_[k].inref.erase(stats_[k].inref.find(j)); for (Iter in = stats_[j].inref.begin(); in != stats_[j].inref.end(); ++in) { Label i = in->first; size_t ni = in->second; stats_[k].inref[i] += ni * nk; stats_[k].nref += ni * nk; } } } template <class Arc> void ReplaceUtil<Arc>::CheckMutableFsts() { if (mutable_fst_array_.size() == 0) { for (Label i = 0; i < fst_array_.size(); ++i) { if (!fst_array_[i]) { mutable_fst_array_.push_back(0); } else { mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i])); delete fst_array_[i]; fst_array_[i] = mutable_fst_array_[i]; } } } } template <class Arc> void ReplaceUtil<Arc>::Connect() { CheckMutableFsts(); uint64 props = kAccessible | kCoAccessible; for (Label i = 0; i < mutable_fst_array_.size(); ++i) { if (!mutable_fst_array_[i]) continue; if (mutable_fst_array_[i]->Properties(props, false) != props) fst::Connect(mutable_fst_array_[i]); } GetDependencies(false); for (Label i = 0; i < mutable_fst_array_.size(); ++i) { MutableFst<Arc> *fst = mutable_fst_array_[i]; if (fst && !depaccess_[i]) { delete fst; fst_array_[i] = 0; mutable_fst_array_[i] = 0; } } ClearDependencies(); } template <class Arc> bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const { // Finds topological order of dependencies. vector<StateId> order; bool acyclic = false; TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); DfsVisit(fst, &top_order_visitor); if (!acyclic) { LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies"; return false; } toporder->resize(order.size()); for (Label i = 0; i < order.size(); ++i) (*toporder)[order[i]] = i; return true; } template <class Arc> void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) { CheckMutableFsts(); unordered_set<Label> label_set; for (Label i = 0; i < labels.size(); ++i) if (labels[i] != root_label_) // can't replace root label_set.insert(labels[i]); // Finds Fst dependencies restricted to the labels requested. GetDependencies(false); VectorFst<Arc> pfst(depfst_); for (StateId i = 0; i < pfst.NumStates(); ++i) { vector<Arc> arcs; for (ArcIterator< VectorFst<Arc> > aiter(pfst, i); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); Label label = nonterminal_array_[arc.nextstate]; if (label_set.count(label) > 0) arcs.push_back(arc); } pfst.DeleteArcs(i); for (size_t j = 0; j < arcs.size(); ++j) pfst.AddArc(i, arcs[j]); } vector<Label> toporder; if (!GetTopOrder(pfst, &toporder)) { ClearDependencies(); return; } // Visits Fsts in reverse topological order of dependencies and // performs replacements. for (Label o = toporder.size() - 1; o >= 0; --o) { vector<FstPair> fst_pairs; StateId s = toporder[o]; for (ArcIterator< VectorFst<Arc> > aiter(pfst, s); !aiter.Done(); aiter.Next()) { const Arc &arc = aiter.Value(); Label label = nonterminal_array_[arc.nextstate]; const Fst<Arc> *fst = fst_array_[arc.nextstate]; fst_pairs.push_back(make_pair(label, fst)); } if (fst_pairs.empty()) continue; Label label = nonterminal_array_[s]; const Fst<Arc> *fst = fst_array_[s]; fst_pairs.push_back(make_pair(label, fst)); Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_); } ClearDependencies(); } template <class Arc> void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms) { vector<Label> labels; GetDependencies(true); vector<Label> toporder; if (!GetTopOrder(depfst_, &toporder)) { ClearDependencies(); return; } for (Label o = toporder.size() - 1; o >= 0; --o) { Label j = toporder[o]; if (stats_[j].nstates <= nstates && stats_[j].narcs <= narcs && stats_[j].nnonterms <= nnonterms) { labels.push_back(nonterminal_array_[j]); UpdateStats(j); } } ReplaceLabels(labels); } template <class Arc> void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) { vector<Label> labels; GetDependencies(true); vector<Label> toporder; if (!GetTopOrder(depfst_, &toporder)) { ClearDependencies(); return; } for (Label o = 0; o < toporder.size(); ++o) { Label j = toporder[o]; if (stats_[j].nref <= ninstances) { labels.push_back(nonterminal_array_[j]); UpdateStats(j); } } ReplaceLabels(labels); } template <class Arc> void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) { CheckMutableFsts(); fst_pairs->clear(); for (Label i = 0; i < fst_array_.size(); ++i) { Label label = nonterminal_array_[i]; const Fst<Arc> *fst = fst_array_[i]; if (!fst) continue; fst_pairs->push_back(make_pair(label, fst)); } } template <class Arc> void ReplaceUtil<Arc>::GetMutableFstPairs( vector<MutableFstPair> *mutable_fst_pairs) { CheckMutableFsts(); mutable_fst_pairs->clear(); for (Label i = 0; i < mutable_fst_array_.size(); ++i) { Label label = nonterminal_array_[i]; MutableFst<Arc> *fst = mutable_fst_array_[i]; if (!fst) continue; mutable_fst_pairs->push_back(make_pair(label, fst->Copy())); } } } // namespace fst #endif // FST_LIB_REPLACE_UTIL_H__