// rational.h // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // \file // An Fst implementation and base interface for delayed unions, // concatenations and closures. #ifndef FST_LIB_RATIONAL_H__ #define FST_LIB_RATIONAL_H__ #include "fst/lib/map.h" #include "fst/lib/mutable-fst.h" #include "fst/lib/replace.h" #include "fst/lib/test-properties.h" namespace fst { typedef CacheOptions RationalFstOptions; // This specifies whether to add the empty string. enum ClosureType { CLOSURE_STAR = 0, // T* -> add the empty string CLOSURE_PLUS = 1 }; // T+ -> don't add the empty string template <class A> class RationalFst; template <class A> void Union(RationalFst<A> *fst1, const Fst<A> &fst2); template <class A> void Concat(RationalFst<A> *fst1, const Fst<A> &fst2); template <class A> void Closure(RationalFst<A> *fst, ClosureType closure_type); // Implementation class for delayed unions, concatenations and closures. template<class A> class RationalFstImpl : public ReplaceFstImpl<A> { public: using FstImpl<A>::SetType; using FstImpl<A>::SetProperties; using FstImpl<A>::Properties; using FstImpl<A>::SetInputSymbols; using FstImpl<A>::SetOutputSymbols; using ReplaceFstImpl<A>::SetRoot; typedef typename A::Weight Weight; typedef typename A::Label Label; explicit RationalFstImpl(const RationalFstOptions &opts) : ReplaceFstImpl<A>(ReplaceFstOptions(opts, kNoLabel)), nonterminals_(0) { SetType("rational"); } // Implementation of UnionFst(fst1,fst2) void InitUnion(const Fst<A> &fst1, const Fst<A> &fst2) { uint64 props1 = fst1.Properties(kFstProperties, false); uint64 props2 = fst2.Properties(kFstProperties, false); SetInputSymbols(fst1.InputSymbols()); SetOutputSymbols(fst1.OutputSymbols()); rfst_.AddState(); rfst_.AddState(); rfst_.SetStart(0); rfst_.SetFinal(1, Weight::One()); rfst_.SetInputSymbols(fst1.InputSymbols()); rfst_.SetOutputSymbols(fst1.OutputSymbols()); nonterminals_ = 2; rfst_.AddArc(0, A(0, -1, Weight::One(), 1)); rfst_.AddArc(0, A(0, -2, Weight::One(), 1)); AddFst(0, &rfst_); AddFst(-1, &fst1); AddFst(-2, &fst2); SetRoot(0); SetProperties(UnionProperties(props1, props2, true), kCopyProperties); } // Implementation of ConcatFst(fst1,fst2) void InitConcat(const Fst<A> &fst1, const Fst<A> &fst2) { uint64 props1 = fst1.Properties(kFstProperties, false); uint64 props2 = fst2.Properties(kFstProperties, false); SetInputSymbols(fst1.InputSymbols()); SetOutputSymbols(fst1.OutputSymbols()); rfst_.AddState(); rfst_.AddState(); rfst_.AddState(); rfst_.SetStart(0); rfst_.SetFinal(2, Weight::One()); rfst_.SetInputSymbols(fst1.InputSymbols()); rfst_.SetOutputSymbols(fst1.OutputSymbols()); nonterminals_ = 2; rfst_.AddArc(0, A(0, -1, Weight::One(), 1)); rfst_.AddArc(1, A(0, -2, Weight::One(), 2)); AddFst(0, &rfst_); AddFst(-1, &fst1); AddFst(-2, &fst2); SetRoot(0); SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); } // Implementation of ClosureFst(fst, closure_type) void InitClosure(const Fst<A> &fst, ClosureType closure_type) { uint64 props = fst.Properties(kFstProperties, false); SetInputSymbols(fst.InputSymbols()); SetOutputSymbols(fst.OutputSymbols()); if (closure_type == CLOSURE_STAR) { rfst_.AddState(); rfst_.SetStart(0); rfst_.SetFinal(0, Weight::One()); rfst_.AddArc(0, A(0, -1, Weight::One(), 0)); } else { rfst_.AddState(); rfst_.AddState(); rfst_.SetStart(0); rfst_.SetFinal(1, Weight::One()); rfst_.AddArc(0, A(0, -1, Weight::One(), 1)); rfst_.AddArc(1, A(0, 0, Weight::One(), 0)); } rfst_.SetInputSymbols(fst.InputSymbols()); rfst_.SetOutputSymbols(fst.OutputSymbols()); AddFst(0, &rfst_); AddFst(-1, &fst); SetRoot(0); nonterminals_ = 1; SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), kCopyProperties); } // Implementation of Union(Fst &, RationalFst *) void AddUnion(const Fst<A> &fst) { uint64 props1 = Properties(); uint64 props2 = fst.Properties(kFstProperties, false); VectorFst<A> afst; afst.AddState(); afst.AddState(); afst.SetStart(0); afst.SetFinal(1, Weight::One()); afst.AddArc(0, A(0, -nonterminals_, Weight::One(), 1)); Union(&rfst_, afst); SetFst(0, &rfst_); ++nonterminals_; SetProperties(UnionProperties(props1, props2, true), kCopyProperties); } // Implementation of Concat(Fst &, RationalFst *) void AddConcat(const Fst<A> &fst) { uint64 props1 = Properties(); uint64 props2 = fst.Properties(kFstProperties, false); VectorFst<A> afst; afst.AddState(); afst.AddState(); afst.SetStart(0); afst.SetFinal(1, Weight::One()); afst.AddArc(0, A(0, -nonterminals_, Weight::One(), 1)); Concat(&rfst_, afst); SetFst(0, &rfst_); ++nonterminals_; SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); } // Implementation of Closure(RationalFst *, closure_type) void AddClosure(ClosureType closure_type) { uint64 props = Properties(); Closure(&rfst_, closure_type); SetFst(0, &rfst_); SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), kCopyProperties); } private: VectorFst<A> rfst_; // rational topology machine; uses neg. nonterminals Label nonterminals_; // # of nonterminals used DISALLOW_EVIL_CONSTRUCTORS(RationalFstImpl); }; // Parent class for the delayed rational operations - delayed union, // concatenation, and closure. This class attaches interface to // implementation and handles reference counting. template <class A> class RationalFst : public Fst<A> { public: friend class CacheStateIterator< RationalFst<A> >; friend class ArcIterator< RationalFst<A> >; friend class CacheArcIterator< RationalFst<A> >; friend void Union<>(RationalFst<A> *fst1, const Fst<A> &fst2); friend void Concat<>(RationalFst<A> *fst1, const Fst<A> &fst2); friend void Closure<>(RationalFst<A> *fst, ClosureType closure_type); typedef A Arc; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; virtual StateId Start() const { return impl_->Start(); } virtual Weight Final(StateId s) const { return impl_->Final(s); } virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } virtual size_t NumInputEpsilons(StateId s) const { return impl_->NumInputEpsilons(s); } virtual size_t NumOutputEpsilons(StateId s) const { return impl_->NumOutputEpsilons(s); } virtual uint64 Properties(uint64 mask, bool test) const { if (test) { uint64 known, test = TestProperties(*this, mask, &known); impl_->SetProperties(test, known); return test & mask; } else { return impl_->Properties(mask); } } virtual const string& Type() const { return impl_->Type(); } virtual const SymbolTable* InputSymbols() const { return impl_->InputSymbols(); } virtual const SymbolTable* OutputSymbols() const { return impl_->OutputSymbols(); } virtual inline void InitStateIterator(StateIteratorData<A> *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { impl_->InitArcIterator(s, data); } protected: RationalFst() : impl_(new RationalFstImpl<A>(RationalFstOptions())) {} explicit RationalFst(const RationalFstOptions &opts) : impl_(new RationalFstImpl<A>(opts)) {} RationalFst(const RationalFst<A> &fst) : impl_(fst.impl_) { impl_->IncrRefCount(); } virtual ~RationalFst() { if (!impl_->DecrRefCount()) delete impl_; } RationalFstImpl<A> *Impl() { return impl_; } private: RationalFstImpl<A> *impl_; void operator=(const RationalFst<A> &fst); // disallow }; // Specialization for RationalFst. template <class A> class StateIterator< RationalFst<A> > : public CacheStateIterator< RationalFst<A> > { public: explicit StateIterator(const RationalFst<A> &fst) : CacheStateIterator< RationalFst<A> >(fst) {} }; // Specialization for RationalFst. template <class A> class ArcIterator< RationalFst<A> > : public CacheArcIterator< RationalFst<A> > { public: typedef typename A::StateId StateId; ArcIterator(const RationalFst<A> &fst, StateId s) : CacheArcIterator< RationalFst<A> >(fst, s) { if (!fst.impl_->HasArcs(s)) fst.impl_->Expand(s); } private: DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); }; template <class A> inline void RationalFst<A>::InitStateIterator(StateIteratorData<A> *data) const { data->base = new StateIterator< RationalFst<A> >(*this); } } // namespace fst #endif // FST_LIB_RATIONAL_H__