// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/opt/licm_pass.h"

#include <queue>
#include <utility>

#include "source/opt/module.h"
#include "source/opt/pass.h"

namespace spvtools {
namespace opt {

Pass::Status LICMPass::Process() { return ProcessIRContext(); }

Pass::Status LICMPass::ProcessIRContext() {
  Status status = Status::SuccessWithoutChange;
  Module* module = get_module();

  // Process each function in the module
  for (auto func = module->begin();
       func != module->end() && status != Status::Failure; ++func) {
    status = CombineStatus(status, ProcessFunction(&*func));
  }
  return status;
}

Pass::Status LICMPass::ProcessFunction(Function* f) {
  Status status = Status::SuccessWithoutChange;
  LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);

  // Process each loop in the function
  for (auto it = loop_descriptor->begin();
       it != loop_descriptor->end() && status != Status::Failure; ++it) {
    Loop& loop = *it;
    // Ignore nested loops, as we will process them in order in ProcessLoop
    if (loop.IsNested()) {
      continue;
    }
    status = CombineStatus(status, ProcessLoop(&loop, f));
  }
  return status;
}

Pass::Status LICMPass::ProcessLoop(Loop* loop, Function* f) {
  Status status = Status::SuccessWithoutChange;

  // Process all nested loops first
  for (auto nl = loop->begin(); nl != loop->end() && status != Status::Failure;
       ++nl) {
    Loop* nested_loop = *nl;
    status = CombineStatus(status, ProcessLoop(nested_loop, f));
  }

  std::vector<BasicBlock*> loop_bbs{};
  status = CombineStatus(
      status,
      AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs));

  for (size_t i = 0; i < loop_bbs.size() && status != Status::Failure; ++i) {
    BasicBlock* bb = loop_bbs[i];
    // do not delete the element
    status =
        CombineStatus(status, AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs));
  }

  return status;
}

Pass::Status LICMPass::AnalyseAndHoistFromBB(
    Loop* loop, Function* f, BasicBlock* bb,
    std::vector<BasicBlock*>* loop_bbs) {
  bool modified = false;
  std::function<bool(Instruction*)> hoist_inst =
      [this, &loop, &modified](Instruction* inst) {
        if (loop->ShouldHoistInstruction(this->context(), inst)) {
          if (!HoistInstruction(loop, inst)) {
            return false;
          }
          modified = true;
        }
        return true;
      };

  if (IsImmediatelyContainedInLoop(loop, f, bb)) {
    if (!bb->WhileEachInst(hoist_inst, false)) {
      return Status::Failure;
    }
  }

  DominatorAnalysis* dom_analysis = context()->GetDominatorAnalysis(f);
  DominatorTree& dom_tree = dom_analysis->GetDomTree();

  for (DominatorTreeNode* child_dom_tree_node : *dom_tree.GetTreeNode(bb)) {
    if (loop->IsInsideLoop(child_dom_tree_node->bb_)) {
      loop_bbs->push_back(child_dom_tree_node->bb_);
    }
  }

  return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}

bool LICMPass::IsImmediatelyContainedInLoop(Loop* loop, Function* f,
                                            BasicBlock* bb) {
  LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);
  return loop == (*loop_descriptor)[bb->id()];
}

bool LICMPass::HoistInstruction(Loop* loop, Instruction* inst) {
  // TODO(1841): Handle failure to create pre-header.
  BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock();
  if (!pre_header_bb) {
    return false;
  }
  Instruction* insertion_point = &*pre_header_bb->tail();
  Instruction* previous_node = insertion_point->PreviousNode();
  if (previous_node && (previous_node->opcode() == SpvOpLoopMerge ||
                        previous_node->opcode() == SpvOpSelectionMerge)) {
    insertion_point = previous_node;
  }

  inst->InsertBefore(insertion_point);
  context()->set_instr_block(inst, pre_header_bb);
  return true;
}

}  // namespace opt
}  // namespace spvtools