// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "source/val/validate.h"
#include "source/opcode.h"
#include "source/val/instruction.h"
#include "source/val/validation_state.h"
namespace spvtools {
namespace val {
namespace {
spv_result_t ValidateConstantBool(ValidationState_t& _,
const Instruction* inst) {
auto type = _.FindDef(inst->type_id());
if (!type || type->opcode() != SpvOpTypeBool) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '"
<< _.getIdName(inst->type_id()) << "' is not a boolean type.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateConstantComposite(ValidationState_t& _,
const Instruction* inst) {
std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
const auto result_type = _.FindDef(inst->type_id());
if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Result Type <id> '"
<< _.getIdName(inst->type_id()) << "' is not a composite type.";
}
const auto constituent_count = inst->words().size() - 3;
switch (result_type->opcode()) {
case SpvOpTypeVector: {
const auto component_count = result_type->GetOperandAs<uint32_t>(2);
if (component_count != constituent_count) {
// TODO: Output ID's on diagnostic
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name
<< " Constituent <id> count does not match "
"Result Type <id> '"
<< _.getIdName(result_type->id())
<< "'s vector component count.";
}
const auto component_type =
_.FindDef(result_type->GetOperandAs<uint32_t>(1));
if (!component_type) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "Component type is not defined.";
}
for (size_t constituent_index = 2;
constituent_index < inst->operands().size(); constituent_index++) {
const auto constituent_id =
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto constituent_result_type = _.FindDef(constituent->type_id());
if (!constituent_result_type ||
component_type->opcode() != constituent_result_type->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "'s type does not match Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s vector element type.";
}
}
} break;
case SpvOpTypeMatrix: {
const auto column_count = result_type->GetOperandAs<uint32_t>(2);
if (column_count != constituent_count) {
// TODO: Output ID's on diagnostic
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name
<< " Constituent <id> count does not match "
"Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s matrix column count.";
}
const auto column_type = _.FindDef(result_type->words()[2]);
if (!column_type) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "Column type is not defined.";
}
const auto component_count = column_type->GetOperandAs<uint32_t>(2);
const auto component_type =
_.FindDef(column_type->GetOperandAs<uint32_t>(1));
if (!component_type) {
return _.diag(SPV_ERROR_INVALID_ID, column_type)
<< "Component type is not defined.";
}
for (size_t constituent_index = 2;
constituent_index < inst->operands().size(); constituent_index++) {
const auto constituent_id =
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
!(SpvOpConstantComposite == constituent->opcode() ||
SpvOpSpecConstantComposite == constituent->opcode() ||
SpvOpUndef == constituent->opcode())) {
// The message says "... or undef" because the spec does not say
// undef is a constant.
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant composite or undef.";
}
const auto vector = _.FindDef(constituent->type_id());
if (!vector) {
return _.diag(SPV_ERROR_INVALID_ID, constituent)
<< "Result type is not defined.";
}
if (column_type->opcode() != vector->opcode()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' type does not match Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s matrix column type.";
}
const auto vector_component_type =
_.FindDef(vector->GetOperandAs<uint32_t>(1));
if (component_type->id() != vector_component_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' component type does not match Result Type <id> '"
<< _.getIdName(result_type->id())
<< "'s matrix column component type.";
}
if (component_count != vector->words()[3]) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' vector component count does not match Result Type <id> '"
<< _.getIdName(result_type->id())
<< "'s vector component count.";
}
}
} break;
case SpvOpTypeArray: {
auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
if (!element_type) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "Element type is not defined.";
}
const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
if (!length) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "Length is not defined.";
}
bool is_int32;
bool is_const;
uint32_t value;
std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
if (is_int32 && is_const && value != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name
<< " Constituent count does not match "
"Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s array length.";
}
for (size_t constituent_index = 2;
constituent_index < inst->operands().size(); constituent_index++) {
const auto constituent_id =
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto constituent_type = _.FindDef(constituent->type_id());
if (!constituent_type) {
return _.diag(SPV_ERROR_INVALID_ID, constituent)
<< "Result type is not defined.";
}
if (element_type->id() != constituent_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "'s type does not match Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s array element type.";
}
}
} break;
case SpvOpTypeStruct: {
const auto member_count = result_type->words().size() - 2;
if (member_count != constituent_count) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(inst->type_id())
<< "' count does not match Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s struct member count.";
}
for (uint32_t constituent_index = 2, member_index = 1;
constituent_index < inst->operands().size();
constituent_index++, member_index++) {
const auto constituent_id =
inst->GetOperandAs<uint32_t>(constituent_index);
const auto constituent = _.FindDef(constituent_id);
if (!constituent ||
!spvOpcodeIsConstantOrUndef(constituent->opcode())) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' is not a constant or undef.";
}
const auto constituent_type = _.FindDef(constituent->type_id());
if (!constituent_type) {
return _.diag(SPV_ERROR_INVALID_ID, constituent)
<< "Result type is not defined.";
}
const auto member_type_id =
result_type->GetOperandAs<uint32_t>(member_index);
const auto member_type = _.FindDef(member_type_id);
if (!member_type || member_type->id() != constituent_type->id()) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< opcode_name << " Constituent <id> '"
<< _.getIdName(constituent_id)
<< "' type does not match the Result Type <id> '"
<< _.getIdName(result_type->id()) << "'s member type.";
}
}
} break;
default:
break;
}
return SPV_SUCCESS;
}
spv_result_t ValidateConstantSampler(ValidationState_t& _,
const Instruction* inst) {
const auto result_type = _.FindDef(inst->type_id());
if (!result_type || result_type->opcode() != SpvOpTypeSampler) {
return _.diag(SPV_ERROR_INVALID_ID, result_type)
<< "OpConstantSampler Result Type <id> '"
<< _.getIdName(inst->type_id()) << "' is not a sampler type.";
}
return SPV_SUCCESS;
}
// True if instruction defines a type that can have a null value, as defined by
// the SPIR-V spec. Tracks composite-type components through module to check
// nullability transitively.
bool IsTypeNullable(const std::vector<uint32_t>& instruction,
const ValidationState_t& _) {
uint16_t opcode;
uint16_t word_count;
spvOpcodeSplit(instruction[0], &word_count, &opcode);
switch (static_cast<SpvOp>(opcode)) {
case SpvOpTypeBool:
case SpvOpTypeInt:
case SpvOpTypeFloat:
case SpvOpTypePointer:
case SpvOpTypeEvent:
case SpvOpTypeDeviceEvent:
case SpvOpTypeReserveId:
case SpvOpTypeQueue:
return true;
case SpvOpTypeArray:
case SpvOpTypeMatrix:
case SpvOpTypeVector: {
auto base_type = _.FindDef(instruction[2]);
return base_type && IsTypeNullable(base_type->words(), _);
}
case SpvOpTypeStruct: {
for (size_t elementIndex = 2; elementIndex < instruction.size();
++elementIndex) {
auto element = _.FindDef(instruction[elementIndex]);
if (!element || !IsTypeNullable(element->words(), _)) return false;
}
return true;
}
default:
return false;
}
}
spv_result_t ValidateConstantNull(ValidationState_t& _,
const Instruction* inst) {
const auto result_type = _.FindDef(inst->type_id());
if (!result_type || !IsTypeNullable(result_type->words(), _)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "OpConstantNull Result Type <id> '"
<< _.getIdName(inst->type_id()) << "' cannot have a null value.";
}
return SPV_SUCCESS;
}
spv_result_t ValidateSpecConstantOp(ValidationState_t& _,
const Instruction* inst) {
const auto op = inst->GetOperandAs<SpvOp>(2);
// The binary parser already ensures that the op is valid for *some*
// environment. Here we check restrictions.
switch(op) {
case SpvOpQuantizeToF16:
if (!_.HasCapability(SpvCapabilityShader)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Specialization constant operation " << spvOpcodeString(op)
<< " requires Shader capability";
}
break;
case SpvOpUConvert:
if (!_.features().uconvert_spec_constant_op &&
!_.HasCapability(SpvCapabilityKernel)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "UConvert requires Kernel capability or extension "
"SPV_AMD_gpu_shader_int16";
}
break;
case SpvOpConvertFToS:
case SpvOpConvertSToF:
case SpvOpConvertFToU:
case SpvOpConvertUToF:
case SpvOpConvertPtrToU:
case SpvOpConvertUToPtr:
case SpvOpGenericCastToPtr:
case SpvOpPtrCastToGeneric:
case SpvOpBitcast:
case SpvOpFNegate:
case SpvOpFAdd:
case SpvOpFSub:
case SpvOpFMul:
case SpvOpFDiv:
case SpvOpFRem:
case SpvOpFMod:
case SpvOpAccessChain:
case SpvOpInBoundsAccessChain:
case SpvOpPtrAccessChain:
case SpvOpInBoundsPtrAccessChain:
if (!_.HasCapability(SpvCapabilityKernel)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Specialization constant operation " << spvOpcodeString(op)
<< " requires Kernel capability";
}
break;
default:
break;
}
// TODO(dneto): Validate result type and arguments to the various operations.
return SPV_SUCCESS;
}
} // namespace
spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
switch (inst->opcode()) {
case SpvOpConstantTrue:
case SpvOpConstantFalse:
case SpvOpSpecConstantTrue:
case SpvOpSpecConstantFalse:
if (auto error = ValidateConstantBool(_, inst)) return error;
break;
case SpvOpConstantComposite:
case SpvOpSpecConstantComposite:
if (auto error = ValidateConstantComposite(_, inst)) return error;
break;
case SpvOpConstantSampler:
if (auto error = ValidateConstantSampler(_, inst)) return error;
break;
case SpvOpConstantNull:
if (auto error = ValidateConstantNull(_, inst)) return error;
break;
case SpvOpSpecConstantOp:
if (auto error = ValidateSpecConstantOp(_, inst)) return error;
break;
default:
break;
}
return SPV_SUCCESS;
}
} // namespace val
} // namespace spvtools