// queue.h
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
//
// \file
// Functions and classes for various Fst state queues with
// a unified interface.

#ifndef FST_LIB_QUEUE_H__
#define FST_LIB_QUEUE_H__

#include <deque>
#include <vector>

#include "fst/lib/arcfilter.h"
#include "fst/lib/connect.h"
#include "fst/lib/heap.h"
#include "fst/lib/topsort.h"

namespace fst {

// template <class S>
// class Queue {
//  public:
//   typedef typename S StateId;
//
//   // Ctr: may need args (e.g., Fst, comparator) for some queues
//   Queue(...);
//   // Returns the head of the queue
//   StateId Head() const;
//   // Inserts a state
//   void Enqueue(StateId s);
//   // Removes the head of the queue
//   void Dequeue();
//   // Updates ordering of state s when weight changes, if necessary
//   void Update(StateId s);
//   // Does the queue contain no elements?
//   bool Empty() const;
//   // Remove all states from queue
//   void Clear();
// };

// State queue types.
enum QueueType {
  TRIVIAL_QUEUE = 0,         // Single state queue
  FIFO_QUEUE = 1,            // First-in, first-out queue
  LIFO_QUEUE = 2,            // Last-in, first-out queue
  SHORTEST_FIRST_QUEUE = 3,  // Shortest-first queue
  TOP_ORDER_QUEUE = 4,       // Topologically-ordered queue
  STATE_ORDER_QUEUE = 5,     // State-ID ordered queue
  SCC_QUEUE = 6,             // Component graph top-ordered meta-queue
  AUTO_QUEUE = 7,            // Auto-selected queue
  OTHER_QUEUE = 8
 };


// QueueBase, templated on the StateId, is the base class shared by the
// queues considered by AutoQueue.
template <class S>
class QueueBase {
 public:
  typedef S StateId;

  QueueBase(QueueType type) : queue_type_(type) {}
  virtual ~QueueBase() {}
  StateId Head() const { return Head_(); }
  void Enqueue(StateId s) { Enqueue_(s); }
  void Dequeue() { Dequeue_(); }
  void Update(StateId s) { Update_(s); }
  bool Empty() const { return Empty_(); }
  void Clear() { Clear_(); }
  QueueType Type() { return queue_type_; }

 private:
  virtual StateId Head_() const = 0;
  virtual void Enqueue_(StateId s) = 0;
  virtual void Dequeue_() = 0;
  virtual void Update_(StateId s) = 0;
  virtual bool Empty_() const = 0;
  virtual void Clear_() = 0;

  QueueType queue_type_;
};


// Trivial queue discipline, templated on the StateId. You may enqueue
// at most one state at a time. It is used for strongly connected components
// with only one state and no self loops.
template <class S>
class TrivialQueue : public QueueBase<S> {
public:
  typedef S StateId;

  TrivialQueue() : QueueBase<S>(TRIVIAL_QUEUE), front_(kNoStateId) {}
  StateId Head() const { return front_; }
  void Enqueue(StateId s) { front_ = s; }
  void Dequeue() { front_ = kNoStateId; }
  void Update(StateId s) {}
  bool Empty() const { return front_ == kNoStateId; }
  void Clear() { front_ = kNoStateId; }


private:
  virtual StateId Head_() const { return Head(); }
  virtual void Enqueue_(StateId s) { Enqueue(s); }
  virtual void Dequeue_() { Dequeue(); }
  virtual void Update_(StateId s) { Update(s); }
  virtual bool Empty_() const { return Empty(); }
  virtual void Clear_() { return Clear(); }

  StateId front_;
};


// First-in, first-out queue discipline, templated on the StateId.
template <class S>
class FifoQueue : public QueueBase<S>, public deque<S> {
 public:
  using deque<S>::back;
  using deque<S>::push_front;
  using deque<S>::pop_back;
  using deque<S>::empty;
  using deque<S>::clear;

  typedef S StateId;

  FifoQueue() : QueueBase<S>(FIFO_QUEUE) {}
  StateId Head() const { return back(); }
  void Enqueue(StateId s) { push_front(s); }
  void Dequeue() { pop_back(); }
  void Update(StateId s) {}
  bool Empty() const { return empty(); }
  void Clear() { clear(); }

 private:
  virtual StateId Head_() const { return Head(); }
  virtual void Enqueue_(StateId s) { Enqueue(s); }
  virtual void Dequeue_() { Dequeue(); }
  virtual void Update_(StateId s) { Update(s); }
  virtual bool Empty_() const { return Empty(); }
  virtual void Clear_() { return Clear(); }
};


// Last-in, first-out queue discipline, templated on the StateId.
template <class S>
class LifoQueue : public QueueBase<S>, public deque<S> {
 public:
  using deque<S>::front;
  using deque<S>::push_front;
  using deque<S>::pop_front;
  using deque<S>::empty;
  using deque<S>::clear;

  typedef S StateId;

  LifoQueue() : QueueBase<S>(LIFO_QUEUE) {}
  StateId Head() const { return front(); }
  void Enqueue(StateId s) { push_front(s); }
  void Dequeue() { pop_front(); }
  void Update(StateId s) {}
  bool Empty() const { return empty(); }
  void Clear() { clear(); }

 private:
  virtual StateId Head_() const { return Head(); }
  virtual void Enqueue_(StateId s) { Enqueue(s); }
  virtual void Dequeue_() { Dequeue(); }
  virtual void Update_(StateId s) { Update(s); }
  virtual bool Empty_() const { return Empty(); }
  virtual void Clear_() { return Clear(); }
};


// Shortest-first queue discipline, templated on the StateId and
// comparison function object.  Comparison function object COMP is
// used to compare two StateIds. If a (single) state's order changes,
// it can be reordered in the queue with a call to Update().
template <typename S, typename C>
class ShortestFirstQueue : public QueueBase<S> {
 public:
  typedef S StateId;
  typedef C Compare;

  ShortestFirstQueue(C comp)
      : QueueBase<S>(SHORTEST_FIRST_QUEUE), heap_(comp) {}

  StateId Head() const { return heap_.Top(); }

  void Enqueue(StateId s) {
    for (StateId i = key_.size(); i <= s; ++i)
      key_.push_back(kNoKey);
    key_[s] = heap_.Insert(s);
  }

  void Dequeue() {
    key_[heap_.Pop()] = kNoKey;
  }

  void Update(StateId s) {
    if (s >= (StateId)key_.size() || key_[s] == kNoKey) {
      Enqueue(s);
    } else {
      heap_.Update(key_[s], s);
    }
  }

  bool Empty() const { return heap_.Empty(); }

  void Clear() {
    heap_.Clear();
    key_.clear();
  }

 private:
  Heap<S, C> heap_;
  vector<ssize_t> key_;

  virtual StateId Head_() const { return Head(); }
  virtual void Enqueue_(StateId s) { Enqueue(s); }
  virtual void Dequeue_() { Dequeue(); }
  virtual void Update_(StateId s) { Update(s); }
  virtual bool Empty_() const { return Empty(); }
  virtual void Clear_() { return Clear(); }
};


// Given a vector that maps from states to weights and a Less
// comparison function object between weights, this class defines a
// comparison function object between states.
template <typename S, typename L>
class StateWeightCompare {
 public:
  typedef L Less;
  typedef typename L::Weight Weight;
  typedef S StateId;

  StateWeightCompare(const vector<Weight>* weights, const L &less)
      : weights_(weights), less_(less) {}

  bool operator()(const S x, const S y) const {
    return less_((*weights_)[x], (*weights_)[y]);
  }

 private:
  const vector<Weight>* weights_;
  L less_;
};


// Shortest-first queue discipline, templated on the StateId and Weight is
// specialized to use the weight's natural order for the comparion function.
template <typename S, typename W>
class NaturalShortestFirstQueue :
      public ShortestFirstQueue<S, StateWeightCompare<S, NaturalLess<W> > > {
 public:
  typedef StateWeightCompare<S, NaturalLess<W> > C;

  NaturalShortestFirstQueue(vector<W> *distance) :
      ShortestFirstQueue<S, C>(C(distance, less_)) {}

 private:
  NaturalLess<W> less_;
};


// Topological-order queue discipline, templated on the StateId.
// States are ordered in the queue topologically. The FST must be acyclic.
template <class S>
class TopOrderQueue : public QueueBase<S> {
 public:
  typedef S StateId;

  // This constructor computes the top. order. It accepts an arc filter
  // to limit the transitions considered in that computation (e.g., only
  // the epsilon graph).
  template <class Arc, class ArcFilter>

  TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter)
      : QueueBase<S>(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId),
        order_(0), state_(0) {
    bool acyclic;
    TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic);
    DfsVisit(fst, &top_order_visitor, filter);
    if (!acyclic)
      LOG(FATAL) << "TopOrderQueue: fst is not acyclic.";
    state_.resize(order_.size(), kNoStateId);
  }

  // This constructor is passed the top. order, useful when we know it
  // beforehand.
  TopOrderQueue(const vector<StateId> &order)
      : QueueBase<S>(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId),
        order_(order), state_(order.size(), kNoStateId) {}

  StateId Head() const { return state_[front_]; }

  void Enqueue(StateId s) {
    if (front_ > back_) front_ = back_ = order_[s];
    else if (order_[s] > back_) back_ = order_[s];
    else if (order_[s] < front_) front_ = order_[s];
    state_[order_[s]] = s;
  }

  void Dequeue() {
    state_[front_] = kNoStateId;
    while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_;
  }

  void Update(StateId s) {}

  bool Empty() const { return front_ > back_; }

  void Clear() {
    for (StateId i = front_; i <= back_; ++i) state_[i] = kNoStateId;
    back_ = kNoStateId;
    front_ = 0;
  }

 private:
  StateId front_;
  StateId back_;
  vector<StateId> order_;
  vector<StateId> state_;

  virtual StateId Head_() const { return Head(); }
  virtual void Enqueue_(StateId s) { Enqueue(s); }
  virtual void Dequeue_() { Dequeue(); }
  virtual void Update_(StateId s) { Update(s); }
  virtual bool Empty_() const { return Empty(); }
  virtual void Clear_() { return Clear(); }

};


// State order queue discipline, templated on the StateId.
// States are ordered in the queue by state Id.
template <class S>
class StateOrderQueue : public QueueBase<S> {
public:
  typedef S StateId;

  StateOrderQueue()
      : QueueBase<S>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {}

  StateId Head() const { return front_; }

  void Enqueue(StateId s) {
    if (front_ > back_) front_ = back_ = s;
    else if (s > back_) back_ = s;
    else if (s < front_) front_ = s;
    while ((StateId)enqueued_.size() <= s) enqueued_.push_back(false); 
    enqueued_[s] = true;
  }

  void Dequeue() {
    enqueued_[front_] = false;
    while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_;
  }

  void Update(StateId s) {}

  bool Empty() const { return front_ > back_; }

  void Clear() {
        for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false;
        front_ = 0;
        back_ = kNoStateId;
  }

private:
  StateId front_;
  StateId back_;
  vector<bool> enqueued_;

  virtual StateId Head_() const { return Head(); }
  virtual void Enqueue_(StateId s) { Enqueue(s); }
  virtual void Dequeue_() { Dequeue(); }
  virtual void Update_(StateId s) { Update(s); }
  virtual bool Empty_() const { return Empty(); }
  virtual void Clear_() { return Clear(); }

};


// SCC topological-order meta-queue discipline, templated on the StateId S
// and a queue Q, which is used inside each SCC.  It visits the SCC's
// of an FST in topological order. Its constructor is passed the queues to
// to use within an SCC.
template <class S, class Q>
class SccQueue : public QueueBase<S> {
 public:
  typedef S StateId;
  typedef Q Queue;

  // Constructor takes a vector specifying the SCC number per state
  // and a vector giving the queue to use per SCC number.
  SccQueue(const vector<StateId> &scc, vector<Queue*> *queue)
    : QueueBase<S>(SCC_QUEUE), queue_(queue), scc_(scc), front_(0),
      back_(kNoStateId) {}

  StateId Head() const {
    while ((front_ <= back_) && 
           (((*queue_)[front_] && (*queue_)[front_]->Empty())
            || (((*queue_)[front_] == 0) &&
                ((front_ > (StateId)trivial_queue_.size())
                 || (trivial_queue_[front_] == kNoStateId)))))
      ++front_;
    if (front_ > back_)
      LOG(FATAL) << "SccQueue: head of empty queue";
    if ((*queue_)[front_])
      return (*queue_)[front_]->Head();
    else
      return trivial_queue_[front_];
  }

  void Enqueue(StateId s) {
    if (front_ > back_) front_ = back_ = scc_[s];
    else if (scc_[s] > back_) back_ = scc_[s];
    else if (scc_[s] < front_) front_ = scc_[s];
    if ((*queue_)[scc_[s]]) {
      (*queue_)[scc_[s]]->Enqueue(s);
    } else {
      while ( (StateId)trivial_queue_.size() <= scc_[s]) 
        trivial_queue_.push_back(kNoStateId);
      trivial_queue_[scc_[s]] = s;
    }
  }

  void Dequeue() {
    if (front_ > back_)
      LOG(FATAL) << "SccQueue: dequeue of empty queue";
    if ((*queue_)[front_])
      (*queue_)[front_]->Dequeue();
    else if (front_ < (StateId)trivial_queue_.size()) 
      trivial_queue_[front_] = kNoStateId;
  }

  void Update(StateId s) {
    if ((*queue_)[scc_[s]])
      (*queue_)[scc_[s]]->Update(s);
  }

  bool Empty() const {
    if (front_ < back_)   // Queue scc # back_ not empty unless back_==front_
      return false;
    else if (front_ > back_)
      return true;
    else if ((*queue_)[front_])
      return (*queue_)[front_]->Empty();
    else
      return (front_ > (StateId)trivial_queue_.size())
        || (trivial_queue_[front_] == kNoStateId);
  }

  void Clear() {
    for (StateId i = front_; i <= back_; ++i)
      if ((*queue_)[i])
        (*queue_)[i]->Clear();
      else if (i < (StateId)trivial_queue_.size()) 
        trivial_queue_[i] = kNoStateId;
    front_ = 0;
    back_ = kNoStateId;
  }

private:
  vector<Queue*> *queue_;
  const vector<StateId> &scc_;
  mutable StateId front_;
  StateId back_;
  vector<StateId> trivial_queue_;

  virtual StateId Head_() const { return Head(); }
  virtual void Enqueue_(StateId s) { Enqueue(s); }
  virtual void Dequeue_() { Dequeue(); }
  virtual void Update_(StateId s) { Update(s); }
  virtual bool Empty_() const { return Empty(); }
  virtual void Clear_() { return Clear(); }
};


// Automatic queue discipline, templated on the StateId. It selects a
// queue discipline for a given FST based on its properties.
template <class S>
class AutoQueue : public QueueBase<S> {
public:
  typedef S StateId;

  // This constructor takes a state distance vector that, if non-null and if
  // the Weight type has the path property, will entertain the
  // shortest-first queue using the natural order w.r.t to the distance.
  template <class Arc, class ArcFilter>
  AutoQueue(const Fst<Arc> &fst, const vector<typename Arc::Weight> *distance,
            ArcFilter filter) : QueueBase<S>(AUTO_QUEUE) {
    typedef typename Arc::Weight Weight;
    typedef StateWeightCompare< StateId, NaturalLess<Weight> > Compare;

    //  First check if the FST is known to have these properties.
    uint64 props = fst.Properties(kAcyclic | kCyclic |
                                  kTopSorted | kUnweighted, false);
    if ((props & kTopSorted) || fst.Start() == kNoStateId) {
      queue_ = new StateOrderQueue<StateId>();
      VLOG(2) << "AutoQueue: using state-order discipline";
    } else if (props & kAcyclic) {
      queue_ = new TopOrderQueue<StateId>(fst, filter);
      VLOG(2) << "AutoQueue: using top-order discipline";
    } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) {
      queue_ = new LifoQueue<StateId>();
      VLOG(2) << "AutoQueue: using LIFO discipline";
    } else {
      uint64 props;
      // Decompose into strongly-connected components.
      SccVisitor<Arc> scc_visitor(&scc_, 0, 0, &props);
      DfsVisit(fst, &scc_visitor, filter);
      StateId nscc = *max_element(scc_.begin(), scc_.end()) + 1;
      vector<QueueType> queue_types(nscc);
      NaturalLess<Weight> *less = 0;
      Compare *comp = 0;
      if (distance && (Weight::Properties() & kPath)) {
        less = new NaturalLess<Weight>;
        comp = new Compare(distance, *less);
      }
      // Find the queue type to use per SCC.
      bool unweighted;
      bool all_trivial;
      SccQueueType(fst, scc_, &queue_types, filter, less, &all_trivial,
                                      &unweighted);
      // If unweighted and semiring is idempotent, use lifo queue.
      if (unweighted) {
          queue_ = new LifoQueue<StateId>();
          VLOG(2) << "AutoQueue: using LIFO discipline";
          delete comp;
          delete less;
          return;
      }
      // If all the scc are trivial, FST is acyclic and the scc# gives
      // the topological order.
      if (all_trivial) {
          queue_ = new TopOrderQueue<StateId>(scc_);
          VLOG(2) << "AutoQueue: using top-order discipline";
          delete comp;
          delete less;
          return;
      }
      VLOG(2) << "AutoQueue: using SCC meta-discipline";
      queues_.resize(nscc);
      for (StateId i = 0; i < nscc; ++i) {
        switch(queue_types[i]) {
          case TRIVIAL_QUEUE:
            queues_[i] = 0;
            VLOG(3) << "AutoQueue: SCC #" << i
                    << ": using trivial discipline";
            break;
          case SHORTEST_FIRST_QUEUE:
            CHECK(comp);
            queues_[i] = new ShortestFirstQueue<StateId, Compare>(*comp);
            VLOG(3) << "AutoQueue: SCC #" << i <<
              ": using shortest-first discipline";
            break;
          case LIFO_QUEUE:
            queues_[i] = new LifoQueue<StateId>();
            VLOG(3) << "AutoQueue: SCC #" << i
                    << ": using LIFO disciplle";
            break;
          case FIFO_QUEUE:
          default:
            queues_[i] = new FifoQueue<StateId>();
            VLOG(3) << "AutoQueue: SCC #" << i
                    << ": using FIFO disciplle";
            break;
        }
      }
      queue_ = new SccQueue< StateId, QueueBase<StateId> >(scc_, &queues_);
      delete comp;
      delete less;
    }
  }

  ~AutoQueue() {
    for (StateId i = 0; i < (StateId)queues_.size(); ++i) /*naucen-edit*/
      delete queues_[i];
    delete queue_;
  }

  StateId Head() const { return queue_->Head(); }

  void Enqueue(StateId s) { queue_->Enqueue(s); }

  void Dequeue() { queue_->Dequeue(); }

  void Update(StateId s) { queue_->Update(s); }

  bool Empty() const { return queue_->Empty(); }

  void Clear() { queue_->Clear(); }


 private:
  QueueBase<StateId> *queue_;
  vector< QueueBase<StateId>* > queues_;
  vector<StateId> scc_;

  template <class Arc, class ArcFilter, class Less>
  static void SccQueueType(const Fst<Arc> &fst,
                           const vector<StateId> &scc,
                           vector<QueueType> *queue_types,
                           ArcFilter filter, Less *less,
                           bool *all_trivial, bool *unweighted);

  virtual StateId Head_() const { return Head(); }

  virtual void Enqueue_(StateId s) { Enqueue(s); }

  virtual void Dequeue_() { Dequeue(); }

  virtual void Update_(StateId s) { Update(s); }

  virtual bool Empty_() const { return Empty(); }

  virtual void Clear_() { return Clear(); }
};


// Examines the states in an Fst's strongly connected components and
// determines which type of queue to use per SCC. Stores result in
// vector QUEUE_TYPES, which is assumed to have length equal to the
// number of SCCs. An arc filter is used to limit the transitions
// considered (e.g., only the epsilon graph).  ALL_TRIVIAL is set
// to true if every queue is the trivial queue. UNWEIGHTED is set to
// true if the semiring is idempotent and all the arc weights are equal to
// Zero() or One().
template <class StateId>
template <class A, class ArcFilter, class Less>
void AutoQueue<StateId>::SccQueueType(const Fst<A> &fst,
                                      const vector<StateId> &scc,
                                      vector<QueueType> *queue_type,
                                      ArcFilter filter, Less *less,
                                      bool *all_trivial, bool *unweighted) {
  typedef A Arc;
  typedef typename A::StateId StateId;
  typedef typename A::Weight Weight;

  *all_trivial = true;
  *unweighted = true;

  for (StateId i = 0; i < (StateId)queue_type->size(); ++i)
    (*queue_type)[i] = TRIVIAL_QUEUE;

  for (StateIterator< Fst<Arc> > sit(fst); !sit.Done(); sit.Next()) {
    StateId state = sit.Value();
    for (ArcIterator< Fst<Arc> > ait(fst, state);
	 !ait.Done();
	 ait.Next()) {
      const Arc &arc = ait.Value();
      if (!filter(arc)) continue;
      if (scc[state] == scc[arc.nextstate]) {
        QueueType &type = (*queue_type)[scc[state]];
        if (!less || ((*less)(arc.weight, Weight::One())))
          type = FIFO_QUEUE;
        else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE))
          if (!(Weight::Properties() & kIdempotent) ||
              (arc.weight != Weight::Zero() && arc.weight != Weight::One()))
            type = SHORTEST_FIRST_QUEUE;
          else
            type = LIFO_QUEUE;
        if (type != TRIVIAL_QUEUE) *all_trivial = false;
      }
      if (!(Weight::Properties() & kIdempotent) ||
          (arc.weight != Weight::Zero() && arc.weight != Weight::One()))
        *unweighted = false;
    }
  }
}

}  // namespace fst

#endif