aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ptx/src/ast.rs10
-rw-r--r--ptx/src/ptx.lalrpop9
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/not.ptx22
-rw-r--r--ptx/src/test/spirv_run/not.spvtxt39
-rw-r--r--ptx/src/translate.rs28
6 files changed, 100 insertions, 9 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index bbc5815..158ec8d 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -176,7 +176,7 @@ pub enum Instruction<P: ArgParams> {
Add(AddDetails, Arg3<P>),
Setp(SetpData, Arg4<P>),
SetpBool(SetpBoolData, Arg5<P>),
- Not(NotData, Arg2<P>),
+ Not(NotType, Arg2<P>),
Bra(BraData, Arg1<P>),
Cvt(CvtData, Arg2<P>),
Shl(ShlData, Arg3<P>),
@@ -386,7 +386,13 @@ pub struct SetpBoolData {
pub bool_op: SetpBoolPostOp,
}
-pub struct NotData {}
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum NotType {
+ Pred,
+ B16,
+ B32,
+ B64,
+}
pub struct BraData {
pub uniform: bool,
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index af26765..d525fbe 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -557,11 +557,14 @@ SetpType: ast::ScalarType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "not" NotType <a:Arg2> => ast::Instruction::Not(ast::NotData{}, a)
+ "not" <t:NotType> <a:Arg2> => ast::Instruction::Not(t, a)
};
-NotType = {
- ".pred", ".b16", ".b32", ".b64"
+NotType: ast::NotType = {
+ ".pred" => ast::NotType::Pred,
+ ".b16" => ast::NotType::B16,
+ ".b32" => ast::NotType::B32,
+ ".b64" => ast::NotType::B64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index c90e487..b4414d9 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -46,6 +46,7 @@ test_ptx!(mul_hi, [u64::max_value()], [1u64]);
test_ptx!(add, [1u64], [2u64]);
test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]);
test_ptx!(bra, [10u64], [11u64]);
+test_ptx!(not, [0u64], [u64::max_value()]);
struct DisplayError<T: Display + Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/not.ptx b/ptx/src/test/spirv_run/not.ptx
new file mode 100644
index 0000000..6182134
--- /dev/null
+++ b/ptx/src/test/spirv_run/not.ptx
@@ -0,0 +1,22 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry not(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ not.b64 temp2, temp;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt
new file mode 100644
index 0000000..518e995
--- /dev/null
+++ b/ptx/src/test/spirv_run/not.spvtxt
@@ -0,0 +1,39 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %1 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %5 "not"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %4 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_0 = OpTypeInt 64 0
+ %5 = OpFunction %void None %4
+ %6 = OpFunctionParameter %ulong
+ %7 = OpFunctionParameter %ulong
+ %20 = OpLabel
+ %8 = OpVariable %_ptr_Function_ulong Function
+ %9 = OpVariable %_ptr_Function_ulong Function
+ %10 = OpVariable %_ptr_Function_ulong Function
+ %11 = OpVariable %_ptr_Function_ulong Function
+ OpStore %8 %6
+ OpStore %9 %7
+ %13 = OpLoad %ulong %8
+ %18 = OpConvertUToPtr %_ptr_Generic_ulong %13
+ %12 = OpLoad %ulong %18
+ OpStore %10 %12
+ %15 = OpLoad %ulong_0 %10
+ %14 = OpNot %ulong_0 %15
+ OpStore %11 %14
+ %16 = OpLoad %ulong %9
+ %17 = OpLoad %ulong %11
+ %19 = OpConvertUToPtr %_ptr_Generic_ulong %16
+ OpStore %19 %17
+ OpReturn
+ OpFunctionEnd
+ \ No newline at end of file
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index c40e554..a6e627f 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -659,6 +659,15 @@ fn emit_function_body_ops(
}
emit_setp(builder, map, setp, arg)?;
}
+ ast::Instruction::Not(t, a) => {
+ let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
+ let result_id = Some(a.dst);
+ let operand = a.src;
+ match t {
+ ast::NotType::Pred => builder.logical_not(result_type, result_id, operand),
+ _ => builder.not(result_type, result_id, operand),
+ }?;
+ }
_ => todo!(),
},
Statement::LoadVar(arg, typ) => {
@@ -887,9 +896,7 @@ fn expand_map_variables<'a>(
s: ast::Statement<ast::ParsedArgParams<'a>>,
) {
match s {
- ast::Statement::Label(name) => {
- result.push(ast::Statement::Label(id_defs.get_id(name)))
- }
+ ast::Statement::Label(name) => result.push(ast::Statement::Label(id_defs.get_id(name))),
ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction(
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
i.map_variable(&mut |id| id_defs.get_id(id)),
@@ -1128,7 +1135,9 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
}
- ast::Instruction::Not(_, _) => todo!(),
+ ast::Instruction::Not(t, a) => {
+ ast::Instruction::Not(t, a.map(visitor, Some(t.to_type())))
+ }
ast::Instruction::Cvt(_, _) => todo!(),
ast::Instruction::Shl(_, _) => todo!(),
ast::Instruction::St(d, a) => {
@@ -1513,6 +1522,17 @@ impl ast::ScalarType {
}
}
+impl ast::NotType {
+ fn to_type(self) -> ast::Type {
+ match self {
+ ast::NotType::Pred => ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred),
+ ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
+ ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
+ ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
+ }
+ }
+}
+
impl ast::AddDetails {
fn get_type(&self) -> ast::Type {
match self {