/*
 * Copyright (C) 2011 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LATINIME_CORRECTION_H
#define LATINIME_CORRECTION_H

#include <cassert>
#include <cstring> // for memset()
#include <stdint.h>

#include "correction_state.h"
#include "defines.h"
#include "proximity_info_state.h"

namespace latinime {

class ProximityInfo;

class Correction {
 public:
    typedef enum {
        TRAVERSE_ALL_ON_TERMINAL,
        TRAVERSE_ALL_NOT_ON_TERMINAL,
        UNRELATED,
        ON_TERMINAL,
        NOT_ON_TERMINAL
    } CorrectionType;

    Correction()
            : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false),
              mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0),
              mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0),
              mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0),
              mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0),
              mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0),
              mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false),
              mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false),
              mSkipping(false), mProximityInfoState() {
        memset(mWord, 0, sizeof(mWord));
        memset(mDistances, 0, sizeof(mDistances));
        memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable));
        // NOTE: mCorrectionStates is an array of instances.
        // No need to initialize it explicitly here.
    }

    virtual ~Correction() {}
    void resetCorrection();
    void initCorrection(
            const ProximityInfo *pi, const int inputSize, const int maxWordLength);
    void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll);

    // TODO: remove
    void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
            const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance,
            const bool doAutoCompletion, const int maxErrors);
    void checkState();
    bool sameAsTyped();
    bool initProcessState(const int index);

    int getInputIndex() const;

    bool needsToPrune() const;

    int pushAndGetTotalTraverseCount() {
        return ++mTotalTraverseCount;
    }

    int getFreqForSplitMultipleWords(
            const int *freqArray, const int *wordLengthArray, const int wordCount,
            const bool isSpaceProximity, const unsigned short *word);
    int getFinalProbability(const int probability, unsigned short **word, int *wordLength);
    int getFinalProbabilityForSubQueue(const int probability, unsigned short **word,
            int *wordLength, const int inputSize);

    CorrectionType processCharAndCalcState(const int32_t c, const bool isTerminal);

    /////////////////////////
    // Tree helper methods
    int goDownTree(const int parentIndex, const int childCount, const int firstChildPos);

    inline int getTreeSiblingPos(const int index) const {
        return mCorrectionStates[index].mSiblingPos;
    }

    inline void setTreeSiblingPos(const int index, const int pos) {
        mCorrectionStates[index].mSiblingPos = pos;
    }

    inline int getTreeParentIndex(const int index) const {
        return mCorrectionStates[index].mParentIndex;
    }

    class RankingAlgorithm {
     public:
        static int calculateFinalProbability(const int inputIndex, const int depth,
                const int probability, int *editDistanceTable, const Correction *correction,
                const int inputSize);
        static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
                const int wordCount, const Correction *correction, const bool isSpaceProximity,
                const unsigned short *word);
        static float calcNormalizedScore(const unsigned short *before, const int beforeLength,
                const unsigned short *after, const int afterLength, const int score);
        static int editDistance(const unsigned short *before,
                const int beforeLength, const unsigned short *after, const int afterLength);
     private:
        static const int CODE_SPACE = ' ';
        static const int MAX_INITIAL_SCORE = 255;
    };

    // proximity info state
    void initInputParams(const ProximityInfo *proximityInfo, const int32_t *inputCodes,
            const int inputSize, const int *xCoordinates, const int *yCoordinates) {
        mProximityInfoState.initInputParams(0, MAX_POINT_TO_KEY_LENGTH,
                proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
    }

    const unsigned short *getPrimaryInputWord() const {
        return mProximityInfoState.getPrimaryInputWord();
    }

    unsigned short getPrimaryCharAt(const int index) const {
        return mProximityInfoState.getPrimaryCharAt(index);
    }

 private:
    DISALLOW_COPY_AND_ASSIGN(Correction);

    /////////////////////////
    // static inline utils //
    /////////////////////////
    static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
    static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
        return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
    }

    static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
    inline static void multiplyIntCapped(const int multiplier, int *base) {
        const int temp = *base;
        if (temp != S_INT_MAX) {
            // Branch if multiplier == 2 for the optimization
            if (multiplier < 0) {
                if (DEBUG_DICT) {
                    assert(false);
                }
                AKLOGI("--- Invalid multiplier: %d", multiplier);
            } else if (multiplier == 0) {
                *base = 0;
            } else if (multiplier == 2) {
                *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
            } else {
                // TODO: This overflow check gives a wrong answer when, for example,
                //       temp = 2^16 + 1 and multiplier = 2^17 + 1.
                //       Fix this behavior.
                const int tempRetval = temp * multiplier;
                *base = tempRetval >= temp ? tempRetval : S_INT_MAX;
            }
        }
    }

    inline static int powerIntCapped(const int base, const int n) {
        if (n <= 0) return 1;
        if (base == 2) {
            return n < 31 ? 1 << n : S_INT_MAX;
        } else {
            int ret = base;
            for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
            return ret;
        }
    }

    inline static void multiplyRate(const int rate, int *freq) {
        if (*freq != S_INT_MAX) {
            if (*freq > 1000000) {
                *freq /= 100;
                multiplyIntCapped(rate, freq);
            } else {
                multiplyIntCapped(rate, freq);
                *freq /= 100;
            }
        }
    }

    inline int getSpaceProximityPos() const {
        return mSpaceProximityPos;
    }
    inline int getMissingSpacePos() const {
        return mMissingSpacePos;
    }

    inline int getSkipPos() const {
        return mSkipPos;
    }

    inline int getExcessivePos() const {
        return mExcessivePos;
    }

    inline int getTransposedPos() const {
        return mTransposedPos;
    }

    inline void incrementInputIndex();
    inline void incrementOutputIndex();
    inline void startToTraverseAllNodes();
    inline bool isSingleQuote(const unsigned short c);
    inline CorrectionType processSkipChar(
            const int32_t c, const bool isTerminal, const bool inputIndexIncremented);
    inline CorrectionType processUnrelatedCorrectionType();
    inline void addCharToCurrentWord(const int32_t c);
    inline int getFinalProbabilityInternal(const int probability, unsigned short **word,
            int *wordLength, const int inputSize);

    static const int TYPED_LETTER_MULTIPLIER = 2;
    static const int FULL_WORD_MULTIPLIER = 2;
    const ProximityInfo *mProximityInfo;

    bool mUseFullEditDistance;
    bool mDoAutoCompletion;
    int mMaxEditDistance;
    int mMaxDepth;
    int mInputSize;
    int mSpaceProximityPos;
    int mMissingSpacePos;
    int mTerminalInputIndex;
    int mTerminalOutputIndex;
    int mMaxErrors;

    uint8_t mTotalTraverseCount;

    // The following arrays are state buffer.
    unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
    int mDistances[MAX_WORD_LENGTH_INTERNAL];

    // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
    // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
    int mEditDistanceTable[(MAX_WORD_LENGTH_INTERNAL + 1) * (MAX_WORD_LENGTH_INTERNAL + 1)];

    CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];

    // The following member variables are being used as cache values of the correction state.
    bool mNeedsToTraverseAllNodes;
    int mOutputIndex;
    int mInputIndex;

    int mEquivalentCharCount;
    int mProximityCount;
    int mExcessiveCount;
    int mTransposedCount;
    int mSkippedCount;

    int mTransposedPos;
    int mExcessivePos;
    int mSkipPos;

    bool mLastCharExceeded;

    bool mMatching;
    bool mProximityMatching;
    bool mAdditionalProximityMatching;
    bool mExceeding;
    bool mTransposing;
    bool mSkipping;
    ProximityInfoState mProximityInfoState;
};
} // namespace latinime
#endif // LATINIME_CORRECTION_H