/*
 * Copyright 2015, The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "Assert.h"
#include "Log.h"
#include "RSTransforms.h"

#include "bcinfo/MetadataExtractor.h"

#include <string>

#include <llvm/Pass.h>
#include <llvm/IR/DIBuilder.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/InstIterator.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Module.h>
#include <llvm/ADT/SetVector.h>

namespace {

const char DEBUG_SOURCE_PATH[] = "/opt/renderscriptdebugger/1";
const char DEBUG_GENERATED_FILE[] = "generated.rs";
const char DEBUG_PROTOTYPE_VAR_NAME[] = "rsDebugOuterForeachT";
const char DEBUG_COMPILE_UNIT_MDNAME[] = "llvm.dbg.cu";

/*
 * LLVM pass to attach debug information to the bits of code
 * generated by the compiler.
 */
class RSAddDebugInfoPass : public llvm::ModulePass {

public:
  // Pass ID
  static char ID;

  RSAddDebugInfoPass() : ModulePass(ID), kernelTypeMD(nullptr),
      sourceFileName(nullptr), emptyExpr(nullptr), abiMetaCU(nullptr),
      indexVarType(nullptr) {
  }

  virtual bool runOnModule(llvm::Module &Module) {
    // Gather information about this bcc module.
    bcinfo::MetadataExtractor me(&Module);
    if (!me.extract()) {
      ALOGE("Could not extract metadata from module!");
      return false;
    }

    const size_t nForEachKernels = me.getExportForEachSignatureCount();
    const char **forEachKernels = me.getExportForEachNameList();
    const bcinfo::MetadataExtractor::Reduce *reductions =
        me.getExportReduceList();
    const size_t nReductions = me.getExportReduceCount();

    llvm::SmallSetVector<llvm::Function *, 16> expandFuncs{};
    auto pushExpanded = [&](const char *const name) -> void {
      bccAssert(name && *name && (::strcmp(name, ".") != 0));

      const std::string expandName = std::string(name) + ".expand";
      if (llvm::Function *const func = Module.getFunction(expandName))
        expandFuncs.insert(func);
    };

    for (size_t i = 0; i < nForEachKernels; ++i)
      pushExpanded(forEachKernels[i]);

    for (size_t i = 0; i < nReductions; ++i) {
      const bcinfo::MetadataExtractor::Reduce &reduction = reductions[i];
      pushExpanded(reduction.mAccumulatorName);
    }

    // Set up the debug info builder.
    llvm::DIBuilder DebugInfo(Module);
    initializeDebugInfo(DebugInfo, Module);

    for (const auto &expandFunc : expandFuncs) {
      // Attach DI metadata to each generated function.
      // No inlining has occurred at this point so it's safe to name match
      // without worrying about inlined function bodies.
      attachDebugInfo(DebugInfo, *expandFunc);
    }

    DebugInfo.finalize();

    cleanupDebugInfo(Module);

    return true;
  }

private:

  // @brief Initialize the debug info generation.
  //
  // This method does a couple of things:
  // * Look up debug metadata for kernel ABI and store it if present.
  // * Store a couple of useful pieces of debug metadata in member
  //   variables so they do not have to be created multiple times.
  void initializeDebugInfo(llvm::DIBuilder &DebugInfo,
                           const llvm::Module &Module) {
    llvm::LLVMContext &ctx = Module.getContext();

    // Start generating debug information for bcc-generated code.
    DebugInfo.createCompileUnit(llvm::dwarf::DW_LANG_GOOGLE_RenderScript,
                                DEBUG_GENERATED_FILE, DEBUG_SOURCE_PATH,
                                "RS", false, "", 0);

    // Pre-generate and save useful pieces of debug metadata.
    sourceFileName = DebugInfo.createFile(DEBUG_GENERATED_FILE, DEBUG_SOURCE_PATH);
    emptyExpr = DebugInfo.createExpression();

    // Lookup compile unit with kernel ABI debug metadata.
    llvm::NamedMDNode *mdCompileUnitList =
        Module.getNamedMetadata(DEBUG_COMPILE_UNIT_MDNAME);
    bccAssert(mdCompileUnitList != nullptr &&
        "DebugInfo pass could not find any existing compile units.");

    llvm::DIGlobalVariable *kernelPrototypeVarMD = nullptr;
    for (llvm::MDNode* CUNode : mdCompileUnitList->operands()) {
      if (auto *CU = llvm::dyn_cast<llvm::DICompileUnit>(CUNode)) {
        for (llvm::DIGlobalVariable* GV : CU->getGlobalVariables()) {
          if (GV->getDisplayName() == DEBUG_PROTOTYPE_VAR_NAME) {
            kernelPrototypeVarMD = GV;
            abiMetaCU = CU;
            break;
          }
        }
        if (kernelPrototypeVarMD != nullptr)
          break;
      }
    }

    // Lookup the expanded function interface type metadata.
    llvm::MDTuple *kernelPrototypeMD = nullptr;
    if (kernelPrototypeVarMD != nullptr) {
      // Dig into the metadata to look for function prototype.
      llvm::DIDerivedType *DT = nullptr;
      DT = llvm::cast<llvm::DIDerivedType>(kernelPrototypeVarMD->getType());
      DT = llvm::cast<llvm::DIDerivedType>(DT->getBaseType());
      llvm::DISubroutineType *ST = llvm::cast<llvm::DISubroutineType>(DT->getBaseType());
      kernelPrototypeMD = llvm::cast<llvm::MDTuple>(ST->getRawTypeArray());

      indexVarType = llvm::dyn_cast_or_null<llvm::DIType>(
          kernelPrototypeMD->getOperand(2));
    }
    // Fall back to the function type of void() if there is no proper debug info.
    if (kernelPrototypeMD == nullptr)
      kernelPrototypeMD = llvm::MDTuple::get(ctx, {nullptr});
    // Fall back to unspecified type if we don't have a proper index type.
    if (indexVarType == nullptr)
      indexVarType = DebugInfo.createBasicType("uint32_t", 32, 32,
          llvm::dwarf::DW_ATE_unsigned);

    // Capture the expanded kernel type debug info.
    kernelTypeMD = DebugInfo.createSubroutineType(kernelPrototypeMD);
  }

  /// @brief Add debug information to a generated function.
  ///
  /// This procedure adds the following pieces of debug information
  /// to the function specified by Func:
  /// * Entry for the function to the current compile unit.
  /// * Adds debug info entries for each function argument.
  /// * Adds debug info entry for the rsIndex local variable.
  /// * File/line information to each instruction set to generates.rs:1.
  void attachDebugInfo(llvm::DIBuilder &DebugInfo, llvm::Function &Func) {
    // Lookup the current thread coordinate variable.
    llvm::AllocaInst* indexVar = nullptr;
    for (llvm::Instruction &inst : llvm::instructions(Func)) {
      if (auto *allocaInst = llvm::dyn_cast<llvm::AllocaInst>(&inst)) {
        if (allocaInst->getName() == bcc::BCC_INDEX_VAR_NAME) {
          indexVar = allocaInst;
          break;
        }
      }
    }

    // Create function-level debug metadata.
    llvm::DISubprogram *ExpandedFunc = DebugInfo.createFunction(
        sourceFileName, // scope
        Func.getName(), Func.getName(),
        sourceFileName, 1, kernelTypeMD,
        false, true, 1, 0, false
    );
    Func.setSubprogram(ExpandedFunc);

    // IRBuilder for allocating variables for arguments.
    llvm::IRBuilder<> ir(&*Func.getEntryBlock().begin());

    // Walk through the argument list and expanded function prototype
    // debuginfo in lockstep to create debug entries for
    // the expanded function arguments.
    unsigned argIdx = 1;
    llvm::MDTuple *argTypes = kernelTypeMD->getTypeArray().get();
    for (llvm::Argument &arg : Func.getArgumentList()) {
      // Stop processing arguments if we run out of debug info.
      if (argIdx >= argTypes->getNumOperands())
        break;

      // Create debuginfo entry for the argument and advance.
      llvm::DILocalVariable *argVarDI = DebugInfo.createParameterVariable(
          ExpandedFunc, arg.getName(), argIdx, sourceFileName, 1,
          llvm::cast<llvm::DIType>(argTypes->getOperand(argIdx).get()),
          true, 0
      );

      // Annotate the argument variable in the IR.
      llvm::AllocaInst *argVar =
          ir.CreateAlloca(arg.getType(), nullptr, arg.getName() + ".var");
      llvm::StoreInst *argStore = ir.CreateStore(&arg, argVar);
      llvm::LoadInst *loadedVar = ir.CreateLoad(argVar, arg.getName() + ".l");
      DebugInfo.insertDeclare(argVar, argVarDI, emptyExpr,
          llvm::DebugLoc::get(1, 1, ExpandedFunc), loadedVar);
      for (llvm::Use &u : arg.uses())
        if (u.getUser() != argStore)
          u.set(loadedVar);
      argIdx++;
    }

    // Annotate the index variable with metadata.
    if (indexVar) {
      // Debug information for loop index variable.
      llvm::DILocalVariable *indexVarDI = DebugInfo.createAutoVariable(
          ExpandedFunc, bcc::BCC_INDEX_VAR_NAME, sourceFileName, 1,
          indexVarType, true
      );

      // Insert declaration annotation in the instruction stream.
      llvm::Instruction *decl = DebugInfo.insertDeclare(
          indexVar, indexVarDI, emptyExpr,
          llvm::DebugLoc::get(1, 1, ExpandedFunc), indexVar);
      indexVar->moveBefore(decl);
    }

    // Attach location information to each instruction in the function.
    for (llvm::Instruction &inst : llvm::instructions(Func)) {
      inst.setDebugLoc(llvm::DebugLoc::get(1, 1, ExpandedFunc));
    }
  }

  // @brief Clean up the debug info.
  //
  // At the moment, it only finds the compile unit for the expanded function
  // metadata generated by clang and removes it.
  void cleanupDebugInfo(llvm::Module& Module) {
    if (abiMetaCU == nullptr)
      return;

    // Remove the compile unit with the runtime interface DI.
    llvm::SmallVector<llvm::MDNode*, 4> unitsTmp;
    llvm::NamedMDNode *debugMD =
        Module.getNamedMetadata(DEBUG_COMPILE_UNIT_MDNAME);
    for (llvm::MDNode *cu : debugMD->operands())
      if (cu != abiMetaCU)
        unitsTmp.push_back(cu);
    debugMD->eraseFromParent();
    debugMD = Module.getOrInsertNamedMetadata(DEBUG_COMPILE_UNIT_MDNAME);
    for (llvm::MDNode *cu : unitsTmp)
      debugMD->addOperand(cu);
  }

private:
  // private attributes
  llvm::DISubroutineType* kernelTypeMD;
  llvm::DIFile *sourceFileName;
  llvm::DIExpression *emptyExpr;
  llvm::DICompileUnit *abiMetaCU;
  llvm::DIType *indexVarType;

}; // end class RSAddDebugInfoPass

char RSAddDebugInfoPass::ID = 0;
static llvm::RegisterPass<RSAddDebugInfoPass> X("addrsdi", "Add RS DebugInfo Pass");

} // end anonymous namespace

namespace bcc {

llvm::ModulePass * createRSAddDebugInfoPass() {
  return new RSAddDebugInfoPass();
}

} // end namespace bcc