// compose.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Compose a PDT and an FST.

#ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
#define FST_EXTENSIONS_PDT_COMPOSE_H__

#include <list>

#include <fst/extensions/pdt/pdt.h>
#include <fst/compose.h>

namespace fst {

// Return paren arcs for Find(kNoLabel).
const uint32 kParenList =  0x00000001;

// Return a kNolabel loop for Find(paren).
const uint32 kParenLoop =  0x00000002;

// This class is a matcher that treats parens as multi-epsilon labels.
// It is most efficient if the parens are in a range non-overlapping with
// the non-paren labels.
template <class F>
class ParenMatcher {
 public:
  typedef SortedMatcher<F> M;
  typedef typename M::FST FST;
  typedef typename M::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  ParenMatcher(const FST &fst, MatchType match_type,
               uint32 flags = (kParenLoop | kParenList))
      : matcher_(fst, match_type),
        match_type_(match_type),
        flags_(flags) {
    if (match_type == MATCH_INPUT) {
      loop_.ilabel = kNoLabel;
      loop_.olabel = 0;
    } else {
      loop_.ilabel = 0;
      loop_.olabel = kNoLabel;
    }
    loop_.weight = Weight::One();
    loop_.nextstate = kNoStateId;
  }

  ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false)
      : matcher_(matcher.matcher_, safe),
        match_type_(matcher.match_type_),
        flags_(matcher.flags_),
        open_parens_(matcher.open_parens_),
        close_parens_(matcher.close_parens_),
        loop_(matcher.loop_) {
    loop_.nextstate = kNoStateId;
  }

  ParenMatcher<F> *Copy(bool safe = false) const {
    return new ParenMatcher<F>(*this, safe);
  }

  MatchType Type(bool test) const { return matcher_.Type(test); }

  void SetState(StateId s) {
    matcher_.SetState(s);
    loop_.nextstate = s;
  }

  bool Find(Label match_label);

  bool Done() const {
    return done_;
  }

  const Arc& Value() const {
    return paren_loop_ ? loop_ : matcher_.Value();
  }

  void Next();

  const FST &GetFst() const { return matcher_.GetFst(); }

  uint64 Properties(uint64 props) const { return matcher_.Properties(props); }

  uint32 Flags() const { return matcher_.Flags(); }

  void AddOpenParen(Label label) {
    if (label == 0) {
      FSTERROR() << "ParenMatcher: Bad open paren label: 0";
    } else {
      open_parens_.Insert(label);
    }
  }

  void AddCloseParen(Label label) {
    if (label == 0) {
      FSTERROR() << "ParenMatcher: Bad close paren label: 0";
    } else {
      close_parens_.Insert(label);
    }
  }

  void RemoveOpenParen(Label label) {
    if (label == 0) {
      FSTERROR() << "ParenMatcher: Bad open paren label: 0";
    } else {
      open_parens_.Erase(label);
    }
  }

  void RemoveCloseParen(Label label) {
    if (label == 0) {
      FSTERROR() << "ParenMatcher: Bad close paren label: 0";
    } else {
      close_parens_.Erase(label);
    }
  }

  void ClearOpenParens() {
    open_parens_.Clear();
  }

  void ClearCloseParens() {
    close_parens_.Clear();
  }

  bool IsOpenParen(Label label) const {
    return open_parens_.Member(label);
  }

  bool IsCloseParen(Label label) const {
    return close_parens_.Member(label);
  }

 private:
  // Advances matcher to next open paren if it exists, returning true.
  // O.w. returns false.
  bool NextOpenParen();

  // Advances matcher to next open paren if it exists, returning true.
  // O.w. returns false.
  bool NextCloseParen();

  M matcher_;
  MatchType match_type_;          // Type of match to perform
  uint32 flags_;

  // open paren label set
  CompactSet<Label, kNoLabel> open_parens_;

  // close paren label set
  CompactSet<Label, kNoLabel> close_parens_;


  bool open_paren_list_;         // Matching open paren list
  bool close_paren_list_;        // Matching close paren list
  bool paren_loop_;              // Current arc is the implicit paren loop
  mutable Arc loop_;             // For non-consuming symbols
  bool done_;                    // Matching done

  void operator=(const ParenMatcher<F> &);  // Disallow
};

template <class M> inline
bool ParenMatcher<M>::Find(Label match_label) {
  open_paren_list_ = false;
  close_paren_list_ = false;
  paren_loop_ = false;
  done_ = false;

  // Returns all parenthesis arcs
  if (match_label == kNoLabel && (flags_ & kParenList)) {
    if (open_parens_.LowerBound() != kNoLabel) {
      matcher_.LowerBound(open_parens_.LowerBound());
      open_paren_list_ = NextOpenParen();
      if (open_paren_list_) return true;
    }
    if (close_parens_.LowerBound() != kNoLabel) {
      matcher_.LowerBound(close_parens_.LowerBound());
      close_paren_list_ = NextCloseParen();
      if (close_paren_list_) return true;
    }
  }

  // Returns 'implicit' paren loop
  if (match_label > 0 && (flags_ & kParenLoop) &&
      (IsOpenParen(match_label) || IsCloseParen(match_label))) {
    paren_loop_ = true;
    return true;
  }

  // Returns all other labels
  if (matcher_.Find(match_label))
    return true;

  done_ = true;
  return false;
}

template <class F> inline
void ParenMatcher<F>::Next() {
  if (paren_loop_) {
    paren_loop_ = false;
    done_ = true;
  } else if (open_paren_list_) {
    matcher_.Next();
    open_paren_list_ = NextOpenParen();
    if (open_paren_list_) return;

    if (close_parens_.LowerBound() != kNoLabel) {
      matcher_.LowerBound(close_parens_.LowerBound());
      close_paren_list_ = NextCloseParen();
      if (close_paren_list_) return;
    }
    done_ = !matcher_.Find(kNoLabel);
  } else if (close_paren_list_) {
    matcher_.Next();
    close_paren_list_ = NextCloseParen();
    if (close_paren_list_) return;
    done_ = !matcher_.Find(kNoLabel);
  } else {
    matcher_.Next();
    done_ = matcher_.Done();
  }
}

// Advances matcher to next open paren if it exists, returning true.
// O.w. returns false.
template <class F> inline
bool ParenMatcher<F>::NextOpenParen() {
  for (; !matcher_.Done(); matcher_.Next()) {
    Label label = match_type_ == MATCH_INPUT ?
        matcher_.Value().ilabel : matcher_.Value().olabel;
    if (label > open_parens_.UpperBound())
      return false;
    if (IsOpenParen(label))
      return true;
  }
  return false;
}

// Advances matcher to next close paren if it exists, returning true.
// O.w. returns false.
template <class F> inline
bool ParenMatcher<F>::NextCloseParen() {
  for (; !matcher_.Done(); matcher_.Next()) {
    Label label = match_type_ == MATCH_INPUT ?
        matcher_.Value().ilabel : matcher_.Value().olabel;
    if (label > close_parens_.UpperBound())
      return false;
    if (IsCloseParen(label))
      return true;
  }
  return false;
}


template <class F>
class ParenFilter {
 public:
  typedef typename F::FST1 FST1;
  typedef typename F::FST2 FST2;
  typedef typename F::Arc Arc;
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;
  typedef typename F::Matcher1 Matcher1;
  typedef typename F::Matcher2 Matcher2;
  typedef typename F::FilterState FilterState1;
  typedef StateId StackId;
  typedef PdtStack<StackId, Label> ParenStack;
  typedef IntegerFilterState<StackId> FilterState2;
  typedef PairFilterState<FilterState1, FilterState2> FilterState;
  typedef ParenFilter<F> Filter;

  ParenFilter(const FST1 &fst1, const FST2 &fst2,
              Matcher1 *matcher1 = 0,  Matcher2 *matcher2 = 0,
              const vector<pair<Label, Label> > *parens = 0,
              bool expand = false, bool keep_parens = true)
      : filter_(fst1, fst2, matcher1, matcher2),
        parens_(parens ? *parens : vector<pair<Label, Label> >()),
        expand_(expand),
        keep_parens_(keep_parens),
        f_(FilterState::NoState()),
        stack_(parens_),
        paren_id_(-1) {
    if (parens) {
      for (size_t i = 0; i < parens->size(); ++i) {
        const pair<Label, Label>  &p = (*parens)[i];
        parens_.push_back(p);
        GetMatcher1()->AddOpenParen(p.first);
        GetMatcher2()->AddOpenParen(p.first);
        if (!expand_) {
          GetMatcher1()->AddCloseParen(p.second);
          GetMatcher2()->AddCloseParen(p.second);
        }
      }
    }
  }

  ParenFilter(const Filter &filter, bool safe = false)
      : filter_(filter.filter_, safe),
        parens_(filter.parens_),
        expand_(filter.expand_),
        keep_parens_(filter.keep_parens_),
        f_(FilterState::NoState()),
        stack_(filter.parens_),
        paren_id_(-1) { }

  FilterState Start() const {
    return FilterState(filter_.Start(), FilterState2(0));
  }

  void SetState(StateId s1, StateId s2, const FilterState &f) {
    f_ = f;
    filter_.SetState(s1, s2, f_.GetState1());
    if (!expand_)
      return;

    ssize_t paren_id = stack_.Top(f.GetState2().GetState());
    if (paren_id != paren_id_) {
      if (paren_id_ != -1) {
        GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
        GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
      }
      paren_id_ = paren_id;
      if (paren_id_ != -1) {
        GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
        GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
      }
    }
  }

  FilterState FilterArc(Arc *arc1, Arc *arc2) const {
    FilterState1 f1 = filter_.FilterArc(arc1, arc2);
    const FilterState2 &f2 = f_.GetState2();
    if (f1 == FilterState1::NoState())
      return FilterState::NoState();

    if (arc1->olabel == kNoLabel && arc2->ilabel) {         // arc2 parentheses
      if (keep_parens_) {
        arc1->ilabel = arc2->ilabel;
      } else if (arc2->ilabel) {
        arc2->olabel = arc1->ilabel;
      }
      return FilterParen(arc2->ilabel, f1, f2);
    } else if (arc2->ilabel == kNoLabel && arc1->olabel) {  // arc1 parentheses
      if (keep_parens_) {
        arc2->olabel = arc1->olabel;
      } else {
        arc1->ilabel = arc2->olabel;
      }
      return FilterParen(arc1->olabel, f1, f2);
    } else {
      return FilterState(f1, f2);
    }
  }

  void FilterFinal(Weight *w1, Weight *w2) const {
    if (f_.GetState2().GetState() != 0)
      *w1 = Weight::Zero();
    filter_.FilterFinal(w1, w2);
  }

  // Return resp matchers. Ownership stays with filter.
  Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
  Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }

  uint64 Properties(uint64 iprops) const {
    uint64 oprops = filter_.Properties(iprops);
    return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
  }

 private:
  const FilterState FilterParen(Label label, const FilterState1 &f1,
                                const FilterState2 &f2) const {
    if (!expand_)
      return FilterState(f1, f2);

    StackId stack_id = stack_.Find(f2.GetState(), label);
    if (stack_id < 0) {
      return FilterState::NoState();
    } else {
      return FilterState(f1, FilterState2(stack_id));
    }
  }

  F filter_;
  vector<pair<Label, Label> > parens_;
  bool expand_;                    // Expands to FST
  bool keep_parens_;               // Retains parentheses in output
  FilterState f_;                  // Current filter state
  mutable ParenStack stack_;
  ssize_t paren_id_;
};

// Class to setup composition options for PDT composition.
// Default is for the PDT as the first composition argument.
template <class Arc, bool left_pdt = true>
class PdtComposeFstOptions : public
ComposeFstOptions<Arc,
                  ParenMatcher< Fst<Arc> >,
                  ParenFilter<AltSequenceComposeFilter<
                                ParenMatcher< Fst<Arc> > > > > {
 public:
  typedef typename Arc::Label Label;
  typedef ParenMatcher< Fst<Arc> > PdtMatcher;
  typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
  typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
  using COptions::matcher1;
  using COptions::matcher2;
  using COptions::filter;

  PdtComposeFstOptions(const Fst<Arc> &ifst1,
                    const vector<pair<Label, Label> > &parens,
                       const Fst<Arc> &ifst2, bool expand = false,
                       bool keep_parens = true) {
    matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList);
    matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop);

    filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
                           expand, keep_parens);
  }
};

// Class to setup composition options for PDT with FST composition.
// Specialization is for the FST as the first composition argument.
template <class Arc>
class PdtComposeFstOptions<Arc, false> : public
ComposeFstOptions<Arc,
                  ParenMatcher< Fst<Arc> >,
                  ParenFilter<SequenceComposeFilter<
                                ParenMatcher< Fst<Arc> > > > > {
 public:
  typedef typename Arc::Label Label;
  typedef ParenMatcher< Fst<Arc> > PdtMatcher;
  typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
  typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
  using COptions::matcher1;
  using COptions::matcher2;
  using COptions::filter;

  PdtComposeFstOptions(const Fst<Arc> &ifst1,
                       const Fst<Arc> &ifst2,
                       const vector<pair<Label, Label> > &parens,
                       bool expand = false, bool keep_parens = true) {
    matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop);
    matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList);

    filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
                           expand, keep_parens);
  }
};

enum PdtComposeFilter {
  PAREN_FILTER,          // Bar-Hillel construction; keeps parentheses
  EXPAND_FILTER,         // Bar-Hillel + expansion; removes parentheses
  EXPAND_PAREN_FILTER,   // Bar-Hillel + expansion; keeps parentheses
};

struct PdtComposeOptions {
  bool connect;  // Connect output
  PdtComposeFilter filter_type;  // Which pre-defined filter to use

  explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER)
      : connect(c), filter_type(ft) {}
  PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {}
};

// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
// an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
// In the PDTs, some transitions are labeled with open or close
// parentheses. To be interpreted as a PDT, the parens must balance on
// a path (see PdtExpand()). The open-close parenthesis label pairs
// are passed in 'parens'.
template <class Arc>
void Compose(const Fst<Arc> &ifst1,
             const vector<pair<typename Arc::Label,
                               typename Arc::Label> > &parens,
             const Fst<Arc> &ifst2,
             MutableFst<Arc> *ofst,
             const PdtComposeOptions &opts = PdtComposeOptions()) {
  bool expand = opts.filter_type != PAREN_FILTER;
  bool keep_parens = opts.filter_type != EXPAND_FILTER;
  PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2,
                                        expand, keep_parens);
  copts.gc_limit = 0;
  *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
  if (opts.connect)
    Connect(ofst);
}

// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
// an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
// In the PDTs, some transitions are labeled with open or close
// parentheses. To be interpreted as a PDT, the parens must balance on
// a path (see ExpandFst()). The open-close parenthesis label pairs
// are passed in 'parens'.
template <class Arc>
void Compose(const Fst<Arc> &ifst1,
             const Fst<Arc> &ifst2,
             const vector<pair<typename Arc::Label,
                               typename Arc::Label> > &parens,
             MutableFst<Arc> *ofst,
             const PdtComposeOptions &opts = PdtComposeOptions()) {
  bool expand = opts.filter_type != PAREN_FILTER;
  bool keep_parens = opts.filter_type != EXPAND_FILTER;
  PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens,
                                         expand, keep_parens);
  copts.gc_limit = 0;
  *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
  if (opts.connect)
    Connect(ofst);
}

}  // namespace fst

#endif  // FST_EXTENSIONS_PDT_COMPOSE_H__