// factor-weight.h // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Author: allauzen@cs.nyu.edu (Cyril Allauzen) // // \file // Classes to factor weights in an FST. #ifndef FST_LIB_FACTOR_WEIGHT_H__ #define FST_LIB_FACTOR_WEIGHT_H__ #include <algorithm> #include <ext/hash_map> using __gnu_cxx::hash_map; #include <ext/slist> using __gnu_cxx::slist; #include "fst/lib/cache.h" #include "fst/lib/test-properties.h" namespace fst { struct FactorWeightOptions : CacheOptions { float delta; bool final_only; // only factor final weights when true FactorWeightOptions(const CacheOptions &opts, float d, bool of) : CacheOptions(opts), delta(d), final_only(of) {} explicit FactorWeightOptions(float d, bool of = false) : delta(d), final_only(of) {} FactorWeightOptions(bool of = false) : delta(kDelta), final_only(of) {} }; // A factor iterator takes as argument a weight w and returns a // sequence of pairs of weights (xi,yi) such that the sum of the // products xi times yi is equal to w. If w is fully factored, // the iterator should return nothing. // // template <class W> // class FactorIterator { // public: // FactorIterator(W w); // bool Done() const; // void Next(); // pair<W, W> Value() const; // void Reset(); // } // Factor trivially. template <class W> class IdentityFactor { public: IdentityFactor(const W &w) {} bool Done() const { return true; } void Next() {} pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused void Reset() {} }; // Factor a StringWeight w as 'ab' where 'a' is a label. template <typename L, StringType S = STRING_LEFT> class StringFactor { public: StringFactor(const StringWeight<L, S> &w) : weight_(w), done_(w.Size() <= 1) {} bool Done() const { return done_; } void Next() { done_ = true; } pair< StringWeight<L, S>, StringWeight<L, S> > Value() const { StringWeightIterator<L, S> iter(weight_); StringWeight<L, S> w1(iter.Value()); StringWeight<L, S> w2; for (iter.Next(); !iter.Done(); iter.Next()) w2.PushBack(iter.Value()); return make_pair(w1, w2); } void Reset() { done_ = weight_.Size() <= 1; } private: StringWeight<L, S> weight_; bool done_; }; // Factor a GallicWeight using StringFactor. template <class L, class W, StringType S = STRING_LEFT> class GallicFactor { public: GallicFactor(const GallicWeight<L, W, S> &w) : weight_(w), done_(w.Value1().Size() <= 1) {} bool Done() const { return done_; } void Next() { done_ = true; } pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const { StringFactor<L, S> iter(weight_.Value1()); GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2()); GallicWeight<L, W, S> w2(iter.Value().second, W::One()); return make_pair(w1, w2); } void Reset() { done_ = weight_.Value1().Size() <= 1; } private: GallicWeight<L, W, S> weight_; bool done_; }; // Implementation class for FactorWeight template <class A, class F> class FactorWeightFstImpl : public CacheImpl<A> { public: using FstImpl<A>::SetType; using FstImpl<A>::SetProperties; using FstImpl<A>::Properties; using FstImpl<A>::SetInputSymbols; using FstImpl<A>::SetOutputSymbols; using CacheBaseImpl< CacheState<A> >::HasStart; using CacheBaseImpl< CacheState<A> >::HasFinal; using CacheBaseImpl< CacheState<A> >::HasArcs; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef F FactorIterator; struct Element { Element() {} Element(StateId s, Weight w) : state(s), weight(w) {} StateId state; // Input state Id Weight weight; // Residual weight }; FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts) : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta), final_only_(opts.final_only) { SetType("factor-weight"); uint64 props = fst.Properties(kFstProperties, false); SetProperties(FactorWeightProperties(props), kCopyProperties); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); } ~FactorWeightFstImpl() { delete fst_; } StateId Start() { if (!HasStart()) { StateId s = fst_->Start(); if (s == kNoStateId) return kNoStateId; StateId start = FindState(Element(fst_->Start(), Weight::One())); SetStart(start); } return CacheImpl<A>::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { const Element &e = elements_[s]; // TODO: fix so cast is unnecessary Weight w = e.state == kNoStateId ? e.weight : (Weight) Times(e.weight, fst_->Final(e.state)); FactorIterator f(w); if (w != Weight::Zero() && f.Done()) SetFinal(s, w); else SetFinal(s, Weight::Zero()); } return CacheImpl<A>::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) Expand(s); return CacheImpl<A>::NumOutputEpsilons(s); } void InitArcIterator(StateId s, ArcIteratorData<A> *data) { if (!HasArcs(s)) Expand(s); CacheImpl<A>::InitArcIterator(s, data); } // Find state corresponding to an element. Create new state // if element not found. StateId FindState(const Element &e) { if (final_only_ && e.weight == Weight::One()) { while (unfactored_.size() <= (unsigned int)e.state) unfactored_.push_back(kNoStateId); if (unfactored_[e.state] == kNoStateId) { unfactored_[e.state] = elements_.size(); elements_.push_back(e); } return unfactored_[e.state]; } else { typename ElementMap::iterator eit = element_map_.find(e); if (eit != element_map_.end()) { return (*eit).second; } else { StateId s = elements_.size(); elements_.push_back(e); element_map_.insert(pair<const Element, StateId>(e, s)); return s; } } } // Computes the outgoing transitions from a state, creating new destination // states as needed. void Expand(StateId s) { Element e = elements_[s]; if (e.state != kNoStateId) { for (ArcIterator< Fst<A> > ait(*fst_, e.state); !ait.Done(); ait.Next()) { const A &arc = ait.Value(); Weight w = Times(e.weight, arc.weight); FactorIterator fit(w); if (final_only_ || fit.Done()) { StateId d = FindState(Element(arc.nextstate, Weight::One())); AddArc(s, Arc(arc.ilabel, arc.olabel, w, d)); } else { for (; !fit.Done(); fit.Next()) { const pair<Weight, Weight> &p = fit.Value(); StateId d = FindState(Element(arc.nextstate, p.second.Quantize(delta_))); AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); } } } } if ((e.state == kNoStateId) || (fst_->Final(e.state) != Weight::Zero())) { Weight w = e.state == kNoStateId ? e.weight : Times(e.weight, fst_->Final(e.state)); for (FactorIterator fit(w); !fit.Done(); fit.Next()) { const pair<Weight, Weight> &p = fit.Value(); StateId d = FindState(Element(kNoStateId, p.second.Quantize(delta_))); AddArc(s, Arc(0, 0, p.first, d)); } } SetArcs(s); } private: // Equality function for Elements, assume weights have been quantized. class ElementEqual { public: bool operator()(const Element &x, const Element &y) const { return x.state == y.state && x.weight == y.weight; } }; // Hash function for Elements to Fst states. class ElementKey { public: size_t operator()(const Element &x) const { return static_cast<size_t>(x.state * kPrime + x.weight.Hash()); } private: static const int kPrime = 7853; }; typedef hash_map<Element, StateId, ElementKey, ElementEqual> ElementMap; const Fst<A> *fst_; float delta_; bool final_only_; vector<Element> elements_; // mapping Fst state to Elements ElementMap element_map_; // mapping Elements to Fst state // mapping between old/new 'StateId' for states that do not need to // be factored when 'final_only_' is true vector<StateId> unfactored_; DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl); }; // FactorWeightFst takes as template parameter a FactorIterator as // defined above. The result of weight factoring is a transducer // equivalent to the input whose path weights have been factored // according to the FactorIterator. States and transitions will be // added as necessary. The algorithm is a generalization to arbitrary // weights of the second step of the input epsilon-normalization // algorithm due to Mohri, "Generic epsilon-removal and input // epsilon-normalization algorithms for weighted transducers", // International Journal of Computer Science 13(1): 129-143 (2002). template <class A, class F> class FactorWeightFst : public Fst<A> { public: friend class ArcIterator< FactorWeightFst<A, F> >; friend class CacheStateIterator< FactorWeightFst<A, F> >; friend class CacheArcIterator< FactorWeightFst<A, F> >; typedef A Arc; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; FactorWeightFst(const Fst<A> &fst) : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {} FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions &opts) : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {} FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) { impl_->IncrRefCount(); } virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_; } virtual StateId Start() const { return impl_->Start(); } virtual Weight Final(StateId s) const { return impl_->Final(s); } virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } virtual size_t NumInputEpsilons(StateId s) const { return impl_->NumInputEpsilons(s); } virtual size_t NumOutputEpsilons(StateId s) const { return impl_->NumOutputEpsilons(s); } virtual uint64 Properties(uint64 mask, bool test) const { if (test) { uint64 known, test = TestProperties(*this, mask, &known); impl_->SetProperties(test, known); return test & mask; } else { return impl_->Properties(mask); } } virtual const string& Type() const { return impl_->Type(); } virtual FactorWeightFst<A, F> *Copy() const { return new FactorWeightFst<A, F>(*this); } virtual const SymbolTable* InputSymbols() const { return impl_->InputSymbols(); } virtual const SymbolTable* OutputSymbols() const { return impl_->OutputSymbols(); } virtual inline void InitStateIterator(StateIteratorData<A> *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { impl_->InitArcIterator(s, data); } private: FactorWeightFstImpl<A, F> *Impl() { return impl_; } FactorWeightFstImpl<A, F> *impl_; void operator=(const FactorWeightFst<A, F> &fst); // Disallow }; // Specialization for FactorWeightFst. template<class A, class F> class StateIterator< FactorWeightFst<A, F> > : public CacheStateIterator< FactorWeightFst<A, F> > { public: explicit StateIterator(const FactorWeightFst<A, F> &fst) : CacheStateIterator< FactorWeightFst<A, F> >(fst) {} }; // Specialization for FactorWeightFst. template <class A, class F> class ArcIterator< FactorWeightFst<A, F> > : public CacheArcIterator< FactorWeightFst<A, F> > { public: typedef typename A::StateId StateId; ArcIterator(const FactorWeightFst<A, F> &fst, StateId s) : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) { if (!fst.impl_->HasArcs(s)) fst.impl_->Expand(s); } private: DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); }; template <class A, class F> inline void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const { data->base = new StateIterator< FactorWeightFst<A, F> >(*this); } } // namespace fst #endif // FST_LIB_FACTOR_WEIGHT_H__