#ifndef __FST_IO_H__ #define __FST_IO_H__ // fst-io.h // This is a copy of the OPENFST SDK application sample files ... // except for the main functions ifdef'ed out // 2007, 2008 Nuance Communications // // print-main.h compile-main.h // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // // \file // Classes and functions to compile a binary Fst from textual input. // Includes helper function for fstcompile.cc that templates the main // on the arc type to support multiple and extensible arc types. #include <fstream> #include <sstream> #include "fst/lib/fst.h" #include "fst/lib/fstlib.h" #include "fst/lib/fst-decl.h" #include "fst/lib/vector-fst.h" #include "fst/lib/arcsort.h" #include "fst/lib/invert.h" namespace fst { template <class A> class FstPrinter { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Label Label; typedef typename A::Weight Weight; FstPrinter(const Fst<A> &fst, const SymbolTable *isyms, const SymbolTable *osyms, const SymbolTable *ssyms, bool accep) : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms), accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0) {} // Print Fst to an output strm void Print(ostream *ostrm, const string &dest) { ostrm_ = ostrm; dest_ = dest; StateId start = fst_.Start(); if (start == kNoStateId) return; // initial state first PrintState(start); for (StateIterator< Fst<A> > siter(fst_); !siter.Done(); siter.Next()) { StateId s = siter.Value(); if (s != start) PrintState(s); } } private: // Maximum line length in text file. static const int kLineLen = 8096; void PrintId(int64 id, const SymbolTable *syms, const char *name) const { if (syms) { string symbol = syms->Find(id); if (symbol == "") { LOG(ERROR) << "FstPrinter: Integer " << id << " is not mapped to any textual symbol" << ", symbol table = " << syms->Name() << ", destination = " << dest_; exit(1); } *ostrm_ << symbol; } else { *ostrm_ << id; } } void PrintStateId(StateId s) const { PrintId(s, ssyms_, "state ID"); } void PrintILabel(Label l) const { PrintId(l, isyms_, "arc input label"); } void PrintOLabel(Label l) const { PrintId(l, osyms_, "arc output label"); } void PrintState(StateId s) const { bool output = false; for (ArcIterator< Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) { Arc arc = aiter.Value(); PrintStateId(s); *ostrm_ << "\t"; PrintStateId(arc.nextstate); *ostrm_ << "\t"; PrintILabel(arc.ilabel); if (!accep_) { *ostrm_ << "\t"; PrintOLabel(arc.olabel); } if (arc.weight != Weight::One()) *ostrm_ << "\t" << arc.weight; *ostrm_ << "\n"; output = true; } Weight final = fst_.Final(s); if (final != Weight::Zero() || !output) { PrintStateId(s); if (final != Weight::One()) { *ostrm_ << "\t" << final; } *ostrm_ << "\n"; } } const Fst<A> &fst_; const SymbolTable *isyms_; // ilabel symbol table const SymbolTable *osyms_; // olabel symbol table const SymbolTable *ssyms_; // slabel symbol table bool accep_; // print as acceptor when possible ostream *ostrm_; // binary FST destination string dest_; // binary FST destination name DISALLOW_EVIL_CONSTRUCTORS(FstPrinter); }; #if 0 // Main function for fstprint templated on the arc type. template <class Arc> int PrintMain(int argc, char **argv, istream &istrm, const FstReadOptions &opts) { Fst<Arc> *fst = Fst<Arc>::Read(istrm, opts); if (!fst) return 1; string dest = "standard output"; ostream *ostrm = &std::cout; if (argc == 3) { dest = argv[2]; ostrm = new ofstream(argv[2]); if (!*ostrm) { LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[2]; return 0; } } ostrm->precision(9); const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0; if (!FLAGS_isymbols.empty() && !FLAGS_numeric) { isyms = SymbolTable::ReadText(FLAGS_isymbols); if (!isyms) exit(1); } if (!FLAGS_osymbols.empty() && !FLAGS_numeric) { osyms = SymbolTable::ReadText(FLAGS_osymbols); if (!osyms) exit(1); } if (!FLAGS_ssymbols.empty() && !FLAGS_numeric) { ssyms = SymbolTable::ReadText(FLAGS_ssymbols); if (!ssyms) exit(1); } if (!isyms && !FLAGS_numeric) isyms = fst->InputSymbols(); if (!osyms && !FLAGS_numeric) osyms = fst->OutputSymbols(); FstPrinter<Arc> fstprinter(*fst, isyms, osyms, ssyms, FLAGS_acceptor); fstprinter.Print(ostrm, dest); if (isyms && !FLAGS_save_isymbols.empty()) isyms->WriteText(FLAGS_save_isymbols); if (osyms && !FLAGS_save_osymbols.empty()) osyms->WriteText(FLAGS_save_osymbols); if (ostrm != &std::cout) delete ostrm; return 0; } #endif template <class A> class FstReader { public: typedef A Arc; typedef typename A::StateId StateId; typedef typename A::Label Label; typedef typename A::Weight Weight; FstReader(istream &istrm, const string &source, const SymbolTable *isyms, const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, bool ikeep, bool okeep, bool nkeep) : nline_(0), source_(source), isyms_(isyms), osyms_(osyms), ssyms_(ssyms), nstates_(0), keep_state_numbering_(nkeep) { char line[kLineLen]; while (istrm.getline(line, kLineLen)) { ++nline_; vector<char *> col; SplitToVector(line, "\n\t ", &col, true); if (col.size() == 0 || col[0][0] == '\0') // empty line continue; if (col.size() > 5 || col.size() > 4 && accep || col.size() == 3 && !accep) { LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_ << ", line = " << nline_; exit(1); } StateId s = StrToStateId(col[0]); while (s >= fst_.NumStates()) fst_.AddState(); if (nline_ == 1) fst_.SetStart(s); Arc arc; StateId d = s; switch (col.size()) { case 1: fst_.SetFinal(s, Weight::One()); break; case 2: fst_.SetFinal(s, StrToWeight(col[1], true)); break; case 3: arc.nextstate = d = StrToStateId(col[1]); arc.ilabel = StrToILabel(col[2]); arc.olabel = arc.ilabel; arc.weight = Weight::One(); fst_.AddArc(s, arc); break; case 4: arc.nextstate = d = StrToStateId(col[1]); arc.ilabel = StrToILabel(col[2]); if (accep) { arc.olabel = arc.ilabel; arc.weight = StrToWeight(col[3], false); } else { arc.olabel = StrToOLabel(col[3]); arc.weight = Weight::One(); } fst_.AddArc(s, arc); break; case 5: arc.nextstate = d = StrToStateId(col[1]); arc.ilabel = StrToILabel(col[2]); arc.olabel = StrToOLabel(col[3]); arc.weight = StrToWeight(col[4], false); fst_.AddArc(s, arc); } while (d >= fst_.NumStates()) fst_.AddState(); } if (ikeep) fst_.SetInputSymbols(isyms); if (okeep) fst_.SetOutputSymbols(osyms); } const VectorFst<A> &Fst() const { return fst_; } private: // Maximum line length in text file. static const int kLineLen = 8096; int64 StrToId(const char *s, const SymbolTable *syms, const char *name) const { int64 n; if (syms) { n = syms->Find(s); if (n < 0) { LOG(ERROR) << "FstReader: Symbol \"" << s << "\" is not mapped to any integer " << name << ", symbol table = " << syms->Name() << ", source = " << source_ << ", line = " << nline_; exit(1); } } else { char *p; n = strtoll(s, &p, 10); if (p < s + strlen(s) || n < 0) { LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s << "\", source = " << source_ << ", line = " << nline_; exit(1); } } return n; } StateId StrToStateId(const char *s) { StateId n = StrToId(s, ssyms_, "state ID"); if (keep_state_numbering_) return n; // remap state IDs to make dense set typename hash_map<StateId, StateId>::const_iterator it = states_.find(n); if (it == states_.end()) { states_[n] = nstates_; return nstates_++; } else { return it->second; } } StateId StrToILabel(const char *s) const { return StrToId(s, isyms_, "arc ilabel"); } StateId StrToOLabel(const char *s) const { return StrToId(s, osyms_, "arc olabel"); } Weight StrToWeight(const char *s, bool allow_zero) const { Weight w; istringstream strm(s); strm >> w; if (strm.fail() || !allow_zero && w == Weight::Zero()) { LOG(ERROR) << "FstReader: Bad weight = \"" << s << "\", source = " << source_ << ", line = " << nline_; exit(1); } return w; } VectorFst<A> fst_; size_t nline_; string source_; // text FST source name const SymbolTable *isyms_; // ilabel symbol table const SymbolTable *osyms_; // olabel symbol table const SymbolTable *ssyms_; // slabel symbol table hash_map<StateId, StateId> states_; // state ID map StateId nstates_; // number of seen states bool keep_state_numbering_; DISALLOW_EVIL_CONSTRUCTORS(FstReader); }; #if 0 // Main function for fstcompile templated on the arc type. Last two // arguments unneeded since fstcompile passes the arc type as a flag // unlike the other mains, which infer the arc type from an input Fst. template <class Arc> int CompileMain(int argc, char **argv, istream& /* strm */, const FstReadOptions & /* opts */) { char *ifilename = "standard input"; istream *istrm = &std::cin; if (argc > 1 && strcmp(argv[1], "-") != 0) { ifilename = argv[1]; istrm = new ifstream(ifilename); if (!*istrm) { LOG(ERROR) << argv[0] << ": Open failed, file = " << ifilename; return 1; } } const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0; if (!FLAGS_isymbols.empty()) { isyms = SymbolTable::ReadText(FLAGS_isymbols); if (!isyms) exit(1); } if (!FLAGS_osymbols.empty()) { osyms = SymbolTable::ReadText(FLAGS_osymbols); if (!osyms) exit(1); } if (!FLAGS_ssymbols.empty()) { ssyms = SymbolTable::ReadText(FLAGS_ssymbols); if (!ssyms) exit(1); } FstReader<Arc> fstreader(*istrm, ifilename, isyms, osyms, ssyms, FLAGS_acceptor, FLAGS_keep_isymbols, FLAGS_keep_osymbols, FLAGS_keep_state_numbering); const Fst<Arc> *fst = &fstreader.Fst(); if (FLAGS_fst_type != "vector") { fst = Convert<Arc>(*fst, FLAGS_fst_type); if (!fst) return 1; } fst->Write(argc > 2 ? argv[2] : ""); if (istrm != &std::cin) delete istrm; return 0; } #endif } // namespace fst #endif /* __FST_IO_H__ */