// Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Authors: allauzen@google.com (Cyril Allauzen) // ttai@google.com (Terry Tai) // jpr@google.com (Jake Ratkiewicz) #ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_ #define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_ #include <libgen.h> #include <string> #include <vector> using std::vector; #include <fst/extensions/far/far.h> #include <fst/string.h> namespace fst { // Construct a reader that provides FSTs from a file (stream) either on a // line-by-line basis or on a per-stream basis. Note that the freshly // constructed reader is already set to the first input. // // Sample Usage: // for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) { // Fst *fst = reader.GetVectorFst(); // } template <class A> class StringReader { public: typedef A Arc; typedef typename A::Label Label; typedef typename A::Weight Weight; typedef typename StringCompiler<A>::TokenType TokenType; enum EntryType { LINE = 1, FILE = 2 }; StringReader(istream &istrm, const string &source, EntryType entry_type, TokenType token_type, bool allow_negative_labels, const SymbolTable *syms = 0, Label unknown_label = kNoStateId) : nline_(0), strm_(istrm), source_(source), entry_type_(entry_type), token_type_(token_type), done_(false), compiler_(token_type, syms, unknown_label, allow_negative_labels) { Next(); // Initialize the reader to the first input. } bool Done() { return done_; } void Next() { VLOG(1) << "Processing source " << source_ << " at line " << nline_; if (!strm_) { // We're done if we have no more input. done_ = true; return; } if (entry_type_ == LINE) { getline(strm_, content_); ++nline_; } else { content_.clear(); string line; while (getline(strm_, line)) { ++nline_; content_.append(line); content_.append("\n"); } } if (!strm_ && content_.empty()) // We're also done if we read off all the done_ = true; // whitespace at the end of a file. } VectorFst<A> *GetVectorFst() { VectorFst<A> *fst = new VectorFst<A>; if (compiler_(content_, fst)) { return fst; } else { delete fst; return NULL; } } CompactFst<A, StringCompactor<A> > *GetCompactFst() { CompactFst<A, StringCompactor<A> > *fst = new CompactFst<A, StringCompactor<A> >; if (compiler_(content_, fst)) { return fst; } else { delete fst; return NULL; } } private: size_t nline_; istream &strm_; string source_; EntryType entry_type_; TokenType token_type_; bool done_; StringCompiler<A> compiler_; string content_; // The actual content of the input stream's next FST. DISALLOW_COPY_AND_ASSIGN(StringReader); }; // Compute the minimal length required to encode each line number as a decimal // number. int KeySize(const char *filename); template <class Arc> void FarCompileStrings(const vector<string> &in_fnames, const string &out_fname, const string &fst_type, const FarType &far_type, int32 generate_keys, FarEntryType fet, FarTokenType tt, const string &symbols_fname, const string &unknown_symbol, bool allow_negative_labels, bool file_list_input, const string &key_prefix, const string &key_suffix) { typename StringReader<Arc>::EntryType entry_type; if (fet == FET_LINE) { entry_type = StringReader<Arc>::LINE; } else if (fet == FET_FILE) { entry_type = StringReader<Arc>::FILE; } else { FSTERROR() << "FarCompileStrings: unknown entry type"; return; } typename StringCompiler<Arc>::TokenType token_type; if (tt == FTT_SYMBOL) { token_type = StringCompiler<Arc>::SYMBOL; } else if (tt == FTT_BYTE) { token_type = StringCompiler<Arc>::BYTE; } else if (tt == FTT_UTF8) { token_type = StringCompiler<Arc>::UTF8; } else { FSTERROR() << "FarCompileStrings: unknown token type"; return; } bool compact; if (fst_type.empty() || (fst_type == "vector")) { compact = false; } else if (fst_type == "compact") { compact = true; } else { FSTERROR() << "FarCompileStrings: unknown fst type: " << fst_type; return; } const SymbolTable *syms = 0; typename Arc::Label unknown_label = kNoLabel; if (!symbols_fname.empty()) { syms = SymbolTable::ReadText(symbols_fname, allow_negative_labels); if (!syms) { FSTERROR() << "FarCompileStrings: error reading symbol table: " << symbols_fname; return; } if (!unknown_symbol.empty()) { unknown_label = syms->Find(unknown_symbol); if (unknown_label == kNoLabel) { FSTERROR() << "FarCompileStrings: unknown label \"" << unknown_label << "\" missing from symbol table: " << symbols_fname; return; } } } FarWriter<Arc> *far_writer = FarWriter<Arc>::Create(out_fname, far_type); if (!far_writer) return; vector<string> inputs; if (file_list_input) { for (int i = 1; i < in_fnames.size(); ++i) { ifstream istrm(in_fnames[i].c_str()); string str; while (getline(istrm, str)) inputs.push_back(str); } } else { inputs = in_fnames; } for (int i = 0, n = 0; i < inputs.size(); ++i) { int key_size = generate_keys ? generate_keys : (entry_type == StringReader<Arc>::FILE ? 1 : KeySize(inputs[i].c_str())); ifstream istrm(inputs[i].c_str()); for (StringReader<Arc> reader( istrm, inputs[i], entry_type, token_type, allow_negative_labels, syms, unknown_label); !reader.Done(); reader.Next()) { ++n; const Fst<Arc> *fst; if (compact) fst = reader.GetCompactFst(); else fst = reader.GetVectorFst(); if (!fst) { FSTERROR() << "FarCompileStrings: compiling string number " << n << " in file " << inputs[i] << " failed with token_type = " << (tt == FTT_BYTE ? "byte" : (tt == FTT_UTF8 ? "utf8" : (tt == FTT_SYMBOL ? "symbol" : "unknown"))) << " and entry_type = " << (fet == FET_LINE ? "line" : (fet == FET_FILE ? "file" : "unknown")); delete far_writer; delete syms; return; } ostringstream keybuf; keybuf.width(key_size); keybuf.fill('0'); keybuf << n; string key; if (generate_keys > 0) { key = keybuf.str(); } else { char* filename = new char[inputs[i].size() + 1]; strcpy(filename, inputs[i].c_str()); key = basename(filename); if (entry_type != StringReader<Arc>::FILE) { key += "-"; key += keybuf.str(); } delete[] filename; } far_writer->Add(key_prefix + key + key_suffix, *fst); delete fst; } if (generate_keys == 0) n = 0; } delete far_writer; } } // namespace fst #endif // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_