// relabel.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: johans@google.com (Johan Schalkwyk)
//
// \file
// Functions and classes to relabel an Fst (either on input or output)
//
#ifndef FST_LIB_RELABEL_H__
#define FST_LIB_RELABEL_H__

#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;

#include <fst/cache.h>
#include <fst/test-properties.h>


#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;

namespace fst {

//
// Relabels either the input labels or output labels. The old to
// new labels are specified using a vector of pair<Label,Label>.
// Any label associations not specified are assumed to be identity
// mapping.
//
// \param fst input fst, must be mutable
// \param ipairs vector of input label pairs indicating old to new mapping
// \param opairs vector of output label pairs indicating old to new mapping
//
template <class A>
void Relabel(
    MutableFst<A> *fst,
    const vector<pair<typename A::Label, typename A::Label> >& ipairs,
    const vector<pair<typename A::Label, typename A::Label> >& opairs) {
  typedef typename A::StateId StateId;
  typedef typename A::Label   Label;

  uint64 props = fst->Properties(kFstProperties, false);

  // construct label to label hash.
  unordered_map<Label, Label> input_map;
  for (size_t i = 0; i < ipairs.size(); ++i) {
    input_map[ipairs[i].first] = ipairs[i].second;
  }

  unordered_map<Label, Label> output_map;
  for (size_t i = 0; i < opairs.size(); ++i) {
    output_map[opairs[i].first] = opairs[i].second;
  }

  for (StateIterator<MutableFst<A> > siter(*fst);
       !siter.Done(); siter.Next()) {
    StateId s = siter.Value();
    for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
         !aiter.Done(); aiter.Next()) {
      A arc = aiter.Value();

      // relabel input
      // only relabel if relabel pair defined
      typename unordered_map<Label, Label>::iterator it =
        input_map.find(arc.ilabel);
      if (it != input_map.end()) {
        if (it->second == kNoLabel) {
          FSTERROR() << "Input symbol id " << arc.ilabel
                     << " missing from target vocabulary";
          fst->SetProperties(kError, kError);
          return;
        }
        arc.ilabel = it->second;
      }

      // relabel output
      it = output_map.find(arc.olabel);
      if (it != output_map.end()) {
        if (it->second == kNoLabel) {
          FSTERROR() << "Output symbol id " << arc.olabel
                     << " missing from target vocabulary";
          fst->SetProperties(kError, kError);
          return;
        }
        arc.olabel = it->second;
      }

      aiter.SetValue(arc);
    }
  }

  fst->SetProperties(RelabelProperties(props), kFstProperties);
}

//
// Relabels either the input labels or output labels. The old to
// new labels mappings are specified using an input Symbol set.
// Any label associations not specified are assumed to be identity
// mapping.
//
// \param fst input fst, must be mutable
// \param new_isymbols symbol set indicating new mapping of input symbols
// \param new_osymbols symbol set indicating new mapping of output symbols
//
template<class A>
void Relabel(MutableFst<A> *fst,
             const SymbolTable* new_isymbols,
             const SymbolTable* new_osymbols) {
  Relabel(fst,
          fst->InputSymbols(), new_isymbols, true,
          fst->OutputSymbols(), new_osymbols, true);
}

template<class A>
void Relabel(MutableFst<A> *fst,
             const SymbolTable* old_isymbols,
             const SymbolTable* new_isymbols,
             bool attach_new_isymbols,
             const SymbolTable* old_osymbols,
             const SymbolTable* new_osymbols,
             bool attach_new_osymbols) {
  typedef typename A::StateId StateId;
  typedef typename A::Label   Label;

  vector<pair<Label, Label> > ipairs;
  if (old_isymbols && new_isymbols) {
    for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
         syms_iter.Next()) {
      string isymbol = syms_iter.Symbol();
      int isymbol_val = syms_iter.Value();
      int new_isymbol_val = new_isymbols->Find(isymbol);
      ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
    }
    if (attach_new_isymbols)
      fst->SetInputSymbols(new_isymbols);
  }

  vector<pair<Label, Label> > opairs;
  if (old_osymbols && new_osymbols) {
    for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
         syms_iter.Next()) {
      string osymbol = syms_iter.Symbol();
      int osymbol_val = syms_iter.Value();
      int new_osymbol_val = new_osymbols->Find(osymbol);
      opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
    }
    if (attach_new_osymbols)
      fst->SetOutputSymbols(new_osymbols);
  }

  // call relabel using vector of relabel pairs.
  Relabel(fst, ipairs, opairs);
}


typedef CacheOptions RelabelFstOptions;

template <class A> class RelabelFst;

//
// \class RelabelFstImpl
// \brief Implementation for delayed relabeling
//
// Relabels an FST from one symbol set to another. Relabeling
// can either be on input or output space. RelabelFst implements
// a delayed version of the relabel. Arcs are relabeled on the fly
// and not cached. I.e each request is recomputed.
//
template<class A>
class RelabelFstImpl : public CacheImpl<A> {
  friend class StateIterator< RelabelFst<A> >;
 public:
  using FstImpl<A>::SetType;
  using FstImpl<A>::SetProperties;
  using FstImpl<A>::WriteHeader;
  using FstImpl<A>::SetInputSymbols;
  using FstImpl<A>::SetOutputSymbols;

  using CacheImpl<A>::PushArc;
  using CacheImpl<A>::HasArcs;
  using CacheImpl<A>::HasFinal;
  using CacheImpl<A>::HasStart;
  using CacheImpl<A>::SetArcs;
  using CacheImpl<A>::SetFinal;
  using CacheImpl<A>::SetStart;

  typedef A Arc;
  typedef typename A::Label   Label;
  typedef typename A::Weight  Weight;
  typedef typename A::StateId StateId;
  typedef CacheState<A> State;

  RelabelFstImpl(const Fst<A>& fst,
                 const vector<pair<Label, Label> >& ipairs,
                 const vector<pair<Label, Label> >& opairs,
                 const RelabelFstOptions &opts)
      : CacheImpl<A>(opts), fst_(fst.Copy()),
        relabel_input_(false), relabel_output_(false) {
    uint64 props = fst.Properties(kCopyProperties, false);
    SetProperties(RelabelProperties(props));
    SetType("relabel");

    // create input label map
    if (ipairs.size() > 0) {
      for (size_t i = 0; i < ipairs.size(); ++i) {
        input_map_[ipairs[i].first] = ipairs[i].second;
      }
      relabel_input_ = true;
    }

    // create output label map
    if (opairs.size() > 0) {
      for (size_t i = 0; i < opairs.size(); ++i) {
        output_map_[opairs[i].first] = opairs[i].second;
      }
      relabel_output_ = true;
    }
  }

  RelabelFstImpl(const Fst<A>& fst,
                 const SymbolTable* old_isymbols,
                 const SymbolTable* new_isymbols,
                 const SymbolTable* old_osymbols,
                 const SymbolTable* new_osymbols,
                 const RelabelFstOptions &opts)
      : CacheImpl<A>(opts), fst_(fst.Copy()),
        relabel_input_(false), relabel_output_(false) {
    SetType("relabel");

    uint64 props = fst.Properties(kCopyProperties, false);
    SetProperties(RelabelProperties(props));
    SetInputSymbols(old_isymbols);
    SetOutputSymbols(old_osymbols);

    if (old_isymbols && new_isymbols &&
        old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
      for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
           syms_iter.Next()) {
        input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
      }
      SetInputSymbols(new_isymbols);
      relabel_input_ = true;
    }

    if (old_osymbols && new_osymbols &&
        old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
      for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
           syms_iter.Next()) {
        output_map_[syms_iter.Value()] =
          new_osymbols->Find(syms_iter.Symbol());
      }
      SetOutputSymbols(new_osymbols);
      relabel_output_ = true;
    }
  }

  RelabelFstImpl(const RelabelFstImpl<A>& impl)
      : CacheImpl<A>(impl),
        fst_(impl.fst_->Copy(true)),
        input_map_(impl.input_map_),
        output_map_(impl.output_map_),
        relabel_input_(impl.relabel_input_),
        relabel_output_(impl.relabel_output_) {
    SetType("relabel");
    SetProperties(impl.Properties(), kCopyProperties);
    SetInputSymbols(impl.InputSymbols());
    SetOutputSymbols(impl.OutputSymbols());
  }

  ~RelabelFstImpl() { delete fst_; }

  StateId Start() {
    if (!HasStart()) {
      StateId s = fst_->Start();
      SetStart(s);
    }
    return CacheImpl<A>::Start();
  }

  Weight Final(StateId s) {
    if (!HasFinal(s)) {
      SetFinal(s, fst_->Final(s));
    }
    return CacheImpl<A>::Final(s);
  }

  size_t NumArcs(StateId s) {
    if (!HasArcs(s)) {
      Expand(s);
    }
    return CacheImpl<A>::NumArcs(s);
  }

  size_t NumInputEpsilons(StateId s) {
    if (!HasArcs(s)) {
      Expand(s);
    }
    return CacheImpl<A>::NumInputEpsilons(s);
  }

  size_t NumOutputEpsilons(StateId s) {
    if (!HasArcs(s)) {
      Expand(s);
    }
    return CacheImpl<A>::NumOutputEpsilons(s);
  }

  uint64 Properties() const { return Properties(kFstProperties); }

  // Set error if found; return FST impl properties.
  uint64 Properties(uint64 mask) const {
    if ((mask & kError) && fst_->Properties(kError, false))
      SetProperties(kError, kError);
    return FstImpl<Arc>::Properties(mask);
  }

  void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
    if (!HasArcs(s)) {
      Expand(s);
    }
    CacheImpl<A>::InitArcIterator(s, data);
  }

  void Expand(StateId s) {
    for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
      A arc = aiter.Value();

      // relabel input
      if (relabel_input_) {
        typename unordered_map<Label, Label>::iterator it =
          input_map_.find(arc.ilabel);
        if (it != input_map_.end()) { arc.ilabel = it->second; }
      }

      // relabel output
      if (relabel_output_) {
        typename unordered_map<Label, Label>::iterator it =
          output_map_.find(arc.olabel);
        if (it != output_map_.end()) { arc.olabel = it->second; }
      }

      PushArc(s, arc);
    }
    SetArcs(s);
  }


 private:
  const Fst<A> *fst_;

  unordered_map<Label, Label> input_map_;
  unordered_map<Label, Label> output_map_;
  bool relabel_input_;
  bool relabel_output_;

  void operator=(const RelabelFstImpl<A> &);  // disallow
};


//
// \class RelabelFst
// \brief Delayed implementation of arc relabeling
//
// This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A>
class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
 public:
  friend class ArcIterator< RelabelFst<A> >;
  friend class StateIterator< RelabelFst<A> >;

  typedef A Arc;
  typedef typename A::Label   Label;
  typedef typename A::Weight  Weight;
  typedef typename A::StateId StateId;
  typedef CacheState<A> State;
  typedef RelabelFstImpl<A> Impl;

  RelabelFst(const Fst<A>& fst,
             const vector<pair<Label, Label> >& ipairs,
             const vector<pair<Label, Label> >& opairs)
      : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}

  RelabelFst(const Fst<A>& fst,
             const vector<pair<Label, Label> >& ipairs,
             const vector<pair<Label, Label> >& opairs,
             const RelabelFstOptions &opts)
      : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}

  RelabelFst(const Fst<A>& fst,
             const SymbolTable* new_isymbols,
             const SymbolTable* new_osymbols)
      : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
                                 fst.OutputSymbols(), new_osymbols,
                                 RelabelFstOptions())) {}

  RelabelFst(const Fst<A>& fst,
             const SymbolTable* new_isymbols,
             const SymbolTable* new_osymbols,
             const RelabelFstOptions &opts)
      : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
                                 fst.OutputSymbols(), new_osymbols, opts)) {}

  RelabelFst(const Fst<A>& fst,
             const SymbolTable* old_isymbols,
             const SymbolTable* new_isymbols,
             const SymbolTable* old_osymbols,
             const SymbolTable* new_osymbols)
    : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
                               new_osymbols, RelabelFstOptions())) {}

  RelabelFst(const Fst<A>& fst,
             const SymbolTable* old_isymbols,
             const SymbolTable* new_isymbols,
             const SymbolTable* old_osymbols,
             const SymbolTable* new_osymbols,
             const RelabelFstOptions &opts)
    : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
                               new_osymbols, opts)) {}

  // See Fst<>::Copy() for doc.
  RelabelFst(const RelabelFst<A> &fst, bool safe = false)
    : ImplToFst<Impl>(fst, safe) {}

  // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
  virtual RelabelFst<A> *Copy(bool safe = false) const {
    return new RelabelFst<A>(*this, safe);
  }

  virtual void InitStateIterator(StateIteratorData<A> *data) const;

  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
    return GetImpl()->InitArcIterator(s, data);
  }

 private:
  // Makes visible to friends.
  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }

  void operator=(const RelabelFst<A> &fst);  // disallow
};

// Specialization for RelabelFst.
template<class A>
class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
 public:
  typedef typename A::StateId StateId;

  explicit StateIterator(const RelabelFst<A> &fst)
      : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}

  bool Done() const { return siter_.Done(); }

  StateId Value() const { return s_; }

  void Next() {
    if (!siter_.Done()) {
      ++s_;
      siter_.Next();
    }
  }

  void Reset() {
    s_ = 0;
    siter_.Reset();
  }

 private:
  bool Done_() const { return Done(); }
  StateId Value_() const { return Value(); }
  void Next_() { Next(); }
  void Reset_() { Reset(); }

  const RelabelFstImpl<A> *impl_;
  StateIterator< Fst<A> > siter_;
  StateId s_;

  DISALLOW_COPY_AND_ASSIGN(StateIterator);
};


// Specialization for RelabelFst.
template <class A>
class ArcIterator< RelabelFst<A> >
    : public CacheArcIterator< RelabelFst<A> > {
 public:
  typedef typename A::StateId StateId;

  ArcIterator(const RelabelFst<A> &fst, StateId s)
      : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
    if (!fst.GetImpl()->HasArcs(s))
      fst.GetImpl()->Expand(s);
  }

 private:
  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
};

template <class A> inline
void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
  data->base = new StateIterator< RelabelFst<A> >(*this);
}

// Useful alias when using StdArc.
typedef RelabelFst<StdArc> StdRelabelFst;

}  // namespace fst

#endif  // FST_LIB_RELABEL_H__