// matcher.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Classes to allow matching labels leaving FST states.

#ifndef FST_LIB_MATCHER_H__
#define FST_LIB_MATCHER_H__

#include <algorithm>
#include <set>

#include <fst/mutable-fst.h>  // for all internal FST accessors


namespace fst {

// MATCHERS - these can find and iterate through requested labels at
// FST states. In the simplest form, these are just some associative
// map or search keyed on labels. More generally, they may
// implement matching special labels that represent sets of labels
// such as 'sigma' (all), 'rho' (rest), or 'phi' (fail).
// The Matcher interface is:
//
// template <class F>
// class Matcher {
//  public:
//   typedef F FST;
//   typedef F::Arc Arc;
//   typedef typename Arc::StateId StateId;
//   typedef typename Arc::Label Label;
//   typedef typename Arc::Weight Weight;
//
//   // Required constructors.
//   Matcher(const F &fst, MatchType type);
//   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
//   // for further doc.
//   Matcher(const Matcher &matcher, bool safe = false);
//
//   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
//   // for further doc.
//   Matcher<F> *Copy(bool safe = false) const;
//
//   // Returns the match type that can be provided (depending on
//   // compatibility of the input FST). It is either
//   // the requested match type, MATCH_NONE, or MATCH_UNKNOWN.
//   // If 'test' is false, a constant time test is performed, but
//   // MATCH_UNKNOWN may be returned. If 'test' is true,
//   // a definite answer is returned, but may involve more costly
//   // computation (e.g., visiting the Fst).
//   MatchType Type(bool test) const;
//   // Specifies the current state.
//   void SetState(StateId s);
//
//   // This finds matches to a label at the current state.
//   // Returns true if a match found. kNoLabel matches any
//   // 'non-consuming' transitions, e.g., epsilon transitions,
//   // which do not require a matching symbol.
//   bool Find(Label label);
//   // These iterate through any matches found:
//   bool Done() const;         // No more matches.
//   const A& Value() const;    // Current arc (when !Done)
//   void Next();               // Advance to next arc (when !Done)
//   // Initially and after SetState() the iterator methods
//   // have undefined behavior until Find() is called.
//
//   // Return matcher FST.
//   const F& GetFst() const;
//   // This specifies the known Fst properties as viewed from this
//   // matcher. It takes as argument the input Fst's known properties.
//   uint64 Properties(uint64 props) const;
// };

//
// MATCHER FLAGS (see also kLookAheadFlags in lookahead-matcher.h)
//
// Matcher prefers being used as the matching side in composition.
const uint32 kPreferMatch  = 0x00000001;

// Matcher needs to be used as the matching side in composition.
const uint32 kRequireMatch = 0x00000002;

// Flags used for basic matchers (see also lookahead.h).
const uint32 kMatcherFlags = kPreferMatch | kRequireMatch;

// Matcher interface, templated on the Arc definition; used
// for matcher specializations that are returned by the
// InitMatcher Fst method.
template <class A>
class MatcherBase {
 public:
  typedef A Arc;
  typedef typename A::StateId StateId;
  typedef typename A::Label Label;
  typedef typename A::Weight Weight;

  virtual ~MatcherBase() {}

  virtual MatcherBase<A> *Copy(bool safe = false) const = 0;
  virtual MatchType Type(bool test) const = 0;
  void SetState(StateId s) { SetState_(s); }
  bool Find(Label label) { return Find_(label); }
  bool Done() const { return Done_(); }
  const A& Value() const { return Value_(); }
  void Next() { Next_(); }
  virtual const Fst<A> &GetFst() const = 0;
  virtual uint64 Properties(uint64 props) const = 0;
  virtual uint32 Flags() const { return 0; }
 private:
  virtual void SetState_(StateId s) = 0;
  virtual bool Find_(Label label) = 0;
  virtual bool Done_() const = 0;
  virtual const A& Value_() const  = 0;
  virtual void Next_()  = 0;
};


// A matcher that expects sorted labels on the side to be matched.
// If match_type == MATCH_INPUT, epsilons match the implicit self loop
// Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
// actual epsilon transitions. If match_type == MATCH_OUTPUT, then
// Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
template <class F>
class SortedMatcher : public MatcherBase<typename F::Arc> {
 public:
  typedef F FST;
  typedef typename F::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  // Labels >= binary_label will be searched for by binary search,
  // o.w. linear search is used.
  SortedMatcher(const F &fst, MatchType match_type,
                Label binary_label = 1)
      : fst_(fst.Copy()),
        s_(kNoStateId),
        aiter_(0),
        match_type_(match_type),
        binary_label_(binary_label),
        match_label_(kNoLabel),
        narcs_(0),
        loop_(kNoLabel, 0, Weight::One(), kNoStateId),
        error_(false) {
    switch(match_type_) {
      case MATCH_INPUT:
      case MATCH_NONE:
        break;
      case MATCH_OUTPUT:
        swap(loop_.ilabel, loop_.olabel);
        break;
      default:
        FSTERROR() << "SortedMatcher: bad match type";
        match_type_ = MATCH_NONE;
        error_ = true;
    }
  }

  SortedMatcher(const SortedMatcher<F> &matcher, bool safe = false)
      : fst_(matcher.fst_->Copy(safe)),
        s_(kNoStateId),
        aiter_(0),
        match_type_(matcher.match_type_),
        binary_label_(matcher.binary_label_),
        match_label_(kNoLabel),
        narcs_(0),
        loop_(matcher.loop_),
        error_(matcher.error_) {}

  virtual ~SortedMatcher() {
    if (aiter_)
      delete aiter_;
    delete fst_;
  }

  virtual SortedMatcher<F> *Copy(bool safe = false) const {
    return new SortedMatcher<F>(*this, safe);
  }

  virtual MatchType Type(bool test) const {
    if (match_type_ == MATCH_NONE)
      return match_type_;

    uint64 true_prop =  match_type_ == MATCH_INPUT ?
        kILabelSorted : kOLabelSorted;
    uint64 false_prop = match_type_ == MATCH_INPUT ?
        kNotILabelSorted : kNotOLabelSorted;
    uint64 props = fst_->Properties(true_prop | false_prop, test);

    if (props & true_prop)
      return match_type_;
    else if (props & false_prop)
      return MATCH_NONE;
    else
      return MATCH_UNKNOWN;
  }

  void SetState(StateId s) {
    if (s_ == s)
      return;
    s_ = s;
    if (match_type_ == MATCH_NONE) {
      FSTERROR() << "SortedMatcher: bad match type";
      error_ = true;
    }
    if (aiter_)
      delete aiter_;
    aiter_ = new ArcIterator<F>(*fst_, s);
    aiter_->SetFlags(kArcNoCache, kArcNoCache);
    narcs_ = internal::NumArcs(*fst_, s);
    loop_.nextstate = s;
  }

  bool Find(Label match_label) {
    exact_match_ = true;
    if (error_) {
      current_loop_ = false;
      match_label_ = kNoLabel;
      return false;
    }
    current_loop_ = match_label == 0;
    match_label_ = match_label == kNoLabel ? 0 : match_label;
    if (Search()) {
      return true;
    } else {
      return current_loop_;
    }
  }

  // Positions matcher to the first position where inserting
  // match_label would maintain the sort order.
  void LowerBound(Label match_label) {
    exact_match_ = false;
    current_loop_ = false;
    if (error_) {
      match_label_ = kNoLabel;
      return;
    }
    match_label_ = match_label;
    Search();
  }

  // After Find(), returns false if no more exact matches.
  // After LowerBound(), returns false if no more arcs.
  bool Done() const {
    if (current_loop_)
      return false;
    if (aiter_->Done())
      return true;
    if (!exact_match_)
      return false;
    aiter_->SetFlags(
      match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
      kArcValueFlags);
    Label label = match_type_ == MATCH_INPUT ?
        aiter_->Value().ilabel : aiter_->Value().olabel;
    return label != match_label_;
  }

  const Arc& Value() const {
    if (current_loop_) {
      return loop_;
    }
    aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
    return aiter_->Value();
  }

  void Next() {
    if (current_loop_)
      current_loop_ = false;
    else
      aiter_->Next();
  }

  virtual const F &GetFst() const { return *fst_; }

  virtual uint64 Properties(uint64 inprops) const {
    uint64 outprops = inprops;
    if (error_) outprops |= kError;
    return outprops;
  }

  size_t Position() const { return aiter_ ? aiter_->Position() : 0; }

 private:
  virtual void SetState_(StateId s) { SetState(s); }
  virtual bool Find_(Label label) { return Find(label); }
  virtual bool Done_() const { return Done(); }
  virtual const Arc& Value_() const { return Value(); }
  virtual void Next_() { Next(); }

  bool Search();

  const F *fst_;
  StateId s_;                     // Current state
  ArcIterator<F> *aiter_;         // Iterator for current state
  MatchType match_type_;          // Type of match to perform
  Label binary_label_;            // Least label for binary search
  Label match_label_;             // Current label to be matched
  size_t narcs_;                  // Current state arc count
  Arc loop_;                      // For non-consuming symbols
  bool current_loop_;             // Current arc is the implicit loop
  bool exact_match_;              // Exact match or lower bound?
  bool error_;                    // Error encountered

  void operator=(const SortedMatcher<F> &);  // Disallow
};

// Returns true iff match to match_label_. Positions arc iterator at
// lower bound regardless.
template <class F> inline
bool SortedMatcher<F>::Search() {
  aiter_->SetFlags(
      match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
      kArcValueFlags);
  if (match_label_ >= binary_label_) {
    // Binary search for match.
    size_t low = 0;
    size_t high = narcs_;
    while (low < high) {
      size_t mid = (low + high) / 2;
      aiter_->Seek(mid);
      Label label = match_type_ == MATCH_INPUT ?
          aiter_->Value().ilabel : aiter_->Value().olabel;
      if (label > match_label_) {
        high = mid;
      } else if (label < match_label_) {
        low = mid + 1;
      } else {
        // find first matching label (when non-determinism)
        for (size_t i = mid; i > low; --i) {
          aiter_->Seek(i - 1);
          label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel :
              aiter_->Value().olabel;
          if (label != match_label_) {
            aiter_->Seek(i);
            return true;
          }
        }
        return true;
      }
    }
    aiter_->Seek(low);
    return false;
  } else {
    // Linear search for match.
    for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
      Label label = match_type_ == MATCH_INPUT ?
          aiter_->Value().ilabel : aiter_->Value().olabel;
      if (label == match_label_) {
        return true;
      }
      if (label > match_label_)
        break;
    }
    return false;
  }
}


// Specifies whether during matching we rewrite both the input and output sides.
enum MatcherRewriteMode {
  MATCHER_REWRITE_AUTO = 0,    // Rewrites both sides iff acceptor.
  MATCHER_REWRITE_ALWAYS,
  MATCHER_REWRITE_NEVER
};


// For any requested label that doesn't match at a state, this matcher
// considers all transitions that match the label 'rho_label' (rho =
// 'rest').  Each such rho transition found is returned with the
// rho_label rewritten as the requested label (both sides if an
// acceptor, or if 'rewrite_both' is true and both input and output
// labels of the found transition are 'rho_label').  If 'rho_label' is
// kNoLabel, this special matching is not done.  RhoMatcher is
// templated itself on a matcher, which is used to perform the
// underlying matching.  By default, the underlying matcher is
// constructed by RhoMatcher.  The user can instead pass in this
// object; in that case, RhoMatcher takes its ownership.
template <class M>
class RhoMatcher : public MatcherBase<typename M::Arc> {
 public:
  typedef typename M::FST FST;
  typedef typename M::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  RhoMatcher(const FST &fst,
             MatchType match_type,
             Label rho_label = kNoLabel,
             MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
             M *matcher = 0)
      : matcher_(matcher ? matcher : new M(fst, match_type)),
        match_type_(match_type),
        rho_label_(rho_label),
        error_(false) {
    if (match_type == MATCH_BOTH) {
      FSTERROR() << "RhoMatcher: bad match type";
      match_type_ = MATCH_NONE;
      error_ = true;
    }
    if (rho_label == 0) {
      FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
      rho_label_ = kNoLabel;
      error_ = true;
    }

    if (rewrite_mode == MATCHER_REWRITE_AUTO)
      rewrite_both_ = fst.Properties(kAcceptor, true);
    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
      rewrite_both_ = true;
    else
      rewrite_both_ = false;
  }

  RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false)
      : matcher_(new M(*matcher.matcher_, safe)),
        match_type_(matcher.match_type_),
        rho_label_(matcher.rho_label_),
        rewrite_both_(matcher.rewrite_both_),
        error_(matcher.error_) {}

  virtual ~RhoMatcher() {
    delete matcher_;
  }

  virtual RhoMatcher<M> *Copy(bool safe = false) const {
    return new RhoMatcher<M>(*this, safe);
  }

  virtual MatchType Type(bool test) const { return matcher_->Type(test); }

  void SetState(StateId s) {
    matcher_->SetState(s);
    has_rho_ = rho_label_ != kNoLabel;
  }

  bool Find(Label match_label) {
    if (match_label == rho_label_ && rho_label_ != kNoLabel) {
      FSTERROR() << "RhoMatcher::Find: bad label (rho)";
      error_ = true;
      return false;
    }
    if (matcher_->Find(match_label)) {
      rho_match_ = kNoLabel;
      return true;
    } else if (has_rho_ && match_label != 0 && match_label != kNoLabel &&
               (has_rho_ = matcher_->Find(rho_label_))) {
      rho_match_ = match_label;
      return true;
    } else {
      return false;
    }
  }

  bool Done() const { return matcher_->Done(); }

  const Arc& Value() const {
    if (rho_match_ == kNoLabel) {
      return matcher_->Value();
    } else {
      rho_arc_ = matcher_->Value();
      if (rewrite_both_) {
        if (rho_arc_.ilabel == rho_label_)
          rho_arc_.ilabel = rho_match_;
        if (rho_arc_.olabel == rho_label_)
          rho_arc_.olabel = rho_match_;
      } else if (match_type_ == MATCH_INPUT) {
        rho_arc_.ilabel = rho_match_;
      } else {
        rho_arc_.olabel = rho_match_;
      }
      return rho_arc_;
    }
  }

  void Next() { matcher_->Next(); }

  virtual const FST &GetFst() const { return matcher_->GetFst(); }

  virtual uint64 Properties(uint64 props) const;

  virtual uint32 Flags() const {
    if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE)
      return matcher_->Flags();
    return matcher_->Flags() | kRequireMatch;
  }

 private:
  virtual void SetState_(StateId s) { SetState(s); }
  virtual bool Find_(Label label) { return Find(label); }
  virtual bool Done_() const { return Done(); }
  virtual const Arc& Value_() const { return Value(); }
  virtual void Next_() { Next(); }

  M *matcher_;
  MatchType match_type_;  // Type of match requested
  Label rho_label_;       // Label that represents the rho transition
  bool rewrite_both_;     // Rewrite both sides when both are 'rho_label_'
  bool has_rho_;          // Are there possibly rhos at the current state?
  Label rho_match_;       // Current label that matches rho transition
  mutable Arc rho_arc_;   // Arc to return when rho match
  bool error_;            // Error encountered

  void operator=(const RhoMatcher<M> &);  // Disallow
};

template <class M> inline
uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
  uint64 outprops = matcher_->Properties(inprops);
  if (error_) outprops |= kError;

  if (match_type_ == MATCH_NONE) {
    return outprops;
  } else if (match_type_ == MATCH_INPUT) {
    if (rewrite_both_) {
      return outprops & ~(kODeterministic | kNonODeterministic | kString |
                       kILabelSorted | kNotILabelSorted |
                       kOLabelSorted | kNotOLabelSorted);
    } else {
      return outprops & ~(kODeterministic | kAcceptor | kString |
                       kILabelSorted | kNotILabelSorted);
    }
  } else if (match_type_ == MATCH_OUTPUT) {
    if (rewrite_both_) {
      return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
                       kILabelSorted | kNotILabelSorted |
                       kOLabelSorted | kNotOLabelSorted);
    } else {
      return outprops & ~(kIDeterministic | kAcceptor | kString |
                       kOLabelSorted | kNotOLabelSorted);
    }
  } else {
    // Shouldn't ever get here.
    FSTERROR() << "RhoMatcher:: bad match type: " << match_type_;
    return 0;
  }
}


// For any requested label, this matcher considers all transitions
// that match the label 'sigma_label' (sigma = "any"), and this in
// additions to transitions with the requested label.  Each such sigma
// transition found is returned with the sigma_label rewritten as the
// requested label (both sides if an acceptor, or if 'rewrite_both' is
// true and both input and output labels of the found transition are
// 'sigma_label').  If 'sigma_label' is kNoLabel, this special
// matching is not done.  SigmaMatcher is templated itself on a
// matcher, which is used to perform the underlying matching.  By
// default, the underlying matcher is constructed by SigmaMatcher.
// The user can instead pass in this object; in that case,
// SigmaMatcher takes its ownership.
template <class M>
class SigmaMatcher : public MatcherBase<typename M::Arc> {
 public:
  typedef typename M::FST FST;
  typedef typename M::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  SigmaMatcher(const FST &fst,
               MatchType match_type,
               Label sigma_label = kNoLabel,
               MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
               M *matcher = 0)
      : matcher_(matcher ? matcher : new M(fst, match_type)),
        match_type_(match_type),
        sigma_label_(sigma_label),
        error_(false) {
    if (match_type == MATCH_BOTH) {
      FSTERROR() << "SigmaMatcher: bad match type";
      match_type_ = MATCH_NONE;
      error_ = true;
    }
    if (sigma_label == 0) {
      FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
      sigma_label_ = kNoLabel;
      error_ = true;
    }

    if (rewrite_mode == MATCHER_REWRITE_AUTO)
      rewrite_both_ = fst.Properties(kAcceptor, true);
    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
      rewrite_both_ = true;
    else
      rewrite_both_ = false;
  }

  SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false)
      : matcher_(new M(*matcher.matcher_, safe)),
        match_type_(matcher.match_type_),
        sigma_label_(matcher.sigma_label_),
        rewrite_both_(matcher.rewrite_both_),
        error_(matcher.error_) {}

  virtual ~SigmaMatcher() {
    delete matcher_;
  }

  virtual SigmaMatcher<M> *Copy(bool safe = false) const {
    return new SigmaMatcher<M>(*this, safe);
  }

  virtual MatchType Type(bool test) const { return matcher_->Type(test); }

  void SetState(StateId s) {
    matcher_->SetState(s);
    has_sigma_ =
        sigma_label_ != kNoLabel ? matcher_->Find(sigma_label_) : false;
  }

  bool Find(Label match_label) {
    match_label_ = match_label;
    if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
      FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
      error_ = true;
      return false;
    }
    if (matcher_->Find(match_label)) {
      sigma_match_ = kNoLabel;
      return true;
    } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
               matcher_->Find(sigma_label_)) {
      sigma_match_ = match_label;
      return true;
    } else {
      return false;
    }
  }

  bool Done() const {
    return matcher_->Done();
  }

  const Arc& Value() const {
    if (sigma_match_ == kNoLabel) {
      return matcher_->Value();
    } else {
      sigma_arc_ = matcher_->Value();
      if (rewrite_both_) {
        if (sigma_arc_.ilabel == sigma_label_)
          sigma_arc_.ilabel = sigma_match_;
        if (sigma_arc_.olabel == sigma_label_)
          sigma_arc_.olabel = sigma_match_;
      } else if (match_type_ == MATCH_INPUT) {
        sigma_arc_.ilabel = sigma_match_;
      } else {
        sigma_arc_.olabel = sigma_match_;
      }
      return sigma_arc_;
    }
  }

  void Next() {
    matcher_->Next();
    if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
        (match_label_ > 0)) {
      matcher_->Find(sigma_label_);
      sigma_match_ = match_label_;
    }
  }

  virtual const FST &GetFst() const { return matcher_->GetFst(); }

  virtual uint64 Properties(uint64 props) const;

  virtual uint32 Flags() const {
    if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE)
      return matcher_->Flags();
    // kRequireMatch temporarily disabled until issues
    // in //speech/gaudi/annotation/util/denorm are resolved.
    // return matcher_->Flags() | kRequireMatch;
    return matcher_->Flags();
  }

private:
  virtual void SetState_(StateId s) { SetState(s); }
  virtual bool Find_(Label label) { return Find(label); }
  virtual bool Done_() const { return Done(); }
  virtual const Arc& Value_() const { return Value(); }
  virtual void Next_() { Next(); }

  M *matcher_;
  MatchType match_type_;   // Type of match requested
  Label sigma_label_;      // Label that represents the sigma transition
  bool rewrite_both_;      // Rewrite both sides when both are 'sigma_label_'
  bool has_sigma_;         // Are there sigmas at the current state?
  Label sigma_match_;      // Current label that matches sigma transition
  mutable Arc sigma_arc_;  // Arc to return when sigma match
  Label match_label_;      // Label being matched
  bool error_;             // Error encountered

  void operator=(const SigmaMatcher<M> &);  // disallow
};

template <class M> inline
uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
  uint64 outprops = matcher_->Properties(inprops);
  if (error_) outprops |= kError;

  if (match_type_ == MATCH_NONE) {
    return outprops;
  } else if (rewrite_both_) {
    return outprops & ~(kIDeterministic | kNonIDeterministic |
                     kODeterministic | kNonODeterministic |
                     kILabelSorted | kNotILabelSorted |
                     kOLabelSorted | kNotOLabelSorted |
                     kString);
  } else if (match_type_ == MATCH_INPUT) {
    return outprops & ~(kIDeterministic | kNonIDeterministic |
                     kODeterministic | kNonODeterministic |
                     kILabelSorted | kNotILabelSorted |
                     kString | kAcceptor);
  } else if (match_type_ == MATCH_OUTPUT) {
    return outprops & ~(kIDeterministic | kNonIDeterministic |
                     kODeterministic | kNonODeterministic |
                     kOLabelSorted | kNotOLabelSorted |
                     kString | kAcceptor);
  } else {
    // Shouldn't ever get here.
    FSTERROR() << "SigmaMatcher:: bad match type: " << match_type_;
    return 0;
  }
}


// For any requested label that doesn't match at a state, this matcher
// considers the *unique* transition that matches the label 'phi_label'
// (phi = 'fail'), and recursively looks for a match at its
// destination.  When 'phi_loop' is true, if no match is found but a
// phi self-loop is found, then the phi transition found is returned
// with the phi_label rewritten as the requested label (both sides if
// an acceptor, or if 'rewrite_both' is true and both input and output
// labels of the found transition are 'phi_label').  If 'phi_label' is
// kNoLabel, this special matching is not done.  PhiMatcher is
// templated itself on a matcher, which is used to perform the
// underlying matching.  By default, the underlying matcher is
// constructed by PhiMatcher. The user can instead pass in this
// object; in that case, PhiMatcher takes its ownership.
// Warning: phi non-determinism not supported (for simplicity).
template <class M>
class PhiMatcher : public MatcherBase<typename M::Arc> {
 public:
  typedef typename M::FST FST;
  typedef typename M::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  PhiMatcher(const FST &fst,
             MatchType match_type,
             Label phi_label = kNoLabel,
             bool phi_loop = true,
             MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
             M *matcher = 0)
      : matcher_(matcher ? matcher : new M(fst, match_type)),
        match_type_(match_type),
        phi_label_(phi_label),
        state_(kNoStateId),
        phi_loop_(phi_loop),
        error_(false) {
    if (match_type == MATCH_BOTH) {
      FSTERROR() << "PhiMatcher: bad match type";
      match_type_ = MATCH_NONE;
      error_ = true;
    }

    if (rewrite_mode == MATCHER_REWRITE_AUTO)
      rewrite_both_ = fst.Properties(kAcceptor, true);
    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
      rewrite_both_ = true;
    else
      rewrite_both_ = false;
   }

  PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false)
      : matcher_(new M(*matcher.matcher_, safe)),
        match_type_(matcher.match_type_),
        phi_label_(matcher.phi_label_),
        rewrite_both_(matcher.rewrite_both_),
        state_(kNoStateId),
        phi_loop_(matcher.phi_loop_),
        error_(matcher.error_) {}

  virtual ~PhiMatcher() {
    delete matcher_;
  }

  virtual PhiMatcher<M> *Copy(bool safe = false) const {
    return new PhiMatcher<M>(*this, safe);
  }

  virtual MatchType Type(bool test) const { return matcher_->Type(test); }

  void SetState(StateId s) {
    matcher_->SetState(s);
    state_ = s;
    has_phi_ = phi_label_ != kNoLabel;
  }

  bool Find(Label match_label);

  bool Done() const { return matcher_->Done(); }

  const Arc& Value() const {
    if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
      return matcher_->Value();
    } else if (phi_match_ == 0) {  // Virtual epsilon loop
      phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
      if (match_type_ == MATCH_OUTPUT)
        swap(phi_arc_.ilabel, phi_arc_.olabel);
      return phi_arc_;
    } else {
      phi_arc_ = matcher_->Value();
      phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
      if (phi_match_ != kNoLabel) {  // Phi loop match
        if (rewrite_both_) {
          if (phi_arc_.ilabel == phi_label_)
            phi_arc_.ilabel = phi_match_;
          if (phi_arc_.olabel == phi_label_)
            phi_arc_.olabel = phi_match_;
        } else if (match_type_ == MATCH_INPUT) {
          phi_arc_.ilabel = phi_match_;
        } else {
          phi_arc_.olabel = phi_match_;
        }
      }
      return phi_arc_;
    }
  }

  void Next() { matcher_->Next(); }

  virtual const FST &GetFst() const { return matcher_->GetFst(); }

  virtual uint64 Properties(uint64 props) const;

  virtual uint32 Flags() const {
    if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE)
      return matcher_->Flags();
    return matcher_->Flags() | kRequireMatch;
  }

private:
  virtual void SetState_(StateId s) { SetState(s); }
  virtual bool Find_(Label label) { return Find(label); }
  virtual bool Done_() const { return Done(); }
  virtual const Arc& Value_() const { return Value(); }
  virtual void Next_() { Next(); }

  M *matcher_;
  MatchType match_type_;  // Type of match requested
  Label phi_label_;       // Label that represents the phi transition
  bool rewrite_both_;     // Rewrite both sides when both are 'phi_label_'
  bool has_phi_;          // Are there possibly phis at the current state?
  Label phi_match_;       // Current label that matches phi loop
  mutable Arc phi_arc_;   // Arc to return
  StateId state_;         // State where looking for matches
  Weight phi_weight_;     // Product of the weights of phi transitions taken
  bool phi_loop_;         // When true, phi self-loop are allowed and treated
                          // as rho (required for Aho-Corasick)
  bool error_;             // Error encountered

  void operator=(const PhiMatcher<M> &);  // disallow
};

template <class M> inline
bool PhiMatcher<M>::Find(Label match_label) {
  if (match_label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
    FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
    error_ = true;
    return false;
  }
  matcher_->SetState(state_);
  phi_match_ = kNoLabel;
  phi_weight_ = Weight::One();
  if (phi_label_ == 0) {          // When 'phi_label_ == 0',
    if (match_label == kNoLabel)  // there are no more true epsilon arcs,
      return false;
    if (match_label == 0) {       // but virtual eps loop need to be returned
      if (!matcher_->Find(kNoLabel)) {
        return matcher_->Find(0);
      } else {
        phi_match_ = 0;
        return true;
      }
    }
  }
  if (!has_phi_ || match_label == 0 || match_label == kNoLabel)
    return matcher_->Find(match_label);
  StateId state = state_;
  while (!matcher_->Find(match_label)) {
    // Look for phi transition (if phi_label_ == 0, we need to look
    // for -1 to avoid getting the virtual self-loop)
    if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_))
      return false;
    if (phi_loop_ && matcher_->Value().nextstate == state) {
      phi_match_ = match_label;
      return true;
    }
    phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
    state = matcher_->Value().nextstate;
    matcher_->Next();
    if (!matcher_->Done()) {
      FSTERROR() << "PhiMatcher: phi non-determinism not supported";
      error_ = true;
    }
    matcher_->SetState(state);
  }
  return true;
}

template <class M> inline
uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
  uint64 outprops = matcher_->Properties(inprops);
  if (error_) outprops |= kError;

  if (match_type_ == MATCH_NONE) {
    return outprops;
  } else if (match_type_ == MATCH_INPUT) {
    if (phi_label_ == 0) {
      outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
      outprops |= kNoEpsilons | kNoIEpsilons;
    }
    if (rewrite_both_) {
      return outprops & ~(kODeterministic | kNonODeterministic | kString |
                       kILabelSorted | kNotILabelSorted |
                       kOLabelSorted | kNotOLabelSorted);
    } else {
      return outprops & ~(kODeterministic | kAcceptor | kString |
                       kILabelSorted | kNotILabelSorted |
                       kOLabelSorted | kNotOLabelSorted);
    }
  } else if (match_type_ == MATCH_OUTPUT) {
    if (phi_label_ == 0) {
      outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
      outprops |= kNoEpsilons | kNoOEpsilons;
    }
    if (rewrite_both_) {
      return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
                       kILabelSorted | kNotILabelSorted |
                       kOLabelSorted | kNotOLabelSorted);
    } else {
      return outprops & ~(kIDeterministic | kAcceptor | kString |
                       kILabelSorted | kNotILabelSorted |
                       kOLabelSorted | kNotOLabelSorted);
    }
  } else {
    // Shouldn't ever get here.
    FSTERROR() << "PhiMatcher:: bad match type: " << match_type_;
    return 0;
  }
}


//
// MULTI-EPS MATCHER FLAGS
//

// Return multi-epsilon arcs for Find(kNoLabel).
const uint32 kMultiEpsList =  0x00000001;

// Return a kNolabel loop for Find(multi_eps).
const uint32 kMultiEpsLoop =  0x00000002;

// MultiEpsMatcher: allows treating multiple non-0 labels as
// non-consuming labels in addition to 0 that is always
// non-consuming. Precise behavior controlled by 'flags' argument. By
// default, the underlying matcher is constructed by
// MultiEpsMatcher. The user can instead pass in this object; in that
// case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
// true.
template <class M>
class MultiEpsMatcher {
 public:
  typedef typename M::FST FST;
  typedef typename M::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  MultiEpsMatcher(const FST &fst, MatchType match_type,
                  uint32 flags = (kMultiEpsLoop | kMultiEpsList),
                  M *matcher = 0, bool own_matcher = true)
      : matcher_(matcher ? matcher : new M(fst, match_type)),
        flags_(flags),
        own_matcher_(matcher ?  own_matcher : true) {
    if (match_type == MATCH_INPUT) {
      loop_.ilabel = kNoLabel;
      loop_.olabel = 0;
    } else {
      loop_.ilabel = 0;
      loop_.olabel = kNoLabel;
    }
    loop_.weight = Weight::One();
    loop_.nextstate = kNoStateId;
  }

  MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false)
      : matcher_(new M(*matcher.matcher_, safe)),
        flags_(matcher.flags_),
        own_matcher_(true),
        multi_eps_labels_(matcher.multi_eps_labels_),
        loop_(matcher.loop_) {
    loop_.nextstate = kNoStateId;
  }

  ~MultiEpsMatcher() {
    if (own_matcher_)
      delete matcher_;
  }

  MultiEpsMatcher<M> *Copy(bool safe = false) const {
    return new MultiEpsMatcher<M>(*this, safe);
  }

  MatchType Type(bool test) const { return matcher_->Type(test); }

  void SetState(StateId s) {
    matcher_->SetState(s);
    loop_.nextstate = s;
  }

  bool Find(Label match_label);

  bool Done() const {
    return done_;
  }

  const Arc& Value() const {
    return current_loop_ ? loop_ : matcher_->Value();
  }

  void Next() {
    if (!current_loop_) {
      matcher_->Next();
      done_ = matcher_->Done();
      if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
        ++multi_eps_iter_;
        while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
               !matcher_->Find(*multi_eps_iter_))
          ++multi_eps_iter_;
        if (multi_eps_iter_ != multi_eps_labels_.End())
          done_ = false;
        else
          done_ = !matcher_->Find(kNoLabel);

      }
    } else {
      done_ = true;
    }
  }

  const FST &GetFst() const { return matcher_->GetFst(); }

  uint64 Properties(uint64 props) const { return matcher_->Properties(props); }

  uint32 Flags() const { return matcher_->Flags(); }

  void AddMultiEpsLabel(Label label) {
    if (label == 0) {
      FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
    } else {
      multi_eps_labels_.Insert(label);
    }
  }

  void RemoveMultiEpsLabel(Label label) {
    if (label == 0) {
      FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
    } else {
      multi_eps_labels_.Erase(label);
    }
  }

  void ClearMultiEpsLabels() {
    multi_eps_labels_.Clear();
  }

private:
  M *matcher_;
  uint32 flags_;
  bool own_matcher_;             // Does this class delete the matcher?

  // Multi-eps label set
  CompactSet<Label, kNoLabel> multi_eps_labels_;
  typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;

  bool current_loop_;            // Current arc is the implicit loop
  mutable Arc loop_;             // For non-consuming symbols
  bool done_;                    // Matching done

  void operator=(const MultiEpsMatcher<M> &);  // Disallow
};

template <class M> inline
bool MultiEpsMatcher<M>::Find(Label match_label) {
  multi_eps_iter_ = multi_eps_labels_.End();
  current_loop_ = false;
  bool ret;
  if (match_label == 0) {
    ret = matcher_->Find(0);
  } else if (match_label == kNoLabel) {
    if (flags_ & kMultiEpsList) {
      // return all non-consuming arcs (incl. epsilon)
      multi_eps_iter_ = multi_eps_labels_.Begin();
      while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
             !matcher_->Find(*multi_eps_iter_))
        ++multi_eps_iter_;
      if (multi_eps_iter_ != multi_eps_labels_.End())
        ret = true;
      else
        ret = matcher_->Find(kNoLabel);
    } else {
      // return all epsilon arcs
      ret = matcher_->Find(kNoLabel);
    }
  } else if ((flags_ & kMultiEpsLoop) &&
             multi_eps_labels_.Find(match_label) != multi_eps_labels_.End()) {
    // return 'implicit' loop
    current_loop_ = true;
    ret = true;
  } else {
    ret = matcher_->Find(match_label);
  }
  done_ = !ret;
  return ret;
}


// Generic matcher, templated on the FST definition
// - a wrapper around pointer to specific one.
// Here is a typical use: \code
//   Matcher<StdFst> matcher(fst, MATCH_INPUT);
//   matcher.SetState(state);
//   if (matcher.Find(label))
//     for (; !matcher.Done(); matcher.Next()) {
//       StdArc &arc = matcher.Value();
//       ...
//     } \endcode
template <class F>
class Matcher {
 public:
  typedef F FST;
  typedef typename F::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  Matcher(const F &fst, MatchType match_type) {
    base_ = fst.InitMatcher(match_type);
    if (!base_)
      base_ = new SortedMatcher<F>(fst, match_type);
  }

  Matcher(const Matcher<F> &matcher, bool safe = false) {
    base_ = matcher.base_->Copy(safe);
  }

  // Takes ownership of the provided matcher
  Matcher(MatcherBase<Arc>* base_matcher) { base_ = base_matcher; }

  ~Matcher() { delete base_; }

  Matcher<F> *Copy(bool safe = false) const {
    return new Matcher<F>(*this, safe);
  }

  MatchType Type(bool test) const { return base_->Type(test); }
  void SetState(StateId s) { base_->SetState(s); }
  bool Find(Label label) { return base_->Find(label); }
  bool Done() const { return base_->Done(); }
  const Arc& Value() const { return base_->Value(); }
  void Next() { base_->Next(); }
  const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
  uint64 Properties(uint64 props) const { return base_->Properties(props); }
  uint32 Flags() const { return base_->Flags() & kMatcherFlags; }

 private:
  MatcherBase<Arc> *base_;

  void operator=(const Matcher<Arc> &);  // disallow
};

}  // namespace fst



#endif  // FST_LIB_MATCHER_H__