// paren.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// Common classes for PDT parentheses

// \file

#ifndef FST_EXTENSIONS_PDT_PAREN_H_
#define FST_EXTENSIONS_PDT_PAREN_H_

#include <algorithm>
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <tr1/unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <set>

#include <fst/extensions/pdt/pdt.h>
#include <fst/extensions/pdt/collection.h>
#include <fst/fst.h>
#include <fst/dfs-visit.h>


namespace fst {

//
// ParenState: Pair of an open (close) parenthesis and
// its destination (source) state.
//

template <class A>
class ParenState {
 public:
  typedef typename A::Label Label;
  typedef typename A::StateId StateId;

  struct Hash {
    size_t operator()(const ParenState<A> &p) const {
      return p.paren_id + p.state_id * kPrime;
    }
  };

  Label paren_id;     // ID of open (close) paren
  StateId state_id;   // destination (source) state of open (close) paren

  ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}

  ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}

  bool operator==(const ParenState<A> &p) const {
    if (&p == this)
      return true;
    return p.paren_id == this->paren_id && p.state_id == this->state_id;
  }

  bool operator!=(const ParenState<A> &p) const { return !(p == *this); }

  bool operator<(const ParenState<A> &p) const {
    return paren_id < this->paren.id ||
        (p.paren_id == this->paren.id && p.state_id < this->state_id);
  }

 private:
  static const size_t kPrime;
};

template <class A>
const size_t ParenState<A>::kPrime = 7853;


// Creates an FST-style iterator from STL map and iterator.
template <class M>
class MapIterator {
 public:
  typedef typename M::const_iterator StlIterator;
  typedef typename M::value_type PairType;
  typedef typename PairType::second_type ValueType;

  MapIterator(const M &m, StlIterator iter)
      : map_(m), begin_(iter), iter_(iter) {}

  bool Done() const {
    return iter_ == map_.end() || iter_->first != begin_->first;
  }

  ValueType Value() const { return iter_->second; }
  void Next() { ++iter_; }
  void Reset() { iter_ = begin_; }

 private:
  const M &map_;
  StlIterator begin_;
  StlIterator iter_;
};

//
// PdtParenReachable: Provides various parenthesis reachability information
// on a PDT.
//

template <class A>
class PdtParenReachable {
 public:
  typedef typename A::StateId StateId;
  typedef typename A::Label Label;
 public:
  // Maps from state ID to reachable paren IDs from (to) that state.
  typedef unordered_multimap<StateId, Label> ParenMultiMap;

  // Maps from paren ID and state ID to reachable state set ID
  typedef unordered_map<ParenState<A>, ssize_t,
                   typename ParenState<A>::Hash> StateSetMap;

  // Maps from paren ID and state ID to arcs exiting that state with that
  // Label.
  typedef unordered_multimap<ParenState<A>, A,
                        typename ParenState<A>::Hash> ParenArcMultiMap;

  typedef MapIterator<ParenMultiMap> ParenIterator;

  typedef MapIterator<ParenArcMultiMap> ParenArcIterator;

  typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;

  // Computes close (open) parenthesis reachabilty information for
  // a PDT with bounded stack.
  PdtParenReachable(const Fst<A> &fst,
                    const vector<pair<Label, Label> > &parens, bool close)
      : fst_(fst),
        parens_(parens),
        close_(close) {
    for (Label i = 0; i < parens.size(); ++i) {
      const pair<Label, Label>  &p = parens[i];
      paren_id_map_[p.first] = i;
      paren_id_map_[p.second] = i;
    }

    if (close_) {
      StateId start = fst.Start();
      if (start == kNoStateId)
        return;
      DFSearch(start, start);
    } else {
      FSTERROR() << "PdtParenReachable: open paren info not implemented";
    }
  }

  // Given a state ID, returns an iterator over paren IDs
  // for close (open) parens reachable from that state along balanced
  // paths.
  ParenIterator FindParens(StateId s) const {
    return ParenIterator(paren_multimap_, paren_multimap_.find(s));
  }

  // Given a paren ID and a state ID s, returns an iterator over
  // states that can be reached along balanced paths from (to) s that
  // have have close (open) parentheses matching the paren ID exiting
  // (entering) those states.
  SetIterator FindStates(Label paren_id, StateId s) const {
    ParenState<A> paren_state(paren_id, s);
    typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
    if (id_it == set_map_.end()) {
      return state_sets_.FindSet(-1);
    } else {
      return state_sets_.FindSet(id_it->second);
    }
  }

  // Given a paren Id and a state ID s, return an iterator over
  // arcs that exit (enter) s and are labeled with a close (open)
  // parenthesis matching the paren ID.
  ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
    ParenState<A> paren_state(paren_id, s);
    return ParenArcIterator(paren_arc_multimap_,
                            paren_arc_multimap_.find(paren_state));
  }

 private:
  // DFS that gathers paren and state set information.
  // Bool returns false when cycle detected.
  bool DFSearch(StateId s, StateId start);

  // Unions state sets together gathered by the DFS.
  void ComputeStateSet(StateId s);

  // Gather state set(s) from state 'nexts'.
  void UpdateStateSet(StateId nexts, set<Label> *paren_set,
                      vector< set<StateId> > *state_sets) const;

  const Fst<A> &fst_;
  const vector<pair<Label, Label> > &parens_;         // Paren ID -> Labels
  bool close_;                                        // Close/open paren info?
  unordered_map<Label, Label> paren_id_map_;               // Paren labels -> ID
  ParenMultiMap paren_multimap_;                      // Paren reachability
  ParenArcMultiMap paren_arc_multimap_;               // Paren Arcs
  vector<char> state_color_;                          // DFS state
  mutable Collection<ssize_t, StateId> state_sets_;   // Reachable states -> ID
  StateSetMap set_map_;                               // ID -> Reachable states
  DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
};

// DFS that gathers paren and state set information.
template <class A>
bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
  if (s >= state_color_.size())
    state_color_.resize(s + 1, kDfsWhite);

  if (state_color_[s] == kDfsBlack)
    return true;

  if (state_color_[s] == kDfsGrey)
    return false;

  state_color_[s] = kDfsGrey;

  for (ArcIterator<Fst<A> > aiter(fst_, s);
       !aiter.Done();
       aiter.Next()) {
    const A &arc = aiter.Value();

    typename unordered_map<Label, Label>::const_iterator pit
        = paren_id_map_.find(arc.ilabel);
    if (pit != paren_id_map_.end()) {               // paren?
      Label paren_id = pit->second;
      if (arc.ilabel == parens_[paren_id].first) {  // open paren
        DFSearch(arc.nextstate, arc.nextstate);
        for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
             !set_iter.Done(); set_iter.Next()) {
          for (ParenArcIterator paren_arc_iter =
                   FindParenArcs(paren_id, set_iter.Element());
               !paren_arc_iter.Done();
               paren_arc_iter.Next()) {
            const A &cparc = paren_arc_iter.Value();
            DFSearch(cparc.nextstate, start);
          }
        }
      }
    } else {                                       // non-paren
      if(!DFSearch(arc.nextstate, start)) {
        FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
        return true;
      }
    }
  }
  ComputeStateSet(s);
  state_color_[s] = kDfsBlack;
  return true;
}

// Unions state sets together gathered by the DFS.
template <class A>
void PdtParenReachable<A>::ComputeStateSet(StateId s) {
  set<Label> paren_set;
  vector< set<StateId> > state_sets(parens_.size());
  for (ArcIterator< Fst<A> > aiter(fst_, s);
       !aiter.Done();
       aiter.Next()) {
    const A &arc = aiter.Value();

    typename unordered_map<Label, Label>::const_iterator pit
        = paren_id_map_.find(arc.ilabel);
    if (pit != paren_id_map_.end()) {               // paren?
      Label paren_id = pit->second;
      if (arc.ilabel == parens_[paren_id].first) {  // open paren
        for (SetIterator set_iter =
                 FindStates(paren_id, arc.nextstate);
             !set_iter.Done(); set_iter.Next()) {
          for (ParenArcIterator paren_arc_iter =
                   FindParenArcs(paren_id, set_iter.Element());
               !paren_arc_iter.Done();
               paren_arc_iter.Next()) {
            const A &cparc = paren_arc_iter.Value();
            UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
          }
        }
      } else {                                      // close paren
        paren_set.insert(paren_id);
        state_sets[paren_id].insert(s);
        ParenState<A> paren_state(paren_id, s);
        paren_arc_multimap_.insert(make_pair(paren_state, arc));
      }
    } else {                                        // non-paren
      UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
    }
  }

  vector<StateId> state_set;
  for (typename set<Label>::iterator paren_iter = paren_set.begin();
       paren_iter != paren_set.end(); ++paren_iter) {
    state_set.clear();
    Label paren_id = *paren_iter;
    paren_multimap_.insert(make_pair(s, paren_id));
    for (typename set<StateId>::iterator state_iter
             = state_sets[paren_id].begin();
         state_iter != state_sets[paren_id].end();
         ++state_iter) {
      state_set.push_back(*state_iter);
    }
    ParenState<A> paren_state(paren_id, s);
    set_map_[paren_state] = state_sets_.FindId(state_set);
  }
}

// Gather state set(s) from state 'nexts'.
template <class A>
void PdtParenReachable<A>::UpdateStateSet(
    StateId nexts, set<Label> *paren_set,
    vector< set<StateId> > *state_sets) const {
  for(ParenIterator paren_iter = FindParens(nexts);
      !paren_iter.Done(); paren_iter.Next()) {
    Label paren_id = paren_iter.Value();
    paren_set->insert(paren_id);
    for (SetIterator set_iter = FindStates(paren_id, nexts);
         !set_iter.Done(); set_iter.Next()) {
      (*state_sets)[paren_id].insert(set_iter.Element());
    }
  }
}


// Store balancing parenthesis data for a PDT. Allows on-the-fly
// construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
template <class A>
class PdtBalanceData {
 public:
  typedef typename A::StateId StateId;
  typedef typename A::Label Label;

  // Hash set for open parens
  typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;

  // Maps from open paren destination state to parenthesis ID.
  typedef unordered_multimap<StateId, Label> OpenParenMap;

  // Maps from open paren state to source states of matching close parens
  typedef unordered_multimap<ParenState<A>, StateId,
                        typename ParenState<A>::Hash> CloseParenMap;

  // Maps from open paren state to close source set ID
  typedef unordered_map<ParenState<A>, ssize_t,
                   typename ParenState<A>::Hash> CloseSourceMap;

  typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;

  PdtBalanceData() {}

  void Clear() {
    open_paren_map_.clear();
    close_paren_map_.clear();
  }

  // Adds an open parenthesis with destination state 'open_dest'.
  void OpenInsert(Label paren_id, StateId open_dest) {
    ParenState<A> key(paren_id, open_dest);
    if (!open_paren_set_.count(key)) {
      open_paren_set_.insert(key);
      open_paren_map_.insert(make_pair(open_dest, paren_id));
    }
  }

  // Adds a matching closing parenthesis with source state
  // 'close_source' that balances an open_parenthesis with destination
  // state 'open_dest' if OpenInsert() previously called
  // (o.w. CloseInsert() does nothing).
  void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
    ParenState<A> key(paren_id, open_dest);
    if (open_paren_set_.count(key))
      close_paren_map_.insert(make_pair(key, close_source));
  }

  // Find close paren source states matching an open parenthesis.
  // Methods that follow, iterate through those matching states.
  // Should be called only after FinishInsert(open_dest).
  SetIterator Find(Label paren_id, StateId open_dest) {
    ParenState<A> close_key(paren_id, open_dest);
    typename CloseSourceMap::const_iterator id_it =
        close_source_map_.find(close_key);
    if (id_it == close_source_map_.end()) {
      return close_source_sets_.FindSet(-1);
    } else {
      return close_source_sets_.FindSet(id_it->second);
    }
  }

  // Call when all open and close parenthesis insertions wrt open
  // parentheses entering 'open_dest' are finished. Must be called
  // before Find(open_dest). Stores close paren source state sets
  // efficiently.
  void FinishInsert(StateId open_dest) {
    vector<StateId> close_sources;
    for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
         oit != open_paren_map_.end() && oit->first == open_dest;) {
      Label paren_id = oit->second;
      close_sources.clear();
      ParenState<A> okey(paren_id, open_dest);
      open_paren_set_.erase(open_paren_set_.find(okey));
      for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
           cit != close_paren_map_.end() && cit->first == okey;) {
        close_sources.push_back(cit->second);
        close_paren_map_.erase(cit++);
      }
      sort(close_sources.begin(), close_sources.end());
      typename vector<StateId>::iterator unique_end =
          unique(close_sources.begin(), close_sources.end());
      close_sources.resize(unique_end - close_sources.begin());

      if (!close_sources.empty())
        close_source_map_[okey] = close_source_sets_.FindId(close_sources);
      open_paren_map_.erase(oit++);
    }
  }

  // Return a new balance data object representing the reversed balance
  // information.
  PdtBalanceData<A> *Reverse(StateId num_states,
                               StateId num_split,
                               StateId state_id_shift) const;

 private:
  OpenParenSet open_paren_set_;                      // open par. at dest?

  OpenParenMap open_paren_map_;                      // open parens per state
  ParenState<A> open_dest_;                          // cur open dest. state
  typename OpenParenMap::const_iterator open_iter_;  // cur open parens/state

  CloseParenMap close_paren_map_;                    // close states/open
                                                     //  paren and state

  CloseSourceMap close_source_map_;                  // paren, state to set ID
  mutable Collection<ssize_t, StateId> close_source_sets_;
};

// Return a new balance data object representing the reversed balance
// information.
template <class A>
PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
    StateId num_states,
    StateId num_split,
    StateId state_id_shift) const {
  PdtBalanceData<A> *bd = new PdtBalanceData<A>;
  unordered_set<StateId> close_sources;
  StateId split_size = num_states / num_split;

  for (StateId i = 0; i < num_states; i+= split_size) {
    close_sources.clear();

    for (typename CloseSourceMap::const_iterator
             sit = close_source_map_.begin();
         sit != close_source_map_.end();
         ++sit) {
      ParenState<A> okey = sit->first;
      StateId open_dest = okey.state_id;
      Label paren_id = okey.paren_id;
      for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
           !set_iter.Done(); set_iter.Next()) {
        StateId close_source = set_iter.Element();
        if ((close_source < i) || (close_source >= i + split_size))
          continue;
        close_sources.insert(close_source + state_id_shift);
        bd->OpenInsert(paren_id, close_source + state_id_shift);
        bd->CloseInsert(paren_id, close_source + state_id_shift,
                        open_dest + state_id_shift);
      }
    }

    for (typename unordered_set<StateId>::const_iterator it
             = close_sources.begin();
         it != close_sources.end();
         ++it) {
      bd->FinishInsert(*it);
    }

  }
  return bd;
}


}  // namespace fst

#endif  // FST_EXTENSIONS_PDT_PAREN_H_