// push.h
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
//
// \file
// Class to reweight/push an FST.
#ifndef FST_LIB_PUSH_H__
#define FST_LIB_PUSH_H__
#include "fst/lib/factor-weight.h"
#include "fst/lib/fst.h"
#include "fst/lib/map.h"
#include "fst/lib/reweight.h"
#include "fst/lib/shortest-distance.h"
namespace fst {
// Pushes the weights in FST in the direction defined by TYPE. If
// pushing towards the initial state, the sum of the weight of the
// outgoing transitions and final weight at a non-initial state is
// equal to One() in the resulting machine. If pushing towards the
// final state, the same property holds on the reverse machine.
//
// Weight needs to be left distributive when pushing towards the
// initial state and right distributive when pushing towards the final
// states.
template <class Arc>
void Push(MutableFst<Arc> *fst, ReweightType type) {
vector<typename Arc::Weight> distance;
ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL);
Reweight(fst, distance, type);
}
const uint32 kPushWeights = 0x0001;
const uint32 kPushLabels = 0x0002;
// OFST obtained from IFST by pushing weights and/or labels according
// to PTYPE in the direction defined by RTYPE. Weight needs to be
// left distributive when pushing weights towards the initial state
// and right distributive when pushing weights towards the final
// states.
template <class Arc, ReweightType rtype>
void Push(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, uint32 ptype) {
if (ptype == kPushWeights) {
*ofst = ifst;
Push(ofst, rtype);
} else if (ptype & kPushLabels) {
const StringType stype = rtype == REWEIGHT_TO_INITIAL
? STRING_LEFT
: STRING_RIGHT;
vector<typename GallicArc<Arc, stype>::Weight> gdistance;
VectorFst< GallicArc<Arc, stype> > gfst;
Map(ifst, &gfst, ToGallicMapper<Arc, stype>());
if (ptype == (kPushWeights | kPushLabels)) {
ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL);
} else {
MapFst<Arc, Arc, RmWeightMapper<Arc> >
uwfst(ifst, RmWeightMapper<Arc>());
MapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> >
guwfst(uwfst, ToGallicMapper<Arc, stype>());
ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL);
}
Reweight(&gfst, gdistance, rtype);
FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label,
typename Arc::Weight, stype> > fwfst(gfst);
Map(fwfst, ofst, FromGallicMapper<Arc, stype>());
} else {
*ofst = ifst;
}
}
} // namespace fst
#endif /* FST_LIB_PUSH_H_ */