// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <map>
#include <memory>
#include <string>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "source/opt/build_module.h"
#include "source/opt/cfg.h"
#include "source/opt/ir_context.h"
#include "source/opt/pass.h"
#include "source/opt/propagator.h"

namespace spvtools {
namespace opt {
namespace {

using ::testing::UnorderedElementsAre;

class PropagatorTest : public testing::Test {
 protected:
  virtual void TearDown() {
    ctx_.reset(nullptr);
    values_.clear();
    values_vec_.clear();
  }

  void Assemble(const std::string& input) {
    ctx_ = BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, input);
    ASSERT_NE(nullptr, ctx_) << "Assembling failed for shader:\n"
                             << input << "\n";
  }

  bool Propagate(const SSAPropagator::VisitFunction& visit_fn) {
    SSAPropagator propagator(ctx_.get(), visit_fn);
    bool retval = false;
    for (auto& fn : *ctx_->module()) {
      retval |= propagator.Run(&fn);
    }
    return retval;
  }

  const std::vector<uint32_t>& GetValues() {
    values_vec_.clear();
    for (const auto& it : values_) {
      values_vec_.push_back(it.second);
    }
    return values_vec_;
  }

  std::unique_ptr<IRContext> ctx_;
  std::map<uint32_t, uint32_t> values_;
  std::vector<uint32_t> values_vec_;
};

TEST_F(PropagatorTest, LocalPropagate) {
  const std::string spv_asm = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %main "main" %outparm
               OpExecutionMode %main OriginUpperLeft
               OpSource GLSL 450
               OpName %main "main"
               OpName %x "x"
               OpName %y "y"
               OpName %z "z"
               OpName %outparm "outparm"
               OpDecorate %outparm Location 0
       %void = OpTypeVoid
          %3 = OpTypeFunction %void
        %int = OpTypeInt 32 1
%_ptr_Function_int = OpTypePointer Function %int
      %int_4 = OpConstant %int 4
      %int_3 = OpConstant %int 3
      %int_1 = OpConstant %int 1
%_ptr_Output_int = OpTypePointer Output %int
    %outparm = OpVariable %_ptr_Output_int Output
       %main = OpFunction %void None %3
          %5 = OpLabel
          %x = OpVariable %_ptr_Function_int Function
          %y = OpVariable %_ptr_Function_int Function
          %z = OpVariable %_ptr_Function_int Function
               OpStore %x %int_4
               OpStore %y %int_3
               OpStore %z %int_1
         %20 = OpLoad %int %z
               OpStore %outparm %20
               OpReturn
               OpFunctionEnd
               )";
  Assemble(spv_asm);

  const auto visit_fn = [this](Instruction* instr, BasicBlock** dest_bb) {
    *dest_bb = nullptr;
    if (instr->opcode() == SpvOpStore) {
      uint32_t lhs_id = instr->GetSingleWordOperand(0);
      uint32_t rhs_id = instr->GetSingleWordOperand(1);
      Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
      if (rhs_def->opcode() == SpvOpConstant) {
        uint32_t val = rhs_def->GetSingleWordOperand(2);
        values_[lhs_id] = val;
        return SSAPropagator::kInteresting;
      }
    }
    return SSAPropagator::kVarying;
  };

  EXPECT_TRUE(Propagate(visit_fn));
  EXPECT_THAT(GetValues(), UnorderedElementsAre(4, 3, 1));
}

TEST_F(PropagatorTest, PropagateThroughPhis) {
  const std::string spv_asm = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %main "main" %x %outparm
               OpExecutionMode %main OriginUpperLeft
               OpSource GLSL 450
               OpName %main "main"
               OpName %x "x"
               OpName %outparm "outparm"
               OpDecorate %x Flat
               OpDecorate %x Location 0
               OpDecorate %outparm Location 0
       %void = OpTypeVoid
          %3 = OpTypeFunction %void
        %int = OpTypeInt 32 1
       %bool = OpTypeBool
%_ptr_Function_int = OpTypePointer Function %int
      %int_4 = OpConstant %int 4
      %int_3 = OpConstant %int 3
      %int_1 = OpConstant %int 1
%_ptr_Input_int = OpTypePointer Input %int
          %x = OpVariable %_ptr_Input_int Input
%_ptr_Output_int = OpTypePointer Output %int
    %outparm = OpVariable %_ptr_Output_int Output
       %main = OpFunction %void None %3
          %4 = OpLabel
          %5 = OpLoad %int %x
          %6 = OpSGreaterThan %bool %5 %int_3
               OpSelectionMerge %25 None
               OpBranchConditional %6 %22 %23
         %22 = OpLabel
          %7 = OpLoad %int %int_4
               OpBranch %25
         %23 = OpLabel
          %8 = OpLoad %int %int_4
               OpBranch %25
         %25 = OpLabel
         %35 = OpPhi %int %7 %22 %8 %23
               OpStore %outparm %35
               OpReturn
               OpFunctionEnd
               )";

  Assemble(spv_asm);

  Instruction* phi_instr = nullptr;
  const auto visit_fn = [this, &phi_instr](Instruction* instr,
                                           BasicBlock** dest_bb) {
    *dest_bb = nullptr;
    if (instr->opcode() == SpvOpLoad) {
      uint32_t rhs_id = instr->GetSingleWordOperand(2);
      Instruction* rhs_def = ctx_->get_def_use_mgr()->GetDef(rhs_id);
      if (rhs_def->opcode() == SpvOpConstant) {
        uint32_t val = rhs_def->GetSingleWordOperand(2);
        values_[instr->result_id()] = val;
        return SSAPropagator::kInteresting;
      }
    } else if (instr->opcode() == SpvOpPhi) {
      phi_instr = instr;
      SSAPropagator::PropStatus retval;
      for (uint32_t i = 2; i < instr->NumOperands(); i += 2) {
        uint32_t phi_arg_id = instr->GetSingleWordOperand(i);
        auto it = values_.find(phi_arg_id);
        if (it != values_.end()) {
          EXPECT_EQ(it->second, 4u);
          retval = SSAPropagator::kInteresting;
          values_[instr->result_id()] = it->second;
        } else {
          retval = SSAPropagator::kNotInteresting;
          break;
        }
      }
      return retval;
    }

    return SSAPropagator::kVarying;
  };

  EXPECT_TRUE(Propagate(visit_fn));

  // The propagator should've concluded that the Phi instruction has a constant
  // value of 4.
  EXPECT_NE(phi_instr, nullptr);
  EXPECT_EQ(values_[phi_instr->result_id()], 4u);

  EXPECT_THAT(GetValues(), UnorderedElementsAre(4u, 4u, 4u));
}

}  // namespace
}  // namespace opt
}  // namespace spvtools