HELLO·Android
系统源代码
IT资讯
技术文章
我的收藏
注册
登录
-
我收藏的文章
创建代码块
我的代码块
我的账号
Oreo
|
8.0.0_r4
下载
查看原文件
收藏
根目录
frameworks
rs
cpp
ScriptIntrinsicBLAS.cpp
/* * Copyright (C) 2015 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "RenderScript.h" #include "rsCppInternal.h" #define NELEM(m) (sizeof(m) / sizeof((m)[0])) using android::RSC::Allocation; using android::RSC::Element; using android::RSC::RS; using android::RSC::RS_ERROR_INVALID_ELEMENT; using android::RSC::RS_ERROR_INVALID_PARAMETER; using android::RSC::RS_SUCCESS; using android::RSC::ScriptIntrinsicBLAS; using android::RSC::sp; // ScriptIntrinsicBLAS APIS ScriptIntrinsicBLAS::ScriptIntrinsicBLAS(sp
rs, sp
e) : ScriptIntrinsic(rs, RS_SCRIPT_INTRINSIC_ID_BLAS, e) { } sp
ScriptIntrinsicBLAS::create(const sp
& rs) { return new ScriptIntrinsicBLAS(rs, Element::U32(rs)); } enum RsBlasDataType { SINGLE, DOUBLE, SINGLE_COMPLEX, DOUBLE_COMPLEX }; static RsBlasCall setUpBLASCall(RsBlasDataType dataType, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, int incX, int incY, int KL, int KU, float alphaF, float betaF, double alphaD, double betaD, float alphaCX, float alphaCY, float betaCX, float betaCY, double alphaZX, double alphaZY, double betaZX, double betaZY ) { RsBlasCall call; memset(&call, 0, sizeof(call)); call.func = func; call.transA = (RsBlasTranspose)TransA; call.transB = (RsBlasTranspose)TransB; call.side = (RsBlasSide)Side; call.uplo = (RsBlasUplo)Uplo; call.diag = (RsBlasDiag)Diag; call.M = M; call.N = N; call.K = K; switch (dataType) { case SINGLE: // For Single-precision BLAS. call.alpha.f = alphaF; call.beta.f = betaF; break; case DOUBLE: // For Double-precision BLAS. call.alpha.d = alphaD; call.beta.d = betaD; break; case SINGLE_COMPLEX: // For Single-precision complex BLAS. call.alpha.c.r = alphaCX; call.alpha.c.i = alphaCY; call.beta.c.r = betaCX; call.beta.c.i = betaCY; break; case DOUBLE_COMPLEX: // For Double-precision complex BLAS. call.alpha.z.r = alphaZX; call.alpha.z.i = alphaZY; call.beta.z.r = betaZX; call.beta.z.i = betaZY; break; default: break; } call.incX = incX; call.incY = incY; call.KL = KL; call.KU = KU; return call; } static void nScriptIntrinsicBLAS_Single(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, float alpha, RsAllocation A, RsAllocation B, float beta, RsAllocation C, int incX, int incY, int KL, int KU) { RsBlasCall call = setUpBLASCall(SINGLE, func, TransA, TransB, Side, Uplo, Diag, M, N, K, incX, incY, KL, KU, alpha, beta, 0.0, 0.0, 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } static void nScriptIntrinsicBLAS_Double(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, double alpha, RsAllocation A, RsAllocation B, double beta, RsAllocation C, int incX, int incY, int KL, int KU) { RsBlasCall call = setUpBLASCall(DOUBLE, func, TransA, TransB, Side, Uplo, Diag, M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, alpha, beta, 0.0f, 0.0f, 0.0f, 0.0f, 0.0, 0.0, 0.0, 0.0); RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } static void nScriptIntrinsicBLAS_Complex(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, float alphaX, float alphaY, RsAllocation A, RsAllocation B, float betaX, float betaY, RsAllocation C, int incX, int incY, int KL, int KU) { RsBlasCall call = setUpBLASCall(SINGLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, alphaX, alphaY, betaX, betaY, 0.0, 0.0, 0.0, 0.0); RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } static void nScriptIntrinsicBLAS_Z(RS* mRS, RsContext con, RsScript id, RsBlasFunction func, int TransA, int TransB, int Side, int Uplo, int Diag, int M, int N, int K, double alphaX, double alphaY, RsAllocation A, RsAllocation B, double betaX, double betaY, RsAllocation C, int incX, int incY, int KL, int KU) { RsBlasCall call = setUpBLASCall(DOUBLE_COMPLEX, func, TransA, TransB, Side, Uplo, Diag, M, N, K, incX, incY, KL, KU, 0.0f, 0.0f, 0.0, 0.0, 0.0f, 0.0f, 0.0f, 0.0f, alphaX, alphaY, betaX, betaY); RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } static void nScriptIntrinsicBLAS_BNNM(RS* mRS, RsContext con, RsScript id, int M, int N, int K, RsAllocation A, int a_offset, RsAllocation B, int b_offset, RsAllocation C, int c_offset, int c_mult_int) { RsBlasCall call; memset(&call, 0, sizeof(call)); call.func = RsBlas_bnnm; call.M = M; call.N = N; call.K = K; call.a_offset = a_offset & 0xFF; call.b_offset = b_offset & 0xFF; call.c_offset = c_offset; call.c_mult_int = c_mult_int; RsAllocation in_allocs[3] = {A, B, C}; tryDispatch(mRS, RS::dispatch->ScriptForEachMulti(con, id, 0, in_allocs, NELEM(in_allocs), nullptr, &call, sizeof(call), nullptr, 0)); } /** * Level 2 BLAS */ static void validateGEMV(RS* mRS, const sp
& e, RsBlasTranspose TransA, const sp
& A, const sp
& X, int incX, const sp
& Y, int incY) { int M = A->getType()->getY(); int N = A->getType()->getX(); if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = -1, expectedYDim = -1; if (TransA == RsBlasNoTrans) { expectedXDim = 1 + (N - 1) * incX; expectedYDim = 1 + (M - 1) * incY; } else { expectedXDim = 1 + (M - 1) * incX; expectedYDim = 1 + (N - 1) * incY; } if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GEMV"); } } void ScriptIntrinsicBLAS::SGEMV(RsBlasTranspose TransA, float alpha, const sp
& A, const sp
& X, int incX, float beta, const sp
& Y, int incY) { validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DGEMV(RsBlasTranspose TransA, double alpha, const sp
& A, const sp
& X, int incX, double beta, const sp
& Y, int incY) { validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CGEMV(RsBlasTranspose TransA, Float2 alpha, const sp
& A, const sp
& X, int incX, Float2 beta, const sp
& Y, int incY) { validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZGEMV(RsBlasTranspose TransA, Double2 alpha, const sp
& A, const sp
& X, int incX, Double2 beta, const sp
& Y, int incY) { validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SGBMV(RsBlasTranspose TransA, int KL, int KU, float alpha, const sp
& A, const sp
& X, int incX, float beta, const sp
& Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(mRS, Element::F32(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); } int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, KL, KU); } void ScriptIntrinsicBLAS::DGBMV(RsBlasTranspose TransA, int KL, int KU, double alpha, const sp
& A, const sp
& X, int incX, double beta, const sp
& Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(mRS, Element::F64(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); } int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, KL, KU); } void ScriptIntrinsicBLAS::CGBMV(RsBlasTranspose TransA, int KL, int KU, Float2 alpha, const sp
& A, const sp
& X, int incX, Float2 beta, const sp
& Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(mRS, Element::F32_2(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); } int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, KL, KU); } void ScriptIntrinsicBLAS::ZGBMV(RsBlasTranspose TransA, int KL, int KU, Double2 alpha, const sp
& A, const sp
& X, int incX, Double2 beta, const sp
& Y, int incY) { // GBMV has the same validation requirements as GEMV + KL and KU >= 0 validateGEMV(mRS, Element::F64_2(mRS), TransA, A, X, incX, Y, incY); if (KL < 0 || KU < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "KL and KU must be greater than or equal to 0"); } int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgbmv, TransA, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, KL, KU); } static void validateTRMV(RS* mRS, const sp
& e, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { int N = A->getType()->getY(); if ((int)A->getType()->getX() != N) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for TRMV"); } if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (incX <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TRMV"); } } static int validateTPMV(RS* mRS, const sp
& e, RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { if (!Ap->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (Ap->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); } int N = sqrt((double)Ap->getType()->getX() * 2); if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); } if (incX <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for TPMV"); } return N; } void ScriptIntrinsicBLAS::STRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTRMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, const sp
& A, const sp
& X, int incX) { // TBMV has the same requirements as TRMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, const sp
& A, const sp
& X, int incX) { // TBMV has the same requirements as TRMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, const sp
& A, const sp
& X, int incX) { // TBMV has the same requirements as TRMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTBMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, const sp
& A, const sp
& X, int incX) { // TBMV has the same requirements as TRMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbmv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTPMV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpmv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { // TRSV is the same as TRMV validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_strsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { // TRSV is the same as TRMV validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { // TRSV is the same as TRMV validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTRSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& A, const sp
& X, int incX) { // TRSV is the same as TRMV validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztrsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, const sp
& A, const sp
& X, int incX) { // TBSV is the same as TRMV + K >= 0 validateTRMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); } nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, const sp
& A, const sp
& X, int incX) { // TBSV is the same as TRMV + K >= 0 validateTRMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); } nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, A->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, const sp
& A, const sp
& X, int incX) { // TBSV is the same as TRMV + K >= 0 validateTRMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTBSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, int K, const sp
& A, const sp
& X, int incX) { // TBSV is the same as TRMV + K >= 0 validateTRMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, A, X, incX); int N = A->getType()->getY(); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Number of diagonals must be positive"); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztbsv, TransA, 0, 0, Uplo, Diag, 0, N, K, 0, 0, A->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::STPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { // TPSV is same as TPMV int N = validateTPMV(mRS, Element::F32(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_stpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { // TPSV is same as TPMV int N = validateTPMV(mRS, Element::F64(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dtpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, Ap->getID(), X->getID(), 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { // TPSV is same as TPMV int N = validateTPMV(mRS, Element::F32_2(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_ctpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZTPSV(RsBlasUplo Uplo, RsBlasTranspose TransA, RsBlasDiag Diag, const sp
& Ap, const sp
& X, int incX) { // TPSV is same as TPMV int N = validateTPMV(mRS, Element::F64_2(mRS), Uplo, TransA, Diag, Ap, X, incX); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_ztpsv, TransA, 0, 0, Uplo, Diag, 0, N, 0, 0, 0, Ap->getID(), X->getID(), 0, 0, 0, incX, 0, 0, 0); } /** * Level 2, S and D only */ static int validateSYMV(RS* mRS, const sp
& e, RsBlasUplo Uplo, const sp
& A, const sp
& X, const sp
& Y, int incX, int incY) { int N = A->getType()->getY(); if ((int)A->getType()->getX() != N) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a square matrix for SYMV"); } if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e) ) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); } int expectedYDim = 1 + (N - 1) * incY; if ((int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYMV"); } return N; } static int validateSPMV(RS* mRS, const sp
& e, RsBlasUplo Uplo, const sp
& Ap, const sp
& X, int incX, const sp
& Y, int incY) { if (!Ap->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (Ap->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); } int N = sqrt((double)Ap->getType()->getX() * 2); if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); } int expectedYDim = 1 + (N - 1) * incY; if ((int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPMV"); } return N; } static void validateGER(RS* mRS, const sp
& e, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e) ) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } int M = A->getType()->getY(); int N = A->getType()->getX(); if (N < 1 || M < 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "M and N must be 1 or greater for GER"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (M - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); } int expectedYDim = 1 + (N - 1) * incY; if ((int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GER"); } } static int validateSYR(RS* mRS, const sp
& e, RsBlasUplo Uplo, const sp
& X, int incX, const sp
& A) { if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } int N = A->getType()->getX(); if (X->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (N != (int)A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); } if (incX <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); } return N; } static int validateSPR(RS* mRS, const sp
& e, RsBlasUplo Uplo, const sp
& X, int incX, const sp
& Ap) { if (!Ap->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (Ap->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); } int N = sqrt((double)Ap->getType()->getX() * 2); if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); } if (incX <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR"); } return N; } static int validateSYR2(RS* mRS, const sp
& e, RsBlasUplo Uplo, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } int N = A->getType()->getX(); if (N != (int)A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "A must be a symmetric matrix"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; int expectedYDim = 1 + (N - 1) * incY; if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SYR"); } return N; } static int validateSPR2(RS* mRS, const sp
& e, RsBlasUplo Uplo, const sp
& X, int incX, const sp
& Y, int incY, const sp
& Ap) { if (!Ap->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } if (Ap->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Ap must have a Y dimension of 0 or 1"); } int N = sqrt((double)Ap->getType()->getX() * 2); if ((int)Ap->getType()->getX() != ((N * (N+1)) / 2)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid dimension for Ap"); } if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (N - 1) * incX; int expectedYDim = 1 + (N - 1) * incY; if ((int)X->getType()->getX() != expectedXDim || (int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for SPR2"); } return N; } void ScriptIntrinsicBLAS::SSYMV(RsBlasUplo Uplo, float alpha, const sp
& A, const sp
& X, int incX, float beta, const sp
& Y, int incY) { int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SSBMV(RsBlasUplo Uplo, int K, float alpha, const sp
& A, const sp
& X, int incX, float beta, const sp
& Y, int incY) { // SBMV is the same as SYMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } int N = validateSYMV(mRS, Element::F32(mRS), Uplo, A, X, Y, incX, incY); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SSPMV(RsBlasUplo Uplo, float alpha, const sp
& Ap, const sp
& X, int incX, float beta, const sp
& Y, int incY) { int N = validateSPMV(mRS, Element::F32(mRS), Uplo, Ap, X, incX, Y, incY); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SGER(float alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { int M = A->getType()->getY(); int N = A->getType()->getX(); validateGER(mRS, Element::F32(mRS), X, incX, Y, incY, A); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sger, 0, 0, 0, 0, 0, M, N, 0, alpha, X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SSYR(RsBlasUplo Uplo, float alpha, const sp
& X, int incX, const sp
& A) { int N = validateSYR(mRS, Element::F32(mRS), Uplo, X, incX, A); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::SSPR(RsBlasUplo Uplo, float alpha, const sp
& X, int incX, const sp
& Ap) { int N = validateSPR(mRS, Element::F32(mRS), Uplo, X, incX, Ap); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::SSYR2(RsBlasUplo Uplo, float alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { int N = validateSYR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::SSPR2(RsBlasUplo Uplo, float alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& Ap) { int N = validateSPR2(mRS, Element::F32(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSYMV(RsBlasUplo Uplo, double alpha, const sp
& A, const sp
& X, int incX, double beta, const sp
& Y, int incY) { int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSBMV(RsBlasUplo Uplo, int K, double alpha, const sp
& A, const sp
& X, int incX, double beta, const sp
& Y, int incY) { // SBMV is the same as SYMV + K >= 0 if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be greater than or equal to 0"); } int N = validateSYMV(mRS, Element::F64(mRS), Uplo, A, X, Y, incX, incY); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha, A->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSPMV(RsBlasUplo Uplo, double alpha, const sp
& Ap, const sp
& X, int incX, double beta, const sp
& Y, int incY) { int N = validateSPMV(mRS, Element::F64(mRS), Uplo, Ap, X, incX, Y, incY); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, Ap->getID(), X->getID(), beta, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DGER(double alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { int M = A->getType()->getY(); int N = A->getType()->getX(); validateGER(mRS, Element::F64(mRS), X, incX, Y, incY, A); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dger, 0, 0, 0, 0, 0, M, N, 0, alpha, X->getID(), Y->getID(), 0.f, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSYR(RsBlasUplo Uplo, double alpha, const sp
& X, int incX, const sp
& A) { int N = validateSYR(mRS, Element::F64(mRS), Uplo, X, incX, A); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), A->getID(), 0.f, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DSPR(RsBlasUplo Uplo, double alpha, const sp
& X, int incX, const sp
& Ap) { int N = validateSPR(mRS, Element::F64(mRS), Uplo, X, incX, Ap); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Ap->getID(), 0.f, 0, incX, 0, 0, 0); } void ScriptIntrinsicBLAS::DSYR2(RsBlasUplo Uplo, double alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { int N = validateSYR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Y->getID(), 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::DSPR2(RsBlasUplo Uplo, double alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& Ap) { int N = validateSPR2(mRS, Element::F64(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dspr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, X->getID(), Y->getID(), 0, Ap->getID(), incX, incY, 0, 0); } /** * Level 2, C and Z only */ static void validateGERU(RS* mRS, const sp
& e, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { if (!A->getType()->getElement()->isCompatible(e) || !X->getType()->getElement()->isCompatible(e) || !Y->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (X->getType()->getY() > 1 || Y->getType()->getY() > 1) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "BLAS vectors must have Y dimension of 0 or 1"); } int M = A->getType()->getY(); int N = A->getType()->getX(); if (incX <= 0 || incY <= 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Vector increments must be greater than 0"); } int expectedXDim = 1 + (M - 1) * incX; if ((int)X->getType()->getX() != expectedXDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); } int expectedYDim = 1 + (N - 1) * incY; if ((int)Y->getType()->getX() != expectedYDim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Incorrect vector dimensions for GERU"); } } void ScriptIntrinsicBLAS::CHEMV(RsBlasUplo Uplo, Float2 alpha, const sp
& A, const sp
& X, int incX, Float2 beta, const sp
& Y, int incY) { // HEMV is the same as SYR2 validation-wise int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CHBMV(RsBlasUplo Uplo, int K, Float2 alpha, const sp
& A, const sp
& X, int incX, Float2 beta, const sp
& Y, int incY) { // HBMV is the same as SYR2 validation-wise int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CHPMV(RsBlasUplo Uplo, Float2 alpha, const sp
& Ap, const sp
& X, int incX, Float2 beta, const sp
& Y, int incY) { // HPMV is the same as SPR2 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CGERU(Float2 alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CGERC(Float2 alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { // Same as GERU validateGERU(mRS, Element::F32_2(mRS), X, incX, Y, incY, A); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CHER(RsBlasUplo Uplo, float alpha, const sp
& X, int incX, const sp
& A) { // Same as SYR int N = validateSYR(mRS, Element::F32_2(mRS), Uplo, X, incX, A); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X->getID(), 0, 0, 0, A->getID(), incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CHPR(RsBlasUplo Uplo, float alpha, const sp
& X, int incX, const sp
& Ap) { // Equivalent to SPR for validation int N = validateSPR(mRS, Element::F32_2(mRS), Uplo, X, incX, Ap); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X->getID(), 0, 0, 0, Ap->getID(), incX, 0, 0, 0); } void ScriptIntrinsicBLAS::CHER2(RsBlasUplo Uplo, Float2 alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { // Same as SYR2 int N = validateSYR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::CHPR2(RsBlasUplo Uplo, Float2 alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& Ap) { // Same as SPR2 int N = validateSPR2(mRS, Element::F32_2(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_chpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, Ap->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHEMV(RsBlasUplo Uplo, Double2 alpha, const sp
& A, const sp
& X, int incX, Double2 beta, const sp
& Y, int incY) { // HEMV is the same as SYR2 validation-wise int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhemv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHBMV(RsBlasUplo Uplo, int K, Double2 alpha, const sp
& A, const sp
& X, int incX, Double2 beta, const sp
& Y, int incY) { // HBMV is the same as SYR2 validation-wise int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); if (K < 0) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "K must be 0 or greater for HBMV"); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhbmv, 0, 0, 0, Uplo, 0, 0, N, K, alpha.x, alpha.y, A->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHPMV(RsBlasUplo Uplo, Double2 alpha, const sp
& Ap, const sp
& X, int incX, Double2 beta, const sp
& Y, int incY) { // HPMV is the same as SPR2 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpmv, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, Ap->getID(), X->getID(), beta.x, beta.y, Y->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZGERU(Double2 alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgeru, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZGERC(Double2 alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { // Same as GERU validateGERU(mRS, Element::F64_2(mRS), X, incX, Y, incY, A); int M = A->getType()->getY(); int N = A->getType()->getX(); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgerc, 0, 0, 0, 0, 0, M, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHER(RsBlasUplo Uplo, double alpha, const sp
& X, int incX, const sp
& A) { // Same as SYR int N = validateSYR(mRS, Element::F64_2(mRS), Uplo, X, incX, A); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X->getID(), 0, 0, 0, A->getID(), incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZHPR(RsBlasUplo Uplo, double alpha, const sp
& X, int incX, const sp
& Ap) { // Equivalent to SPR for validation int N = validateSPR(mRS, Element::F64_2(mRS), Uplo, X, incX, Ap); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr, 0, 0, 0, Uplo, 0, 0, N, 0, alpha, 0, X->getID(), 0, 0, 0, Ap->getID(), incX, 0, 0, 0); } void ScriptIntrinsicBLAS::ZHER2(RsBlasUplo Uplo, Double2 alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& A) { // Same as SYR2 int N = validateSYR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, A); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zher2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, A->getID(), incX, incY, 0, 0); } void ScriptIntrinsicBLAS::ZHPR2(RsBlasUplo Uplo, Double2 alpha, const sp
& X, int incX, const sp
& Y, int incY, const sp
& Ap) { // Same as SPR2 int N = validateSPR2(mRS, Element::F64_2(mRS), Uplo, X, incX, Y, incY, Ap); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zhpr2, 0, 0, 0, Uplo, 0, 0, N, 0, alpha.x, alpha.y, X->getID(), Y->getID(), 0, 0, Ap->getID(), incX, incY, 0, 0); } /** * Level 3 BLAS */ static void validateL3(RS* mRS, const sp
& e, int TransA, int TransB, int Side, const sp
& A, const sp
& B, const sp
& C) { int aM = -1, aN = -1, bM = -1, bN = -1, cM = -1, cN = -1; if ((A != nullptr && !A->getType()->getElement()->isCompatible(e)) || (B != nullptr && !B->getType()->getElement()->isCompatible(e)) || (C != nullptr && !C->getType()->getElement()->isCompatible(e))) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } if (C == nullptr) { // Since matrix C is used to store the result, it cannot be null. mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Allocation C cannot be null"); } cM = C->getType()->getY(); cN = C->getType()->getX(); if (Side == RsBlasRight) { if ((A == nullptr && B != nullptr) || (A != nullptr && B == nullptr)) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Provided Matrix A without Matrix B, or vice versa"); } if (B != nullptr) { bM = A->getType()->getY(); bN = A->getType()->getX(); } if (A != nullptr) { aM = B->getType()->getY(); aN = B->getType()->getX(); } } else { if (A != nullptr) { if (TransA == RsBlasTrans || TransA == RsBlasConjTrans) { aN = A->getType()->getY(); aM = A->getType()->getX(); } else { aM = A->getType()->getY(); aN = A->getType()->getX(); } } if (B != nullptr) { if (TransB == RsBlasTrans || TransB == RsBlasConjTrans) { bN = B->getType()->getY(); bM = B->getType()->getX(); } else { bM = B->getType()->getY(); bN = B->getType()->getX(); } } } if (A != nullptr && B != nullptr && C != nullptr) { if (aN != bM || aM != cM || bN != cN) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); } } else if (A != nullptr && C != nullptr) { // A and C only, for SYRK if (cM != cN) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix C is not symmetric"); } if (aM != cM) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); } } else if (A != nullptr && B != nullptr) { // A and B only if (aN != bM) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Called BLAS with invalid dimensions"); } } } void ScriptIntrinsicBLAS::SGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, float alpha, const sp
& A, const sp
& B, float beta, const sp
& C) { validateL3(mRS, Element::F32(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA != RsBlasNoTrans) { M = A->getType()->getX(); K = A->getType()->getY(); } else { M = A->getType()->getY(); K = A->getType()->getX(); } if (TransB != RsBlasNoTrans) { N = B->getType()->getY(); } else { N = B->getType()->getX(); } nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_sgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, double alpha, const sp
& A, const sp
& B, double beta, const sp
& C) { validateL3(mRS, Element::F64(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA != RsBlasNoTrans) { M = A->getType()->getX(); K = A->getType()->getY(); } else { M = A->getType()->getY(); K = A->getType()->getX(); } if (TransB != RsBlasNoTrans) { N = B->getType()->getY(); } else { N = B->getType()->getX(); } nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Float2 alpha, const sp
& A, const sp
& B, Float2 beta, const sp
& C) { validateL3(mRS, Element::F32_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA != RsBlasNoTrans) { M = A->getType()->getX(); K = A->getType()->getY(); } else { M = A->getType()->getY(); K = A->getType()->getX(); } if (TransB != RsBlasNoTrans) { N = B->getType()->getY(); } else { N = B->getType()->getX(); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_cgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZGEMM(RsBlasTranspose TransA, RsBlasTranspose TransB, Double2 alpha, const sp
& A, const sp
& B, Double2 beta, const sp
& C) { validateL3(mRS, Element::F64_2(mRS), TransA, TransB, 0, A, B, C); int M = -1, N = -1, K = -1; if (TransA != RsBlasNoTrans) { M = A->getType()->getX(); K = A->getType()->getY(); } else { M = A->getType()->getY(); K = A->getType()->getX(); } if (TransB != RsBlasNoTrans) { N = B->getType()->getY(); } else { N = B->getType()->getX(); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zgemm, TransA, TransB, 0, 0, 0, M, N, K, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::SSYMM(RsBlasSide Side, RsBlasUplo Uplo, float alpha, const sp
& A, const sp
& B, float beta, const sp
& C) { //For SYMM, Matrix A should be symmetric if (A->getType()->getX() != A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); } validateL3(mRS, Element::F32(mRS), 0, 0, Side, A, B, C); nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssymm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DSYMM(RsBlasSide Side, RsBlasUplo Uplo, double alpha, const sp
& A, const sp
& B, double beta, const sp
& C) { if (A->getType()->getX() != A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); } validateL3(mRS, Element::F64(mRS), 0, 0, Side, A, B, C); nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsymm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CSYMM(RsBlasSide Side, RsBlasUplo Uplo, Float2 alpha, const sp
& A, const sp
& B, Float2 beta, const sp
& C) { if (A->getType()->getX() != A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); } validateL3(mRS, Element::F32_2(mRS), 0, 0, Side, A, B, C); nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csymm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZSYMM(RsBlasSide Side, RsBlasUplo Uplo, Double2 alpha, const sp
& A, const sp
& B, Double2 beta, const sp
& C) { if (A->getType()->getX() != A->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Matrix A is not symmetric"); } validateL3(mRS, Element::F64_2(mRS), 0, 0, Side, A, B, C); nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsymm, 0, 0, Side, Uplo, 0, C->getType()->getY(), C->getType()->getX(), 0, alpha.x, alpha.y, A->getID(), B->getID(), beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::SSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, const sp
& A, float beta, const sp
& C) { validateL3(mRS, Element::F32(mRS), Trans, 0, 0, A, nullptr, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyrk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha, A->getID(), 0, beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, const sp
& A, double beta, const sp
& C) { validateL3(mRS, Element::F64(mRS), Trans, 0, 0, A, nullptr, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Double(mRS, mRS->getContext(), getID(), RsBlas_dsyrk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha, A->getID(), 0, beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::CSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Float2 alpha, const sp
& A, Float2 beta, const sp
& C) { validateL3(mRS, Element::F32_2(mRS), Trans, 0, 0, A, nullptr, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Complex(mRS, mRS->getContext(), getID(), RsBlas_csyrk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha.x, alpha.y, A->getID(), 0, beta.x, beta.y, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::ZSYRK(RsBlasUplo Uplo, RsBlasTranspose Trans, Double2 alpha, const sp
& A, Double2 beta, const sp
& C) { validateL3(mRS, Element::F64_2(mRS), Trans, 0, 0, A, nullptr, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Z(mRS, mRS->getContext(), getID(), RsBlas_zsyrk, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha.x, alpha.y, A->getID(), 0, beta.x, beta.y, C->getID(), 0, 0, 0, 0); } static void validateSYR2K(RS* mRS, const sp
& e, RsBlasTranspose Trans, const sp
& A, const sp
& B, const sp
& C) { if (!A->getType()->getElement()->isCompatible(e) || !B->getType()->getElement()->isCompatible(e) || !C->getType()->getElement()->isCompatible(e)) { mRS->throwError(RS_ERROR_INVALID_ELEMENT, "Called BLAS with wrong Element type"); } int Cdim = -1; // A is n x k if no transpose, k x n if transpose // C is n x n if (Trans == RsBlasTrans) { // check columns versus C Cdim = A->getType()->getX(); } else { // check rows versus C Cdim = A->getType()->getY(); } if ((int)C->getType()->getX() != Cdim || (int)C->getType()->getY() != Cdim) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid symmetric matrix in SYR2K"); } // A dims == B dims if (A->getType()->getX() != B->getType()->getX() || A->getType()->getY() != B->getType()->getY()) { mRS->throwError(RS_ERROR_INVALID_PARAMETER, "Invalid A and B in SYR2K"); } } void ScriptIntrinsicBLAS::SSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, float alpha, const sp
& A, const sp
& B, float beta, const sp
& C) { validateSYR2K(mRS, Element::F32(mRS), Trans, A, B, C); int K = -1; if (Trans != RsBlasNoTrans) { K = A->getType()->getY(); } else { K = A->getType()->getX(); } nScriptIntrinsicBLAS_Single(mRS, mRS->getContext(), getID(), RsBlas_ssyr2k, Trans, 0, 0, Uplo, 0, 0, C->getType()->getX(), K, alpha, A->getID(), B->getID(), beta, C->getID(), 0, 0, 0, 0); } void ScriptIntrinsicBLAS::DSYR2K(RsBlasUplo Uplo, RsBlasTranspose Trans, double alpha, const sp