diff options
-rw-r--r-- | ptx/src/ast.rs | 13 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 18 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 2 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/shr.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/shr.spvtxt | 50 | ||||
-rw-r--r-- | ptx/src/translate.rs | 24 |
6 files changed, 128 insertions, 0 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 097e19c..b509dfe 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -339,6 +339,7 @@ pub enum Instruction<P: ArgParams> { Cvt(CvtDetails, Arg2<P>), Cvta(CvtaDetails, Arg2<P>), Shl(ShlType, Arg3<P>), + Shr(ShrType, Arg3<P>), St(StData, Arg2St<P>), Ret(RetData), Call(CallInst<P>), @@ -762,6 +763,18 @@ pub enum ShlType { B64, } +sub_scalar_type!(ShrType { + B16, + B32, + B64, + U16, + U32, + U64, + S16, + S32, + S64, +}); + pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index ba3fc2b..debdae7 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -439,6 +439,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstBra, InstCvt, InstShl, + InstShr, InstSt, InstRet, InstCvta, @@ -918,6 +919,23 @@ ShlType: ast::ShlType = { ".b64" => ast::ShlType::B64, }; +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr +InstShr: ast::Instruction<ast::ParsedArgParams<'input>> = { + "shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a) +}; + +ShrType: ast::ShrType = { + ".b16" => ast::ShrType::B16, + ".b32" => ast::ShrType::B32, + ".b64" => ast::ShrType::B64, + ".u16" => ast::ShrType::U16, + ".u32" => ast::ShrType::U32, + ".u64" => ast::ShrType::U64, + ".s16" => ast::ShrType::S16, + ".s32" => ast::ShrType::S32, + ".s64" => ast::ShrType::S64, +}; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = { diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 5a16755..6f516fd 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -68,6 +68,8 @@ test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]); test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
+test_ptx!(shr, [-2i32], [-1i32]);
+
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/shr.ptx b/ptx/src/test/spirv_run/shr.ptx new file mode 100644 index 0000000..0a12fa7 --- /dev/null +++ b/ptx/src/test/spirv_run/shr.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry shr(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp, [in_addr];
+ shr.s32 temp, temp, 1;
+ st.s32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/shr.spvtxt b/ptx/src/test/spirv_run/shr.spvtxt new file mode 100644 index 0000000..417839d --- /dev/null +++ b/ptx/src/test/spirv_run/shr.spvtxt @@ -0,0 +1,50 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %24 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "shr" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %27 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %uint_1 = OpConstant %uint 1 + %1 = OpFunction %void None %27 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %22 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %20 = OpConvertUToPtr %_ptr_Generic_uint %14 + %13 = OpLoad %uint %20 + OpStore %6 %13 + %16 = OpLoad %uint %6 + %15 = OpShiftRightArithmetic %uint %16 %uint_1 + OpStore %6 %15 + %17 = OpLoad %ulong %5 + %18 = OpLoad %uint %6 + %21 = OpConvertUToPtr %_ptr_Generic_uint %17 + OpStore %21 %18 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 37cef00..fe6a7dc 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -589,6 +589,9 @@ fn convert_to_typed_statements( ast::Instruction::Mad(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast())))
}
+ ast::Instruction::Shr(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1555,6 +1558,14 @@ fn emit_function_body_ops( let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?;
}
+ ast::Instruction::Shr(t, a) => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
+ if t.signed() {
+ builder.shift_right_arithmetic(result_type, Some(a.dst), a.src1, a.src2)?;
+ } else {
+ builder.shift_right_logical(result_type, Some(a.dst), a.src1, a.src2)?;
+ }
+ }
ast::Instruction::Cvt(dets, arg) => {
emit_cvt(builder, map, dets, arg)?;
}
@@ -2874,6 +2885,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?)
}
+ ast::Instruction::Shr(t, a) => {
+ ast::Instruction::Shr(t, a.map_shift(visitor, ast::Type::Scalar(t.into()))?)
+ }
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
let is_param = d.state_space == ast::StStateSpace::Param
@@ -3094,6 +3108,7 @@ impl ast::Instruction<ExpandedArgParams> { | ast::Instruction::Cvt(_, _)
| ast::Instruction::Cvta(_, _)
| ast::Instruction::Shl(_, _)
+ | ast::Instruction::Shr(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Ret(_)
| ast::Instruction::Abs(_, _)
@@ -4009,6 +4024,15 @@ impl ast::ShlType { }
}
+impl ast::ShrType {
+ fn signed(&self) -> bool {
+ match self {
+ ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
+ _ => false,
+ }
+ }
+}
+
impl ast::AddDetails {
fn get_type(&self) -> ast::Type {
match self {
|