// prune.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Functions implementing pruning.

#ifndef FST_LIB_PRUNE_H__
#define FST_LIB_PRUNE_H__

#include <vector>
using std::vector;

#include <fst/arcfilter.h>
#include <fst/heap.h>
#include <fst/shortest-distance.h>


namespace fst {

template <class A, class ArcFilter>
class PruneOptions {
 public:
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;

  // Pruning weight threshold.
  Weight weight_threshold;
  // Pruning state threshold.
  StateId state_threshold;
  // Arc filter.
  ArcFilter filter;
  // If non-zero, passes in pre-computed shortest distance to final states.
  const vector<Weight> *distance;
  // Determines the degree of convergence required when computing shortest
  // distances.
  float delta;

  explicit PruneOptions(const Weight& w, StateId s, ArcFilter f,
                        vector<Weight> *d = 0, float e = kDelta)
      : weight_threshold(w),
        state_threshold(s),
        filter(f),
        distance(d),
        delta(e) {}
 private:
  PruneOptions();  // disallow
};


template <class S, class W>
class PruneCompare {
 public:
  typedef S StateId;
  typedef W Weight;

  PruneCompare(const vector<Weight> &idistance,
               const vector<Weight> &fdistance)
      : idistance_(idistance), fdistance_(fdistance) {}

  bool operator()(const StateId x, const StateId y) const {
    Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(),
                      x < fdistance_.size() ? fdistance_[x] : Weight::Zero());
    Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(),
                      y < fdistance_.size() ? fdistance_[y] : Weight::Zero());
    return less_(wx, wy);
  }

 private:
  const vector<Weight> &idistance_;
  const vector<Weight> &fdistance_;
  NaturalLess<Weight> less_;
};



// Pruning algorithm: this version modifies its input and it takes an
// options class as an argment. Delete states and arcs in 'fst' that
// do not belong to a successful path whose weight is no more than
// the weight of the shortest path Times() 'opts.weight_threshold'.
// When 'opts.state_threshold != kNoStateId', the resulting transducer
// will restricted further to have at most 'opts.state_threshold'
// states. Weights need to be commutative and have the path
// property. The weight 'w' of any cycle needs to be bounded, i.e.,
// 'Plus(w, W::One()) = One()'.
template <class Arc, class ArcFilter>
void Prune(MutableFst<Arc> *fst,
           const PruneOptions<Arc, ArcFilter> &opts) {
  typedef typename Arc::Weight Weight;
  typedef typename Arc::StateId StateId;

  if ((Weight::Properties() & (kPath | kCommutative))
      != (kPath | kCommutative)) {
    FSTERROR() << "Prune: Weight needs to have the path property and"
               << " be commutative: "
               << Weight::Type();
    fst->SetProperties(kError, kError);
    return;
  }
  StateId ns = fst->NumStates();
  if (ns == 0) return;
  vector<Weight> idistance(ns, Weight::Zero());
  vector<Weight> tmp;
  if (!opts.distance) {
    tmp.reserve(ns);
    ShortestDistance(*fst, &tmp, true, opts.delta);
  }
  const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;

  if ((opts.state_threshold == 0) ||
      (fdistance->size() <= fst->Start()) ||
      ((*fdistance)[fst->Start()] == Weight::Zero())) {
    fst->DeleteStates();
    return;
  }
  PruneCompare<StateId, Weight> compare(idistance, *fdistance);
  Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
  vector<bool> visited(ns, false);
  vector<size_t> enqueued(ns, kNoKey);
  vector<StateId> dead;
  dead.push_back(fst->AddState());
  NaturalLess<Weight> less;
  Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold);

  StateId num_visited = 0;
  StateId s = fst->Start();
  if (!less(limit, (*fdistance)[s])) {
    idistance[s] = Weight::One();
    enqueued[s] = heap.Insert(s);
    ++num_visited;
  }

  while (!heap.Empty()) {
    s = heap.Top();
    heap.Pop();
    enqueued[s] = kNoKey;
    visited[s] = true;
    if (less(limit, Times(idistance[s], fst->Final(s))))
      fst->SetFinal(s, Weight::Zero());
    for (MutableArcIterator< MutableFst<Arc> > ait(fst, s);
         !ait.Done();
         ait.Next()) {
      Arc arc = ait.Value();
      if (!opts.filter(arc)) continue;
      Weight weight = Times(Times(idistance[s], arc.weight),
                            arc.nextstate < fdistance->size()
                            ? (*fdistance)[arc.nextstate]
                            : Weight::Zero());
      if (less(limit, weight)) {
        arc.nextstate = dead[0];
        ait.SetValue(arc);
        continue;
      }
      if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate]))
        idistance[arc.nextstate] = Times(idistance[s], arc.weight);
      if (visited[arc.nextstate]) continue;
      if ((opts.state_threshold != kNoStateId) &&
          (num_visited >= opts.state_threshold))
        continue;
      if (enqueued[arc.nextstate] == kNoKey) {
        enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
        ++num_visited;
      } else {
        heap.Update(enqueued[arc.nextstate], arc.nextstate);
      }
    }
  }
  for (size_t i = 0; i < visited.size(); ++i)
    if (!visited[i]) dead.push_back(i);
  fst->DeleteStates(dead);
}


// Pruning algorithm: this version modifies its input and simply takes
// the pruning threshold as an argument. Delete states and arcs in
// 'fst' that do not belong to a successful path whose weight is no
// more than the weight of the shortest path Times()
// 'weight_threshold'.  When 'state_threshold != kNoStateId', the
// resulting transducer will be restricted further to have at most
// 'opts.state_threshold' states. Weights need to be commutative and
// have the path property. The weight 'w' of any cycle needs to be
// bounded, i.e., 'Plus(w, W::One()) = One()'.
template <class Arc>
void Prune(MutableFst<Arc> *fst,
           typename Arc::Weight weight_threshold,
           typename Arc::StateId state_threshold = kNoStateId,
           double delta = kDelta) {
  PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
                                             AnyArcFilter<Arc>(), 0, delta);
  Prune(fst, opts);
}


// Pruning algorithm: this version writes the pruned input Fst to an
// output MutableFst and it takes an options class as an argument.
// 'ofst' contains states and arcs that belong to a successful path in
// 'ifst' whose weight is no more than the weight of the shortest path
// Times() 'opts.weight_threshold'. When 'opts.state_threshold !=
// kNoStateId', 'ofst' will be restricted further to have at most
// 'opts.state_threshold' states. Weights need to be commutative and
// have the path property. The weight 'w' of any cycle needs to be
// bounded, i.e., 'Plus(w, W::One()) = One()'.
template <class Arc, class ArcFilter>
void Prune(const Fst<Arc> &ifst,
           MutableFst<Arc> *ofst,
           const PruneOptions<Arc, ArcFilter> &opts) {
  typedef typename Arc::Weight Weight;
  typedef typename Arc::StateId StateId;

  if ((Weight::Properties() & (kPath | kCommutative))
      != (kPath | kCommutative)) {
    FSTERROR() << "Prune: Weight needs to have the path property and"
               << " be commutative: "
               << Weight::Type();
    ofst->SetProperties(kError, kError);
    return;
  }
  ofst->DeleteStates();
  ofst->SetInputSymbols(ifst.InputSymbols());
  ofst->SetOutputSymbols(ifst.OutputSymbols());
  if (ifst.Start() == kNoStateId)
    return;
  NaturalLess<Weight> less;
  if (less(opts.weight_threshold, Weight::One()) ||
      (opts.state_threshold == 0))
    return;
  vector<Weight> idistance;
  vector<Weight> tmp;
  if (!opts.distance)
    ShortestDistance(ifst, &tmp, true, opts.delta);
  const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;

  if ((fdistance->size() <= ifst.Start()) ||
      ((*fdistance)[ifst.Start()] == Weight::Zero())) {
    return;
  }
  PruneCompare<StateId, Weight> compare(idistance, *fdistance);
  Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
  vector<StateId> copy;
  vector<size_t> enqueued;
  vector<bool> visited;

  StateId s = ifst.Start();
  Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(),
                         opts.weight_threshold);
  while (copy.size() <= s)
    copy.push_back(kNoStateId);
  copy[s] = ofst->AddState();
  ofst->SetStart(copy[s]);
  while (idistance.size() <= s)
    idistance.push_back(Weight::Zero());
  idistance[s] = Weight::One();
  while (enqueued.size() <= s) {
    enqueued.push_back(kNoKey);
    visited.push_back(false);
  }
  enqueued[s] = heap.Insert(s);

  while (!heap.Empty()) {
    s = heap.Top();
    heap.Pop();
    enqueued[s] = kNoKey;
    visited[s] = true;
    if (!less(limit, Times(idistance[s], ifst.Final(s))))
      ofst->SetFinal(copy[s], ifst.Final(s));
    for (ArcIterator< Fst<Arc> > ait(ifst, s);
         !ait.Done();
         ait.Next()) {
      const Arc &arc = ait.Value();
      if (!opts.filter(arc)) continue;
      Weight weight = Times(Times(idistance[s], arc.weight),
                            arc.nextstate < fdistance->size()
                            ? (*fdistance)[arc.nextstate]
                            : Weight::Zero());
      if (less(limit, weight)) continue;
      if ((opts.state_threshold != kNoStateId) &&
          (ofst->NumStates() >= opts.state_threshold))
        continue;
      while (idistance.size() <= arc.nextstate)
        idistance.push_back(Weight::Zero());
      if (less(Times(idistance[s], arc.weight),
               idistance[arc.nextstate]))
        idistance[arc.nextstate] = Times(idistance[s], arc.weight);
      while (copy.size() <= arc.nextstate)
        copy.push_back(kNoStateId);
      if (copy[arc.nextstate] == kNoStateId)
        copy[arc.nextstate] = ofst->AddState();
      ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
                                copy[arc.nextstate]));
      while (enqueued.size() <= arc.nextstate) {
        enqueued.push_back(kNoKey);
        visited.push_back(false);
      }
      if (visited[arc.nextstate]) continue;
      if (enqueued[arc.nextstate] == kNoKey)
        enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
      else
        heap.Update(enqueued[arc.nextstate], arc.nextstate);
    }
  }
}


// Pruning algorithm: this version writes the pruned input Fst to an
// output MutableFst and simply takes the pruning threshold as an
// argument.  'ofst' contains states and arcs that belong to a
// successful path in 'ifst' whose weight is no more than
// the weight of the shortest path Times() 'weight_threshold'. When
// 'state_threshold != kNoStateId', 'ofst' will be restricted further
// to have at most 'opts.state_threshold' states. Weights need to be
// commutative and have the path property. The weight 'w' of any cycle
// needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'.
template <class Arc>
void Prune(const Fst<Arc> &ifst,
           MutableFst<Arc> *ofst,
           typename Arc::Weight weight_threshold,
           typename Arc::StateId state_threshold = kNoStateId,
           float delta = kDelta) {
  PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
                                             AnyArcFilter<Arc>(), 0, delta);
  Prune(ifst, ofst, opts);
}

}  // namespace fst

#endif // FST_LIB_PRUNE_H_