// reweight.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Function to reweight an FST.

#ifndef FST_LIB_REWEIGHT_H__
#define FST_LIB_REWEIGHT_H__

#include <vector>
using std::vector;

#include <fst/mutable-fst.h>


namespace fst {

enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };

// Reweight FST according to the potentials defined by the POTENTIAL
// vector in the direction defined by TYPE. Weight needs to be left
// distributive when reweighting towards the initial state and right
// distributive when reweighting towards the final states.
//
// An arc of weight w, with an origin state of potential p and
// destination state of potential q, is reweighted by p\wq when
// reweighting towards the initial state and by pw/q when reweighting
// towards the final states.
template <class Arc>
void Reweight(MutableFst<Arc> *fst,
              const vector<typename Arc::Weight> &potential,
              ReweightType type) {
  typedef typename Arc::Weight Weight;

  if (fst->NumStates() == 0)
    return;

  if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
    FSTERROR() << "Reweight: Reweighting to the final states requires "
               << "Weight to be right distributive: "
               << Weight::Type();
    fst->SetProperties(kError, kError);
    return;
  }

  if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
    FSTERROR() << "Reweight: Reweighting to the initial state requires "
               << "Weight to be left distributive: "
               << Weight::Type();
    fst->SetProperties(kError, kError);
    return;
  }

  StateIterator< MutableFst<Arc> > sit(*fst);
  for (; !sit.Done(); sit.Next()) {
    typename Arc::StateId state = sit.Value();
    if (state == potential.size())
      break;
    typename Arc::Weight weight = potential[state];
    if (weight != Weight::Zero()) {
      for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
           !ait.Done();
           ait.Next()) {
        Arc arc = ait.Value();
        if (arc.nextstate >= potential.size())
          continue;
        typename Arc::Weight nextweight = potential[arc.nextstate];
        if (nextweight == Weight::Zero())
          continue;
        if (type == REWEIGHT_TO_INITIAL)
          arc.weight = Divide(Times(arc.weight, nextweight), weight,
                              DIVIDE_LEFT);
        if (type == REWEIGHT_TO_FINAL)
          arc.weight = Divide(Times(weight, arc.weight), nextweight,
                              DIVIDE_RIGHT);
        ait.SetValue(arc);
      }
      if (type == REWEIGHT_TO_INITIAL)
        fst->SetFinal(state, Divide(fst->Final(state), weight, DIVIDE_LEFT));
    }
    if (type == REWEIGHT_TO_FINAL)
      fst->SetFinal(state, Times(weight, fst->Final(state)));
  }

  // This handles elements past the end of the potentials array.
  for (; !sit.Done(); sit.Next()) {
    typename Arc::StateId state = sit.Value();
    if (type == REWEIGHT_TO_FINAL)
      fst->SetFinal(state, Times(Weight::Zero(), fst->Final(state)));
  }

  typename Arc::Weight startweight = fst->Start() < potential.size() ?
      potential[fst->Start()] : Weight::Zero();
  if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
    if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
      typename Arc::StateId state = fst->Start();
      for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
           !ait.Done();
           ait.Next()) {
        Arc arc = ait.Value();
        if (type == REWEIGHT_TO_INITIAL)
          arc.weight = Times(startweight, arc.weight);
        else
          arc.weight = Times(
              Divide(Weight::One(), startweight, DIVIDE_RIGHT),
              arc.weight);
        ait.SetValue(arc);
      }
      if (type == REWEIGHT_TO_INITIAL)
        fst->SetFinal(state, Times(startweight, fst->Final(state)));
      else
        fst->SetFinal(state, Times(Divide(Weight::One(), startweight,
                                          DIVIDE_RIGHT),
                                   fst->Final(state)));
    } else {
      typename Arc::StateId state = fst->AddState();
      Weight w = type == REWEIGHT_TO_INITIAL ?  startweight :
                 Divide(Weight::One(), startweight, DIVIDE_RIGHT);
      Arc arc(0, 0, w, fst->Start());
      fst->AddArc(state, arc);
      fst->SetStart(state);
    }
  }

  fst->SetProperties(ReweightProperties(
                         fst->Properties(kFstProperties, false)),
                     kFstProperties);
}

}  // namespace fst

#endif  // FST_LIB_REWEIGHT_H_