// replace.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Recursively replace Fst arcs with other Fst(s) returning a PDT. #ifndef FST_EXTENSIONS_PDT_REPLACE_H__ #define FST_EXTENSIONS_PDT_REPLACE_H__ #include <fst/replace.h> namespace fst { // Hash to paren IDs template <typename S> struct ReplaceParenHash { size_t operator()(const pair<size_t, S> &p) const { return p.first + p.second * kPrime; } private: static const size_t kPrime = 7853; }; template <typename S> const size_t ReplaceParenHash<S>::kPrime; // Builds a pushdown transducer (PDT) from an RTN specification // identical to that in fst/lib/replace.h. The result is a PDT // encoded as the FST 'ofst' where some transitions are labeled with // open or close parentheses. To be interpreted as a PDT, the parens // must balance on a path (see PdtExpand()). The open/close // parenthesis label pairs are returned in 'parens'. template <class Arc> void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >& ifst_array, MutableFst<Arc> *ofst, vector<pair<typename Arc::Label, typename Arc::Label> > *parens, typename Arc::Label root) { typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; ofst->DeleteStates(); parens->clear(); unordered_map<Label, size_t> label2id; for (size_t i = 0; i < ifst_array.size(); ++i) label2id[ifst_array[i].first] = i; Label max_label = kNoLabel; deque<size_t> non_term_queue; // Queue of non-terminals to replace unordered_set<Label> non_term_set; // Set of non-terminals to replace non_term_queue.push_back(root); non_term_set.insert(root); // PDT state corr. to ith replace FST start state. vector<StateId> fst_start(ifst_array.size(), kNoLabel); // PDT state, weight pairs corr. to ith replace FST final state & weights. vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size()); // Builds single Fst combining all referenced input Fsts. Leaves in the // non-termnals for now. Tabulate the PDT states that correspond to // the start and final states of the input Fsts. for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) { Label label = non_term_queue.front(); non_term_queue.pop_front(); size_t fst_id = label2id[label]; const Fst<Arc> *ifst = ifst_array[fst_id].second; for (StateIterator< Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) { StateId is = siter.Value(); StateId os = ofst->AddState(); if (is == ifst->Start()) { fst_start[fst_id] = os; if (label == root) ofst->SetStart(os); } if (ifst->Final(is) != Weight::Zero()) { if (label == root) ofst->SetFinal(os, ifst->Final(is)); fst_final[fst_id].push_back(make_pair(os, ifst->Final(is))); } for (ArcIterator< Fst<Arc> > aiter(*ifst, is); !aiter.Done(); aiter.Next()) { Arc arc = aiter.Value(); if (max_label == kNoLabel || arc.olabel > max_label) max_label = arc.olabel; typename unordered_map<Label, size_t>::const_iterator it = label2id.find(arc.olabel); if (it != label2id.end()) { size_t nfst_id = it->second; if (ifst_array[nfst_id].second->Start() == -1) continue; if (non_term_set.count(arc.olabel) == 0) { non_term_queue.push_back(arc.olabel); non_term_set.insert(arc.olabel); } } arc.nextstate += soff; ofst->AddArc(os, arc); } } } // Changes each non-terminal transition to an open parenthesis // transition redirected to the PDT state that corresponds to the // start state of the input FST for the non-terminal. Adds close parenthesis // transitions from the PDT states corr. to the final states of the // input FST for the non-terminal to the former destination state of the // non-terminal transition. typedef MutableArcIterator< MutableFst<Arc> > MIter; typedef unordered_map<pair<size_t, StateId >, size_t, ReplaceParenHash<StateId> > ParenMap; // Parenthesis pair ID per fst, state pair. ParenMap paren_map; // # of parenthesis pairs per fst. vector<size_t> nparens(ifst_array.size(), 0); // Initial open parenthesis label Label first_paren = max_label + 1; for (StateIterator< Fst<Arc> > siter(*ofst); !siter.Done(); siter.Next()) { StateId os = siter.Value(); MIter *aiter = new MIter(ofst, os); for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) { Arc arc = aiter->Value(); typename unordered_map<Label, size_t>::const_iterator lit = label2id.find(arc.olabel); if (lit != label2id.end()) { size_t nfst_id = lit->second; // Get parentheses. Ensures distinct parenthesis pair per // non-terminal and destination state but otherwise reuses them. Label open_paren = kNoLabel, close_paren = kNoLabel; pair<size_t, StateId> paren_key(nfst_id, arc.nextstate); typename ParenMap::const_iterator pit = paren_map.find(paren_key); if (pit != paren_map.end()) { size_t paren_id = pit->second; open_paren = (*parens)[paren_id].first; close_paren = (*parens)[paren_id].second; } else { size_t paren_id = nparens[nfst_id]++; open_paren = first_paren + 2 * paren_id; close_paren = open_paren + 1; paren_map[paren_key] = paren_id; if (paren_id >= parens->size()) parens->push_back(make_pair(open_paren, close_paren)); } // Sets open parenthesis. Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]); aiter->SetValue(sarc); // Adds close parentheses. for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) { pair<StateId, Weight> &p = fst_final[nfst_id][i]; Arc farc(close_paren, close_paren, p.second, arc.nextstate); ofst->AddArc(p.first, farc); if (os == p.first) { // Invalidated iterator delete aiter; aiter = new MIter(ofst, os); aiter->Seek(n); } } } } delete aiter; } } } // namespace fst #endif // FST_EXTENSIONS_PDT_REPLACE_H__