// Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: jpr@google.com (Jake Ratkiewicz) // Convenience file for including all PDT operations at once, and/or // registering them for new arc types. #ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_ #define FST_EXTENSIONS_PDT_PDTSCRIPT_H_ #include <utility> using std::pair; using std::make_pair; #include <vector> using std::vector; #include <fst/compose.h> // for ComposeOptions #include <fst/util.h> #include <fst/script/fst-class.h> #include <fst/script/arg-packs.h> #include <fst/script/shortest-path.h> #include <fst/extensions/pdt/compose.h> #include <fst/extensions/pdt/expand.h> #include <fst/extensions/pdt/info.h> #include <fst/extensions/pdt/replace.h> #include <fst/extensions/pdt/reverse.h> #include <fst/extensions/pdt/shortest-path.h> namespace fst { namespace script { // PDT COMPOSE typedef args::Package<const FstClass &, const FstClass &, const vector<pair<int64, int64> >&, MutableFstClass *, const ComposeOptions &, bool> PdtComposeArgs; template<class Arc> void PdtCompose(PdtComposeArgs *args) { const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>()); const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>()); MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>(); vector<pair<typename Arc::Label, typename Arc::Label> > parens( args->arg3.size()); for (size_t i = 0; i < parens.size(); ++i) { parens[i].first = args->arg3[i].first; parens[i].second = args->arg3[i].second; } if (args->arg6) { Compose(ifst1, parens, ifst2, ofst, args->arg5); } else { Compose(ifst1, ifst2, parens, ofst, args->arg5); } } void PdtCompose(const FstClass & ifst1, const FstClass & ifst2, const vector<pair<int64, int64> > &parens, MutableFstClass *ofst, const ComposeOptions &copts, bool left_pdt); // PDT EXPAND struct PdtExpandOptions { bool connect; bool keep_parentheses; WeightClass weight_threshold; PdtExpandOptions(bool c = true, bool k = false, WeightClass w = WeightClass::Zero()) : connect(c), keep_parentheses(k), weight_threshold(w) {} }; typedef args::Package<const FstClass &, const vector<pair<int64, int64> >&, MutableFstClass *, PdtExpandOptions> PdtExpandArgs; template<class Arc> void PdtExpand(PdtExpandArgs *args) { const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); vector<pair<typename Arc::Label, typename Arc::Label> > parens( args->arg2.size()); for (size_t i = 0; i < parens.size(); ++i) { parens[i].first = args->arg2[i].first; parens[i].second = args->arg2[i].second; } Expand(fst, parens, ofst, ExpandOptions<Arc>( args->arg4.connect, args->arg4.keep_parentheses, *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>()))); } void PdtExpand(const FstClass &ifst, const vector<pair<int64, int64> > &parens, MutableFstClass *ofst, const PdtExpandOptions &opts); void PdtExpand(const FstClass &ifst, const vector<pair<int64, int64> > &parens, MutableFstClass *ofst, bool connect); // PDT REPLACE typedef args::Package<const vector<pair<int64, const FstClass*> > &, MutableFstClass *, vector<pair<int64, int64> > *, const int64 &> PdtReplaceArgs; template<class Arc> void PdtReplace(PdtReplaceArgs *args) { vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples( args->arg1.size()); for (size_t i = 0; i < tuples.size(); ++i) { tuples[i].first = args->arg1[i].first; tuples[i].second = (args->arg1[i].second)->GetFst<Arc>(); } MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>(); vector<pair<typename Arc::Label, typename Arc::Label> > parens( args->arg3->size()); for (size_t i = 0; i < parens.size(); ++i) { parens[i].first = args->arg3->at(i).first; parens[i].second = args->arg3->at(i).second; } Replace(tuples, ofst, &parens, args->arg4); // now copy parens back args->arg3->resize(parens.size()); for (size_t i = 0; i < parens.size(); ++i) { (*args->arg3)[i].first = parens[i].first; (*args->arg3)[i].second = parens[i].second; } } void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples, MutableFstClass *ofst, vector<pair<int64, int64> > *parens, const int64 &root); // PDT REVERSE typedef args::Package<const FstClass &, const vector<pair<int64, int64> >&, MutableFstClass *> PdtReverseArgs; template<class Arc> void PdtReverse(PdtReverseArgs *args) { const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); vector<pair<typename Arc::Label, typename Arc::Label> > parens( args->arg2.size()); for (size_t i = 0; i < parens.size(); ++i) { parens[i].first = args->arg2[i].first; parens[i].second = args->arg2[i].second; } Reverse(fst, parens, ofst); } void PdtReverse(const FstClass &ifst, const vector<pair<int64, int64> > &parens, MutableFstClass *ofst); // PDT SHORTESTPATH struct PdtShortestPathOptions { QueueType queue_type; bool keep_parentheses; bool path_gc; PdtShortestPathOptions(QueueType qt = FIFO_QUEUE, bool kp = false, bool gc = true) : queue_type(qt), keep_parentheses(kp), path_gc(gc) {} }; typedef args::Package<const FstClass &, const vector<pair<int64, int64> >&, MutableFstClass *, const PdtShortestPathOptions &> PdtShortestPathArgs; template<class Arc> void PdtShortestPath(PdtShortestPathArgs *args) { typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>(); const PdtShortestPathOptions &opts = args->arg4; vector<pair<Label, Label> > parens(args->arg2.size()); for (size_t i = 0; i < parens.size(); ++i) { parens[i].first = args->arg2[i].first; parens[i].second = args->arg2[i].second; } switch (opts.queue_type) { default: FSTERROR() << "Unknown queue type: " << opts.queue_type; case FIFO_QUEUE: { typedef FifoQueue<StateId> Queue; fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, parens, ofst, spopts); return; } case LIFO_QUEUE: { typedef LifoQueue<StateId> Queue; fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, parens, ofst, spopts); return; } case STATE_ORDER_QUEUE: { typedef StateOrderQueue<StateId> Queue; fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses, opts.path_gc); ShortestPath(fst, parens, ofst, spopts); return; } } } void PdtShortestPath(const FstClass &ifst, const vector<pair<int64, int64> > &parens, MutableFstClass *ofst, const PdtShortestPathOptions &opts = PdtShortestPathOptions()); // PRINT INFO typedef args::Package<const FstClass &, const vector<pair<int64, int64> > &> PrintPdtInfoArgs; template<class Arc> void PrintPdtInfo(PrintPdtInfoArgs *args) { const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>()); vector<pair<typename Arc::Label, typename Arc::Label> > parens( args->arg2.size()); for (size_t i = 0; i < parens.size(); ++i) { parens[i].first = args->arg2[i].first; parens[i].second = args->arg2[i].second; } PdtInfo<Arc> pdtinfo(fst, parens); PrintPdtInfo(pdtinfo); } void PrintPdtInfo(const FstClass &ifst, const vector<pair<int64, int64> > &parens); } // namespace script } // namespace fst #define REGISTER_FST_PDT_OPERATIONS(ArcType) \ REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs); \ REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs); \ REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs); \ REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs); \ REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs); \ REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs) #endif // FST_EXTENSIONS_PDT_PDTSCRIPT_H_