summaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/ast.rs15
-rw-r--r--ptx/src/ptx.lalrpop24
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/selp.ptx23
-rw-r--r--ptx/src/test/spirv_run/selp.spvtxt65
-rw-r--r--ptx/src/translate.rs71
6 files changed, 198 insertions, 1 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index f4502af..1266ea4 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -201,6 +201,20 @@ sub_enum!(LdStScalarType {
F64,
});
+sub_enum!(SelpType {
+ B16,
+ B32,
+ B64,
+ U16,
+ U32,
+ U64,
+ S16,
+ S32,
+ S64,
+ F32,
+ F64,
+});
+
pub trait UnwrapWithVec<E, To> {
fn unwrap_with(self, errs: &mut Vec<E>) -> To;
}
@@ -512,6 +526,7 @@ pub enum Instruction<P: ArgParams> {
Max(MinMaxDetails, Arg3<P>),
Rcp(RcpDetails, Arg2<P>),
And(OrAndType, Arg3<P>),
+ Selp(SelpType, Arg4<P>),
}
#[derive(Copy, Clone)]
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 7414443..025f0be 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -142,6 +142,7 @@ match {
"or",
"rcp",
"ret",
+ "selp",
"setp",
"shl",
"shr",
@@ -176,6 +177,7 @@ ExtendedID : &'input str = {
"or",
"rcp",
"ret",
+ "selp",
"setp",
"shl",
"shr",
@@ -614,7 +616,8 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstSub,
InstMin,
InstMax,
- InstRcp
+ InstRcp,
+ InstSelp
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -1271,6 +1274,25 @@ MinMaxDetails: ast::MinMaxDetails = {
)
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp
+InstSelp: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "selp" <t:SelpType> <a:Arg4> => ast::Instruction::Selp(t, a),
+};
+
+SelpType: ast::SelpType = {
+ ".b16" => ast::SelpType::B16,
+ ".b32" => ast::SelpType::B32,
+ ".b64" => ast::SelpType::B64,
+ ".u16" => ast::SelpType::U16,
+ ".u32" => ast::SelpType::U32,
+ ".u64" => ast::SelpType::U64,
+ ".s16" => ast::SelpType::S16,
+ ".s32" => ast::SelpType::S32,
+ ".s64" => ast::SelpType::S64,
+ ".f32" => ast::SelpType::F32,
+ ".f64" => ast::SelpType::F64,
+};
+
ArithDetails: ast::ArithDetails = {
<t:UIntType> => ast::ArithDetails::Unsigned(t),
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index dfdec72..f336055 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -90,6 +90,7 @@ test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
test_ptx!(constant_f32, [10f32], [5f32]);
test_ptx!(constant_negative, [-101i32], [101i32]);
test_ptx!(and, [6u32, 3u32], [2u32]);
+test_ptx!(selp, [100u16, 200u16], [200u16]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/selp.ptx b/ptx/src/test/spirv_run/selp.ptx
new file mode 100644
index 0000000..79171dc
--- /dev/null
+++ b/ptx/src/test/spirv_run/selp.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry selp(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u16 temp1;
+ .reg .u16 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u16 temp1, [in_addr];
+ ld.u16 temp2, [in_addr + 2];
+ selp.u16 temp1, temp1, temp2, 0;
+ st.u16 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt
new file mode 100644
index 0000000..dffd9af
--- /dev/null
+++ b/ptx/src/test/spirv_run/selp.spvtxt
@@ -0,0 +1,65 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 40
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+OpCapability FunctionFloatControlINTEL
+OpExtension "SPV_INTEL_float_controls2"
+%31 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "selp"
+%32 = OpTypeVoid
+%33 = OpTypeInt 64 0
+%34 = OpTypeFunction %32 %33 %33
+%35 = OpTypePointer Function %33
+%36 = OpTypeInt 16 0
+%37 = OpTypePointer Function %36
+%38 = OpTypePointer Generic %36
+%23 = OpConstant %33 2
+%39 = OpTypeBool
+%25 = OpConstantFalse %39
+%1 = OpFunction %32 None %34
+%8 = OpFunctionParameter %33
+%9 = OpFunctionParameter %33
+%29 = OpLabel
+%2 = OpVariable %35 Function
+%3 = OpVariable %35 Function
+%4 = OpVariable %35 Function
+%5 = OpVariable %35 Function
+%6 = OpVariable %37 Function
+%7 = OpVariable %37 Function
+OpStore %2 %8
+OpStore %3 %9
+%11 = OpLoad %33 %2
+%10 = OpCopyObject %33 %11
+OpStore %4 %10
+%13 = OpLoad %33 %3
+%12 = OpCopyObject %33 %13
+OpStore %5 %12
+%15 = OpLoad %33 %4
+%26 = OpConvertUToPtr %38 %15
+%14 = OpLoad %36 %26
+OpStore %6 %14
+%17 = OpLoad %33 %4
+%24 = OpIAdd %33 %17 %23
+%27 = OpConvertUToPtr %38 %24
+%16 = OpLoad %36 %27
+OpStore %7 %16
+%19 = OpLoad %36 %6
+%20 = OpLoad %36 %7
+%18 = OpSelect %36 %25 %20 %20
+OpStore %6 %18
+%21 = OpLoad %33 %5
+%22 = OpLoad %36 %6
+%28 = OpConvertUToPtr %38 %21
+OpStore %28 %22
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index c699cc4..9d73742 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1266,6 +1266,9 @@ fn convert_to_typed_statements(
ast::Instruction::And(d, a) => {
result.push(Statement::Instruction(ast::Instruction::And(d, a.cast())))
}
+ ast::Instruction::Selp(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -2159,6 +2162,22 @@ fn emit_function_body_ops(
(ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => {
builder.constant_f64(typ_id, Some(cnst.dst), value);
}
+ (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => {
+ let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
+ if value == 0 {
+ builder.constant_false(bool_type, Some(cnst.dst));
+ } else {
+ builder.constant_true(bool_type, Some(cnst.dst));
+ }
+ }
+ (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => {
+ let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
+ if value == 0 {
+ builder.constant_false(bool_type, Some(cnst.dst));
+ } else {
+ builder.constant_true(bool_type, Some(cnst.dst));
+ }
+ }
_ => return Err(TranslateError::MismatchedType),
}
}
@@ -2362,6 +2381,10 @@ fn emit_function_body_ops(
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
}
}
+ ast::Instruction::Selp(t, a) => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
+ builder.select(result_type, Some(a.dst), a.src3, a.src2, a.src2)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -4056,6 +4079,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
t,
a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
),
+ ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?),
})
}
}
@@ -4301,6 +4325,7 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Max(_, _)
| ast::Instruction::Rcp(_, _)
| ast::Instruction::And(_, _)
+ | ast::Instruction::Selp(_, _)
| ast::Instruction::Mad(_, _) => None,
}
}
@@ -4321,6 +4346,7 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Or(_, _) => None,
ast::Instruction::And(_, _) => None,
ast::Instruction::Cvta(_, _) => None,
+ ast::Instruction::Selp(_, _) => None,
ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None,
ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None,
@@ -5047,6 +5073,51 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
src3,
})
}
+
+ fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ t: ast::SelpType,
+ ) -> Result<ast::Arg4<U>, TranslateError> {
+ let dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&ast::Type::Scalar(t.into())),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(t.into()),
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(t.into()),
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ )?;
+ Ok(ast::Arg4 {
+ dst,
+ src1,
+ src2,
+ src3,
+ })
+ }
}
impl<T: ArgParamsEx> ast::Arg4Setp<T> {