// replace.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: johans@google.com (Johan Schalkwyk) // // \file // Functions and classes for the recursive replacement of Fsts. // #ifndef FST_LIB_REPLACE_H__ #define FST_LIB_REPLACE_H__ #include <tr1/unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; #include <set> #include <string> #include <utility> using std::pair; using std::make_pair; #include <vector> using std::vector; #include <fst/cache.h> #include <fst/expanded-fst.h> #include <fst/fst.h> #include <fst/matcher.h> #include <fst/replace-util.h> #include <fst/state-table.h> #include <fst/test-properties.h> namespace fst { // // REPLACE STATE TUPLES AND TABLES // // The replace state table has the form // // template <class A, class P> // class ReplaceStateTable { // public: // typedef A Arc; // typedef P PrefixId; // typedef typename A::StateId StateId; // typedef ReplaceStateTuple<StateId, PrefixId> StateTuple; // typedef typename A::Label Label; // // // Required constuctor // ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples, // Label root); // // // Required copy constructor that does not copy state // ReplaceStateTable(const ReplaceStateTable<A,P> &table); // // // Lookup state ID by tuple. If it doesn't exist, then add it. // StateId FindState(const StateTuple &tuple); // // // Lookup state tuple by ID. // const StateTuple &Tuple(StateId id) const; // }; // \struct ReplaceStateTuple // \brief Tuple of information that uniquely defines a state in replace template <class S, class P> struct ReplaceStateTuple { typedef S StateId; typedef P PrefixId; ReplaceStateTuple() : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {} ReplaceStateTuple(PrefixId p, StateId f, StateId s) : prefix_id(p), fst_id(f), fst_state(s) {} PrefixId prefix_id; // index in prefix table StateId fst_id; // current fst being walked StateId fst_state; // current state in fst being walked, not to be // confused with the state_id of the combined fst }; // Equality of replace state tuples. template <class S, class P> inline bool operator==(const ReplaceStateTuple<S, P>& x, const ReplaceStateTuple<S, P>& y) { return x.prefix_id == y.prefix_id && x.fst_id == y.fst_id && x.fst_state == y.fst_state; } // \class ReplaceRootSelector // Functor returning true for tuples corresponding to states in the root FST template <class S, class P> class ReplaceRootSelector { public: bool operator()(const ReplaceStateTuple<S, P> &tuple) const { return tuple.prefix_id == 0; } }; // \class ReplaceFingerprint // Fingerprint for general replace state tuples. template <class S, class P> class ReplaceFingerprint { public: ReplaceFingerprint(const vector<uint64> *size_array) : cumulative_size_array_(size_array) {} uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const { return tuple.prefix_id * (cumulative_size_array_->back()) + cumulative_size_array_->at(tuple.fst_id - 1) + tuple.fst_state; } private: const vector<uint64> *cumulative_size_array_; }; // \class ReplaceFstStateFingerprint // Useful when the fst_state uniquely define the tuple. template <class S, class P> class ReplaceFstStateFingerprint { public: uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const { return tuple.fst_state; } }; // \class ReplaceHash // A generic hash function for replace state tuples. template <typename S, typename P> class ReplaceHash { public: size_t operator()(const ReplaceStateTuple<S, P>& t) const { return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1; } private: static const size_t kPrime0; static const size_t kPrime1; }; template <typename S, typename P> const size_t ReplaceHash<S, P>::kPrime0 = 7853; template <typename S, typename P> const size_t ReplaceHash<S, P>::kPrime1 = 7867; template <class A, class T> class ReplaceFstMatcher; // \class VectorHashReplaceStateTable // A two-level state table for replace. // Warning: calls CountStates to compute the number of states of each // component Fst. template <class A, class P = ssize_t> class VectorHashReplaceStateTable { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Label Label; typedef P PrefixId; typedef ReplaceStateTuple<StateId, P> StateTuple; typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>, ReplaceRootSelector<StateId, P>, ReplaceFstStateFingerprint<StateId, P>, ReplaceFingerprint<StateId, P> > StateTable; VectorHashReplaceStateTable( const vector<pair<Label, const Fst<A>*> > &fst_tuples, Label root) : root_size_(0) { cumulative_size_array_.push_back(0); for (size_t i = 0; i < fst_tuples.size(); ++i) { if (fst_tuples[i].first == root) { root_size_ = CountStates(*(fst_tuples[i].second)); cumulative_size_array_.push_back(cumulative_size_array_.back()); } else { cumulative_size_array_.push_back(cumulative_size_array_.back() + CountStates(*(fst_tuples[i].second))); } } state_table_ = new StateTable( new ReplaceRootSelector<StateId, P>, new ReplaceFstStateFingerprint<StateId, P>, new ReplaceFingerprint<StateId, P>(&cumulative_size_array_), root_size_, root_size_ + cumulative_size_array_.back()); } VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table) : root_size_(table.root_size_), cumulative_size_array_(table.cumulative_size_array_) { state_table_ = new StateTable( new ReplaceRootSelector<StateId, P>, new ReplaceFstStateFingerprint<StateId, P>, new ReplaceFingerprint<StateId, P>(&cumulative_size_array_), root_size_, root_size_ + cumulative_size_array_.back()); } ~VectorHashReplaceStateTable() { delete state_table_; } StateId FindState(const StateTuple &tuple) { return state_table_->FindState(tuple); } const StateTuple &Tuple(StateId id) const { return state_table_->Tuple(id); } private: StateId root_size_; vector<uint64> cumulative_size_array_; StateTable *state_table_; }; // \class DefaultReplaceStateTable // Default replace state table template <class A, class P = ssize_t> class DefaultReplaceStateTable : public CompactHashStateTable< ReplaceStateTuple<typename A::StateId, P>, ReplaceHash<typename A::StateId, P> > { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Label Label; typedef P PrefixId; typedef ReplaceStateTuple<StateId, P> StateTuple; typedef CompactHashStateTable<StateTuple, ReplaceHash<StateId, PrefixId> > StateTable; using StateTable::FindState; using StateTable::Tuple; DefaultReplaceStateTable( const vector<pair<Label, const Fst<A>*> > &fst_tuples, Label root) {} DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table) : StateTable() {} }; // // REPLACE FST CLASS // // By default ReplaceFst will copy the input label of the 'replace arc'. // For acceptors we do not want this behaviour. Instead we need to // create an epsilon arc when recursing into the appropriate Fst. // The 'epsilon_on_replace' option can be used to toggle this behaviour. template <class A, class T = DefaultReplaceStateTable<A> > struct ReplaceFstOptions : CacheOptions { int64 root; // root rule for expansion bool epsilon_on_replace; bool take_ownership; // take ownership of input Fst(s) T* state_table; ReplaceFstOptions(const CacheOptions &opts, int64 r) : CacheOptions(opts), root(r), epsilon_on_replace(false), take_ownership(false), state_table(0) {} explicit ReplaceFstOptions(int64 r) : root(r), epsilon_on_replace(false), take_ownership(false), state_table(0) {} ReplaceFstOptions(int64 r, bool epsilon_replace_arc) : root(r), epsilon_on_replace(epsilon_replace_arc), take_ownership(false), state_table(0) {} ReplaceFstOptions() : root(kNoLabel), epsilon_on_replace(false), take_ownership(false), state_table(0) {} }; // \class ReplaceFstImpl // \brief Implementation class for replace class Fst // // The replace implementation class supports a dynamic // expansion of a recursive transition network represented as Fst // with dynamic replacable arcs. // template <class A, class T> class ReplaceFstImpl : public CacheImpl<A> { friend class ReplaceFstMatcher<A, T>; public: using FstImpl<A>::SetType; using FstImpl<A>::SetProperties; using FstImpl<A>::WriteHeader; using FstImpl<A>::SetInputSymbols; using FstImpl<A>::SetOutputSymbols; using FstImpl<A>::InputSymbols; using FstImpl<A>::OutputSymbols; using CacheImpl<A>::PushArc; using CacheImpl<A>::HasArcs; using CacheImpl<A>::HasFinal; using CacheImpl<A>::HasStart; using CacheImpl<A>::SetArcs; using CacheImpl<A>::SetFinal; using CacheImpl<A>::SetStart; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; typedef A Arc; typedef unordered_map<Label, Label> NonTerminalHash; typedef T StateTable; typedef typename T::PrefixId PrefixId; typedef ReplaceStateTuple<StateId, PrefixId> StateTuple; // constructor for replace class implementation. // \param fst_tuples array of label/fst tuples, one for each non-terminal ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples, const ReplaceFstOptions<A, T> &opts) : CacheImpl<A>(opts), epsilon_on_replace_(opts.epsilon_on_replace), state_table_(opts.state_table ? opts.state_table : new StateTable(fst_tuples, opts.root)) { SetType("replace"); if (fst_tuples.size() > 0) { SetInputSymbols(fst_tuples[0].second->InputSymbols()); SetOutputSymbols(fst_tuples[0].second->OutputSymbols()); } bool all_negative = true; // all nonterminals are negative? bool dense_range = true; // all nonterminals are positive // and form a dense range containing 1? for (size_t i = 0; i < fst_tuples.size(); ++i) { Label nonterminal = fst_tuples[i].first; if (nonterminal >= 0) all_negative = false; if (nonterminal > fst_tuples.size() || nonterminal <= 0) dense_range = false; } vector<uint64> inprops; bool all_ilabel_sorted = true; bool all_olabel_sorted = true; bool all_non_empty = true; fst_array_.push_back(0); for (size_t i = 0; i < fst_tuples.size(); ++i) { Label label = fst_tuples[i].first; const Fst<A> *fst = fst_tuples[i].second; nonterminal_hash_[label] = fst_array_.size(); nonterminal_set_.insert(label); fst_array_.push_back(opts.take_ownership ? fst : fst->Copy()); if (fst->Start() == kNoStateId) all_non_empty = false; if(!fst->Properties(kILabelSorted, false)) all_ilabel_sorted = false; if(!fst->Properties(kOLabelSorted, false)) all_olabel_sorted = false; inprops.push_back(fst->Properties(kCopyProperties, false)); if (i) { if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) { FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i << " does not match input symbols of base Fst (0'th fst)"; SetProperties(kError, kError); } if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) { FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i << " does not match output symbols of base Fst " << "(0'th fst)"; SetProperties(kError, kError); } } } Label nonterminal = nonterminal_hash_[opts.root]; if ((nonterminal == 0) && (fst_array_.size() > 1)) { FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '" << opts.root << "' in the input tuple vector"; SetProperties(kError, kError); } root_ = (nonterminal > 0) ? nonterminal : 1; SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_, all_non_empty)); // We assume that all terminals are positive. The resulting // ReplaceFst is known to be kILabelSorted when all sub-FSTs are // kILabelSorted and one of the 3 following conditions is satisfied: // 1. 'epsilon_on_replace' is false, or // 2. all non-terminals are negative, or // 3. all non-terninals are positive and form a dense range containing 1. if (all_ilabel_sorted && (!epsilon_on_replace_ || all_negative || dense_range)) SetProperties(kILabelSorted, kILabelSorted); // Similarly, the resulting ReplaceFst is known to be // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of // the 2 following conditions is satisfied: // 1. all non-terminals are negative, or // 2. all non-terninals are positive and form a dense range containing 1. if (all_olabel_sorted && (all_negative || dense_range)) SetProperties(kOLabelSorted, kOLabelSorted); // Enable optional caching as long as sorted and all non empty. if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty) always_cache_ = false; else always_cache_ = true; VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = " << (always_cache_ ? "true" : "false"); } ReplaceFstImpl(const ReplaceFstImpl& impl) : CacheImpl<A>(impl), epsilon_on_replace_(impl.epsilon_on_replace_), always_cache_(impl.always_cache_), state_table_(new StateTable(*(impl.state_table_))), nonterminal_set_(impl.nonterminal_set_), nonterminal_hash_(impl.nonterminal_hash_), root_(impl.root_) { SetType("replace"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); fst_array_.reserve(impl.fst_array_.size()); fst_array_.push_back(0); for (size_t i = 1; i < impl.fst_array_.size(); ++i) { fst_array_.push_back(impl.fst_array_[i]->Copy(true)); } } ~ReplaceFstImpl() { VLOG(2) << "~ReplaceFstImpl: gc = " << (CacheImpl<A>::GetCacheGc() ? "true" : "false") << ", gc_size = " << CacheImpl<A>::GetCacheSize() << ", gc_limit = " << CacheImpl<A>::GetCacheLimit(); delete state_table_; for (size_t i = 1; i < fst_array_.size(); ++i) { delete fst_array_[i]; } } // Computes the dependency graph of the replace class and returns // true if the dependencies are cyclic. Cyclic dependencies will result // in an un-expandable replace fst. bool CyclicDependencies() const { ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_); return replace_util.CyclicDependencies(); } // Return or compute start state of replace fst StateId Start() { if (!HasStart()) { if (fst_array_.size() == 1) { // no fsts defined for replace SetStart(kNoStateId); return kNoStateId; } else { const Fst<A>* fst = fst_array_[root_]; StateId fst_start = fst->Start(); if (fst_start == kNoStateId) // root Fst is empty return kNoStateId; PrefixId prefix = GetPrefixId(StackPrefix()); StateId start = state_table_->FindState( StateTuple(prefix, root_, fst_start)); SetStart(start); return start; } } else { return CacheImpl<A>::Start(); } } // return final weight of state (kInfWeight means state is not final) Weight Final(StateId s) { if (!HasFinal(s)) { const StateTuple& tuple = state_table_->Tuple(s); const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; const Fst<A>* fst = fst_array_[tuple.fst_id]; StateId fst_state = tuple.fst_state; if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0) SetFinal(s, fst->Final(fst_state)); else SetFinal(s, Weight::Zero()); } return CacheImpl<A>::Final(s); } size_t NumArcs(StateId s) { if (HasArcs(s)) { // If state cached, use the cached value. return CacheImpl<A>::NumArcs(s); } else if (always_cache_) { // If always caching, expand and cache state. Expand(s); return CacheImpl<A>::NumArcs(s); } else { // Otherwise compute the number of arcs without expanding. StateTuple tuple = state_table_->Tuple(s); if (tuple.fst_state == kNoStateId) return 0; const Fst<A>* fst = fst_array_[tuple.fst_id]; size_t num_arcs = fst->NumArcs(tuple.fst_state); if (ComputeFinalArc(tuple, 0)) num_arcs++; return num_arcs; } } // Returns whether a given label is a non terminal bool IsNonTerminal(Label l) const { // TODO(allauzen): be smarter and take advantage of // all_dense or all_negative. // Use also in ComputeArc, this would require changes to replace // so that recursing into an empty fst lead to a non co-accessible // state instead of deleting the arc as done currently. // Current use correct, since i/olabel sorted iff all_non_empty. typename NonTerminalHash::const_iterator it = nonterminal_hash_.find(l); return it != nonterminal_hash_.end(); } size_t NumInputEpsilons(StateId s) { if (HasArcs(s)) { // If state cached, use the cached value. return CacheImpl<A>::NumInputEpsilons(s); } else if (always_cache_ || !Properties(kILabelSorted)) { // If always caching or if the number of input epsilons is too expensive // to compute without caching (i.e. not ilabel sorted), // then expand and cache state. Expand(s); return CacheImpl<A>::NumInputEpsilons(s); } else { // Otherwise, compute the number of input epsilons without caching. StateTuple tuple = state_table_->Tuple(s); if (tuple.fst_state == kNoStateId) return 0; const Fst<A>* fst = fst_array_[tuple.fst_id]; size_t num = 0; if (!epsilon_on_replace_) { // If epsilon_on_replace is false, all input epsilon arcs // are also input epsilons arcs in the underlying machine. fst->NumInputEpsilons(tuple.fst_state); } else { // Otherwise, one need to consider that all non-terminal arcs // in the underlying machine also become input epsilon arc. ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state); for (; !aiter.Done() && ((aiter.Value().ilabel == 0) || IsNonTerminal(aiter.Value().olabel)); aiter.Next()) ++num; } if (ComputeFinalArc(tuple, 0)) num++; return num; } } size_t NumOutputEpsilons(StateId s) { if (HasArcs(s)) { // If state cached, use the cached value. return CacheImpl<A>::NumOutputEpsilons(s); } else if(always_cache_ || !Properties(kOLabelSorted)) { // If always caching or if the number of output epsilons is too expensive // to compute without caching (i.e. not olabel sorted), // then expand and cache state. Expand(s); return CacheImpl<A>::NumOutputEpsilons(s); } else { // Otherwise, compute the number of output epsilons without caching. StateTuple tuple = state_table_->Tuple(s); if (tuple.fst_state == kNoStateId) return 0; const Fst<A>* fst = fst_array_[tuple.fst_id]; size_t num = 0; ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state); for (; !aiter.Done() && ((aiter.Value().olabel == 0) || IsNonTerminal(aiter.Value().olabel)); aiter.Next()) ++num; if (ComputeFinalArc(tuple, 0)) num++; return num; } } uint64 Properties() const { return Properties(kFstProperties); } // Set error if found; return FST impl properties. uint64 Properties(uint64 mask) const { if (mask & kError) { for (size_t i = 1; i < fst_array_.size(); ++i) { if (fst_array_[i]->Properties(kError, false)) SetProperties(kError, kError); } } return FstImpl<Arc>::Properties(mask); } // return the base arc iterator, if arcs have not been computed yet, // extend/recurse for new arcs. void InitArcIterator(StateId s, ArcIteratorData<A> *data) { if (!HasArcs(s)) Expand(s); CacheImpl<A>::InitArcIterator(s, data); // TODO(allauzen): Set behaviour of generic iterator // Warning: ArcIterator<ReplaceFst<A> >::InitCache() // relies on current behaviour. } // Extend current state (walk arcs one level deep) void Expand(StateId s) { StateTuple tuple = state_table_->Tuple(s); // If local fst is empty if (tuple.fst_state == kNoStateId) { SetArcs(s); return; } ArcIterator< Fst<A> > aiter( *(fst_array_[tuple.fst_id]), tuple.fst_state); Arc arc; // Create a final arc when needed if (ComputeFinalArc(tuple, &arc)) PushArc(s, arc); // Expand all arcs leaving the state for (;!aiter.Done(); aiter.Next()) { if (ComputeArc(tuple, aiter.Value(), &arc)) PushArc(s, arc); } SetArcs(s); } void Expand(StateId s, const StateTuple &tuple, const ArcIteratorData<A> &data) { // If local fst is empty if (tuple.fst_state == kNoStateId) { SetArcs(s); return; } ArcIterator< Fst<A> > aiter(data); Arc arc; // Create a final arc when needed if (ComputeFinalArc(tuple, &arc)) AddArc(s, arc); // Expand all arcs leaving the state for (; !aiter.Done(); aiter.Next()) { if (ComputeArc(tuple, aiter.Value(), &arc)) AddArc(s, arc); } SetArcs(s); } // If arcp == 0, only returns if a final arc is required, does not // actually compute it. bool ComputeFinalArc(const StateTuple &tuple, A* arcp, uint32 flags = kArcValueFlags) { const Fst<A>* fst = fst_array_[tuple.fst_id]; StateId fst_state = tuple.fst_state; if (fst_state == kNoStateId) return false; // if state is final, pop up stack const StackPrefix& stack = stackprefix_array_[tuple.prefix_id]; if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) { if (arcp) { arcp->ilabel = 0; arcp->olabel = 0; if (flags & kArcNextStateValue) { PrefixId prefix_id = PopPrefix(stack); const PrefixTuple& top = stack.Top(); arcp->nextstate = state_table_->FindState( StateTuple(prefix_id, top.fst_id, top.nextstate)); } if (flags & kArcWeightValue) arcp->weight = fst->Final(fst_state); } return true; } else { return false; } } // Compute the arc in the replace fst corresponding to a given // in the underlying machine. Returns false if the underlying arc // corresponds to no arc in the replace. bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp, uint32 flags = kArcValueFlags) { if (!epsilon_on_replace_ && (flags == (flags & (kArcILabelValue | kArcWeightValue)))) { *arcp = arc; return true; } if (arc.olabel == 0) { // expand local fst StateId nextstate = flags & kArcNextStateValue ? state_table_->FindState( StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) : kNoStateId; *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate); } else { // check for non terminal typename NonTerminalHash::const_iterator it = nonterminal_hash_.find(arc.olabel); if (it != nonterminal_hash_.end()) { // recurse into non terminal Label nonterminal = it->second; const Fst<A>* nt_fst = fst_array_[nonterminal]; PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id], tuple.fst_id, arc.nextstate); // if start state is valid replace, else arc is implicitly // deleted StateId nt_start = nt_fst->Start(); if (nt_start != kNoStateId) { StateId nt_nextstate = flags & kArcNextStateValue ? state_table_->FindState( StateTuple(nt_prefix, nonterminal, nt_start)) : kNoStateId; Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel; *arcp = A(ilabel, 0, arc.weight, nt_nextstate); } else { return false; } } else { StateId nextstate = flags & kArcNextStateValue ? state_table_->FindState( StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate)) : kNoStateId; *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate); } } return true; } // Returns the arc iterator flags supported by this Fst. uint32 ArcIteratorFlags() const { uint32 flags = kArcValueFlags; if (!always_cache_) flags |= kArcNoCache; return flags; } T* GetStateTable() const { return state_table_; } const Fst<A>* GetFst(Label fst_id) const { return fst_array_[fst_id]; } bool EpsilonOnReplace() const { return epsilon_on_replace_; } // private helper classes private: static const size_t kPrime0; // \class PrefixTuple // \brief Tuple of fst_id and destination state (entry in stack prefix) struct PrefixTuple { PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {} Label fst_id; StateId nextstate; }; // \class StackPrefix // \brief Container for stack prefix. class StackPrefix { public: StackPrefix() {} // copy constructor StackPrefix(const StackPrefix& x) : prefix_(x.prefix_) { } void Push(StateId fst_id, StateId nextstate) { prefix_.push_back(PrefixTuple(fst_id, nextstate)); } void Pop() { prefix_.pop_back(); } const PrefixTuple& Top() const { return prefix_[prefix_.size()-1]; } size_t Depth() const { return prefix_.size(); } public: vector<PrefixTuple> prefix_; }; // \class StackPrefixEqual // \brief Compare two stack prefix classes for equality class StackPrefixEqual { public: bool operator()(const StackPrefix& x, const StackPrefix& y) const { if (x.prefix_.size() != y.prefix_.size()) return false; for (size_t i = 0; i < x.prefix_.size(); ++i) { if (x.prefix_[i].fst_id != y.prefix_[i].fst_id || x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false; } return true; } }; // // \class StackPrefixKey // \brief Hash function for stack prefix to prefix id class StackPrefixKey { public: size_t operator()(const StackPrefix& x) const { size_t sum = 0; for (size_t i = 0; i < x.prefix_.size(); ++i) { sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0; } return sum; } }; typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual> StackPrefixHash; // private methods private: // hash stack prefix (return unique index into stackprefix array) PrefixId GetPrefixId(const StackPrefix& prefix) { typename StackPrefixHash::iterator it = prefix_hash_.find(prefix); if (it == prefix_hash_.end()) { PrefixId prefix_id = stackprefix_array_.size(); stackprefix_array_.push_back(prefix); prefix_hash_[prefix] = prefix_id; return prefix_id; } else { return it->second; } } // prefix id after a stack pop PrefixId PopPrefix(StackPrefix prefix) { prefix.Pop(); return GetPrefixId(prefix); } // prefix id after a stack push PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) { prefix.Push(fst_id, nextstate); return GetPrefixId(prefix); } // private data private: // runtime options bool epsilon_on_replace_; bool always_cache_; // Optionally caching arc iterator disabled when true // state table StateTable *state_table_; // cross index of unique stack prefix // could potentially have one copy of prefix array StackPrefixHash prefix_hash_; vector<StackPrefix> stackprefix_array_; set<Label> nonterminal_set_; NonTerminalHash nonterminal_hash_; vector<const Fst<A>*> fst_array_; Label root_; void operator=(const ReplaceFstImpl<A, T> &); // disallow }; template <class A, class T> const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853; // // \class ReplaceFst // \brief Recursivively replaces arcs in the root Fst with other Fsts. // This version is a delayed Fst. // // ReplaceFst supports dynamic replacement of arcs in one Fst with // another Fst. This replacement is recursive. ReplaceFst can be used // to support a variety of delayed constructions such as recursive // transition networks, union, or closure. It is constructed with an // array of Fst(s). One Fst represents the root (or topology) // machine. The root Fst refers to other Fsts by recursively replacing // arcs labeled as non-terminals with the matching non-terminal // Fst. Currently the ReplaceFst uses the output symbols of the arcs // to determine whether the arc is a non-terminal arc or not. A // non-terminal can be any label that is not a non-zero terminal label // in the output alphabet. // // Note that the constructor uses a vector of pair<>. These correspond // to the tuple of non-terminal Label and corresponding Fst. For example // to implement the closure operation we need 2 Fsts. The first root // Fst is a single Arc on the start State that self loops, it references // the particular machine for which we are performing the closure operation. // // The ReplaceFst class supports an optionally caching arc iterator: // ArcIterator< ReplaceFst<A> > // The ReplaceFst need to be built such that it is known to be ilabel // or olabel sorted (see usage below). // // Observe that Matcher<Fst<A> > will use the optionally caching arc // iterator when available (Fst is ilabel sorted and matching on the // input, or Fst is olabel sorted and matching on the output). // In order to obtain the most efficient behaviour, it is recommended // to set 'epsilon_on_replace' to false (this means constructing acceptors // as transducers with epsilons on the input side of nonterminal arcs) // and matching on the input side. // // This class attaches interface to implementation and handles // reference counting, delegating most methods to ImplToFst. template <class A, class T = DefaultReplaceStateTable<A> > class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > { public: friend class ArcIterator< ReplaceFst<A, T> >; friend class StateIterator< ReplaceFst<A, T> >; friend class ReplaceFstMatcher<A, T>; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; typedef ReplaceFstImpl<A, T> Impl; using ImplToFst<Impl>::Properties; ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array, Label root) : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {} ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array, const ReplaceFstOptions<A, T> &opts) : ImplToFst<Impl>(new Impl(fst_array, opts)) {} // See Fst<>::Copy() for doc. ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false) : ImplToFst<Impl>(fst, safe) {} // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc. virtual ReplaceFst<A, T> *Copy(bool safe = false) const { return new ReplaceFst<A, T>(*this, safe); } virtual inline void InitStateIterator(StateIteratorData<A> *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { GetImpl()->InitArcIterator(s, data); } virtual MatcherBase<A> *InitMatcher(MatchType match_type) const { if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) && ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) || (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) { return new ReplaceFstMatcher<A, T>(*this, match_type); } else { VLOG(2) << "Not using replace matcher"; return 0; } } bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); } private: // Makes visible to friends. Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } void operator=(const ReplaceFst<A> &fst); // disallow }; // Specialization for ReplaceFst. template<class A, class T> class StateIterator< ReplaceFst<A, T> > : public CacheStateIterator< ReplaceFst<A, T> > { public: explicit StateIterator(const ReplaceFst<A, T> &fst) : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {} private: DISALLOW_COPY_AND_ASSIGN(StateIterator); }; // Specialization for ReplaceFst. // Implements optional caching. It can be used as follows: // // ReplaceFst<A> replace; // ArcIterator< ReplaceFst<A> > aiter(replace, s); // // Note: ArcIterator< Fst<A> > is always a caching arc iterator. // aiter.SetFlags(kArcNoCache, kArcNoCache); // // Use the arc iterator, no arc will be cached, no state will be expanded. // // The varied 'kArcValueFlags' can be used to decide which part // // of arc values needs to be computed. // aiter.SetFlags(kArcILabelValue, kArcValueFlags); // // Only want the ilabel for this arc // aiter.Value(); // Does not compute the destination state. // aiter.Next(); // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue); // // Want both ilabel and nextstate for that arc // aiter.Value(); // Does compute the destination state and inserts it // // in the replace state table. // // No Arc has been cached at that point. // template <class A, class T> class ArcIterator< ReplaceFst<A, T> > { public: typedef A Arc; typedef typename A::StateId StateId; ArcIterator(const ReplaceFst<A, T> &fst, StateId s) : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0), data_flags_(0), final_flags_(0) { cache_data_.ref_count = 0; local_data_.ref_count = 0; // If FST does not support optional caching, force caching. if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) && !(fst_.GetImpl()->HasArcs(state_))) fst_.GetImpl()->Expand(state_); // If state is already cached, use cached arcs array. if (fst_.GetImpl()->HasArcs(state_)) { (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_, &cache_data_); num_arcs_ = cache_data_.narcs; arcs_ = cache_data_.arcs; // 'arcs_' is a ptr to the cached arcs. data_flags_ = kArcValueFlags; // All the arc member values are valid. } else { // Otherwise delay decision until Value() is called. tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_); if (tuple_.fst_state == kNoStateId) { num_arcs_ = 0; } else { // The decision to cache or not to cache has been defered // until Value() or SetFlags() is called. However, the arc // iterator is set up now to be ready for non-caching in order // to keep the Value() method simple and efficient. const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id); fst->InitArcIterator(tuple_.fst_state, &local_data_); // 'arcs_' is a pointer to the arcs in the underlying machine. arcs_ = local_data_.arcs; // Compute the final arc (but not its destination state) // if a final arc is required. bool has_final_arc = fst_.GetImpl()->ComputeFinalArc( tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue); // Set the arc value flags that hold for 'final_arc_'. final_flags_ = kArcValueFlags & ~kArcNextStateValue; // Compute the number of arcs. num_arcs_ = local_data_.narcs; if (has_final_arc) ++num_arcs_; // Set the offset between the underlying arc positions and // the positions in the arc iterator. offset_ = num_arcs_ - local_data_.narcs; // Defers the decision to cache or not until Value() or // SetFlags() is called. data_flags_ = 0; } } } ~ArcIterator() { if (cache_data_.ref_count) --(*cache_data_.ref_count); if (local_data_.ref_count) --(*local_data_.ref_count); } void ExpandAndCache() const { // TODO(allauzen): revisit this // fst_.GetImpl()->Expand(state_, tuple_, local_data_); // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_, // &cache_data_); // fst_.InitArcIterator(state_, &cache_data_); // Expand and cache state. arcs_ = cache_data_.arcs; // 'arcs_' is a pointer to the cached arcs. data_flags_ = kArcValueFlags; // All the arc member values are valid. offset_ = 0; // No offset } void Init() { if (flags_ & kArcNoCache) { // If caching is disabled // 'arcs_' is a pointer to the arcs in the underlying machine. arcs_ = local_data_.arcs; // Set the arcs value flags that hold for 'arcs_'. data_flags_ = kArcWeightValue; if (!fst_.GetImpl()->EpsilonOnReplace()) data_flags_ |= kArcILabelValue; // Set the offset between the underlying arc positions and // the positions in the arc iterator. offset_ = num_arcs_ - local_data_.narcs; } else { // Otherwise, expand and cache ExpandAndCache(); } } bool Done() const { return pos_ >= num_arcs_; } const A& Value() const { // If 'data_flags_' was set to 0, non-caching was not requested if (!data_flags_) { // TODO(allauzen): revisit this. if (flags_ & kArcNoCache) { // Should never happen. FSTERROR() << "ReplaceFst: inconsistent arc iterator flags"; } ExpandAndCache(); // Expand and cache. } if (pos_ - offset_ >= 0) { // The requested arc is not the 'final' arc. const A& arc = arcs_[pos_ - offset_]; if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) { // If the value flags for 'arc' match the recquired value flags // then return 'arc'. return arc; } else { // Otherwise, compute the corresponding arc on-the-fly. fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags); return arc_; } } else { // The requested arc is the 'final' arc. if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) { // If the arc value flags that hold for the final arc // do not match the requested value flags, then // 'final_arc_' needs to be updated. fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_, flags_ & kArcValueFlags); final_flags_ = flags_ & kArcValueFlags; } return final_arc_; } } void Next() { ++pos_; } size_t Position() const { return pos_; } void Reset() { pos_ = 0; } void Seek(size_t pos) { pos_ = pos; } uint32 Flags() const { return flags_; } void SetFlags(uint32 f, uint32 mask) { // Update the flags taking into account what flags are supported // by the Fst. flags_ &= ~mask; flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags()); // If non-caching is not requested (and caching has not already // been performed), then flush 'data_flags_' to request caching // during the next call to Value(). if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) { if (!fst_.GetImpl()->HasArcs(state_)) data_flags_ = 0; } // If 'data_flags_' has been flushed but non-caching is requested // before calling Value(), then set up the iterator for non-caching. if ((f & kArcNoCache) && (!data_flags_)) Init(); } private: const ReplaceFst<A, T> &fst_; // Reference to the FST StateId state_; // State in the FST mutable typename T::StateTuple tuple_; // Tuple corresponding to state_ ssize_t pos_; // Current position mutable ssize_t offset_; // Offset between position in iterator and in arcs_ ssize_t num_arcs_; // Number of arcs at state_ uint32 flags_; // Behavorial flags for the arc iterator mutable Arc arc_; // Memory to temporarily store computed arcs mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local fst mutable const A* arcs_; // Array of arcs mutable uint32 data_flags_; // Arc value flags valid for data in arcs_ mutable Arc final_arc_; // Final arc (when required) mutable uint32 final_flags_; // Arc value flags valid for final_arc_ DISALLOW_COPY_AND_ASSIGN(ArcIterator); }; template <class A, class T> class ReplaceFstMatcher : public MatcherBase<A> { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Label Label; typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher; ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type) : fst_(fst), impl_(fst_.GetImpl()), s_(fst::kNoStateId), match_type_(match_type), current_loop_(false), final_arc_(false), loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) { if (match_type_ == fst::MATCH_OUTPUT) swap(loop_.ilabel, loop_.olabel); InitMatchers(); } ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false) : fst_(matcher.fst_), impl_(fst_.GetImpl()), s_(fst::kNoStateId), match_type_(matcher.match_type_), current_loop_(false), loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) { if (match_type_ == fst::MATCH_OUTPUT) swap(loop_.ilabel, loop_.olabel); InitMatchers(); } // Create a local matcher for each component Fst of replace. // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher // is used to match each non-terminal arc, since these non-terminal // turn into epsilons on recursion. void InitMatchers() { const vector<const Fst<A>*>& fst_array = impl_->fst_array_; matcher_.resize(fst_array.size(), 0); for (size_t i = 0; i < fst_array.size(); ++i) { if (fst_array[i]) { matcher_[i] = new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList); typename set<Label>::iterator it = impl_->nonterminal_set_.begin(); for (; it != impl_->nonterminal_set_.end(); ++it) { matcher_[i]->AddMultiEpsLabel(*it); } } } } virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const { return new ReplaceFstMatcher<A, T>(*this, safe); } virtual ~ReplaceFstMatcher() { for (size_t i = 0; i < matcher_.size(); ++i) delete matcher_[i]; } virtual MatchType Type(bool test) const { if (match_type_ == MATCH_NONE) return match_type_; uint64 true_prop = match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted; uint64 false_prop = match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted; uint64 props = fst_.Properties(true_prop | false_prop, test); if (props & true_prop) return match_type_; else if (props & false_prop) return MATCH_NONE; else return MATCH_UNKNOWN; } virtual const Fst<A> &GetFst() const { return fst_; } virtual uint64 Properties(uint64 props) const { return props; } private: // Set the sate from which our matching happens. virtual void SetState_(StateId s) { if (s_ == s) return; s_ = s; tuple_ = impl_->GetStateTable()->Tuple(s_); if (tuple_.fst_state == kNoStateId) { done_ = true; return; } // Get current matcher. Used for non epsilon matching current_matcher_ = matcher_[tuple_.fst_id]; current_matcher_->SetState(tuple_.fst_state); loop_.nextstate = s_; final_arc_ = false; } // Search for label, from previous set state. If label == 0, first // hallucinate and epsilon loop, else use the underlying matcher to // search for the label or epsilons. // - Note since the ReplaceFST recursion on non-terminal arcs causes // epsilon transitions to be created we use the MultiEpsilonMatcher // to search for possible matches of non terminals. // - If the component Fst reaches a final state we also need to add // the exiting final arc. virtual bool Find_(Label label) { bool found = false; label_ = label; if (label_ == 0 || label_ == kNoLabel) { // Compute loop directly, saving Replace::ComputeArc if (label_ == 0) { current_loop_ = true; found = true; } // Search for matching multi epsilons final_arc_ = impl_->ComputeFinalArc(tuple_, 0); found = current_matcher_->Find(kNoLabel) || final_arc_ || found; } else { // Search on sub machine directly using sub machine matcher. found = current_matcher_->Find(label_); } return found; } virtual bool Done_() const { return !current_loop_ && !final_arc_ && current_matcher_->Done(); } virtual const Arc& Value_() const { if (current_loop_) { return loop_; } if (final_arc_) { impl_->ComputeFinalArc(tuple_, &arc_); return arc_; } const Arc& component_arc = current_matcher_->Value(); impl_->ComputeArc(tuple_, component_arc, &arc_); return arc_; } virtual void Next_() { if (current_loop_) { current_loop_ = false; return; } if (final_arc_) { final_arc_ = false; return; } current_matcher_->Next(); } const ReplaceFst<A, T>& fst_; ReplaceFstImpl<A, T> *impl_; LocalMatcher* current_matcher_; vector<LocalMatcher*> matcher_; StateId s_; // Current state Label label_; // Current label MatchType match_type_; // Supplied by caller mutable bool done_; mutable bool current_loop_; // Current arc is the implicit loop mutable bool final_arc_; // Current arc for exiting recursion mutable typename T::StateTuple tuple_; // Tuple corresponding to state_ mutable Arc arc_; Arc loop_; }; template <class A, class T> inline void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const { data->base = new StateIterator< ReplaceFst<A, T> >(*this); } typedef ReplaceFst<StdArc> StdReplaceFst; // // Recursivively replaces arcs in the root Fst with other Fsts. // This version writes the result of replacement to an output MutableFst. // // Replace supports replacement of arcs in one Fst with another // Fst. This replacement is recursive. Replace takes an array of // Fst(s). One Fst represents the root (or topology) machine. The root // Fst refers to other Fsts by recursively replacing arcs labeled as // non-terminals with the matching non-terminal Fst. Currently Replace // uses the output symbols of the arcs to determine whether the arc is // a non-terminal arc or not. A non-terminal can be any label that is // not a non-zero terminal label in the output alphabet. Note that // input argument is a vector of pair<>. These correspond to the tuple // of non-terminal Label and corresponding Fst. template<class Arc> void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >& ifst_array, MutableFst<Arc> *ofst, typename Arc::Label root, bool epsilon_on_replace) { ReplaceFstOptions<Arc> opts(root, epsilon_on_replace); opts.gc_limit = 0; // Cache only the last state for fastest copy. *ofst = ReplaceFst<Arc>(ifst_array, opts); } template<class Arc> void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >& ifst_array, MutableFst<Arc> *ofst, typename Arc::Label root) { Replace(ifst_array, ofst, root, false); } } // namespace fst #endif // FST_LIB_REPLACE_H__