// Copyright 2016 The SwiftShader Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "ValidateLimitations.h" #include "InfoSink.h" #include "InitializeParseContext.h" #include "ParseHelper.h" namespace { bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) { for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) { if (i->index.id == symbol->getId()) return true; } return false; } void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) { for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) { if (i->index.id == symbol->getId()) { ASSERT(i->loop); i->loop->setUnrollFlag(true); return; } } UNREACHABLE(0); } // Traverses a node to check if it represents a constant index expression. // Definition: // constant-index-expressions are a superset of constant-expressions. // Constant-index-expressions can include loop indices as defined in // GLSL ES 1.0 spec, Appendix A, section 4. // The following are constant-index-expressions: // - Constant expressions // - Loop indices as defined in section 4 // - Expressions composed of both of the above class ValidateConstIndexExpr : public TIntermTraverser { public: ValidateConstIndexExpr(const TLoopStack& stack) : mValid(true), mLoopStack(stack) {} // Returns true if the parsed node represents a constant index expression. bool isValid() const { return mValid; } virtual void visitSymbol(TIntermSymbol* symbol) { // Only constants and loop indices are allowed in a // constant index expression. if (mValid) { mValid = (symbol->getQualifier() == EvqConstExpr) || IsLoopIndex(symbol, mLoopStack); } } private: bool mValid; const TLoopStack& mLoopStack; }; // Traverses a node to check if it uses a loop index. // If an int loop index is used in its body as a sampler array index, // mark the loop for unroll. class ValidateLoopIndexExpr : public TIntermTraverser { public: ValidateLoopIndexExpr(TLoopStack& stack) : mUsesFloatLoopIndex(false), mUsesIntLoopIndex(false), mLoopStack(stack) {} bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; } bool usesIntLoopIndex() const { return mUsesIntLoopIndex; } virtual void visitSymbol(TIntermSymbol* symbol) { if (IsLoopIndex(symbol, mLoopStack)) { switch (symbol->getBasicType()) { case EbtFloat: mUsesFloatLoopIndex = true; break; case EbtUInt: mUsesIntLoopIndex = true; MarkLoopForUnroll(symbol, mLoopStack); break; case EbtInt: mUsesIntLoopIndex = true; MarkLoopForUnroll(symbol, mLoopStack); break; default: UNREACHABLE(symbol->getBasicType()); } } } private: bool mUsesFloatLoopIndex; bool mUsesIntLoopIndex; TLoopStack& mLoopStack; }; } // namespace ValidateLimitations::ValidateLimitations(GLenum shaderType, TInfoSinkBase& sink) : mShaderType(shaderType), mSink(sink), mNumErrors(0) { } bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node) { // Check if loop index is modified in the loop body. validateOperation(node, node->getLeft()); // Check indexing. switch (node->getOp()) { case EOpIndexDirect: validateIndexing(node); break; case EOpIndexIndirect: validateIndexing(node); break; default: break; } return true; } bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node) { // Check if loop index is modified in the loop body. validateOperation(node, node->getOperand()); return true; } bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node) { switch (node->getOp()) { case EOpFunctionCall: validateFunctionCall(node); break; default: break; } return true; } bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node) { if (!validateLoopType(node)) return false; TLoopInfo info; memset(&info, 0, sizeof(TLoopInfo)); info.loop = node; if (!validateForLoopHeader(node, &info)) return false; TIntermNode* body = node->getBody(); if (body) { mLoopStack.push_back(info); body->traverse(this); mLoopStack.pop_back(); } // The loop is fully processed - no need to visit children. return false; } void ValidateLimitations::error(TSourceLoc loc, const char *reason, const char* token) { mSink.prefix(EPrefixError); mSink.location(loc); mSink << "'" << token << "' : " << reason << "\n"; ++mNumErrors; } bool ValidateLimitations::withinLoopBody() const { return !mLoopStack.empty(); } bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const { return IsLoopIndex(symbol, mLoopStack); } bool ValidateLimitations::validateLoopType(TIntermLoop* node) { TLoopType type = node->getType(); if (type == ELoopFor) return true; // Reject while and do-while loops. error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do"); return false; } bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node, TLoopInfo* info) { ASSERT(node->getType() == ELoopFor); // // The for statement has the form: // for ( init-declaration ; condition ; expression ) statement // if (!validateForLoopInit(node, info)) return false; if (!validateForLoopCond(node, info)) return false; if (!validateForLoopExpr(node, info)) return false; return true; } bool ValidateLimitations::validateForLoopInit(TIntermLoop* node, TLoopInfo* info) { TIntermNode* init = node->getInit(); if (!init) { error(node->getLine(), "Missing init declaration", "for"); return false; } // // init-declaration has the form: // type-specifier identifier = constant-expression // TIntermAggregate* decl = init->getAsAggregate(); if (!decl || (decl->getOp() != EOpDeclaration)) { error(init->getLine(), "Invalid init declaration", "for"); return false; } // To keep things simple do not allow declaration list. TIntermSequence& declSeq = decl->getSequence(); if (declSeq.size() != 1) { error(decl->getLine(), "Invalid init declaration", "for"); return false; } TIntermBinary* declInit = declSeq[0]->getAsBinaryNode(); if (!declInit || (declInit->getOp() != EOpInitialize)) { error(decl->getLine(), "Invalid init declaration", "for"); return false; } TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode(); if (!symbol) { error(declInit->getLine(), "Invalid init declaration", "for"); return false; } // The loop index has type int or float. TBasicType type = symbol->getBasicType(); if (!IsInteger(type) && (type != EbtFloat)) { error(symbol->getLine(), "Invalid type for loop index", getBasicString(type)); return false; } // The loop index is initialized with constant expression. if (!isConstExpr(declInit->getRight())) { error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression", symbol->getSymbol().c_str()); return false; } info->index.id = symbol->getId(); return true; } bool ValidateLimitations::validateForLoopCond(TIntermLoop* node, TLoopInfo* info) { TIntermNode* cond = node->getCondition(); if (!cond) { error(node->getLine(), "Missing condition", "for"); return false; } // // condition has the form: // loop_index relational_operator constant_expression // TIntermBinary* binOp = cond->getAsBinaryNode(); if (!binOp) { error(node->getLine(), "Invalid condition", "for"); return false; } // Loop index should be to the left of relational operator. TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode(); if (!symbol) { error(binOp->getLine(), "Invalid condition", "for"); return false; } if (symbol->getId() != info->index.id) { error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str()); return false; } // Relational operator is one of: > >= < <= == or !=. switch (binOp->getOp()) { case EOpEqual: case EOpNotEqual: case EOpLessThan: case EOpGreaterThan: case EOpLessThanEqual: case EOpGreaterThanEqual: break; default: error(binOp->getLine(), "Invalid relational operator", getOperatorString(binOp->getOp())); break; } // Loop index must be compared with a constant. if (!isConstExpr(binOp->getRight())) { error(binOp->getLine(), "Loop index cannot be compared with non-constant expression", symbol->getSymbol().c_str()); return false; } return true; } bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node, TLoopInfo* info) { TIntermNode* expr = node->getExpression(); if (!expr) { error(node->getLine(), "Missing expression", "for"); return false; } // for expression has one of the following forms: // loop_index++ // loop_index-- // loop_index += constant_expression // loop_index -= constant_expression // ++loop_index // --loop_index // The last two forms are not specified in the spec, but I am assuming // its an oversight. TIntermUnary* unOp = expr->getAsUnaryNode(); TIntermBinary* binOp = unOp ? nullptr : expr->getAsBinaryNode(); TOperator op = EOpNull; TIntermSymbol* symbol = nullptr; if (unOp) { op = unOp->getOp(); symbol = unOp->getOperand()->getAsSymbolNode(); } else if (binOp) { op = binOp->getOp(); symbol = binOp->getLeft()->getAsSymbolNode(); } // The operand must be loop index. if (!symbol) { error(expr->getLine(), "Invalid expression", "for"); return false; } if (symbol->getId() != info->index.id) { error(symbol->getLine(), "Expected loop index", symbol->getSymbol().c_str()); return false; } // The operator is one of: ++ -- += -=. switch (op) { case EOpPostIncrement: case EOpPostDecrement: case EOpPreIncrement: case EOpPreDecrement: ASSERT((unOp != NULL) && (binOp == NULL)); break; case EOpAddAssign: case EOpSubAssign: ASSERT((unOp == NULL) && (binOp != NULL)); break; default: error(expr->getLine(), "Invalid operator", getOperatorString(op)); return false; } // Loop index must be incremented/decremented with a constant. if (binOp != NULL) { if (!isConstExpr(binOp->getRight())) { error(binOp->getLine(), "Loop index cannot be modified by non-constant expression", symbol->getSymbol().c_str()); return false; } } return true; } bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node) { ASSERT(node->getOp() == EOpFunctionCall); // If not within loop body, there is nothing to check. if (!withinLoopBody()) return true; // List of param indices for which loop indices are used as argument. typedef std::vector<int> ParamIndex; ParamIndex pIndex; TIntermSequence& params = node->getSequence(); for (TIntermSequence::size_type i = 0; i < params.size(); ++i) { TIntermSymbol* symbol = params[i]->getAsSymbolNode(); if (symbol && isLoopIndex(symbol)) pIndex.push_back(i); } // If none of the loop indices are used as arguments, // there is nothing to check. if (pIndex.empty()) return true; bool valid = true; TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable; TSymbol* symbol = symbolTable.find(node->getName(), GetGlobalParseContext()->getShaderVersion()); ASSERT(symbol && symbol->isFunction()); TFunction* function = static_cast<TFunction*>(symbol); for (ParamIndex::const_iterator i = pIndex.begin(); i != pIndex.end(); ++i) { const TParameter& param = function->getParam(*i); TQualifier qual = param.type->getQualifier(); if ((qual == EvqOut) || (qual == EvqInOut)) { error(params[*i]->getLine(), "Loop index cannot be used as argument to a function out or inout parameter", params[*i]->getAsSymbolNode()->getSymbol().c_str()); valid = false; } } return valid; } bool ValidateLimitations::validateOperation(TIntermOperator* node, TIntermNode* operand) { // Check if loop index is modified in the loop body. if (!withinLoopBody() || !node->modifiesState()) return true; const TIntermSymbol* symbol = operand->getAsSymbolNode(); if (symbol && isLoopIndex(symbol)) { error(node->getLine(), "Loop index cannot be statically assigned to within the body of the loop", symbol->getSymbol().c_str()); } return true; } bool ValidateLimitations::isConstExpr(TIntermNode* node) { ASSERT(node); return node->getAsConstantUnion() != nullptr; } bool ValidateLimitations::isConstIndexExpr(TIntermNode* node) { ASSERT(node); ValidateConstIndexExpr validate(mLoopStack); node->traverse(&validate); return validate.isValid(); } bool ValidateLimitations::validateIndexing(TIntermBinary* node) { ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect)); bool valid = true; TIntermTyped* index = node->getRight(); // The index expression must have integral type. if (!index->isScalarInt()) { error(index->getLine(), "Index expression must have integral type", index->getCompleteString().c_str()); valid = false; } // The index expession must be a constant-index-expression unless // the operand is a uniform in a vertex shader. TIntermTyped* operand = node->getLeft(); bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform); if (!skip && !isConstIndexExpr(index)) { error(index->getLine(), "Index expression must be constant", "[]"); valid = false; } return valid; }