// Copyright (c) 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Ensures Data Rules are followed according to the specifications.

#include "source/val/validate.h"

#include <cassert>
#include <sstream>
#include <string>

#include "source/diagnostic.h"
#include "source/opcode.h"
#include "source/operand.h"
#include "source/val/instruction.h"
#include "source/val/validation_state.h"

namespace spvtools {
namespace val {
namespace {

// Validates that the number of components in the vector is valid.
// Vector types can only be parameterized as having 2, 3, or 4 components.
// If the Vector16 capability is added, 8 and 16 components are also allowed.
spv_result_t ValidateVecNumComponents(ValidationState_t& _,
                                      const Instruction* inst) {
  // Operand 2 specifies the number of components in the vector.
  auto num_components = inst->GetOperandAs<const uint32_t>(2);
  if (num_components == 2 || num_components == 3 || num_components == 4) {
    return SPV_SUCCESS;
  }
  if (num_components == 8 || num_components == 16) {
    if (_.HasCapability(SpvCapabilityVector16)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
           << "Having " << num_components << " components for "
           << spvOpcodeString(inst->opcode())
           << " requires the Vector16 capability";
  }
  return _.diag(SPV_ERROR_INVALID_DATA, inst)
         << "Illegal number of components (" << num_components << ") for "
         << spvOpcodeString(inst->opcode());
}

// Validates that the number of bits specifed for a float type is valid.
// Scalar floating-point types can be parameterized only with 32-bits.
// Float16 capability allows using a 16-bit OpTypeFloat.
// Float16Buffer capability allows creation of a 16-bit OpTypeFloat.
// Float64 capability allows using a 64-bit OpTypeFloat.
spv_result_t ValidateFloatSize(ValidationState_t& _, const Instruction* inst) {
  // Operand 1 is the number of bits for this float
  auto num_bits = inst->GetOperandAs<const uint32_t>(1);
  if (num_bits == 32) {
    return SPV_SUCCESS;
  }
  if (num_bits == 16) {
    if (_.features().declare_float16_type) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
           << "Using a 16-bit floating point "
           << "type requires the Float16 or Float16Buffer capability,"
              " or an extension that explicitly enables 16-bit floating point.";
  }
  if (num_bits == 64) {
    if (_.HasCapability(SpvCapabilityFloat64)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
           << "Using a 64-bit floating point "
           << "type requires the Float64 capability.";
  }
  return _.diag(SPV_ERROR_INVALID_DATA, inst)
         << "Invalid number of bits (" << num_bits << ") used for OpTypeFloat.";
}

// Validates that the number of bits specified for an Int type is valid.
// Scalar integer types can be parameterized only with 32-bits.
// Int8, Int16, and Int64 capabilities allow using 8-bit, 16-bit, and 64-bit
// integers, respectively.
spv_result_t ValidateIntSize(ValidationState_t& _, const Instruction* inst) {
  // Operand 1 is the number of bits for this integer.
  auto num_bits = inst->GetOperandAs<const uint32_t>(1);
  if (num_bits == 32) {
    return SPV_SUCCESS;
  }
  if (num_bits == 8) {
    if (_.features().declare_int8_type) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
           << "Using an 8-bit integer type requires the Int8 capability,"
              " or an extension that explicitly enables 8-bit integers.";
  }
  if (num_bits == 16) {
    if (_.features().declare_int16_type) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
           << "Using a 16-bit integer type requires the Int16 capability,"
              " or an extension that explicitly enables 16-bit integers.";
  }
  if (num_bits == 64) {
    if (_.HasCapability(SpvCapabilityInt64)) {
      return SPV_SUCCESS;
    }
    return _.diag(SPV_ERROR_INVALID_DATA, inst)
           << "Using a 64-bit integer type requires the Int64 capability.";
  }
  return _.diag(SPV_ERROR_INVALID_DATA, inst)
         << "Invalid number of bits (" << num_bits << ") used for OpTypeInt.";
}

// Validates that the matrix is parameterized with floating-point types.
spv_result_t ValidateMatrixColumnType(ValidationState_t& _,
                                      const Instruction* inst) {
  // Find the component type of matrix columns (must be vector).
  // Operand 1 is the <id> of the type specified for matrix columns.
  auto type_id = inst->GetOperandAs<const uint32_t>(1);
  auto col_type_instr = _.FindDef(type_id);
  if (col_type_instr->opcode() != SpvOpTypeVector) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "Columns in a matrix must be of type vector.";
  }

  // Trace back once more to find out the type of components in the vector.
  // Operand 1 is the <id> of the type of data in the vector.
  auto comp_type_id =
      col_type_instr->words()[col_type_instr->operands()[1].offset];
  auto comp_type_instruction = _.FindDef(comp_type_id);
  if (comp_type_instruction->opcode() != SpvOpTypeFloat) {
    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
                                                   "parameterized with "
                                                   "floating-point types.";
  }
  return SPV_SUCCESS;
}

// Validates that the matrix has 2,3, or 4 columns.
spv_result_t ValidateMatrixNumCols(ValidationState_t& _,
                                   const Instruction* inst) {
  // Operand 2 is the number of columns in the matrix.
  auto num_cols = inst->GetOperandAs<const uint32_t>(2);
  if (num_cols != 2 && num_cols != 3 && num_cols != 4) {
    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Matrix types can only be "
                                                   "parameterized as having "
                                                   "only 2, 3, or 4 columns.";
  }
  return SPV_SUCCESS;
}

// Validates that OpSpecConstant specializes to either int or float type.
spv_result_t ValidateSpecConstNumerical(ValidationState_t& _,
                                        const Instruction* inst) {
  // Operand 0 is the <id> of the type that we're specializing to.
  auto type_id = inst->GetOperandAs<const uint32_t>(0);
  auto type_instruction = _.FindDef(type_id);
  auto type_opcode = type_instruction->opcode();
  if (type_opcode != SpvOpTypeInt && type_opcode != SpvOpTypeFloat) {
    return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Specialization constant "
                                                   "must be an integer or "
                                                   "floating-point number.";
  }
  return SPV_SUCCESS;
}

// Validates that OpSpecConstantTrue and OpSpecConstantFalse specialize to bool.
spv_result_t ValidateSpecConstBoolean(ValidationState_t& _,
                                      const Instruction* inst) {
  // Find out the type that we're specializing to.
  auto type_instruction = _.FindDef(inst->type_id());
  if (type_instruction->opcode() != SpvOpTypeBool) {
    return _.diag(SPV_ERROR_INVALID_ID, inst)
           << "Specialization constant must be a boolean type.";
  }
  return SPV_SUCCESS;
}

// Records the <id> of the forward pointer to be used for validation.
spv_result_t ValidateForwardPointer(ValidationState_t& _,
                                    const Instruction* inst) {
  // Record the <id> (which is operand 0) to ensure it's used properly.
  // OpTypeStruct can only include undefined pointers that are
  // previously declared as a ForwardPointer
  return (_.RegisterForwardPointer(inst->GetOperandAs<uint32_t>(0)));
}

// Validates that any undefined component of the struct is a forward pointer.
// It is valid to declare a forward pointer, and use its <id> as one of the
// components of a struct.
spv_result_t ValidateStruct(ValidationState_t& _, const Instruction* inst) {
  // Struct components are operands 1, 2, etc.
  for (unsigned i = 1; i < inst->operands().size(); i++) {
    auto type_id = inst->GetOperandAs<const uint32_t>(i);
    auto type_instruction = _.FindDef(type_id);
    if (type_instruction == nullptr && !_.IsForwardPointer(type_id)) {
      return _.diag(SPV_ERROR_INVALID_ID, inst)
             << "Forward reference operands in an OpTypeStruct must first be "
                "declared using OpTypeForwardPointer.";
    }
  }
  return SPV_SUCCESS;
}

}  // namespace

// Validates that Data Rules are followed according to the specifications.
// (Data Rules subsection of 2.16.1 Universal Validation Rules)
spv_result_t DataRulesPass(ValidationState_t& _, const Instruction* inst) {
  switch (inst->opcode()) {
    case SpvOpTypeVector: {
      if (auto error = ValidateVecNumComponents(_, inst)) return error;
      break;
    }
    case SpvOpTypeFloat: {
      if (auto error = ValidateFloatSize(_, inst)) return error;
      break;
    }
    case SpvOpTypeInt: {
      if (auto error = ValidateIntSize(_, inst)) return error;
      break;
    }
    case SpvOpTypeMatrix: {
      if (auto error = ValidateMatrixColumnType(_, inst)) return error;
      if (auto error = ValidateMatrixNumCols(_, inst)) return error;
      break;
    }
    // TODO(ehsan): Add OpSpecConstantComposite validation code.
    // TODO(ehsan): Add OpSpecConstantOp validation code (if any).
    case SpvOpSpecConstant: {
      if (auto error = ValidateSpecConstNumerical(_, inst)) return error;
      break;
    }
    case SpvOpSpecConstantFalse:
    case SpvOpSpecConstantTrue: {
      if (auto error = ValidateSpecConstBoolean(_, inst)) return error;
      break;
    }
    case SpvOpTypeForwardPointer: {
      if (auto error = ValidateForwardPointer(_, inst)) return error;
      break;
    }
    case SpvOpTypeStruct: {
      if (auto error = ValidateStruct(_, inst)) return error;
      break;
    }
    // TODO(ehsan): add more data rules validation here.
    default: { break; }
  }

  return SPV_SUCCESS;
}

}  // namespace val
}  // namespace spvtools