// Copyright (c) 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/opt/combine_access_chains.h"

#include <utility>

#include "source/opt/constants.h"
#include "source/opt/ir_builder.h"
#include "source/opt/ir_context.h"

namespace spvtools {
namespace opt {

Pass::Status CombineAccessChains::Process() {
  bool modified = false;

  for (auto& function : *get_module()) {
    modified |= ProcessFunction(function);
  }

  return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}

bool CombineAccessChains::ProcessFunction(Function& function) {
  bool modified = false;

  cfg()->ForEachBlockInReversePostOrder(
      function.entry().get(), [&modified, this](BasicBlock* block) {
        block->ForEachInst([&modified, this](Instruction* inst) {
          switch (inst->opcode()) {
            case SpvOpAccessChain:
            case SpvOpInBoundsAccessChain:
            case SpvOpPtrAccessChain:
            case SpvOpInBoundsPtrAccessChain:
              modified |= CombineAccessChain(inst);
              break;
            default:
              break;
          }
        });
      });

  return modified;
}

uint32_t CombineAccessChains::GetConstantValue(
    const analysis::Constant* constant_inst) {
  if (constant_inst->type()->AsInteger()->width() <= 32) {
    if (constant_inst->type()->AsInteger()->IsSigned()) {
      return static_cast<uint32_t>(constant_inst->GetS32());
    } else {
      return constant_inst->GetU32();
    }
  } else {
    assert(false);
    return 0u;
  }
}

uint32_t CombineAccessChains::GetArrayStride(const Instruction* inst) {
  uint32_t array_stride = 0;
  context()->get_decoration_mgr()->WhileEachDecoration(
      inst->type_id(), SpvDecorationArrayStride,
      [&array_stride](const Instruction& decoration) {
        assert(decoration.opcode() != SpvOpDecorateId);
        if (decoration.opcode() == SpvOpDecorate) {
          array_stride = decoration.GetSingleWordInOperand(1);
        } else {
          array_stride = decoration.GetSingleWordInOperand(2);
        }
        return false;
      });
  return array_stride;
}

const analysis::Type* CombineAccessChains::GetIndexedType(Instruction* inst) {
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  analysis::TypeManager* type_mgr = context()->get_type_mgr();

  Instruction* base_ptr = def_use_mgr->GetDef(inst->GetSingleWordInOperand(0));
  const analysis::Type* type = type_mgr->GetType(base_ptr->type_id());
  assert(type->AsPointer());
  type = type->AsPointer()->pointee_type();
  std::vector<uint32_t> element_indices;
  uint32_t starting_index = 1;
  if (IsPtrAccessChain(inst->opcode())) {
    // Skip the first index of OpPtrAccessChain as it does not affect type
    // resolution.
    starting_index = 2;
  }
  for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
    Instruction* index_inst =
        def_use_mgr->GetDef(inst->GetSingleWordInOperand(i));
    const analysis::Constant* index_constant =
        context()->get_constant_mgr()->GetConstantFromInst(index_inst);
    if (index_constant) {
      uint32_t index_value = GetConstantValue(index_constant);
      element_indices.push_back(index_value);
    } else {
      // This index must not matter to resolve the type in valid SPIR-V.
      element_indices.push_back(0);
    }
  }
  type = type_mgr->GetMemberType(type, element_indices);
  return type;
}

bool CombineAccessChains::CombineIndices(Instruction* ptr_input,
                                         Instruction* inst,
                                         std::vector<Operand>* new_operands) {
  analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr();
  analysis::ConstantManager* constant_mgr = context()->get_constant_mgr();

  Instruction* last_index_inst = def_use_mgr->GetDef(
      ptr_input->GetSingleWordInOperand(ptr_input->NumInOperands() - 1));
  const analysis::Constant* last_index_constant =
      constant_mgr->GetConstantFromInst(last_index_inst);

  Instruction* element_inst =
      def_use_mgr->GetDef(inst->GetSingleWordInOperand(1));
  const analysis::Constant* element_constant =
      constant_mgr->GetConstantFromInst(element_inst);

  // Combine the last index of the AccessChain (|ptr_inst|) with the element
  // operand of the PtrAccessChain (|inst|).
  const bool combining_element_operands =
      IsPtrAccessChain(inst->opcode()) &&
      IsPtrAccessChain(ptr_input->opcode()) && ptr_input->NumInOperands() == 2;
  uint32_t new_value_id = 0;
  const analysis::Type* type = GetIndexedType(ptr_input);
  if (last_index_constant && element_constant) {
    // Combine the constants.
    uint32_t new_value = GetConstantValue(last_index_constant) +
                         GetConstantValue(element_constant);
    const analysis::Constant* new_value_constant =
        constant_mgr->GetConstant(last_index_constant->type(), {new_value});
    Instruction* new_value_inst =
        constant_mgr->GetDefiningInstruction(new_value_constant);
    new_value_id = new_value_inst->result_id();
  } else if (!type->AsStruct() || combining_element_operands) {
    // Generate an addition of the two indices.
    InstructionBuilder builder(
        context(), inst,
        IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
    Instruction* addition = builder.AddIAdd(last_index_inst->type_id(),
                                            last_index_inst->result_id(),
                                            element_inst->result_id());
    new_value_id = addition->result_id();
  } else {
    // Indexing into structs must be constant, so bail out here.
    return false;
  }
  new_operands->push_back({SPV_OPERAND_TYPE_ID, {new_value_id}});
  return true;
}

bool CombineAccessChains::CreateNewInputOperands(
    Instruction* ptr_input, Instruction* inst,
    std::vector<Operand>* new_operands) {
  // Start by copying all the input operands of the feeder access chain.
  for (uint32_t i = 0; i != ptr_input->NumInOperands() - 1; ++i) {
    new_operands->push_back(ptr_input->GetInOperand(i));
  }

  // Deal with the last index of the feeder access chain.
  if (IsPtrAccessChain(inst->opcode())) {
    // The last index of the feeder should be combined with the element operand
    // of |inst|.
    if (!CombineIndices(ptr_input, inst, new_operands)) return false;
  } else {
    // The indices aren't being combined so now add the last index operand of
    // |ptr_input|.
    new_operands->push_back(
        ptr_input->GetInOperand(ptr_input->NumInOperands() - 1));
  }

  // Copy the remaining index operands.
  uint32_t starting_index = IsPtrAccessChain(inst->opcode()) ? 2 : 1;
  for (uint32_t i = starting_index; i < inst->NumInOperands(); ++i) {
    new_operands->push_back(inst->GetInOperand(i));
  }

  return true;
}

bool CombineAccessChains::CombineAccessChain(Instruction* inst) {
  assert((inst->opcode() == SpvOpPtrAccessChain ||
          inst->opcode() == SpvOpAccessChain ||
          inst->opcode() == SpvOpInBoundsAccessChain ||
          inst->opcode() == SpvOpInBoundsPtrAccessChain) &&
         "Wrong opcode. Expected an access chain.");

  Instruction* ptr_input =
      context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(0));
  if (ptr_input->opcode() != SpvOpAccessChain &&
      ptr_input->opcode() != SpvOpInBoundsAccessChain &&
      ptr_input->opcode() != SpvOpPtrAccessChain &&
      ptr_input->opcode() != SpvOpInBoundsPtrAccessChain) {
    return false;
  }

  if (Has64BitIndices(inst) || Has64BitIndices(ptr_input)) return false;

  // Handles the following cases:
  // 1. |ptr_input| is an index-less access chain. Replace the pointer
  //    in |inst| with |ptr_input|'s pointer.
  // 2. |inst| is a index-less access chain. Change |inst| to an
  //    OpCopyObject.
  // 3. |inst| is not a pointer access chain.
  //    |inst|'s indices are appended to |ptr_input|'s indices.
  // 4. |ptr_input| is not pointer access chain.
  //    |inst| is a pointer access chain.
  //    |inst|'s element operand is combined with the last index in
  //    |ptr_input| to form a new operand.
  // 5. |ptr_input| is a pointer access chain.
  //    Like the above scenario, |inst|'s element operand is combined
  //    with |ptr_input|'s last index. This results is either a
  //    combined element operand or combined regular index.

  // TODO(alan-baker): Support this properly. Requires analyzing the
  // size/alignment of the type and converting the stride into an element
  // index.
  uint32_t array_stride = GetArrayStride(ptr_input);
  if (array_stride != 0) return false;

  if (ptr_input->NumInOperands() == 1) {
    // The input is effectively a no-op.
    inst->SetInOperand(0, {ptr_input->GetSingleWordInOperand(0)});
    context()->AnalyzeUses(inst);
  } else if (inst->NumInOperands() == 1) {
    // |inst| is a no-op, change it to a copy. Instruction simplification will
    // clean it up.
    inst->SetOpcode(SpvOpCopyObject);
  } else {
    std::vector<Operand> new_operands;
    if (!CreateNewInputOperands(ptr_input, inst, &new_operands)) return false;

    // Update the instruction.
    inst->SetOpcode(UpdateOpcode(inst->opcode(), ptr_input->opcode()));
    inst->SetInOperands(std::move(new_operands));
    context()->AnalyzeUses(inst);
  }
  return true;
}

SpvOp CombineAccessChains::UpdateOpcode(SpvOp base_opcode, SpvOp input_opcode) {
  auto IsInBounds = [](SpvOp opcode) {
    return opcode == SpvOpInBoundsPtrAccessChain ||
           opcode == SpvOpInBoundsAccessChain;
  };

  if (input_opcode == SpvOpInBoundsPtrAccessChain) {
    if (!IsInBounds(base_opcode)) return SpvOpPtrAccessChain;
  } else if (input_opcode == SpvOpInBoundsAccessChain) {
    if (!IsInBounds(base_opcode)) return SpvOpAccessChain;
  }

  return input_opcode;
}

bool CombineAccessChains::IsPtrAccessChain(SpvOp opcode) {
  return opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain;
}

bool CombineAccessChains::Has64BitIndices(Instruction* inst) {
  for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
    Instruction* index_inst =
        context()->get_def_use_mgr()->GetDef(inst->GetSingleWordInOperand(i));
    const analysis::Type* index_type =
        context()->get_type_mgr()->GetType(index_inst->type_id());
    if (!index_type->AsInteger() || index_type->AsInteger()->width() != 32)
      return true;
  }
  return false;
}

}  // namespace opt
}  // namespace spvtools