aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-02-15 16:54:04 +0000
committerAndrzej Janik <[email protected]>2024-02-15 16:54:04 +0000
commit13bf965784bcee152e9591e1b35bc73c60eda723 (patch)
tree5fd552a27f0b6878f056af36c6259ed2f457d1ba
parent8fef0e4fe7e6ee6d1f2b556694faffee0c6b9648 (diff)
downloadZLUDA-13bf965784bcee152e9591e1b35bc73c60eda723.tar.gz
ZLUDA-13bf965784bcee152e9591e1b35bc73c60eda723.zip
Add missing bits and pieces
-rw-r--r--ptx/src/emit.rs25
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
-rw-r--r--ptx/src/test/spirv_run/set_f16x2.ll68
-rw-r--r--ptx/src/test/spirv_run/set_f16x2.ptx (renamed from ptx/src/test/spirv_run/st_f16x2.ptx)8
-rw-r--r--ptx/src/test/spirv_run/st_f16x2.ll43
-rw-r--r--zluda/src/impl/device.rs4
-rw-r--r--zluda_blas/src/cublas.rs49
-rw-r--r--zluda_blas/src/lib.rs123
8 files changed, 267 insertions, 55 deletions
diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs
index ac054d4..94cc973 100644
--- a/ptx/src/emit.rs
+++ b/ptx/src/emit.rs
@@ -1176,15 +1176,26 @@ fn emit_inst_set(
) -> Result<(), TranslateError> {
let builder = ctx.builder.get();
let temp_result = emit_inst_setp_float(ctx, details.cmp_op, None, arg.src1, arg.src2)?;
- if details.src_type != ast::ScalarType::F16x2 || details.dst_type == ast::ScalarType::F16x2 {
+ if details.src_type != ast::ScalarType::F16x2 {
+ return Err(TranslateError::todo());
+ }
+ if details.dst_type.is_integer() && details.dst_type.size_of() == mem::size_of::<u32>() as u8 {
+ let b16vec2_type = get_llvm_type(ctx, &ast::Type::Vector(ast::ScalarType::B16, 2))?;
+ let b16vec2_result =
+ unsafe { LLVMBuildSExt(builder, temp_result, b16vec2_type, LLVM_UNNAMED) };
+
+ let u32_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?;
+ ctx.names.register_result(arg.dst, |dst_name| unsafe {
+ LLVMBuildBitCast(builder, b16vec2_result, u32_type, dst_name)
+ });
+ } else if matches!(details.dst_type, ast::ScalarType::F16x2) {
+ let f16x2_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::F16x2))?;
+ ctx.names.register_result(arg.dst, |dst_name| unsafe {
+ LLVMBuildUIToFP(builder, temp_result, f16x2_type, dst_name)
+ });
+ } else {
return Err(TranslateError::todo());
}
- let b16vec2_type = get_llvm_type(ctx, &ast::Type::Vector(ast::ScalarType::B16, 2))?;
- let b16vec2_result = unsafe { LLVMBuildSExt(builder, temp_result, b16vec2_type, LLVM_UNNAMED) };
- let u32_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?;
- ctx.names.register_result(arg.dst, |dst_name| unsafe {
- LLVMBuildBitCast(builder, b16vec2_result, u32_type, dst_name)
- });
Ok(())
}
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index e640765..a65240c 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -336,7 +336,7 @@ test_ptx!(
[f16::from_f32(2.0), f16::from_f32(3.0)],
[f16::from_f32(2.0), f16::from_f32(5.0)]
);
-test_ptx!(st_f16x2, [0xc1690e6eu32, 0x13739444u32], [0xffffu32]);
+test_ptx!(set_f16x2, [0xc1690e6eu32, 0x13739444u32, 0x424834CC, 0x4248B4CC], [0xffffu32, 0x3C000000]);
test_ptx!(
dp4a,
[0xde3032f5u32, 0x2474fe15, 0xf51d8d6c],
diff --git a/ptx/src/test/spirv_run/set_f16x2.ll b/ptx/src/test/spirv_run/set_f16x2.ll
new file mode 100644
index 0000000..4a2c8ea
--- /dev/null
+++ b/ptx/src/test/spirv_run/set_f16x2.ll
@@ -0,0 +1,68 @@
+target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
+target triple = "amdgcn-amd-amdhsa"
+
+define protected amdgpu_kernel void @set_f16x2(ptr addrspace(4) byref(i64) %"41", ptr addrspace(4) byref(i64) %"42") #0 {
+"59":
+ %"11" = alloca i1, align 1, addrspace(5)
+ store i1 false, ptr addrspace(5) %"11", align 1
+ %"12" = alloca i1, align 1, addrspace(5)
+ store i1 false, ptr addrspace(5) %"12", align 1
+ %"4" = alloca i64, align 8, addrspace(5)
+ %"5" = alloca i64, align 8, addrspace(5)
+ %"6" = alloca i32, align 4, addrspace(5)
+ %"7" = alloca i32, align 4, addrspace(5)
+ %"8" = alloca i32, align 4, addrspace(5)
+ %"9" = alloca i32, align 4, addrspace(5)
+ %"10" = alloca <2 x half>, align 4, addrspace(5)
+ %"13" = load i64, ptr addrspace(4) %"41", align 8
+ store i64 %"13", ptr addrspace(5) %"4", align 8
+ %"14" = load i64, ptr addrspace(4) %"42", align 8
+ store i64 %"14", ptr addrspace(5) %"5", align 8
+ %"16" = load i64, ptr addrspace(5) %"4", align 8
+ %"44" = inttoptr i64 %"16" to ptr
+ %"43" = load i32, ptr %"44", align 4
+ store i32 %"43", ptr addrspace(5) %"6", align 4
+ %"18" = load i64, ptr addrspace(5) %"4", align 8
+ %"45" = inttoptr i64 %"18" to ptr
+ %"61" = getelementptr inbounds i8, ptr %"45", i64 4
+ %"46" = load i32, ptr %"61", align 4
+ store i32 %"46", ptr addrspace(5) %"7", align 4
+ %"20" = load i64, ptr addrspace(5) %"4", align 8
+ %"47" = inttoptr i64 %"20" to ptr
+ %"63" = getelementptr inbounds i8, ptr %"47", i64 8
+ %"48" = load i32, ptr %"63", align 4
+ store i32 %"48", ptr addrspace(5) %"8", align 4
+ %"22" = load i64, ptr addrspace(5) %"4", align 8
+ %"49" = inttoptr i64 %"22" to ptr
+ %"65" = getelementptr inbounds i8, ptr %"49", i64 12
+ %"50" = load i32, ptr %"65", align 4
+ store i32 %"50", ptr addrspace(5) %"9", align 4
+ %"24" = load i32, ptr addrspace(5) %"6", align 4
+ %"25" = load i32, ptr addrspace(5) %"7", align 4
+ %"52" = bitcast i32 %"24" to <2 x half>
+ %"53" = bitcast i32 %"25" to <2 x half>
+ %0 = fcmp ugt <2 x half> %"52", %"53"
+ %1 = sext <2 x i1> %0 to <2 x i16>
+ %"51" = bitcast <2 x i16> %1 to i32
+ store i32 %"51", ptr addrspace(5) %"6", align 4
+ %"27" = load i32, ptr addrspace(5) %"8", align 4
+ %"28" = load i32, ptr addrspace(5) %"9", align 4
+ %"55" = bitcast i32 %"27" to <2 x half>
+ %"56" = bitcast i32 %"28" to <2 x half>
+ %2 = fcmp oeq <2 x half> %"55", %"56"
+ %"54" = uitofp <2 x i1> %2 to <2 x half>
+ %"26" = bitcast <2 x half> %"54" to i32
+ store i32 %"26", ptr addrspace(5) %"8", align 4
+ %"29" = load i64, ptr addrspace(5) %"5", align 8
+ %"30" = load i32, ptr addrspace(5) %"6", align 4
+ %"57" = inttoptr i64 %"29" to ptr
+ store i32 %"30", ptr %"57", align 4
+ %"31" = load i64, ptr addrspace(5) %"5", align 8
+ %"32" = load i32, ptr addrspace(5) %"8", align 4
+ %"58" = inttoptr i64 %"31" to ptr
+ %"67" = getelementptr inbounds i8, ptr %"58", i64 4
+ store i32 %"32", ptr %"67", align 4
+ ret void
+}
+
+attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
diff --git a/ptx/src/test/spirv_run/st_f16x2.ptx b/ptx/src/test/spirv_run/set_f16x2.ptx
index b386f68..420dbbf 100644
--- a/ptx/src/test/spirv_run/st_f16x2.ptx
+++ b/ptx/src/test/spirv_run/set_f16x2.ptx
@@ -2,7 +2,7 @@
.target sm_53
.address_size 64
-.visible .entry st_f16x2(
+.visible .entry set_f16x2(
.param .u64 input,
.param .u64 output
)
@@ -11,6 +11,8 @@
.reg .u64 out_addr;
.reg .b32 temp0;
.reg .b32 temp1;
+ .reg .b32 temp2;
+ .reg .b32 temp3;
.reg .f16x2 sela;
ld.param.u64 in_addr, [input];
@@ -18,7 +20,11 @@
ld.u32 temp0, [in_addr];
ld.u32 temp1, [in_addr+4];
+ ld.u32 temp2, [in_addr+8];
+ ld.u32 temp3, [in_addr+12];
set.gtu.u32.f16x2 temp0, temp0, temp1;
+ set.eq.f16x2.f16x2 temp2, temp2, temp3;
st.b32 [out_addr], temp0;
+ st.b32 [out_addr+4], temp2;
ret;
}
diff --git a/ptx/src/test/spirv_run/st_f16x2.ll b/ptx/src/test/spirv_run/st_f16x2.ll
deleted file mode 100644
index 69fd33b..0000000
--- a/ptx/src/test/spirv_run/st_f16x2.ll
+++ /dev/null
@@ -1,43 +0,0 @@
-target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
-target triple = "amdgcn-amd-amdhsa"
-
-define protected amdgpu_kernel void @st_f16x2(ptr addrspace(4) byref(i64) %"24", ptr addrspace(4) byref(i64) %"25") #0 {
-"34":
- %"9" = alloca i1, align 1, addrspace(5)
- store i1 false, ptr addrspace(5) %"9", align 1
- %"10" = alloca i1, align 1, addrspace(5)
- store i1 false, ptr addrspace(5) %"10", align 1
- %"4" = alloca i64, align 8, addrspace(5)
- %"5" = alloca i64, align 8, addrspace(5)
- %"6" = alloca i32, align 4, addrspace(5)
- %"7" = alloca i32, align 4, addrspace(5)
- %"8" = alloca <2 x half>, align 4, addrspace(5)
- %"11" = load i64, ptr addrspace(4) %"24", align 8
- store i64 %"11", ptr addrspace(5) %"4", align 8
- %"12" = load i64, ptr addrspace(4) %"25", align 8
- store i64 %"12", ptr addrspace(5) %"5", align 8
- %"14" = load i64, ptr addrspace(5) %"4", align 8
- %"27" = inttoptr i64 %"14" to ptr
- %"26" = load i32, ptr %"27", align 4
- store i32 %"26", ptr addrspace(5) %"6", align 4
- %"16" = load i64, ptr addrspace(5) %"4", align 8
- %"28" = inttoptr i64 %"16" to ptr
- %"36" = getelementptr inbounds i8, ptr %"28", i64 4
- %"29" = load i32, ptr %"36", align 4
- store i32 %"29", ptr addrspace(5) %"7", align 4
- %"18" = load i32, ptr addrspace(5) %"6", align 4
- %"19" = load i32, ptr addrspace(5) %"7", align 4
- %"31" = bitcast i32 %"18" to <2 x half>
- %"32" = bitcast i32 %"19" to <2 x half>
- %0 = fcmp ugt <2 x half> %"31", %"32"
- %1 = sext <2 x i1> %0 to <2 x i16>
- %"30" = bitcast <2 x i16> %1 to i32
- store i32 %"30", ptr addrspace(5) %"6", align 4
- %"20" = load i64, ptr addrspace(5) %"5", align 8
- %"21" = load i32, ptr addrspace(5) %"6", align 4
- %"33" = inttoptr i64 %"20" to ptr
- store i32 %"21", ptr %"33", align 4
- ret void
-}
-
-attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" }
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index 4a97b3b..59201e2 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -162,7 +162,9 @@ pub(crate) unsafe fn get_attribute(
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED
| CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH
- | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS => {
+ | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS
+ // Possibly true, used by llama.cpp
+ | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED => {
*pi = 0;
return Ok(());
}
diff --git a/zluda_blas/src/cublas.rs b/zluda_blas/src/cublas.rs
index b0bf587..16cee4b 100644
--- a/zluda_blas/src/cublas.rs
+++ b/zluda_blas/src/cublas.rs
@@ -3926,7 +3926,28 @@ pub unsafe extern "system" fn cublasGemmBatchedEx(
computeType: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
- crate::unsupported()
+ crate::gemm_batched_ex(
+ handle,
+ transa,
+ transb,
+ m,
+ n,
+ k,
+ alpha,
+ Aarray,
+ Atype,
+ lda,
+ Barray,
+ Btype,
+ ldb,
+ beta,
+ Carray,
+ Ctype,
+ ldc,
+ batchCount,
+ computeType,
+ algo,
+ )
}
#[no_mangle]
@@ -3955,7 +3976,31 @@ pub unsafe extern "system" fn cublasGemmStridedBatchedEx(
computeType: cublasComputeType_t,
algo: cublasGemmAlgo_t,
) -> cublasStatus_t {
- crate::unsupported()
+ crate::gemm_strided_batched_ex(
+ handle,
+ transa,
+ transb,
+ m,
+ n,
+ k,
+ alpha,
+ A,
+ Atype,
+ lda,
+ strideA,
+ B,
+ Btype,
+ ldb,
+ strideB,
+ beta,
+ C,
+ Ctype,
+ ldc,
+ strideC,
+ batchCount,
+ computeType,
+ algo,
+ )
}
#[no_mangle]
diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs
index 0d50a99..c4b0bc5 100644
--- a/zluda_blas/src/lib.rs
+++ b/zluda_blas/src/lib.rs
@@ -916,3 +916,126 @@ unsafe fn dtrsm(
ldb,
))
}
+
+unsafe fn gemm_batched_ex(
+ handle: cublasHandle_t,
+ transa: cublasOperation_t,
+ transb: cublasOperation_t,
+ m: i32,
+ n: i32,
+ k: i32,
+ alpha: *const std::ffi::c_void,
+ a: *const *const std::ffi::c_void,
+ atype: cudaDataType_t,
+ lda: i32,
+ b: *const *const std::ffi::c_void,
+ btype: cudaDataType_t,
+ ldb: i32,
+ beta: *const std::ffi::c_void,
+ c: *const *mut std::ffi::c_void,
+ ctype: cudaDataType_t,
+ ldc: i32,
+ batch_count: i32,
+ compute_type: cublasComputeType_t,
+ algo: cublasGemmAlgo_t,
+) -> cublasStatus_t {
+ let transa = op_from_cuda(transa);
+ let transb = op_from_cuda(transb);
+ let atype = type_from_cuda(atype);
+ let btype = type_from_cuda(btype);
+ let ctype = type_from_cuda(ctype);
+ let compute_type = to_compute_type(compute_type);
+ let algo = to_algo(algo);
+ to_cuda(rocblas_gemm_batched_ex(
+ handle.cast(),
+ transa,
+ transb,
+ m,
+ n,
+ k,
+ alpha,
+ a.cast(),
+ atype,
+ lda,
+ b.cast(),
+ btype,
+ ldb,
+ beta,
+ c.cast(),
+ ctype,
+ ldc,
+ c.cast_mut().cast(),
+ ctype,
+ ldc,
+ batch_count,
+ compute_type,
+ algo,
+ 0,
+ rocblas_gemm_flags::rocblas_gemm_flags_none.0,
+ ))
+}
+
+unsafe fn gemm_strided_batched_ex(
+ handle: cublasHandle_t,
+ transa: cublasOperation_t,
+ transb: cublasOperation_t,
+ m: ::std::os::raw::c_int,
+ n: ::std::os::raw::c_int,
+ k: ::std::os::raw::c_int,
+ alpha: *const ::std::os::raw::c_void,
+ a: *const ::std::os::raw::c_void,
+ atype: cudaDataType,
+ lda: ::std::os::raw::c_int,
+ stride_a: ::std::os::raw::c_longlong,
+ b: *const ::std::os::raw::c_void,
+ btype: cudaDataType,
+ ldb: ::std::os::raw::c_int,
+ stride_b: ::std::os::raw::c_longlong,
+ beta: *const ::std::os::raw::c_void,
+ c: *mut ::std::os::raw::c_void,
+ ctype: cudaDataType,
+ ldc: ::std::os::raw::c_int,
+ stride_c: ::std::os::raw::c_longlong,
+ batch_count: ::std::os::raw::c_int,
+ compute_type: cublasComputeType_t,
+ algo: cublasGemmAlgo_t,
+) -> cublasStatus_t {
+ let transa = op_from_cuda(transa);
+ let transb = op_from_cuda(transb);
+ let atype = type_from_cuda(atype);
+ let btype = type_from_cuda(btype);
+ let ctype = type_from_cuda(ctype);
+ let compute_type = to_compute_type(compute_type);
+ let algo = to_algo(algo);
+ to_cuda(rocblas_gemm_strided_batched_ex(
+ handle.cast(),
+ transa,
+ transb,
+ m,
+ n,
+ k,
+ alpha,
+ a,
+ atype,
+ lda,
+ stride_a,
+ b,
+ btype,
+ ldb,
+ stride_b,
+ beta,
+ c,
+ ctype,
+ ldc,
+ stride_c,
+ c,
+ ctype,
+ ldc,
+ stride_c,
+ batch_count,
+ compute_type,
+ algo,
+ 0,
+ rocblas_gemm_flags::rocblas_gemm_flags_none.0,
+ ))
+}