From 6ef19d65010164a7cc8408663eb189b64f44d26a Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 17 Sep 2021 18:31:12 +0000 Subject: Add early support for more sregs --- ptx/lib/zluda_ptx_impl.bc | Bin 31224 -> 31940 bytes ptx/lib/zluda_ptx_impl.cl | 13 +++++++ ptx/src/test/spirv_run/lanemask_lt.ptx | 25 +++++++++++++ ptx/src/test/spirv_run/lanemask_lt.spvtxt | 45 ++++++++++++++++++++++++ ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/translate.rs | 56 ++++++++++-------------------- 6 files changed, 103 insertions(+), 37 deletions(-) create mode 100644 ptx/src/test/spirv_run/lanemask_lt.ptx create mode 100644 ptx/src/test/spirv_run/lanemask_lt.spvtxt diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc index 175f4df..7aa12c8 100644 Binary files a/ptx/lib/zluda_ptx_impl.bc and b/ptx/lib/zluda_ptx_impl.bc differ diff --git a/ptx/lib/zluda_ptx_impl.cl b/ptx/lib/zluda_ptx_impl.cl index aca9327..d439795 100644 --- a/ptx/lib/zluda_ptx_impl.cl +++ b/ptx/lib/zluda_ptx_impl.cl @@ -296,6 +296,19 @@ atomic_add(atom_acq_rel_sys_shared_add_f64, memory_order_acq_rel, memory_order_a uint FUNC(activemask)() { return (uint)__builtin_amdgcn_uicmp(1, 0, 33); } + + uint FUNC(sreg_clock)() { + return (uint)__builtin_amdgcn_s_memtime(); + } + + // Taken from __ballot definition in hipamd/include/hip/amd_detail/amd_device_functions.h + // They return active threads, which I think is incorrect + extern __attribute__((const)) uint __ockl_lane_u32(); + uint FUNC(sreg_lanemask_lt)() { + uint lane_idx = __ockl_lane_u32(); + ulong mask = (1UL << lane_idx) - 1UL; + return (uint)mask; + } #endif void FUNC(__assertfail)( diff --git a/ptx/src/test/spirv_run/lanemask_lt.ptx b/ptx/src/test/spirv_run/lanemask_lt.ptx new file mode 100644 index 0000000..02b13ce --- /dev/null +++ b/ptx/src/test/spirv_run/lanemask_lt.ptx @@ -0,0 +1,25 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry lanemask_lt( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 temp; + .reg .b32 temp2; + .reg .b32 less_lane; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 temp, [in_addr]; + add.u32 temp2, temp, 1; + mov.u32 less_lane, %lanemask_lt; + add.u32 temp2, temp2, less_lane; + st.u32 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/lanemask_lt.spvtxt b/ptx/src/test/spirv_run/lanemask_lt.spvtxt new file mode 100644 index 0000000..0753c95 --- /dev/null +++ b/ptx/src/test/spirv_run/lanemask_lt.spvtxt @@ -0,0 +1,45 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %18 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "activemask" + OpExecutionMode %1 ContractionOff + OpDecorate %15 LinkageAttributes "__zluda_ptx_impl__activemask" Import + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %21 = OpTypeFunction %uint + %ulong = OpTypeInt 64 0 + %23 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %15 = OpFunction %uint None %21 + OpFunctionEnd + %1 = OpFunction %void None %23 + %6 = OpFunctionParameter %ulong + %7 = OpFunctionParameter %ulong + %14 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_uint Function + OpStore %2 %6 + OpStore %3 %7 + %8 = OpLoad %ulong %3 Aligned 8 + OpStore %4 %8 + %9 = OpFunctionCall %uint %15 + OpStore %5 %9 + %10 = OpLoad %ulong %4 + %11 = OpLoad %uint %5 + %12 = OpConvertUToPtr %_ptr_Generic_uint %10 + %13 = OpCopyObject %uint %11 + OpStore %12 %13 Aligned 4 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 0dcd0bb..d68cf17 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -210,6 +210,7 @@ test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]); test_ptx!(activemask, [0u32], [1u32]); test_ptx!(membar, [152731u32], [152731u32]); test_ptx!(func_ptr, [152731u64], [152732u64]); +test_ptx!(lanemask_lt, [187235u32], [187236u32]); struct DisplayError { err: T, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 39bd07e..15dcdd1 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -9,7 +9,9 @@ use rspirv::binary::{Assemble, Disassemble}; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.bc"); -static ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; +const ZLUDA_PTX_PREFIX_SREG_CLOCK: &'static str = "__zluda_ptx_impl__sreg_clock"; +const ZLUDA_PTX_PREFIX_SREG_LANEMASK_LT: &'static str = "__zluda_ptx_impl__sreg_lanemask_lt"; quick_error! { #[derive(Debug)] @@ -1015,25 +1017,6 @@ fn compute_denorm_information<'input>( .collect() } -fn emit_builtins( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver, -) { - for (reg, id) in id_defs.special_registers.builtins() { - let result_type = map.get_or_add( - builder, - SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input), - ); - builder.variable(result_type, Some(id), spirv::StorageClass::Input, None); - builder.decorate( - id, - spirv::Decoration::BuiltIn, - [dr::Operand::BuiltIn(reg.get_builtin())].iter().cloned(), - ); - } -} - fn emit_function_header<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -4815,6 +4798,8 @@ enum PtxSpecialRegister { Ctaid64, Nctaid, Nctaid64, + Clock, + LanemaskLt, } impl PtxSpecialRegister { @@ -4824,6 +4809,8 @@ impl PtxSpecialRegister { "%ntid" => Some(Self::Ntid), "%ctaid" => Some(Self::Ctaid), "%nctaid" => Some(Self::Nctaid), + "%clock" => Some(Self::Clock), + "%lanemask_lt" => Some(Self::LanemaskLt), _ => None, } } @@ -4838,6 +4825,8 @@ impl PtxSpecialRegister { PtxSpecialRegister::Ctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4), PtxSpecialRegister::Nctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), + PtxSpecialRegister::Clock => ast::Type::Scalar(ast::ScalarType::U32), + PtxSpecialRegister::LanemaskLt => ast::Type::Scalar(ast::ScalarType::U32), } } @@ -4846,7 +4835,9 @@ impl PtxSpecialRegister { PtxSpecialRegister::Tid | PtxSpecialRegister::Ntid | PtxSpecialRegister::Ctaid - | PtxSpecialRegister::Nctaid => ast::ScalarType::U32, + | PtxSpecialRegister::Nctaid + | PtxSpecialRegister::Clock + | PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, PtxSpecialRegister::Tid64 | PtxSpecialRegister::Ntid64 | PtxSpecialRegister::Ctaid64 @@ -4854,21 +4845,6 @@ impl PtxSpecialRegister { } } - fn get_builtin(self) -> spirv::BuiltIn { - match self { - PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { - spirv::BuiltIn::LocalInvocationId - } - PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => { - spirv::BuiltIn::EnqueuedWorkgroupSize - } - PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => spirv::BuiltIn::WorkgroupId, - PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { - spirv::BuiltIn::NumWorkgroups - } - } - } - fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) { match self { PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { @@ -4883,6 +4859,10 @@ impl PtxSpecialRegister { PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { ("_Z14get_num_groupsj", ast::ScalarType::U64) } + PtxSpecialRegister::Clock => (ZLUDA_PTX_PREFIX_SREG_CLOCK, ast::ScalarType::U32), + PtxSpecialRegister::LanemaskLt => { + (ZLUDA_PTX_PREFIX_SREG_LANEMASK_LT, ast::ScalarType::U32) + } } } @@ -4899,7 +4879,9 @@ impl PtxSpecialRegister { PtxSpecialRegister::Tid64 | PtxSpecialRegister::Ntid64 | PtxSpecialRegister::Ctaid64 - | PtxSpecialRegister::Nctaid64 => None, + | PtxSpecialRegister::Nctaid64 + | PtxSpecialRegister::Clock => None, + PtxSpecialRegister::LanemaskLt => None, } } } -- cgit v1.2.3