aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-10-01 20:28:57 +0200
committerAndrzej Janik <[email protected]>2020-10-01 20:28:57 +0200
commitbd3d440dba9a913e2214de89a151f9c2c34984fe (patch)
treee90bd1e36968d3abae492b0c5bf22791f119fb80
parent96a342e33f221803874ff897f4aa1aa3aae8e71c (diff)
downloadZLUDA-bd3d440dba9a913e2214de89a151f9c2c34984fe.tar.gz
ZLUDA-bd3d440dba9a913e2214de89a151f9c2c34984fe.zip
Implement or
-rw-r--r--ptx/src/ast.rs8
-rw-r--r--ptx/src/ptx.lalrpop17
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/translate.rs16
4 files changed, 41 insertions, 1 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index b509dfe..8c64ebf 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -345,6 +345,7 @@ pub enum Instruction<P: ArgParams> {
Call(CallInst<P>),
Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>),
+ Or(OrType, Arg3<P>),
}
#[derive(Copy, Clone)]
@@ -802,3 +803,10 @@ pub enum StCacheOperator {
pub struct RetData {
pub uniform: bool,
}
+
+sub_scalar_type!(OrType {
+ Pred,
+ B16,
+ B32,
+ B64,
+});
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index debdae7..d2d5be8 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -127,6 +127,7 @@ match {
"mov",
"mul",
"not",
+ "or",
"ret",
"setp",
"shl",
@@ -155,6 +156,7 @@ ExtendedID : &'input str = {
"mov",
"mul",
"not",
+ "or",
"ret",
"setp",
"shl",
@@ -445,7 +447,8 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstCvta,
InstCall,
InstAbs,
- InstMad
+ InstMad,
+ InstOr
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -1048,6 +1051,18 @@ SignedIntType: ast::ScalarType = {
".s64" => ast::ScalarType::S64,
};
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or
+InstOr: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "or" <d:OrType> <a:Arg3> => ast::Instruction::Or(d, a),
+};
+
+OrType: ast::OrType = {
+ ".pred" => ast::OrType::Pred,
+ ".b16" => ast::OrType::B16,
+ ".b32" => ast::OrType::B32,
+ ".b64" => ast::OrType::B64,
+}
+
Operand: ast::Operand<&'input str> = {
<r:ExtendedID> => ast::Operand::Reg(r),
<r:ExtendedID> "+" <o:Num> => {
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 6f516fd..99785a6 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -69,6 +69,7 @@ test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
+test_ptx!(or, [1u64, 2u64], [3u64]);
struct DisplayError<T: Debug> {
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index fe6a7dc..fb1b843 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -592,6 +592,9 @@ fn convert_to_typed_statements(
ast::Instruction::Shr(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast())))
}
+ ast::Instruction::Or(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1583,6 +1586,14 @@ fn emit_function_body_ops(
}
ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
},
+ ast::Instruction::Or(t, a) => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
+ if *t == ast::OrType::Pred {
+ builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
+ } else {
+ builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
+ }
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@@ -2905,6 +2916,10 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let is_wide = d.is_wide();
ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?)
}
+ ast::Instruction::Or(t, a) => ast::Instruction::Or(
+ t,
+ a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?,
+ ),
})
}
}
@@ -3113,6 +3128,7 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Ret(_)
| ast::Instruction::Abs(_, _)
| ast::Instruction::Call(_)
+ | ast::Instruction::Or(_, _)
| ast::Instruction::Mad(_, _) => None,
}
}