/*
 * Copyright (C) 2011 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LATINIME_CORRECTION_H
#define LATINIME_CORRECTION_H

#include <cstring> // for memset()

#include "correction_state.h"
#include "defines.h"
#include "proximity_info_state.h"

namespace latinime {

class ProximityInfo;

class Correction {
 public:
    typedef enum {
        TRAVERSE_ALL_ON_TERMINAL,
        TRAVERSE_ALL_NOT_ON_TERMINAL,
        UNRELATED,
        ON_TERMINAL,
        NOT_ON_TERMINAL
    } CorrectionType;

    Correction()
            : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false),
              mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0),
              mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0),
              mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0),
              mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0),
              mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0),
              mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false),
              mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false),
              mSkipping(false), mProximityInfoState() {
        memset(mWord, 0, sizeof(mWord));
        memset(mDistances, 0, sizeof(mDistances));
        memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable));
        // NOTE: mCorrectionStates is an array of instances.
        // No need to initialize it explicitly here.
    }

    // Non virtual inline destructor -- never inherit this class
    ~Correction() {}
    void resetCorrection();
    void initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth);
    void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll);

    // TODO: remove
    void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
            const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance,
            const bool doAutoCompletion, const int maxErrors);
    void checkState() const;
    bool sameAsTyped() const;
    bool initProcessState(const int index);

    int getInputIndex() const;

    bool needsToPrune() const;

    int pushAndGetTotalTraverseCount() {
        return ++mTotalTraverseCount;
    }

    int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
            const int wordCount, const bool isSpaceProximity, const int *word) const;
    int getFinalProbability(const int probability, int **word, int *wordLength);
    int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength,
            const int inputSize);

    CorrectionType processCharAndCalcState(const int c, const bool isTerminal);

    /////////////////////////
    // Tree helper methods
    int goDownTree(const int parentIndex, const int childCount, const int firstChildPos);

    inline int getTreeSiblingPos(const int index) const {
        return mCorrectionStates[index].mSiblingPos;
    }

    inline void setTreeSiblingPos(const int index, const int pos) {
        mCorrectionStates[index].mSiblingPos = pos;
    }

    inline int getTreeParentIndex(const int index) const {
        return mCorrectionStates[index].mParentIndex;
    }

    class RankingAlgorithm {
     public:
        static int calculateFinalProbability(const int inputIndex, const int depth,
                const int probability, int *editDistanceTable, const Correction *correction,
                const int inputSize);
        static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
                const int wordCount, const Correction *correction, const bool isSpaceProximity,
                const int *word);
        static float calcNormalizedScore(const int *before, const int beforeLength,
                const int *after, const int afterLength, const int score);
        static int editDistance(const int *before, const int beforeLength, const int *after,
                const int afterLength);
     private:
        static const int MAX_INITIAL_SCORE = 255;
    };

    // proximity info state
    void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes,
            const int inputSize, const int *xCoordinates, const int *yCoordinates) {
        mProximityInfoState.initInputParams(0, static_cast<float>(MAX_VALUE_FOR_WEIGHTING),
                proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
    }

    const int *getPrimaryInputWord() const {
        return mProximityInfoState.getPrimaryInputWord();
    }

    int getPrimaryCodePointAt(const int index) const {
        return mProximityInfoState.getPrimaryCodePointAt(index);
    }

 private:
    DISALLOW_COPY_AND_ASSIGN(Correction);

    /////////////////////////
    // static inline utils //
    /////////////////////////
    static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
    static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
        return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
    }

    static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
    AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) {
        const int temp = *base;
        if (temp != S_INT_MAX) {
            // Branch if multiplier == 2 for the optimization
            if (multiplier < 0) {
                if (DEBUG_DICT) {
                    ASSERT(false);
                }
                AKLOGI("--- Invalid multiplier: %d", multiplier);
            } else if (multiplier == 0) {
                *base = 0;
            } else if (multiplier == 2) {
                *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
            } else {
                // TODO: This overflow check gives a wrong answer when, for example,
                //       temp = 2^16 + 1 and multiplier = 2^17 + 1.
                //       Fix this behavior.
                const int tempRetval = temp * multiplier;
                *base = tempRetval >= temp ? tempRetval : S_INT_MAX;
            }
        }
    }

    AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) {
        if (n <= 0) return 1;
        if (base == 2) {
            return n < 31 ? 1 << n : S_INT_MAX;
        }
        int ret = base;
        for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
        return ret;
    }

    AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) {
        if (*freq != S_INT_MAX) {
            if (*freq > 1000000) {
                *freq /= 100;
                multiplyIntCapped(rate, freq);
            } else {
                multiplyIntCapped(rate, freq);
                *freq /= 100;
            }
        }
    }

    inline int getSpaceProximityPos() const {
        return mSpaceProximityPos;
    }
    inline int getMissingSpacePos() const {
        return mMissingSpacePos;
    }

    inline int getSkipPos() const {
        return mSkipPos;
    }

    inline int getExcessivePos() const {
        return mExcessivePos;
    }

    inline int getTransposedPos() const {
        return mTransposedPos;
    }

    inline void incrementInputIndex();
    inline void incrementOutputIndex();
    inline void startToTraverseAllNodes();
    inline bool isSingleQuote(const int c);
    inline CorrectionType processSkipChar(const int c, const bool isTerminal,
            const bool inputIndexIncremented);
    inline CorrectionType processUnrelatedCorrectionType();
    inline void addCharToCurrentWord(const int c);
    inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength,
            const int inputSize);

    static const int TYPED_LETTER_MULTIPLIER = 2;
    static const int FULL_WORD_MULTIPLIER = 2;
    const ProximityInfo *mProximityInfo;

    bool mUseFullEditDistance;
    bool mDoAutoCompletion;
    int mMaxEditDistance;
    int mMaxDepth;
    int mInputSize;
    int mSpaceProximityPos;
    int mMissingSpacePos;
    int mTerminalInputIndex;
    int mTerminalOutputIndex;
    int mMaxErrors;

    int mTotalTraverseCount;

    // The following arrays are state buffer.
    int mWord[MAX_WORD_LENGTH];
    int mDistances[MAX_WORD_LENGTH];

    // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
    // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
    int mEditDistanceTable[(MAX_WORD_LENGTH + 1) * (MAX_WORD_LENGTH + 1)];

    CorrectionState mCorrectionStates[MAX_WORD_LENGTH];

    // The following member variables are being used as cache values of the correction state.
    bool mNeedsToTraverseAllNodes;
    int mOutputIndex;
    int mInputIndex;

    int mEquivalentCharCount;
    int mProximityCount;
    int mExcessiveCount;
    int mTransposedCount;
    int mSkippedCount;

    int mTransposedPos;
    int mExcessivePos;
    int mSkipPos;

    bool mLastCharExceeded;

    bool mMatching;
    bool mProximityMatching;
    bool mAdditionalProximityMatching;
    bool mExceeding;
    bool mTransposing;
    bool mSkipping;
    ProximityInfoState mProximityInfoState;
};

inline void Correction::incrementInputIndex() {
    ++mInputIndex;
}

AK_FORCE_INLINE void Correction::incrementOutputIndex() {
    ++mOutputIndex;
    mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex;
    mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount;
    mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
    mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
    mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;

    mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount;
    mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount;
    mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount;
    mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount;
    mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;

    mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
    mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos;
    mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos;

    mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded;

    mCorrectionStates[mOutputIndex].mMatching = mMatching;
    mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
    mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching;
    mCorrectionStates[mOutputIndex].mTransposing = mTransposing;
    mCorrectionStates[mOutputIndex].mExceeding = mExceeding;
    mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
}

inline void Correction::startToTraverseAllNodes() {
    mNeedsToTraverseAllNodes = true;
}

AK_FORCE_INLINE bool Correction::isSingleQuote(const int c) {
    const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex);
    return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE);
}

AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c,
        const bool isTerminal, const bool inputIndexIncremented) {
    addCharToCurrentWord(c);
    mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
    mTerminalOutputIndex = mOutputIndex;
    incrementOutputIndex();
    if (mNeedsToTraverseAllNodes && isTerminal) {
        return TRAVERSE_ALL_ON_TERMINAL;
    }
    return TRAVERSE_ALL_NOT_ON_TERMINAL;
}

inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() {
    // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType
    mTerminalInputIndex = mInputIndex;
    mTerminalOutputIndex = mOutputIndex;
    return UNRELATED;
}

AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input,
        const int inputSize, const int *output, const int outputLength) {
    // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH] is not touched.
    // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j].
    // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated,
    // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize].
    int *const current = editDistanceTable + outputLength * (inputSize + 1);
    const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1);
    const int *const prevprev =
            outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0;
    current[0] = outputLength;
    const int co = toBaseLowerCase(output[outputLength - 1]);
    const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0;
    for (int i = 1; i <= inputSize; ++i) {
        const int ci = toBaseLowerCase(input[i - 1]);
        const int cost = (ci == co) ? 0 : 1;
        current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
        if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) {
            current[i] = min(current[i], prevprev[i - 2] + 1);
        }
    }
}

AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) {
    mWord[mOutputIndex] = c;
    const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord();
    calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord,
            mOutputIndex + 1);
}

inline int Correction::getFinalProbabilityInternal(const int probability, int **word,
        int *wordLength, const int inputSize) {
    const int outputIndex = mTerminalOutputIndex;
    const int inputIndex = mTerminalInputIndex;
    *wordLength = outputIndex + 1;
    *word = mWord;
    int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability(
            inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize);
    return finalProbability;
}

} // namespace latinime
#endif // LATINIME_CORRECTION_H