/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "utils/regex-match.h"

#include <memory>

#include "annotator/types.h"
#include "utils/lua-utils.h"

#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lualib.h"
#ifdef __cplusplus
}
#endif

namespace libtextclassifier3 {
namespace {

// Provide a lua environment for running regex match post verification.
// It sets up and exposes the match data as well as the context.
class LuaVerifier : private LuaEnvironment {
 public:
  static std::unique_ptr<LuaVerifier> Create(
      const std::string& context, const std::string& verifier_code,
      const UniLib::RegexMatcher* matcher);

  bool Verify(bool* result);

 private:
  explicit LuaVerifier(const std::string& context,
                       const std::string& verifier_code,
                       const UniLib::RegexMatcher* matcher)
      : context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
  bool Initialize();

  // Provides details of a capturing group to lua.
  int GetCapturingGroup();

  const std::string& context_;
  const std::string& verifier_code_;
  const UniLib::RegexMatcher* matcher_;
};

bool LuaVerifier::Initialize() {
  // Run protected to not lua panic in case of setup failure.
  return RunProtected([this] {
           LoadDefaultLibraries();

           // Expose context of the match as `context` global variable.
           PushString(context_);
           lua_setglobal(state_, "context");

           // Expose match array as `match` global variable.
           // Each entry `match[i]` exposes the ith capturing group as:
           //   * `begin`: span start
           //   * `end`: span end
           //   * `text`: the text
           BindTable<LuaVerifier, &LuaVerifier::GetCapturingGroup>("match");
           lua_setglobal(state_, "match");
           return LUA_OK;
         }) == LUA_OK;
}

std::unique_ptr<LuaVerifier> LuaVerifier::Create(
    const std::string& context, const std::string& verifier_code,
    const UniLib::RegexMatcher* matcher) {
  auto verifier = std::unique_ptr<LuaVerifier>(
      new LuaVerifier(context, verifier_code, matcher));
  if (!verifier->Initialize()) {
    TC3_LOG(ERROR) << "Could not initialize lua environment.";
    return nullptr;
  }
  return verifier;
}

int LuaVerifier::GetCapturingGroup() {
  if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
    TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
                   << lua_type(state_, /*idx=*/-1);
    lua_error(state_);
    return 0;
  }
  const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
  int status = UniLib::RegexMatcher::kNoError;
  const CodepointSpan span = {matcher_->Start(group_id, &status),
                              matcher_->End(group_id, &status)};
  std::string text = matcher_->Group(group_id, &status).ToUTF8String();
  if (status != UniLib::RegexMatcher::kNoError) {
    TC3_LOG(ERROR) << "Could not extract span from capturing group.";
    lua_error(state_);
    return 0;
  }
  lua_newtable(state_);
  lua_pushinteger(state_, span.first);
  lua_setfield(state_, /*idx=*/-2, "begin");
  lua_pushinteger(state_, span.second);
  lua_setfield(state_, /*idx=*/-2, "end");
  PushString(text);
  lua_setfield(state_, /*idx=*/-2, "text");
  return 1;
}

bool LuaVerifier::Verify(bool* result) {
  if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
                      /*name=*/nullptr) != LUA_OK) {
    TC3_LOG(ERROR) << "Could not load verifier snippet.";
    return false;
  }

  if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
    TC3_LOG(ERROR) << "Could not run verifier snippet.";
    return false;
  }

  if (RunProtected(
          [this, result] {
            if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
              TC3_LOG(ERROR) << "Unexpected verification result type: "
                             << lua_type(state_, /*idx=*/-1);
              lua_error(state_);
              return LUA_ERRRUN;
            }
            *result = lua_toboolean(state_, /*idx=*/-1);
            return LUA_OK;
          },
          /*num_args=*/1) != LUA_OK) {
    TC3_LOG(ERROR) << "Could not read lua result.";
    return false;
  }
  return true;
}

}  // namespace

bool SetFieldFromCapturingGroup(const int group_id,
                                const FlatbufferFieldPath* field_path,
                                const UniLib::RegexMatcher* matcher,
                                ReflectiveFlatbuffer* flatbuffer) {
  int status = UniLib::RegexMatcher::kNoError;
  std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
  if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
    return false;
  }
  return flatbuffer->ParseAndSet(field_path, group_text);
}

bool VerifyMatch(const std::string& context,
                 const UniLib::RegexMatcher* matcher,
                 const std::string& lua_verifier_code) {
  bool status = false;
  auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
  if (verifier == nullptr) {
    TC3_LOG(ERROR) << "Could not create verifier.";
    return false;
  }
  if (!verifier->Verify(&status)) {
    TC3_LOG(ERROR) << "Could not create verifier.";
    return false;
  }
  return status;
}

}  // namespace libtextclassifier3