/*
 * Copyright 2015, The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "RSScriptGroupFusion.h"

#include "Assert.h"
#include "Log.h"
#include "bcc/BCCContext.h"
#include "bcc/Source.h"
#include "bcinfo/MetadataExtractor.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/raw_ostream.h"

using llvm::Function;
using llvm::Module;

using std::string;

namespace bcc {

namespace {

const Function* getInvokeFunction(const Source& source, const int slot,
                                  Module* newModule) {

  bcinfo::MetadataExtractor &metadata = *source.getMetadata();
  const char* functionName = metadata.getExportFuncNameList()[slot];
  Function* func = newModule->getFunction(functionName);
  // Materialize the function so that later the caller can inspect its argument
  // and return types.
  newModule->materialize(func);
  return func;
}

const Function*
getFunction(Module* mergedModule, const Source* source, const int slot,
            uint32_t* signature) {

  bcinfo::MetadataExtractor &metadata = *source->getMetadata();
  const char* functionName = metadata.getExportForEachNameList()[slot];
  if (functionName == nullptr || !functionName[0]) {
    ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function",
          source->getName().c_str(), slot);
    return nullptr;
  }

  if (metadata.getExportForEachInputCountList()[slot] > 1) {
    ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs",
          source->getName().c_str(), functionName);
    return nullptr;
  }

  if (signature != nullptr) {
    *signature = metadata.getExportForEachSignatureList()[slot];
  }

  const Function* function = mergedModule->getFunction(functionName);

  return function;
}

// The whitelist of supported signature bits. Context or user data arguments are
// not currently supported in kernel fusion. To support them or any new kinds of
// arguments in the future, it requires not only listing the signature bits here,
// but also implementing additional necessary fusion logic in the getFusedFuncSig(),
// getFusedFuncType(), and fuseKernels() functions below.
constexpr uint32_t ExpectedSignatureBits =
        bcinfo::MD_SIG_In |
        bcinfo::MD_SIG_Out |
        bcinfo::MD_SIG_X |
        bcinfo::MD_SIG_Y |
        bcinfo::MD_SIG_Z |
        bcinfo::MD_SIG_Kernel;

int getFusedFuncSig(const std::vector<Source*>& sources,
                    const std::vector<int>& slots,
                    uint32_t* retSig) {
  *retSig = 0;
  uint32_t firstSignature = 0;
  uint32_t signature = 0;
  auto slotIter = slots.begin();
  for (const Source* source : sources) {
    const int slot = *slotIter++;
    bcinfo::MetadataExtractor &metadata = *source->getMetadata();

    if (metadata.getExportForEachInputCountList()[slot] > 1) {
      ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs",
            source->getName().c_str(), slot);
      return -1;
    }

    signature = metadata.getExportForEachSignatureList()[slot];
    if (signature & ~ExpectedSignatureBits) {
      ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x",
            source->getName().c_str(), slot, signature);
      return -1;
    }

    if (firstSignature == 0) {
      firstSignature = signature;
    }

    *retSig |= signature;
  }

  if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) {
    *retSig &= ~bcinfo::MD_SIG_In;
  }

  if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) {
    *retSig &= ~bcinfo::MD_SIG_Out;
  }

  return 0;
}

llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context,
                                     const std::vector<Source*>& sources,
                                     const std::vector<int>& slots,
                                     Module* M,
                                     uint32_t* signature) {
  int error = getFusedFuncSig(sources, slots, signature);

  if (error < 0) {
    return nullptr;
  }

  const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr);

  bccAssert (firstF != nullptr);

  llvm::SmallVector<llvm::Type*, 8> ArgTys;

  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) {
    ArgTys.push_back(firstF->arg_begin()->getType());
  }

  llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32);
  if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) {
    ArgTys.push_back(I32Ty);
  }
  if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) {
    ArgTys.push_back(I32Ty);
  }
  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) {
    ArgTys.push_back(I32Ty);
  }

  const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr);

  bccAssert (lastF != nullptr);

  llvm::Type* retTy = lastF->getReturnType();

  return llvm::FunctionType::get(retTy, ArgTys, false);
}

}  // anonymous namespace

bool fuseKernels(bcc::BCCContext& Context,
                 const std::vector<Source *>& sources,
                 const std::vector<int>& slots,
                 const std::string& fusedName,
                 Module* mergedModule) {
  bccAssert(sources.size() == slots.size() && "sources and slots differ in size");

  uint32_t fusedFunctionSignature;

  llvm::FunctionType* fusedType =
          getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);

  if (fusedType == nullptr) {
    return false;
  }

  Function* fusedKernel =
          (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType));

  llvm::LLVMContext& ctxt = Context.getLLVMContext();

  llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel);
  llvm::IRBuilder<> builder(block);

  Function::arg_iterator argIter = fusedKernel->arg_begin();

  llvm::Value* dataElement = nullptr;
  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
    dataElement = &*(argIter++);
    dataElement->setName("DataIn");
  }

  llvm::Value* X = nullptr;
  if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
    X = &*(argIter++);
    X->setName("x");
  }

  llvm::Value* Y = nullptr;
  if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
    Y = &*(argIter++);
    Y->setName("y");
  }

  llvm::Value* Z = nullptr;
  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
    Z = &*(argIter++);
    Z->setName("z");
  }

  auto slotIter = slots.begin();
  for (const Source* source : sources) {
    int slot = *slotIter;

    uint32_t inputFunctionSignature;
    const Function* inputFunction =
            getFunction(mergedModule, source, slot, &inputFunctionSignature);
    if (inputFunction == nullptr) {
      // Either failed to find the kernel function, or the function has multiple inputs.
      return false;
    }

    // Don't try to fuse a non-kernel
    if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
      ALOGE("Kernel fusion (module %s function %s): not a kernel",
            source->getName().c_str(), inputFunction->getName().str().c_str());
      return false;
    }

    std::vector<llvm::Value*> args;

    if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
      if (dataElement == nullptr) {
        ALOGE("Kernel fusion (module %s function %s): expected input, but got null",
              source->getName().c_str(), inputFunction->getName().str().c_str());
        return false;
      }

      const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
      llvm::Type* firstArgType = funcTy->getParamType(0);

      if (dataElement->getType() != firstArgType) {
        std::string msg;
        llvm::raw_string_ostream rso(msg);
        rso << "Mismatching argument type, expected ";
        firstArgType->print(rso);
        rso << ", received ";
        dataElement->getType()->print(rso);
        ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(),
              inputFunction->getName().str().c_str(), rso.str().c_str());
        return false;
      }

      args.push_back(dataElement);
    } else {
      // Only the first kernel in a batch is allowed to have no input
      if (slotIter != slots.begin()) {
        ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input",
              source->getName().c_str(), inputFunction->getName().str().c_str());
        return false;
      }
    }

    if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
      args.push_back(X);
    }

    if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
      args.push_back(Y);
    }

    if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
      args.push_back(Z);
    }

    dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);

    slotIter++;
  }

  if (fusedKernel->getReturnType()->isVoidTy()) {
    builder.CreateRetVoid();
  } else {
    builder.CreateRet(dataElement);
  }

  llvm::NamedMDNode* ExportForEachNameMD =
    mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name");

  llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName);
  llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr);
  ExportForEachNameMD->addOperand(nameMDNode);

  llvm::NamedMDNode* ExportForEachMD =
    mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
  llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
                                                 llvm::utostr(fusedFunctionSignature));
  llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
  ExportForEachMD->addOperand(sigMDNode);

  return true;
}

bool renameInvoke(BCCContext& Context, const Source* source, const int slot,
                  const std::string& newName, Module* module) {
  const llvm::Function* F = getInvokeFunction(*source, slot, module);
  std::vector<llvm::Type*> params;
  for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
    params.push_back(I->getType());
  }
  llvm::Type* returnTy = F->getReturnType();

  llvm::FunctionType* batchFuncTy =
          llvm::FunctionType::get(returnTy, params, false);

  llvm::Function* newF =
          llvm::Function::Create(batchFuncTy,
                                 llvm::GlobalValue::ExternalLinkage, newName,
                                 module);

  llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(),
                                                     "entry", newF);
  llvm::IRBuilder<> builder(block);

  llvm::Function::arg_iterator argIter = newF->arg_begin();
  llvm::Value* arg1 = &*(argIter++);
  builder.CreateCall((llvm::Value*)F, arg1);

  builder.CreateRetVoid();

  llvm::NamedMDNode* ExportFuncNameMD =
          module->getOrInsertNamedMetadata("#rs_export_func");
  llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName);
  llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD);
  ExportFuncNameMD->addOperand(nodeMD);

  return true;
}

}  // namespace bcc