C++程序  |  186行  |  5.24 KB

//
// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//

#include "compiler/DetectCallDepth.h"
#include "compiler/InfoSink.h"

DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
    : name(fname),
      visit(PreVisit)
{
}

const TString& DetectCallDepth::FunctionNode::getName() const
{
    return name;
}

void DetectCallDepth::FunctionNode::addCallee(
    DetectCallDepth::FunctionNode* callee)
{
    for (size_t i = 0; i < callees.size(); ++i) {
        if (callees[i] == callee)
            return;
    }
    callees.push_back(callee);
}

int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
{
    ASSERT(visit == PreVisit);
    ASSERT(detectCallDepth);

    int maxDepth = depth;
    visit = InVisit;
    for (size_t i = 0; i < callees.size(); ++i) {
        switch (callees[i]->visit) {
            case InVisit:
                // cycle detected, i.e., recursion detected.
                return kInfiniteCallDepth;
            case PostVisit:
                break;
            case PreVisit: {
                // Check before we recurse so we don't go too depth
                if (detectCallDepth->checkExceedsMaxDepth(depth))
                    return depth;
                int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
                // Check after we recurse so we can exit immediately and provide info.
                if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
                    detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
                    return callDepth;
                }
                maxDepth = std::max(callDepth, maxDepth);
                break;
            }
            default:
                UNREACHABLE();
                break;
        }
    }
    visit = PostVisit;
    return maxDepth;
}

void DetectCallDepth::FunctionNode::reset()
{
    visit = PreVisit;
}

DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
    : TIntermTraverser(true, false, true, false),
      currentFunction(NULL),
      infoSink(infoSink),
      maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
{
}

DetectCallDepth::~DetectCallDepth()
{
    for (size_t i = 0; i < functions.size(); ++i)
        delete functions[i];
}

bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
{
    switch (node->getOp())
    {
        case EOpPrototype:
            // Function declaration.
            // Don't add FunctionNode here because node->getName() is the
            // unmangled function name.
            break;
        case EOpFunction: {
            // Function definition.
            if (visit == PreVisit) {
                currentFunction = findFunctionByName(node->getName());
                if (currentFunction == NULL) {
                    currentFunction = new FunctionNode(node->getName());
                    functions.push_back(currentFunction);
                }
            } else if (visit == PostVisit) {
                currentFunction = NULL;
            }
            break;
        }
        case EOpFunctionCall: {
            // Function call.
            if (visit == PreVisit) {
                FunctionNode* func = findFunctionByName(node->getName());
                if (func == NULL) {
                    func = new FunctionNode(node->getName());
                    functions.push_back(func);
                }
                if (currentFunction)
                    currentFunction->addCallee(func);
            }
            break;
        }
        default:
            break;
    }
    return true;
}

bool DetectCallDepth::checkExceedsMaxDepth(int depth)
{
    return depth >= maxDepth;
}

void DetectCallDepth::resetFunctionNodes()
{
    for (size_t i = 0; i < functions.size(); ++i) {
        functions[i]->reset();
    }
}

DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
{
    currentFunction = NULL;
    resetFunctionNodes();

    int maxCallDepth = func->detectCallDepth(this, 1);

    if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
        return kErrorRecursion;

    if (maxCallDepth >= maxDepth)
        return kErrorMaxDepthExceeded;

    return kErrorNone;
}

DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
{
    if (maxDepth != FunctionNode::kInfiniteCallDepth) {
        // Check all functions because the driver may fail on them
        // TODO: Before detectingRecursion, strip unused functions.
        for (size_t i = 0; i < functions.size(); ++i) {
            ErrorCode error = detectCallDepthForFunction(functions[i]);
            if (error != kErrorNone)
                return error;
        }
    } else {
        FunctionNode* main = findFunctionByName("main(");
        if (main == NULL)
            return kErrorMissingMain;

        return detectCallDepthForFunction(main);
    }

    return kErrorNone;
}

DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
    const TString& name)
{
    for (size_t i = 0; i < functions.size(); ++i) {
        if (functions[i]->getName() == name)
            return functions[i];
    }
    return NULL;
}