// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <memory>
#include <string>
#include <unordered_set>
#include <vector>

#include "gmock/gmock.h"
#include "source/opt/iterator.h"
#include "source/opt/loop_descriptor.h"
#include "source/opt/pass.h"
#include "source/opt/scalar_analysis.h"
#include "source/opt/tree_iterator.h"
#include "test/opt/assembly_builder.h"
#include "test/opt/function_utils.h"
#include "test/opt/pass_fixture.h"
#include "test/opt/pass_utils.h"

namespace spvtools {
namespace opt {
namespace {

using ::testing::UnorderedElementsAre;
using ScalarAnalysisTest = PassTest<::testing::Test>;

/*
Generated from the following GLSL + --eliminate-local-multi-store

#version 410 core
layout (location = 1) out float array[10];
void main() {
  for (int i = 0; i < 10; ++i) {
    array[i] = array[i+1];
  }
}
*/
TEST_F(ScalarAnalysisTest, BasicEvolutionTest) {
  const std::string text = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %4 "main" %24
               OpExecutionMode %4 OriginUpperLeft
               OpSource GLSL 410
               OpName %4 "main"
               OpName %24 "array"
               OpDecorate %24 Location 1
          %2 = OpTypeVoid
          %3 = OpTypeFunction %2
          %6 = OpTypeInt 32 1
          %7 = OpTypePointer Function %6
          %9 = OpConstant %6 0
         %16 = OpConstant %6 10
         %17 = OpTypeBool
         %19 = OpTypeFloat 32
         %20 = OpTypeInt 32 0
         %21 = OpConstant %20 10
         %22 = OpTypeArray %19 %21
         %23 = OpTypePointer Output %22
         %24 = OpVariable %23 Output
         %27 = OpConstant %6 1
         %29 = OpTypePointer Output %19
          %4 = OpFunction %2 None %3
          %5 = OpLabel
               OpBranch %10
         %10 = OpLabel
         %35 = OpPhi %6 %9 %5 %34 %13
               OpLoopMerge %12 %13 None
               OpBranch %14
         %14 = OpLabel
         %18 = OpSLessThan %17 %35 %16
               OpBranchConditional %18 %11 %12
         %11 = OpLabel
         %28 = OpIAdd %6 %35 %27
         %30 = OpAccessChain %29 %24 %28
         %31 = OpLoad %19 %30
         %32 = OpAccessChain %29 %24 %35
               OpStore %32 %31
               OpBranch %13
         %13 = OpLabel
         %34 = OpIAdd %6 %35 %27
               OpBranch %10
         %12 = OpLabel
               OpReturn
               OpFunctionEnd
  )";
  // clang-format on
  std::unique_ptr<IRContext> context =
      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  Module* module = context->module();
  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
                             << text << std::endl;
  const Function* f = spvtest::GetFunction(module, 4);
  ScalarEvolutionAnalysis analysis{context.get()};

  const Instruction* store = nullptr;
  const Instruction* load = nullptr;
  for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) {
    if (inst.opcode() == SpvOp::SpvOpStore) {
      store = &inst;
    }
    if (inst.opcode() == SpvOp::SpvOpLoad) {
      load = &inst;
    }
  }

  EXPECT_NE(load, nullptr);
  EXPECT_NE(store, nullptr);

  Instruction* access_chain =
      context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));

  Instruction* child = context->get_def_use_mgr()->GetDef(
      access_chain->GetSingleWordInOperand(1));
  const SENode* node = analysis.AnalyzeInstruction(child);

  EXPECT_NE(node, nullptr);

  // Unsimplified node should have the form of ADD(REC(0,1), 1)
  EXPECT_EQ(node->GetType(), SENode::Add);

  const SENode* child_1 = node->GetChild(0);
  EXPECT_TRUE(child_1->GetType() == SENode::Constant ||
              child_1->GetType() == SENode::RecurrentAddExpr);

  const SENode* child_2 = node->GetChild(1);
  EXPECT_TRUE(child_2->GetType() == SENode::Constant ||
              child_2->GetType() == SENode::RecurrentAddExpr);

  SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  // Simplified should be in the form of REC(1,1)
  EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);

  EXPECT_EQ(simplified->GetChild(0)->GetType(), SENode::Constant);
  EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(),
            1);

  EXPECT_EQ(simplified->GetChild(1)->GetType(), SENode::Constant);
  EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(),
            1);

  EXPECT_EQ(simplified->GetChild(0), simplified->GetChild(1));
}

/*
Generated from the following GLSL + --eliminate-local-multi-store

#version 410 core
layout (location = 1) out float array[10];
layout (location = 2) flat in int loop_invariant;
void main() {
  for (int i = 0; i < 10; ++i) {
    array[i] = array[i+loop_invariant];
  }
}

*/
TEST_F(ScalarAnalysisTest, LoadTest) {
  const std::string text = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %2 "main" %3 %4
               OpExecutionMode %2 OriginUpperLeft
               OpSource GLSL 430
               OpName %2 "main"
               OpName %3 "array"
               OpName %4 "loop_invariant"
               OpDecorate %3 Location 1
               OpDecorate %4 Flat
               OpDecorate %4 Location 2
          %5 = OpTypeVoid
          %6 = OpTypeFunction %5
          %7 = OpTypeInt 32 1
          %8 = OpTypePointer Function %7
          %9 = OpConstant %7 0
         %10 = OpConstant %7 10
         %11 = OpTypeBool
         %12 = OpTypeFloat 32
         %13 = OpTypeInt 32 0
         %14 = OpConstant %13 10
         %15 = OpTypeArray %12 %14
         %16 = OpTypePointer Output %15
          %3 = OpVariable %16 Output
         %17 = OpTypePointer Input %7
          %4 = OpVariable %17 Input
         %18 = OpTypePointer Output %12
         %19 = OpConstant %7 1
          %2 = OpFunction %5 None %6
         %20 = OpLabel
               OpBranch %21
         %21 = OpLabel
         %22 = OpPhi %7 %9 %20 %23 %24
               OpLoopMerge %25 %24 None
               OpBranch %26
         %26 = OpLabel
         %27 = OpSLessThan %11 %22 %10
               OpBranchConditional %27 %28 %25
         %28 = OpLabel
         %29 = OpLoad %7 %4
         %30 = OpIAdd %7 %22 %29
         %31 = OpAccessChain %18 %3 %30
         %32 = OpLoad %12 %31
         %33 = OpAccessChain %18 %3 %22
               OpStore %33 %32
               OpBranch %24
         %24 = OpLabel
         %23 = OpIAdd %7 %22 %19
               OpBranch %21
         %25 = OpLabel
               OpReturn
               OpFunctionEnd
)";
  // clang-format on
  std::unique_ptr<IRContext> context =
      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  Module* module = context->module();
  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
                             << text << std::endl;
  const Function* f = spvtest::GetFunction(module, 2);
  ScalarEvolutionAnalysis analysis{context.get()};

  const Instruction* load = nullptr;
  for (const Instruction& inst : *spvtest::GetBasicBlock(f, 28)) {
    if (inst.opcode() == SpvOp::SpvOpLoad) {
      load = &inst;
    }
  }

  EXPECT_NE(load, nullptr);

  Instruction* access_chain =
      context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));

  Instruction* child = context->get_def_use_mgr()->GetDef(
      access_chain->GetSingleWordInOperand(1));
  //  const SENode* node =
  //  analysis.GetNodeFromInstruction(child->unique_id());

  const SENode* node = analysis.AnalyzeInstruction(child);

  EXPECT_NE(node, nullptr);

  // Unsimplified node should have the form of ADD(REC(0,1), X)
  EXPECT_EQ(node->GetType(), SENode::Add);

  const SENode* child_1 = node->GetChild(0);
  EXPECT_TRUE(child_1->GetType() == SENode::ValueUnknown ||
              child_1->GetType() == SENode::RecurrentAddExpr);

  const SENode* child_2 = node->GetChild(1);
  EXPECT_TRUE(child_2->GetType() == SENode::ValueUnknown ||
              child_2->GetType() == SENode::RecurrentAddExpr);

  SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);

  const SERecurrentNode* rec = simplified->AsSERecurrentNode();

  EXPECT_NE(rec->GetChild(0), rec->GetChild(1));

  EXPECT_EQ(rec->GetOffset()->GetType(), SENode::ValueUnknown);

  EXPECT_EQ(rec->GetCoefficient()->GetType(), SENode::Constant);
  EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u);
}

/*
Generated from the following GLSL + --eliminate-local-multi-store

#version 410 core
layout (location = 1) out float array[10];
layout (location = 2) flat in int loop_invariant;
void main() {
  array[0] = array[loop_invariant * 2 + 4 + 5 - 24 - loop_invariant -
loop_invariant+ 16 * 3];
}

*/
TEST_F(ScalarAnalysisTest, SimplifySimple) {
  const std::string text = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %2 "main" %3 %4
               OpExecutionMode %2 OriginUpperLeft
               OpSource GLSL 430
               OpName %2 "main"
               OpName %3 "array"
               OpName %4 "loop_invariant"
               OpDecorate %3 Location 1
               OpDecorate %4 Flat
               OpDecorate %4 Location 2
          %5 = OpTypeVoid
          %6 = OpTypeFunction %5
          %7 = OpTypeFloat 32
          %8 = OpTypeInt 32 0
          %9 = OpConstant %8 10
         %10 = OpTypeArray %7 %9
         %11 = OpTypePointer Output %10
          %3 = OpVariable %11 Output
         %12 = OpTypeInt 32 1
         %13 = OpConstant %12 0
         %14 = OpTypePointer Input %12
          %4 = OpVariable %14 Input
         %15 = OpConstant %12 2
         %16 = OpConstant %12 4
         %17 = OpConstant %12 5
         %18 = OpConstant %12 24
         %19 = OpConstant %12 48
         %20 = OpTypePointer Output %7
          %2 = OpFunction %5 None %6
         %21 = OpLabel
         %22 = OpLoad %12 %4
         %23 = OpIMul %12 %22 %15
         %24 = OpIAdd %12 %23 %16
         %25 = OpIAdd %12 %24 %17
         %26 = OpISub %12 %25 %18
         %28 = OpISub %12 %26 %22
         %30 = OpISub %12 %28 %22
         %31 = OpIAdd %12 %30 %19
         %32 = OpAccessChain %20 %3 %31
         %33 = OpLoad %7 %32
         %34 = OpAccessChain %20 %3 %13
               OpStore %34 %33
               OpReturn
               OpFunctionEnd
    )";
  // clang-format on
  std::unique_ptr<IRContext> context =
      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  Module* module = context->module();
  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
                             << text << std::endl;
  const Function* f = spvtest::GetFunction(module, 2);
  ScalarEvolutionAnalysis analysis{context.get()};

  const Instruction* load = nullptr;
  for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
    if (inst.opcode() == SpvOp::SpvOpLoad && inst.result_id() == 33) {
      load = &inst;
    }
  }

  EXPECT_NE(load, nullptr);

  Instruction* access_chain =
      context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));

  Instruction* child = context->get_def_use_mgr()->GetDef(
      access_chain->GetSingleWordInOperand(1));

  const SENode* node = analysis.AnalyzeInstruction(child);

  // Unsimplified is a very large graph with an add at the top.
  EXPECT_NE(node, nullptr);
  EXPECT_EQ(node->GetType(), SENode::Add);

  // Simplified node should resolve down to a constant expression as the loads
  // will eliminate themselves.
  SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));

  EXPECT_EQ(simplified->GetType(), SENode::Constant);
  EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u);
}

/*
Generated from the following GLSL + --eliminate-local-multi-store

#version 410 core
layout(location = 0) in vec4 c;
layout (location = 1) out float array[10];
void main() {
  int N = int(c.x);
  for (int i = 0; i < 10; ++i) {
    array[i] = array[i];
    array[i] = array[i-1];
    array[i] = array[i+1];
    array[i+1] = array[i+1];
    array[i+N] = array[i+N];
    array[i] = array[i+N];
  }
}

*/
TEST_F(ScalarAnalysisTest, Simplify) {
  const std::string text = R"(               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %4 "main" %12 %33
               OpExecutionMode %4 OriginUpperLeft
               OpSource GLSL 410
               OpName %4 "main"
               OpName %8 "N"
               OpName %12 "c"
               OpName %19 "i"
               OpName %33 "array"
               OpDecorate %12 Location 0
               OpDecorate %33 Location 1
          %2 = OpTypeVoid
          %3 = OpTypeFunction %2
          %6 = OpTypeInt 32 1
          %7 = OpTypePointer Function %6
          %9 = OpTypeFloat 32
         %10 = OpTypeVector %9 4
         %11 = OpTypePointer Input %10
         %12 = OpVariable %11 Input
         %13 = OpTypeInt 32 0
         %14 = OpConstant %13 0
         %15 = OpTypePointer Input %9
         %20 = OpConstant %6 0
         %27 = OpConstant %6 10
         %28 = OpTypeBool
         %30 = OpConstant %13 10
         %31 = OpTypeArray %9 %30
         %32 = OpTypePointer Output %31
         %33 = OpVariable %32 Output
         %36 = OpTypePointer Output %9
         %42 = OpConstant %6 1
          %4 = OpFunction %2 None %3
          %5 = OpLabel
          %8 = OpVariable %7 Function
         %19 = OpVariable %7 Function
         %16 = OpAccessChain %15 %12 %14
         %17 = OpLoad %9 %16
         %18 = OpConvertFToS %6 %17
               OpStore %8 %18
               OpStore %19 %20
               OpBranch %21
         %21 = OpLabel
         %78 = OpPhi %6 %20 %5 %77 %24
               OpLoopMerge %23 %24 None
               OpBranch %25
         %25 = OpLabel
         %29 = OpSLessThan %28 %78 %27
               OpBranchConditional %29 %22 %23
         %22 = OpLabel
         %37 = OpAccessChain %36 %33 %78
         %38 = OpLoad %9 %37
         %39 = OpAccessChain %36 %33 %78
               OpStore %39 %38
         %43 = OpISub %6 %78 %42
         %44 = OpAccessChain %36 %33 %43
         %45 = OpLoad %9 %44
         %46 = OpAccessChain %36 %33 %78
               OpStore %46 %45
         %49 = OpIAdd %6 %78 %42
         %50 = OpAccessChain %36 %33 %49
         %51 = OpLoad %9 %50
         %52 = OpAccessChain %36 %33 %78
               OpStore %52 %51
         %54 = OpIAdd %6 %78 %42
         %56 = OpIAdd %6 %78 %42
         %57 = OpAccessChain %36 %33 %56
         %58 = OpLoad %9 %57
         %59 = OpAccessChain %36 %33 %54
               OpStore %59 %58
         %62 = OpIAdd %6 %78 %18
         %65 = OpIAdd %6 %78 %18
         %66 = OpAccessChain %36 %33 %65
         %67 = OpLoad %9 %66
         %68 = OpAccessChain %36 %33 %62
               OpStore %68 %67
         %72 = OpIAdd %6 %78 %18
         %73 = OpAccessChain %36 %33 %72
         %74 = OpLoad %9 %73
         %75 = OpAccessChain %36 %33 %78
               OpStore %75 %74
               OpBranch %24
         %24 = OpLabel
         %77 = OpIAdd %6 %78 %42
               OpStore %19 %77
               OpBranch %21
         %23 = OpLabel
               OpReturn
               OpFunctionEnd
)";
  // clang-format on
  std::unique_ptr<IRContext> context =
      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  Module* module = context->module();
  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
                             << text << std::endl;
  const Function* f = spvtest::GetFunction(module, 4);
  ScalarEvolutionAnalysis analysis{context.get()};

  const Instruction* loads[6];
  const Instruction* stores[6];
  int load_count = 0;
  int store_count = 0;

  for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) {
    if (inst.opcode() == SpvOp::SpvOpLoad) {
      loads[load_count] = &inst;
      ++load_count;
    }
    if (inst.opcode() == SpvOp::SpvOpStore) {
      stores[store_count] = &inst;
      ++store_count;
    }
  }

  EXPECT_EQ(load_count, 6);
  EXPECT_EQ(store_count, 6);

  Instruction* load_access_chain;
  Instruction* store_access_chain;
  Instruction* load_child;
  Instruction* store_child;
  SENode* load_node;
  SENode* store_node;
  SENode* subtract_node;
  SENode* simplified_node;

  // Testing [i] - [i] == 0
  load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  store_access_chain =
      context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));

  load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));
  store_child = context->get_def_use_mgr()->GetDef(
      store_access_chain->GetSingleWordInOperand(1));

  load_node = analysis.AnalyzeInstruction(load_child);
  store_node = analysis.AnalyzeInstruction(store_child);

  subtract_node = analysis.CreateSubtraction(store_node, load_node);
  simplified_node = analysis.SimplifyExpression(subtract_node);
  EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);

  // Testing [i] - [i-1] == 1
  load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  store_access_chain =
      context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));

  load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));
  store_child = context->get_def_use_mgr()->GetDef(
      store_access_chain->GetSingleWordInOperand(1));

  load_node = analysis.AnalyzeInstruction(load_child);
  store_node = analysis.AnalyzeInstruction(store_child);

  subtract_node = analysis.CreateSubtraction(store_node, load_node);
  simplified_node = analysis.SimplifyExpression(subtract_node);

  EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u);

  // Testing [i] - [i+1] == -1
  load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
  store_access_chain =
      context->get_def_use_mgr()->GetDef(stores[2]->GetSingleWordInOperand(0));

  load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));
  store_child = context->get_def_use_mgr()->GetDef(
      store_access_chain->GetSingleWordInOperand(1));

  load_node = analysis.AnalyzeInstruction(load_child);
  store_node = analysis.AnalyzeInstruction(store_child);

  subtract_node = analysis.CreateSubtraction(store_node, load_node);
  simplified_node = analysis.SimplifyExpression(subtract_node);
  EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1);

  // Testing [i+1] - [i+1] == 0
  load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[3]->GetSingleWordInOperand(0));
  store_access_chain =
      context->get_def_use_mgr()->GetDef(stores[3]->GetSingleWordInOperand(0));

  load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));
  store_child = context->get_def_use_mgr()->GetDef(
      store_access_chain->GetSingleWordInOperand(1));

  load_node = analysis.AnalyzeInstruction(load_child);
  store_node = analysis.AnalyzeInstruction(store_child);

  subtract_node = analysis.CreateSubtraction(store_node, load_node);
  simplified_node = analysis.SimplifyExpression(subtract_node);
  EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);

  // Testing [i+N] - [i+N] == 0
  load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[4]->GetSingleWordInOperand(0));
  store_access_chain =
      context->get_def_use_mgr()->GetDef(stores[4]->GetSingleWordInOperand(0));

  load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));
  store_child = context->get_def_use_mgr()->GetDef(
      store_access_chain->GetSingleWordInOperand(1));

  load_node = analysis.AnalyzeInstruction(load_child);
  store_node = analysis.AnalyzeInstruction(store_child);

  subtract_node = analysis.CreateSubtraction(store_node, load_node);

  simplified_node = analysis.SimplifyExpression(subtract_node);
  EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);

  // Testing [i] - [i+N] == -N
  load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[5]->GetSingleWordInOperand(0));
  store_access_chain =
      context->get_def_use_mgr()->GetDef(stores[5]->GetSingleWordInOperand(0));

  load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));
  store_child = context->get_def_use_mgr()->GetDef(
      store_access_chain->GetSingleWordInOperand(1));

  load_node = analysis.AnalyzeInstruction(load_child);
  store_node = analysis.AnalyzeInstruction(store_child);

  subtract_node = analysis.CreateSubtraction(store_node, load_node);
  simplified_node = analysis.SimplifyExpression(subtract_node);
  EXPECT_EQ(simplified_node->GetType(), SENode::Negative);
}

/*
Generated from the following GLSL + --eliminate-local-multi-store

#version 430
layout(location = 1) out float array[10];
layout(location = 2) flat in int loop_invariant;
void main(void) {
  for (int i = 0; i < 10; ++i) {
    array[i * 2 + i * 5] = array[i * i * 2];
    array[i * 2] = array[i * 5];
  }
}

*/

TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) {
  const std::string text = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %2 "main" %3 %4
               OpExecutionMode %2 OriginUpperLeft
               OpSource GLSL 430
               OpName %2 "main"
               OpName %5 "i"
               OpName %3 "array"
               OpName %4 "loop_invariant"
               OpDecorate %3 Location 1
               OpDecorate %4 Flat
               OpDecorate %4 Location 2
          %6 = OpTypeVoid
          %7 = OpTypeFunction %6
          %8 = OpTypeInt 32 1
          %9 = OpTypePointer Function %8
         %10 = OpConstant %8 0
         %11 = OpConstant %8 10
         %12 = OpTypeBool
         %13 = OpTypeFloat 32
         %14 = OpTypeInt 32 0
         %15 = OpConstant %14 10
         %16 = OpTypeArray %13 %15
         %17 = OpTypePointer Output %16
          %3 = OpVariable %17 Output
         %18 = OpConstant %8 2
         %19 = OpConstant %8 5
         %20 = OpTypePointer Output %13
         %21 = OpConstant %8 1
         %22 = OpTypePointer Input %8
          %4 = OpVariable %22 Input
          %2 = OpFunction %6 None %7
         %23 = OpLabel
          %5 = OpVariable %9 Function
               OpStore %5 %10
               OpBranch %24
         %24 = OpLabel
         %25 = OpPhi %8 %10 %23 %26 %27
               OpLoopMerge %28 %27 None
               OpBranch %29
         %29 = OpLabel
         %30 = OpSLessThan %12 %25 %11
               OpBranchConditional %30 %31 %28
         %31 = OpLabel
         %32 = OpIMul %8 %25 %18
         %33 = OpIMul %8 %25 %19
         %34 = OpIAdd %8 %32 %33
         %35 = OpIMul %8 %25 %25
         %36 = OpIMul %8 %35 %18
         %37 = OpAccessChain %20 %3 %36
         %38 = OpLoad %13 %37
         %39 = OpAccessChain %20 %3 %34
               OpStore %39 %38
         %40 = OpIMul %8 %25 %18
         %41 = OpIMul %8 %25 %19
         %42 = OpAccessChain %20 %3 %41
         %43 = OpLoad %13 %42
         %44 = OpAccessChain %20 %3 %40
               OpStore %44 %43
               OpBranch %27
         %27 = OpLabel
         %26 = OpIAdd %8 %25 %21
               OpStore %5 %26
               OpBranch %24
         %28 = OpLabel
               OpReturn
               OpFunctionEnd
    )";
  std::unique_ptr<IRContext> context =
      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  Module* module = context->module();
  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
                             << text << std::endl;
  const Function* f = spvtest::GetFunction(module, 2);
  ScalarEvolutionAnalysis analysis{context.get()};

  const Instruction* loads[2] = {nullptr, nullptr};
  const Instruction* stores[2] = {nullptr, nullptr};
  int load_count = 0;
  int store_count = 0;

  for (const Instruction& inst : *spvtest::GetBasicBlock(f, 31)) {
    if (inst.opcode() == SpvOp::SpvOpLoad) {
      loads[load_count] = &inst;
      ++load_count;
    }
    if (inst.opcode() == SpvOp::SpvOpStore) {
      stores[store_count] = &inst;
      ++store_count;
    }
  }

  EXPECT_EQ(load_count, 2);
  EXPECT_EQ(store_count, 2);

  Instruction* load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  Instruction* store_access_chain =
      context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));

  Instruction* load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));
  Instruction* store_child = context->get_def_use_mgr()->GetDef(
      store_access_chain->GetSingleWordInOperand(1));

  SENode* store_node = analysis.AnalyzeInstruction(store_child);

  SENode* store_simplified = analysis.SimplifyExpression(store_node);

  load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  store_access_chain =
      context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
  load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));
  store_child = context->get_def_use_mgr()->GetDef(
      store_access_chain->GetSingleWordInOperand(1));

  SENode* second_store =
      analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child));
  SENode* second_load =
      analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child));
  SENode* combined_add = analysis.SimplifyExpression(
      analysis.CreateAddNode(second_load, second_store));

  // We're checking that the two recurrent expression have been correctly
  // folded. In store_simplified they will have been folded as the entire
  // expression was simplified as one. In combined_add the two expressions have
  // been simplified one after the other which means the recurrent expressions
  // aren't exactly the same but should still be folded as they are with respect
  // to the same loop.
  EXPECT_EQ(combined_add, store_simplified);
}

/*
Generated from the following GLSL + --eliminate-local-multi-store

#version 430
void main(void) {
    for (int i = 0; i < 10; --i) {
        array[i] = array[i];
    }
}

*/

TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) {
  const std::string text = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %2 "main" %3 %4
               OpExecutionMode %2 OriginUpperLeft
               OpSource GLSL 430
               OpName %2 "main"
               OpName %5 "i"
               OpName %3 "array"
               OpName %4 "loop_invariant"
               OpDecorate %3 Location 1
               OpDecorate %4 Flat
               OpDecorate %4 Location 2
          %6 = OpTypeVoid
          %7 = OpTypeFunction %6
          %8 = OpTypeInt 32 1
          %9 = OpTypePointer Function %8
         %10 = OpConstant %8 0
         %11 = OpConstant %8 10
         %12 = OpTypeBool
         %13 = OpTypeFloat 32
         %14 = OpTypeInt 32 0
         %15 = OpConstant %14 10
         %16 = OpTypeArray %13 %15
         %17 = OpTypePointer Output %16
          %3 = OpVariable %17 Output
         %18 = OpTypePointer Output %13
         %19 = OpConstant %8 1
         %20 = OpTypePointer Input %8
          %4 = OpVariable %20 Input
          %2 = OpFunction %6 None %7
         %21 = OpLabel
          %5 = OpVariable %9 Function
               OpStore %5 %10
               OpBranch %22
         %22 = OpLabel
         %23 = OpPhi %8 %10 %21 %24 %25
               OpLoopMerge %26 %25 None
               OpBranch %27
         %27 = OpLabel
         %28 = OpSLessThan %12 %23 %11
               OpBranchConditional %28 %29 %26
         %29 = OpLabel
         %30 = OpAccessChain %18 %3 %23
         %31 = OpLoad %13 %30
         %32 = OpAccessChain %18 %3 %23
               OpStore %32 %31
               OpBranch %25
         %25 = OpLabel
         %24 = OpISub %8 %23 %19
               OpStore %5 %24
               OpBranch %22
         %26 = OpLabel
               OpReturn
               OpFunctionEnd
    )";
  std::unique_ptr<IRContext> context =
      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  Module* module = context->module();
  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
                             << text << std::endl;
  const Function* f = spvtest::GetFunction(module, 2);
  ScalarEvolutionAnalysis analysis{context.get()};

  const Instruction* loads[1] = {nullptr};
  int load_count = 0;

  for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) {
    if (inst.opcode() == SpvOp::SpvOpLoad) {
      loads[load_count] = &inst;
      ++load_count;
    }
  }

  EXPECT_EQ(load_count, 1);

  Instruction* load_access_chain =
      context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  Instruction* load_child = context->get_def_use_mgr()->GetDef(
      load_access_chain->GetSingleWordInOperand(1));

  SENode* load_node = analysis.AnalyzeInstruction(load_child);

  EXPECT_TRUE(load_node);
  EXPECT_EQ(load_node->GetType(), SENode::RecurrentAddExpr);
  EXPECT_TRUE(load_node->AsSERecurrentNode());

  SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient();
  SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset();

  EXPECT_EQ(child_1->GetType(), SENode::Constant);
  EXPECT_EQ(child_2->GetType(), SENode::Constant);

  EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1);
  EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u);

  SERecurrentNode* load_simplified =
      analysis.SimplifyExpression(load_node)->AsSERecurrentNode();

  EXPECT_TRUE(load_simplified);
  EXPECT_EQ(load_node, load_simplified);

  EXPECT_EQ(load_simplified->GetType(), SENode::RecurrentAddExpr);
  EXPECT_TRUE(load_simplified->AsSERecurrentNode());

  SENode* simplified_child_1 =
      load_simplified->AsSERecurrentNode()->GetCoefficient();
  SENode* simplified_child_2 =
      load_simplified->AsSERecurrentNode()->GetOffset();

  EXPECT_EQ(child_1, simplified_child_1);
  EXPECT_EQ(child_2, simplified_child_2);
}

/*
Generated from the following GLSL + --eliminate-local-multi-store

#version 430
void main(void) {
    for (int i = 0; i < 10; --i) {
        array[i] = array[i];
    }
}

*/

TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) {
  const std::string text = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %2 "main" %3 %4
               OpExecutionMode %2 OriginUpperLeft
               OpSource GLSL 430
               OpName %2 "main"
               OpName %5 "i"
               OpName %3 "array"
               OpName %4 "N"
               OpDecorate %3 Location 1
               OpDecorate %4 Flat
               OpDecorate %4 Location 2
          %6 = OpTypeVoid
          %7 = OpTypeFunction %6
          %8 = OpTypeInt 32 1
          %9 = OpTypePointer Function %8
         %10 = OpConstant %8 0
         %11 = OpConstant %8 10
         %12 = OpTypeBool
         %13 = OpTypeFloat 32
         %14 = OpTypeInt 32 0
         %15 = OpConstant %14 10
         %16 = OpTypeArray %13 %15
         %17 = OpTypePointer Output %16
          %3 = OpVariable %17 Output
         %18 = OpConstant %8 2
         %19 = OpTypePointer Input %8
          %4 = OpVariable %19 Input
         %20 = OpTypePointer Output %13
         %21 = OpConstant %8 1
          %2 = OpFunction %6 None %7
         %22 = OpLabel
          %5 = OpVariable %9 Function
               OpStore %5 %10
               OpBranch %23
         %23 = OpLabel
         %24 = OpPhi %8 %10 %22 %25 %26
               OpLoopMerge %27 %26 None
               OpBranch %28
         %28 = OpLabel
         %29 = OpSLessThan %12 %24 %11
               OpBranchConditional %29 %30 %27
         %30 = OpLabel
         %31 = OpLoad %8 %4
         %32 = OpIMul %8 %18 %31
         %33 = OpIAdd %8 %24 %32
         %35 = OpIAdd %8 %24 %31
         %36 = OpAccessChain %20 %3 %35
         %37 = OpLoad %13 %36
         %38 = OpAccessChain %20 %3 %33
               OpStore %38 %37
         %39 = OpIMul %8 %18 %24
         %41 = OpIMul %8 %18 %31
         %42 = OpIAdd %8 %39 %41
         %43 = OpIAdd %8 %42 %21
         %44 = OpIMul %8 %18 %24
         %46 = OpIAdd %8 %44 %31
         %47 = OpIAdd %8 %46 %21
         %48 = OpAccessChain %20 %3 %47
         %49 = OpLoad %13 %48
         %50 = OpAccessChain %20 %3 %43
               OpStore %50 %49
               OpBranch %26
         %26 = OpLabel
         %25 = OpISub %8 %24 %21
               OpStore %5 %25
               OpBranch %23
         %27 = OpLabel
               OpReturn
               OpFunctionEnd
    )";
  std::unique_ptr<IRContext> context =
      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  Module* module = context->module();
  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
                             << text << std::endl;
  const Function* f = spvtest::GetFunction(module, 2);
  ScalarEvolutionAnalysis analysis{context.get()};

  std::vector<const Instruction*> loads{};
  std::vector<const Instruction*> stores{};

  for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) {
    if (inst.opcode() == SpvOp::SpvOpLoad) {
      loads.push_back(&inst);
    }
    if (inst.opcode() == SpvOp::SpvOpStore) {
      stores.push_back(&inst);
    }
  }

  EXPECT_EQ(loads.size(), 3u);
  EXPECT_EQ(stores.size(), 2u);
  {
    Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
        stores[0]->GetSingleWordInOperand(0));

    Instruction* store_child = context->get_def_use_mgr()->GetDef(
        store_access_chain->GetSingleWordInOperand(1));

    SENode* store_node = analysis.AnalyzeInstruction(store_child);

    SENode* store_simplified = analysis.SimplifyExpression(store_node);

    Instruction* load_access_chain =
        context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));

    Instruction* load_child = context->get_def_use_mgr()->GetDef(
        load_access_chain->GetSingleWordInOperand(1));

    SENode* load_node = analysis.AnalyzeInstruction(load_child);

    SENode* load_simplified = analysis.SimplifyExpression(load_node);

    SENode* difference =
        analysis.CreateSubtraction(store_simplified, load_simplified);

    SENode* difference_simplified = analysis.SimplifyExpression(difference);

    // Check that i+2*N  -  i*N, turns into just N when both sides have already
    // been simplified into a single recurrent expression.
    EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);

    // Check that the inverse, i*N - i+2*N turns into -N.
    SENode* difference_inverse = analysis.SimplifyExpression(
        analysis.CreateSubtraction(load_simplified, store_simplified));

    EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
    EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
    EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
  }

  {
    Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
        stores[1]->GetSingleWordInOperand(0));

    Instruction* store_child = context->get_def_use_mgr()->GetDef(
        store_access_chain->GetSingleWordInOperand(1));
    SENode* store_node = analysis.AnalyzeInstruction(store_child);
    SENode* store_simplified = analysis.SimplifyExpression(store_node);

    Instruction* load_access_chain =
        context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));

    Instruction* load_child = context->get_def_use_mgr()->GetDef(
        load_access_chain->GetSingleWordInOperand(1));

    SENode* load_node = analysis.AnalyzeInstruction(load_child);

    SENode* load_simplified = analysis.SimplifyExpression(load_node);

    SENode* difference =
        analysis.CreateSubtraction(store_simplified, load_simplified);
    SENode* difference_simplified = analysis.SimplifyExpression(difference);

    // Check that 2*i + 2*N + 1  -  2*i + N + 1, turns into just N when both
    // sides have already been simplified into a single recurrent expression.
    EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);

    // Check that the inverse, (2*i + N + 1)  -  (2*i + 2*N + 1) turns into -N.
    SENode* difference_inverse = analysis.SimplifyExpression(
        analysis.CreateSubtraction(load_simplified, store_simplified));

    EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
    EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
    EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
  }
}

/* Generated from the following GLSL + --eliminate-local-multi-store

  #version 430
  layout(location = 1) out float array[10];
  layout(location = 2) flat in int N;
  void main(void) {
    int step = 0;
    for (int i = 0; i < N; i += step) {
      step++;
    }
  }
*/
TEST_F(ScalarAnalysisTest, InductionWithVariantStep) {
  const std::string text = R"(
               OpCapability Shader
          %1 = OpExtInstImport "GLSL.std.450"
               OpMemoryModel Logical GLSL450
               OpEntryPoint Fragment %2 "main" %3 %4
               OpExecutionMode %2 OriginUpperLeft
               OpSource GLSL 430
               OpName %2 "main"
               OpName %5 "step"
               OpName %6 "i"
               OpName %3 "N"
               OpName %4 "array"
               OpDecorate %3 Flat
               OpDecorate %3 Location 2
               OpDecorate %4 Location 1
          %7 = OpTypeVoid
          %8 = OpTypeFunction %7
          %9 = OpTypeInt 32 1
         %10 = OpTypePointer Function %9
         %11 = OpConstant %9 0
         %12 = OpTypePointer Input %9
          %3 = OpVariable %12 Input
         %13 = OpTypeBool
         %14 = OpConstant %9 1
         %15 = OpTypeFloat 32
         %16 = OpTypeInt 32 0
         %17 = OpConstant %16 10
         %18 = OpTypeArray %15 %17
         %19 = OpTypePointer Output %18
          %4 = OpVariable %19 Output
          %2 = OpFunction %7 None %8
         %20 = OpLabel
          %5 = OpVariable %10 Function
          %6 = OpVariable %10 Function
               OpStore %5 %11
               OpStore %6 %11
               OpBranch %21
         %21 = OpLabel
         %22 = OpPhi %9 %11 %20 %23 %24
         %25 = OpPhi %9 %11 %20 %26 %24
               OpLoopMerge %27 %24 None
               OpBranch %28
         %28 = OpLabel
         %29 = OpLoad %9 %3
         %30 = OpSLessThan %13 %25 %29
               OpBranchConditional %30 %31 %27
         %31 = OpLabel
         %23 = OpIAdd %9 %22 %14
               OpStore %5 %23
               OpBranch %24
         %24 = OpLabel
         %26 = OpIAdd %9 %25 %23
               OpStore %6 %26
               OpBranch %21
         %27 = OpLabel
               OpReturn
               OpFunctionEnd
  )";
  std::unique_ptr<IRContext> context =
      BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
                  SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  Module* module = context->module();
  EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
                             << text << std::endl;
  const Function* f = spvtest::GetFunction(module, 2);
  ScalarEvolutionAnalysis analysis{context.get()};

  std::vector<const Instruction*> phis{};

  for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
    if (inst.opcode() == SpvOp::SpvOpPhi) {
      phis.push_back(&inst);
    }
  }

  EXPECT_EQ(phis.size(), 2u);
  SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]);
  SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]);
  phi_node_1->DumpDot(std::cout, true);
  EXPECT_NE(phi_node_1, nullptr);
  EXPECT_NE(phi_node_2, nullptr);

  EXPECT_EQ(phi_node_1->GetType(), SENode::RecurrentAddExpr);
  EXPECT_EQ(phi_node_2->GetType(), SENode::CanNotCompute);

  SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1);
  SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2);

  EXPECT_EQ(simplified_1->GetType(), SENode::RecurrentAddExpr);
  EXPECT_EQ(simplified_2->GetType(), SENode::CanNotCompute);
}

}  // namespace
}  // namespace opt
}  // namespace spvtools