// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: jpr@google.com (Jake Ratkiewicz)
// Convenience file for including all PDT operations at once, and/or
// registering them for new arc types.

#ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_
#define FST_EXTENSIONS_PDT_PDTSCRIPT_H_

#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;

#include <fst/compose.h>  // for ComposeOptions
#include <fst/util.h>

#include <fst/script/fst-class.h>
#include <fst/script/arg-packs.h>
#include <fst/script/shortest-path.h>

#include <fst/extensions/pdt/compose.h>
#include <fst/extensions/pdt/expand.h>
#include <fst/extensions/pdt/info.h>
#include <fst/extensions/pdt/replace.h>
#include <fst/extensions/pdt/reverse.h>
#include <fst/extensions/pdt/shortest-path.h>


namespace fst {
namespace script {

// PDT COMPOSE

typedef args::Package<const FstClass &,
                      const FstClass &,
                      const vector<pair<int64, int64> >&,
                      MutableFstClass *,
                      const ComposeOptions &,
                      bool> PdtComposeArgs;

template<class Arc>
void PdtCompose(PdtComposeArgs *args) {
  const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>());
  const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>());
  MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>();

  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
      args->arg3.size());

  for (size_t i = 0; i < parens.size(); ++i) {
    parens[i].first = args->arg3[i].first;
    parens[i].second = args->arg3[i].second;
  }

  if (args->arg6) {
    Compose(ifst1, parens, ifst2, ofst, args->arg5);
  } else {
    Compose(ifst1, ifst2, parens, ofst, args->arg5);
  }
}

void PdtCompose(const FstClass & ifst1,
                const FstClass & ifst2,
                const vector<pair<int64, int64> > &parens,
                MutableFstClass *ofst,
                const ComposeOptions &copts,
                bool left_pdt);

// PDT EXPAND

struct PdtExpandOptions {
  bool connect;
  bool keep_parentheses;
  WeightClass weight_threshold;

  PdtExpandOptions(bool c = true, bool k = false,
                   WeightClass w = WeightClass::Zero())
      : connect(c), keep_parentheses(k), weight_threshold(w) {}
};

typedef args::Package<const FstClass &,
                      const vector<pair<int64, int64> >&,
                      MutableFstClass *, PdtExpandOptions> PdtExpandArgs;

template<class Arc>
void PdtExpand(PdtExpandArgs *args) {
  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
  MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();

  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
      args->arg2.size());
  for (size_t i = 0; i < parens.size(); ++i) {
    parens[i].first = args->arg2[i].first;
    parens[i].second = args->arg2[i].second;
  }
  Expand(fst, parens, ofst,
         ExpandOptions<Arc>(
             args->arg4.connect, args->arg4.keep_parentheses,
             *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>())));
}

void PdtExpand(const FstClass &ifst,
               const vector<pair<int64, int64> > &parens,
               MutableFstClass *ofst, const PdtExpandOptions &opts);

void PdtExpand(const FstClass &ifst,
               const vector<pair<int64, int64> > &parens,
               MutableFstClass *ofst, bool connect);

// PDT REPLACE

typedef args::Package<const vector<pair<int64, const FstClass*> > &,
                      MutableFstClass *,
                      vector<pair<int64, int64> > *,
                      const int64 &> PdtReplaceArgs;
template<class Arc>
void PdtReplace(PdtReplaceArgs *args) {
  vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples(
      args->arg1.size());
  for (size_t i = 0; i < tuples.size(); ++i) {
    tuples[i].first = args->arg1[i].first;
    tuples[i].second = (args->arg1[i].second)->GetFst<Arc>();
  }
  MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
      args->arg3->size());

  for (size_t i = 0; i < parens.size(); ++i) {
    parens[i].first = args->arg3->at(i).first;
    parens[i].second = args->arg3->at(i).second;
  }
  Replace(tuples, ofst, &parens, args->arg4);

  // now copy parens back
  args->arg3->resize(parens.size());
  for (size_t i = 0; i < parens.size(); ++i) {
    (*args->arg3)[i].first = parens[i].first;
    (*args->arg3)[i].second = parens[i].second;
  }
}

void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples,
                MutableFstClass *ofst,
                vector<pair<int64, int64> > *parens,
                const int64 &root);

// PDT REVERSE

typedef args::Package<const FstClass &,
                      const vector<pair<int64, int64> >&,
                      MutableFstClass *> PdtReverseArgs;

template<class Arc>
void PdtReverse(PdtReverseArgs *args) {
  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
  MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();

  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
      args->arg2.size());
  for (size_t i = 0; i < parens.size(); ++i) {
    parens[i].first = args->arg2[i].first;
    parens[i].second = args->arg2[i].second;
  }
  Reverse(fst, parens, ofst);
}

void PdtReverse(const FstClass &ifst,
                const vector<pair<int64, int64> > &parens,
                MutableFstClass *ofst);


// PDT SHORTESTPATH

struct PdtShortestPathOptions {
  QueueType queue_type;
  bool keep_parentheses;
  bool path_gc;

  PdtShortestPathOptions(QueueType qt = FIFO_QUEUE,
                         bool kp = false, bool gc = true)
      : queue_type(qt), keep_parentheses(kp), path_gc(gc) {}
};

typedef args::Package<const FstClass &,
                      const vector<pair<int64, int64> >&,
                      MutableFstClass *,
                      const PdtShortestPathOptions &> PdtShortestPathArgs;

template<class Arc>
void PdtShortestPath(PdtShortestPathArgs *args) {
  typedef typename Arc::StateId StateId;
  typedef typename Arc::Label Label;
  typedef typename Arc::Weight Weight;

  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
  MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
  const PdtShortestPathOptions &opts = args->arg4;


  vector<pair<Label, Label> > parens(args->arg2.size());
  for (size_t i = 0; i < parens.size(); ++i) {
    parens[i].first = args->arg2[i].first;
    parens[i].second = args->arg2[i].second;
  }

  switch (opts.queue_type) {
    default:
      FSTERROR() << "Unknown queue type: " << opts.queue_type;
    case FIFO_QUEUE: {
      typedef FifoQueue<StateId> Queue;
      fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
                                                         opts.path_gc);
      ShortestPath(fst, parens, ofst, spopts);
      return;
    }
    case LIFO_QUEUE: {
      typedef LifoQueue<StateId> Queue;
      fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
                                                         opts.path_gc);
      ShortestPath(fst, parens, ofst, spopts);
      return;
    }
    case STATE_ORDER_QUEUE: {
      typedef StateOrderQueue<StateId> Queue;
      fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
                                                         opts.path_gc);
      ShortestPath(fst, parens, ofst, spopts);
      return;
    }
  }
}

void PdtShortestPath(const FstClass &ifst,
                     const vector<pair<int64, int64> > &parens,
                     MutableFstClass *ofst,
                     const PdtShortestPathOptions &opts =
                     PdtShortestPathOptions());

// PRINT INFO

typedef args::Package<const FstClass &,
                      const vector<pair<int64, int64> > &> PrintPdtInfoArgs;

template<class Arc>
void PrintPdtInfo(PrintPdtInfoArgs *args) {
  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
      args->arg2.size());
  for (size_t i = 0; i < parens.size(); ++i) {
    parens[i].first = args->arg2[i].first;
    parens[i].second = args->arg2[i].second;
  }
  PdtInfo<Arc> pdtinfo(fst, parens);
  PrintPdtInfo(pdtinfo);
}

void PrintPdtInfo(const FstClass &ifst,
                  const vector<pair<int64, int64> > &parens);

}  // namespace script
}  // namespace fst


#define REGISTER_FST_PDT_OPERATIONS(ArcType)                                \
  REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs);              \
  REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs);                \
  REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs);              \
  REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs);              \
  REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs);    \
  REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs)
#endif  // FST_EXTENSIONS_PDT_PDTSCRIPT_H_