// compose.h // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // \file // Class to compute the composition of two FSTs #ifndef FST_LIB_COMPOSE_H__ #define FST_LIB_COMPOSE_H__ #include <algorithm> #include <ext/hash_map> using __gnu_cxx::hash_map; #include "fst/lib/cache.h" #include "fst/lib/test-properties.h" namespace fst { // Enumeration of uint64 bits used to represent the user-defined // properties of FST composition (in the template parameter to // ComposeFstOptions<T>). The bits stand for extensions of generic FST // composition. ComposeFstOptions<> (all the bits unset) is the "plain" // compose without any extra extensions. enum ComposeTypes { // RHO: flags dealing with a special "rest" symbol in the FSTs. // NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO // may be set. COMPOSE_FST1_RHO = 1ULL<<0, // "Rest" symbol on the output side of fst1. COMPOSE_FST2_RHO = 1ULL<<1, // "Rest" symbol on the input side of fst2. COMPOSE_FST1_PHI = 1ULL<<2, // "Failure" symbol on the output // side of fst1. COMPOSE_FST2_PHI = 1ULL<<3, // "Failure" symbol on the input side // of fst2. COMPOSE_FST1_SIGMA = 1ULL<<4, // "Any" symbol on the output side of // fst1. COMPOSE_FST2_SIGMA = 1ULL<<5, // "Any" symbol on the input side of // fst2. // Optimization related bits. COMPOSE_GENERIC = 1ULL<<32, // Disables optimizations, applies // the generic version of the // composition algorithm. This flag // is used for internal testing // only. // ----------------------------------------------------------------- // Auxiliary enum values denoting specific combinations of // bits. Internal use only. COMPOSE_RHO = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO, COMPOSE_PHI = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI, COMPOSE_SIGMA = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA, COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA, // ----------------------------------------------------------------- // The following bits, denoting specific optimizations, are // typically set *internally* by the composition algorithm. COMPOSE_FST1_STRING = 1ULL<<33, // fst1 is a string COMPOSE_FST2_STRING = 1ULL<<34, // fst2 is a string COMPOSE_FST1_DET = 1ULL<<35, // fst1 is deterministic COMPOSE_FST2_DET = 1ULL<<36, // fst2 is deterministic COMPOSE_INTERNAL_MASK = 0xffffffff00000000ULL }; template <uint64 T = 0ULL> struct ComposeFstOptions : public CacheOptions { explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} ComposeFstOptions() { } }; // Abstract base for the implementation of delayed ComposeFst. The // concrete specializations are templated on the (uint64-valued) // properties of the FSTs being composed. template <class A> class ComposeFstImplBase : public CacheImpl<A> { public: using FstImpl<A>::SetType; using FstImpl<A>::SetProperties; using FstImpl<A>::Properties; using FstImpl<A>::SetInputSymbols; using FstImpl<A>::SetOutputSymbols; using CacheBaseImpl< CacheState<A> >::HasStart; using CacheBaseImpl< CacheState<A> >::HasFinal; using CacheBaseImpl< CacheState<A> >::HasArcs; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; ComposeFstImplBase(const Fst<A> &fst1, const Fst<A> &fst2, const CacheOptions &opts) :CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) { SetType("compose"); uint64 props1 = fst1.Properties(kFstProperties, false); uint64 props2 = fst2.Properties(kFstProperties, false); SetProperties(ComposeProperties(props1, props2), kCopyProperties); if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) LOG(FATAL) << "ComposeFst: output symbol table of 1st argument " << "does not match input symbol table of 2nd argument"; SetInputSymbols(fst1.InputSymbols()); SetOutputSymbols(fst2.OutputSymbols()); } virtual ~ComposeFstImplBase() { delete fst1_; delete fst2_; } StateId Start() { if (!HasStart()) { StateId start = ComputeStart(); if (start != kNoStateId) { SetStart(start); } } return CacheImpl<A>::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { Weight final = ComputeFinal(s); SetFinal(s, final); } return CacheImpl<A>::Final(s); } virtual void Expand(StateId s) = 0; size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumOutputEpsilons(s); } void InitArcIterator(StateId s, ArcIteratorData<A> *data) { if (!HasArcs(s)) Expand(s); CacheImpl<A>::InitArcIterator(s, data); } // Access to flags encoding compose options/optimizations etc. (for // debugging). virtual uint64 ComposeFlags() const = 0; protected: virtual StateId ComputeStart() = 0; virtual Weight ComputeFinal(StateId s) = 0; const Fst<A> *fst1_; // first input Fst const Fst<A> *fst2_; // second input Fst }; // The following class encapsulates implementation-dependent details // of state tuple lookup, i.e. a bijective mapping from triples of two // FST states and an epsilon filter state to the corresponding state // IDs of the fst resulting from composition. The mapping must // implement the [] operator in the style of STL associative // containers (map, hash_map), i.e. table[x] must return a reference // to the value associated with x. If x is an unassigned tuple, the // operator must automatically associate x with value 0. // // NB: "table[x] == 0" for unassigned tuples x is required by the // following off-by-one device used in the implementation of // ComposeFstImpl. The value stored in the table is equal to tuple ID // plus one, i.e. it is always a strictly positive number. Therefore, // table[x] is equal to 0 if and only if x is an unassigned tuple (in // which the algorithm assigns a new ID to x, and sets table[x] - // stored in a reference - to "new ID + 1"). This form of lookup is // more efficient than calling "find(x)" and "insert(make_pair(x, new // ID))" if x is an unassigned tuple. // // The generic implementation is a wrapper around a hash_map. template <class A, uint64 T> class ComposeStateTable { public: typedef typename A::StateId StateId; struct StateTuple { StateTuple() {} StateTuple(StateId s1, StateId s2, int f) : state_id1(s1), state_id2(s2), filt(f) {} StateId state_id1; // state Id on fst1 StateId state_id2; // state Id on fst2 int filt; // epsilon filter state }; ComposeStateTable() { StateTuple empty_tuple(kNoStateId, kNoStateId, 0); } // NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is // inserted into 'table_' (standard STL container semantics). Since // StateId is a built-in type, the explicit default constructor call // StateId() returns 0. StateId &operator[](const StateTuple &tuple) { return table_[tuple]; } private: // Comparison object for hashing StateTuple(s). class StateTupleEqual { public: bool operator()(const StateTuple& x, const StateTuple& y) const { return x.state_id1 == y.state_id1 && x.state_id2 == y.state_id2 && x.filt == y.filt; } }; static const int kPrime0 = 7853; static const int kPrime1 = 7867; // Hash function for StateTuple to Fst states. class StateTupleKey { public: size_t operator()(const StateTuple& x) const { return static_cast<size_t>(x.state_id1 + x.state_id2 * kPrime0 + x.filt * kPrime1); } }; // Lookup table mapping state tuples to state IDs. typedef hash_map<StateTuple, StateId, StateTupleKey, StateTupleEqual> StateTable; // Actual table data. StateTable table_; DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable); }; // State tuple lookup table for the composition of a string FST with a // deterministic FST. The class maps state tuples to their unique IDs // (i.e. states of the ComposeFst). Main optimization: due to the // 1-to-1 correspondence between the states of the input string FST // and those of the resulting (string) FST, a state tuple (s1, s2) is // simply mapped to StateId s1. Hence, we use an STL vector as a // lookup table. Template argument Fst1IsString specifies which FST is // a string (this determines whether or not we index the lookup table // by the first or by the second state). template <class A, bool Fst1IsString> class StringDetComposeStateTable { public: typedef typename A::StateId StateId; struct StateTuple { typedef typename A::StateId StateId; StateTuple() {} StateTuple(StateId s1, StateId s2, int /* f */) : state_id1(s1), state_id2(s2) {} StateId state_id1; // state Id on fst1 StateId state_id2; // state Id on fst2 static const int filt = 0; // 'fake' epsilon filter - only needed // for API compatibility }; StringDetComposeStateTable() {} // Subscript operator. Behaves in a way similar to its map/hash_map // counterpart, i.e. returns a reference to the value associated // with 'tuple', inserting a 0 value if 'tuple' is unassigned. StateId &operator[](const StateTuple &tuple) { StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2; if (index >= (StateId)data_.size()) { // NB: all values in [old_size; index] are initialized to 0. data_.resize(index + 1); } return data_[index]; } private: vector<StateId> data_; DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable); }; // Specializations of ComposeStateTable for the string/det case. // Both inherit from StringDetComposeStateTable. template <class A> class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET> : public StringDetComposeStateTable<A, true> { }; template <class A> class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET> : public StringDetComposeStateTable<A, false> { }; // Parameterized implementation of FST composition for a pair of FSTs // matching the property bit vector T. If possible, // instantiation-specific switches in the code are based on the values // of the bits in T, which are known at compile time, so unused code // should be optimized away by the compiler. template <class A, uint64 T> class ComposeFstImpl : public ComposeFstImplBase<A> { typedef typename A::StateId StateId; typedef typename A::Label Label; typedef typename A::Weight Weight; using FstImpl<A>::SetType; using FstImpl<A>::SetProperties; enum FindType { FIND_INPUT = 1, // find input label on fst2 FIND_OUTPUT = 2, // find output label on fst1 FIND_BOTH = 3 }; // find choice state dependent typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable; typedef typename StateTupleTable::StateTuple StateTuple; public: ComposeFstImpl(const Fst<A> &fst1, const Fst<A> &fst2, const CacheOptions &opts) :ComposeFstImplBase<A>(fst1, fst2, opts) { bool osorted = fst1.Properties(kOLabelSorted, false); bool isorted = fst2.Properties(kILabelSorted, false); switch (T & COMPOSE_SPECIAL_SYMBOLS) { case COMPOSE_FST1_RHO: case COMPOSE_FST1_PHI: case COMPOSE_FST1_SIGMA: if (!osorted || FLAGS_fst_verify_properties) osorted = fst1.Properties(kOLabelSorted, true); if (!osorted) LOG(FATAL) << "ComposeFst: 1st argument not output label " << "sorted (special symbols present)"; break; case COMPOSE_FST2_RHO: case COMPOSE_FST2_PHI: case COMPOSE_FST2_SIGMA: if (!isorted || FLAGS_fst_verify_properties) isorted = fst2.Properties(kILabelSorted, true); if (!isorted) LOG(FATAL) << "ComposeFst: 2nd argument not input label " << "sorted (special symbols present)"; break; case 0: if (!isorted && !osorted || FLAGS_fst_verify_properties) { osorted = fst1.Properties(kOLabelSorted, true); if (!osorted) isorted = fst2.Properties(kILabelSorted, true); } break; default: LOG(FATAL) << "ComposeFst: More than one special symbol used in composition"; } if (isorted && (T & COMPOSE_FST2_SIGMA)) { find_type_ = FIND_INPUT; } else if (osorted && (T & COMPOSE_FST1_SIGMA)) { find_type_ = FIND_OUTPUT; } else if (isorted && (T & COMPOSE_FST2_PHI)) { find_type_ = FIND_INPUT; } else if (osorted && (T & COMPOSE_FST1_PHI)) { find_type_ = FIND_OUTPUT; } else if (isorted && (T & COMPOSE_FST2_RHO)) { find_type_ = FIND_INPUT; } else if (osorted && (T & COMPOSE_FST1_RHO)) { find_type_ = FIND_OUTPUT; } else if (isorted && (T & COMPOSE_FST1_STRING)) { find_type_ = FIND_INPUT; } else if(osorted && (T & COMPOSE_FST2_STRING)) { find_type_ = FIND_OUTPUT; } else if (isorted && osorted) { find_type_ = FIND_BOTH; } else if (isorted) { find_type_ = FIND_INPUT; } else if (osorted) { find_type_ = FIND_OUTPUT; } else { LOG(FATAL) << "ComposeFst: 1st argument not output label sorted " << "and 2nd argument is not input label sorted"; } } // Finds/creates an Fst state given a StateTuple. Only creates a new // state if StateTuple is not found in the state hash. // // The method exploits the following device: all pairs stored in the // associative container state_tuple_table_ are of the form (tuple, // id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has // been stored previously. For unassigned tuples, the call to // state_tuple_table_[tuple] creates a new pair (tuple, 0). As a // result, state_tuple_table_[tuple] == 0 iff tuple is new. StateId FindState(const StateTuple& tuple) { StateId &assoc_value = state_tuple_table_[tuple]; if (assoc_value == 0) { // tuple wasn't present in lookup table: // assign it a new ID. state_tuples_.push_back(tuple); assoc_value = state_tuples_.size(); } return assoc_value - 1; // NB: assoc_value = ID + 1 } // Generates arc for composition state s from matched input Fst arcs. void AddArc(StateId s, const A &arca, const A &arcb, int f, bool find_input) { A arc; if (find_input) { arc.ilabel = arcb.ilabel; arc.olabel = arca.olabel; arc.weight = Times(arcb.weight, arca.weight); StateTuple tuple(arcb.nextstate, arca.nextstate, f); arc.nextstate = FindState(tuple); } else { arc.ilabel = arca.ilabel; arc.olabel = arcb.olabel; arc.weight = Times(arca.weight, arcb.weight); StateTuple tuple(arca.nextstate, arcb.nextstate, f); arc.nextstate = FindState(tuple); } CacheImpl<A>::AddArc(s, arc); } // Arranges it so that the first arg to OrderedExpand is the Fst // that will be passed to FindLabel. void Expand(StateId s) { StateTuple &tuple = state_tuples_[s]; StateId s1 = tuple.state_id1; StateId s2 = tuple.state_id2; int f = tuple.filt; if (find_type_ == FIND_INPUT) OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2, ComposeFstImplBase<A>::fst1_, s1, f, true); else OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1, ComposeFstImplBase<A>::fst2_, s2, f, false); } // Access to flags encoding compose options/optimizations etc. (for // debugging). virtual uint64 ComposeFlags() const { return T; } private: // This does that actual matching of labels in the composition. The // arguments are ordered so FindLabel is called with state SA of // FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg // determines whether the input or output label of arcs at SB is // the one to match on. void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa, const Fst<A> *fstb, StateId sb, int f, bool find_input) { size_t numarcsa = fsta->NumArcs(sa); size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) : fsta->NumOutputEpsilons(sa); bool finala = fsta->Final(sa) != Weight::Zero(); ArcIterator< Fst<A> > aitera(*fsta, sa); // First handle special epsilons and sigmas on FSTA for (; !aitera.Done(); aitera.Next()) { const A &arca = aitera.Value(); Label match_labela = find_input ? arca.ilabel : arca.olabel; if (match_labela > 0) { break; } if ((T & COMPOSE_SIGMA) != 0 && match_labela == kSigmaLabel) { // Found a sigma? Match it against all (non-special) symbols // on side b. for (ArcIterator< Fst<A> > aiterb(*fstb, sb); !aiterb.Done(); aiterb.Next()) { const A &arcb = aiterb.Value(); Label labelb = find_input ? arcb.olabel : arcb.ilabel; if (labelb <= 0) continue; AddArc(s, arca, arcb, 0, find_input); } } else if (f == 0 && match_labela == 0) { A earcb(0, 0, Weight::One(), sb); AddArc(s, arca, earcb, 0, find_input); // move forward on epsilon } } // Next handle non-epsilon matches, rho labels, and epsilons on FSTB for (ArcIterator< Fst<A> > aiterb(*fstb, sb); !aiterb.Done(); aiterb.Next()) { const A &arcb = aiterb.Value(); Label match_labelb = find_input ? arcb.olabel : arcb.ilabel; if (match_labelb) { // Consider non-epsilon match if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) { for (; !aitera.Done(); aitera.Next()) { const A &arca = aitera.Value(); Label match_labela = find_input ? arca.ilabel : arca.olabel; if (match_labela != match_labelb) break; AddArc(s, arca, arcb, 0, find_input); // move forward on match } } else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) { // If there is no transition labelled 'match_labelb' in // fsta, try matching 'match_labelb' against special symbols // (Phi, Rho,...). for (aitera.Reset(); !aitera.Done(); aitera.Next()) { A arca = aitera.Value(); Label labela = find_input ? arca.ilabel : arca.olabel; if (labela >= 0) { break; } else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) { // Case 1: if a failure transition exists, follow its // transitive closure until a) a transition labelled // 'match_labelb' is found, or b) the initial state of // fsta is reached. StateId sf = sa; // Start of current failure transition. while (labela == kPhiLabel && sf != arca.nextstate) { sf = arca.nextstate; size_t numarcsf = fsta->NumArcs(sf); ArcIterator< Fst<A> > aiterf(*fsta, sf); if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) { // Sub-case 1a: there exists a transition starting // in sf and consuming symbol 'match_labelb'. AddArc(s, aiterf.Value(), arcb, 0, find_input); break; } else { // No transition labelled 'match_labelb' found: try // next failure transition (starting at 'sf'). for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) { arca = aiterf.Value(); labela = find_input ? arca.ilabel : arca.olabel; if (labela >= kPhiLabel) break; } } } if (labela == kPhiLabel && sf == arca.nextstate) { // Sub-case 1b: failure transitions lead to start // state without finding a matching // transition. Therefore, we generate a loop in start // state of fsta. A loop(match_labelb, match_labelb, Weight::One(), sf); AddArc(s, loop, arcb, 0, find_input); } } else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) { // Case 2: 'match_labelb' can be matched against a // "rest" (rho) label in fsta. if (find_input) { arca.ilabel = match_labelb; if (arca.olabel == kRhoLabel) arca.olabel = match_labelb; } else { arca.olabel = match_labelb; if (arca.ilabel == kRhoLabel) arca.ilabel = match_labelb; } AddArc(s, arca, arcb, 0, find_input); // move fwd on match } } } } else if (numepsa != numarcsa || finala) { // Handle FSTB epsilon A earca(0, 0, Weight::One(), sa); AddArc(s, earca, arcb, numepsa > 0, find_input); // move on epsilon } } SetArcs(s); } // Finds matches to MATCH_LABEL in arcs given by AITER // using FIND_INPUT to determine whether to look on input or output. bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs, Label match_label, bool find_input) { // binary search for match size_t low = 0; size_t high = numarcs; while (low < high) { size_t mid = (low + high) / 2; aiter->Seek(mid); Label label = find_input ? aiter->Value().ilabel : aiter->Value().olabel; if (label > match_label) { high = mid; } else if (label < match_label) { low = mid + 1; } else { // find first matching label (when non-determinism) for (size_t i = mid; i > low; --i) { aiter->Seek(i - 1); label = find_input ? aiter->Value().ilabel : aiter->Value().olabel; if (label != match_label) { aiter->Seek(i); return true; } } return true; } } return false; } StateId ComputeStart() { StateId s1 = ComposeFstImplBase<A>::fst1_->Start(); StateId s2 = ComposeFstImplBase<A>::fst2_->Start(); if (s1 == kNoStateId || s2 == kNoStateId) return kNoStateId; StateTuple tuple(s1, s2, 0); return FindState(tuple); } Weight ComputeFinal(StateId s) { StateTuple &tuple = state_tuples_[s]; Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1), ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2)); return final; } FindType find_type_; // find label on which side? // Maps from StateId to StateTuple. vector<StateTuple> state_tuples_; // Maps from StateTuple to StateId. StateTupleTable state_tuple_table_; DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl); }; // Computes the composition of two transducers. This version is a // delayed Fst. If FST1 transduces string x to y with weight a and FST2 // transduces y to z with weight b, then their composition transduces // string x to z with weight Times(x, z). // // The output labels of the first transducer or the input labels of // the second transducer must be sorted. The weights need to form a // commutative semiring (valid for TropicalWeight and LogWeight). // // Complexity: // Assuming the first FST is unsorted and the second is sorted: // - Time: O(v1 v2 d1 (log d2 + m2)), // - Space: O(v1 v2) // where vi = # of states visited, di = maximum out-degree, and mi the // maximum multiplicity of the states visited for the ith // FST. Constant time and space to visit an input state or arc is // assumed and exclusive of caching. // // Caveats: // - ComposeFst does not trim its output (since it is a delayed operation). // - The efficiency of composition can be strongly affected by several factors: // - the choice of which tnansducer is sorted - prefer sorting the FST // that has the greater average out-degree. // - the amount of non-determinism // - the presence and location of epsilon transitions - avoid epsilon // transitions on the output side of the first transducer or // the input side of the second transducer or prefer placing // them later in a path since they delay matching and can // introduce non-coaccessible states and transitions. template <class A> class ComposeFst : public Fst<A> { public: friend class ArcIterator< ComposeFst<A> >; friend class CacheStateIterator< ComposeFst<A> >; friend class CacheArcIterator< ComposeFst<A> >; typedef A Arc; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2) : impl_(Init(fst1, fst2, ComposeFstOptions<>())) { } template <uint64 T> ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2, const ComposeFstOptions<T> &opts) : impl_(Init(fst1, fst2, opts)) { } ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) { impl_->IncrRefCount(); } virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_; } virtual StateId Start() const { return impl_->Start(); } virtual Weight Final(StateId s) const { return impl_->Final(s); } virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } virtual size_t NumInputEpsilons(StateId s) const { return impl_->NumInputEpsilons(s); } virtual size_t NumOutputEpsilons(StateId s) const { return impl_->NumOutputEpsilons(s); } virtual uint64 Properties(uint64 mask, bool test) const { if (test) { uint64 known, test = TestProperties(*this, mask, &known); impl_->SetProperties(test, known); return test & mask; } else { return impl_->Properties(mask); } } virtual const string& Type() const { return impl_->Type(); } virtual ComposeFst<A> *Copy() const { return new ComposeFst<A>(*this); } virtual const SymbolTable* InputSymbols() const { return impl_->InputSymbols(); } virtual const SymbolTable* OutputSymbols() const { return impl_->OutputSymbols(); } virtual inline void InitStateIterator(StateIteratorData<A> *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { impl_->InitArcIterator(s, data); } // Access to flags encoding compose options/optimizations etc. (for // debugging). uint64 ComposeFlags() const { return impl_->ComposeFlags(); } protected: ComposeFstImplBase<A> *Impl() { return impl_; } private: ComposeFstImplBase<A> *impl_; // Auxiliary method encapsulating the creation of a ComposeFst // implementation that is appropriate for the properties of fst1 and // fst2. template <uint64 T> static ComposeFstImplBase<A> *Init( const Fst<A> &fst1, const Fst<A> &fst2, const ComposeFstOptions<T> &opts) { // Filter for sort properties (forces a property check). uint64 sort_props_mask = kILabelSorted | kOLabelSorted; // Filter for optimization-related properties (does not force a // property-check). uint64 opt_props_mask = kString | kIDeterministic | kODeterministic | kNoIEpsilons | kNoOEpsilons; uint64 props1 = fst1.Properties(sort_props_mask, true); uint64 props2 = fst2.Properties(sort_props_mask, true); props1 |= fst1.Properties(opt_props_mask, false); props2 |= fst2.Properties(opt_props_mask, false); if (!(Weight::Properties() & kCommutative)) { props1 |= fst1.Properties(kUnweighted, true); props2 |= fst2.Properties(kUnweighted, true); if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: " << Weight::Type(); } // Case 1: flag COMPOSE_GENERIC disables optimizations. if (T & COMPOSE_GENERIC) { return new ComposeFstImpl<A, T>(fst1, fst2, opts); } const uint64 kStringDetOptProps = kIDeterministic | kILabelSorted | kNoIEpsilons; const uint64 kDetStringOptProps = kODeterministic | kOLabelSorted | kNoOEpsilons; // Case 2: fst1 is a string, fst2 is deterministic and epsilon-free. if ((props1 & kString) && !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) && ((props2 & kStringDetOptProps) == kStringDetOptProps)) { return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>( fst1, fst2, opts); } // Case 3: fst1 is deterministic and epsilon-free, fst2 is string. if ((props2 & kString) && !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) && ((props1 & kDetStringOptProps) == kDetStringOptProps)) { return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>( fst1, fst2, opts); } // Default case: no optimizations. return new ComposeFstImpl<A, T>(fst1, fst2, opts); } void operator=(const ComposeFst<A> &fst); // disallow }; // Specialization for ComposeFst. template<class A> class StateIterator< ComposeFst<A> > : public CacheStateIterator< ComposeFst<A> > { public: explicit StateIterator(const ComposeFst<A> &fst) : CacheStateIterator< ComposeFst<A> >(fst) {} }; // Specialization for ComposeFst. template <class A> class ArcIterator< ComposeFst<A> > : public CacheArcIterator< ComposeFst<A> > { public: typedef typename A::StateId StateId; ArcIterator(const ComposeFst<A> &fst, StateId s) : CacheArcIterator< ComposeFst<A> >(fst, s) { if (!fst.impl_->HasArcs(s)) fst.impl_->Expand(s); } private: DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); }; template <class A> inline void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const { data->base = new StateIterator< ComposeFst<A> >(*this); } // Useful alias when using StdArc. typedef ComposeFst<StdArc> StdComposeFst; struct ComposeOptions { bool connect; // Connect output ComposeOptions(bool c) : connect(c) {} ComposeOptions() : connect(true) { } }; // Computes the composition of two transducers. This version writes // the composed FST into a MurableFst. If FST1 transduces string x to // y with weight a and FST2 transduces y to z with weight b, then // their composition transduces string x to z with weight // Times(x, z). // // The output labels of the first transducer or the input labels of // the second transducer must be sorted. The weights need to form a // commutative semiring (valid for TropicalWeight and LogWeight). // // Complexity: // Assuming the first FST is unsorted and the second is sorted: // - Time: O(V1 V2 D1 (log D2 + M2)), // - Space: O(V1 V2 D1 M2) // where Vi = # of states, Di = maximum out-degree, and Mi is // the maximum multiplicity for the ith FST. // // Caveats: // - Compose trims its output. // - The efficiency of composition can be strongly affected by several factors: // - the choice of which tnansducer is sorted - prefer sorting the FST // that has the greater average out-degree. // - the amount of non-determinism // - the presence and location of epsilon transitions - avoid epsilon // transitions on the output side of the first transducer or // the input side of the second transducer or prefer placing // them later in a path since they delay matching and can // introduce non-coaccessible states and transitions. template<class Arc> void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2, MutableFst<Arc> *ofst, const ComposeOptions &opts = ComposeOptions()) { ComposeFstOptions<> nopts; nopts.gc_limit = 0; // Cache only the last state for fastest copy. *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts); if (opts.connect) Connect(ofst); } } // namespace fst #endif // FST_LIB_COMPOSE_H__