// Copyright 2015 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "eight_bit_int_gemm.h"

#include <memory>

// gemmlowp symbols should have hidden visibility.
// currently this is ensured in the build system by
// passing -finlines-visibility-hidden. TODO: it would be
// safer to hardcode it here with some #pragma's.
#include "../public/gemmlowp.h"

// Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
// code. This code path consists of a number of meta-programmed, automatically
// generated GEMM kernels that are suitable for some sizes of input matrices.
// Due to the fact that the generated code relies heavily on loop unrolling,
// inling and currying of runtime parameters the size of the generated binary
// is quite significant (approx. 200kb) which might be prohibitive in
// low-memory situations.

#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
#include "../meta/multi_thread_gemm.h"
#endif

namespace gemmlowp {
namespace eight_bit_int_gemm {
namespace {

// To be used as template parameter for GlobalLock.
// GlobalLock<EightBitIntGemmLockId> is the global lock
// on EightBitIntGemm entry points, protecting
// EightBitIntGemm's global state.
struct EightBitIntGemmLockId;

// Global state: consists of one global GemmContext instance.
GemmContext* global_context;

GemmContext* GetOrCreateGlobalContext() {
  if (!global_context) {
    global_context = new GemmContext;
  }
  return global_context;
}

void DestroyGlobalContext() {
  delete global_context;
  global_context = nullptr;
}

template <bool transpose_a, bool transpose_b, bool transpose_c>
void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
                         const std::uint8_t* a, std::int32_t a_offset, int lda,
                         const std::uint8_t* b, std::int32_t b_offset, int ldb,
                         std::uint8_t* c, std::int32_t c_offset,
                         std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
                         BitDepthSetting bit_depth) {
  const int lhs_offset = a_offset;
  const int rhs_offset = b_offset;
  const int result_offset = c_offset;
  const int result_mult_int = c_mult_int;
  const int result_shift = c_shift;

  static const MapOrder ResultOrder =
      transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
  static const MapOrder LhsOrder =
      transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
  static const MapOrder RhsOrder =
      transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;

  MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
  MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
  MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);

  switch (bit_depth) {
#define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS)     \
  case BitDepthSetting::BIT_DEPTH_SETTING:                                 \
    Gemm<std::uint8_t, BIT_DEPTH_PARAMS>(                                  \
        context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
        result_mult_int, result_shift);                                    \
    return;
    GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
    GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
    default:
      abort();
#undef GEMMLOWP_HANDLE_BIT_DEPTH
  }
}

template <bool transpose_a, bool transpose_b, bool transpose_c>
void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
                              const std::uint8_t* a, std::int32_t a_offset,
                              int lda, const std::uint8_t* b,
                              std::int32_t b_offset, int ldb, std::int32_t* c,
                              int ldc, BitDepthSetting bit_depth) {
  const int lhs_offset = a_offset;
  const int rhs_offset = b_offset;

  static const MapOrder ResultOrder =
      transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
  static const MapOrder LhsOrder =
      transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
  static const MapOrder RhsOrder =
      transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;

  MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
  MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
  MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);

  auto empty_pipeline = std::make_tuple();

  switch (bit_depth) {
#define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
  case BitDepthSetting::BIT_DEPTH_SETTING:                                   \
    GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>(    \
        context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
    return;
    GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
    GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
    default:
      abort();
#undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
  }
}

class Scratch {
 public:
  Scratch() : buffer_(), size_(0) {}

  void AssureSize(std::int32_t required_size) {
    if (size_ >= required_size) {
      return;
    }
    buffer_.reset(new std::uint8_t[required_size]);
    size_ = required_size;
  }

  void Clear() {
    buffer_.reset(nullptr);
    size_ = 0;
  }

  std::uint8_t* buffer() { return buffer_.get(); }

 private:
  std::unique_ptr<std::uint8_t[]> buffer_;
  std::int32_t size_;
};

Scratch* global_scratch = nullptr;

Scratch* GetOrCreateGlobalScratch() {
  if (global_scratch == nullptr) {
    global_scratch = new Scratch();
  }
  return global_scratch;
}

void DestroyGlobalScratch() {
  delete global_scratch;
  global_scratch = nullptr;
}

#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)

bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
  // Is it row major and nicely packed?
  if (transpose && stride == cols) {
    return true;
  }

  // Is it a one row vector? (a vector is both row and column major)
  if (rows == 1) {
    return true;
  }

  return false;
}

bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
  // Is it column major and nicely packed?
  if (!transpose && stride == rows) {
    return true;
  }

  // Is it a one column vector? (a vector is both row and column major)
  if (cols == 1) {
    return true;
  }

  return false;
}

bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
                           int m, int n, int k, int lda, int ldb, int ldc,
                           BitDepthSetting depth_setting) {
  // Meta fastpath only supports 8bit x 8bit and k up to 2048.
  if (depth_setting != BitDepthSetting::A8B8 || k > 2048) {
    return false;
  }

  // The first operand needs to be a row major matrix or a vector.
  if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
    return false;
  }

  // The second operand needs to be a column major matrix or a vector.
  if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
    return false;
  }

  // The result can either be a row major matrix, a column major matrix or
  // a vector.
  if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
    return true;
  }

  if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
    return true;
  }

  return false;
}

// Assure enough scratch memory is allocated and run the fast path gemm.
void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
                           const std::uint8_t* rhs, int m, int n, int k,
                           std::int32_t lhs_offset, std::int32_t rhs_offset,
                           std::int32_t sum_offset,
                           std::int32_t multiplicative_offset,
                           std::int32_t shift, bool result_transpose,
                           std::int32_t result_stride, std::uint8_t* result) {
  Scratch* scratch = GetOrCreateGlobalScratch();
  if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
    scratch->AssureSize(
        meta::gemm_q8_scratch(m, n, k, context->max_num_threads()));
    meta::multi_thread_gemm_q8(
        context->workers_pool(), context->max_num_threads(), scratch->buffer(),
        lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset,
        multiplicative_offset, shift, result);
  } else {
    scratch->AssureSize(
        meta::gemm_q8_scratch(n, m, k, context->max_num_threads()));
    meta::multi_thread_gemm_q8(
        context->workers_pool(), context->max_num_threads(), scratch->buffer(),
        rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset,
        multiplicative_offset, shift, result);
  }
}

// Assure enough scratch memory is allocated and run the 8bit to float fast
// path gemm.
void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
                   const std::uint8_t* rhs, int m, int n, int k,
                   std::int32_t lhs_offset, std::int32_t rhs_offset,
                   float result_offset, bool result_transpose,
                   std::int32_t result_stride, float* result) {
  Scratch* scratch = GetOrCreateGlobalScratch();
  if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
    scratch->AssureSize(
        meta::gemm_f_scratch(m, n, k, context->max_num_threads()));
    meta::multi_thread_gemm_f(
        context->workers_pool(), context->max_num_threads(), scratch->buffer(),
        lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result);
  } else {
    scratch->AssureSize(
        meta::gemm_f_scratch(n, m, k, context->max_num_threads()));
    meta::multi_thread_gemm_f(
        context->workers_pool(), context->max_num_threads(), scratch->buffer(),
        rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result);
  }
}

#endif

}  // end anonymous namespace

// Public interface entry points

void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
                     int m, int n, int k, const std::uint8_t* a,
                     std::int32_t a_offset, int lda, const std::uint8_t* b,
                     std::int32_t b_offset, int ldb, std::uint8_t* c,
                     std::int32_t c_offset, std::int32_t c_mult_int,
                     std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
  AutoGlobalLock<EightBitIntGemmLockId> lock;
  GemmContext* context = GetOrCreateGlobalContext();

#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
  if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
                            ldb, ldc, bit_depth)) {
    MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
                          c_mult_int, c_shift, transpose_c, ldc, c);
    return;
  }
#endif

#define GEMMLOWP_HANDLE_CASE(ta, tb, tc)                                    \
  if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {        \
    EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b,  \
                                    b_offset, ldb, c, c_offset, c_mult_int, \
                                    c_shift, ldc, bit_depth);               \
  }

  GEMMLOWP_HANDLE_CASE(false, false, false)
  GEMMLOWP_HANDLE_CASE(false, false, true)
  GEMMLOWP_HANDLE_CASE(false, true, false)
  GEMMLOWP_HANDLE_CASE(false, true, true)
  GEMMLOWP_HANDLE_CASE(true, false, false)
  GEMMLOWP_HANDLE_CASE(true, false, true)
  GEMMLOWP_HANDLE_CASE(true, true, false)
  GEMMLOWP_HANDLE_CASE(true, true, true)

#undef GEMMLOWP_HANDLE_CASE
}

void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
                     int m, int n, int k, const std::uint8_t* a,
                     std::int32_t a_offset, std::int32_t lda,
                     const std::uint8_t* b, std::int32_t b_offset,
                     std::int32_t ldb, float* c, float c_offset,
                     std::int32_t ldc, BitDepthSetting bit_depth) {
  AutoGlobalLock<EightBitIntGemmLockId> lock;
  GemmContext* context = GetOrCreateGlobalContext();

#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
  if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
                            ldb, ldc, bit_depth)) {
    MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
                  transpose_c, ldc, c);
    return;
  }
#endif

  // TODO(maciekc): implement a float output stage, get rid of scratch memory.
  Scratch* scratch = GetOrCreateGlobalScratch();
  if (transpose_c) {
    scratch->AssureSize(m * ldc * sizeof(std::int32_t));
  } else {
    scratch->AssureSize(n * ldc * sizeof(std::int32_t));
  }
  std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());

#define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc)                               \
  if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {         \
    EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
                                         b, b_offset, ldb, temp_c, ldc,      \
                                         bit_depth);                         \
  }

  GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
  GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
  GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
  GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
  GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
  GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
  GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
  GEMMLOWP_HANDLE_INT32_CASE(true, true, true)

#undef GEMMLOWP_HANDLE_INT32_CASE

  if (transpose_c) {
    // Row major.
    for (int i = 0; i < m; ++i) {
      float* dest_row = c + i * ldc;
      std::int32_t* src_row = temp_c + i * ldc;
      for (int j = 0; j < n; ++j) {
        dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
      }
    }
  } else {
    // Column major.
    for (int i = 0; i < n; ++i) {
      float* dest_column = c + i * ldc;
      std::int32_t* src_column = temp_c + i * ldc;
      for (int j = 0; j < m; ++j) {
        dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
      }
    }
  }
}

void SetMaxNumThreads(int n) {
  AutoGlobalLock<EightBitIntGemmLockId> lock;
  GemmContext* context = GetOrCreateGlobalContext();
  context->set_max_num_threads(n);
}

void FreePersistentResources() {
  AutoGlobalLock<EightBitIntGemmLockId> lock;
  DestroyGlobalContext();
  DestroyGlobalScratch();
}

}  // namespace eight_bit_int_gemm
}  // namespace gemmlowp