普通文本  |  304行  |  10.57 KB

/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "utils/lua-utils.h"

// lua_dump takes an extra argument "strip" in 5.3, but not in 5.2.
#ifndef TC3_AOSP
#define lua_dump(L, w, d, s) lua_dump((L), (w), (d))
#endif

namespace libtextclassifier3 {
namespace {
// Upvalue indices for the flatbuffer callback.
static constexpr int kSchemaArgId = 1;
static constexpr int kTypeArgId = 2;
static constexpr int kTableArgId = 3;

static constexpr luaL_Reg defaultlibs[] = {{"_G", luaopen_base},
                                           {LUA_TABLIBNAME, luaopen_table},
                                           {LUA_STRLIBNAME, luaopen_string},
                                           {LUA_BITLIBNAME, luaopen_bit32},
                                           {LUA_MATHLIBNAME, luaopen_math},
                                           {nullptr, nullptr}};

// Implementation of a lua_Writer that appends the data to a string.
int LuaStringWriter(lua_State *state, const void *data, size_t size,
                    void *result) {
  std::string *const result_string = static_cast<std::string *>(result);
  result_string->insert(result_string->size(), static_cast<const char *>(data),
                        size);
  return LUA_OK;
}

}  // namespace

LuaEnvironment::LuaEnvironment() { state_ = luaL_newstate(); }

LuaEnvironment::~LuaEnvironment() {
  if (state_ != nullptr) {
    lua_close(state_);
  }
}

int LuaEnvironment::Iterator::NextCallback(lua_State *state) {
  return FromUpValue<Iterator *>(kIteratorArgId, state)->Next(state);
}

int LuaEnvironment::Iterator::LengthCallback(lua_State *state) {
  return FromUpValue<Iterator *>(kIteratorArgId, state)->Length(state);
}

int LuaEnvironment::Iterator::ItemCallback(lua_State *state) {
  return FromUpValue<Iterator *>(kIteratorArgId, state)->Item(state);
}

int LuaEnvironment::Iterator::IteritemsCallback(lua_State *state) {
  return FromUpValue<Iterator *>(kIteratorArgId, state)->Iteritems(state);
}

void LuaEnvironment::PushFlatbuffer(const char *name,
                                    const reflection::Schema *schema,
                                    const reflection::Object *type,
                                    const flatbuffers::Table *table,
                                    lua_State *state) {
  lua_newtable(state);
  luaL_newmetatable(state, name);
  lua_pushlightuserdata(state, AsUserData(schema));
  lua_pushlightuserdata(state, AsUserData(type));
  lua_pushlightuserdata(state, AsUserData(table));
  lua_pushcclosure(state, &GetFieldCallback, 3);
  lua_setfield(state, -2, kIndexKey);
  lua_setmetatable(state, -2);
}

int LuaEnvironment::GetFieldCallback(lua_State *state) {
  // Fetch the arguments.
  const reflection::Schema *schema =
      FromUpValue<reflection::Schema *>(kSchemaArgId, state);
  const reflection::Object *type =
      FromUpValue<reflection::Object *>(kTypeArgId, state);
  const flatbuffers::Table *table =
      FromUpValue<flatbuffers::Table *>(kTableArgId, state);
  return GetField(schema, type, table, state);
}

int LuaEnvironment::GetField(const reflection::Schema *schema,
                             const reflection::Object *type,
                             const flatbuffers::Table *table,
                             lua_State *state) {
  const char *field_name = lua_tostring(state, -1);
  const reflection::Field *field = type->fields()->LookupByKey(field_name);
  if (field == nullptr) {
    lua_error(state);
    return 0;
  }
  // Provide primitive fields directly.
  const reflection::BaseType field_type = field->type()->base_type();
  switch (field_type) {
    case reflection::Bool:
      lua_pushboolean(state, table->GetField<uint8_t>(
                                 field->offset(), field->default_integer()));
      break;
    case reflection::Int:
      lua_pushinteger(state, table->GetField<int32>(field->offset(),
                                                    field->default_integer()));
      break;
    case reflection::Long:
      lua_pushinteger(state, table->GetField<int64>(field->offset(),
                                                    field->default_integer()));
      break;
    case reflection::Float:
      lua_pushnumber(state, table->GetField<float>(field->offset(),
                                                   field->default_real()));
      break;
    case reflection::Double:
      lua_pushnumber(state, table->GetField<double>(field->offset(),
                                                    field->default_real()));
      break;
    case reflection::String: {
      const flatbuffers::String *string_value =
          table->GetPointer<const flatbuffers::String *>(field->offset());
      if (string_value != nullptr) {
        lua_pushlstring(state, string_value->data(), string_value->Length());
      } else {
        lua_pushlstring(state, "", 0);
      }
      break;
    }
    case reflection::Obj: {
      const flatbuffers::Table *field_table =
          table->GetPointer<const flatbuffers::Table *>(field->offset());
      if (field_table == nullptr) {
        TC3_LOG(ERROR) << "Field was not set in entity data.";
        lua_error(state);
        return 0;
      }
      const reflection::Object *field_type =
          schema->objects()->Get(field->type()->index());
      PushFlatbuffer(field->name()->c_str(), schema, field_type, field_table,
                     state);
      break;
    }
    default:
      TC3_LOG(ERROR) << "Unsupported type: " << field_type;
      lua_error(state);
      return 0;
  }
  return 1;
}

int LuaEnvironment::ReadFlatbuffer(ReflectiveFlatbuffer *buffer) {
  if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
    TC3_LOG(ERROR) << "Expected actions table, got: "
                   << lua_type(state_, /*idx=*/-1);
    lua_error(state_);
    return LUA_ERRRUN;
  }

  lua_pushnil(state_);
  while (lua_next(state_, /*idx=*/-2)) {
    const StringPiece key = ReadString(/*index=*/-2);
    const reflection::Field *field = buffer->GetFieldOrNull(key);
    if (field == nullptr) {
      TC3_LOG(ERROR) << "Unknown field: " << key.ToString();
      lua_error(state_);
      return LUA_ERRRUN;
    }
    switch (field->type()->base_type()) {
      case reflection::Obj:
        return ReadFlatbuffer(buffer->Mutable(field));
      case reflection::Bool:
        buffer->Set(field,
                    static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
        break;
      case reflection::Int:
        buffer->Set(field, static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
        break;
      case reflection::Long:
        buffer->Set(field,
                    static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
        break;
      case reflection::Float:
        buffer->Set(field,
                    static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
        break;
      case reflection::Double:
        buffer->Set(field,
                    static_cast<double>(lua_tonumber(state_, /*idx=*/-1)));
        break;
      case reflection::String: {
        buffer->Set(field, ReadString(/*index=*/-1));
        break;
      }
      default:
        TC3_LOG(ERROR) << "Unsupported type: " << field->type()->base_type();
        lua_error(state_);
        return LUA_ERRRUN;
    }
    lua_pop(state_, 1);
  }
  // lua_pop(state_, /*n=*/1);
  return LUA_OK;
}

void LuaEnvironment::LoadDefaultLibraries() {
  for (const luaL_Reg *lib = defaultlibs; lib->func; lib++) {
    luaL_requiref(state_, lib->name, lib->func, 1);
    lua_pop(state_, 1); /* remove lib */
  }
}

void LuaEnvironment::PushValue(const Variant &value) {
  if (value.HasInt()) {
    lua_pushnumber(state_, value.IntValue());
  } else if (value.HasInt64()) {
    lua_pushnumber(state_, value.Int64Value());
  } else if (value.HasBool()) {
    lua_pushboolean(state_, value.BoolValue());
  } else if (value.HasFloat()) {
    lua_pushnumber(state_, value.FloatValue());
  } else if (value.HasDouble()) {
    lua_pushnumber(state_, value.DoubleValue());
  } else if (value.HasString()) {
    lua_pushlstring(state_, value.StringValue().data(),
                    value.StringValue().size());
  } else {
    TC3_LOG(FATAL) << "Unknown value type.";
  }
}

StringPiece LuaEnvironment::ReadString(const int index) const {
  size_t length = 0;
  const char *data = lua_tolstring(state_, index, &length);
  return StringPiece(data, length);
}

void LuaEnvironment::PushString(const StringPiece str) {
  lua_pushlstring(state_, str.data(), str.size());
}

void LuaEnvironment::PushFlatbuffer(const reflection::Schema *schema,
                                    const flatbuffers::Table *table) {
  PushFlatbuffer(schema->root_table()->name()->c_str(), schema,
                 schema->root_table(), table, state_);
}

int LuaEnvironment::RunProtected(const std::function<int()> &func,
                                 const int num_args, const int num_results) {
  struct ProtectedCall {
    std::function<int()> func;

    static int run(lua_State *state) {
      // Read the pointer to the ProtectedCall struct.
      ProtectedCall *p = static_cast<ProtectedCall *>(
          lua_touserdata(state, lua_upvalueindex(1)));
      return p->func();
    }
  };
  ProtectedCall protected_call = {func};
  lua_pushlightuserdata(state_, &protected_call);
  lua_pushcclosure(state_, &ProtectedCall::run, /*n=*/1);
  // Put the closure before the arguments on the stack.
  if (num_args > 0) {
    lua_insert(state_, -(1 + num_args));
  }
  return lua_pcall(state_, num_args, num_results, /*errorfunc=*/0);
}

bool LuaEnvironment::Compile(StringPiece snippet, std::string *bytecode) {
  if (luaL_loadbuffer(state_, snippet.data(), snippet.size(),
                      /*name=*/nullptr) != LUA_OK) {
    TC3_LOG(ERROR) << "Could not compile lua snippet: "
                   << ReadString(/*index=*/-1).ToString();
    lua_pop(state_, 1);
    return false;
  }
  if (lua_dump(state_, LuaStringWriter, bytecode, /*strip*/ 1) != LUA_OK) {
    TC3_LOG(ERROR) << "Could not dump compiled lua snippet.";
    lua_pop(state_, 1);
    return false;
  }
  lua_pop(state_, 1);
  return true;
}

bool Compile(StringPiece snippet, std::string *bytecode) {
  return LuaEnvironment().Compile(snippet, bytecode);
}

}  // namespace libtextclassifier3