// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// simd_wrappers.h: some inline functions wrapping SIMD intrinsics,
// extending the set of such functions from fixedpoint.h.

#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_
#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_

#include <algorithm>
#include <type_traits>
#include "../fixedpoint/fixedpoint.h"

namespace gemmlowp {

template <typename ScalarType, int ScalarCount>
struct RegisterType {
  using Type = ScalarType;
};

inline std::int32_t Min(std::int32_t a, std::int32_t b) {
  return std::min(a, b);
}

inline std::int32_t Max(std::int32_t a, std::int32_t b) {
  return std::max(a, b);
}

inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) {
  *acc += lhs * rhs;
}

template <typename tScalarType, int tScalarCount>
struct RegisterBuffer {
  using ScalarType = tScalarType;
  static constexpr int kScalarCount = tScalarCount;
  using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type;
  static_assert((kScalarCount & (kScalarCount - 1)) == 0,
                "kScalarCount must be a power of two");
  static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, "");
  static constexpr int kRegisterLanes =
      sizeof(RegisterType) / sizeof(ScalarType);
  static constexpr int kRegisterCount =
      (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) /
      sizeof(RegisterType);

  RegisterType reg[kRegisterCount];
};

template <typename tScalarType, int tRows, int tCols>
struct RegisterBlock {
  using ScalarType = tScalarType;
  static constexpr int kRows = tRows;
  static constexpr int kCols = tCols;
  static constexpr int kScalarCount = kRows * kCols;
  using BufferType = RegisterBuffer<ScalarType, kScalarCount>;
  using RegisterType = typename BufferType::RegisterType;
  static constexpr int kRegisterCount = BufferType::kRegisterCount;
  static constexpr int kRegisterLanes = BufferType::kRegisterLanes;

  BufferType buf;
};

template <typename RegisterBlockType>
struct RegisterBlockAddImpl {
  static RegisterBlockType Run(const RegisterBlockType& lhs,
                               const RegisterBlockType& rhs) {
    RegisterBlockType result;
    for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
      result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
    }
    return result;
  }
};

template <typename RegisterBlockType>
RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs,
                                   const RegisterBlockType& rhs) {
  return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs);
}

template <typename LhsType, typename RhsType>
struct ShouldFlipLhsRhs {
  static constexpr bool kValue =
      (LhsType::kScalarCount < RhsType::kScalarCount) ||
      (LhsType::kScalarCount == RhsType::kScalarCount &&
       (LhsType::kRows < RhsType::kRows));
};

template <typename LhsType, typename RhsType,
          bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue>
struct FlipLhsRhs {
  using FlippedLhsType = LhsType;
  using FlippedRhsType = RhsType;
  static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
                                          const RhsType& rhs) {
    return lhs;
  }
  static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
                                          const RhsType& rhs) {
    return rhs;
  }
};

template <typename LhsType, typename RhsType>
struct FlipLhsRhs<LhsType, RhsType, true> {
  using FlippedLhsType = RhsType;
  using FlippedRhsType = LhsType;
  static const FlippedLhsType& FlippedLhs(const LhsType& lhs,
                                          const RhsType& rhs) {
    return rhs;
  }
  static const FlippedRhsType& FlippedRhs(const LhsType& lhs,
                                          const RhsType& rhs) {
    return lhs;
  }
};

template <typename Lhs, typename Rhs>
struct BroadcastBinaryOpShape {
  static constexpr int kRows =
      Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows;
  static constexpr int kCols =
      Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols;
};

template <typename Lhs, typename Rhs>
struct BroadcastBinaryOpRegisterBlock {
  using Shape = BroadcastBinaryOpShape<Lhs, Rhs>;
  using ScalarType = typename Lhs::ScalarType;
  using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
};

template <typename Lhs, typename Rhs>
struct BroadcastAddImpl {
  using ResultBlockType =
      typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
  static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
    ResultBlockType result;
    static constexpr int Rows = ResultBlockType::kRows;
    static constexpr int Cols = ResultBlockType::kCols;
    static constexpr int LhsRows = Lhs::kRows;
    static constexpr int LhsCols = Lhs::kCols;
    static constexpr int RhsRows = Rhs::kRows;
    static constexpr int RhsCols = Rhs::kCols;

    static_assert(LhsRows == Rows || LhsRows == 1, "");
    static_assert(RhsRows == Rows || RhsRows == 1, "");
    static_assert(LhsCols == Cols || LhsCols == 1, "");
    static_assert(RhsCols == Cols || RhsCols == 1, "");
    static_assert(ResultBlockType::kRegisterLanes == 1,
                  "This path is only for scalar values");
    static_assert(Lhs::kRegisterLanes == 1,
                  "This path is only for scalar values");
    static_assert(Rhs::kRegisterLanes == 1,
                  "This path is only for scalar values");

    for (int c = 0; c < Cols; c++) {
      const int lhs_c = LhsCols == Cols ? c : 0;
      const int rhs_c = RhsCols == Cols ? c : 0;
      for (int r = 0; r < Rows; r++) {
        const int lhs_r = LhsRows == Rows ? r : 0;
        const int rhs_r = RhsRows == Rows ? r : 0;
        result.buf.reg[r + c * Rows] =
            Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
                rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
      }
    }
    return result;
  }
};

template <typename Lhs, typename Rhs>
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd(
    const Lhs& lhs, const Rhs& rhs) {
  using Flip = FlipLhsRhs<Lhs, Rhs>;
  return BroadcastAddImpl<
      typename Flip::FlippedLhsType,
      typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
                                          Flip::FlippedRhs(lhs, rhs));
}

template <typename Lhs, typename Rhs>
struct BroadcastMulImpl {
  using ResultBlockType =
      typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type;
  static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) {
    ResultBlockType result;
    static constexpr int Rows = ResultBlockType::kRows;
    static constexpr int Cols = ResultBlockType::kCols;
    static constexpr int LhsRows = Lhs::kRows;
    static constexpr int LhsCols = Lhs::kCols;
    static constexpr int RhsRows = Rhs::kRows;
    static constexpr int RhsCols = Rhs::kCols;
    static_assert(ResultBlockType::kRegisterLanes == 1,
                  "This path is only for scalar values");
    static_assert(Lhs::kRegisterLanes == 1,
                  "This path is only for scalar values");
    static_assert(Rhs::kRegisterLanes == 1,
                  "This path is only for scalar values");

    static_assert(LhsRows == Rows || LhsRows == 1, "");
    static_assert(RhsRows == Rows || RhsRows == 1, "");
    static_assert(LhsCols == Cols || LhsCols == 1, "");
    static_assert(RhsCols == Cols || RhsCols == 1, "");
    for (int c = 0; c < Cols; c++) {
      const int lhs_c = LhsCols == Cols ? c : 0;
      const int rhs_c = RhsCols == Cols ? c : 0;
      for (int r = 0; r < Rows; r++) {
        const int lhs_r = LhsRows == Rows ? r : 0;
        const int rhs_r = RhsRows == Rows ? r : 0;
        result.buf.reg[r + c * Rows] =
            Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
                rhs.buf.reg[rhs_r + rhs_c * RhsRows]);
      }
    }
    return result;
  }
};

template <typename Lhs, typename Rhs>
typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul(
    const Lhs& lhs, const Rhs& rhs) {
  using Flip = FlipLhsRhs<Lhs, Rhs>;
  return BroadcastMulImpl<
      typename Flip::FlippedLhsType,
      typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs),
                                          Flip::FlippedRhs(lhs, rhs));
}

template <typename Lhs, typename Rhs, typename Acc>
struct BroadcastMulAddImpl {
  static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
    static constexpr int Rows = Acc::kRows;
    static constexpr int Cols = Acc::kCols;
    static constexpr int LhsRows = Lhs::kRows;
    static constexpr int LhsCols = Lhs::kCols;
    static constexpr int RhsRows = Rhs::kRows;
    static constexpr int RhsCols = Rhs::kCols;
    static_assert(Acc::kRegisterLanes == 1,
                  "This path is only for scalar values");
    static_assert(Lhs::kRegisterLanes == 1,
                  "This path is only for scalar values");
    static_assert(Rhs::kRegisterLanes == 1,
                  "This path is only for scalar values");

    static_assert(LhsRows == Rows || LhsRows == 1, "");
    static_assert(RhsRows == Rows || RhsRows == 1, "");
    static_assert(LhsCols == Cols || LhsCols == 1, "");
    static_assert(RhsCols == Cols || RhsCols == 1, "");
    for (int c = 0; c < Cols; c++) {
      const int lhs_c = LhsCols == Cols ? c : 0;
      const int rhs_c = RhsCols == Cols ? c : 0;
      for (int r = 0; r < Rows; r++) {
        const int lhs_r = LhsRows == Rows ? r : 0;
        const int rhs_r = RhsRows == Rows ? r : 0;
        MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows],
               rhs.buf.reg[rhs_r + rhs_c * RhsRows],
               &acc->buf.reg[r + c * Rows]);
      }
    }
  }
};

template <typename Lhs, typename Rhs, typename Acc>
void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) {
  using Flip = FlipLhsRhs<Lhs, Rhs>;
  BroadcastMulAddImpl<typename Flip::FlippedLhsType,
                      typename Flip::FlippedRhsType,
                      Acc>::Run(Flip::FlippedLhs(lhs, rhs),
                                Flip::FlippedRhs(lhs, rhs), acc);
}

template <typename RegisterBlockType, typename SrcObjectType>
struct LoadImpl {
  static_assert(std::is_same<SrcObjectType, void>::value,
                "This generic impl should never be hit");
};

template <typename ScalarType, int Rows, int Cols, typename SrcScalarType>
struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
  using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>;
  static RegisterBlockType Run(const SrcObjectType& src, int row, int col) {
    RegisterBlockType result;
    int i = 0;
    for (int c = 0; c < Cols; c++) {
      const ScalarType* src_ptr = src.data(row, col + c);
      for (int r = 0; r < Rows; r++) {
        result.buf.reg[i++] = *src_ptr++;
      }
    }
    return result;
  }
};

template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
          VectorShape Shape>
struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
                VectorMap<SrcScalarType, Shape>> {
  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
  using SrcObjectType = VectorMap<SrcScalarType, Shape>;
  static RegisterBlockType Run(const SrcObjectType& src, int pos) {
    static_assert(Shape == VectorShape::Col || Rows == 1, "");
    static_assert(Shape == VectorShape::Row || Cols == 1, "");
    RegisterBlockType result;
    for (int i = 0; i < Rows * Cols; i++) {
      result.buf.reg[i] = src(pos + i);
    }
    return result;
  }
};

template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
          VectorShape Shape>
struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>,
                VectorDup<SrcScalarType, Shape>> {
  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
  using SrcObjectType = VectorDup<SrcScalarType, Shape>;
  static RegisterBlockType Run(const SrcObjectType& src, int) {
    static_assert(Shape == VectorShape::Col || Rows == 1, "");
    static_assert(Shape == VectorShape::Row || Cols == 1, "");
    RegisterBlockType result;
    for (int i = 0; i < Rows * Cols; i++) {
      result.buf.reg[i] = src(0);
    }
    return result;
  }
};

template <typename RegisterBlockType, typename SrcObjectType>
RegisterBlockType Load(const SrcObjectType& src, int row, int col) {
  return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col);
}

template <typename RegisterBlockType, typename SrcObjectType>
RegisterBlockType Load(const SrcObjectType& src, int pos) {
  return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos);
}

template <typename RegisterBlockType>
struct LoadContiguousImpl {
  using ScalarType = typename RegisterBlockType::ScalarType;
  static_assert(RegisterBlockType::kRegisterLanes == 1,
                "This path is only for scalar values");
  static RegisterBlockType Run(const ScalarType* src) {
    RegisterBlockType result;
    for (int i = 0; i < RegisterBlockType::kScalarCount; i++) {
      result.buf.reg[i] = src[i];
    }
    return result;
  }
};

template <typename RegisterBlockType>
RegisterBlockType LoadContiguous(
    const typename RegisterBlockType::ScalarType* src) {
  return LoadContiguousImpl<RegisterBlockType>::Run(src);
}

template <int BroadcastRows, int BroadcastCols, typename SrcObjectType>
struct LoadForBroadcastingShape {};

template <int BroadcastRows, int BroadcastCols, typename ScalarType,
          VectorShape Shape>
struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
                                VectorMap<ScalarType, Shape>> {
  static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1;
  static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1;
};

template <int BroadcastRows, int BroadcastCols, typename ScalarType,
          VectorShape Shape>
struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols,
                                VectorDup<ScalarType, Shape>> {
  static constexpr int kRows = 1;
  static constexpr int kCols = 1;
};

template <typename RegisterBlockType, typename SrcObjectType>
struct LoadForBroadcastingRegisterBlock {
  using Shape =
      LoadForBroadcastingShape<RegisterBlockType::kRows,
                               RegisterBlockType::kCols, SrcObjectType>;
  using ScalarType = typename RegisterBlockType::ScalarType;
  using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>;
};

template <typename RegisterBlockType, typename SrcObjectType>
struct LoadForBroadcastingImpl {
  static_assert(std::is_same<SrcObjectType, void>::value,
                "This generic impl should never be hit");
};

template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
          VectorShape Shape>
struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
                               VectorMap<SrcScalarType, Shape>> {
  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
  using SrcObjectType = VectorMap<SrcScalarType, Shape>;
  using ResultBlockType =
      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
                                                SrcObjectType>::Type;
  static_assert(ResultBlockType::kRegisterLanes == 1,
                "This path is only for scalar values");
  static ResultBlockType Run(const SrcObjectType& src, int pos) {
    ResultBlockType result;
    for (int c = 0; c < ResultBlockType::kCols; c++) {
      for (int r = 0; r < ResultBlockType::kRows; r++) {
        const int i = Shape == VectorShape::Col ? r : c;
        result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i);
      }
    }
    return result;
  }
};

template <typename ScalarType, int Rows, int Cols, typename SrcScalarType,
          VectorShape Shape>
struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>,
                               VectorDup<SrcScalarType, Shape>> {
  using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>;
  using SrcObjectType = VectorDup<SrcScalarType, Shape>;
  using ResultBlockType =
      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
                                                SrcObjectType>::Type;
  static_assert(ResultBlockType::kRegisterLanes == 1,
                "This path is only for scalar values");
  static ResultBlockType Run(const SrcObjectType& src, int) {
    ResultBlockType result;
    for (int c = 0; c < ResultBlockType::kCols; c++) {
      for (int r = 0; r < ResultBlockType::kRows; r++) {
        result.buf.reg[r + c * ResultBlockType::kRows] = src(0);
      }
    }
    return result;
  }
};

template <typename RegisterBlockType, typename SrcObjectType>
typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
                                          SrcObjectType>::Type
LoadForBroadcasting(const SrcObjectType& src, int row, int col) {
  return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(
      src, row, col);
}

template <typename RegisterBlockType, typename SrcObjectType>
typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
                                          SrcObjectType>::Type
LoadForBroadcasting(const SrcObjectType& src, int pos) {
  return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src,
                                                                        pos);
}

template <int ConstantValue, typename RegisterBlockType>
struct AddConstantImpl {
  static void Run(RegisterBlockType* block) {
    using RegisterType = typename RegisterBlockType::RegisterType;
    const RegisterType dup = Dup<RegisterType>(ConstantValue);
    for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) {
      block->buf.reg[i] = Add(block->buf.reg[i], dup);
    }
  }
};

template <typename RegisterBlockType>
struct AddConstantImpl<0, RegisterBlockType> {
  static void Run(RegisterBlockType*) {
    // This is a no-op.
  }
};

template <int ConstantValue, typename RegisterBlockType>
void AddConstant(RegisterBlockType* block) {
  AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block);
}

template <int N>
using RegBufferInt32 = RegisterBuffer<std::int32_t, N>;
template <int N>
using RegBufferInt16 = RegisterBuffer<std::int16_t, N>;
template <int N>
using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>;
template <int R, int C>
using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>;
template <int R, int C>
using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>;
template <int R, int C>
using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>;

}  // end namespace gemmlowp

#if defined GEMMLOWP_NEON
#include "simd_wrappers_neon.h"
#elif defined GEMMLOWP_SSE4
#include "simd_wrappers_sse.h"
#elif defined GEMMLOWP_MSA
#include "simd_wrappers_msa.h"
#endif

#endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_