aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ptx/src/ast.rs13
-rw-r--r--ptx/src/ptx.lalrpop18
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
-rw-r--r--ptx/src/test/spirv_run/shr.ptx21
-rw-r--r--ptx/src/test/spirv_run/shr.spvtxt50
-rw-r--r--ptx/src/translate.rs24
6 files changed, 128 insertions, 0 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 097e19c..b509dfe 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -339,6 +339,7 @@ pub enum Instruction<P: ArgParams> {
Cvt(CvtDetails, Arg2<P>),
Cvta(CvtaDetails, Arg2<P>),
Shl(ShlType, Arg3<P>),
+ Shr(ShrType, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
Call(CallInst<P>),
@@ -762,6 +763,18 @@ pub enum ShlType {
B64,
}
+sub_scalar_type!(ShrType {
+ B16,
+ B32,
+ B64,
+ U16,
+ U32,
+ U64,
+ S16,
+ S32,
+ S64,
+});
+
pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StStateSpace,
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index ba3fc2b..debdae7 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -439,6 +439,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstBra,
InstCvt,
InstShl,
+ InstShr,
InstSt,
InstRet,
InstCvta,
@@ -918,6 +919,23 @@ ShlType: ast::ShlType = {
".b64" => ast::ShlType::B64,
};
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr
+InstShr: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a)
+};
+
+ShrType: ast::ShrType = {
+ ".b16" => ast::ShrType::B16,
+ ".b32" => ast::ShrType::B32,
+ ".b64" => ast::ShrType::B64,
+ ".u16" => ast::ShrType::U16,
+ ".u32" => ast::ShrType::U32,
+ ".u64" => ast::ShrType::U64,
+ ".s16" => ast::ShrType::S16,
+ ".s32" => ast::ShrType::S32,
+ ".s64" => ast::ShrType::S64,
+};
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 5a16755..6f516fd 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -68,6 +68,8 @@ test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
+test_ptx!(shr, [-2i32], [-1i32]);
+
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/shr.ptx b/ptx/src/test/spirv_run/shr.ptx
new file mode 100644
index 0000000..0a12fa7
--- /dev/null
+++ b/ptx/src/test/spirv_run/shr.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry shr(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp, [in_addr];
+ shr.s32 temp, temp, 1;
+ st.s32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/shr.spvtxt b/ptx/src/test/spirv_run/shr.spvtxt
new file mode 100644
index 0000000..417839d
--- /dev/null
+++ b/ptx/src/test/spirv_run/shr.spvtxt
@@ -0,0 +1,50 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %24 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "shr"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %27 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %uint_1 = OpConstant %uint 1
+ %1 = OpFunction %void None %27
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %22 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %10 = OpLoad %ulong %2
+ %9 = OpCopyObject %ulong %10
+ OpStore %4 %9
+ %12 = OpLoad %ulong %3
+ %11 = OpCopyObject %ulong %12
+ OpStore %5 %11
+ %14 = OpLoad %ulong %4
+ %20 = OpConvertUToPtr %_ptr_Generic_uint %14
+ %13 = OpLoad %uint %20
+ OpStore %6 %13
+ %16 = OpLoad %uint %6
+ %15 = OpShiftRightArithmetic %uint %16 %uint_1
+ OpStore %6 %15
+ %17 = OpLoad %ulong %5
+ %18 = OpLoad %uint %6
+ %21 = OpConvertUToPtr %_ptr_Generic_uint %17
+ OpStore %21 %18
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 37cef00..fe6a7dc 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -589,6 +589,9 @@ fn convert_to_typed_statements(
ast::Instruction::Mad(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast())))
}
+ ast::Instruction::Shr(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1555,6 +1558,14 @@ fn emit_function_body_ops(
let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?;
}
+ ast::Instruction::Shr(t, a) => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
+ if t.signed() {
+ builder.shift_right_arithmetic(result_type, Some(a.dst), a.src1, a.src2)?;
+ } else {
+ builder.shift_right_logical(result_type, Some(a.dst), a.src1, a.src2)?;
+ }
+ }
ast::Instruction::Cvt(dets, arg) => {
emit_cvt(builder, map, dets, arg)?;
}
@@ -2874,6 +2885,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?)
}
+ ast::Instruction::Shr(t, a) => {
+ ast::Instruction::Shr(t, a.map_shift(visitor, ast::Type::Scalar(t.into()))?)
+ }
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
let is_param = d.state_space == ast::StStateSpace::Param
@@ -3094,6 +3108,7 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Cvt(_, _)
| ast::Instruction::Cvta(_, _)
| ast::Instruction::Shl(_, _)
+ | ast::Instruction::Shr(_, _)
| ast::Instruction::St(_, _)
| ast::Instruction::Ret(_)
| ast::Instruction::Abs(_, _)
@@ -4009,6 +4024,15 @@ impl ast::ShlType {
}
}
+impl ast::ShrType {
+ fn signed(&self) -> bool {
+ match self {
+ ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
+ _ => false,
+ }
+ }
+}
+
impl ast::AddDetails {
fn get_type(&self) -> ast::Type {
match self {