From ff449289eb6fe4e429be25ef829ff10144146056 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sat, 1 Aug 2020 01:09:57 +0200 Subject: Implement shift left --- ptx/src/ast.rs | 9 +++++++-- ptx/src/ptx.lalrpop | 8 +++++--- ptx/src/test/spirv_run/mod.rs | 1 + ptx/src/test/spirv_run/shl.ptx | 22 +++++++++++++++++++++ ptx/src/test/spirv_run/shl.spvtxt | 41 +++++++++++++++++++++++++++++++++++++++ ptx/src/translate.rs | 36 ++++++++++++++++++++++++++++++---- 6 files changed, 108 insertions(+), 9 deletions(-) create mode 100644 ptx/src/test/spirv_run/shl.ptx create mode 100644 ptx/src/test/spirv_run/shl.spvtxt diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 158ec8d..ec562fe 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -179,7 +179,7 @@ pub enum Instruction { Not(NotType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtData, Arg2

), - Shl(ShlData, Arg3

), + Shl(ShlType, Arg3

), St(StData, Arg2St

), Ret(RetData), } @@ -400,7 +400,12 @@ pub struct BraData { pub struct CvtData {} -pub struct ShlData {} +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum ShlType { + B16, + B32, + B64, +} pub struct StData { pub qualifier: LdStQualifier, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d525fbe..bd5678e 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -606,11 +606,13 @@ CvtType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl InstShl: ast::Instruction> = { - "shl" ShlType => ast::Instruction::Shl(ast::ShlData{}, a) + "shl" => ast::Instruction::Shl(t, a) }; -ShlType = { - ".b16", ".b32", ".b64" +ShlType: ast::ShlType = { + ".b16" => ast::ShlType::B16, + ".b32" => ast::ShlType::B32, + ".b64" => ast::ShlType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index b4414d9..14a48be 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -47,6 +47,7 @@ test_ptx!(add, [1u64], [2u64]); test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); test_ptx!(bra, [10u64], [11u64]); test_ptx!(not, [0u64], [u64::max_value()]); +test_ptx!(shl, [11u64], [44u64]); struct DisplayError { err: T, diff --git a/ptx/src/test/spirv_run/shl.ptx b/ptx/src/test/spirv_run/shl.ptx new file mode 100644 index 0000000..e888741 --- /dev/null +++ b/ptx/src/test/spirv_run/shl.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry shl( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + .reg .u64 temp2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u64 temp, [in_addr]; + shl.b64 temp2, temp, 2; + st.u64 [out_addr], temp2; + ret; +} diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt new file mode 100644 index 0000000..131bd9e --- /dev/null +++ b/ptx/src/test/spirv_run/shl.spvtxt @@ -0,0 +1,41 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %1 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %5 "shl" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %4 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_0 = OpTypeInt 64 0 + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 + %5 = OpFunction %void None %4 + %6 = OpFunctionParameter %ulong + %7 = OpFunctionParameter %ulong + %21 = OpLabel + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_ulong Function + %10 = OpVariable %_ptr_Function_ulong Function + %11 = OpVariable %_ptr_Function_ulong Function + OpStore %8 %6 + OpStore %9 %7 + %13 = OpLoad %ulong %8 + %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 + %12 = OpLoad %ulong %19 + OpStore %10 %12 + %15 = OpLoad %ulong_0 %10 + %14 = OpShiftLeftLogical %ulong_0 %15 %uint_2 + OpStore %11 %14 + %16 = OpLoad %ulong %9 + %17 = OpLoad %ulong %11 + %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 + OpStore %20 %17 + OpReturn + OpFunctionEnd + \ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a6e627f..7091fc9 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -668,6 +668,10 @@ fn emit_function_body_ops( _ => builder.not(result_type, result_id, operand), }?; } + ast::Instruction::Shl(t, a) => { + let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); + builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?; + } _ => todo!(), }, Statement::LoadVar(arg, typ) => { @@ -1121,11 +1125,11 @@ impl ast::Instruction { } ast::Instruction::Mul(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Mul(d, a.map(visitor, Some(inst_type))) + ast::Instruction::Mul(d, a.map_non_shift(visitor, Some(inst_type))) } ast::Instruction::Add(d, a) => { let inst_type = d.get_type(); - ast::Instruction::Add(d, a.map(visitor, Some(inst_type))) + ast::Instruction::Add(d, a.map_non_shift(visitor, Some(inst_type))) } ast::Instruction::Setp(d, a) => { let inst_type = d.typ; @@ -1139,7 +1143,9 @@ impl ast::Instruction { ast::Instruction::Not(t, a.map(visitor, Some(t.to_type()))) } ast::Instruction::Cvt(_, _) => todo!(), - ast::Instruction::Shl(_, _) => todo!(), + ast::Instruction::Shl(t, a) => { + ast::Instruction::Shl(t, a.map_shift(visitor, Some(t.to_type()))) + } ast::Instruction::St(d, a) => { let inst_type = d.typ; ast::Instruction::St(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) @@ -1365,7 +1371,7 @@ impl ast::Arg2Mov { } impl ast::Arg3 { - fn map>( + fn map_non_shift>( self, visitor: &mut V, t: Option, @@ -1376,6 +1382,18 @@ impl ast::Arg3 { src2: visitor.src_operand(self.src2, t), } } + + fn map_shift>( + self, + visitor: &mut V, + t: Option, + ) -> ast::Arg3 { + ast::Arg3 { + dst: visitor.dst_variable(self.dst, t), + src1: visitor.src_operand(self.src1, t), + src2: visitor.src_operand(self.src2, Some(ast::Type::Scalar(ast::ScalarType::U32))), + } + } } impl ast::Arg4 { @@ -1533,6 +1551,16 @@ impl ast::NotType { } } +impl ast::ShlType { + fn to_type(self) -> ast::Type { + match self { + ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16), + ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32), + ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64), + } + } +} + impl ast::AddDetails { fn get_type(&self) -> ast::Type { match self { -- cgit v1.2.3