// relabel.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: johans@google.com (Johan Schalkwyk) // // \file // Functions and classes to relabel an Fst (either on input or output) // #ifndef FST_LIB_RELABEL_H__ #define FST_LIB_RELABEL_H__ #include <tr1/unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; #include <string> #include <utility> using std::pair; using std::make_pair; #include <vector> using std::vector; #include <fst/cache.h> #include <fst/test-properties.h> #include <tr1/unordered_map> using std::tr1::unordered_map; using std::tr1::unordered_multimap; namespace fst { // // Relabels either the input labels or output labels. The old to // new labels are specified using a vector of pair<Label,Label>. // Any label associations not specified are assumed to be identity // mapping. // // \param fst input fst, must be mutable // \param ipairs vector of input label pairs indicating old to new mapping // \param opairs vector of output label pairs indicating old to new mapping // template <class A> void Relabel( MutableFst<A> *fst, const vector<pair<typename A::Label, typename A::Label> >& ipairs, const vector<pair<typename A::Label, typename A::Label> >& opairs) { typedef typename A::StateId StateId; typedef typename A::Label Label; uint64 props = fst->Properties(kFstProperties, false); // construct label to label hash. unordered_map<Label, Label> input_map; for (size_t i = 0; i < ipairs.size(); ++i) { input_map[ipairs[i].first] = ipairs[i].second; } unordered_map<Label, Label> output_map; for (size_t i = 0; i < opairs.size(); ++i) { output_map[opairs[i].first] = opairs[i].second; } for (StateIterator<MutableFst<A> > siter(*fst); !siter.Done(); siter.Next()) { StateId s = siter.Value(); for (MutableArcIterator<MutableFst<A> > aiter(fst, s); !aiter.Done(); aiter.Next()) { A arc = aiter.Value(); // relabel input // only relabel if relabel pair defined typename unordered_map<Label, Label>::iterator it = input_map.find(arc.ilabel); if (it != input_map.end()) { if (it->second == kNoLabel) { FSTERROR() << "Input symbol id " << arc.ilabel << " missing from target vocabulary"; fst->SetProperties(kError, kError); return; } arc.ilabel = it->second; } // relabel output it = output_map.find(arc.olabel); if (it != output_map.end()) { if (it->second == kNoLabel) { FSTERROR() << "Output symbol id " << arc.olabel << " missing from target vocabulary"; fst->SetProperties(kError, kError); return; } arc.olabel = it->second; } aiter.SetValue(arc); } } fst->SetProperties(RelabelProperties(props), kFstProperties); } // // Relabels either the input labels or output labels. The old to // new labels mappings are specified using an input Symbol set. // Any label associations not specified are assumed to be identity // mapping. // // \param fst input fst, must be mutable // \param new_isymbols symbol set indicating new mapping of input symbols // \param new_osymbols symbol set indicating new mapping of output symbols // template<class A> void Relabel(MutableFst<A> *fst, const SymbolTable* new_isymbols, const SymbolTable* new_osymbols) { Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(), new_osymbols, true); } template<class A> void Relabel(MutableFst<A> *fst, const SymbolTable* old_isymbols, const SymbolTable* new_isymbols, bool attach_new_isymbols, const SymbolTable* old_osymbols, const SymbolTable* new_osymbols, bool attach_new_osymbols) { typedef typename A::StateId StateId; typedef typename A::Label Label; vector<pair<Label, Label> > ipairs; if (old_isymbols && new_isymbols) { for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); syms_iter.Next()) { string isymbol = syms_iter.Symbol(); int isymbol_val = syms_iter.Value(); int new_isymbol_val = new_isymbols->Find(isymbol); ipairs.push_back(make_pair(isymbol_val, new_isymbol_val)); } if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols); } vector<pair<Label, Label> > opairs; if (old_osymbols && new_osymbols) { for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); syms_iter.Next()) { string osymbol = syms_iter.Symbol(); int osymbol_val = syms_iter.Value(); int new_osymbol_val = new_osymbols->Find(osymbol); opairs.push_back(make_pair(osymbol_val, new_osymbol_val)); } if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols); } // call relabel using vector of relabel pairs. Relabel(fst, ipairs, opairs); } typedef CacheOptions RelabelFstOptions; template <class A> class RelabelFst; // // \class RelabelFstImpl // \brief Implementation for delayed relabeling // // Relabels an FST from one symbol set to another. Relabeling // can either be on input or output space. RelabelFst implements // a delayed version of the relabel. Arcs are relabeled on the fly // and not cached. I.e each request is recomputed. // template<class A> class RelabelFstImpl : public CacheImpl<A> { friend class StateIterator< RelabelFst<A> >; public: using FstImpl<A>::SetType; using FstImpl<A>::SetProperties; using FstImpl<A>::WriteHeader; using FstImpl<A>::SetInputSymbols; using FstImpl<A>::SetOutputSymbols; using CacheImpl<A>::PushArc; using CacheImpl<A>::HasArcs; using CacheImpl<A>::HasFinal; using CacheImpl<A>::HasStart; using CacheImpl<A>::SetArcs; using CacheImpl<A>::SetFinal; using CacheImpl<A>::SetStart; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; RelabelFstImpl(const Fst<A>& fst, const vector<pair<Label, Label> >& ipairs, const vector<pair<Label, Label> >& opairs, const RelabelFstOptions &opts) : CacheImpl<A>(opts), fst_(fst.Copy()), relabel_input_(false), relabel_output_(false) { uint64 props = fst.Properties(kCopyProperties, false); SetProperties(RelabelProperties(props)); SetType("relabel"); // create input label map if (ipairs.size() > 0) { for (size_t i = 0; i < ipairs.size(); ++i) { input_map_[ipairs[i].first] = ipairs[i].second; } relabel_input_ = true; } // create output label map if (opairs.size() > 0) { for (size_t i = 0; i < opairs.size(); ++i) { output_map_[opairs[i].first] = opairs[i].second; } relabel_output_ = true; } } RelabelFstImpl(const Fst<A>& fst, const SymbolTable* old_isymbols, const SymbolTable* new_isymbols, const SymbolTable* old_osymbols, const SymbolTable* new_osymbols, const RelabelFstOptions &opts) : CacheImpl<A>(opts), fst_(fst.Copy()), relabel_input_(false), relabel_output_(false) { SetType("relabel"); uint64 props = fst.Properties(kCopyProperties, false); SetProperties(RelabelProperties(props)); SetInputSymbols(old_isymbols); SetOutputSymbols(old_osymbols); if (old_isymbols && new_isymbols && old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); syms_iter.Next()) { input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); } SetInputSymbols(new_isymbols); relabel_input_ = true; } if (old_osymbols && new_osymbols && old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); syms_iter.Next()) { output_map_[syms_iter.Value()] = new_osymbols->Find(syms_iter.Symbol()); } SetOutputSymbols(new_osymbols); relabel_output_ = true; } } RelabelFstImpl(const RelabelFstImpl<A>& impl) : CacheImpl<A>(impl), fst_(impl.fst_->Copy(true)), input_map_(impl.input_map_), output_map_(impl.output_map_), relabel_input_(impl.relabel_input_), relabel_output_(impl.relabel_output_) { SetType("relabel"); SetProperties(impl.Properties(), kCopyProperties); SetInputSymbols(impl.InputSymbols()); SetOutputSymbols(impl.OutputSymbols()); } ~RelabelFstImpl() { delete fst_; } StateId Start() { if (!HasStart()) { StateId s = fst_->Start(); SetStart(s); } return CacheImpl<A>::Start(); } Weight Final(StateId s) { if (!HasFinal(s)) { SetFinal(s, fst_->Final(s)); } return CacheImpl<A>::Final(s); } size_t NumArcs(StateId s) { if (!HasArcs(s)) { Expand(s); } return CacheImpl<A>::NumArcs(s); } size_t NumInputEpsilons(StateId s) { if (!HasArcs(s)) { Expand(s); } return CacheImpl<A>::NumInputEpsilons(s); } size_t NumOutputEpsilons(StateId s) { if (!HasArcs(s)) { Expand(s); } return CacheImpl<A>::NumOutputEpsilons(s); } uint64 Properties() const { return Properties(kFstProperties); } // Set error if found; return FST impl properties. uint64 Properties(uint64 mask) const { if ((mask & kError) && fst_->Properties(kError, false)) SetProperties(kError, kError); return FstImpl<Arc>::Properties(mask); } void InitArcIterator(StateId s, ArcIteratorData<A>* data) { if (!HasArcs(s)) { Expand(s); } CacheImpl<A>::InitArcIterator(s, data); } void Expand(StateId s) { for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { A arc = aiter.Value(); // relabel input if (relabel_input_) { typename unordered_map<Label, Label>::iterator it = input_map_.find(arc.ilabel); if (it != input_map_.end()) { arc.ilabel = it->second; } } // relabel output if (relabel_output_) { typename unordered_map<Label, Label>::iterator it = output_map_.find(arc.olabel); if (it != output_map_.end()) { arc.olabel = it->second; } } PushArc(s, arc); } SetArcs(s); } private: const Fst<A> *fst_; unordered_map<Label, Label> input_map_; unordered_map<Label, Label> output_map_; bool relabel_input_; bool relabel_output_; void operator=(const RelabelFstImpl<A> &); // disallow }; // // \class RelabelFst // \brief Delayed implementation of arc relabeling // // This class attaches interface to implementation and handles // reference counting, delegating most methods to ImplToFst. template <class A> class RelabelFst : public ImplToFst< RelabelFstImpl<A> > { public: friend class ArcIterator< RelabelFst<A> >; friend class StateIterator< RelabelFst<A> >; typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename A::StateId StateId; typedef CacheState<A> State; typedef RelabelFstImpl<A> Impl; RelabelFst(const Fst<A>& fst, const vector<pair<Label, Label> >& ipairs, const vector<pair<Label, Label> >& opairs) : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {} RelabelFst(const Fst<A>& fst, const vector<pair<Label, Label> >& ipairs, const vector<pair<Label, Label> >& opairs, const RelabelFstOptions &opts) : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {} RelabelFst(const Fst<A>& fst, const SymbolTable* new_isymbols, const SymbolTable* new_osymbols) : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, fst.OutputSymbols(), new_osymbols, RelabelFstOptions())) {} RelabelFst(const Fst<A>& fst, const SymbolTable* new_isymbols, const SymbolTable* new_osymbols, const RelabelFstOptions &opts) : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, fst.OutputSymbols(), new_osymbols, opts)) {} RelabelFst(const Fst<A>& fst, const SymbolTable* old_isymbols, const SymbolTable* new_isymbols, const SymbolTable* old_osymbols, const SymbolTable* new_osymbols) : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, new_osymbols, RelabelFstOptions())) {} RelabelFst(const Fst<A>& fst, const SymbolTable* old_isymbols, const SymbolTable* new_isymbols, const SymbolTable* old_osymbols, const SymbolTable* new_osymbols, const RelabelFstOptions &opts) : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, new_osymbols, opts)) {} // See Fst<>::Copy() for doc. RelabelFst(const RelabelFst<A> &fst, bool safe = false) : ImplToFst<Impl>(fst, safe) {} // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc. virtual RelabelFst<A> *Copy(bool safe = false) const { return new RelabelFst<A>(*this, safe); } virtual void InitStateIterator(StateIteratorData<A> *data) const; virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { return GetImpl()->InitArcIterator(s, data); } private: // Makes visible to friends. Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } void operator=(const RelabelFst<A> &fst); // disallow }; // Specialization for RelabelFst. template<class A> class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> { public: typedef typename A::StateId StateId; explicit StateIterator(const RelabelFst<A> &fst) : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} bool Done() const { return siter_.Done(); } StateId Value() const { return s_; } void Next() { if (!siter_.Done()) { ++s_; siter_.Next(); } } void Reset() { s_ = 0; siter_.Reset(); } private: bool Done_() const { return Done(); } StateId Value_() const { return Value(); } void Next_() { Next(); } void Reset_() { Reset(); } const RelabelFstImpl<A> *impl_; StateIterator< Fst<A> > siter_; StateId s_; DISALLOW_COPY_AND_ASSIGN(StateIterator); }; // Specialization for RelabelFst. template <class A> class ArcIterator< RelabelFst<A> > : public CacheArcIterator< RelabelFst<A> > { public: typedef typename A::StateId StateId; ArcIterator(const RelabelFst<A> &fst, StateId s) : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) { if (!fst.GetImpl()->HasArcs(s)) fst.GetImpl()->Expand(s); } private: DISALLOW_COPY_AND_ASSIGN(ArcIterator); }; template <class A> inline void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const { data->base = new StateIterator< RelabelFst<A> >(*this); } // Useful alias when using StdArc. typedef RelabelFst<StdArc> StdRelabelFst; } // namespace fst #endif // FST_LIB_RELABEL_H__