// reweight.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: allauzen@google.com (Cyril Allauzen) // // \file // Function to reweight an FST. #ifndef FST_LIB_REWEIGHT_H__ #define FST_LIB_REWEIGHT_H__ #include <vector> using std::vector; #include <fst/mutable-fst.h> namespace fst { enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL }; // Reweight FST according to the potentials defined by the POTENTIAL // vector in the direction defined by TYPE. Weight needs to be left // distributive when reweighting towards the initial state and right // distributive when reweighting towards the final states. // // An arc of weight w, with an origin state of potential p and // destination state of potential q, is reweighted by p\wq when // reweighting towards the initial state and by pw/q when reweighting // towards the final states. template <class Arc> void Reweight(MutableFst<Arc> *fst, const vector<typename Arc::Weight> &potential, ReweightType type) { typedef typename Arc::Weight Weight; if (fst->NumStates() == 0) return; if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) { FSTERROR() << "Reweight: Reweighting to the final states requires " << "Weight to be right distributive: " << Weight::Type(); fst->SetProperties(kError, kError); return; } if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) { FSTERROR() << "Reweight: Reweighting to the initial state requires " << "Weight to be left distributive: " << Weight::Type(); fst->SetProperties(kError, kError); return; } StateIterator< MutableFst<Arc> > sit(*fst); for (; !sit.Done(); sit.Next()) { typename Arc::StateId state = sit.Value(); if (state == potential.size()) break; typename Arc::Weight weight = potential[state]; if (weight != Weight::Zero()) { for (MutableArcIterator< MutableFst<Arc> > ait(fst, state); !ait.Done(); ait.Next()) { Arc arc = ait.Value(); if (arc.nextstate >= potential.size()) continue; typename Arc::Weight nextweight = potential[arc.nextstate]; if (nextweight == Weight::Zero()) continue; if (type == REWEIGHT_TO_INITIAL) arc.weight = Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT); if (type == REWEIGHT_TO_FINAL) arc.weight = Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT); ait.SetValue(arc); } if (type == REWEIGHT_TO_INITIAL) fst->SetFinal(state, Divide(fst->Final(state), weight, DIVIDE_LEFT)); } if (type == REWEIGHT_TO_FINAL) fst->SetFinal(state, Times(weight, fst->Final(state))); } // This handles elements past the end of the potentials array. for (; !sit.Done(); sit.Next()) { typename Arc::StateId state = sit.Value(); if (type == REWEIGHT_TO_FINAL) fst->SetFinal(state, Times(Weight::Zero(), fst->Final(state))); } typename Arc::Weight startweight = fst->Start() < potential.size() ? potential[fst->Start()] : Weight::Zero(); if ((startweight != Weight::One()) && (startweight != Weight::Zero())) { if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) { typename Arc::StateId state = fst->Start(); for (MutableArcIterator< MutableFst<Arc> > ait(fst, state); !ait.Done(); ait.Next()) { Arc arc = ait.Value(); if (type == REWEIGHT_TO_INITIAL) arc.weight = Times(startweight, arc.weight); else arc.weight = Times( Divide(Weight::One(), startweight, DIVIDE_RIGHT), arc.weight); ait.SetValue(arc); } if (type == REWEIGHT_TO_INITIAL) fst->SetFinal(state, Times(startweight, fst->Final(state))); else fst->SetFinal(state, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT), fst->Final(state))); } else { typename Arc::StateId state = fst->AddState(); Weight w = type == REWEIGHT_TO_INITIAL ? startweight : Divide(Weight::One(), startweight, DIVIDE_RIGHT); Arc arc(0, 0, w, fst->Start()); fst->AddArc(state, arc); fst->SetStart(state); } } fst->SetProperties(ReweightProperties( fst->Properties(kFstProperties, false)), kFstProperties); } } // namespace fst #endif // FST_LIB_REWEIGHT_H_