/*
 * Copyright (C) 2008 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
 * Operations on an Object.
 */
#include "Dalvik.h"

/*
 * Find a matching field, in the current class only.
 *
 * Returns NULL if the field can't be found.  (Does not throw an exception.)
 */
InstField* dvmFindInstanceField(const ClassObject* clazz,
    const char* fieldName, const char* signature)
{
    InstField* pField;
    int i;

    assert(clazz != NULL);

    /*
     * Find a field with a matching name and signature.  The Java programming
     * language does not allow you to have two fields with the same name
     * and different types, but the Java VM spec does allow it, so we can't
     * bail out early when the name matches.
     */
    pField = clazz->ifields;
    for (i = 0; i < clazz->ifieldCount; i++, pField++) {
        if (strcmp(fieldName, pField->field.name) == 0 &&
            strcmp(signature, pField->field.signature) == 0)
        {
            return pField;
        }
    }

    return NULL;
}

/*
 * Find a matching field, in this class or a superclass.
 *
 * Searching through interfaces isn't necessary, because interface fields
 * are inherently public/static/final.
 *
 * Returns NULL if the field can't be found.  (Does not throw an exception.)
 */
InstField* dvmFindInstanceFieldHier(const ClassObject* clazz,
    const char* fieldName, const char* signature)
{
    InstField* pField;

    /*
     * Search for a match in the current class.
     */
    pField = dvmFindInstanceField(clazz, fieldName, signature);
    if (pField != NULL)
        return pField;

    if (clazz->super != NULL)
        return dvmFindInstanceFieldHier(clazz->super, fieldName, signature);
    else
        return NULL;
}


/*
 * Find a matching field, in this class or an interface.
 *
 * Returns NULL if the field can't be found.  (Does not throw an exception.)
 */
StaticField* dvmFindStaticField(const ClassObject* clazz,
    const char* fieldName, const char* signature)
{
    StaticField* pField;
    int i;

    assert(clazz != NULL);

    pField = clazz->sfields;
    for (i = 0; i < clazz->sfieldCount; i++, pField++) {
        if (strcmp(fieldName, pField->field.name) == 0) {
            /*
             * The name matches.  Unlike methods, we can't have two fields
             * with the same names but differing types.
             */
            if (strcmp(signature, pField->field.signature) != 0) {
                LOGW("Found field '%s', but sig is '%s' not '%s'\n",
                    fieldName, pField->field.signature, signature);
                return NULL;
            }
            return pField;
        }
    }

    return NULL;
}

/*
 * Find a matching field, in this class or a superclass.
 *
 * Returns NULL if the field can't be found.  (Does not throw an exception.)
 */
StaticField* dvmFindStaticFieldHier(const ClassObject* clazz,
    const char* fieldName, const char* signature)
{
    StaticField* pField;

    /*
     * Search for a match in the current class.
     */
    pField = dvmFindStaticField(clazz, fieldName, signature);
    if (pField != NULL)
        return pField;

    /*
     * See if it's in any of our interfaces.  We don't check interfaces
     * inherited from the superclass yet.
     *
     * (Note the set may have been stripped down because of redundancy with
     * the superclass; see notes in createIftable.)
     */
    int i = 0;
    if (clazz->super != NULL) {
        assert(clazz->iftableCount >= clazz->super->iftableCount);
        i = clazz->super->iftableCount;
    }
    for ( ; i < clazz->iftableCount; i++) {
        ClassObject* iface = clazz->iftable[i].clazz;
        pField = dvmFindStaticField(iface, fieldName, signature);
        if (pField != NULL)
            return pField;
    }

    if (clazz->super != NULL)
        return dvmFindStaticFieldHier(clazz->super, fieldName, signature);
    else
        return NULL;
}

/*
 * Compare the given name, return type, and argument types with the contents
 * of the given method. This returns 0 if they are equal and non-zero if not.
 */
static inline int compareMethodHelper(Method* method, const char* methodName,
    const char* returnType, size_t argCount, const char** argTypes)
{
    DexParameterIterator iterator;
    const DexProto* proto;

    if (strcmp(methodName, method->name) != 0) {
        return 1;
    }

    proto = &method->prototype;
        
    if (strcmp(returnType, dexProtoGetReturnType(proto)) != 0) {
        return 1;
    }

    if (dexProtoGetParameterCount(proto) != argCount) {
        return 1;
    }

    dexParameterIteratorInit(&iterator, proto);

    for (/*argCount*/; argCount != 0; argCount--, argTypes++) {
        const char* argType = *argTypes;
        const char* paramType = dexParameterIteratorNextDescriptor(&iterator);

        if (paramType == NULL) {
            /* Param list ended early; no match */
            break;
        } else if (strcmp(argType, paramType) != 0) {
            /* Types aren't the same; no match. */
            break;
        }
    }

    if (argCount == 0) {
        /* We ran through all the given arguments... */
        if (dexParameterIteratorNextDescriptor(&iterator) == NULL) {
            /* ...and through all the method's arguments; success! */
            return 0;
        }
    }

    return 1;
}

/*
 * Get the count of arguments in the given method descriptor string,
 * and also find a pointer to the return type.
 */
static inline size_t countArgsAndFindReturnType(const char* descriptor,
    const char** pReturnType) 
{
    size_t count = 0;
    bool bogus = false;
    bool done = false;

    assert(*descriptor == '(');
    descriptor++;
    
    while (!done) {
        switch (*descriptor) {
            case 'B': case 'C': case 'D': case 'F':
            case 'I': case 'J': case 'S': case 'Z': {
                count++;
                break;
            }
            case '[': {
                do {
                    descriptor++;
                } while (*descriptor == '[');
                /*
                 * Don't increment count, as it will be taken care of
                 * by the next iteration. Also, decrement descriptor
                 * to compensate for the increment below the switch.
                 */
                descriptor--;
                break;
            }
            case 'L': {
                do {
                    descriptor++;
                } while ((*descriptor != ';') && (*descriptor != '\0'));
                count++;
                if (*descriptor == '\0') {
                    /* Bogus descriptor. */
                    done = true;
                    bogus = true;
                }
                break;
            }
            case ')': {
                /* 
                 * Note: The loop will exit after incrementing descriptor
                 * one more time, so it then points at the return type.
                 */
                done = true;
                break;
            }
            default: {
                /* Bogus descriptor. */
                done = true;
                bogus = true;
                break;
            }
        }

        descriptor++;
    }

    if (bogus) {
        *pReturnType = NULL;
        return 0;
    }

    *pReturnType = descriptor;
    return count;
}

/*
 * Copy the argument types into the given array using the given buffer
 * for the contents.
 */
static inline void copyTypes(char* buffer, const char** argTypes,
    size_t argCount, const char* descriptor)
{
    size_t i;
    char c;

    /* Skip the '('. */
    descriptor++;
    
    for (i = 0; i < argCount; i++) {
        argTypes[i] = buffer;

        /* Copy all the array markers and one extra character. */
        do {
            c = *(descriptor++);
            *(buffer++) = c;
        } while (c == '[');

        if (c == 'L') {
            /* Copy the rest of a class name. */
            do {
                c = *(descriptor++);
                *(buffer++) = c;
            } while (c != ';');
        }
        
        *(buffer++) = '\0';
    }        
}

/*
 * Look for a match in the given class. Returns the match if found
 * or NULL if not.
 */
static Method* findMethodInListByDescriptor(const ClassObject* clazz,
    bool findVirtual, bool isHier, const char* name, const char* descriptor)
{
    const char* returnType;
    size_t argCount = countArgsAndFindReturnType(descriptor, &returnType);

    if (returnType == NULL) {
        LOGW("Bogus method descriptor: %s\n", descriptor);
        return NULL;
    }

    /*
     * Make buffer big enough for all the argument type characters and
     * one '\0' per argument. The "- 2" is because "returnType -
     * descriptor" includes two parens.
     */
    char buffer[argCount + (returnType - descriptor) - 2];
    const char* argTypes[argCount];

    copyTypes(buffer, argTypes, argCount, descriptor);

    while (clazz != NULL) {
        Method* methods;
        size_t methodCount;
        size_t i;

        if (findVirtual) {
            methods = clazz->virtualMethods;
            methodCount = clazz->virtualMethodCount;
        } else {
            methods = clazz->directMethods;
            methodCount = clazz->directMethodCount;
        }
        
        for (i = 0; i < methodCount; i++) {
            Method* method = &methods[i];
            if (compareMethodHelper(method, name, returnType, argCount,
                            argTypes) == 0) {
                return method;
            }
        }

        if (! isHier) {
            break;
        }

        clazz = clazz->super;
    }

    return NULL;
}

/*
 * Look for a match in the given clazz. Returns the match if found
 * or NULL if not.
 */
static Method* findMethodInListByProto(const ClassObject* clazz,
    bool findVirtual, bool isHier, const char* name, const DexProto* proto)
{    
    while (clazz != NULL) {
        Method* methods;
        size_t methodCount;
        size_t i;

        if (findVirtual) {
            methods = clazz->virtualMethods;
            methodCount = clazz->virtualMethodCount;
        } else {
            methods = clazz->directMethods;
            methodCount = clazz->directMethodCount;
        }

        for (i = 0; i < methodCount; i++) {
            Method* method = &methods[i];
            if (dvmCompareNameProtoAndMethod(name, proto, method) == 0) {
                return method;
            }
        }

        if (! isHier) {
            break;
        }

        clazz = clazz->super;
    }

    return NULL;
}

/*
 * Find a "virtual" method in a class.
 *
 * Does not chase into the superclass.
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindVirtualMethodByDescriptor(const ClassObject* clazz,
    const char* methodName, const char* descriptor)
{
    return findMethodInListByDescriptor(clazz, true, false,
            methodName, descriptor);

    // TODO? - throw IncompatibleClassChangeError if a match is
    // found in the directMethods list, rather than NotFoundError.
    // Note we could have been called by dvmFindVirtualMethodHier though.
}


/*
 * Find a "virtual" method in a class, knowing only the name.  This is
 * only useful in limited circumstances, e.g. when searching for a member
 * of an annotation class.
 *
 * Does not chase into the superclass.
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindVirtualMethodByName(const ClassObject* clazz,
    const char* methodName)
{
    Method* methods = clazz->virtualMethods;
    int methodCount = clazz->virtualMethodCount;
    int i;

    for (i = 0; i < methodCount; i++) {
        if (strcmp(methods[i].name, methodName) == 0)
            return &methods[i];
    }

    return NULL;
}

/*
 * Find a "virtual" method in a class.
 *
 * Does not chase into the superclass.
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindVirtualMethod(const ClassObject* clazz, const char* methodName,
    const DexProto* proto)
{
    return findMethodInListByProto(clazz, true, false, methodName, proto);
}

/*
 * Find a "virtual" method in a class.  If we don't find it, try the
 * superclass.
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindVirtualMethodHierByDescriptor(const ClassObject* clazz,
    const char* methodName, const char* descriptor)
{
    return findMethodInListByDescriptor(clazz, true, true,
            methodName, descriptor);
}

/*
 * Find a "virtual" method in a class.  If we don't find it, try the
 * superclass.
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindVirtualMethodHier(const ClassObject* clazz,
    const char* methodName, const DexProto* proto)
{
    return findMethodInListByProto(clazz, true, true, methodName, proto);
}

/*
 * Find a "direct" method (static, private, or "<*init>").
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindDirectMethodByDescriptor(const ClassObject* clazz,
    const char* methodName, const char* descriptor)
{
    return findMethodInListByDescriptor(clazz, false, false,
            methodName, descriptor);
}

/*
 * Find a "direct" method.  If we don't find it, try the superclass.  This
 * is only appropriate for static methods, but will work for all direct
 * methods.
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindDirectMethodHierByDescriptor(const ClassObject* clazz,
    const char* methodName, const char* descriptor)
{
    return findMethodInListByDescriptor(clazz, false, true,
            methodName, descriptor);
}

/*
 * Find a "direct" method (static or "<*init>").
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindDirectMethod(const ClassObject* clazz, const char* methodName,
    const DexProto* proto)
{
    return findMethodInListByProto(clazz, false, false, methodName, proto);
}

/*
 * Find a "direct" method in a class.  If we don't find it, try the
 * superclass.
 *
 * Returns NULL if the method can't be found.  (Does not throw an exception.)
 */
Method* dvmFindDirectMethodHier(const ClassObject* clazz,
    const char* methodName, const DexProto* proto)
{
    return findMethodInListByProto(clazz, false, true, methodName, proto);
}

/*
 * We have a method pointer for a method in "clazz", but it might be
 * pointing to a method in a derived class.  We want to find the actual entry
 * from the class' vtable.  If "clazz" is an interface, we have to do a
 * little more digging.
 *
 * (This is used for reflection and JNI "call method" calls.)
 */
const Method* dvmGetVirtualizedMethod(const ClassObject* clazz,
    const Method* meth)
{
    Method* actualMeth;
    int methodIndex;

    assert(!dvmIsStaticMethod(meth));

    if (dvmIsPrivateMethod(meth))   // no vtable entry for these
        return meth;

    /*
     * If the method was declared in an interface, we need to scan through
     * the class' list of interfaces for it, and find the vtable index
     * from that.
     *
     * TODO: use the interface cache.
     */
    if (dvmIsInterfaceClass(meth->clazz)) {
        int i;

        for (i = 0; i < clazz->iftableCount; i++) {
            if (clazz->iftable[i].clazz == meth->clazz)
                break;
        }
        if (i == clazz->iftableCount) {
            dvmThrowException("Ljava/lang/IncompatibleClassChangeError;",
                "invoking method from interface not implemented by class");
            return NULL;
        }

        methodIndex = clazz->iftable[i].methodIndexArray[meth->methodIndex];
    } else {
        methodIndex = meth->methodIndex;
    }

    assert(methodIndex >= 0 && methodIndex < clazz->vtableCount);
    actualMeth = clazz->vtable[methodIndex];

    /*
     * Make sure there's code to execute.
     */
    if (dvmIsAbstractMethod(actualMeth)) {
        dvmThrowException("Ljava/lang/AbstractMethodError;", NULL);
        return NULL;
    }
    assert(!dvmIsMirandaMethod(actualMeth));

    return actualMeth;
}

/*
 * Get the source file for a method.
 */
const char* dvmGetMethodSourceFile(const Method* meth)
{
    /*
     * TODO: A method's debug info can override the default source
     * file for a class, so we should account for that possibility
     * here.
     */
    return meth->clazz->sourceFile;
}

/*
 * Dump some information about an object.
 */
void dvmDumpObject(const Object* obj)
{
    ClassObject* clazz;
    int i;

    if (obj == NULL || obj->clazz == NULL) {
        LOGW("Null or malformed object not dumped\n");
        return;
    }

    clazz = obj->clazz;
    LOGV("----- Object dump: %p (%s, %d bytes) -----\n",
        obj, clazz->descriptor, (int) clazz->objectSize);
    //printHexDump(obj, clazz->objectSize);
    LOGV("  Fields:\n");
    for (i = 0; i < clazz->ifieldCount; i++) {
        const InstField* pField = &clazz->ifields[i];
        char type = pField->field.signature[0];

        if (type == 'F' || type == 'D') {
            double dval;

            if (type == 'F')
                dval = dvmGetFieldFloat(obj, pField->byteOffset);
            else
                dval = dvmGetFieldDouble(obj, pField->byteOffset);

            LOGV("  %2d: '%s' '%s' flg=%04x %.3f\n", i, pField->field.name,
                pField->field.signature, pField->field.accessFlags, dval);
        } else {
            long long lval;

            if (pField->field.signature[0] == 'J')
                lval = dvmGetFieldLong(obj, pField->byteOffset);
            else if (pField->field.signature[0] == 'Z')
                lval = dvmGetFieldBoolean(obj, pField->byteOffset);
            else
                lval = dvmGetFieldInt(obj, pField->byteOffset);

            LOGV("  %2d: '%s' '%s' af=%04x 0x%llx\n", i, pField->field.name,
                pField->field.signature, pField->field.accessFlags, lval);
        }
    }
}