/*
* Copyright (C) 2011 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_CORRECTION_H
#define LATINIME_CORRECTION_H
#include <cassert>
#include <cstring> // for memset()
#include <stdint.h>
#include "correction_state.h"
#include "defines.h"
#include "proximity_info_state.h"
namespace latinime {
class ProximityInfo;
class Correction {
public:
typedef enum {
TRAVERSE_ALL_ON_TERMINAL,
TRAVERSE_ALL_NOT_ON_TERMINAL,
UNRELATED,
ON_TERMINAL,
NOT_ON_TERMINAL
} CorrectionType;
Correction()
: mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false),
mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0),
mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0),
mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0),
mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0),
mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0),
mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false),
mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false),
mSkipping(false), mProximityInfoState() {
memset(mWord, 0, sizeof(mWord));
memset(mDistances, 0, sizeof(mDistances));
memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable));
// NOTE: mCorrectionStates is an array of instances.
// No need to initialize it explicitly here.
}
virtual ~Correction() {}
void resetCorrection();
void initCorrection(
const ProximityInfo *pi, const int inputSize, const int maxWordLength);
void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll);
// TODO: remove
void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance,
const bool doAutoCompletion, const int maxErrors);
void checkState();
bool sameAsTyped();
bool initProcessState(const int index);
int getInputIndex() const;
bool needsToPrune() const;
int pushAndGetTotalTraverseCount() {
return ++mTotalTraverseCount;
}
int getFreqForSplitMultipleWords(
const int *freqArray, const int *wordLengthArray, const int wordCount,
const bool isSpaceProximity, const unsigned short *word);
int getFinalProbability(const int probability, unsigned short **word, int *wordLength);
int getFinalProbabilityForSubQueue(const int probability, unsigned short **word,
int *wordLength, const int inputSize);
CorrectionType processCharAndCalcState(const int32_t c, const bool isTerminal);
/////////////////////////
// Tree helper methods
int goDownTree(const int parentIndex, const int childCount, const int firstChildPos);
inline int getTreeSiblingPos(const int index) const {
return mCorrectionStates[index].mSiblingPos;
}
inline void setTreeSiblingPos(const int index, const int pos) {
mCorrectionStates[index].mSiblingPos = pos;
}
inline int getTreeParentIndex(const int index) const {
return mCorrectionStates[index].mParentIndex;
}
class RankingAlgorithm {
public:
static int calculateFinalProbability(const int inputIndex, const int depth,
const int probability, int *editDistanceTable, const Correction *correction,
const int inputSize);
static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
const int wordCount, const Correction *correction, const bool isSpaceProximity,
const unsigned short *word);
static float calcNormalizedScore(const unsigned short *before, const int beforeLength,
const unsigned short *after, const int afterLength, const int score);
static int editDistance(const unsigned short *before,
const int beforeLength, const unsigned short *after, const int afterLength);
private:
static const int CODE_SPACE = ' ';
static const int MAX_INITIAL_SCORE = 255;
};
// proximity info state
void initInputParams(const ProximityInfo *proximityInfo, const int32_t *inputCodes,
const int inputSize, const int *xCoordinates, const int *yCoordinates) {
mProximityInfoState.initInputParams(0, MAX_POINT_TO_KEY_LENGTH,
proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
}
const unsigned short *getPrimaryInputWord() const {
return mProximityInfoState.getPrimaryInputWord();
}
unsigned short getPrimaryCharAt(const int index) const {
return mProximityInfoState.getPrimaryCharAt(index);
}
private:
DISALLOW_COPY_AND_ASSIGN(Correction);
/////////////////////////
// static inline utils //
/////////////////////////
static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
}
static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
inline static void multiplyIntCapped(const int multiplier, int *base) {
const int temp = *base;
if (temp != S_INT_MAX) {
// Branch if multiplier == 2 for the optimization
if (multiplier < 0) {
if (DEBUG_DICT) {
assert(false);
}
AKLOGI("--- Invalid multiplier: %d", multiplier);
} else if (multiplier == 0) {
*base = 0;
} else if (multiplier == 2) {
*base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
} else {
// TODO: This overflow check gives a wrong answer when, for example,
// temp = 2^16 + 1 and multiplier = 2^17 + 1.
// Fix this behavior.
const int tempRetval = temp * multiplier;
*base = tempRetval >= temp ? tempRetval : S_INT_MAX;
}
}
}
inline static int powerIntCapped(const int base, const int n) {
if (n <= 0) return 1;
if (base == 2) {
return n < 31 ? 1 << n : S_INT_MAX;
} else {
int ret = base;
for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
return ret;
}
}
inline static void multiplyRate(const int rate, int *freq) {
if (*freq != S_INT_MAX) {
if (*freq > 1000000) {
*freq /= 100;
multiplyIntCapped(rate, freq);
} else {
multiplyIntCapped(rate, freq);
*freq /= 100;
}
}
}
inline int getSpaceProximityPos() const {
return mSpaceProximityPos;
}
inline int getMissingSpacePos() const {
return mMissingSpacePos;
}
inline int getSkipPos() const {
return mSkipPos;
}
inline int getExcessivePos() const {
return mExcessivePos;
}
inline int getTransposedPos() const {
return mTransposedPos;
}
inline void incrementInputIndex();
inline void incrementOutputIndex();
inline void startToTraverseAllNodes();
inline bool isSingleQuote(const unsigned short c);
inline CorrectionType processSkipChar(
const int32_t c, const bool isTerminal, const bool inputIndexIncremented);
inline CorrectionType processUnrelatedCorrectionType();
inline void addCharToCurrentWord(const int32_t c);
inline int getFinalProbabilityInternal(const int probability, unsigned short **word,
int *wordLength, const int inputSize);
static const int TYPED_LETTER_MULTIPLIER = 2;
static const int FULL_WORD_MULTIPLIER = 2;
const ProximityInfo *mProximityInfo;
bool mUseFullEditDistance;
bool mDoAutoCompletion;
int mMaxEditDistance;
int mMaxDepth;
int mInputSize;
int mSpaceProximityPos;
int mMissingSpacePos;
int mTerminalInputIndex;
int mTerminalOutputIndex;
int mMaxErrors;
uint8_t mTotalTraverseCount;
// The following arrays are state buffer.
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
int mDistances[MAX_WORD_LENGTH_INTERNAL];
// Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
// Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
int mEditDistanceTable[(MAX_WORD_LENGTH_INTERNAL + 1) * (MAX_WORD_LENGTH_INTERNAL + 1)];
CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
// The following member variables are being used as cache values of the correction state.
bool mNeedsToTraverseAllNodes;
int mOutputIndex;
int mInputIndex;
int mEquivalentCharCount;
int mProximityCount;
int mExcessiveCount;
int mTransposedCount;
int mSkippedCount;
int mTransposedPos;
int mExcessivePos;
int mSkipPos;
bool mLastCharExceeded;
bool mMatching;
bool mProximityMatching;
bool mAdditionalProximityMatching;
bool mExceeding;
bool mTransposing;
bool mSkipping;
ProximityInfoState mProximityInfoState;
};
} // namespace latinime
#endif // LATINIME_CORRECTION_H