// replace.h // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // \file // Functions and classes for the recursive replacement of Fsts. // #ifndef FST_LIB_REPLACE_H__ #define FST_LIB_REPLACE_H__ #include <ext/hash_map> using __gnu_cxx::hash_map; #include "fst/lib/fst.h" #include "fst/lib/cache.h" #include "fst/lib/test-properties.h" namespace fst { // By default ReplaceFst will copy the input label of the 'replace arc'. // For acceptors we do not want this behaviour. Instead we need to // create an epsilon arc when recursing into the appropriate Fst. // The epsilon_on_replace option can be used to toggle this behaviour. struct ReplaceFstOptions : CacheOptions { int64 root; // root rule for expansion bool epsilon_on_replace; ReplaceFstOptions(const CacheOptions &opts, int64 r) : CacheOptions(opts), root(r), epsilon_on_replace(false) {} explicit ReplaceFstOptions(int64 r) : root(r), epsilon_on_replace(false) {} ReplaceFstOptions(int64 r, bool epsilon_replace_arc) : root(r), epsilon_on_replace(epsilon_replace_arc) {} ReplaceFstOptions() : root(kNoLabel), epsilon_on_replace(false) {} }; // // \class ReplaceFstImpl // \brief Implementation class for replace class Fst // // The replace implementation class supports a dynamic // expansion of a recursive transition network represented as Fst // with dynamic replacable arcs. // template <class A> class ReplaceFstImpl : public CacheImpl<A> { public: using FstImpl<A>::SetType; using FstImpl<A>::SetProperties; using FstImpl<A>::Properties; using FstImpl<A>::SetInputSymbols; using FstImpl<A>::SetOutputSymbols; using FstImpl<A>::InputSymbols; using FstImpl<A>::OutputSymbols; using CacheImpl<A>::HasStart; using CacheImpl<A>::HasArcs; using CacheImpl<A>::SetStart; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; typedef A Arc; typedef hash_map<Label, Label> NonTerminalHash; // \struct StateTuple // \brief Tuple of information that uniquely defines a state struct StateTuple { typedef int PrefixId; StateTuple() {} StateTuple(PrefixId p, StateId f, StateId s) : prefix_id(p), fst_id(f), fst_state(s) {} PrefixId prefix_id; // index in prefix table StateId fst_id; // current fst being walked StateId fst_state; // current state in fst being walked, not to be // confused with the state_id of the combined fst }; // constructor for replace class implementation. // \param fst_tuples array of label/fst tuples, one for each non-terminal ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples, const ReplaceFstOptions &opts) : CacheImpl<A>(opts), opts_(opts) { SetType("replace"); if (fst_tuples.size() > 0) { SetInputSymbols(fst_tuples[0].second->InputSymbols()); SetOutputSymbols(fst_tuples[0].second->OutputSymbols()); } fst_array_.push_back(0); for (size_t i = 0; i < fst_tuples.size(); ++i) AddFst(fst_tuples[i].first, fst_tuples[i].second); SetRoot(opts.root); } explicit ReplaceFstImpl(const ReplaceFstOptions &opts) : CacheImpl<A>(opts), opts_(opts), root_(kNoLabel) { fst_array_.push_back(0); } ReplaceFstImpl(const ReplaceFstImpl& impl) : opts_(impl.opts_), state_tuples_(impl.state_tuples_), state_hash_(impl.state_hash_), prefix_hash_(impl.prefix_hash_), stackprefix_array_(impl.stackprefix_array_), nonterminal_hash_(impl.nonterminal_hash_), root_(impl.root_) { SetType("replace"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(InputSymbols()); SetOutputSymbols(OutputSymbols()); fst_array_.reserve(impl.fst_array_.size()); fst_array_.push_back(0); for (size_t i = 1; i < impl.fst_array_.size(); ++i) fst_array_.push_back(impl.fst_array_[i]->Copy()); } ~ReplaceFstImpl() { for (size_t i = 1; i < fst_array_.size(); ++i) { delete fst_array_[i]; } } // Add to Fst array void AddFst(Label label, const Fst<A>* fst) { nonterminal_hash_[label] = fst_array_.size(); fst_array_.push_back(fst->Copy()); if (fst_array_.size() > 1) { vector<uint64> inprops(fst_array_.size()); for (size_t i = 1; i < fst_array_.size(); ++i) { inprops[i] = fst_array_[i]->Properties(kCopyProperties, false); } SetProperties(ReplaceProperties(inprops)); const SymbolTable* isymbols = fst_array_[1]->InputSymbols(); const SymbolTable* osymbols = fst_array_[1]->OutputSymbols(); for (size_t i = 2; i < fst_array_.size(); ++i) { if (!CompatSymbols(isymbols, fst_array_[i]->InputSymbols())) { LOG(FATAL) << "ReplaceFst::AddFst input symbols of Fst " << i-1 << " does not match input symbols of base Fst (0'th fst)"; } if (!CompatSymbols(osymbols, fst_array_[i]->OutputSymbols())) { LOG(FATAL) << "ReplaceFst::AddFst output symbols of Fst " << i-1 << " does not match output symbols of base Fst " << "(0'th fst)"; } } } } // Computes the dependency graph of the replace class and returns // true if the dependencies are cyclic. Cyclic dependencies will result // in an un-expandable replace fst. bool CyclicDependencies() const { StdVectorFst depfst; // one state for each fst for (size_t i = 1; i < fst_array_.size(); ++i) depfst.AddState(); // an arc from each state (representing the fst) to the // state representing the fst being replaced for (size_t i = 1; i < fst_array_.size(); ++i) { for (StateIterator<Fst<A> > siter(*(fst_array_[i])); !siter.Done(); siter.Next()) { for (ArcIterator<Fst<A> > aiter(*(fst_array_[i]), siter.Value()); !aiter.Done(); aiter.Next()) { const A& arc = aiter.Value(); typename NonTerminalHash::const_iterator it = nonterminal_hash_.find(arc.olabel); if (it != nonterminal_hash_.end()) { Label j = it->second - 1; depfst.AddArc(i - 1, A(arc.olabel, arc.olabel, Weight::One(), j)); } } } } depfst.SetStart(root_ - 1); depfst.SetFinal(root_ - 1, Weight::One()); return depfst.Properties(kCyclic, true); } // set root rule for expansion void SetRoot(Label root) { Label nonterminal = nonterminal_hash_[root]; root_ = (nonterminal > 0) ? nonterminal : 1; } // Change Fst array void SetFst(Label label, const Fst<A>* fst) { Label nonterminal = nonterminal_hash_[label]; delete fst_array_[nonterminal]; fst_array_[nonterminal] = fst->Copy(); } // Return or compute start state of replace fst StateId Start() { if (!HasStart()) { if (fst_array_.size() == 1) { // no fsts defined for replace SetStart(kNoStateId); return kNoStateId; } else { const Fst<A>* fst = fst_array_[root_]; StateId fst_start = fst->Start(); if (fst_start == kNoStateId) // root Fst is empty return kNoStateId; int prefix = PrefixId(StackPrefix()); StateId start = FindState(StateTuple(prefix, root_, fst_start)); SetStart(start); return start; } } else { return CacheImpl<A>::Start(); } } // return final weight of state (kInfWeight means state is not final) Weight Final(StateId s) { if (!HasFinal(s)) { const StateTuple& tuple = state_tuples_[s]; const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; const Fst<A>* fst = fst_array_[tuple.fst_id]; StateId fst_state = tuple.fst_state; if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0) SetFinal(s, fst->Final(fst_state)); else SetFinal(s, Weight::Zero()); } return CacheImpl<A>::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumOutputEpsilons(s); } // return the base arc iterator, if arcs have not been computed yet, // extend/recurse for new arcs. void InitArcIterator(StateId s, ArcIteratorData<A> *data) { if (!HasArcs(s)) Expand(s); CacheImpl<A>::InitArcIterator(s, data); } // Find/create an Fst state given a StateTuple. Only create a new // state if StateTuple is not found in the state hash. StateId FindState(const StateTuple& tuple) { typename StateTupleHash::iterator it = state_hash_.find(tuple); if (it == state_hash_.end()) { StateId new_state_id = state_tuples_.size(); state_tuples_.push_back(tuple); state_hash_[tuple] = new_state_id; return new_state_id; } else { return it->second; } } // extend current state (walk arcs one level deep) void Expand(StateId s) { StateTuple tuple = state_tuples_[s]; const Fst<A>* fst = fst_array_[tuple.fst_id]; StateId fst_state = tuple.fst_state; if (fst_state == kNoStateId) { SetArcs(s); return; } // if state is final, pop up stack const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) { int prefix_id = PopPrefix(stack); const PrefixTuple& top = stack.Top(); StateId nextstate = FindState(StateTuple(prefix_id, top.fst_id, top.nextstate)); AddArc(s, A(0, 0, fst->Final(fst_state), nextstate)); } // extend arcs leaving the state for (ArcIterator< Fst<A> > aiter(*fst, fst_state); !aiter.Done(); aiter.Next()) { const Arc& arc = aiter.Value(); if (arc.olabel == 0) { // expand local fst StateId nextstate = FindState(StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)); AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate)); } else { // check for non terminal typename NonTerminalHash::const_iterator it = nonterminal_hash_.find(arc.olabel); if (it != nonterminal_hash_.end()) { // recurse into non terminal Label nonterminal = it->second; const Fst<A>* nt_fst = fst_array_[nonterminal]; int nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id], tuple.fst_id, arc.nextstate); // if start state is valid replace, else arc is implicitly // deleted StateId nt_start = nt_fst->Start(); if (nt_start != kNoStateId) { StateId nt_nextstate = FindState( StateTuple(nt_prefix, nonterminal, nt_start)); Label ilabel = (opts_.epsilon_on_replace) ? 0 : arc.ilabel; AddArc(s, A(ilabel, 0, arc.weight, nt_nextstate)); } } else { StateId nextstate = FindState( StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)); AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate)); } } } SetArcs(s); } // private helper classes private: static const int kPrime0 = 7853; static const int kPrime1 = 7867; // \class StateTupleEqual // \brief Compare two StateTuples for equality class StateTupleEqual { public: bool operator()(const StateTuple& x, const StateTuple& y) const { return ((x.prefix_id == y.prefix_id) && (x.fst_id == y.fst_id) && (x.fst_state == y.fst_state)); } }; // \class StateTupleKey // \brief Hash function for StateTuple to Fst states class StateTupleKey { public: size_t operator()(const StateTuple& x) const { return static_cast<size_t>(x.prefix_id + x.fst_id * kPrime0 + x.fst_state * kPrime1); } }; typedef hash_map<StateTuple, StateId, StateTupleKey, StateTupleEqual> StateTupleHash; // \class PrefixTuple // \brief Tuple of fst_id and destination state (entry in stack prefix) struct PrefixTuple { PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {} Label fst_id; StateId nextstate; }; // \class StackPrefix // \brief Container for stack prefix. class StackPrefix { public: StackPrefix() {} // copy constructor StackPrefix(const StackPrefix& x) : prefix_(x.prefix_) { } void Push(int fst_id, StateId nextstate) { prefix_.push_back(PrefixTuple(fst_id, nextstate)); } void Pop() { prefix_.pop_back(); } const PrefixTuple& Top() const { return prefix_[prefix_.size()-1]; } size_t Depth() const { return prefix_.size(); } public: vector<PrefixTuple> prefix_; }; // \class StackPrefixEqual // \brief Compare two stack prefix classes for equality class StackPrefixEqual { public: bool operator()(const StackPrefix& x, const StackPrefix& y) const { if (x.prefix_.size() != y.prefix_.size()) return false; for (size_t i = 0; i < x.prefix_.size(); ++i) { if (x.prefix_[i].fst_id != y.prefix_[i].fst_id || x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false; } return true; } }; // // \class StackPrefixKey // \brief Hash function for stack prefix to prefix id class StackPrefixKey { public: size_t operator()(const StackPrefix& x) const { int sum = 0; for (size_t i = 0; i < x.prefix_.size(); ++i) { sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0; } return (size_t) sum; } }; typedef hash_map<StackPrefix, int, StackPrefixKey, StackPrefixEqual> StackPrefixHash; // private methods private: // hash stack prefix (return unique index into stackprefix array) int PrefixId(const StackPrefix& prefix) { typename StackPrefixHash::iterator it = prefix_hash_.find(prefix); if (it == prefix_hash_.end()) { int prefix_id = stackprefix_array_.size(); stackprefix_array_.push_back(prefix); prefix_hash_[prefix] = prefix_id; return prefix_id; } else { return it->second; } } // prefix id after a stack pop int PopPrefix(StackPrefix prefix) { prefix.Pop(); return PrefixId(prefix); } // prefix id after a stack push int PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) { prefix.Push(fst_id, nextstate); return PrefixId(prefix); } // private data private: // runtime options ReplaceFstOptions opts_; // maps from StateId to StateTuple vector<StateTuple> state_tuples_; // hashes from StateTuple to StateId StateTupleHash state_hash_; // cross index of unique stack prefix // could potentially have one copy of prefix array StackPrefixHash prefix_hash_; vector<StackPrefix> stackprefix_array_; NonTerminalHash nonterminal_hash_; vector<const Fst<A>*> fst_array_; Label root_; void operator=(const ReplaceFstImpl<A> &); // disallow }; // // \class ReplaceFst // \brief Recursivively replaces arcs in the root Fst with other Fsts. // This version is a delayed Fst. // // ReplaceFst supports dynamic replacement of arcs in one Fst with // another Fst. This replacement is recursive. ReplaceFst can be used // to support a variety of delayed constructions such as recursive // transition networks, union, or closure. It is constructed with an // array of Fst(s). One Fst represents the root (or topology) // machine. The root Fst refers to other Fsts by recursively replacing // arcs labeled as non-terminals with the matching non-terminal // Fst. Currently the ReplaceFst uses the output symbols of the arcs // to determine whether the arc is a non-terminal arc or not. A // non-terminal can be any label that is not a non-zero terminal label // in the output alphabet. // // Note that the constructor uses a vector of pair<>. These correspond // to the tuple of non-terminal Label and corresponding Fst. For example // to implement the closure operation we need 2 Fsts. The first root // Fst is a single Arc on the start State that self loops, it references // the particular machine for which we are performing the closure operation. // template <class A> class ReplaceFst : public Fst<A> { public: friend class ArcIterator< ReplaceFst<A> >; friend class CacheStateIterator< ReplaceFst<A> >; friend class CacheArcIterator< ReplaceFst<A> >; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array, Label root) : impl_(new ReplaceFstImpl<A>(fst_array, ReplaceFstOptions(root))) {} ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array, const ReplaceFstOptions &opts) : impl_(new ReplaceFstImpl<A>(fst_array, opts)) {} ReplaceFst(const ReplaceFst<A>& fst) : impl_(new ReplaceFstImpl<A>(*(fst.impl_))) {} virtual ~ReplaceFst() { delete impl_; } virtual StateId Start() const { return impl_->Start(); } virtual Weight Final(StateId s) const { return impl_->Final(s); } virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } virtual size_t NumInputEpsilons(StateId s) const { return impl_->NumInputEpsilons(s); } virtual size_t NumOutputEpsilons(StateId s) const { return impl_->NumOutputEpsilons(s); } virtual uint64 Properties(uint64 mask, bool test) const { if (test) { uint64 known, test = TestProperties(*this, mask, &known); impl_->SetProperties(test, known); return test & mask; } else { return impl_->Properties(mask); } } virtual const string& Type() const { return impl_->Type(); } virtual ReplaceFst<A>* Copy() const { return new ReplaceFst<A>(*this); } virtual const SymbolTable* InputSymbols() const { return impl_->InputSymbols(); } virtual const SymbolTable* OutputSymbols() const { return impl_->OutputSymbols(); } virtual inline void InitStateIterator(StateIteratorData<A> *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { impl_->InitArcIterator(s, data); } bool CyclicDependencies() const { return impl_->CyclicDependencies(); } private: ReplaceFstImpl<A>* impl_; }; // Specialization for ReplaceFst. template<class A> class StateIterator< ReplaceFst<A> > : public CacheStateIterator< ReplaceFst<A> > { public: explicit StateIterator(const ReplaceFst<A> &fst) : CacheStateIterator< ReplaceFst<A> >(fst) {} private: DISALLOW_EVIL_CONSTRUCTORS(StateIterator); }; // Specialization for ReplaceFst. template <class A> class ArcIterator< ReplaceFst<A> > : public CacheArcIterator< ReplaceFst<A> > { public: typedef typename A::StateId StateId; ArcIterator(const ReplaceFst<A> &fst, StateId s) : CacheArcIterator< ReplaceFst<A> >(fst, s) { if (!fst.impl_->HasArcs(s)) fst.impl_->Expand(s); } private: DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); }; template <class A> inline void ReplaceFst<A>::InitStateIterator(StateIteratorData<A> *data) const { data->base = new StateIterator< ReplaceFst<A> >(*this); } typedef ReplaceFst<StdArc> StdReplaceFst; // // Recursivively replaces arcs in the root Fst with other Fsts. // This version writes the result of replacement to an output MutableFst. // // Replace supports replacement of arcs in one Fst with another // Fst. This replacement is recursive. Replace takes an array of // Fst(s). One Fst represents the root (or topology) machine. The root // Fst refers to other Fsts by recursively replacing arcs labeled as // non-terminals with the matching non-terminal Fst. Currently Replace // uses the output symbols of the arcs to determine whether the arc is // a non-terminal arc or not. A non-terminal can be any label that is // not a non-zero terminal label in the output alphabet. Note that // input argument is a vector of pair<>. These correspond to the tuple // of non-terminal Label and corresponding Fst. template<class Arc> void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >& ifst_array, MutableFst<Arc> *ofst, typename Arc::Label root) { ReplaceFstOptions opts(root); opts.gc_limit = 0; // Cache only the last state for fastest copy. *ofst = ReplaceFst<Arc>(ifst_array, opts); } } #endif // FST_LIB_REPLACE_H__