// matcher.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Classes to allow matching labels leaving FST states. #ifndef FST_LIB_MATCHER_H__ #define FST_LIB_MATCHER_H__ #include <algorithm> #include <set> #include <fst/mutable-fst.h> // for all internal FST accessors namespace fst { // MATCHERS - these can find and iterate through requested labels at // FST states. In the simplest form, these are just some associative // map or search keyed on labels. More generally, they may // implement matching special labels that represent sets of labels // such as 'sigma' (all), 'rho' (rest), or 'phi' (fail). // The Matcher interface is: // // template <class F> // class Matcher { // public: // typedef F FST; // typedef F::Arc Arc; // typedef typename Arc::StateId StateId; // typedef typename Arc::Label Label; // typedef typename Arc::Weight Weight; // // // Required constructors. // Matcher(const F &fst, MatchType type); // // If safe=true, the copy is thread-safe. See Fst<>::Copy() // // for further doc. // Matcher(const Matcher &matcher, bool safe = false); // // // If safe=true, the copy is thread-safe. See Fst<>::Copy() // // for further doc. // Matcher<F> *Copy(bool safe = false) const; // // // Returns the match type that can be provided (depending on // // compatibility of the input FST). It is either // // the requested match type, MATCH_NONE, or MATCH_UNKNOWN. // // If 'test' is false, a constant time test is performed, but // // MATCH_UNKNOWN may be returned. If 'test' is true, // // a definite answer is returned, but may involve more costly // // computation (e.g., visiting the Fst). // MatchType Type(bool test) const; // // Specifies the current state. // void SetState(StateId s); // // // This finds matches to a label at the current state. // // Returns true if a match found. kNoLabel matches any // // 'non-consuming' transitions, e.g., epsilon transitions, // // which do not require a matching symbol. // bool Find(Label label); // // These iterate through any matches found: // bool Done() const; // No more matches. // const A& Value() const; // Current arc (when !Done) // void Next(); // Advance to next arc (when !Done) // // Initially and after SetState() the iterator methods // // have undefined behavior until Find() is called. // // // Return matcher FST. // const F& GetFst() const; // // This specifies the known Fst properties as viewed from this // // matcher. It takes as argument the input Fst's known properties. // uint64 Properties(uint64 props) const; // }; // // MATCHER FLAGS (see also kLookAheadFlags in lookahead-matcher.h) // // Matcher prefers being used as the matching side in composition. const uint32 kPreferMatch = 0x00000001; // Matcher needs to be used as the matching side in composition. const uint32 kRequireMatch = 0x00000002; // Flags used for basic matchers (see also lookahead.h). const uint32 kMatcherFlags = kPreferMatch | kRequireMatch; // Matcher interface, templated on the Arc definition; used // for matcher specializations that are returned by the // InitMatcher Fst method. template <class A> class MatcherBase { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Label Label; typedef typename A::Weight Weight; virtual ~MatcherBase() {} virtual MatcherBase<A> *Copy(bool safe = false) const = 0; virtual MatchType Type(bool test) const = 0; void SetState(StateId s) { SetState_(s); } bool Find(Label label) { return Find_(label); } bool Done() const { return Done_(); } const A& Value() const { return Value_(); } void Next() { Next_(); } virtual const Fst<A> &GetFst() const = 0; virtual uint64 Properties(uint64 props) const = 0; virtual uint32 Flags() const { return 0; } private: virtual void SetState_(StateId s) = 0; virtual bool Find_(Label label) = 0; virtual bool Done_() const = 0; virtual const A& Value_() const = 0; virtual void Next_() = 0; }; // A matcher that expects sorted labels on the side to be matched. // If match_type == MATCH_INPUT, epsilons match the implicit self loop // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any // actual epsilon transitions. If match_type == MATCH_OUTPUT, then // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched. template <class F> class SortedMatcher : public MatcherBase<typename F::Arc> { public: typedef F FST; typedef typename F::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; // Labels >= binary_label will be searched for by binary search, // o.w. linear search is used. SortedMatcher(const F &fst, MatchType match_type, Label binary_label = 1) : fst_(fst.Copy()), s_(kNoStateId), aiter_(0), match_type_(match_type), binary_label_(binary_label), match_label_(kNoLabel), narcs_(0), loop_(kNoLabel, 0, Weight::One(), kNoStateId), error_(false) { switch(match_type_) { case MATCH_INPUT: case MATCH_NONE: break; case MATCH_OUTPUT: swap(loop_.ilabel, loop_.olabel); break; default: FSTERROR() << "SortedMatcher: bad match type"; match_type_ = MATCH_NONE; error_ = true; } } SortedMatcher(const SortedMatcher<F> &matcher, bool safe = false) : fst_(matcher.fst_->Copy(safe)), s_(kNoStateId), aiter_(0), match_type_(matcher.match_type_), binary_label_(matcher.binary_label_), match_label_(kNoLabel), narcs_(0), loop_(matcher.loop_), error_(matcher.error_) {} virtual ~SortedMatcher() { if (aiter_) delete aiter_; delete fst_; } virtual SortedMatcher<F> *Copy(bool safe = false) const { return new SortedMatcher<F>(*this, safe); } virtual MatchType Type(bool test) const { if (match_type_ == MATCH_NONE) return match_type_; uint64 true_prop = match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted; uint64 false_prop = match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted; uint64 props = fst_->Properties(true_prop | false_prop, test); if (props & true_prop) return match_type_; else if (props & false_prop) return MATCH_NONE; else return MATCH_UNKNOWN; } void SetState(StateId s) { if (s_ == s) return; s_ = s; if (match_type_ == MATCH_NONE) { FSTERROR() << "SortedMatcher: bad match type"; error_ = true; } if (aiter_) delete aiter_; aiter_ = new ArcIterator<F>(*fst_, s); aiter_->SetFlags(kArcNoCache, kArcNoCache); narcs_ = internal::NumArcs(*fst_, s); loop_.nextstate = s; } bool Find(Label match_label) { exact_match_ = true; if (error_) { current_loop_ = false; match_label_ = kNoLabel; return false; } current_loop_ = match_label == 0; match_label_ = match_label == kNoLabel ? 0 : match_label; if (Search()) { return true; } else { return current_loop_; } } // Positions matcher to the first position where inserting // match_label would maintain the sort order. void LowerBound(Label match_label) { exact_match_ = false; current_loop_ = false; if (error_) { match_label_ = kNoLabel; return; } match_label_ = match_label; Search(); } // After Find(), returns false if no more exact matches. // After LowerBound(), returns false if no more arcs. bool Done() const { if (current_loop_) return false; if (aiter_->Done()) return true; if (!exact_match_) return false; aiter_->SetFlags( match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue, kArcValueFlags); Label label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel : aiter_->Value().olabel; return label != match_label_; } const Arc& Value() const { if (current_loop_) { return loop_; } aiter_->SetFlags(kArcValueFlags, kArcValueFlags); return aiter_->Value(); } void Next() { if (current_loop_) current_loop_ = false; else aiter_->Next(); } virtual const F &GetFst() const { return *fst_; } virtual uint64 Properties(uint64 inprops) const { uint64 outprops = inprops; if (error_) outprops |= kError; return outprops; } size_t Position() const { return aiter_ ? aiter_->Position() : 0; } private: virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } virtual bool Done_() const { return Done(); } virtual const Arc& Value_() const { return Value(); } virtual void Next_() { Next(); } bool Search(); const F *fst_; StateId s_; // Current state ArcIterator<F> *aiter_; // Iterator for current state MatchType match_type_; // Type of match to perform Label binary_label_; // Least label for binary search Label match_label_; // Current label to be matched size_t narcs_; // Current state arc count Arc loop_; // For non-consuming symbols bool current_loop_; // Current arc is the implicit loop bool exact_match_; // Exact match or lower bound? bool error_; // Error encountered void operator=(const SortedMatcher<F> &); // Disallow }; // Returns true iff match to match_label_. Positions arc iterator at // lower bound regardless. template <class F> inline bool SortedMatcher<F>::Search() { aiter_->SetFlags( match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue, kArcValueFlags); if (match_label_ >= binary_label_) { // Binary search for match. size_t low = 0; size_t high = narcs_; while (low < high) { size_t mid = (low + high) / 2; aiter_->Seek(mid); Label label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel : aiter_->Value().olabel; if (label > match_label_) { high = mid; } else if (label < match_label_) { low = mid + 1; } else { // find first matching label (when non-determinism) for (size_t i = mid; i > low; --i) { aiter_->Seek(i - 1); label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel : aiter_->Value().olabel; if (label != match_label_) { aiter_->Seek(i); return true; } } return true; } } aiter_->Seek(low); return false; } else { // Linear search for match. for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) { Label label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel : aiter_->Value().olabel; if (label == match_label_) { return true; } if (label > match_label_) break; } return false; } } // Specifies whether during matching we rewrite both the input and output sides. enum MatcherRewriteMode { MATCHER_REWRITE_AUTO = 0, // Rewrites both sides iff acceptor. MATCHER_REWRITE_ALWAYS, MATCHER_REWRITE_NEVER }; // For any requested label that doesn't match at a state, this matcher // considers all transitions that match the label 'rho_label' (rho = // 'rest'). Each such rho transition found is returned with the // rho_label rewritten as the requested label (both sides if an // acceptor, or if 'rewrite_both' is true and both input and output // labels of the found transition are 'rho_label'). If 'rho_label' is // kNoLabel, this special matching is not done. RhoMatcher is // templated itself on a matcher, which is used to perform the // underlying matching. By default, the underlying matcher is // constructed by RhoMatcher. The user can instead pass in this // object; in that case, RhoMatcher takes its ownership. template <class M> class RhoMatcher : public MatcherBase<typename M::Arc> { public: typedef typename M::FST FST; typedef typename M::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel, MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, M *matcher = 0) : matcher_(matcher ? matcher : new M(fst, match_type)), match_type_(match_type), rho_label_(rho_label), error_(false) { if (match_type == MATCH_BOTH) { FSTERROR() << "RhoMatcher: bad match type"; match_type_ = MATCH_NONE; error_ = true; } if (rho_label == 0) { FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label"; rho_label_ = kNoLabel; error_ = true; } if (rewrite_mode == MATCHER_REWRITE_AUTO) rewrite_both_ = fst.Properties(kAcceptor, true); else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) rewrite_both_ = true; else rewrite_both_ = false; } RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false) : matcher_(new M(*matcher.matcher_, safe)), match_type_(matcher.match_type_), rho_label_(matcher.rho_label_), rewrite_both_(matcher.rewrite_both_), error_(matcher.error_) {} virtual ~RhoMatcher() { delete matcher_; } virtual RhoMatcher<M> *Copy(bool safe = false) const { return new RhoMatcher<M>(*this, safe); } virtual MatchType Type(bool test) const { return matcher_->Type(test); } void SetState(StateId s) { matcher_->SetState(s); has_rho_ = rho_label_ != kNoLabel; } bool Find(Label match_label) { if (match_label == rho_label_ && rho_label_ != kNoLabel) { FSTERROR() << "RhoMatcher::Find: bad label (rho)"; error_ = true; return false; } if (matcher_->Find(match_label)) { rho_match_ = kNoLabel; return true; } else if (has_rho_ && match_label != 0 && match_label != kNoLabel && (has_rho_ = matcher_->Find(rho_label_))) { rho_match_ = match_label; return true; } else { return false; } } bool Done() const { return matcher_->Done(); } const Arc& Value() const { if (rho_match_ == kNoLabel) { return matcher_->Value(); } else { rho_arc_ = matcher_->Value(); if (rewrite_both_) { if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_; if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_; } else if (match_type_ == MATCH_INPUT) { rho_arc_.ilabel = rho_match_; } else { rho_arc_.olabel = rho_match_; } return rho_arc_; } } void Next() { matcher_->Next(); } virtual const FST &GetFst() const { return matcher_->GetFst(); } virtual uint64 Properties(uint64 props) const; virtual uint32 Flags() const { if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) return matcher_->Flags(); return matcher_->Flags() | kRequireMatch; } private: virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } virtual bool Done_() const { return Done(); } virtual const Arc& Value_() const { return Value(); } virtual void Next_() { Next(); } M *matcher_; MatchType match_type_; // Type of match requested Label rho_label_; // Label that represents the rho transition bool rewrite_both_; // Rewrite both sides when both are 'rho_label_' bool has_rho_; // Are there possibly rhos at the current state? Label rho_match_; // Current label that matches rho transition mutable Arc rho_arc_; // Arc to return when rho match bool error_; // Error encountered void operator=(const RhoMatcher<M> &); // Disallow }; template <class M> inline uint64 RhoMatcher<M>::Properties(uint64 inprops) const { uint64 outprops = matcher_->Properties(inprops); if (error_) outprops |= kError; if (match_type_ == MATCH_NONE) { return outprops; } else if (match_type_ == MATCH_INPUT) { if (rewrite_both_) { return outprops & ~(kODeterministic | kNonODeterministic | kString | kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); } else { return outprops & ~(kODeterministic | kAcceptor | kString | kILabelSorted | kNotILabelSorted); } } else if (match_type_ == MATCH_OUTPUT) { if (rewrite_both_) { return outprops & ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); } else { return outprops & ~(kIDeterministic | kAcceptor | kString | kOLabelSorted | kNotOLabelSorted); } } else { // Shouldn't ever get here. FSTERROR() << "RhoMatcher:: bad match type: " << match_type_; return 0; } } // For any requested label, this matcher considers all transitions // that match the label 'sigma_label' (sigma = "any"), and this in // additions to transitions with the requested label. Each such sigma // transition found is returned with the sigma_label rewritten as the // requested label (both sides if an acceptor, or if 'rewrite_both' is // true and both input and output labels of the found transition are // 'sigma_label'). If 'sigma_label' is kNoLabel, this special // matching is not done. SigmaMatcher is templated itself on a // matcher, which is used to perform the underlying matching. By // default, the underlying matcher is constructed by SigmaMatcher. // The user can instead pass in this object; in that case, // SigmaMatcher takes its ownership. template <class M> class SigmaMatcher : public MatcherBase<typename M::Arc> { public: typedef typename M::FST FST; typedef typename M::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; SigmaMatcher(const FST &fst, MatchType match_type, Label sigma_label = kNoLabel, MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, M *matcher = 0) : matcher_(matcher ? matcher : new M(fst, match_type)), match_type_(match_type), sigma_label_(sigma_label), error_(false) { if (match_type == MATCH_BOTH) { FSTERROR() << "SigmaMatcher: bad match type"; match_type_ = MATCH_NONE; error_ = true; } if (sigma_label == 0) { FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label"; sigma_label_ = kNoLabel; error_ = true; } if (rewrite_mode == MATCHER_REWRITE_AUTO) rewrite_both_ = fst.Properties(kAcceptor, true); else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) rewrite_both_ = true; else rewrite_both_ = false; } SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false) : matcher_(new M(*matcher.matcher_, safe)), match_type_(matcher.match_type_), sigma_label_(matcher.sigma_label_), rewrite_both_(matcher.rewrite_both_), error_(matcher.error_) {} virtual ~SigmaMatcher() { delete matcher_; } virtual SigmaMatcher<M> *Copy(bool safe = false) const { return new SigmaMatcher<M>(*this, safe); } virtual MatchType Type(bool test) const { return matcher_->Type(test); } void SetState(StateId s) { matcher_->SetState(s); has_sigma_ = sigma_label_ != kNoLabel ? matcher_->Find(sigma_label_) : false; } bool Find(Label match_label) { match_label_ = match_label; if (match_label == sigma_label_ && sigma_label_ != kNoLabel) { FSTERROR() << "SigmaMatcher::Find: bad label (sigma)"; error_ = true; return false; } if (matcher_->Find(match_label)) { sigma_match_ = kNoLabel; return true; } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel && matcher_->Find(sigma_label_)) { sigma_match_ = match_label; return true; } else { return false; } } bool Done() const { return matcher_->Done(); } const Arc& Value() const { if (sigma_match_ == kNoLabel) { return matcher_->Value(); } else { sigma_arc_ = matcher_->Value(); if (rewrite_both_) { if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_; if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_; } else if (match_type_ == MATCH_INPUT) { sigma_arc_.ilabel = sigma_match_; } else { sigma_arc_.olabel = sigma_match_; } return sigma_arc_; } } void Next() { matcher_->Next(); if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) && (match_label_ > 0)) { matcher_->Find(sigma_label_); sigma_match_ = match_label_; } } virtual const FST &GetFst() const { return matcher_->GetFst(); } virtual uint64 Properties(uint64 props) const; virtual uint32 Flags() const { if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) return matcher_->Flags(); // kRequireMatch temporarily disabled until issues // in //speech/gaudi/annotation/util/denorm are resolved. // return matcher_->Flags() | kRequireMatch; return matcher_->Flags(); } private: virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } virtual bool Done_() const { return Done(); } virtual const Arc& Value_() const { return Value(); } virtual void Next_() { Next(); } M *matcher_; MatchType match_type_; // Type of match requested Label sigma_label_; // Label that represents the sigma transition bool rewrite_both_; // Rewrite both sides when both are 'sigma_label_' bool has_sigma_; // Are there sigmas at the current state? Label sigma_match_; // Current label that matches sigma transition mutable Arc sigma_arc_; // Arc to return when sigma match Label match_label_; // Label being matched bool error_; // Error encountered void operator=(const SigmaMatcher<M> &); // disallow }; template <class M> inline uint64 SigmaMatcher<M>::Properties(uint64 inprops) const { uint64 outprops = matcher_->Properties(inprops); if (error_) outprops |= kError; if (match_type_ == MATCH_NONE) { return outprops; } else if (rewrite_both_) { return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic | kNonODeterministic | kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | kString); } else if (match_type_ == MATCH_INPUT) { return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic | kNonODeterministic | kILabelSorted | kNotILabelSorted | kString | kAcceptor); } else if (match_type_ == MATCH_OUTPUT) { return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic | kNonODeterministic | kOLabelSorted | kNotOLabelSorted | kString | kAcceptor); } else { // Shouldn't ever get here. FSTERROR() << "SigmaMatcher:: bad match type: " << match_type_; return 0; } } // For any requested label that doesn't match at a state, this matcher // considers the *unique* transition that matches the label 'phi_label' // (phi = 'fail'), and recursively looks for a match at its // destination. When 'phi_loop' is true, if no match is found but a // phi self-loop is found, then the phi transition found is returned // with the phi_label rewritten as the requested label (both sides if // an acceptor, or if 'rewrite_both' is true and both input and output // labels of the found transition are 'phi_label'). If 'phi_label' is // kNoLabel, this special matching is not done. PhiMatcher is // templated itself on a matcher, which is used to perform the // underlying matching. By default, the underlying matcher is // constructed by PhiMatcher. The user can instead pass in this // object; in that case, PhiMatcher takes its ownership. // Warning: phi non-determinism not supported (for simplicity). template <class M> class PhiMatcher : public MatcherBase<typename M::Arc> { public: typedef typename M::FST FST; typedef typename M::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel, bool phi_loop = true, MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, M *matcher = 0) : matcher_(matcher ? matcher : new M(fst, match_type)), match_type_(match_type), phi_label_(phi_label), state_(kNoStateId), phi_loop_(phi_loop), error_(false) { if (match_type == MATCH_BOTH) { FSTERROR() << "PhiMatcher: bad match type"; match_type_ = MATCH_NONE; error_ = true; } if (rewrite_mode == MATCHER_REWRITE_AUTO) rewrite_both_ = fst.Properties(kAcceptor, true); else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) rewrite_both_ = true; else rewrite_both_ = false; } PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false) : matcher_(new M(*matcher.matcher_, safe)), match_type_(matcher.match_type_), phi_label_(matcher.phi_label_), rewrite_both_(matcher.rewrite_both_), state_(kNoStateId), phi_loop_(matcher.phi_loop_), error_(matcher.error_) {} virtual ~PhiMatcher() { delete matcher_; } virtual PhiMatcher<M> *Copy(bool safe = false) const { return new PhiMatcher<M>(*this, safe); } virtual MatchType Type(bool test) const { return matcher_->Type(test); } void SetState(StateId s) { matcher_->SetState(s); state_ = s; has_phi_ = phi_label_ != kNoLabel; } bool Find(Label match_label); bool Done() const { return matcher_->Done(); } const Arc& Value() const { if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) { return matcher_->Value(); } else if (phi_match_ == 0) { // Virtual epsilon loop phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_); if (match_type_ == MATCH_OUTPUT) swap(phi_arc_.ilabel, phi_arc_.olabel); return phi_arc_; } else { phi_arc_ = matcher_->Value(); phi_arc_.weight = Times(phi_weight_, phi_arc_.weight); if (phi_match_ != kNoLabel) { // Phi loop match if (rewrite_both_) { if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_; if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_; } else if (match_type_ == MATCH_INPUT) { phi_arc_.ilabel = phi_match_; } else { phi_arc_.olabel = phi_match_; } } return phi_arc_; } } void Next() { matcher_->Next(); } virtual const FST &GetFst() const { return matcher_->GetFst(); } virtual uint64 Properties(uint64 props) const; virtual uint32 Flags() const { if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) return matcher_->Flags(); return matcher_->Flags() | kRequireMatch; } private: virtual void SetState_(StateId s) { SetState(s); } virtual bool Find_(Label label) { return Find(label); } virtual bool Done_() const { return Done(); } virtual const Arc& Value_() const { return Value(); } virtual void Next_() { Next(); } M *matcher_; MatchType match_type_; // Type of match requested Label phi_label_; // Label that represents the phi transition bool rewrite_both_; // Rewrite both sides when both are 'phi_label_' bool has_phi_; // Are there possibly phis at the current state? Label phi_match_; // Current label that matches phi loop mutable Arc phi_arc_; // Arc to return StateId state_; // State where looking for matches Weight phi_weight_; // Product of the weights of phi transitions taken bool phi_loop_; // When true, phi self-loop are allowed and treated // as rho (required for Aho-Corasick) bool error_; // Error encountered void operator=(const PhiMatcher<M> &); // disallow }; template <class M> inline bool PhiMatcher<M>::Find(Label match_label) { if (match_label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) { FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_; error_ = true; return false; } matcher_->SetState(state_); phi_match_ = kNoLabel; phi_weight_ = Weight::One(); if (phi_label_ == 0) { // When 'phi_label_ == 0', if (match_label == kNoLabel) // there are no more true epsilon arcs, return false; if (match_label == 0) { // but virtual eps loop need to be returned if (!matcher_->Find(kNoLabel)) { return matcher_->Find(0); } else { phi_match_ = 0; return true; } } } if (!has_phi_ || match_label == 0 || match_label == kNoLabel) return matcher_->Find(match_label); StateId state = state_; while (!matcher_->Find(match_label)) { // Look for phi transition (if phi_label_ == 0, we need to look // for -1 to avoid getting the virtual self-loop) if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false; if (phi_loop_ && matcher_->Value().nextstate == state) { phi_match_ = match_label; return true; } phi_weight_ = Times(phi_weight_, matcher_->Value().weight); state = matcher_->Value().nextstate; matcher_->Next(); if (!matcher_->Done()) { FSTERROR() << "PhiMatcher: phi non-determinism not supported"; error_ = true; } matcher_->SetState(state); } return true; } template <class M> inline uint64 PhiMatcher<M>::Properties(uint64 inprops) const { uint64 outprops = matcher_->Properties(inprops); if (error_) outprops |= kError; if (match_type_ == MATCH_NONE) { return outprops; } else if (match_type_ == MATCH_INPUT) { if (phi_label_ == 0) { outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; outprops |= kNoEpsilons | kNoIEpsilons; } if (rewrite_both_) { return outprops & ~(kODeterministic | kNonODeterministic | kString | kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); } else { return outprops & ~(kODeterministic | kAcceptor | kString | kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); } } else if (match_type_ == MATCH_OUTPUT) { if (phi_label_ == 0) { outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; outprops |= kNoEpsilons | kNoOEpsilons; } if (rewrite_both_) { return outprops & ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); } else { return outprops & ~(kIDeterministic | kAcceptor | kString | kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); } } else { // Shouldn't ever get here. FSTERROR() << "PhiMatcher:: bad match type: " << match_type_; return 0; } } // // MULTI-EPS MATCHER FLAGS // // Return multi-epsilon arcs for Find(kNoLabel). const uint32 kMultiEpsList = 0x00000001; // Return a kNolabel loop for Find(multi_eps). const uint32 kMultiEpsLoop = 0x00000002; // MultiEpsMatcher: allows treating multiple non-0 labels as // non-consuming labels in addition to 0 that is always // non-consuming. Precise behavior controlled by 'flags' argument. By // default, the underlying matcher is constructed by // MultiEpsMatcher. The user can instead pass in this object; in that // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is // true. template <class M> class MultiEpsMatcher { public: typedef typename M::FST FST; typedef typename M::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; MultiEpsMatcher(const FST &fst, MatchType match_type, uint32 flags = (kMultiEpsLoop | kMultiEpsList), M *matcher = 0, bool own_matcher = true) : matcher_(matcher ? matcher : new M(fst, match_type)), flags_(flags), own_matcher_(matcher ? own_matcher : true) { if (match_type == MATCH_INPUT) { loop_.ilabel = kNoLabel; loop_.olabel = 0; } else { loop_.ilabel = 0; loop_.olabel = kNoLabel; } loop_.weight = Weight::One(); loop_.nextstate = kNoStateId; } MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false) : matcher_(new M(*matcher.matcher_, safe)), flags_(matcher.flags_), own_matcher_(true), multi_eps_labels_(matcher.multi_eps_labels_), loop_(matcher.loop_) { loop_.nextstate = kNoStateId; } ~MultiEpsMatcher() { if (own_matcher_) delete matcher_; } MultiEpsMatcher<M> *Copy(bool safe = false) const { return new MultiEpsMatcher<M>(*this, safe); } MatchType Type(bool test) const { return matcher_->Type(test); } void SetState(StateId s) { matcher_->SetState(s); loop_.nextstate = s; } bool Find(Label match_label); bool Done() const { return done_; } const Arc& Value() const { return current_loop_ ? loop_ : matcher_->Value(); } void Next() { if (!current_loop_) { matcher_->Next(); done_ = matcher_->Done(); if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) { ++multi_eps_iter_; while ((multi_eps_iter_ != multi_eps_labels_.End()) && !matcher_->Find(*multi_eps_iter_)) ++multi_eps_iter_; if (multi_eps_iter_ != multi_eps_labels_.End()) done_ = false; else done_ = !matcher_->Find(kNoLabel); } } else { done_ = true; } } const FST &GetFst() const { return matcher_->GetFst(); } uint64 Properties(uint64 props) const { return matcher_->Properties(props); } uint32 Flags() const { return matcher_->Flags(); } void AddMultiEpsLabel(Label label) { if (label == 0) { FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; } else { multi_eps_labels_.Insert(label); } } void RemoveMultiEpsLabel(Label label) { if (label == 0) { FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; } else { multi_eps_labels_.Erase(label); } } void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); } private: M *matcher_; uint32 flags_; bool own_matcher_; // Does this class delete the matcher? // Multi-eps label set CompactSet<Label, kNoLabel> multi_eps_labels_; typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_; bool current_loop_; // Current arc is the implicit loop mutable Arc loop_; // For non-consuming symbols bool done_; // Matching done void operator=(const MultiEpsMatcher<M> &); // Disallow }; template <class M> inline bool MultiEpsMatcher<M>::Find(Label match_label) { multi_eps_iter_ = multi_eps_labels_.End(); current_loop_ = false; bool ret; if (match_label == 0) { ret = matcher_->Find(0); } else if (match_label == kNoLabel) { if (flags_ & kMultiEpsList) { // return all non-consuming arcs (incl. epsilon) multi_eps_iter_ = multi_eps_labels_.Begin(); while ((multi_eps_iter_ != multi_eps_labels_.End()) && !matcher_->Find(*multi_eps_iter_)) ++multi_eps_iter_; if (multi_eps_iter_ != multi_eps_labels_.End()) ret = true; else ret = matcher_->Find(kNoLabel); } else { // return all epsilon arcs ret = matcher_->Find(kNoLabel); } } else if ((flags_ & kMultiEpsLoop) && multi_eps_labels_.Find(match_label) != multi_eps_labels_.End()) { // return 'implicit' loop current_loop_ = true; ret = true; } else { ret = matcher_->Find(match_label); } done_ = !ret; return ret; } // Generic matcher, templated on the FST definition // - a wrapper around pointer to specific one. // Here is a typical use: \code // Matcher<StdFst> matcher(fst, MATCH_INPUT); // matcher.SetState(state); // if (matcher.Find(label)) // for (; !matcher.Done(); matcher.Next()) { // StdArc &arc = matcher.Value(); // ... // } \endcode template <class F> class Matcher { public: typedef F FST; typedef typename F::Arc Arc; typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; Matcher(const F &fst, MatchType match_type) { base_ = fst.InitMatcher(match_type); if (!base_) base_ = new SortedMatcher<F>(fst, match_type); } Matcher(const Matcher<F> &matcher, bool safe = false) { base_ = matcher.base_->Copy(safe); } // Takes ownership of the provided matcher Matcher(MatcherBase<Arc>* base_matcher) { base_ = base_matcher; } ~Matcher() { delete base_; } Matcher<F> *Copy(bool safe = false) const { return new Matcher<F>(*this, safe); } MatchType Type(bool test) const { return base_->Type(test); } void SetState(StateId s) { base_->SetState(s); } bool Find(Label label) { return base_->Find(label); } bool Done() const { return base_->Done(); } const Arc& Value() const { return base_->Value(); } void Next() { base_->Next(); } const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); } uint64 Properties(uint64 props) const { return base_->Properties(props); } uint32 Flags() const { return base_->Flags() & kMatcherFlags; } private: MatcherBase<Arc> *base_; void operator=(const Matcher<Arc> &); // disallow }; } // namespace fst #endif // FST_LIB_MATCHER_H__