aboutsummaryrefslogtreecommitdiffhomepage
path: root/zluda_blas/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'zluda_blas/src/lib.rs')
-rw-r--r--zluda_blas/src/lib.rs123
1 files changed, 123 insertions, 0 deletions
diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs
index 0d50a99..c4b0bc5 100644
--- a/zluda_blas/src/lib.rs
+++ b/zluda_blas/src/lib.rs
@@ -916,3 +916,126 @@ unsafe fn dtrsm(
ldb,
))
}
+
+unsafe fn gemm_batched_ex(
+ handle: cublasHandle_t,
+ transa: cublasOperation_t,
+ transb: cublasOperation_t,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: *const std::ffi::c_void,
+ a: *const *const std::ffi::c_void,
+ atype: cudaDataType_t,
+ lda: i32,
+ b: *const *const std::ffi::c_void,
+ btype: cudaDataType_t,
+ ldb: i32,
+ beta: *const std::ffi::c_void,
+ c: *const *mut std::ffi::c_void,
+ ctype: cudaDataType_t,
+ ldc: i32,
+ batch_count: i32,
+ compute_type: cublasComputeType_t,
+ algo: cublasGemmAlgo_t,
+) -> cublasStatus_t {
+ let transa = op_from_cuda(transa);
+ let transb = op_from_cuda(transb);
+ let atype = type_from_cuda(atype);
+ let btype = type_from_cuda(btype);
+ let ctype = type_from_cuda(ctype);
+ let compute_type = to_compute_type(compute_type);
+ let algo = to_algo(algo);
+ to_cuda(rocblas_gemm_batched_ex(
+ handle.cast(),
+ transa,
+ transb,
+ m,
+ n,
+ k,
+ alpha,
+ a.cast(),
+ atype,
+ lda,
+ b.cast(),
+ btype,
+ ldb,
+ beta,
+ c.cast(),
+ ctype,
+ ldc,
+ c.cast_mut().cast(),
+ ctype,
+ ldc,
+ batch_count,
+ compute_type,
+ algo,
+ 0,
+ rocblas_gemm_flags::rocblas_gemm_flags_none.0,
+ ))
+}
+
+unsafe fn gemm_strided_batched_ex(
+ handle: cublasHandle_t,
+ transa: cublasOperation_t,
+ transb: cublasOperation_t,
+ m: ::std::os::raw::c_int,
+ n: ::std::os::raw::c_int,
+ k: ::std::os::raw::c_int,
+ alpha: *const ::std::os::raw::c_void,
+ a: *const ::std::os::raw::c_void,
+ atype: cudaDataType,
+ lda: ::std::os::raw::c_int,
+ stride_a: ::std::os::raw::c_longlong,
+ b: *const ::std::os::raw::c_void,
+ btype: cudaDataType,
+ ldb: ::std::os::raw::c_int,
+ stride_b: ::std::os::raw::c_longlong,
+ beta: *const ::std::os::raw::c_void,
+ c: *mut ::std::os::raw::c_void,
+ ctype: cudaDataType,
+ ldc: ::std::os::raw::c_int,
+ stride_c: ::std::os::raw::c_longlong,
+ batch_count: ::std::os::raw::c_int,
+ compute_type: cublasComputeType_t,
+ algo: cublasGemmAlgo_t,
+) -> cublasStatus_t {
+ let transa = op_from_cuda(transa);
+ let transb = op_from_cuda(transb);
+ let atype = type_from_cuda(atype);
+ let btype = type_from_cuda(btype);
+ let ctype = type_from_cuda(ctype);
+ let compute_type = to_compute_type(compute_type);
+ let algo = to_algo(algo);
+ to_cuda(rocblas_gemm_strided_batched_ex(
+ handle.cast(),
+ transa,
+ transb,
+ m,
+ n,
+ k,
+ alpha,
+ a,
+ atype,
+ lda,
+ stride_a,
+ b,
+ btype,
+ ldb,
+ stride_b,
+ beta,
+ c,
+ ctype,
+ ldc,
+ stride_c,
+ c,
+ ctype,
+ ldc,
+ stride_c,
+ batch_count,
+ compute_type,
+ algo,
+ 0,
+ rocblas_gemm_flags::rocblas_gemm_flags_none.0,
+ ))
+}