// cache.h
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// \file
// An Fst implementation that caches FST elements of a delayed
// computation.

#ifndef FST_LIB_CACHE_H__
#define FST_LIB_CACHE_H__

#include <list>

#include "fst/lib/vector-fst.h"

DECLARE_bool(fst_default_cache_gc);
DECLARE_int64(fst_default_cache_gc_limit);

namespace fst {

struct CacheOptions {
  bool gc;          // enable GC
  size_t gc_limit;  // # of bytes allowed before GC


  CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
  CacheOptions()
      : gc(FLAGS_fst_default_cache_gc),
        gc_limit(FLAGS_fst_default_cache_gc_limit) {}
};


// This is a VectorFstBaseImpl container that holds a State similar to
// VectorState but additionally has a flags data member (see
// CacheState below). This class is used to cache FST elements with
// the flags used to indicate what has been cached. Use HasStart()
// HasFinal(), and HasArcs() to determine if cached and SetStart(),
// SetFinal(), AddArc(), and SetArcs() to cache. Note you must set the
// final weight even if the state is non-final to mark it as
// cached. If the 'gc' option is 'false', cached items have the extent
// of the FST - minimizing computation. If the 'gc' option is 'true',
// garbage collection of states (not in use in an arc iterator) is
// performed, in a rough approximation of LRU order, when 'gc_limit'
// bytes is reached - controlling memory use. When 'gc_limit' is 0,
// special optimizations apply - minimizing memory use.

template <class S>
class CacheBaseImpl : public VectorFstBaseImpl<S> {
 public:
  using FstImpl<typename S::Arc>::Type;
  using VectorFstBaseImpl<S>::NumStates;
  using VectorFstBaseImpl<S>::AddState;

  typedef S State;
  typedef typename S::Arc Arc;
  typedef typename Arc::Weight Weight;
  typedef typename Arc::StateId StateId;

  CacheBaseImpl()
      : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
        cache_first_state_id_(kNoStateId), cache_first_state_(0),
        cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
        cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
                     FLAGS_fst_default_cache_gc_limit == 0 ?
                     FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {}

  explicit CacheBaseImpl(const CacheOptions &opts)
      : cache_start_(false), nknown_states_(0),
        min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
        cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
        cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
                     opts.gc_limit : kMinCacheLimit) {}

  ~CacheBaseImpl() {
    delete cache_first_state_;
  }

  // Gets a state from its ID; state must exist.
  const S *GetState(StateId s) const {
    if (s == cache_first_state_id_)
      return cache_first_state_;
    else
      return VectorFstBaseImpl<S>::GetState(s);
  }

  // Gets a state from its ID; state must exist.
  S *GetState(StateId s) {
    if (s == cache_first_state_id_)
      return cache_first_state_;
    else
      return VectorFstBaseImpl<S>::GetState(s);
  }

  // Gets a state from its ID; return 0 if it doesn't exist.
  const S *CheckState(StateId s) const {
    if (s == cache_first_state_id_)
      return cache_first_state_;
    else if (s < NumStates())
      return VectorFstBaseImpl<S>::GetState(s);
    else
      return 0;
  }

  // Gets a state from its ID; add it if necessary.
  S *ExtendState(StateId s) {
    if (s == cache_first_state_id_) {
      return cache_first_state_;                   // Return 1st cached state
    } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
      cache_first_state_id_ = s;                   // Remember 1st cached state
      cache_first_state_ = new S;
      return cache_first_state_;
    } else if (cache_first_state_id_ != kNoStateId &&
               cache_first_state_->ref_count == 0) {
      cache_first_state_id_ = s;                   // Reuse 1st cached state
      cache_first_state_->Reset();
      return cache_first_state_;                   // Return 1st cached state
    } else {
      while (NumStates() <= s)                     // Add state to main cache
        AddState(0);
      if (!VectorFstBaseImpl<S>::GetState(s)) {
        SetState(s, new S);
        if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
          while (NumStates() <= cache_first_state_id_)
            AddState(0);
          SetState(cache_first_state_id_, cache_first_state_);
          if (cache_gc_) {
            cache_states_.push_back(cache_first_state_id_);
            cache_size_ += sizeof(S) +
                           cache_first_state_->arcs.capacity() * sizeof(Arc);
            cache_limit_ = kMinCacheLimit;
          }
          cache_first_state_id_ = kNoStateId;
          cache_first_state_ = 0;
        }
        if (cache_gc_) {
          cache_states_.push_back(s);
          cache_size_ += sizeof(S);
          if (cache_size_ > cache_limit_)
            GC(s, false);
        }
      }
      return VectorFstBaseImpl<S>::GetState(s);
    }
  }

  void SetStart(StateId s) {
    VectorFstBaseImpl<S>::SetStart(s);
    cache_start_ = true;
    if (s >= nknown_states_)
      nknown_states_ = s + 1;
  }

  void SetFinal(StateId s, Weight w) {
    S *state = ExtendState(s);
    state->final = w;
    state->flags |= kCacheFinal | kCacheRecent;
  }

  void AddArc(StateId s, const Arc &arc) {
    S *state = ExtendState(s);
    state->arcs.push_back(arc);
  }

  // Marks arcs of state s as cached.
  void SetArcs(StateId s) {
    S *state = ExtendState(s);
    vector<Arc> &arcs = state->arcs;
    state->niepsilons = state->noepsilons = 0;
    for (unsigned int a = 0; a < arcs.size(); ++a) { 
      const Arc &arc = arcs[a];
      if (arc.nextstate >= nknown_states_)
        nknown_states_ = arc.nextstate + 1;
      if (arc.ilabel == 0)
        ++state->niepsilons;
      if (arc.olabel == 0)
        ++state->noepsilons;
    }
    ExpandedState(s);
    state->flags |= kCacheArcs | kCacheRecent;
    if (cache_gc_ && s != cache_first_state_id_) {
      cache_size_ += arcs.capacity() * sizeof(Arc);
      if (cache_size_ > cache_limit_)
        GC(s, false);
    }
  };

  void ReserveArcs(StateId s, size_t n) {
    S *state = ExtendState(s);
    state->arcs.reserve(n);
  }

  // Is the start state cached?
  bool HasStart() const { return cache_start_; }
  // Is the final weight of state s cached?

  bool HasFinal(StateId s) const {
    const S *state = CheckState(s);
    if (state && state->flags & kCacheFinal) {
      state->flags |= kCacheRecent;
      return true;
    } else {
      return false;
    }
  }

  // Are arcs of state s cached?
  bool HasArcs(StateId s) const {
    const S *state = CheckState(s);
    if (state && state->flags & kCacheArcs) {
      state->flags |= kCacheRecent;
      return true;
    } else {
      return false;
    }
  }

  Weight Final(StateId s) const {
    const S *state = GetState(s);
    return state->final;
  }

  size_t NumArcs(StateId s) const {
    const S *state = GetState(s);
    return state->arcs.size();
  }

  size_t NumInputEpsilons(StateId s) const {
    const S *state = GetState(s);
    return state->niepsilons;
  }

  size_t NumOutputEpsilons(StateId s) const {
    const S *state = GetState(s);
    return state->noepsilons;
  }

  // Provides information needed for generic arc iterator.
  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
    const S *state = GetState(s);
    data->base = 0;
    data->narcs = state->arcs.size();
    data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
    data->ref_count = &(state->ref_count);
    ++(*data->ref_count);
  }

  // Number of known states.
  StateId NumKnownStates() const { return nknown_states_; }
  // Find the mininum never-expanded state Id
  StateId MinUnexpandedState() const {
    while (min_unexpanded_state_id_ < (StateId)expanded_states_.size() && 
          expanded_states_[min_unexpanded_state_id_])
      ++min_unexpanded_state_id_;
    return min_unexpanded_state_id_;
  }

  // Removes from cache_states_ and uncaches (not referenced-counted)
  // states that have not been accessed since the last GC until
  // cache_limit_/3 bytes are uncached.  If that fails to free enough,
  // recurs uncaching recently visited states as well. If still
  // unable to free enough memory, then widens cache_limit_.
  void GC(StateId current, bool free_recent) {
    if (!cache_gc_)
      return;
    VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
            << "), free recently cached = " << free_recent
            << ", cache size = " << cache_size_
            << ", cache limit = " << cache_limit_ << "\n";
    typename list<StateId>::iterator siter = cache_states_.begin();

    size_t cache_target = (2 * cache_limit_)/3 + 1;
    while (siter != cache_states_.end()) {
      StateId s = *siter;
      S* state = VectorFstBaseImpl<S>::GetState(s);
      if (cache_size_ > cache_target && state->ref_count == 0 &&
          (free_recent || !(state->flags & kCacheRecent)) && s != current) {
        cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
        delete state;
        SetState(s, 0);
        cache_states_.erase(siter++);
      } else {
        state->flags &= ~kCacheRecent;
        ++siter;
      }
    }
    if (!free_recent && cache_size_ > cache_target) {
      GC(current, true);
    } else {
      while (cache_size_ > cache_target) {
        cache_limit_ *= 2;
        cache_target *= 2;
      }
    }
    VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
            << "), free recently cached = " << free_recent
            << ", cache size = " << cache_size_
            << ", cache limit = " << cache_limit_ << "\n";
  }

 private:
  static const uint32 kCacheFinal =  0x0001;  // Final weight has been cached
  static const uint32 kCacheArcs =   0x0002;  // Arcs have been cached
  static const uint32 kCacheRecent = 0x0004;  // Mark as visited since GC

  static const size_t kMinCacheLimit;         // Minimum (non-zero) cache limit

  void ExpandedState(StateId s) {
    if (s < min_unexpanded_state_id_)
      return;
    while ((StateId)expanded_states_.size() <= s) 
      expanded_states_.push_back(false);
    expanded_states_[s] = true;
  }

  bool cache_start_;                         // Is the start state cached?
  StateId nknown_states_;                    // # of known states
  vector<bool> expanded_states_;             // states that have been expanded
  mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
  StateId cache_first_state_id_;             // First cached state id
  S *cache_first_state_;                     // First cached state
  list<StateId> cache_states_;               // list of currently cached states
  bool cache_gc_;                            // enable GC
  size_t cache_size_;                        // # of bytes cached
  size_t cache_limit_;                       // # of bytes allowed before GC

  void InitStateIterator(StateIteratorData<Arc> *);  // disallow
  DISALLOW_EVIL_CONSTRUCTORS(CacheBaseImpl);
};

template <class S>
const size_t CacheBaseImpl<S>::kMinCacheLimit = 8096;


// Arcs implemented by an STL vector per state. Similar to VectorState
// but adds flags and ref count to keep track of what has been cached.
template <class A>
struct CacheState {
  typedef A Arc;
  typedef typename A::Weight Weight;
  typedef typename A::StateId StateId;

  CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}

  void Reset() {
    flags = 0;
    ref_count = 0;
    arcs.resize(0);
  }

  Weight final;              // Final weight
  vector<A> arcs;            // Arcs represenation
  size_t niepsilons;         // # of input epsilons
  size_t noepsilons;         // # of output epsilons
  mutable uint32 flags;
  mutable int ref_count;
};

// A CacheBaseImpl with a commonly used CacheState.
template <class A>
class CacheImpl : public CacheBaseImpl< CacheState<A> > {
 public:
  typedef CacheState<A> State;

  CacheImpl() {}

  explicit CacheImpl(const CacheOptions &opts)
      : CacheBaseImpl< CacheState<A> >(opts) {}

 private:
  DISALLOW_EVIL_CONSTRUCTORS(CacheImpl);
};


// Use this to make a state iterator for a CacheBaseImpl-derived Fst.
// You'll need to make this class a friend of your derived Fst.
// Note this iterator only returns those states reachable from
// the initial state, so consider implementing a class-specific one.
template <class F>
class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
 public:
  typedef typename F::Arc Arc;
  typedef typename Arc::StateId StateId;

  explicit CacheStateIterator(const F &fst) : fst_(fst), s_(0) {}

  virtual bool Done() const {
    if (s_ < fst_.impl_->NumKnownStates())
      return false;
    fst_.Start();  // force start state
    if (s_ < fst_.impl_->NumKnownStates())
      return false;
    for (int u = fst_.impl_->MinUnexpandedState();
         u < fst_.impl_->NumKnownStates();
         u = fst_.impl_->MinUnexpandedState()) {
      ArcIterator<F>(fst_, u);  // force state expansion
      if (s_ < fst_.impl_->NumKnownStates())
        return false;
    }
    return true;
  }

  virtual StateId Value() const { return s_; }

  virtual void Next() { ++s_; }

  virtual void Reset() { s_ = 0; }

 private:
  const F &fst_;
  StateId s_;
};


// Use this to make an arc iterator for a CacheBaseImpl-derived Fst.
// You'll need to make this class a friend of your derived Fst and
// define types Arc and State.
template <class F>
class CacheArcIterator {
 public:
  typedef typename F::Arc Arc;
  typedef typename F::State State;
  typedef typename Arc::StateId StateId;

  CacheArcIterator(const F &fst, StateId s) : i_(0) {
    state_ = fst.impl_->ExtendState(s);
    ++state_->ref_count;
  }

  ~CacheArcIterator() { --state_->ref_count;  }

  bool Done() const { return i_ >= state_->arcs.size(); }

  const Arc& Value() const { return state_->arcs[i_]; }

  void Next() { ++i_; }

  void Reset() { i_ = 0; }

  void Seek(size_t a) { i_ = a; }

 private:
  const State *state_;
  size_t i_;

  DISALLOW_EVIL_CONSTRUCTORS(CacheArcIterator);
};

}  // namespace fst

#endif  // FST_LIB_CACHE_H__