// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // simd_wrappers.h: some inline functions wrapping SIMD intrinsics, // extending the set of such functions from fixedpoint.h. #ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ #define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_ #include <algorithm> #include <type_traits> #include "../fixedpoint/fixedpoint.h" namespace gemmlowp { template <typename ScalarType, int ScalarCount> struct RegisterType { using Type = ScalarType; }; inline std::int32_t Min(std::int32_t a, std::int32_t b) { return std::min(a, b); } inline std::int32_t Max(std::int32_t a, std::int32_t b) { return std::max(a, b); } inline void MulAdd(std::int32_t lhs, std::int32_t rhs, std::int32_t* acc) { *acc += lhs * rhs; } template <typename tScalarType, int tScalarCount> struct RegisterBuffer { using ScalarType = tScalarType; static constexpr int kScalarCount = tScalarCount; using RegisterType = typename RegisterType<ScalarType, kScalarCount>::Type; static_assert((kScalarCount & (kScalarCount - 1)) == 0, "kScalarCount must be a power of two"); static_assert(sizeof(RegisterType) % sizeof(ScalarType) == 0, ""); static constexpr int kRegisterLanes = sizeof(RegisterType) / sizeof(ScalarType); static constexpr int kRegisterCount = (kScalarCount * sizeof(ScalarType) + sizeof(RegisterType) - 1) / sizeof(RegisterType); RegisterType reg[kRegisterCount]; }; template <typename tScalarType, int tRows, int tCols> struct RegisterBlock { using ScalarType = tScalarType; static constexpr int kRows = tRows; static constexpr int kCols = tCols; static constexpr int kScalarCount = kRows * kCols; using BufferType = RegisterBuffer<ScalarType, kScalarCount>; using RegisterType = typename BufferType::RegisterType; static constexpr int kRegisterCount = BufferType::kRegisterCount; static constexpr int kRegisterLanes = BufferType::kRegisterLanes; BufferType buf; }; template <typename RegisterBlockType> struct RegisterBlockAddImpl { static RegisterBlockType Run(const RegisterBlockType& lhs, const RegisterBlockType& rhs) { RegisterBlockType result; for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); } return result; } }; template <typename RegisterBlockType> RegisterBlockType RegisterBlockAdd(const RegisterBlockType& lhs, const RegisterBlockType& rhs) { return RegisterBlockAddImpl<RegisterBlockType>::Run(lhs, rhs); } template <typename LhsType, typename RhsType> struct ShouldFlipLhsRhs { static constexpr bool kValue = (LhsType::kScalarCount < RhsType::kScalarCount) || (LhsType::kScalarCount == RhsType::kScalarCount && (LhsType::kRows < RhsType::kRows)); }; template <typename LhsType, typename RhsType, bool Flip = ShouldFlipLhsRhs<LhsType, RhsType>::kValue> struct FlipLhsRhs { using FlippedLhsType = LhsType; using FlippedRhsType = RhsType; static const FlippedLhsType& FlippedLhs(const LhsType& lhs, const RhsType& rhs) { return lhs; } static const FlippedRhsType& FlippedRhs(const LhsType& lhs, const RhsType& rhs) { return rhs; } }; template <typename LhsType, typename RhsType> struct FlipLhsRhs<LhsType, RhsType, true> { using FlippedLhsType = RhsType; using FlippedRhsType = LhsType; static const FlippedLhsType& FlippedLhs(const LhsType& lhs, const RhsType& rhs) { return rhs; } static const FlippedRhsType& FlippedRhs(const LhsType& lhs, const RhsType& rhs) { return lhs; } }; template <typename Lhs, typename Rhs> struct BroadcastBinaryOpShape { static constexpr int kRows = Lhs::kRows > Rhs::kRows ? Lhs::kRows : Rhs::kRows; static constexpr int kCols = Lhs::kCols > Rhs::kCols ? Lhs::kCols : Rhs::kCols; }; template <typename Lhs, typename Rhs> struct BroadcastBinaryOpRegisterBlock { using Shape = BroadcastBinaryOpShape<Lhs, Rhs>; using ScalarType = typename Lhs::ScalarType; using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; }; template <typename Lhs, typename Rhs> struct BroadcastAddImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { ResultBlockType result; static constexpr int Rows = ResultBlockType::kRows; static constexpr int Cols = ResultBlockType::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; result.buf.reg[r + c * Rows] = Add(lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows]); } } return result; } }; template <typename Lhs, typename Rhs> typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastAdd( const Lhs& lhs, const Rhs& rhs) { using Flip = FlipLhsRhs<Lhs, Rhs>; return BroadcastAddImpl< typename Flip::FlippedLhsType, typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs)); } template <typename Lhs, typename Rhs> struct BroadcastMulImpl { using ResultBlockType = typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type; static ResultBlockType Run(const Lhs& lhs, const Rhs& rhs) { ResultBlockType result; static constexpr int Rows = ResultBlockType::kRows; static constexpr int Cols = ResultBlockType::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; result.buf.reg[r + c * Rows] = Mul(lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows]); } } return result; } }; template <typename Lhs, typename Rhs> typename BroadcastBinaryOpRegisterBlock<Lhs, Rhs>::Type BroadcastMul( const Lhs& lhs, const Rhs& rhs) { using Flip = FlipLhsRhs<Lhs, Rhs>; return BroadcastMulImpl< typename Flip::FlippedLhsType, typename Flip::FlippedRhsType>::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs)); } template <typename Lhs, typename Rhs, typename Acc> struct BroadcastMulAddImpl { static void Run(const Lhs& lhs, const Rhs& rhs, Acc* acc) { static constexpr int Rows = Acc::kRows; static constexpr int Cols = Acc::kCols; static constexpr int LhsRows = Lhs::kRows; static constexpr int LhsCols = Lhs::kCols; static constexpr int RhsRows = Rhs::kRows; static constexpr int RhsCols = Rhs::kCols; static_assert(Acc::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Lhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(Rhs::kRegisterLanes == 1, "This path is only for scalar values"); static_assert(LhsRows == Rows || LhsRows == 1, ""); static_assert(RhsRows == Rows || RhsRows == 1, ""); static_assert(LhsCols == Cols || LhsCols == 1, ""); static_assert(RhsCols == Cols || RhsCols == 1, ""); for (int c = 0; c < Cols; c++) { const int lhs_c = LhsCols == Cols ? c : 0; const int rhs_c = RhsCols == Cols ? c : 0; for (int r = 0; r < Rows; r++) { const int lhs_r = LhsRows == Rows ? r : 0; const int rhs_r = RhsRows == Rows ? r : 0; MulAdd(lhs.buf.reg[lhs_r + lhs_c * LhsRows], rhs.buf.reg[rhs_r + rhs_c * RhsRows], &acc->buf.reg[r + c * Rows]); } } } }; template <typename Lhs, typename Rhs, typename Acc> void BroadcastMulAdd(const Lhs& lhs, const Rhs& rhs, Acc* acc) { using Flip = FlipLhsRhs<Lhs, Rhs>; BroadcastMulAddImpl<typename Flip::FlippedLhsType, typename Flip::FlippedRhsType, Acc>::Run(Flip::FlippedLhs(lhs, rhs), Flip::FlippedRhs(lhs, rhs), acc); } template <typename RegisterBlockType, typename SrcObjectType> struct LoadImpl { static_assert(std::is_same<SrcObjectType, void>::value, "This generic impl should never be hit"); }; template <typename ScalarType, int Rows, int Cols, typename SrcScalarType> struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, MatrixMap<SrcScalarType, MapOrder::ColMajor>> { using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; using SrcObjectType = MatrixMap<SrcScalarType, MapOrder::ColMajor>; static RegisterBlockType Run(const SrcObjectType& src, int row, int col) { RegisterBlockType result; int i = 0; for (int c = 0; c < Cols; c++) { const ScalarType* src_ptr = src.data(row, col + c); for (int r = 0; r < Rows; r++) { result.buf.reg[i++] = *src_ptr++; } } return result; } }; template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, VectorShape Shape> struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, VectorMap<SrcScalarType, Shape>> { using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; using SrcObjectType = VectorMap<SrcScalarType, Shape>; static RegisterBlockType Run(const SrcObjectType& src, int pos) { static_assert(Shape == VectorShape::Col || Rows == 1, ""); static_assert(Shape == VectorShape::Row || Cols == 1, ""); RegisterBlockType result; for (int i = 0; i < Rows * Cols; i++) { result.buf.reg[i] = src(pos + i); } return result; } }; template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, VectorShape Shape> struct LoadImpl<RegisterBlock<ScalarType, Rows, Cols>, VectorDup<SrcScalarType, Shape>> { using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; using SrcObjectType = VectorDup<SrcScalarType, Shape>; static RegisterBlockType Run(const SrcObjectType& src, int) { static_assert(Shape == VectorShape::Col || Rows == 1, ""); static_assert(Shape == VectorShape::Row || Cols == 1, ""); RegisterBlockType result; for (int i = 0; i < Rows * Cols; i++) { result.buf.reg[i] = src(0); } return result; } }; template <typename RegisterBlockType, typename SrcObjectType> RegisterBlockType Load(const SrcObjectType& src, int row, int col) { return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, row, col); } template <typename RegisterBlockType, typename SrcObjectType> RegisterBlockType Load(const SrcObjectType& src, int pos) { return LoadImpl<RegisterBlockType, SrcObjectType>::Run(src, pos); } template <typename RegisterBlockType> struct LoadContiguousImpl { using ScalarType = typename RegisterBlockType::ScalarType; static_assert(RegisterBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static RegisterBlockType Run(const ScalarType* src) { RegisterBlockType result; for (int i = 0; i < RegisterBlockType::kScalarCount; i++) { result.buf.reg[i] = src[i]; } return result; } }; template <typename RegisterBlockType> RegisterBlockType LoadContiguous( const typename RegisterBlockType::ScalarType* src) { return LoadContiguousImpl<RegisterBlockType>::Run(src); } template <int BroadcastRows, int BroadcastCols, typename SrcObjectType> struct LoadForBroadcastingShape {}; template <int BroadcastRows, int BroadcastCols, typename ScalarType, VectorShape Shape> struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, VectorMap<ScalarType, Shape>> { static constexpr int kRows = Shape == VectorShape::Col ? BroadcastRows : 1; static constexpr int kCols = Shape == VectorShape::Row ? BroadcastCols : 1; }; template <int BroadcastRows, int BroadcastCols, typename ScalarType, VectorShape Shape> struct LoadForBroadcastingShape<BroadcastRows, BroadcastCols, VectorDup<ScalarType, Shape>> { static constexpr int kRows = 1; static constexpr int kCols = 1; }; template <typename RegisterBlockType, typename SrcObjectType> struct LoadForBroadcastingRegisterBlock { using Shape = LoadForBroadcastingShape<RegisterBlockType::kRows, RegisterBlockType::kCols, SrcObjectType>; using ScalarType = typename RegisterBlockType::ScalarType; using Type = RegisterBlock<ScalarType, Shape::kRows, Shape::kCols>; }; template <typename RegisterBlockType, typename SrcObjectType> struct LoadForBroadcastingImpl { static_assert(std::is_same<SrcObjectType, void>::value, "This generic impl should never be hit"); }; template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, VectorShape Shape> struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, VectorMap<SrcScalarType, Shape>> { using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; using SrcObjectType = VectorMap<SrcScalarType, Shape>; using ResultBlockType = typename LoadForBroadcastingRegisterBlock<RegisterBlockType, SrcObjectType>::Type; static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static ResultBlockType Run(const SrcObjectType& src, int pos) { ResultBlockType result; for (int c = 0; c < ResultBlockType::kCols; c++) { for (int r = 0; r < ResultBlockType::kRows; r++) { const int i = Shape == VectorShape::Col ? r : c; result.buf.reg[r + c * ResultBlockType::kRows] = src(pos + i); } } return result; } }; template <typename ScalarType, int Rows, int Cols, typename SrcScalarType, VectorShape Shape> struct LoadForBroadcastingImpl<RegisterBlock<ScalarType, Rows, Cols>, VectorDup<SrcScalarType, Shape>> { using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; using SrcObjectType = VectorDup<SrcScalarType, Shape>; using ResultBlockType = typename LoadForBroadcastingRegisterBlock<RegisterBlockType, SrcObjectType>::Type; static_assert(ResultBlockType::kRegisterLanes == 1, "This path is only for scalar values"); static ResultBlockType Run(const SrcObjectType& src, int) { ResultBlockType result; for (int c = 0; c < ResultBlockType::kCols; c++) { for (int r = 0; r < ResultBlockType::kRows; r++) { result.buf.reg[r + c * ResultBlockType::kRows] = src(0); } } return result; } }; template <typename RegisterBlockType, typename SrcObjectType> typename LoadForBroadcastingRegisterBlock<RegisterBlockType, SrcObjectType>::Type LoadForBroadcasting(const SrcObjectType& src, int row, int col) { return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run( src, row, col); } template <typename RegisterBlockType, typename SrcObjectType> typename LoadForBroadcastingRegisterBlock<RegisterBlockType, SrcObjectType>::Type LoadForBroadcasting(const SrcObjectType& src, int pos) { return LoadForBroadcastingImpl<RegisterBlockType, SrcObjectType>::Run(src, pos); } template <int ConstantValue, typename RegisterBlockType> struct AddConstantImpl { static void Run(RegisterBlockType* block) { using RegisterType = typename RegisterBlockType::RegisterType; const RegisterType dup = Dup<RegisterType>(ConstantValue); for (int i = 0; i < RegisterBlockType::kRegisterCount; i++) { block->buf.reg[i] = Add(block->buf.reg[i], dup); } } }; template <typename RegisterBlockType> struct AddConstantImpl<0, RegisterBlockType> { static void Run(RegisterBlockType*) { // This is a no-op. } }; template <int ConstantValue, typename RegisterBlockType> void AddConstant(RegisterBlockType* block) { AddConstantImpl<ConstantValue, RegisterBlockType>::Run(block); } template <int N> using RegBufferInt32 = RegisterBuffer<std::int32_t, N>; template <int N> using RegBufferInt16 = RegisterBuffer<std::int16_t, N>; template <int N> using RegBufferUint8 = RegisterBuffer<std::uint8_t, N>; template <int R, int C> using RegBlockInt32 = RegisterBlock<std::int32_t, R, C>; template <int R, int C> using RegBlockInt16 = RegisterBlock<std::int16_t, R, C>; template <int R, int C> using RegBlockUint8 = RegisterBlock<std::uint8_t, R, C>; } // end namespace gemmlowp #if defined GEMMLOWP_NEON #include "simd_wrappers_neon.h" #elif defined GEMMLOWP_SSE4 #include "simd_wrappers_sse.h" #elif defined GEMMLOWP_MSA #include "simd_wrappers_msa.h" #endif #endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_H_