#include <algorithm>
#include <stdexcept>

#include "trie.h"

namespace marisa {
namespace {

template <typename T, typename U>
class PredictCallback {
 public:
  PredictCallback(T key_ids, U keys, std::size_t max_num_results)
      : key_ids_(key_ids), keys_(keys),
        max_num_results_(max_num_results), num_results_(0) {}
  PredictCallback(const PredictCallback &callback)
      : key_ids_(callback.key_ids_), keys_(callback.keys_),
        max_num_results_(callback.max_num_results_),
        num_results_(callback.num_results_) {}

  bool operator()(marisa::UInt32 key_id, const std::string &key) {
    if (key_ids_.is_valid()) {
      key_ids_.insert(num_results_, key_id);
    }
    if (keys_.is_valid()) {
      keys_.insert(num_results_, key);
    }
    return ++num_results_ < max_num_results_;
  }

 private:
  T key_ids_;
  U keys_;
  const std::size_t max_num_results_;
  std::size_t num_results_;

  // Disallows assignment.
  PredictCallback &operator=(const PredictCallback &);
};

}  // namespace

std::string Trie::restore(UInt32 key_id) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
  std::string key;
  restore_(key_id, &key);
  return key;
}

void Trie::restore(UInt32 key_id, std::string *key) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
  MARISA_THROW_IF(key == NULL, MARISA_PARAM_ERROR);
  restore_(key_id, key);
}

std::size_t Trie::restore(UInt32 key_id, char *key_buf,
    std::size_t key_buf_size) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
  MARISA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
      MARISA_PARAM_ERROR);
  return restore_(key_id, key_buf, key_buf_size);
}

UInt32 Trie::lookup(const char *str) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return lookup_<CQuery>(CQuery(str));
}

UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return lookup_<const Query &>(Query(ptr, length));
}

std::size_t Trie::find(const char *str,
    UInt32 *key_ids, std::size_t *key_lengths,
    std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return find_<CQuery>(CQuery(str),
      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
}

std::size_t Trie::find(const char *ptr, std::size_t length,
    UInt32 *key_ids, std::size_t *key_lengths,
    std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return find_<const Query &>(Query(ptr, length),
      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
}

std::size_t Trie::find(const char *str,
    std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
    std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return find_<CQuery>(CQuery(str),
      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
}

std::size_t Trie::find(const char *ptr, std::size_t length,
    std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
    std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return find_<const Query &>(Query(ptr, length),
      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
}

UInt32 Trie::find_first(const char *str,
    std::size_t *key_length) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return find_first_<CQuery>(CQuery(str), key_length);
}

UInt32 Trie::find_first(const char *ptr, std::size_t length,
    std::size_t *key_length) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return find_first_<const Query &>(Query(ptr, length), key_length);
}

UInt32 Trie::find_last(const char *str,
    std::size_t *key_length) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return find_last_<CQuery>(CQuery(str), key_length);
}

UInt32 Trie::find_last(const char *ptr, std::size_t length,
    std::size_t *key_length) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return find_last_<const Query &>(Query(ptr, length), key_length);
}

std::size_t Trie::predict(const char *str,
    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return (keys == NULL) ?
      predict_breadth_first(str, key_ids, keys, max_num_results) :
      predict_depth_first(str, key_ids, keys, max_num_results);
}

std::size_t Trie::predict(const char *ptr, std::size_t length,
    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return (keys == NULL) ?
      predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
      predict_depth_first(ptr, length, key_ids, keys, max_num_results);
}

std::size_t Trie::predict(const char *str,
    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return (keys == NULL) ?
      predict_breadth_first(str, key_ids, keys, max_num_results) :
      predict_depth_first(str, key_ids, keys, max_num_results);
}

std::size_t Trie::predict(const char *ptr, std::size_t length,
    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return (keys == NULL) ?
      predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
      predict_depth_first(ptr, length, key_ids, keys, max_num_results);
}

std::size_t Trie::predict_breadth_first(const char *str,
    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return predict_breadth_first_<CQuery>(CQuery(str),
      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}

std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return predict_breadth_first_<const Query &>(Query(ptr, length),
      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}

std::size_t Trie::predict_breadth_first(const char *str,
    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return predict_breadth_first_<CQuery>(CQuery(str),
      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}

std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
    std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return predict_breadth_first_<const Query &>(Query(ptr, length),
      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}

std::size_t Trie::predict_depth_first(const char *str,
    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return predict_depth_first_<CQuery>(CQuery(str),
      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}

std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return predict_depth_first_<const Query &>(Query(ptr, length),
      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}

std::size_t Trie::predict_depth_first(
    const char *str, std::vector<UInt32> *key_ids,
    std::vector<std::string> *keys, std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
  return predict_depth_first_<CQuery>(CQuery(str),
      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}

std::size_t Trie::predict_depth_first(
    const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
    std::vector<std::string> *keys, std::size_t max_num_results) const {
  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
  return predict_depth_first_<const Query &>(Query(ptr, length),
      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
}

void Trie::restore_(UInt32 key_id, std::string *key) const {
  const std::size_t start_pos = key->length();
  UInt32 node = key_id_to_node(key_id);
  while (node != 0) {
    if (has_link(node)) {
      const std::size_t prev_pos = key->length();
      if (has_trie()) {
        trie_->trie_restore(get_link(node), key);
      } else {
        tail_restore(node, key);
      }
      std::reverse(key->begin() + prev_pos, key->end());
    } else {
      *key += labels_[node];
    }
    node = get_parent(node);
  }
  std::reverse(key->begin() + start_pos, key->end());
}

void Trie::trie_restore(UInt32 node, std::string *key) const {
  do {
    if (has_link(node)) {
      if (has_trie()) {
        trie_->trie_restore(get_link(node), key);
      } else {
        tail_restore(node, key);
      }
    } else {
      *key += labels_[node];
    }
    node = get_parent(node);
  } while (node != 0);
}

void Trie::tail_restore(UInt32 node, std::string *key) const {
  const UInt32 link_id = link_flags_.rank1(node);
  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
  if (tail_.mode() == MARISA_BINARY_TAIL) {
    const UInt32 length = (links_[link_id + 1] * 256)
        + labels_[link_flags_.select1(link_id + 1)] - offset;
    key->append(reinterpret_cast<const char *>(tail_[offset]), length);
  } else {
    key->append(reinterpret_cast<const char *>(tail_[offset]));
  }
}

std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
    std::size_t key_buf_size) const {
  std::size_t pos = 0;
  UInt32 node = key_id_to_node(key_id);
  while (node != 0) {
    if (has_link(node)) {
      const std::size_t prev_pos = pos;
      if (has_trie()) {
        trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
      } else {
        tail_restore(node, key_buf, key_buf_size, pos);
      }
      if (pos < key_buf_size) {
        std::reverse(key_buf + prev_pos, key_buf + pos);
      }
    } else {
      if (pos < key_buf_size) {
        key_buf[pos] = labels_[node];
      }
      ++pos;
    }
    node = get_parent(node);
  }
  if (pos < key_buf_size) {
    key_buf[pos] = '\0';
    std::reverse(key_buf, key_buf + pos);
  }
  return pos;
}

void Trie::trie_restore(UInt32 node, char *key_buf,
    std::size_t key_buf_size, std::size_t &pos) const {
  do {
    if (has_link(node)) {
      if (has_trie()) {
        trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
      } else {
        tail_restore(node, key_buf, key_buf_size, pos);
      }
    } else {
      if (pos < key_buf_size) {
        key_buf[pos] = labels_[node];
      }
      ++pos;
    }
    node = get_parent(node);
  } while (node != 0);
}

void Trie::tail_restore(UInt32 node, char *key_buf,
    std::size_t key_buf_size, std::size_t &pos) const {
  const UInt32 link_id = link_flags_.rank1(node);
  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
  if (tail_.mode() == MARISA_BINARY_TAIL) {
    const UInt8 *ptr = tail_[offset];
    const UInt32 length = (links_[link_id + 1] * 256)
        + labels_[link_flags_.select1(link_id + 1)] - offset;
    for (UInt32 i = 0; i < length; ++i) {
      if (pos < key_buf_size) {
        key_buf[pos] = ptr[i];
      }
      ++pos;
    }
  } else {
    for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
      if (pos < key_buf_size) {
        key_buf[pos] = *str;
      }
      ++pos;
    }
  }
}

template <typename T>
UInt32 Trie::lookup_(T query) const {
  UInt32 node = 0;
  std::size_t pos = 0;
  while (!query.ends_at(pos)) {
    if (!find_child<T>(node, query, pos)) {
      return notfound();
    }
  }
  return terminal_flags_[node] ? node_to_key_id(node) : notfound();
}

template <typename T>
std::size_t Trie::trie_match(UInt32 node, T query,
    std::size_t pos) const {
  if (has_link(node)) {
    std::size_t next_pos;
    if (has_trie()) {
      next_pos = trie_->trie_match<T>(get_link(node), query, pos);
    } else {
      next_pos = tail_match<T>(node, get_link_id(node), query, pos);
    }
    if ((next_pos == mismatch()) || (next_pos == pos)) {
      return next_pos;
    }
    pos = next_pos;
  } else if (labels_[node] != query[pos]) {
    return pos;
  } else {
    ++pos;
  }
  node = get_parent(node);
  while (node != 0) {
    if (query.ends_at(pos)) {
      return mismatch();
    }
    if (has_link(node)) {
      std::size_t next_pos;
      if (has_trie()) {
        next_pos = trie_->trie_match<T>(get_link(node), query, pos);
      } else {
        next_pos = tail_match<T>(node, get_link_id(node), query, pos);
      }
      if ((next_pos == mismatch()) || (next_pos == pos)) {
        return mismatch();
      }
      pos = next_pos;
    } else if (labels_[node] != query[pos]) {
      return mismatch();
    } else {
      ++pos;
    }
    node = get_parent(node);
  }
  return pos;
}

template std::size_t Trie::trie_match<CQuery>(UInt32 node,
    CQuery query, std::size_t pos) const;
template std::size_t Trie::trie_match<const Query &>(UInt32 node,
    const Query &query, std::size_t pos) const;

template <typename T>
std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
    T query, std::size_t pos) const {
  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
  const UInt8 *ptr = tail_[offset];
  if (*ptr != query[pos]) {
    return pos;
  } else if (tail_.mode() == MARISA_BINARY_TAIL) {
    const UInt32 length = (links_[link_id + 1] * 256)
        + labels_[link_flags_.select1(link_id + 1)] - offset;
    for (UInt32 i = 1; i < length; ++i) {
      if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
        return mismatch();
      }
    }
    return pos + length;
  } else {
    for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
      if (query.ends_at(pos) || (*ptr != query[pos])) {
        return mismatch();
      }
    }
    return pos;
  }
}

template std::size_t Trie::tail_match<CQuery>(UInt32 node,
    UInt32 link_id, CQuery query, std::size_t pos) const;
template std::size_t Trie::tail_match<const Query &>(UInt32 node,
    UInt32 link_id, const Query &query, std::size_t pos) const;

template <typename T, typename U, typename V>
std::size_t Trie::find_(T query, U key_ids, V key_lengths,
    std::size_t max_num_results) const {
  if (max_num_results == 0) {
    return 0;
  }
  std::size_t count = 0;
  UInt32 node = 0;
  std::size_t pos = 0;
  do {
    if (terminal_flags_[node]) {
      if (key_ids.is_valid()) {
        key_ids.insert(count, node_to_key_id(node));
      }
      if (key_lengths.is_valid()) {
        key_lengths.insert(count, pos);
      }
      if (++count >= max_num_results) {
        return count;
      }
    }
  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
  return count;
}

template <typename T>
UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
  UInt32 node = 0;
  std::size_t pos = 0;
  do {
    if (terminal_flags_[node]) {
      if (key_length != NULL) {
        *key_length = pos;
      }
      return node_to_key_id(node);
    }
  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
  return notfound();
}

template <typename T>
UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
  UInt32 node = 0;
  UInt32 node_found = notfound();
  std::size_t pos = 0;
  std::size_t pos_found = mismatch();
  do {
    if (terminal_flags_[node]) {
      node_found = node;
      pos_found = pos;
    }
  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
  if (node_found != notfound()) {
    if (key_length != NULL) {
      *key_length = pos_found;
    }
    return node_to_key_id(node_found);
  }
  return notfound();
}

template <typename T, typename U, typename V>
std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
    std::size_t max_num_results) const {
  if (max_num_results == 0) {
    return 0;
  }
  UInt32 node = 0;
  std::size_t pos = 0;
  while (!query.ends_at(pos)) {
    if (!predict_child<T>(node, query, pos, NULL)) {
      return 0;
    }
  }
  std::string key;
  std::size_t count = 0;
  if (terminal_flags_[node]) {
    const UInt32 key_id = node_to_key_id(node);
    if (key_ids.is_valid()) {
      key_ids.insert(count, key_id);
    }
    if (keys.is_valid()) {
      restore(key_id, &key);
      keys.insert(count, key);
    }
    if (++count >= max_num_results) {
      return count;
    }
  }
  const UInt32 louds_pos = get_child(node);
  if (!louds_[louds_pos]) {
    return count;
  }
  UInt32 node_begin = louds_pos_to_node(louds_pos, node);
  UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
  while (node_begin < node_end) {
    const UInt32 key_id_begin = node_to_key_id(node_begin);
    const UInt32 key_id_end = node_to_key_id(node_end);
    if (key_ids.is_valid()) {
      UInt32 temp_count = count;
      for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
        key_ids.insert(temp_count, key_id);
        if (++temp_count >= max_num_results) {
          break;
        }
      }
    }
    if (keys.is_valid()) {
      UInt32 temp_count = count;
      for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
        key.clear();
        restore(key_id, &key);
        keys.insert(temp_count, key);
        if (++temp_count >= max_num_results) {
          break;
        }
      }
    }
    count += key_id_end - key_id_begin;
    if (count >= max_num_results) {
      return max_num_results;
    }
    node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
    node_end = louds_pos_to_node(get_child(node_end), node_end);
  }
  return count;
}

template <typename T, typename U, typename V>
std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
    std::size_t max_num_results) const {
  if (max_num_results == 0) {
    return 0;
  } else if (keys.is_valid()) {
    PredictCallback<U, V> callback(key_ids, keys, max_num_results);
    return predict_callback_(query, callback);
  }

  UInt32 node = 0;
  std::size_t pos = 0;
  while (!query.ends_at(pos)) {
    if (!predict_child<T>(node, query, pos, NULL)) {
      return 0;
    }
  }
  std::size_t count = 0;
  if (terminal_flags_[node]) {
    if (key_ids.is_valid()) {
      key_ids.insert(count, node_to_key_id(node));
    }
    if (++count >= max_num_results) {
      return count;
    }
  }
  Cell cell;
  cell.set_louds_pos(get_child(node));
  if (!louds_[cell.louds_pos()]) {
    return count;
  }
  cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
  cell.set_key_id(node_to_key_id(cell.node()));
  Vector<Cell> stack;
  stack.push_back(cell);
  std::size_t stack_pos = 1;
  while (stack_pos != 0) {
    Cell &cur = stack[stack_pos - 1];
    if (!louds_[cur.louds_pos()]) {
      cur.set_louds_pos(cur.louds_pos() + 1);
      --stack_pos;
      continue;
    }
    cur.set_louds_pos(cur.louds_pos() + 1);
    if (terminal_flags_[cur.node()]) {
      if (key_ids.is_valid()) {
        key_ids.insert(count, cur.key_id());
      }
      if (++count >= max_num_results) {
        return count;
      }
      cur.set_key_id(cur.key_id() + 1);
    }
    if (stack_pos == stack.size()) {
      cell.set_louds_pos(get_child(cur.node()));
      cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
      cell.set_key_id(node_to_key_id(cell.node()));
      stack.push_back(cell);
    }
    stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
    ++stack_pos;
  }
  return count;
}

template <typename T>
std::size_t Trie::trie_prefix_match(UInt32 node, T query,
    std::size_t pos, std::string *key) const {
  if (has_link(node)) {
    std::size_t next_pos;
    if (has_trie()) {
      next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
    } else {
      next_pos = tail_prefix_match<T>(
          node, get_link_id(node), query, pos, key);
    }
    if ((next_pos == mismatch()) || (next_pos == pos)) {
      return next_pos;
    }
    pos = next_pos;
  } else if (labels_[node] != query[pos]) {
    return pos;
  } else {
    ++pos;
  }
  node = get_parent(node);
  while (node != 0) {
    if (query.ends_at(pos)) {
      if (key != NULL) {
        trie_restore(node, key);
      }
      return pos;
    }
    if (has_link(node)) {
      std::size_t next_pos;
      if (has_trie()) {
        next_pos = trie_->trie_prefix_match<T>(
            get_link(node), query, pos, key);
      } else {
        next_pos = tail_prefix_match<T>(
            node, get_link_id(node), query, pos, key);
      }
      if ((next_pos == mismatch()) || (next_pos == pos)) {
        return next_pos;
      }
      pos = next_pos;
    } else if (labels_[node] != query[pos]) {
      return mismatch();
    } else {
      ++pos;
    }
    node = get_parent(node);
  }
  return pos;
}

template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
    CQuery query, std::size_t pos, std::string *key) const;
template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
    const Query &query, std::size_t pos, std::string *key) const;

template <typename T>
std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
    T query, std::size_t pos, std::string *key) const {
  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
  const UInt8 *ptr = tail_[offset];
  if (*ptr != query[pos]) {
    return pos;
  } else if (tail_.mode() == MARISA_BINARY_TAIL) {
    const UInt32 length = (links_[link_id + 1] * 256)
        + labels_[link_flags_.select1(link_id + 1)] - offset;
    for (UInt32 i = 1; i < length; ++i) {
      if (query.ends_at(pos + i)) {
        if (key != NULL) {
          key->append(reinterpret_cast<const char *>(ptr + i), length - i);
        }
        return pos + i;
      } else if (ptr[i] != query[pos + i]) {
        return mismatch();
      }
    }
    return pos + length;
  } else {
    for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
      if (query.ends_at(pos)) {
        if (key != NULL) {
          key->append(reinterpret_cast<const char *>(ptr));
        }
        return pos;
      } else if (*ptr != query[pos]) {
        return mismatch();
      }
    }
    return pos;
  }
}

template std::size_t Trie::tail_prefix_match<CQuery>(
    UInt32 node, UInt32 link_id,
    CQuery query, std::size_t pos, std::string *key) const;
template std::size_t Trie::tail_prefix_match<const Query &>(
    UInt32 node, UInt32 link_id,
    const Query &query, std::size_t pos, std::string *key) const;

}  // namespace marisa