#ifndef _RSGBUILTINFUNCTIONS_HPP
#define _RSGBUILTINFUNCTIONS_HPP
/*-------------------------------------------------------------------------
 * drawElements Quality Program Random Shader Generator
 * ----------------------------------------------------
 *
 * Copyright 2014 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 *//*!
 * \file
 * \brief Built-in Functions.
 *//*--------------------------------------------------------------------*/

#include "rsgDefs.hpp"
#include "rsgExpression.hpp"
#include "rsgUtils.hpp"
#include "deMath.h"

namespace rsg
{

// Template for built-in functions with form "GenType func(GenType val)".
template <class GetValueRangeWeight, class ComputeValueRange, class Evaluate>
class UnaryBuiltinVecFunc : public Expression
{
public:
								UnaryBuiltinVecFunc		(GeneratorState& state, const char* function, ConstValueRangeAccess valueRange);
	virtual						~UnaryBuiltinVecFunc	(void);

	Expression*					createNextChild			(GeneratorState& state);
	void						tokenize				(GeneratorState& state, TokenStream& str) const;

	void						evaluate				(ExecutionContext& execCtx);
	ExecConstValueAccess		getValue				(void) const { return m_value.getValue(m_inValueRange.getType()); }

	static float				getWeight				(const GeneratorState& state, ConstValueRangeAccess valueRange);

private:
	std::string					m_function;
	ValueRange					m_inValueRange;
	ExecValueStorage			m_value;
	Expression*					m_child;
};

template <class GetValueRangeWeight, class ComputeValueRange, class Evaluate>
UnaryBuiltinVecFunc<GetValueRangeWeight, ComputeValueRange, Evaluate>::UnaryBuiltinVecFunc (GeneratorState& state, const char* function, ConstValueRangeAccess valueRange)
	: m_function		(function)
	, m_inValueRange	(valueRange.getType())
	, m_child			(DE_NULL)
{
	DE_UNREF(state);
	DE_ASSERT(valueRange.getType().isFloatOrVec());

	m_value.setStorage(valueRange.getType());

	// Compute input value range
	for (int ndx = 0; ndx < m_inValueRange.getType().getNumElements(); ndx++)
	{
		ConstValueRangeAccess	outRange	= valueRange.component(ndx);
		ValueRangeAccess		inRange		= m_inValueRange.asAccess().component(ndx);

		ComputeValueRange()(outRange.getMin().asFloat(), outRange.getMax().asFloat(), inRange.getMin().asFloat(), inRange.getMax().asFloat());
	}
}

template <class GetValueRangeWeight, class ComputeValueRange, class Evaluate>
UnaryBuiltinVecFunc<GetValueRangeWeight, ComputeValueRange, Evaluate>::~UnaryBuiltinVecFunc (void)
{
	delete m_child;
}

template <class GetValueRangeWeight, class ComputeValueRange, class Evaluate>
Expression* UnaryBuiltinVecFunc<GetValueRangeWeight, ComputeValueRange, Evaluate>::createNextChild (GeneratorState& state)
{
	if (m_child)
		return DE_NULL;

	m_child = Expression::createRandom(state, m_inValueRange);
	return m_child;
}

template <class GetValueRangeWeight, class ComputeValueRange, class Evaluate>
void UnaryBuiltinVecFunc<GetValueRangeWeight, ComputeValueRange, Evaluate>::tokenize (GeneratorState& state, TokenStream& str) const
{
	str << Token(m_function.c_str()) << Token::LEFT_PAREN;
	m_child->tokenize(state, str);
	str << Token::RIGHT_PAREN;
}

template <class GetValueRangeWeight, class ComputeValueRange, class Evaluate>
void UnaryBuiltinVecFunc<GetValueRangeWeight, ComputeValueRange, Evaluate>::evaluate (ExecutionContext& execCtx)
{
	m_child->evaluate(execCtx);

	ExecConstValueAccess	srcValue	= m_child->getValue();
	ExecValueAccess			dstValue	= m_value.getValue(m_inValueRange.getType());

	for (int elemNdx = 0; elemNdx < m_inValueRange.getType().getNumElements(); elemNdx++)
	{
		ExecConstValueAccess	srcComp		= srcValue.component(elemNdx);
		ExecValueAccess			dstComp		= dstValue.component(elemNdx);

		for (int compNdx = 0; compNdx < EXEC_VEC_WIDTH; compNdx++)
			dstComp.asFloat(compNdx) = Evaluate()(srcComp.asFloat(compNdx));
	}
}

template <class GetValueRangeWeight, class ComputeValueRange, class Evaluate>
float UnaryBuiltinVecFunc<GetValueRangeWeight, ComputeValueRange, Evaluate>::getWeight (const GeneratorState& state, ConstValueRangeAccess valueRange)
{
	// \todo [2011-06-14 pyry] Void support?
	if (!valueRange.getType().isFloatOrVec())
		return 0.0f;

	int availableLevels = state.getShaderParameters().maxExpressionDepth - state.getExpressionDepth();

	if (availableLevels < getConservativeValueExprDepth(state, valueRange) + 1)
		return 0.0f;

	// Compute value range weight
	float combinedWeight = 1.0f;
	for (int elemNdx = 0; elemNdx < valueRange.getType().getNumElements(); elemNdx++)
	{
		float elemWeight = GetValueRangeWeight()(valueRange.component(elemNdx).getMin().asFloat(), valueRange.component(elemNdx).getMax().asFloat());
		combinedWeight *= elemWeight;
	}

	return combinedWeight;
}

// Proxy template.
template <class C>
struct GetUnaryBuiltinVecWeight
{
	inline float operator() (float outMin, float outMax) const { return C::getCompWeight(outMin, outMax); }
};

template <class C>
struct ComputeUnaryBuiltinVecRange
{
	inline void operator() (float outMin, float outMax, float& inMin, float& inMax) const { C::computeValueRange(outMin, outMax, inMin, inMax); }
};

template <class C>
struct EvaluateUnaryBuiltinVec
{
	inline float operator() (float inVal) const { return C::evaluateComp(inVal); }
};

template <class C>
class UnaryBuiltinVecTemplateProxy : public UnaryBuiltinVecFunc<GetUnaryBuiltinVecWeight<C>, ComputeUnaryBuiltinVecRange<C>, EvaluateUnaryBuiltinVec<C> >
{
public:
	UnaryBuiltinVecTemplateProxy (GeneratorState& state, const char* function, ConstValueRangeAccess valueRange)
		: UnaryBuiltinVecFunc<GetUnaryBuiltinVecWeight<C>, ComputeUnaryBuiltinVecRange<C>, EvaluateUnaryBuiltinVec<C> >(state, function, valueRange)
	{
	}
};

// Template for trigonometric function group.
template <class C>
class UnaryTrigonometricFunc : public UnaryBuiltinVecTemplateProxy<C>
{
public:
	UnaryTrigonometricFunc (GeneratorState& state, const char* function, ConstValueRangeAccess valueRange)
		: UnaryBuiltinVecTemplateProxy<C>(state, function, valueRange)
	{
	}

	static inline float getCompWeight (float outMin, float outMax)
	{
		if (Scalar::min<float>() == outMin || Scalar::max<float>() == outMax)
			return 1.0f; // Infinite value range, anything goes

		// Transform range
		float inMin, inMax;
		if (!C::transformValueRange(outMin, outMax, inMin, inMax))
			return 0.0f; // Not possible to transform value range (out of range perhaps)

		// Quantize
		if (!quantizeFloatRange(inMin, inMax))
			return 0.0f; // Not possible to quantize - would cause accuracy issues

		if (outMin == outMax)
			return 1.0f; // Constant value and passed quantization

		// Evaluate new intersection
		float intersectionLen	= C::evaluateComp(inMax) - C::evaluateComp(inMin);
		float valRangeLen		= outMax - outMin;

		return deFloatMax(0.1f, intersectionLen/valRangeLen);
	}

	static inline void computeValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		DE_VERIFY(C::transformValueRange(outMin, outMax, inMin, inMax));
		DE_VERIFY(quantizeFloatRange(inMin, inMax));
		DE_ASSERT(inMin <= inMax);
	}

	static float getWeight (const GeneratorState& state, ConstValueRangeAccess valueRange)
	{
		if (state.getProgramParameters().trigonometricBaseWeight <= 0.0f)
			return 0.0f;

		return UnaryBuiltinVecTemplateProxy<C>::getWeight(state, valueRange) * state.getProgramParameters().trigonometricBaseWeight;
	}
};

class SinOp : public UnaryTrigonometricFunc<SinOp>
{
public:
	SinOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryTrigonometricFunc<SinOp>(state, "sin", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		if (outMax < -1.0f || outMin > 1.0f)
			return false;

		inMin = (outMin >= -1.0f) ? deFloatAsin(outMin) : -0.5f*DE_PI;
		inMax = (outMax <= +1.0f) ? deFloatAsin(outMax) : +0.5f*DE_PI;

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatSin(inVal);
	}
};

class CosOp : public UnaryTrigonometricFunc<CosOp>
{
public:
	CosOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryTrigonometricFunc<CosOp>(state, "cos", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		if (outMax < -1.0f || outMin > 1.0f)
			return false;

		inMax = (outMin >= -1.0f) ? deFloatAcos(outMin) : +DE_PI;
		inMin = (outMax <= +1.0f) ? deFloatAcos(outMax) : -DE_PI;

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatCos(inVal);
	}
};

class TanOp : public UnaryTrigonometricFunc<TanOp>
{
public:
	TanOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryTrigonometricFunc<TanOp>(state, "tan", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		// \note Currently tan() is limited to -4..4 range. Otherwise we will run into accuracy issues
		const float rangeMin = -4.0f;
		const float rangeMax = +4.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false;

		inMin = deFloatAtanOver(deFloatMax(outMin, rangeMin));
		inMax = deFloatAtanOver(deFloatMin(outMax, rangeMax));

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatTan(inVal);
	}
};

class AsinOp : public UnaryTrigonometricFunc<AsinOp>
{
public:
	AsinOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryTrigonometricFunc<AsinOp>(state, "asin", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		const float rangeMin = -DE_PI/2.0f;
		const float rangeMax = +DE_PI/2.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMin = deFloatSin(deFloatMax(outMin, rangeMin));
		inMax = deFloatSin(deFloatMin(outMax, rangeMax));

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatAsin(inVal);
	}
};

class AcosOp : public UnaryTrigonometricFunc<AcosOp>
{
public:
	AcosOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryTrigonometricFunc<AcosOp>(state, "acos", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		const float rangeMin = 0.0f;
		const float rangeMax = DE_PI;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMax = deFloatCos(deFloatMax(outMin, rangeMin));
		inMin = deFloatCos(deFloatMin(outMax, rangeMax));

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatAcos(inVal);
	}
};

class AtanOp : public UnaryTrigonometricFunc<AtanOp>
{
public:
	AtanOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryTrigonometricFunc<AtanOp>(state, "atan", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		// \note For accuracy reasons output range is limited to -1..1
		const float rangeMin = -1.0f;
		const float rangeMax = +1.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMin = deFloatTan(deFloatMax(outMin, rangeMin));
		inMax = deFloatTan(deFloatMin(outMax, rangeMax));

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatAtanOver(inVal);
	}
};

// Template for exponential function group.
// \todo [2011-07-07 pyry] Shares most of the code with Trigonometric variant..
template <class C>
class UnaryExponentialFunc : public UnaryBuiltinVecTemplateProxy<C>
{
public:
	UnaryExponentialFunc (GeneratorState& state, const char* function, ConstValueRangeAccess valueRange)
		: UnaryBuiltinVecTemplateProxy<C>(state, function, valueRange)
	{
	}

	static inline float getCompWeight (float outMin, float outMax)
	{
		if (Scalar::min<float>() == outMin || Scalar::max<float>() == outMax)
			return 1.0f; // Infinite value range, anything goes

		// Transform range
		float inMin, inMax;
		if (!C::transformValueRange(outMin, outMax, inMin, inMax))
			return 0.0f; // Not possible to transform value range (out of range perhaps)

		// Quantize
		if (!quantizeFloatRange(inMin, inMax))
			return 0.0f; // Not possible to quantize - would cause accuracy issues

		if (outMin == outMax)
			return 1.0f; // Constant value and passed quantization

		// Evaluate new intersection
		float intersectionLen	= C::evaluateComp(inMax) - C::evaluateComp(inMin);
		float valRangeLen		= outMax - outMin;

		return deFloatMax(0.1f, intersectionLen/valRangeLen);
	}

	static inline void computeValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		DE_VERIFY(C::transformValueRange(outMin, outMax, inMin, inMax));
		DE_VERIFY(quantizeFloatRange(inMin, inMax));
		DE_ASSERT(inMin <= inMax);
	}

	static float getWeight (const GeneratorState& state, ConstValueRangeAccess valueRange)
	{
		if (state.getProgramParameters().exponentialBaseWeight <= 0.0f)
			return 0.0f;

		return UnaryBuiltinVecTemplateProxy<C>::getWeight(state, valueRange) * state.getProgramParameters().exponentialBaseWeight;
	}
};

class ExpOp : public UnaryExponentialFunc<ExpOp>
{
public:
	ExpOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryExponentialFunc<ExpOp>(state, "exp", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		// Limited due to accuracy reasons, should be 0..+inf
		const float rangeMin = 0.1f;
		const float rangeMax = 10.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMin = deFloatLog(deFloatMax(outMin, rangeMin));
		inMax = deFloatLog(deFloatMin(outMax, rangeMax));

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatExp(inVal);
	}
};

class LogOp : public UnaryExponentialFunc<LogOp>
{
public:
	LogOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryExponentialFunc<LogOp>(state, "log", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		// Limited due to accuracy reasons, should be -inf..+inf
		const float rangeMin = 0.1f;
		const float rangeMax = 6.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMin = deFloatExp(deFloatMax(outMin, rangeMin));
		inMax = deFloatExp(deFloatMin(outMax, rangeMax));

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatLog(inVal);
	}
};

class Exp2Op : public UnaryExponentialFunc<Exp2Op>
{
public:
	Exp2Op (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryExponentialFunc<Exp2Op>(state, "exp2", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		// Limited due to accuracy reasons, should be 0..+inf
		const float rangeMin = 0.1f;
		const float rangeMax = 10.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMin = deFloatLog2(deFloatMax(outMin, rangeMin));
		inMax = deFloatLog2(deFloatMin(outMax, rangeMax));

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatExp2(inVal);
	}
};

class Log2Op : public UnaryExponentialFunc<Log2Op>
{
public:
	Log2Op (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryExponentialFunc<Log2Op>(state, "log2", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		// Limited due to accuracy reasons, should be -inf..+inf
		const float rangeMin = 0.1f;
		const float rangeMax = 6.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMin = deFloatExp2(deFloatMax(outMin, rangeMin));
		inMax = deFloatExp2(deFloatMin(outMax, rangeMax));

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatLog2(inVal);
	}
};

class SqrtOp : public UnaryExponentialFunc<SqrtOp>
{
public:
	SqrtOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryExponentialFunc<SqrtOp>(state, "sqrt", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		// Limited due to accuracy reasons, should be 0..+inf
		const float rangeMin = 0.0f;
		const float rangeMax = 4.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMin = deFloatMax(outMin, rangeMin);
		inMax = deFloatMin(outMax, rangeMax);

		inMin *= inMin;
		inMax *= inMax;

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return deFloatSqrt(inVal);
	}
};

class InvSqrtOp : public UnaryExponentialFunc<InvSqrtOp>
{
public:
	InvSqrtOp (GeneratorState& state, ConstValueRangeAccess valueRange)
		: UnaryExponentialFunc<InvSqrtOp>(state, "inversesqrt", valueRange)
	{
	}

	static inline bool transformValueRange (float outMin, float outMax, float& inMin, float& inMax)
	{
		// Limited due to accuracy reasons
		const float rangeMin = 0.4f;
		const float rangeMax = 3.0f;

		if (outMax < rangeMin || outMin > rangeMax)
			return false; // Out of range

		inMax = 1.0f/deFloatMax(outMin, rangeMin);
		inMin = 1.0f/deFloatMin(outMax, rangeMax);

		inMin *= inMin;
		inMax *= inMax;

		return true;
	}

	static inline float evaluateComp (float inVal)
	{
		return 1.0f/deFloatSqrt(inVal);
	}
};

} // rsg

#endif // _RSGBUILTINFUNCTIONS_HPP