/*
* Copyright 2010, The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "slang_rs_object_ref_count.h"
#include <list>
#include "clang/AST/DeclGroup.h"
#include "clang/AST/Expr.h"
#include "clang/AST/NestedNameSpecifier.h"
#include "clang/AST/OperationKinds.h"
#include "clang/AST/Stmt.h"
#include "clang/AST/StmtVisitor.h"
#include "slang_assert.h"
#include "slang.h"
#include "slang_rs_ast_replace.h"
#include "slang_rs_export_type.h"
namespace slang {
/* Even though those two arrays are of size DataTypeMax, only entries that
* correspond to object types will be set.
*/
clang::FunctionDecl *
RSObjectRefCount::RSSetObjectFD[DataTypeMax];
clang::FunctionDecl *
RSObjectRefCount::RSClearObjectFD[DataTypeMax];
void RSObjectRefCount::GetRSRefCountingFunctions(clang::ASTContext &C) {
for (unsigned i = 0; i < DataTypeMax; i++) {
RSSetObjectFD[i] = nullptr;
RSClearObjectFD[i] = nullptr;
}
clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
E = TUDecl->decls_end(); I != E; I++) {
if ((I->getKind() >= clang::Decl::firstFunction) &&
(I->getKind() <= clang::Decl::lastFunction)) {
clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
// points to RSSetObjectFD or RSClearObjectFD
clang::FunctionDecl **RSObjectFD;
if (FD->getName() == "rsSetObject") {
slangAssert((FD->getNumParams() == 2) &&
"Invalid rsSetObject function prototype (# params)");
RSObjectFD = RSSetObjectFD;
} else if (FD->getName() == "rsClearObject") {
slangAssert((FD->getNumParams() == 1) &&
"Invalid rsClearObject function prototype (# params)");
RSObjectFD = RSClearObjectFD;
} else {
continue;
}
const clang::ParmVarDecl *PVD = FD->getParamDecl(0);
clang::QualType PVT = PVD->getOriginalType();
// The first parameter must be a pointer like rs_allocation*
slangAssert(PVT->isPointerType() &&
"Invalid rs{Set,Clear}Object function prototype (pointer param)");
// The rs object type passed to the FD
clang::QualType RST = PVT->getPointeeType();
DataType DT = RSExportPrimitiveType::GetRSSpecificType(RST.getTypePtr());
slangAssert(RSExportPrimitiveType::IsRSObjectType(DT)
&& "must be RS object type");
if (DT >= 0 && DT < DataTypeMax) {
RSObjectFD[DT] = FD;
} else {
slangAssert(false && "incorrect type");
}
}
}
}
namespace {
unsigned CountRSObjectTypes(const clang::Type *T);
clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
clang::Expr *DstExpr,
clang::Expr *SrcExpr,
clang::SourceLocation StartLoc,
clang::SourceLocation Loc);
// This function constructs a new CompoundStmt from the input StmtList.
clang::CompoundStmt* BuildCompoundStmt(clang::ASTContext &C,
std::list<clang::Stmt*> &StmtList, clang::SourceLocation Loc) {
unsigned NewStmtCount = StmtList.size();
unsigned CompoundStmtCount = 0;
clang::Stmt **CompoundStmtList;
CompoundStmtList = new clang::Stmt*[NewStmtCount];
std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
std::list<clang::Stmt*>::const_iterator E = StmtList.end();
for ( ; I != E; I++) {
CompoundStmtList[CompoundStmtCount++] = *I;
}
slangAssert(CompoundStmtCount == NewStmtCount);
clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
C, llvm::makeArrayRef(CompoundStmtList, CompoundStmtCount), Loc, Loc);
delete [] CompoundStmtList;
return CS;
}
void AppendAfterStmt(clang::ASTContext &C,
clang::CompoundStmt *CS,
clang::Stmt *S,
std::list<clang::Stmt*> &StmtList) {
slangAssert(CS);
clang::CompoundStmt::body_iterator bI = CS->body_begin();
clang::CompoundStmt::body_iterator bE = CS->body_end();
clang::Stmt **UpdatedStmtList =
new clang::Stmt*[CS->size() + StmtList.size()];
unsigned UpdatedStmtCount = 0;
unsigned Once = 0;
for ( ; bI != bE; bI++) {
if (!S && ((*bI)->getStmtClass() == clang::Stmt::ReturnStmtClass)) {
// If we come across a return here, we don't have anything we can
// reasonably replace. We should have already inserted our destructor
// code in the proper spot, so we just clean up and return.
delete [] UpdatedStmtList;
return;
}
UpdatedStmtList[UpdatedStmtCount++] = *bI;
if ((*bI == S) && !Once) {
Once++;
std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
std::list<clang::Stmt*>::const_iterator E = StmtList.end();
for ( ; I != E; I++) {
UpdatedStmtList[UpdatedStmtCount++] = *I;
}
}
}
slangAssert(Once <= 1);
// When S is nullptr, we are appending to the end of the CompoundStmt.
if (!S) {
slangAssert(Once == 0);
std::list<clang::Stmt*>::const_iterator I = StmtList.begin();
std::list<clang::Stmt*>::const_iterator E = StmtList.end();
for ( ; I != E; I++) {
UpdatedStmtList[UpdatedStmtCount++] = *I;
}
}
CS->setStmts(C, llvm::makeArrayRef(UpdatedStmtList, UpdatedStmtCount));
delete [] UpdatedStmtList;
}
// This class visits a compound statement and collects a list of all the exiting
// statements, such as any return statement in any sub-block, and any
// break/continue statement that would resume outside the current scope.
// We do not handle the case for goto statements that leave a local scope.
class DestructorVisitor : public clang::StmtVisitor<DestructorVisitor> {
private:
// The loop depth of the currently visited node.
int mLoopDepth;
// The switch statement depth of the currently visited node.
// Note that this is tracked separately from the loop depth because
// SwitchStmt-contained ContinueStmt's should have destructors for the
// corresponding loop scope.
int mSwitchDepth;
// Output of the visitor: the statements that should be replaced by compound
// statements, each of which contains rsClearObject() calls followed by the
// original statement.
std::vector<clang::Stmt*> mExitingStmts;
public:
DestructorVisitor() : mLoopDepth(0), mSwitchDepth(0) {}
const std::vector<clang::Stmt*>& getExitingStmts() const {
return mExitingStmts;
}
void VisitStmt(clang::Stmt *S);
void VisitBreakStmt(clang::BreakStmt *BS);
void VisitContinueStmt(clang::ContinueStmt *CS);
void VisitDoStmt(clang::DoStmt *DS);
void VisitForStmt(clang::ForStmt *FS);
void VisitReturnStmt(clang::ReturnStmt *RS);
void VisitSwitchStmt(clang::SwitchStmt *SS);
void VisitWhileStmt(clang::WhileStmt *WS);
};
void DestructorVisitor::VisitStmt(clang::Stmt *S) {
for (clang::Stmt* Child : S->children()) {
if (Child) {
Visit(Child);
}
}
}
void DestructorVisitor::VisitBreakStmt(clang::BreakStmt *BS) {
VisitStmt(BS);
if ((mLoopDepth == 0) && (mSwitchDepth == 0)) {
mExitingStmts.push_back(BS);
}
}
void DestructorVisitor::VisitContinueStmt(clang::ContinueStmt *CS) {
VisitStmt(CS);
if (mLoopDepth == 0) {
// Switch statements can have nested continues.
mExitingStmts.push_back(CS);
}
}
void DestructorVisitor::VisitDoStmt(clang::DoStmt *DS) {
mLoopDepth++;
VisitStmt(DS);
mLoopDepth--;
}
void DestructorVisitor::VisitForStmt(clang::ForStmt *FS) {
mLoopDepth++;
VisitStmt(FS);
mLoopDepth--;
}
void DestructorVisitor::VisitReturnStmt(clang::ReturnStmt *RS) {
mExitingStmts.push_back(RS);
}
void DestructorVisitor::VisitSwitchStmt(clang::SwitchStmt *SS) {
mSwitchDepth++;
VisitStmt(SS);
mSwitchDepth--;
}
void DestructorVisitor::VisitWhileStmt(clang::WhileStmt *WS) {
mLoopDepth++;
VisitStmt(WS);
mLoopDepth--;
}
clang::Expr *ClearSingleRSObject(clang::ASTContext &C,
clang::Expr *RefRSVar,
clang::SourceLocation Loc) {
slangAssert(RefRSVar);
const clang::Type *T = RefRSVar->getType().getTypePtr();
slangAssert(!T->isArrayType() &&
"Should not be destroying arrays with this function");
clang::FunctionDecl *ClearObjectFD = RSObjectRefCount::GetRSClearObjectFD(T);
slangAssert((ClearObjectFD != nullptr) &&
"rsClearObject doesn't cover all RS object types");
clang::QualType ClearObjectFDType = ClearObjectFD->getType();
clang::QualType ClearObjectFDArgType =
ClearObjectFD->getParamDecl(0)->getOriginalType();
// Example destructor for "rs_font localFont;"
//
// (CallExpr 'void'
// (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
// (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
// (UnaryOperator 'rs_font *' prefix '&'
// (DeclRefExpr 'rs_font':'rs_font' Var='localFont')))
// Get address of targeted RS object
clang::Expr *AddrRefRSVar =
new(C) clang::UnaryOperator(RefRSVar,
clang::UO_AddrOf,
ClearObjectFDArgType,
clang::VK_RValue,
clang::OK_Ordinary,
Loc);
clang::Expr *RefRSClearObjectFD =
clang::DeclRefExpr::Create(C,
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
ClearObjectFD,
false,
ClearObjectFD->getLocation(),
ClearObjectFDType,
clang::VK_RValue,
nullptr);
clang::Expr *RSClearObjectFP =
clang::ImplicitCastExpr::Create(C,
C.getPointerType(ClearObjectFDType),
clang::CK_FunctionToPointerDecay,
RefRSClearObjectFD,
nullptr,
clang::VK_RValue);
llvm::SmallVector<clang::Expr*, 1> ArgList;
ArgList.push_back(AddrRefRSVar);
clang::CallExpr *RSClearObjectCall =
new(C) clang::CallExpr(C,
RSClearObjectFP,
ArgList,
ClearObjectFD->getCallResultType(),
clang::VK_RValue,
Loc);
return RSClearObjectCall;
}
static int ArrayDim(const clang::Type *T) {
if (!T || !T->isArrayType()) {
return 0;
}
const clang::ConstantArrayType *CAT =
static_cast<const clang::ConstantArrayType *>(T);
return static_cast<int>(CAT->getSize().getSExtValue());
}
clang::Stmt *ClearStructRSObject(
clang::ASTContext &C,
clang::DeclContext *DC,
clang::Expr *RefRSStruct,
clang::SourceLocation StartLoc,
clang::SourceLocation Loc);
clang::Stmt *ClearArrayRSObject(
clang::ASTContext &C,
clang::DeclContext *DC,
clang::Expr *RefRSArr,
clang::SourceLocation StartLoc,
clang::SourceLocation Loc) {
const clang::Type *BaseType = RefRSArr->getType().getTypePtr();
slangAssert(BaseType->isArrayType());
int NumArrayElements = ArrayDim(BaseType);
// Actually extract out the base RS object type for use later
BaseType = BaseType->getArrayElementTypeNoTypeQual();
if (NumArrayElements <= 0) {
return nullptr;
}
// Example destructor loop for "rs_font fontArr[10];"
//
// (ForStmt
// (DeclStmt
// (VarDecl used rsIntIter 'int' cinit
// (IntegerLiteral 'int' 0)))
// (BinaryOperator 'int' '<'
// (ImplicitCastExpr int LValueToRValue
// (DeclRefExpr 'int' Var='rsIntIter'))
// (IntegerLiteral 'int' 10)
// nullptr << CondVar >>
// (UnaryOperator 'int' postfix '++'
// (DeclRefExpr 'int' Var='rsIntIter'))
// (CallExpr 'void'
// (ImplicitCastExpr 'void (*)(rs_font *)' <FunctionToPointerDecay>
// (DeclRefExpr 'void (rs_font *)' FunctionDecl='rsClearObject'))
// (UnaryOperator 'rs_font *' prefix '&'
// (ArraySubscriptExpr 'rs_font':'rs_font'
// (ImplicitCastExpr 'rs_font *' <ArrayToPointerDecay>
// (DeclRefExpr 'rs_font [10]' Var='fontArr'))
// (DeclRefExpr 'int' Var='rsIntIter'))))))
// Create helper variable for iterating through elements
static unsigned sIterCounter = 0;
std::stringstream UniqueIterName;
UniqueIterName << "rsIntIter" << sIterCounter++;
clang::IdentifierInfo *II = &C.Idents.get(UniqueIterName.str());
clang::VarDecl *IIVD =
clang::VarDecl::Create(C,
DC,
StartLoc,
Loc,
II,
C.IntTy,
C.getTrivialTypeSourceInfo(C.IntTy),
clang::SC_None);
// Mark "rsIntIter" as used
IIVD->markUsed(C);
// Form the actual destructor loop
// for (Init; Cond; Inc)
// RSClearObjectCall;
// Init -> "int rsIntIter = 0"
clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
IIVD->setInit(Int0);
clang::Decl *IID = (clang::Decl *)IIVD;
clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
clang::Stmt *Init = new(C) clang::DeclStmt(DGR, Loc, Loc);
// Cond -> "rsIntIter < NumArrayElements"
clang::DeclRefExpr *RefrsIntIterLValue =
clang::DeclRefExpr::Create(C,
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
IIVD,
false,
Loc,
C.IntTy,
clang::VK_LValue,
nullptr);
clang::Expr *RefrsIntIterRValue =
clang::ImplicitCastExpr::Create(C,
RefrsIntIterLValue->getType(),
clang::CK_LValueToRValue,
RefrsIntIterLValue,
nullptr,
clang::VK_RValue);
clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
clang::BinaryOperator *Cond =
new(C) clang::BinaryOperator(RefrsIntIterRValue,
NumArrayElementsExpr,
clang::BO_LT,
C.IntTy,
clang::VK_RValue,
clang::OK_Ordinary,
Loc,
false);
// Inc -> "rsIntIter++"
clang::UnaryOperator *Inc =
new(C) clang::UnaryOperator(RefrsIntIterLValue,
clang::UO_PostInc,
C.IntTy,
clang::VK_RValue,
clang::OK_Ordinary,
Loc);
// Body -> "rsClearObject(&VD[rsIntIter]);"
// Destructor loop operates on individual array elements
clang::Expr *RefRSArrPtr =
clang::ImplicitCastExpr::Create(C,
C.getPointerType(BaseType->getCanonicalTypeInternal()),
clang::CK_ArrayToPointerDecay,
RefRSArr,
nullptr,
clang::VK_RValue);
clang::Expr *RefRSArrPtrSubscript =
new(C) clang::ArraySubscriptExpr(RefRSArrPtr,
RefrsIntIterRValue,
BaseType->getCanonicalTypeInternal(),
clang::VK_RValue,
clang::OK_Ordinary,
Loc);
DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
clang::Stmt *RSClearObjectCall = nullptr;
if (BaseType->isArrayType()) {
RSClearObjectCall =
ClearArrayRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
} else if (DT == DataTypeUnknown) {
RSClearObjectCall =
ClearStructRSObject(C, DC, RefRSArrPtrSubscript, StartLoc, Loc);
} else {
RSClearObjectCall = ClearSingleRSObject(C, RefRSArrPtrSubscript, Loc);
}
clang::ForStmt *DestructorLoop =
new(C) clang::ForStmt(C,
Init,
Cond,
nullptr, // no condVar
Inc,
RSClearObjectCall,
Loc,
Loc,
Loc);
return DestructorLoop;
}
unsigned CountRSObjectTypes(const clang::Type *T) {
slangAssert(T);
unsigned RSObjectCount = 0;
if (T->isArrayType()) {
return CountRSObjectTypes(T->getArrayElementTypeNoTypeQual());
}
DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
if (DT != DataTypeUnknown) {
return (RSExportPrimitiveType::IsRSObjectType(DT) ? 1 : 0);
}
if (T->isUnionType()) {
clang::RecordDecl *RD = T->getAsUnionType()->getDecl();
RD = RD->getDefinition();
for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
FE = RD->field_end();
FI != FE;
FI++) {
const clang::FieldDecl *FD = *FI;
const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
if (CountRSObjectTypes(FT)) {
slangAssert(false && "can't have unions with RS object types!");
return 0;
}
}
}
if (!T->isStructureType()) {
return 0;
}
clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
RD = RD->getDefinition();
for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
FE = RD->field_end();
FI != FE;
FI++) {
const clang::FieldDecl *FD = *FI;
const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
if (CountRSObjectTypes(FT)) {
// Sub-structs should only count once (as should arrays, etc.)
RSObjectCount++;
}
}
return RSObjectCount;
}
clang::Stmt *ClearStructRSObject(
clang::ASTContext &C,
clang::DeclContext *DC,
clang::Expr *RefRSStruct,
clang::SourceLocation StartLoc,
clang::SourceLocation Loc) {
const clang::Type *BaseType = RefRSStruct->getType().getTypePtr();
slangAssert(!BaseType->isArrayType());
// Structs should show up as unknown primitive types
slangAssert(RSExportPrimitiveType::GetRSSpecificType(BaseType) ==
DataTypeUnknown);
unsigned FieldsToDestroy = CountRSObjectTypes(BaseType);
slangAssert(FieldsToDestroy != 0);
unsigned StmtCount = 0;
clang::Stmt **StmtArray = new clang::Stmt*[FieldsToDestroy];
for (unsigned i = 0; i < FieldsToDestroy; i++) {
StmtArray[i] = nullptr;
}
// Populate StmtArray by creating a destructor for each RS object field
clang::RecordDecl *RD = BaseType->getAsStructureType()->getDecl();
RD = RD->getDefinition();
for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
FE = RD->field_end();
FI != FE;
FI++) {
// We just look through all field declarations to see if we find a
// declaration for an RS object type (or an array of one).
bool IsArrayType = false;
clang::FieldDecl *FD = *FI;
const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
const clang::Type *OrigType = FT;
while (FT && FT->isArrayType()) {
FT = FT->getArrayElementTypeNoTypeQual();
IsArrayType = true;
}
// Pass a DeclarationNameInfo with a valid DeclName, since name equality
// gets asserted during CodeGen.
clang::DeclarationNameInfo FDDeclNameInfo(FD->getDeclName(),
FD->getLocation());
if (RSExportPrimitiveType::IsRSObjectType(FT)) {
clang::DeclAccessPair FoundDecl =
clang::DeclAccessPair::make(FD, clang::AS_none);
clang::MemberExpr *RSObjectMember =
clang::MemberExpr::Create(C,
RefRSStruct,
false,
clang::SourceLocation(),
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
FD,
FoundDecl,
FDDeclNameInfo,
nullptr,
OrigType->getCanonicalTypeInternal(),
clang::VK_RValue,
clang::OK_Ordinary);
slangAssert(StmtCount < FieldsToDestroy);
if (IsArrayType) {
StmtArray[StmtCount++] = ClearArrayRSObject(C,
DC,
RSObjectMember,
StartLoc,
Loc);
} else {
StmtArray[StmtCount++] = ClearSingleRSObject(C,
RSObjectMember,
Loc);
}
} else if (FT->isStructureType() && CountRSObjectTypes(FT)) {
// In this case, we have a nested struct. We may not end up filling all
// of the spaces in StmtArray (sub-structs should handle themselves
// with separate compound statements).
clang::DeclAccessPair FoundDecl =
clang::DeclAccessPair::make(FD, clang::AS_none);
clang::MemberExpr *RSObjectMember =
clang::MemberExpr::Create(C,
RefRSStruct,
false,
clang::SourceLocation(),
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
FD,
FoundDecl,
clang::DeclarationNameInfo(),
nullptr,
OrigType->getCanonicalTypeInternal(),
clang::VK_RValue,
clang::OK_Ordinary);
if (IsArrayType) {
StmtArray[StmtCount++] = ClearArrayRSObject(C,
DC,
RSObjectMember,
StartLoc,
Loc);
} else {
StmtArray[StmtCount++] = ClearStructRSObject(C,
DC,
RSObjectMember,
StartLoc,
Loc);
}
}
}
slangAssert(StmtCount > 0);
clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
delete [] StmtArray;
return CS;
}
clang::Stmt *CreateSingleRSSetObject(clang::ASTContext &C,
clang::Expr *DstExpr,
clang::Expr *SrcExpr,
clang::SourceLocation StartLoc,
clang::SourceLocation Loc) {
const clang::Type *T = DstExpr->getType().getTypePtr();
clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(T);
slangAssert((SetObjectFD != nullptr) &&
"rsSetObject doesn't cover all RS object types");
clang::QualType SetObjectFDType = SetObjectFD->getType();
clang::QualType SetObjectFDArgType[2];
SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
clang::Expr *RefRSSetObjectFD =
clang::DeclRefExpr::Create(C,
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
SetObjectFD,
false,
Loc,
SetObjectFDType,
clang::VK_RValue,
nullptr);
clang::Expr *RSSetObjectFP =
clang::ImplicitCastExpr::Create(C,
C.getPointerType(SetObjectFDType),
clang::CK_FunctionToPointerDecay,
RefRSSetObjectFD,
nullptr,
clang::VK_RValue);
llvm::SmallVector<clang::Expr*, 2> ArgList;
ArgList.push_back(new(C) clang::UnaryOperator(DstExpr,
clang::UO_AddrOf,
SetObjectFDArgType[0],
clang::VK_RValue,
clang::OK_Ordinary,
Loc));
ArgList.push_back(SrcExpr);
clang::CallExpr *RSSetObjectCall =
new(C) clang::CallExpr(C,
RSSetObjectFP,
ArgList,
SetObjectFD->getCallResultType(),
clang::VK_RValue,
Loc);
return RSSetObjectCall;
}
clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
clang::Expr *LHS,
clang::Expr *RHS,
clang::SourceLocation StartLoc,
clang::SourceLocation Loc);
/*static clang::Stmt *CreateArrayRSSetObject(clang::ASTContext &C,
clang::Expr *DstArr,
clang::Expr *SrcArr,
clang::SourceLocation StartLoc,
clang::SourceLocation Loc) {
clang::DeclContext *DC = nullptr;
const clang::Type *BaseType = DstArr->getType().getTypePtr();
slangAssert(BaseType->isArrayType());
int NumArrayElements = ArrayDim(BaseType);
// Actually extract out the base RS object type for use later
BaseType = BaseType->getArrayElementTypeNoTypeQual();
clang::Stmt *StmtArray[2] = {nullptr};
int StmtCtr = 0;
if (NumArrayElements <= 0) {
return nullptr;
}
// Create helper variable for iterating through elements
clang::IdentifierInfo& II = C.Idents.get("rsIntIter");
clang::VarDecl *IIVD =
clang::VarDecl::Create(C,
DC,
StartLoc,
Loc,
&II,
C.IntTy,
C.getTrivialTypeSourceInfo(C.IntTy),
clang::SC_None,
clang::SC_None);
clang::Decl *IID = (clang::Decl *)IIVD;
clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(C, &IID, 1);
StmtArray[StmtCtr++] = new(C) clang::DeclStmt(DGR, Loc, Loc);
// Form the actual loop
// for (Init; Cond; Inc)
// RSSetObjectCall;
// Init -> "rsIntIter = 0"
clang::DeclRefExpr *RefrsIntIter =
clang::DeclRefExpr::Create(C,
clang::NestedNameSpecifierLoc(),
IIVD,
Loc,
C.IntTy,
clang::VK_RValue,
nullptr);
clang::Expr *Int0 = clang::IntegerLiteral::Create(C,
llvm::APInt(C.getTypeSize(C.IntTy), 0), C.IntTy, Loc);
clang::BinaryOperator *Init =
new(C) clang::BinaryOperator(RefrsIntIter,
Int0,
clang::BO_Assign,
C.IntTy,
clang::VK_RValue,
clang::OK_Ordinary,
Loc);
// Cond -> "rsIntIter < NumArrayElements"
clang::Expr *NumArrayElementsExpr = clang::IntegerLiteral::Create(C,
llvm::APInt(C.getTypeSize(C.IntTy), NumArrayElements), C.IntTy, Loc);
clang::BinaryOperator *Cond =
new(C) clang::BinaryOperator(RefrsIntIter,
NumArrayElementsExpr,
clang::BO_LT,
C.IntTy,
clang::VK_RValue,
clang::OK_Ordinary,
Loc);
// Inc -> "rsIntIter++"
clang::UnaryOperator *Inc =
new(C) clang::UnaryOperator(RefrsIntIter,
clang::UO_PostInc,
C.IntTy,
clang::VK_RValue,
clang::OK_Ordinary,
Loc);
// Body -> "rsSetObject(&Dst[rsIntIter], Src[rsIntIter]);"
// Loop operates on individual array elements
clang::Expr *DstArrPtr =
clang::ImplicitCastExpr::Create(C,
C.getPointerType(BaseType->getCanonicalTypeInternal()),
clang::CK_ArrayToPointerDecay,
DstArr,
nullptr,
clang::VK_RValue);
clang::Expr *DstArrPtrSubscript =
new(C) clang::ArraySubscriptExpr(DstArrPtr,
RefrsIntIter,
BaseType->getCanonicalTypeInternal(),
clang::VK_RValue,
clang::OK_Ordinary,
Loc);
clang::Expr *SrcArrPtr =
clang::ImplicitCastExpr::Create(C,
C.getPointerType(BaseType->getCanonicalTypeInternal()),
clang::CK_ArrayToPointerDecay,
SrcArr,
nullptr,
clang::VK_RValue);
clang::Expr *SrcArrPtrSubscript =
new(C) clang::ArraySubscriptExpr(SrcArrPtr,
RefrsIntIter,
BaseType->getCanonicalTypeInternal(),
clang::VK_RValue,
clang::OK_Ordinary,
Loc);
DataType DT = RSExportPrimitiveType::GetRSSpecificType(BaseType);
clang::Stmt *RSSetObjectCall = nullptr;
if (BaseType->isArrayType()) {
RSSetObjectCall = CreateArrayRSSetObject(C, DstArrPtrSubscript,
SrcArrPtrSubscript,
StartLoc, Loc);
} else if (DT == DataTypeUnknown) {
RSSetObjectCall = CreateStructRSSetObject(C, DstArrPtrSubscript,
SrcArrPtrSubscript,
StartLoc, Loc);
} else {
RSSetObjectCall = CreateSingleRSSetObject(C, DstArrPtrSubscript,
SrcArrPtrSubscript,
StartLoc, Loc);
}
clang::ForStmt *DestructorLoop =
new(C) clang::ForStmt(C,
Init,
Cond,
nullptr, // no condVar
Inc,
RSSetObjectCall,
Loc,
Loc,
Loc);
StmtArray[StmtCtr++] = DestructorLoop;
slangAssert(StmtCtr == 2);
clang::CompoundStmt *CS =
new(C) clang::CompoundStmt(C, StmtArray, StmtCtr, Loc, Loc);
return CS;
} */
clang::Stmt *CreateStructRSSetObject(clang::ASTContext &C,
clang::Expr *LHS,
clang::Expr *RHS,
clang::SourceLocation StartLoc,
clang::SourceLocation Loc) {
clang::QualType QT = LHS->getType();
const clang::Type *T = QT.getTypePtr();
slangAssert(T->isStructureType());
slangAssert(!RSExportPrimitiveType::IsRSObjectType(T));
// Keep an extra slot for the original copy (memcpy)
unsigned FieldsToSet = CountRSObjectTypes(T) + 1;
unsigned StmtCount = 0;
clang::Stmt **StmtArray = new clang::Stmt*[FieldsToSet];
for (unsigned i = 0; i < FieldsToSet; i++) {
StmtArray[i] = nullptr;
}
clang::RecordDecl *RD = T->getAsStructureType()->getDecl();
RD = RD->getDefinition();
for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
FE = RD->field_end();
FI != FE;
FI++) {
bool IsArrayType = false;
clang::FieldDecl *FD = *FI;
const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
const clang::Type *OrigType = FT;
if (!CountRSObjectTypes(FT)) {
// Skip to next if we don't have any viable RS object types
continue;
}
clang::DeclAccessPair FoundDecl =
clang::DeclAccessPair::make(FD, clang::AS_none);
clang::MemberExpr *DstMember =
clang::MemberExpr::Create(C,
LHS,
false,
clang::SourceLocation(),
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
FD,
FoundDecl,
clang::DeclarationNameInfo(),
nullptr,
OrigType->getCanonicalTypeInternal(),
clang::VK_RValue,
clang::OK_Ordinary);
clang::MemberExpr *SrcMember =
clang::MemberExpr::Create(C,
RHS,
false,
clang::SourceLocation(),
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
FD,
FoundDecl,
clang::DeclarationNameInfo(),
nullptr,
OrigType->getCanonicalTypeInternal(),
clang::VK_RValue,
clang::OK_Ordinary);
if (FT->isArrayType()) {
FT = FT->getArrayElementTypeNoTypeQual();
IsArrayType = true;
}
DataType DT = RSExportPrimitiveType::GetRSSpecificType(FT);
if (IsArrayType) {
clang::DiagnosticsEngine &DiagEngine = C.getDiagnostics();
DiagEngine.Report(
clang::FullSourceLoc(Loc, C.getSourceManager()),
DiagEngine.getCustomDiagID(
clang::DiagnosticsEngine::Error,
"Arrays of RS object types within structures cannot be copied"));
// TODO(srhines): Support setting arrays of RS objects
// StmtArray[StmtCount++] =
// CreateArrayRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
} else if (DT == DataTypeUnknown) {
StmtArray[StmtCount++] =
CreateStructRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
} else if (RSExportPrimitiveType::IsRSObjectType(DT)) {
StmtArray[StmtCount++] =
CreateSingleRSSetObject(C, DstMember, SrcMember, StartLoc, Loc);
} else {
slangAssert(false);
}
}
slangAssert(StmtCount < FieldsToSet);
// We still need to actually do the overall struct copy. For simplicity,
// we just do a straight-up assignment (which will still preserve all
// the proper RS object reference counts).
clang::BinaryOperator *CopyStruct =
new(C) clang::BinaryOperator(LHS, RHS, clang::BO_Assign, QT,
clang::VK_RValue, clang::OK_Ordinary, Loc,
false);
StmtArray[StmtCount++] = CopyStruct;
clang::CompoundStmt *CS = new(C) clang::CompoundStmt(
C, llvm::makeArrayRef(StmtArray, StmtCount), Loc, Loc);
delete [] StmtArray;
return CS;
}
} // namespace
void RSObjectRefCount::Scope::InsertStmt(const clang::ASTContext &C,
clang::Stmt *NewStmt) {
std::vector<clang::Stmt*> newBody;
for (clang::Stmt* S1 : mCS->body()) {
if (S1 == mCurrent) {
newBody.push_back(NewStmt);
}
newBody.push_back(S1);
}
mCS->setStmts(C, newBody);
}
void RSObjectRefCount::Scope::ReplaceStmt(const clang::ASTContext &C,
clang::Stmt *NewStmt) {
std::vector<clang::Stmt*> newBody;
for (clang::Stmt* S1 : mCS->body()) {
if (S1 == mCurrent) {
newBody.push_back(NewStmt);
} else {
newBody.push_back(S1);
}
}
mCS->setStmts(C, newBody);
}
void RSObjectRefCount::Scope::ReplaceExpr(const clang::ASTContext& C,
clang::Expr* OldExpr,
clang::Expr* NewExpr) {
RSASTReplace R(C);
R.ReplaceStmt(mCurrent, OldExpr, NewExpr);
}
void RSObjectRefCount::Scope::ReplaceRSObjectAssignment(
clang::BinaryOperator *AS) {
clang::QualType QT = AS->getType();
clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
DataTypeRSAllocation)->getASTContext();
clang::SourceLocation Loc = AS->getExprLoc();
clang::SourceLocation StartLoc = AS->getLHS()->getExprLoc();
clang::Stmt *UpdatedStmt = nullptr;
if (!RSExportPrimitiveType::IsRSObjectType(QT.getTypePtr())) {
// By definition, this is a struct assignment if we get here
UpdatedStmt =
CreateStructRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
} else {
UpdatedStmt =
CreateSingleRSSetObject(C, AS->getLHS(), AS->getRHS(), StartLoc, Loc);
}
RSASTReplace R(C);
R.ReplaceStmt(mCS, AS, UpdatedStmt);
}
void RSObjectRefCount::Scope::AppendRSObjectInit(
clang::VarDecl *VD,
clang::DeclStmt *DS,
DataType DT,
clang::Expr *InitExpr) {
slangAssert(VD);
if (!InitExpr) {
return;
}
clang::ASTContext &C = RSObjectRefCount::GetRSSetObjectFD(
DataTypeRSAllocation)->getASTContext();
clang::SourceLocation Loc = RSObjectRefCount::GetRSSetObjectFD(
DataTypeRSAllocation)->getLocation();
clang::SourceLocation StartLoc = RSObjectRefCount::GetRSSetObjectFD(
DataTypeRSAllocation)->getInnerLocStart();
if (DT == DataTypeIsStruct) {
const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
clang::DeclRefExpr *RefRSVar =
clang::DeclRefExpr::Create(C,
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
VD,
false,
Loc,
T->getCanonicalTypeInternal(),
clang::VK_RValue,
nullptr);
clang::Stmt *RSSetObjectOps =
CreateStructRSSetObject(C, RefRSVar, InitExpr, StartLoc, Loc);
std::list<clang::Stmt*> StmtList;
StmtList.push_back(RSSetObjectOps);
AppendAfterStmt(C, mCS, DS, StmtList);
return;
}
clang::FunctionDecl *SetObjectFD = RSObjectRefCount::GetRSSetObjectFD(DT);
slangAssert((SetObjectFD != nullptr) &&
"rsSetObject doesn't cover all RS object types");
clang::QualType SetObjectFDType = SetObjectFD->getType();
clang::QualType SetObjectFDArgType[2];
SetObjectFDArgType[0] = SetObjectFD->getParamDecl(0)->getOriginalType();
SetObjectFDArgType[1] = SetObjectFD->getParamDecl(1)->getOriginalType();
clang::Expr *RefRSSetObjectFD =
clang::DeclRefExpr::Create(C,
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
SetObjectFD,
false,
Loc,
SetObjectFDType,
clang::VK_RValue,
nullptr);
clang::Expr *RSSetObjectFP =
clang::ImplicitCastExpr::Create(C,
C.getPointerType(SetObjectFDType),
clang::CK_FunctionToPointerDecay,
RefRSSetObjectFD,
nullptr,
clang::VK_RValue);
const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
clang::DeclRefExpr *RefRSVar =
clang::DeclRefExpr::Create(C,
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
VD,
false,
Loc,
T->getCanonicalTypeInternal(),
clang::VK_RValue,
nullptr);
llvm::SmallVector<clang::Expr*, 2> ArgList;
ArgList.push_back(new(C) clang::UnaryOperator(RefRSVar,
clang::UO_AddrOf,
SetObjectFDArgType[0],
clang::VK_RValue,
clang::OK_Ordinary,
Loc));
ArgList.push_back(InitExpr);
clang::CallExpr *RSSetObjectCall =
new(C) clang::CallExpr(C,
RSSetObjectFP,
ArgList,
SetObjectFD->getCallResultType(),
clang::VK_RValue,
Loc);
std::list<clang::Stmt*> StmtList;
StmtList.push_back(RSSetObjectCall);
AppendAfterStmt(C, mCS, DS, StmtList);
}
void RSObjectRefCount::Scope::InsertLocalVarDestructors() {
if (mRSO.empty()) {
return;
}
clang::DeclContext* DC = mRSO.front()->getDeclContext();
clang::ASTContext& C = DC->getParentASTContext();
clang::SourceManager& SM = C.getSourceManager();
const auto& OccursBefore = [&SM] (clang::SourceLocation L1, clang::SourceLocation L2)->bool {
return SM.isBeforeInTranslationUnit(L1, L2);
};
typedef std::map<clang::SourceLocation, clang::Stmt*, decltype(OccursBefore)> DMap;
DMap dtors(OccursBefore);
// Create rsClearObject calls. Note the DMap entries are sorted by the SourceLocation.
for (clang::VarDecl* VD : mRSO) {
clang::SourceLocation Loc = VD->getSourceRange().getBegin();
clang::Stmt* RSClearObjectCall = ClearRSObject(VD, DC);
dtors.insert(std::make_pair(Loc, RSClearObjectCall));
}
DestructorVisitor Visitor;
Visitor.Visit(mCS);
// Replace each exiting statement with a block that contains the original statement
// and added rsClearObject() calls before it.
for (clang::Stmt* S : Visitor.getExitingStmts()) {
const clang::SourceLocation currentLoc = S->getLocStart();
DMap::iterator firstDtorIter = dtors.begin();
DMap::iterator currentDtorIter = firstDtorIter;
DMap::iterator lastDtorIter = dtors.end();
while (currentDtorIter != lastDtorIter &&
OccursBefore(currentDtorIter->first, currentLoc)) {
currentDtorIter++;
}
if (currentDtorIter == firstDtorIter) {
continue;
}
std::list<clang::Stmt*> Stmts;
// Insert rsClearObject() calls for all rsObjects declared before the current statement
for(DMap::iterator it = firstDtorIter; it != currentDtorIter; it++) {
Stmts.push_back(it->second);
}
Stmts.push_back(S);
RSASTReplace R(C);
clang::CompoundStmt* CS = BuildCompoundStmt(C, Stmts, S->getLocEnd());
R.ReplaceStmt(mCS, S, CS);
}
std::list<clang::Stmt*> Stmts;
for(auto LocCallPair : dtors) {
Stmts.push_back(LocCallPair.second);
}
AppendAfterStmt(C, mCS, nullptr, Stmts);
}
clang::Stmt *RSObjectRefCount::Scope::ClearRSObject(
clang::VarDecl *VD,
clang::DeclContext *DC) {
slangAssert(VD);
clang::ASTContext &C = VD->getASTContext();
clang::SourceLocation Loc = VD->getLocation();
clang::SourceLocation StartLoc = VD->getInnerLocStart();
const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
// Reference expr to target RS object variable
clang::DeclRefExpr *RefRSVar =
clang::DeclRefExpr::Create(C,
clang::NestedNameSpecifierLoc(),
clang::SourceLocation(),
VD,
false,
Loc,
T->getCanonicalTypeInternal(),
clang::VK_RValue,
nullptr);
if (T->isArrayType()) {
return ClearArrayRSObject(C, DC, RefRSVar, StartLoc, Loc);
}
DataType DT = RSExportPrimitiveType::GetRSSpecificType(T);
if (DT == DataTypeUnknown ||
DT == DataTypeIsStruct) {
return ClearStructRSObject(C, DC, RefRSVar, StartLoc, Loc);
}
slangAssert((RSExportPrimitiveType::IsRSObjectType(DT)) &&
"Should be RS object");
return ClearSingleRSObject(C, RefRSVar, Loc);
}
bool RSObjectRefCount::InitializeRSObject(clang::VarDecl *VD,
DataType *DT,
clang::Expr **InitExpr) {
slangAssert(VD && DT && InitExpr);
const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
// Loop through array types to get to base type
while (T && T->isArrayType()) {
T = T->getArrayElementTypeNoTypeQual();
}
bool DataTypeIsStructWithRSObject = false;
*DT = RSExportPrimitiveType::GetRSSpecificType(T);
if (*DT == DataTypeUnknown) {
if (RSExportPrimitiveType::IsStructureTypeWithRSObject(T)) {
*DT = DataTypeIsStruct;
DataTypeIsStructWithRSObject = true;
} else {
return false;
}
}
bool DataTypeIsRSObject = false;
if (DataTypeIsStructWithRSObject) {
DataTypeIsRSObject = true;
} else {
DataTypeIsRSObject = RSExportPrimitiveType::IsRSObjectType(*DT);
}
*InitExpr = VD->getInit();
if (!DataTypeIsRSObject && *InitExpr) {
// If we already have an initializer for a matrix type, we are done.
return DataTypeIsRSObject;
}
clang::Expr *ZeroInitializer =
CreateEmptyInitListExpr(VD->getASTContext(), VD->getLocation());
if (ZeroInitializer) {
ZeroInitializer->setType(T->getCanonicalTypeInternal());
VD->setInit(ZeroInitializer);
}
return DataTypeIsRSObject;
}
clang::Expr *RSObjectRefCount::CreateEmptyInitListExpr(
clang::ASTContext &C,
const clang::SourceLocation &Loc) {
// We can cheaply construct a zero initializer by just creating an empty
// initializer list. Clang supports this extension to C(99), and will create
// any necessary constructs to zero out the entire variable.
llvm::SmallVector<clang::Expr*, 1> EmptyInitList;
return new(C) clang::InitListExpr(C, Loc, EmptyInitList, Loc);
}
clang::CompoundStmt* RSObjectRefCount::CreateRetStmtWithTempVar(
clang::ASTContext& C,
clang::DeclContext* DC,
clang::ReturnStmt* RS,
const unsigned id) {
std::list<clang::Stmt*> NewStmts;
// Since we insert rsClearObj() calls before the return statement, we need
// to make sure none of the cleared RS objects are referenced in the
// return statement.
// For that, we create a new local variable named .rs.retval, assign the
// original return expression to it, make all necessary rsClearObj()
// calls, then return .rs.retval. Note rsClearObj() is not called on
// .rs.retval.
clang::SourceLocation Loc = RS->getLocStart();
std::stringstream ss;
ss << ".rs.retval" << id;
llvm::StringRef VarName(ss.str());
clang::Expr* RetVal = RS->getRetValue();
const clang::QualType RetTy = RetVal->getType();
clang::VarDecl* RSRetValDecl = clang::VarDecl::Create(
C, // AST context
DC, // Decl context
Loc, // Start location
Loc, // Id location
&C.Idents.get(VarName), // Id
RetTy, // Type
C.getTrivialTypeSourceInfo(RetTy), // Type info
clang::SC_None // Storage class
);
const clang::Type *T = RetTy.getTypePtr();
clang::Expr *ZeroInitializer =
RSObjectRefCount::CreateEmptyInitListExpr(C, Loc);
ZeroInitializer->setType(T->getCanonicalTypeInternal());
RSRetValDecl->setInit(ZeroInitializer);
clang::Decl* Decls[] = { RSRetValDecl };
const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
C, Decls, sizeof(Decls) / sizeof(*Decls));
clang::DeclStmt* DS = new (C) clang::DeclStmt(DGR, Loc, Loc);
NewStmts.push_back(DS);
clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
C,
clang::NestedNameSpecifierLoc(), // QualifierLoc
Loc, // TemplateKWLoc
RSRetValDecl,
false, // RefersToEnclosingVariableOrCapture
Loc, // NameLoc
RetTy,
clang::VK_LValue
);
clang::Stmt* SetRetTempVar = CreateSingleRSSetObject(C, DRE, RetVal, Loc, Loc);
NewStmts.push_back(SetRetTempVar);
// Creates a new return statement
clang::ReturnStmt* NewRet = new (C) clang::ReturnStmt(RS->getReturnLoc());
clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
C,
RetTy,
clang::CK_LValueToRValue,
DRE,
nullptr,
clang::VK_RValue
);
NewRet->setRetValue(CastExpr);
NewStmts.push_back(NewRet);
return BuildCompoundStmt(C, NewStmts, Loc);
}
void RSObjectRefCount::VisitDeclStmt(clang::DeclStmt *DS) {
VisitStmt(DS);
getCurrentScope()->setCurrentStmt(DS);
for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
I != E;
I++) {
clang::Decl *D = *I;
if (D->getKind() == clang::Decl::Var) {
clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
DataType DT = DataTypeUnknown;
clang::Expr *InitExpr = nullptr;
if (InitializeRSObject(VD, &DT, &InitExpr)) {
// We need to zero-init all RS object types (including matrices), ...
getCurrentScope()->AppendRSObjectInit(VD, DS, DT, InitExpr);
// ... but, only add to the list of RS objects if we have some
// non-matrix RS object fields.
if (CountRSObjectTypes(VD->getType().getTypePtr())) {
getCurrentScope()->addRSObject(VD);
}
}
}
}
}
void RSObjectRefCount::VisitCallExpr(clang::CallExpr* CE) {
clang::QualType RetTy;
const clang::FunctionDecl* FD = CE->getDirectCallee();
if (FD) {
// Direct calls
RetTy = FD->getReturnType();
} else {
// Indirect calls through function pointers
const clang::Expr* Callee = CE->getCallee();
const clang::Type* CalleeType = Callee->getType().getTypePtr();
const clang::PointerType* PtrType = CalleeType->getAs<clang::PointerType>();
if (!PtrType) {
return;
}
const clang::Type* PointeeType = PtrType->getPointeeType().getTypePtr();
const clang::FunctionType* FuncType = PointeeType->getAs<clang::FunctionType>();
if (!FuncType) {
return;
}
RetTy = FuncType->getReturnType();
}
if (CountRSObjectTypes(RetTy.getTypePtr())==0) {
return;
}
clang::SourceLocation Loc = CE->getSourceRange().getBegin();
std::stringstream ss;
ss << ".rs.tmp" << getNextID();
llvm::StringRef VarName(ss.str());
clang::VarDecl* TempVarDecl = clang::VarDecl::Create(
mCtx, // AST context
GetDeclContext(), // Decl context
Loc, // Start location
Loc, // Id location
&mCtx.Idents.get(VarName), // Id
RetTy, // Type
mCtx.getTrivialTypeSourceInfo(RetTy), // Type info
clang::SC_None // Storage class
);
TempVarDecl->setInit(CE);
clang::Decl* Decls[] = { TempVarDecl };
const clang::DeclGroupRef DGR = clang::DeclGroupRef::Create(
mCtx, Decls, sizeof(Decls) / sizeof(*Decls));
clang::DeclStmt* DS = new (mCtx) clang::DeclStmt(DGR, Loc, Loc);
getCurrentScope()->InsertStmt(mCtx, DS);
clang::DeclRefExpr* DRE = clang::DeclRefExpr::Create(
mCtx, // AST context
clang::NestedNameSpecifierLoc(), // QualifierLoc
Loc, // TemplateKWLoc
TempVarDecl,
false, // RefersToEnclosingVariableOrCapture
Loc, // NameLoc
RetTy,
clang::VK_LValue
);
clang::Expr* CastExpr = clang::ImplicitCastExpr::Create(
mCtx,
RetTy,
clang::CK_LValueToRValue,
DRE,
nullptr,
clang::VK_RValue
);
getCurrentScope()->ReplaceExpr(mCtx, CE, CastExpr);
// Register TempVarDecl for destruction call (rsClearObj).
getCurrentScope()->addRSObject(TempVarDecl);
}
void RSObjectRefCount::VisitCompoundStmt(clang::CompoundStmt *CS) {
if (!emptyScope()) {
getCurrentScope()->setCurrentStmt(CS);
}
if (!CS->body_empty()) {
// Push a new scope
Scope *S = new Scope(CS);
mScopeStack.push_back(S);
VisitStmt(CS);
// Destroy the scope
slangAssert((getCurrentScope() == S) && "Corrupted scope stack!");
S->InsertLocalVarDestructors();
mScopeStack.pop_back();
delete S;
}
}
void RSObjectRefCount::VisitBinAssign(clang::BinaryOperator *AS) {
getCurrentScope()->setCurrentStmt(AS);
clang::QualType QT = AS->getType();
if (CountRSObjectTypes(QT.getTypePtr())) {
getCurrentScope()->ReplaceRSObjectAssignment(AS);
}
}
void RSObjectRefCount::VisitReturnStmt(clang::ReturnStmt *RS) {
getCurrentScope()->setCurrentStmt(RS);
// If there is no local rsObject declared so far, no need to transform the
// return statement.
bool RSObjDeclared = false;
for (const Scope* S : mScopeStack) {
if (S->hasRSObject()) {
RSObjDeclared = true;
break;
}
}
if (!RSObjDeclared) {
return;
}
// If the return statement does not return anything, or if it does not return
// a rsObject, no need to transform it.
clang::Expr* RetVal = RS->getRetValue();
if (!RetVal || CountRSObjectTypes(RetVal->getType().getTypePtr()) == 0) {
return;
}
// Transform the return statement so that it does not potentially return or
// reference a rsObject that has been cleared.
clang::CompoundStmt* NewRS;
NewRS = CreateRetStmtWithTempVar(mCtx, GetDeclContext(), RS, getNextID());
getCurrentScope()->ReplaceStmt(mCtx, NewRS);
}
void RSObjectRefCount::VisitStmt(clang::Stmt *S) {
getCurrentScope()->setCurrentStmt(S);
for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
I != E;
I++) {
if (clang::Stmt *Child = *I) {
Visit(Child);
}
}
}
// This function walks the list of global variables and (potentially) creates
// a single global static destructor function that properly decrements
// reference counts on the contained RS object types.
clang::FunctionDecl *RSObjectRefCount::CreateStaticGlobalDtor() {
Init();
clang::DeclContext *DC = mCtx.getTranslationUnitDecl();
clang::SourceLocation loc;
llvm::StringRef SR(".rs.dtor");
clang::IdentifierInfo &II = mCtx.Idents.get(SR);
clang::DeclarationName N(&II);
clang::FunctionProtoType::ExtProtoInfo EPI;
clang::QualType T = mCtx.getFunctionType(mCtx.VoidTy,
llvm::ArrayRef<clang::QualType>(), EPI);
clang::FunctionDecl *FD = nullptr;
// Generate rsClearObject() call chains for every global variable
// (whether static or extern).
std::list<clang::Stmt *> StmtList;
for (clang::DeclContext::decl_iterator I = DC->decls_begin(),
E = DC->decls_end(); I != E; I++) {
clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*I);
if (VD) {
if (CountRSObjectTypes(VD->getType().getTypePtr())) {
if (!FD) {
// Only create FD if we are going to use it.
FD = clang::FunctionDecl::Create(mCtx, DC, loc, loc, N, T, nullptr,
clang::SC_None);
}
// Mark VD as used. It might be unused, except for the destructor.
// 'markUsed' has side-effects that are caused only if VD is not already
// used. Hence no need for an extra check here.
VD->markUsed(mCtx);
// Make sure to create any helpers within the function's DeclContext,
// not the one associated with the global translation unit.
clang::Stmt *RSClearObjectCall = Scope::ClearRSObject(VD, FD);
StmtList.push_back(RSClearObjectCall);
}
}
}
// Nothing needs to be destroyed, so don't emit a dtor.
if (StmtList.empty()) {
return nullptr;
}
clang::CompoundStmt *CS = BuildCompoundStmt(mCtx, StmtList, loc);
FD->setBody(CS);
return FD;
}
bool HasRSObjectType(const clang::Type *T) {
return CountRSObjectTypes(T) != 0;
}
} // namespace slang