#include <sstream> #include <marisa.h> #include "assert.h" namespace { class FindCallback { public: FindCallback(std::vector<marisa::UInt32> *key_ids, std::vector<std::size_t> *key_lengths) : key_ids_(key_ids), key_lengths_(key_lengths) {} FindCallback(const FindCallback &callback) : key_ids_(callback.key_ids_), key_lengths_(callback.key_lengths_) {} bool operator()(marisa::UInt32 key_id, std::size_t key_length) const { key_ids_->push_back(key_id); key_lengths_->push_back(key_length); return true; } private: std::vector<marisa::UInt32> *key_ids_; std::vector<std::size_t> *key_lengths_; // Disallows assignment. FindCallback &operator=(const FindCallback &); }; class PredictCallback { public: PredictCallback(std::vector<marisa::UInt32> *key_ids, std::vector<std::string> *keys) : key_ids_(key_ids), keys_(keys) {} PredictCallback(const PredictCallback &callback) : key_ids_(callback.key_ids_), keys_(callback.keys_) {} bool operator()(marisa::UInt32 key_id, const std::string &key) const { key_ids_->push_back(key_id); keys_->push_back(key); return true; } private: std::vector<marisa::UInt32> *key_ids_; std::vector<std::string> *keys_; // Disallows assignment. PredictCallback &operator=(const PredictCallback &); }; void TestTrie() { TEST_START(); marisa::Trie trie; ASSERT(trie.num_tries() == 0); ASSERT(trie.num_keys() == 0); ASSERT(trie.num_nodes() == 0); ASSERT(trie.total_size() == (sizeof(marisa::UInt32) * 23)); std::vector<std::string> keys; trie.build(keys); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 0); ASSERT(trie.num_nodes() == 1); keys.push_back("apple"); keys.push_back("and"); keys.push_back("Bad"); keys.push_back("apple"); keys.push_back("app"); std::vector<marisa::UInt32> key_ids; trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL | MARISA_LABEL_ORDER); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 11); ASSERT(key_ids.size() == 5); ASSERT(key_ids[0] == 3); ASSERT(key_ids[1] == 1); ASSERT(key_ids[2] == 0); ASSERT(key_ids[3] == 3); ASSERT(key_ids[4] == 2); char key_buf[256]; std::size_t key_length; for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } trie.clear(); ASSERT(trie.num_tries() == 0); ASSERT(trie.num_keys() == 0); ASSERT(trie.num_nodes() == 0); ASSERT(trie.total_size() == (sizeof(marisa::UInt32) * 23)); trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 11); ASSERT(key_ids.size() == 5); ASSERT(key_ids[0] == 3); ASSERT(key_ids[1] == 1); ASSERT(key_ids[2] == 2); ASSERT(key_ids[3] == 3); ASSERT(key_ids[4] == 0); for (std::size_t i = 0; i < keys.size(); ++i) { ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); } ASSERT(trie["appl"] == trie.notfound()); ASSERT(trie["applex"] == trie.notfound()); ASSERT(trie.find_first("ap") == trie.notfound()); ASSERT(trie.find_first("applex") == trie["app"]); ASSERT(trie.find_last("ap") == trie.notfound()); ASSERT(trie.find_last("applex") == trie["apple"]); std::vector<marisa::UInt32> ids; ASSERT(trie.find("ap", &ids) == 0); ASSERT(trie.find("applex", &ids) == 2); ASSERT(ids.size() == 2); ASSERT(ids[0] == trie["app"]); ASSERT(ids[1] == trie["apple"]); std::vector<std::size_t> lengths; ASSERT(trie.find("Baddie", &ids, &lengths) == 1); ASSERT(ids.size() == 3); ASSERT(ids[2] == trie["Bad"]); ASSERT(lengths.size() == 1); ASSERT(lengths[0] == 3); ASSERT(trie.find_callback("anderson", FindCallback(&ids, &lengths)) == 1); ASSERT(ids.size() == 4); ASSERT(ids[3] == trie["and"]); ASSERT(lengths.size() == 2); ASSERT(lengths[1] == 3); ASSERT(trie.predict("") == 4); ASSERT(trie.predict("a") == 3); ASSERT(trie.predict("ap") == 2); ASSERT(trie.predict("app") == 2); ASSERT(trie.predict("appl") == 1); ASSERT(trie.predict("apple") == 1); ASSERT(trie.predict("appleX") == 0); ASSERT(trie.predict("X") == 0); ids.clear(); ASSERT(trie.predict("a", &ids) == 3); ASSERT(ids.size() == 3); ASSERT(ids[0] == trie["app"]); ASSERT(ids[1] == trie["and"]); ASSERT(ids[2] == trie["apple"]); std::vector<std::string> strs; ASSERT(trie.predict("a", &ids, &strs) == 3); ASSERT(ids.size() == 6); ASSERT(ids[3] == trie["app"]); ASSERT(ids[4] == trie["apple"]); ASSERT(ids[5] == trie["and"]); ASSERT(strs[0] == "app"); ASSERT(strs[1] == "apple"); ASSERT(strs[2] == "and"); TEST_END(); } void TestPrefixTrie() { TEST_START(); std::vector<std::string> keys; keys.push_back("after"); keys.push_back("bar"); keys.push_back("car"); keys.push_back("caster"); marisa::Trie trie; std::vector<marisa::UInt32> key_ids; trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE | MARISA_TEXT_TAIL | MARISA_LABEL_ORDER); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 7); char key_buf[256]; std::size_t key_length; for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } key_length = trie.restore(key_ids[0], NULL, 0); ASSERT(key_length == keys[0].length()); EXCEPT(trie.restore(key_ids[0], NULL, 5), MARISA_PARAM_ERROR); key_length = trie.restore(key_ids[0], key_buf, 5); ASSERT(key_length == keys[0].length()); key_length = trie.restore(key_ids[0], key_buf, 6); ASSERT(key_length == keys[0].length()); trie.build(keys, &key_ids, 2 | MARISA_PREFIX_TRIE | MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER); ASSERT(trie.num_tries() == 2); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 16); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } key_length = trie.restore(key_ids[0], NULL, 0); ASSERT(key_length == keys[0].length()); EXCEPT(trie.restore(key_ids[0], NULL, 5), MARISA_PARAM_ERROR); key_length = trie.restore(key_ids[0], key_buf, 5); ASSERT(key_length == keys[0].length()); key_length = trie.restore(key_ids[0], key_buf, 6); ASSERT(key_length == keys[0].length()); trie.build(keys, &key_ids, 2 | MARISA_PREFIX_TRIE | MARISA_TEXT_TAIL | MARISA_LABEL_ORDER); ASSERT(trie.num_tries() == 2); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 14); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } trie.save("trie-test.dat"); trie.clear(); marisa::Mapper mapper; trie.mmap(&mapper, "trie-test.dat"); ASSERT(mapper.is_open()); ASSERT(trie.num_tries() == 2); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 14); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } std::stringstream stream; trie.write(stream); trie.clear(); trie.read(stream); ASSERT(trie.num_tries() == 2); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 14); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } trie.build(keys, &key_ids, 3 | MARISA_PREFIX_TRIE | MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER); ASSERT(trie.num_tries() == 3); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 19); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } ASSERT(trie["ca"] == trie.notfound()); ASSERT(trie["card"] == trie.notfound()); std::size_t length = 0; ASSERT(trie.find_first("ca") == trie.notfound()); ASSERT(trie.find_first("car") == trie["car"]); ASSERT(trie.find_first("card", &length) == trie["car"]); ASSERT(length == 3); ASSERT(trie.find_last("afte") == trie.notfound()); ASSERT(trie.find_last("after") == trie["after"]); ASSERT(trie.find_last("afternoon", &length) == trie["after"]); ASSERT(length == 5); { std::vector<marisa::UInt32> ids; std::vector<std::size_t> lengths; ASSERT(trie.find("card", &ids, &lengths) == 1); ASSERT(ids.size() == 1); ASSERT(ids[0] == trie["car"]); ASSERT(lengths.size() == 1); ASSERT(lengths[0] == 3); ASSERT(trie.predict("ca", &ids) == 2); ASSERT(ids.size() == 3); ASSERT(ids[1] == trie["car"]); ASSERT(ids[2] == trie["caster"]); ASSERT(trie.predict("ca", &ids, NULL, 1) == 1); ASSERT(ids.size() == 4); ASSERT(ids[3] == trie["car"]); std::vector<std::string> strs; ASSERT(trie.predict("ca", &ids, &strs, 1) == 1); ASSERT(ids.size() == 5); ASSERT(ids[4] == trie["car"]); ASSERT(strs.size() == 1); ASSERT(strs[0] == "car"); ASSERT(trie.predict_callback("", PredictCallback(&ids, &strs)) == 4); ASSERT(ids.size() == 9); ASSERT(ids[5] == trie["car"]); ASSERT(ids[6] == trie["caster"]); ASSERT(ids[7] == trie["after"]); ASSERT(ids[8] == trie["bar"]); ASSERT(strs.size() == 5); ASSERT(strs[1] == "car"); ASSERT(strs[2] == "caster"); ASSERT(strs[3] == "after"); ASSERT(strs[4] == "bar"); } { marisa::UInt32 ids[10]; std::size_t lengths[10]; ASSERT(trie.find("card", ids, lengths, 10) == 1); ASSERT(ids[0] == trie["car"]); ASSERT(lengths[0] == 3); ASSERT(trie.predict("ca", ids, NULL, 10) == 2); ASSERT(ids[0] == trie["car"]); ASSERT(ids[1] == trie["caster"]); ASSERT(trie.predict("ca", ids, NULL, 1) == 1); ASSERT(ids[0] == trie["car"]); std::string strs[10]; ASSERT(trie.predict("ca", ids, strs, 1) == 1); ASSERT(ids[0] == trie["car"]); ASSERT(strs[0] == "car"); ASSERT(trie.predict("", ids, strs, 10) == 4); ASSERT(ids[0] == trie["car"]); ASSERT(ids[1] == trie["caster"]); ASSERT(ids[2] == trie["after"]); ASSERT(ids[3] == trie["bar"]); ASSERT(strs[0] == "car"); ASSERT(strs[1] == "caster"); ASSERT(strs[2] == "after"); ASSERT(strs[3] == "bar"); } TEST_END(); } void TestPatriciaTrie() { TEST_START(); std::vector<std::string> keys; keys.push_back("bach"); keys.push_back("bet"); keys.push_back("chat"); keys.push_back("check"); keys.push_back("check"); marisa::Trie trie; std::vector<marisa::UInt32> key_ids; trie.build(keys, &key_ids, 1); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 7); ASSERT(key_ids.size() == 5); ASSERT(key_ids[0] == 2); ASSERT(key_ids[1] == 3); ASSERT(key_ids[2] == 1); ASSERT(key_ids[3] == 0); ASSERT(key_ids[4] == 0); char key_buf[256]; std::size_t key_length; for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } trie.build(keys, &key_ids, 2 | MARISA_WITHOUT_TAIL); ASSERT(trie.num_tries() == 2); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 17); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } trie.build(keys, &key_ids, 2); ASSERT(trie.num_tries() == 2); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 14); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } trie.build(keys, &key_ids, 3 | MARISA_WITHOUT_TAIL); ASSERT(trie.num_tries() == 3); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 20); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } std::stringstream stream; trie.write(stream); trie.clear(); trie.read(stream); ASSERT(trie.num_tries() == 3); ASSERT(trie.num_keys() == 4); ASSERT(trie.num_nodes() == 20); for (std::size_t i = 0; i < keys.size(); ++i) { key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf)); ASSERT(trie[keys[i]] == key_ids[i]); ASSERT(trie[key_ids[i]] == keys[i]); ASSERT(key_length == keys[i].length()); ASSERT(keys[i] == key_buf); } TEST_END(); } void TestEmptyString() { TEST_START(); std::vector<std::string> keys; keys.push_back(""); marisa::Trie trie; std::vector<marisa::UInt32> key_ids; trie.build(keys, &key_ids); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 1); ASSERT(trie.num_nodes() == 1); ASSERT(key_ids.size() == 1); ASSERT(key_ids[0] == 0); ASSERT(trie[""] == 0); ASSERT(trie[(marisa::UInt32)0] == ""); ASSERT(trie["x"] == trie.notfound()); ASSERT(trie.find_first("") == 0); ASSERT(trie.find_first("x") == 0); ASSERT(trie.find_last("") == 0); ASSERT(trie.find_last("x") == 0); std::vector<marisa::UInt32> ids; ASSERT(trie.find("xyz", &ids) == 1); ASSERT(ids.size() == 1); ASSERT(ids[0] == trie[""]); std::vector<std::size_t> lengths; ASSERT(trie.find("xyz", &ids, &lengths) == 1); ASSERT(ids.size() == 2); ASSERT(ids[0] == trie[""]); ASSERT(ids[1] == trie[""]); ASSERT(lengths.size() == 1); ASSERT(lengths[0] == 0); ASSERT(trie.find_callback("xyz", FindCallback(&ids, &lengths)) == 1); ASSERT(ids.size() == 3); ASSERT(ids[2] == trie[""]); ASSERT(lengths.size() == 2); ASSERT(lengths[1] == 0); ASSERT(trie.predict("xyz", &ids) == 0); ASSERT(trie.predict("", &ids) == 1); ASSERT(ids.size() == 4); ASSERT(ids[3] == trie[""]); std::vector<std::string> strs; ASSERT(trie.predict("", &ids, &strs) == 1); ASSERT(ids.size() == 5); ASSERT(ids[4] == trie[""]); ASSERT(strs[0] == ""); TEST_END(); } void TestBinaryKey() { TEST_START(); std::string binary_key = "NP"; binary_key += '\0'; binary_key += "Trie"; std::vector<std::string> keys; keys.push_back(binary_key); marisa::Trie trie; std::vector<marisa::UInt32> key_ids; trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 1); ASSERT(trie.num_nodes() == 8); ASSERT(key_ids.size() == 1); char key_buf[256]; std::size_t key_length; key_length = trie.restore(0, key_buf, sizeof(key_buf)); ASSERT(trie[keys[0]] == key_ids[0]); ASSERT(trie[key_ids[0]] == keys[0]); ASSERT(std::string(key_buf, key_length) == keys[0]); trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE | MARISA_BINARY_TAIL); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 1); ASSERT(trie.num_nodes() == 2); ASSERT(key_ids.size() == 1); key_length = trie.restore(0, key_buf, sizeof(key_buf)); ASSERT(trie[keys[0]] == key_ids[0]); ASSERT(trie[key_ids[0]] == keys[0]); ASSERT(std::string(key_buf, key_length) == keys[0]); trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE | MARISA_TEXT_TAIL); ASSERT(trie.num_tries() == 1); ASSERT(trie.num_keys() == 1); ASSERT(trie.num_nodes() == 2); ASSERT(key_ids.size() == 1); key_length = trie.restore(0, key_buf, sizeof(key_buf)); ASSERT(trie[keys[0]] == key_ids[0]); ASSERT(trie[key_ids[0]] == keys[0]); ASSERT(std::string(key_buf, key_length) == keys[0]); std::vector<marisa::UInt32> ids; ASSERT(trie.predict_breadth_first("", &ids) == 1); ASSERT(ids.size() == 1); ASSERT(ids[0] == key_ids[0]); std::vector<std::string> strs; ASSERT(trie.predict_depth_first("NP", &ids, &strs) == 1); ASSERT(ids.size() == 2); ASSERT(ids[1] == key_ids[0]); ASSERT(strs[0] == keys[0]); TEST_END(); } } // namespace int main() { TestTrie(); TestPrefixTrie(); TestPatriciaTrie(); TestEmptyString(); TestBinaryKey(); return 0; }