diff options
author | Andrzej Janik <[email protected]> | 2024-02-15 16:54:04 +0000 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-02-15 16:54:04 +0000 |
commit | 13bf965784bcee152e9591e1b35bc73c60eda723 (patch) | |
tree | 5fd552a27f0b6878f056af36c6259ed2f457d1ba | |
parent | 8fef0e4fe7e6ee6d1f2b556694faffee0c6b9648 (diff) | |
download | ZLUDA-13bf965784bcee152e9591e1b35bc73c60eda723.tar.gz ZLUDA-13bf965784bcee152e9591e1b35bc73c60eda723.zip |
Add missing bits and pieces
-rw-r--r-- | ptx/src/emit.rs | 25 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 2 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/set_f16x2.ll | 68 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/set_f16x2.ptx (renamed from ptx/src/test/spirv_run/st_f16x2.ptx) | 8 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/st_f16x2.ll | 43 | ||||
-rw-r--r-- | zluda/src/impl/device.rs | 4 | ||||
-rw-r--r-- | zluda_blas/src/cublas.rs | 49 | ||||
-rw-r--r-- | zluda_blas/src/lib.rs | 123 |
8 files changed, 267 insertions, 55 deletions
diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index ac054d4..94cc973 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -1176,15 +1176,26 @@ fn emit_inst_set( ) -> Result<(), TranslateError> { let builder = ctx.builder.get(); let temp_result = emit_inst_setp_float(ctx, details.cmp_op, None, arg.src1, arg.src2)?; - if details.src_type != ast::ScalarType::F16x2 || details.dst_type == ast::ScalarType::F16x2 { + if details.src_type != ast::ScalarType::F16x2 { + return Err(TranslateError::todo()); + } + if details.dst_type.is_integer() && details.dst_type.size_of() == mem::size_of::<u32>() as u8 { + let b16vec2_type = get_llvm_type(ctx, &ast::Type::Vector(ast::ScalarType::B16, 2))?; + let b16vec2_result = + unsafe { LLVMBuildSExt(builder, temp_result, b16vec2_type, LLVM_UNNAMED) }; + + let u32_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?; + ctx.names.register_result(arg.dst, |dst_name| unsafe { + LLVMBuildBitCast(builder, b16vec2_result, u32_type, dst_name) + }); + } else if matches!(details.dst_type, ast::ScalarType::F16x2) { + let f16x2_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::F16x2))?; + ctx.names.register_result(arg.dst, |dst_name| unsafe { + LLVMBuildUIToFP(builder, temp_result, f16x2_type, dst_name) + }); + } else { return Err(TranslateError::todo()); } - let b16vec2_type = get_llvm_type(ctx, &ast::Type::Vector(ast::ScalarType::B16, 2))?; - let b16vec2_result = unsafe { LLVMBuildSExt(builder, temp_result, b16vec2_type, LLVM_UNNAMED) }; - let u32_type = get_llvm_type(ctx, &ast::Type::Scalar(ast::ScalarType::U32))?; - ctx.names.register_result(arg.dst, |dst_name| unsafe { - LLVMBuildBitCast(builder, b16vec2_result, u32_type, dst_name) - }); Ok(()) } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index e640765..a65240c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -336,7 +336,7 @@ test_ptx!( [f16::from_f32(2.0), f16::from_f32(3.0)],
[f16::from_f32(2.0), f16::from_f32(5.0)]
);
-test_ptx!(st_f16x2, [0xc1690e6eu32, 0x13739444u32], [0xffffu32]);
+test_ptx!(set_f16x2, [0xc1690e6eu32, 0x13739444u32, 0x424834CC, 0x4248B4CC], [0xffffu32, 0x3C000000]);
test_ptx!(
dp4a,
[0xde3032f5u32, 0x2474fe15, 0xf51d8d6c],
diff --git a/ptx/src/test/spirv_run/set_f16x2.ll b/ptx/src/test/spirv_run/set_f16x2.ll new file mode 100644 index 0000000..4a2c8ea --- /dev/null +++ b/ptx/src/test/spirv_run/set_f16x2.ll @@ -0,0 +1,68 @@ +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7" +target triple = "amdgcn-amd-amdhsa" + +define protected amdgpu_kernel void @set_f16x2(ptr addrspace(4) byref(i64) %"41", ptr addrspace(4) byref(i64) %"42") #0 { +"59": + %"11" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"11", align 1 + %"12" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"12", align 1 + %"4" = alloca i64, align 8, addrspace(5) + %"5" = alloca i64, align 8, addrspace(5) + %"6" = alloca i32, align 4, addrspace(5) + %"7" = alloca i32, align 4, addrspace(5) + %"8" = alloca i32, align 4, addrspace(5) + %"9" = alloca i32, align 4, addrspace(5) + %"10" = alloca <2 x half>, align 4, addrspace(5) + %"13" = load i64, ptr addrspace(4) %"41", align 8 + store i64 %"13", ptr addrspace(5) %"4", align 8 + %"14" = load i64, ptr addrspace(4) %"42", align 8 + store i64 %"14", ptr addrspace(5) %"5", align 8 + %"16" = load i64, ptr addrspace(5) %"4", align 8 + %"44" = inttoptr i64 %"16" to ptr + %"43" = load i32, ptr %"44", align 4 + store i32 %"43", ptr addrspace(5) %"6", align 4 + %"18" = load i64, ptr addrspace(5) %"4", align 8 + %"45" = inttoptr i64 %"18" to ptr + %"61" = getelementptr inbounds i8, ptr %"45", i64 4 + %"46" = load i32, ptr %"61", align 4 + store i32 %"46", ptr addrspace(5) %"7", align 4 + %"20" = load i64, ptr addrspace(5) %"4", align 8 + %"47" = inttoptr i64 %"20" to ptr + %"63" = getelementptr inbounds i8, ptr %"47", i64 8 + %"48" = load i32, ptr %"63", align 4 + store i32 %"48", ptr addrspace(5) %"8", align 4 + %"22" = load i64, ptr addrspace(5) %"4", align 8 + %"49" = inttoptr i64 %"22" to ptr + %"65" = getelementptr inbounds i8, ptr %"49", i64 12 + %"50" = load i32, ptr %"65", align 4 + store i32 %"50", ptr addrspace(5) %"9", align 4 + %"24" = load i32, ptr addrspace(5) %"6", align 4 + %"25" = load i32, ptr addrspace(5) %"7", align 4 + %"52" = bitcast i32 %"24" to <2 x half> + %"53" = bitcast i32 %"25" to <2 x half> + %0 = fcmp ugt <2 x half> %"52", %"53" + %1 = sext <2 x i1> %0 to <2 x i16> + %"51" = bitcast <2 x i16> %1 to i32 + store i32 %"51", ptr addrspace(5) %"6", align 4 + %"27" = load i32, ptr addrspace(5) %"8", align 4 + %"28" = load i32, ptr addrspace(5) %"9", align 4 + %"55" = bitcast i32 %"27" to <2 x half> + %"56" = bitcast i32 %"28" to <2 x half> + %2 = fcmp oeq <2 x half> %"55", %"56" + %"54" = uitofp <2 x i1> %2 to <2 x half> + %"26" = bitcast <2 x half> %"54" to i32 + store i32 %"26", ptr addrspace(5) %"8", align 4 + %"29" = load i64, ptr addrspace(5) %"5", align 8 + %"30" = load i32, ptr addrspace(5) %"6", align 4 + %"57" = inttoptr i64 %"29" to ptr + store i32 %"30", ptr %"57", align 4 + %"31" = load i64, ptr addrspace(5) %"5", align 8 + %"32" = load i32, ptr addrspace(5) %"8", align 4 + %"58" = inttoptr i64 %"31" to ptr + %"67" = getelementptr inbounds i8, ptr %"58", i64 4 + store i32 %"32", ptr %"67", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/st_f16x2.ptx b/ptx/src/test/spirv_run/set_f16x2.ptx index b386f68..420dbbf 100644 --- a/ptx/src/test/spirv_run/st_f16x2.ptx +++ b/ptx/src/test/spirv_run/set_f16x2.ptx @@ -2,7 +2,7 @@ .target sm_53 .address_size 64 -.visible .entry st_f16x2( +.visible .entry set_f16x2( .param .u64 input, .param .u64 output ) @@ -11,6 +11,8 @@ .reg .u64 out_addr; .reg .b32 temp0; .reg .b32 temp1; + .reg .b32 temp2; + .reg .b32 temp3; .reg .f16x2 sela; ld.param.u64 in_addr, [input]; @@ -18,7 +20,11 @@ ld.u32 temp0, [in_addr]; ld.u32 temp1, [in_addr+4]; + ld.u32 temp2, [in_addr+8]; + ld.u32 temp3, [in_addr+12]; set.gtu.u32.f16x2 temp0, temp0, temp1; + set.eq.f16x2.f16x2 temp2, temp2, temp3; st.b32 [out_addr], temp0; + st.b32 [out_addr+4], temp2; ret; } diff --git a/ptx/src/test/spirv_run/st_f16x2.ll b/ptx/src/test/spirv_run/st_f16x2.ll deleted file mode 100644 index 69fd33b..0000000 --- a/ptx/src/test/spirv_run/st_f16x2.ll +++ /dev/null @@ -1,43 +0,0 @@ -target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7" -target triple = "amdgcn-amd-amdhsa" - -define protected amdgpu_kernel void @st_f16x2(ptr addrspace(4) byref(i64) %"24", ptr addrspace(4) byref(i64) %"25") #0 { -"34": - %"9" = alloca i1, align 1, addrspace(5) - store i1 false, ptr addrspace(5) %"9", align 1 - %"10" = alloca i1, align 1, addrspace(5) - store i1 false, ptr addrspace(5) %"10", align 1 - %"4" = alloca i64, align 8, addrspace(5) - %"5" = alloca i64, align 8, addrspace(5) - %"6" = alloca i32, align 4, addrspace(5) - %"7" = alloca i32, align 4, addrspace(5) - %"8" = alloca <2 x half>, align 4, addrspace(5) - %"11" = load i64, ptr addrspace(4) %"24", align 8 - store i64 %"11", ptr addrspace(5) %"4", align 8 - %"12" = load i64, ptr addrspace(4) %"25", align 8 - store i64 %"12", ptr addrspace(5) %"5", align 8 - %"14" = load i64, ptr addrspace(5) %"4", align 8 - %"27" = inttoptr i64 %"14" to ptr - %"26" = load i32, ptr %"27", align 4 - store i32 %"26", ptr addrspace(5) %"6", align 4 - %"16" = load i64, ptr addrspace(5) %"4", align 8 - %"28" = inttoptr i64 %"16" to ptr - %"36" = getelementptr inbounds i8, ptr %"28", i64 4 - %"29" = load i32, ptr %"36", align 4 - store i32 %"29", ptr addrspace(5) %"7", align 4 - %"18" = load i32, ptr addrspace(5) %"6", align 4 - %"19" = load i32, ptr addrspace(5) %"7", align 4 - %"31" = bitcast i32 %"18" to <2 x half> - %"32" = bitcast i32 %"19" to <2 x half> - %0 = fcmp ugt <2 x half> %"31", %"32" - %1 = sext <2 x i1> %0 to <2 x i16> - %"30" = bitcast <2 x i16> %1 to i32 - store i32 %"30", ptr addrspace(5) %"6", align 4 - %"20" = load i64, ptr addrspace(5) %"5", align 8 - %"21" = load i32, ptr addrspace(5) %"6", align 4 - %"33" = inttoptr i64 %"20" to ptr - store i32 %"21", ptr %"33", align 4 - ret void -} - -attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs index 4a97b3b..59201e2 100644 --- a/zluda/src/impl/device.rs +++ b/zluda/src/impl/device.rs @@ -162,7 +162,9 @@ pub(crate) unsafe fn get_attribute( | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_IPC_EVENT_SUPPORTED | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH - | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS => { + | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_UNIFIED_FUNCTION_POINTERS + // Possibly true, used by llama.cpp + | CUdevice_attribute::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED => { *pi = 0; return Ok(()); } diff --git a/zluda_blas/src/cublas.rs b/zluda_blas/src/cublas.rs index b0bf587..16cee4b 100644 --- a/zluda_blas/src/cublas.rs +++ b/zluda_blas/src/cublas.rs @@ -3926,7 +3926,28 @@ pub unsafe extern "system" fn cublasGemmBatchedEx( computeType: cublasComputeType_t, algo: cublasGemmAlgo_t, ) -> cublasStatus_t { - crate::unsupported() + crate::gemm_batched_ex( + handle, + transa, + transb, + m, + n, + k, + alpha, + Aarray, + Atype, + lda, + Barray, + Btype, + ldb, + beta, + Carray, + Ctype, + ldc, + batchCount, + computeType, + algo, + ) } #[no_mangle] @@ -3955,7 +3976,31 @@ pub unsafe extern "system" fn cublasGemmStridedBatchedEx( computeType: cublasComputeType_t, algo: cublasGemmAlgo_t, ) -> cublasStatus_t { - crate::unsupported() + crate::gemm_strided_batched_ex( + handle, + transa, + transb, + m, + n, + k, + alpha, + A, + Atype, + lda, + strideA, + B, + Btype, + ldb, + strideB, + beta, + C, + Ctype, + ldc, + strideC, + batchCount, + computeType, + algo, + ) } #[no_mangle] diff --git a/zluda_blas/src/lib.rs b/zluda_blas/src/lib.rs index 0d50a99..c4b0bc5 100644 --- a/zluda_blas/src/lib.rs +++ b/zluda_blas/src/lib.rs @@ -916,3 +916,126 @@ unsafe fn dtrsm( ldb, )) } + +unsafe fn gemm_batched_ex( + handle: cublasHandle_t, + transa: cublasOperation_t, + transb: cublasOperation_t, + m: i32, + n: i32, + k: i32, + alpha: *const std::ffi::c_void, + a: *const *const std::ffi::c_void, + atype: cudaDataType_t, + lda: i32, + b: *const *const std::ffi::c_void, + btype: cudaDataType_t, + ldb: i32, + beta: *const std::ffi::c_void, + c: *const *mut std::ffi::c_void, + ctype: cudaDataType_t, + ldc: i32, + batch_count: i32, + compute_type: cublasComputeType_t, + algo: cublasGemmAlgo_t, +) -> cublasStatus_t { + let transa = op_from_cuda(transa); + let transb = op_from_cuda(transb); + let atype = type_from_cuda(atype); + let btype = type_from_cuda(btype); + let ctype = type_from_cuda(ctype); + let compute_type = to_compute_type(compute_type); + let algo = to_algo(algo); + to_cuda(rocblas_gemm_batched_ex( + handle.cast(), + transa, + transb, + m, + n, + k, + alpha, + a.cast(), + atype, + lda, + b.cast(), + btype, + ldb, + beta, + c.cast(), + ctype, + ldc, + c.cast_mut().cast(), + ctype, + ldc, + batch_count, + compute_type, + algo, + 0, + rocblas_gemm_flags::rocblas_gemm_flags_none.0, + )) +} + +unsafe fn gemm_strided_batched_ex( + handle: cublasHandle_t, + transa: cublasOperation_t, + transb: cublasOperation_t, + m: ::std::os::raw::c_int, + n: ::std::os::raw::c_int, + k: ::std::os::raw::c_int, + alpha: *const ::std::os::raw::c_void, + a: *const ::std::os::raw::c_void, + atype: cudaDataType, + lda: ::std::os::raw::c_int, + stride_a: ::std::os::raw::c_longlong, + b: *const ::std::os::raw::c_void, + btype: cudaDataType, + ldb: ::std::os::raw::c_int, + stride_b: ::std::os::raw::c_longlong, + beta: *const ::std::os::raw::c_void, + c: *mut ::std::os::raw::c_void, + ctype: cudaDataType, + ldc: ::std::os::raw::c_int, + stride_c: ::std::os::raw::c_longlong, + batch_count: ::std::os::raw::c_int, + compute_type: cublasComputeType_t, + algo: cublasGemmAlgo_t, +) -> cublasStatus_t { + let transa = op_from_cuda(transa); + let transb = op_from_cuda(transb); + let atype = type_from_cuda(atype); + let btype = type_from_cuda(btype); + let ctype = type_from_cuda(ctype); + let compute_type = to_compute_type(compute_type); + let algo = to_algo(algo); + to_cuda(rocblas_gemm_strided_batched_ex( + handle.cast(), + transa, + transb, + m, + n, + k, + alpha, + a, + atype, + lda, + stride_a, + b, + btype, + ldb, + stride_b, + beta, + c, + ctype, + ldc, + stride_c, + c, + ctype, + ldc, + stride_c, + batch_count, + compute_type, + algo, + 0, + rocblas_gemm_flags::rocblas_gemm_flags_none.0, + )) +} |