diff options
Diffstat (limited to 'zluda_blas/src/lib.rs')
-rw-r--r-- | zluda_blas/src/lib.rs | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index 0d50a99..c4b0bc5 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -916,3 +916,126 @@ unsafe fn dtrsm( ldb, )) } + +unsafe fn gemm_batched_ex( + handle: cublasHandle_t, + transa: cublasOperation_t, + transb: cublasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const std::ffi::c_void, + a: *const *const std::ffi::c_void, + atype: cudaDataType_t, + lda: i32, + b: *const *const std::ffi::c_void, + btype: cudaDataType_t, + ldb: i32, + beta: *const std::ffi::c_void, + c: *const *mut std::ffi::c_void, + ctype: cudaDataType_t, + ldc: i32, + batch_count: i32, + compute_type: cublasComputeType_t, + algo: cublasGemmAlgo_t, +) -> cublasStatus_t { + let transa = op_from_cuda(transa); + let transb = op_from_cuda(transb); + let atype = type_from_cuda(atype); + let btype = type_from_cuda(btype); + let ctype = type_from_cuda(ctype); + let compute_type = to_compute_type(compute_type); + let algo = to_algo(algo); + to_cuda(rocblas_gemm_batched_ex( + handle.cast(), + transa, + transb, + m, + n, + k, + alpha, + a.cast(), + atype, + lda, + b.cast(), + btype, + ldb, + beta, + c.cast(), + ctype, + ldc, + c.cast_mut().cast(), + ctype, + ldc, + batch_count, + compute_type, + algo, + 0, + rocblas_gemm_flags::rocblas_gemm_flags_none.0, + )) +} + +unsafe fn gemm_strided_batched_ex( + handle: cublasHandle_t, + transa: cublasOperation_t, + transb: cublasOperation_t, + m: ::std::os::raw::c_int, + n: ::std::os::raw::c_int, + k: ::std::os::raw::c_int, + alpha: *const ::std::os::raw::c_void, + a: *const ::std::os::raw::c_void, + atype: cudaDataType, + lda: ::std::os::raw::c_int, + stride_a: ::std::os::raw::c_longlong, + b: *const ::std::os::raw::c_void, + btype: cudaDataType, + ldb: ::std::os::raw::c_int, + stride_b: ::std::os::raw::c_longlong, + beta: *const ::std::os::raw::c_void, + c: *mut ::std::os::raw::c_void, + ctype: cudaDataType, + ldc: ::std::os::raw::c_int, + stride_c: ::std::os::raw::c_longlong, + batch_count: ::std::os::raw::c_int, + compute_type: cublasComputeType_t, + algo: cublasGemmAlgo_t, +) -> cublasStatus_t { + let transa = op_from_cuda(transa); + let transb = op_from_cuda(transb); + let atype = type_from_cuda(atype); + let btype = type_from_cuda(btype); + let ctype = type_from_cuda(ctype); + let compute_type = to_compute_type(compute_type); + let algo = to_algo(algo); + to_cuda(rocblas_gemm_strided_batched_ex( + handle.cast(), + transa, + transb, + m, + n, + k, + alpha, + a, + atype, + lda, + stride_a, + b, + btype, + ldb, + stride_b, + beta, + c, + ctype, + ldc, + stride_c, + c, + ctype, + ldc, + stride_c, + batch_count, + compute_type, + algo, + 0, + rocblas_gemm_flags::rocblas_gemm_flags_none.0, + )) +} |