// weight_test.h // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Copyright 2005-2010 Google, Inc. // Author: riley@google.com (Michael Riley) // // \file // Regression test for Fst weights. #include <cstdlib> #include <ctime> #include <fst/expectation-weight.h> #include <fst/float-weight.h> #include <fst/random-weight.h> #include "./weight-tester.h" DEFINE_int32(seed, -1, "random seed"); DEFINE_int32(repeat, 100000, "number of test repetitions"); using fst::TropicalWeight; using fst::TropicalWeightGenerator; using fst::TropicalWeightTpl; using fst::TropicalWeightGenerator_; using fst::LogWeight; using fst::LogWeightGenerator; using fst::LogWeightTpl; using fst::LogWeightGenerator_; using fst::MinMaxWeight; using fst::MinMaxWeightGenerator; using fst::MinMaxWeightTpl; using fst::MinMaxWeightGenerator_; using fst::StringWeight; using fst::StringWeightGenerator; using fst::GallicWeight; using fst::GallicWeightGenerator; using fst::LexicographicWeight; using fst::LexicographicWeightGenerator; using fst::ProductWeight; using fst::ProductWeightGenerator; using fst::PowerWeight; using fst::PowerWeightGenerator; using fst::SignedLogWeightTpl; using fst::SignedLogWeightGenerator_; using fst::ExpectationWeight; using fst::SparsePowerWeight; using fst::SparsePowerWeightGenerator; using fst::STRING_LEFT; using fst::STRING_RIGHT; using fst::WeightTester; template <class T> void TestTemplatedWeights(int repeat, int seed) { TropicalWeightGenerator_<T> tropical_generator(seed); WeightTester<TropicalWeightTpl<T>, TropicalWeightGenerator_<T> > tropical_tester(tropical_generator); tropical_tester.Test(repeat); LogWeightGenerator_<T> log_generator(seed); WeightTester<LogWeightTpl<T>, LogWeightGenerator_<T> > log_tester(log_generator); log_tester.Test(repeat); MinMaxWeightGenerator_<T> minmax_generator(seed); WeightTester<MinMaxWeightTpl<T>, MinMaxWeightGenerator_<T> > minmax_tester(minmax_generator); minmax_tester.Test(repeat); SignedLogWeightGenerator_<T> signedlog_generator(seed); WeightTester<SignedLogWeightTpl<T>, SignedLogWeightGenerator_<T> > signedlog_tester(signedlog_generator); signedlog_tester.Test(repeat); } int main(int argc, char **argv) { std::set_new_handler(FailedNewHandler); SetFlags(argv[0], &argc, &argv, true); int seed = FLAGS_seed >= 0 ? FLAGS_seed : time(0); LOG(INFO) << "Seed = " << seed; TestTemplatedWeights<float>(FLAGS_repeat, seed); TestTemplatedWeights<double>(FLAGS_repeat, seed); FLAGS_fst_weight_parentheses = "()"; TestTemplatedWeights<float>(FLAGS_repeat, seed); TestTemplatedWeights<double>(FLAGS_repeat, seed); FLAGS_fst_weight_parentheses = ""; // Make sure type names for templated weights are consistent CHECK(TropicalWeight::Type() == "tropical"); CHECK(TropicalWeightTpl<double>::Type() != TropicalWeightTpl<float>::Type()); CHECK(LogWeight::Type() == "log"); CHECK(LogWeightTpl<double>::Type() != LogWeightTpl<float>::Type()); TropicalWeightTpl<double> w(15.0); TropicalWeight tw(15.0); StringWeightGenerator<int> left_string_generator(seed); WeightTester<StringWeight<int>, StringWeightGenerator<int> > left_string_tester(left_string_generator); left_string_tester.Test(FLAGS_repeat); StringWeightGenerator<int, STRING_RIGHT> right_string_generator(seed); WeightTester<StringWeight<int, STRING_RIGHT>, StringWeightGenerator<int, STRING_RIGHT> > right_string_tester(right_string_generator); right_string_tester.Test(FLAGS_repeat); typedef GallicWeight<int, TropicalWeight> TropicalGallicWeight; typedef GallicWeightGenerator<int, TropicalWeightGenerator> TropicalGallicWeightGenerator; TropicalGallicWeightGenerator tropical_gallic_generator(seed); WeightTester<TropicalGallicWeight, TropicalGallicWeightGenerator> tropical_gallic_tester(tropical_gallic_generator); tropical_gallic_tester.Test(FLAGS_repeat); typedef ProductWeight<TropicalWeight, TropicalWeight> TropicalProductWeight; typedef ProductWeightGenerator<TropicalWeightGenerator, TropicalWeightGenerator> TropicalProductWeightGenerator; TropicalProductWeightGenerator tropical_product_generator(seed); WeightTester<TropicalProductWeight, TropicalProductWeightGenerator> tropical_product_weight_tester(tropical_product_generator); tropical_product_weight_tester.Test(FLAGS_repeat); typedef PowerWeight<TropicalWeight, 3> TropicalCubeWeight; typedef PowerWeightGenerator<TropicalWeightGenerator, 3> TropicalCubeWeightGenerator; TropicalCubeWeightGenerator tropical_cube_generator(seed); WeightTester<TropicalCubeWeight, TropicalCubeWeightGenerator> tropical_cube_weight_tester(tropical_cube_generator); tropical_cube_weight_tester.Test(FLAGS_repeat); typedef ProductWeight<TropicalWeight, TropicalProductWeight> SecondNestedProductWeight; typedef ProductWeightGenerator<TropicalWeightGenerator, TropicalProductWeightGenerator> SecondNestedProductWeightGenerator; SecondNestedProductWeightGenerator second_nested_product_generator(seed); WeightTester<SecondNestedProductWeight, SecondNestedProductWeightGenerator> second_nested_product_weight_tester(second_nested_product_generator); second_nested_product_weight_tester.Test(FLAGS_repeat); // This only works with fst_weight_parentheses = "()" typedef ProductWeight<TropicalProductWeight, TropicalWeight> FirstNestedProductWeight; typedef ProductWeightGenerator<TropicalProductWeightGenerator, TropicalWeightGenerator> FirstNestedProductWeightGenerator; FirstNestedProductWeightGenerator first_nested_product_generator(seed); WeightTester<FirstNestedProductWeight, FirstNestedProductWeightGenerator> first_nested_product_weight_tester(first_nested_product_generator); typedef PowerWeight<FirstNestedProductWeight, 3> NestedProductCubeWeight; typedef PowerWeightGenerator<FirstNestedProductWeightGenerator, 3> NestedProductCubeWeightGenerator; NestedProductCubeWeightGenerator nested_product_cube_generator(seed); WeightTester<NestedProductCubeWeight, NestedProductCubeWeightGenerator> nested_product_cube_weight_tester(nested_product_cube_generator); typedef SparsePowerWeight<NestedProductCubeWeight, size_t > SparseNestedProductCubeWeight; typedef SparsePowerWeightGenerator<NestedProductCubeWeightGenerator, size_t, 3> SparseNestedProductCubeWeightGenerator; SparseNestedProductCubeWeightGenerator sparse_nested_product_cube_generator(seed); WeightTester<SparseNestedProductCubeWeight, SparseNestedProductCubeWeightGenerator> sparse_nested_product_cube_weight_tester( sparse_nested_product_cube_generator); typedef SparsePowerWeight<LogWeight, size_t > LogSparsePowerWeight; typedef SparsePowerWeightGenerator<LogWeightGenerator, size_t, 3> LogSparsePowerWeightGenerator; LogSparsePowerWeightGenerator log_sparse_power_weight_generator(seed); WeightTester<LogSparsePowerWeight, LogSparsePowerWeightGenerator> log_sparse_power_weight_tester( log_sparse_power_weight_generator); typedef ExpectationWeight<LogWeight, LogWeight> LogLogExpectWeight; typedef ProductWeightGenerator<LogWeightGenerator, LogWeightGenerator, LogLogExpectWeight> LogLogExpectWeightGenerator; LogLogExpectWeightGenerator log_log_expect_weight_generator(seed); WeightTester<LogLogExpectWeight, LogLogExpectWeightGenerator> log_log_expect_weight_tester(log_log_expect_weight_generator); typedef ExpectationWeight<LogWeight, LogSparsePowerWeight> LogLogSparseExpectWeight; typedef ProductWeightGenerator< LogWeightGenerator, LogSparsePowerWeightGenerator, LogLogSparseExpectWeight> LogLogSparseExpectWeightGenerator; LogLogSparseExpectWeightGenerator log_logsparse_expect_weight_generator(seed); WeightTester<LogLogSparseExpectWeight, LogLogSparseExpectWeightGenerator> log_logsparse_expect_weight_tester(log_logsparse_expect_weight_generator); // Test all product weight I/O with parentheses FLAGS_fst_weight_parentheses = "()"; first_nested_product_weight_tester.Test(FLAGS_repeat); nested_product_cube_weight_tester.Test(FLAGS_repeat); log_sparse_power_weight_tester.Test(1); sparse_nested_product_cube_weight_tester.Test(1); tropical_product_weight_tester.Test(5); second_nested_product_weight_tester.Test(5); tropical_gallic_tester.Test(5); tropical_cube_weight_tester.Test(5); FLAGS_fst_weight_parentheses = ""; log_sparse_power_weight_tester.Test(1); log_log_expect_weight_tester.Test(1, false); // disables division log_logsparse_expect_weight_tester.Test(1, false); typedef LexicographicWeight<TropicalWeight, TropicalWeight> TropicalLexicographicWeight; typedef LexicographicWeightGenerator<TropicalWeightGenerator, TropicalWeightGenerator> TropicalLexicographicWeightGenerator; TropicalLexicographicWeightGenerator tropical_lexicographic_generator(seed); WeightTester<TropicalLexicographicWeight, TropicalLexicographicWeightGenerator> tropical_lexicographic_tester(tropical_lexicographic_generator); tropical_lexicographic_tester.Test(FLAGS_repeat); cout << "PASS" << endl; return 0; }