// rational.h
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// \file
// An Fst implementation and base interface for delayed unions,
// concatenations and closures.

#ifndef FST_LIB_RATIONAL_H__
#define FST_LIB_RATIONAL_H__

#include "fst/lib/map.h"
#include "fst/lib/mutable-fst.h"
#include "fst/lib/replace.h"
#include "fst/lib/test-properties.h"

namespace fst {

typedef CacheOptions RationalFstOptions;

// This specifies whether to add the empty string.
enum ClosureType { CLOSURE_STAR = 0,    // T* -> add the empty string
                   CLOSURE_PLUS = 1 };  // T+ -> don't add the empty string

template <class A> class RationalFst;
template <class A> void Union(RationalFst<A> *fst1, const Fst<A> &fst2);
template <class A> void Concat(RationalFst<A> *fst1, const Fst<A> &fst2);
template <class A> void Closure(RationalFst<A> *fst, ClosureType closure_type);


// Implementation class for delayed unions, concatenations and closures.
template<class A>
class RationalFstImpl : public ReplaceFstImpl<A> {
 public:
  using FstImpl<A>::SetType;
  using FstImpl<A>::SetProperties;
  using FstImpl<A>::Properties;
  using FstImpl<A>::SetInputSymbols;
  using FstImpl<A>::SetOutputSymbols;
  using ReplaceFstImpl<A>::SetRoot;

  typedef typename A::Weight Weight;
  typedef typename A::Label Label;

  explicit RationalFstImpl(const RationalFstOptions &opts)
      : ReplaceFstImpl<A>(ReplaceFstOptions(opts, kNoLabel)),
        nonterminals_(0) {
    SetType("rational");
  }

  // Implementation of UnionFst(fst1,fst2)
  void InitUnion(const Fst<A> &fst1, const Fst<A> &fst2) {
    uint64 props1 = fst1.Properties(kFstProperties, false);
    uint64 props2 = fst2.Properties(kFstProperties, false);
    SetInputSymbols(fst1.InputSymbols());
    SetOutputSymbols(fst1.OutputSymbols());
    rfst_.AddState();
    rfst_.AddState();
    rfst_.SetStart(0);
    rfst_.SetFinal(1, Weight::One());
    rfst_.SetInputSymbols(fst1.InputSymbols());
    rfst_.SetOutputSymbols(fst1.OutputSymbols());
    nonterminals_ = 2;
    rfst_.AddArc(0, A(0, -1, Weight::One(), 1));
    rfst_.AddArc(0, A(0, -2, Weight::One(), 1));
    AddFst(0, &rfst_);
    AddFst(-1, &fst1);
    AddFst(-2, &fst2);
    SetRoot(0);
    SetProperties(UnionProperties(props1, props2, true), kCopyProperties);
  }

  // Implementation of ConcatFst(fst1,fst2)
  void InitConcat(const Fst<A> &fst1, const Fst<A> &fst2) {
    uint64 props1 = fst1.Properties(kFstProperties, false);
    uint64 props2 = fst2.Properties(kFstProperties, false);
    SetInputSymbols(fst1.InputSymbols());
    SetOutputSymbols(fst1.OutputSymbols());
    rfst_.AddState();
    rfst_.AddState();
    rfst_.AddState();
    rfst_.SetStart(0);
    rfst_.SetFinal(2, Weight::One());
    rfst_.SetInputSymbols(fst1.InputSymbols());
    rfst_.SetOutputSymbols(fst1.OutputSymbols());
    nonterminals_ = 2;
    rfst_.AddArc(0, A(0, -1, Weight::One(), 1));
    rfst_.AddArc(1, A(0, -2, Weight::One(), 2));
    AddFst(0, &rfst_);
    AddFst(-1, &fst1);
    AddFst(-2, &fst2);
    SetRoot(0);
    SetProperties(ConcatProperties(props1, props2, true), kCopyProperties);
  }

  // Implementation of ClosureFst(fst, closure_type)
  void InitClosure(const Fst<A> &fst, ClosureType closure_type) {
    uint64 props = fst.Properties(kFstProperties, false);
    SetInputSymbols(fst.InputSymbols());
    SetOutputSymbols(fst.OutputSymbols());
    if (closure_type == CLOSURE_STAR) {
      rfst_.AddState();
      rfst_.SetStart(0);
      rfst_.SetFinal(0, Weight::One());
      rfst_.AddArc(0, A(0, -1, Weight::One(), 0));
    } else {
      rfst_.AddState();
      rfst_.AddState();
      rfst_.SetStart(0);
      rfst_.SetFinal(1, Weight::One());
      rfst_.AddArc(0, A(0, -1, Weight::One(), 1));
      rfst_.AddArc(1, A(0, 0, Weight::One(), 0));
    }
    rfst_.SetInputSymbols(fst.InputSymbols());
    rfst_.SetOutputSymbols(fst.OutputSymbols());
    AddFst(0, &rfst_);
    AddFst(-1, &fst);
    SetRoot(0);
    nonterminals_ = 1;
    SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true),
                  kCopyProperties);
  }

  // Implementation of Union(Fst &, RationalFst *)
  void AddUnion(const Fst<A> &fst) {
    uint64 props1 = Properties();
    uint64 props2 = fst.Properties(kFstProperties, false);
    VectorFst<A> afst;
    afst.AddState();
    afst.AddState();
    afst.SetStart(0);
    afst.SetFinal(1, Weight::One());
    afst.AddArc(0, A(0, -nonterminals_, Weight::One(), 1));
    Union(&rfst_, afst);
    SetFst(0, &rfst_);
    ++nonterminals_;
    SetProperties(UnionProperties(props1, props2, true), kCopyProperties);
  }

  // Implementation of Concat(Fst &, RationalFst *)
  void AddConcat(const Fst<A> &fst) {
    uint64 props1 = Properties();
    uint64 props2 = fst.Properties(kFstProperties, false);
    VectorFst<A> afst;
    afst.AddState();
    afst.AddState();
    afst.SetStart(0);
    afst.SetFinal(1, Weight::One());
    afst.AddArc(0, A(0, -nonterminals_, Weight::One(), 1));
    Concat(&rfst_, afst);
    SetFst(0, &rfst_);
    ++nonterminals_;
    SetProperties(ConcatProperties(props1, props2, true), kCopyProperties);
  }

  // Implementation of Closure(RationalFst *, closure_type)
  void AddClosure(ClosureType closure_type) {
    uint64 props = Properties();
    Closure(&rfst_, closure_type);
    SetFst(0, &rfst_);
    SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true),
                  kCopyProperties);
  }

 private:
  VectorFst<A> rfst_;   // rational topology machine; uses neg. nonterminals
  Label nonterminals_;  // # of nonterminals used

  DISALLOW_EVIL_CONSTRUCTORS(RationalFstImpl);
};

// Parent class for the delayed rational operations - delayed union,
// concatenation, and closure.  This class attaches interface to
// implementation and handles reference counting.
template <class A>
class RationalFst : public Fst<A> {
 public:
  friend class CacheStateIterator< RationalFst<A> >;
  friend class ArcIterator< RationalFst<A> >;
  friend class CacheArcIterator< RationalFst<A> >;
  friend void Union<>(RationalFst<A> *fst1, const Fst<A> &fst2);
  friend void Concat<>(RationalFst<A> *fst1, const Fst<A> &fst2);
  friend void Closure<>(RationalFst<A> *fst, ClosureType closure_type);

  typedef A Arc;
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;
  typedef CacheState<A> State;

  virtual StateId Start() const { return impl_->Start(); }
  virtual Weight Final(StateId s) const { return impl_->Final(s); }
  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
  virtual size_t NumInputEpsilons(StateId s) const {
    return impl_->NumInputEpsilons(s);
  }
  virtual size_t NumOutputEpsilons(StateId s) const {
    return impl_->NumOutputEpsilons(s);
  }
  virtual uint64 Properties(uint64 mask, bool test) const {
    if (test) {
      uint64 known, test = TestProperties(*this, mask, &known);
      impl_->SetProperties(test, known);
      return test & mask;
    } else {
      return impl_->Properties(mask);
    }
  }
  virtual const string& Type() const { return impl_->Type(); }
  virtual const SymbolTable* InputSymbols() const {
    return impl_->InputSymbols();
  }
  virtual const SymbolTable* OutputSymbols() const {
    return impl_->OutputSymbols();
  }

  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;

  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    impl_->InitArcIterator(s, data);
  }

 protected:
  RationalFst() : impl_(new RationalFstImpl<A>(RationalFstOptions())) {}
  explicit RationalFst(const RationalFstOptions &opts)
      : impl_(new RationalFstImpl<A>(opts)) {}


  RationalFst(const RationalFst<A> &fst) : impl_(fst.impl_) {
    impl_->IncrRefCount();
  }

  virtual ~RationalFst() { if (!impl_->DecrRefCount()) delete impl_; }

  RationalFstImpl<A> *Impl() { return impl_; }

 private:
  RationalFstImpl<A> *impl_;

  void operator=(const RationalFst<A> &fst);  // disallow
};

// Specialization for RationalFst.
template <class A>
class StateIterator< RationalFst<A> >
    : public CacheStateIterator< RationalFst<A> > {
 public:
  explicit StateIterator(const RationalFst<A> &fst)
      : CacheStateIterator< RationalFst<A> >(fst) {}
};

// Specialization for RationalFst.
template <class A>
class ArcIterator< RationalFst<A> >
    : public CacheArcIterator< RationalFst<A> > {
 public:
  typedef typename A::StateId StateId;

  ArcIterator(const RationalFst<A> &fst, StateId s)
      : CacheArcIterator< RationalFst<A> >(fst, s) {
    if (!fst.impl_->HasArcs(s))
      fst.impl_->Expand(s);
  }

 private:
  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
};

template <class A> inline
void RationalFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
  data->base = new StateIterator< RationalFst<A> >(*this);
}

}  // namespace fst

#endif  // FST_LIB_RATIONAL_H__