/* * Copyright 2010, The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "slang_rs_export_func.h" #include <string> #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "slang_assert.h" #include "slang_rs_context.h" namespace slang { namespace { // Ensure that the exported function is actually valid static bool ValidateFuncDecl(slang::RSContext *Context, const clang::FunctionDecl *FD) { slangAssert(Context && FD); const clang::ASTContext &C = FD->getASTContext(); if (FD->getReturnType().getCanonicalType() != C.VoidTy) { Context->ReportError( FD->getLocation(), "invokable non-static functions are required to return void"); return false; } return true; } } // namespace RSExportFunc *RSExportFunc::Create(RSContext *Context, const clang::FunctionDecl *FD) { llvm::StringRef Name = FD->getName(); RSExportFunc *F; slangAssert(!Name.empty() && "Function must have a name"); if (!ValidateFuncDecl(Context, FD)) { return nullptr; } F = new RSExportFunc(Context, Name, FD); // Initialize mParamPacketType if (FD->getNumParams() <= 0) { F->mParamPacketType = nullptr; } else { clang::ASTContext &Ctx = Context->getASTContext(); std::string Id = CreateDummyName("helper_func_param", F->getName()); clang::RecordDecl *RD = clang::RecordDecl::Create(Ctx, clang::TTK_Struct, Ctx.getTranslationUnitDecl(), clang::SourceLocation(), clang::SourceLocation(), &Ctx.Idents.get(Id)); for (unsigned i = 0; i < FD->getNumParams(); i++) { const clang::ParmVarDecl *PVD = FD->getParamDecl(i); llvm::StringRef ParamName = PVD->getName(); if (PVD->hasDefaultArg()) fprintf(stderr, "Note: parameter '%s' in function '%s' has default " "value which is not supported\n", ParamName.str().c_str(), F->getName().c_str()); clang::FieldDecl *FD = clang::FieldDecl::Create(Ctx, RD, clang::SourceLocation(), clang::SourceLocation(), PVD->getIdentifier(), PVD->getOriginalType(), nullptr, /* BitWidth = */ nullptr, /* Mutable = */ false, /* HasInit = */ clang::ICIS_NoInit); RD->addDecl(FD); } RD->completeDefinition(); clang::QualType T = Ctx.getTagDeclType(RD); slangAssert(!T.isNull()); RSExportType *ET = RSExportType::Create(Context, T.getTypePtr(), NotLegacyKernelArgument); if (ET == nullptr) { fprintf(stderr, "Failed to export the function %s. There's at least one " "parameter whose type is not supported by the " "reflection\n", F->getName().c_str()); return nullptr; } slangAssert((ET->getClass() == RSExportType::ExportClassRecord) && "Parameter packet must be a record"); F->mParamPacketType = static_cast<RSExportRecordType *>(ET); } return F; } bool RSExportFunc::checkParameterPacketType(llvm::StructType *ParamTy) const { if (ParamTy == nullptr) return !hasParam(); else if (!hasParam()) return false; slangAssert(mParamPacketType != nullptr); const RSExportRecordType *ERT = mParamPacketType; // must have same number of elements if (ERT->getFields().size() != ParamTy->getNumElements()) return false; const llvm::StructLayout *ParamTySL = getRSContext()->getDataLayout()->getStructLayout(ParamTy); unsigned Index = 0; for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(), FE = ERT->fields_end(); FI != FE; FI++, Index++) { const RSExportRecordType::Field *F = *FI; llvm::Type *T1 = F->getType()->getLLVMType(); llvm::Type *T2 = ParamTy->getTypeAtIndex(Index); // Fast check if (T1 == T2) continue; // Check offset size_t T1Offset = F->getOffsetInParent(); size_t T2Offset = ParamTySL->getElementOffset(Index); if (T1Offset != T2Offset) return false; // Check size size_t T1Size = F->getType()->getAllocSize(); size_t T2Size = getRSContext()->getDataLayout()->getTypeAllocSize(T2); if (T1Size != T2Size) return false; } return true; } } // namespace slang