/* * Copyright 2015, The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "RSScriptGroupFusion.h" #include "Assert.h" #include "Log.h" #include "bcc/BCCContext.h" #include "bcc/Source.h" #include "bcinfo/MetadataExtractor.h" #include "llvm/ADT/StringExtras.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/raw_ostream.h" using llvm::Function; using llvm::Module; using std::string; namespace bcc { namespace { const Function* getInvokeFunction(const Source& source, const int slot, Module* newModule) { bcinfo::MetadataExtractor &metadata = *source.getMetadata(); const char* functionName = metadata.getExportFuncNameList()[slot]; Function* func = newModule->getFunction(functionName); // Materialize the function so that later the caller can inspect its argument // and return types. newModule->materialize(func); return func; } const Function* getFunction(Module* mergedModule, const Source* source, const int slot, uint32_t* signature) { bcinfo::MetadataExtractor &metadata = *source->getMetadata(); const char* functionName = metadata.getExportForEachNameList()[slot]; if (functionName == nullptr || !functionName[0]) { ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function", source->getName().c_str(), slot); return nullptr; } if (metadata.getExportForEachInputCountList()[slot] > 1) { ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs", source->getName().c_str(), functionName); return nullptr; } if (signature != nullptr) { *signature = metadata.getExportForEachSignatureList()[slot]; } const Function* function = mergedModule->getFunction(functionName); return function; } // The whitelist of supported signature bits. Context or user data arguments are // not currently supported in kernel fusion. To support them or any new kinds of // arguments in the future, it requires not only listing the signature bits here, // but also implementing additional necessary fusion logic in the getFusedFuncSig(), // getFusedFuncType(), and fuseKernels() functions below. constexpr uint32_t ExpectedSignatureBits = bcinfo::MD_SIG_In | bcinfo::MD_SIG_Out | bcinfo::MD_SIG_X | bcinfo::MD_SIG_Y | bcinfo::MD_SIG_Z | bcinfo::MD_SIG_Kernel; int getFusedFuncSig(const std::vector<Source*>& sources, const std::vector<int>& slots, uint32_t* retSig) { *retSig = 0; uint32_t firstSignature = 0; uint32_t signature = 0; auto slotIter = slots.begin(); for (const Source* source : sources) { const int slot = *slotIter++; bcinfo::MetadataExtractor &metadata = *source->getMetadata(); if (metadata.getExportForEachInputCountList()[slot] > 1) { ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs", source->getName().c_str(), slot); return -1; } signature = metadata.getExportForEachSignatureList()[slot]; if (signature & ~ExpectedSignatureBits) { ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x", source->getName().c_str(), slot, signature); return -1; } if (firstSignature == 0) { firstSignature = signature; } *retSig |= signature; } if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) { *retSig &= ~bcinfo::MD_SIG_In; } if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) { *retSig &= ~bcinfo::MD_SIG_Out; } return 0; } llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context, const std::vector<Source*>& sources, const std::vector<int>& slots, Module* M, uint32_t* signature) { int error = getFusedFuncSig(sources, slots, signature); if (error < 0) { return nullptr; } const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr); bccAssert (firstF != nullptr); llvm::SmallVector<llvm::Type*, 8> ArgTys; if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) { ArgTys.push_back(firstF->arg_begin()->getType()); } llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32); if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) { ArgTys.push_back(I32Ty); } if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) { ArgTys.push_back(I32Ty); } if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) { ArgTys.push_back(I32Ty); } const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr); bccAssert (lastF != nullptr); llvm::Type* retTy = lastF->getReturnType(); return llvm::FunctionType::get(retTy, ArgTys, false); } } // anonymous namespace bool fuseKernels(bcc::BCCContext& Context, const std::vector<Source *>& sources, const std::vector<int>& slots, const std::string& fusedName, Module* mergedModule) { bccAssert(sources.size() == slots.size() && "sources and slots differ in size"); uint32_t fusedFunctionSignature; llvm::FunctionType* fusedType = getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature); if (fusedType == nullptr) { return false; } Function* fusedKernel = (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType)); llvm::LLVMContext& ctxt = Context.getLLVMContext(); llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel); llvm::IRBuilder<> builder(block); Function::arg_iterator argIter = fusedKernel->arg_begin(); llvm::Value* dataElement = nullptr; if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) { dataElement = &*(argIter++); dataElement->setName("DataIn"); } llvm::Value* X = nullptr; if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) { X = &*(argIter++); X->setName("x"); } llvm::Value* Y = nullptr; if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) { Y = &*(argIter++); Y->setName("y"); } llvm::Value* Z = nullptr; if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) { Z = &*(argIter++); Z->setName("z"); } auto slotIter = slots.begin(); for (const Source* source : sources) { int slot = *slotIter; uint32_t inputFunctionSignature; const Function* inputFunction = getFunction(mergedModule, source, slot, &inputFunctionSignature); if (inputFunction == nullptr) { // Either failed to find the kernel function, or the function has multiple inputs. return false; } // Don't try to fuse a non-kernel if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) { ALOGE("Kernel fusion (module %s function %s): not a kernel", source->getName().c_str(), inputFunction->getName().str().c_str()); return false; } std::vector<llvm::Value*> args; if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) { if (dataElement == nullptr) { ALOGE("Kernel fusion (module %s function %s): expected input, but got null", source->getName().c_str(), inputFunction->getName().str().c_str()); return false; } const llvm::FunctionType* funcTy = inputFunction->getFunctionType(); llvm::Type* firstArgType = funcTy->getParamType(0); if (dataElement->getType() != firstArgType) { std::string msg; llvm::raw_string_ostream rso(msg); rso << "Mismatching argument type, expected "; firstArgType->print(rso); rso << ", received "; dataElement->getType()->print(rso); ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(), inputFunction->getName().str().c_str(), rso.str().c_str()); return false; } args.push_back(dataElement); } else { // Only the first kernel in a batch is allowed to have no input if (slotIter != slots.begin()) { ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input", source->getName().c_str(), inputFunction->getName().str().c_str()); return false; } } if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) { args.push_back(X); } if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) { args.push_back(Y); } if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) { args.push_back(Z); } dataElement = builder.CreateCall((llvm::Value*)inputFunction, args); slotIter++; } if (fusedKernel->getReturnType()->isVoidTy()) { builder.CreateRetVoid(); } else { builder.CreateRet(dataElement); } llvm::NamedMDNode* ExportForEachNameMD = mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name"); llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName); llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr); ExportForEachNameMD->addOperand(nameMDNode); llvm::NamedMDNode* ExportForEachMD = mergedModule->getOrInsertNamedMetadata("#rs_export_foreach"); llvm::MDString* sigMDStr = llvm::MDString::get(ctxt, llvm::utostr(fusedFunctionSignature)); llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr); ExportForEachMD->addOperand(sigMDNode); return true; } bool renameInvoke(BCCContext& Context, const Source* source, const int slot, const std::string& newName, Module* module) { const llvm::Function* F = getInvokeFunction(*source, slot, module); std::vector<llvm::Type*> params; for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) { params.push_back(I->getType()); } llvm::Type* returnTy = F->getReturnType(); llvm::FunctionType* batchFuncTy = llvm::FunctionType::get(returnTy, params, false); llvm::Function* newF = llvm::Function::Create(batchFuncTy, llvm::GlobalValue::ExternalLinkage, newName, module); llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(), "entry", newF); llvm::IRBuilder<> builder(block); llvm::Function::arg_iterator argIter = newF->arg_begin(); llvm::Value* arg1 = &*(argIter++); builder.CreateCall((llvm::Value*)F, arg1); builder.CreateRetVoid(); llvm::NamedMDNode* ExportFuncNameMD = module->getOrInsertNamedMetadata("#rs_export_func"); llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName); llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD); ExportFuncNameMD->addOperand(nodeMD); return true; } } // namespace bcc