aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src')
-rw-r--r--ptx/src/ast.rs7
-rw-r--r--ptx/src/ptx.lalrpop36
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/neg.ptx21
-rw-r--r--ptx/src/test/spirv_run/neg.spvtxt47
-rw-r--r--ptx/src/translate.rs20
6 files changed, 131 insertions, 1 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index f00ddce..7f2fc9a 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -542,6 +542,7 @@ pub enum Instruction<P: ArgParams> {
Div(DivDetails, Arg3<P>),
Sqrt(SqrtDetails, Arg2<P>),
Rsqrt(RsqrtDetails, Arg2<P>),
+ Neg(NegDetails, Arg2<P>),
}
#[derive(Copy, Clone)]
@@ -1183,6 +1184,12 @@ pub struct RsqrtDetails {
pub flush_to_zero: bool,
}
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct NegDetails {
+ pub typ: ScalarType,
+ pub flush_to_zero: Option<bool>,
+}
+
impl<'a> NumsOrArrays<'a> {
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
self.normalize_dimensions(dimensions)?;
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 4cf4255..9d2adec 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -156,6 +156,7 @@ match {
"min",
"mov",
"mul",
+ "neg",
"not",
"or",
"rcp",
@@ -198,6 +199,7 @@ ExtendedID : &'input str = {
"min",
"mov",
"mul",
+ "neg",
"not",
"or",
"rcp",
@@ -684,6 +686,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstDiv,
InstSqrt,
InstRsqrt,
+ InstNeg,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -1577,6 +1580,39 @@ InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg
+InstNeg: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "neg" <ftz:".ftz"?> <typ:NegTypeFtz> <a:Arg2> => {
+ let details = ast::NegDetails {
+ typ,
+ flush_to_zero: Some(ftz.is_some()),
+ };
+ ast::Instruction::Neg(details, a)
+ },
+ "neg" <typ:NegTypeNonFtz> <a:Arg2> => {
+ let details = ast::NegDetails {
+ typ,
+ flush_to_zero: None,
+ };
+ ast::Instruction::Neg(details, a)
+ },
+}
+
+NegTypeFtz: ast::ScalarType = {
+ ".f16" => ast::ScalarType::F16,
+ ".f16x2" => ast::ScalarType::F16x2,
+ ".f32" => ast::ScalarType::F32,
+}
+
+NegTypeNonFtz: ast::ScalarType = {
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+ ".f64" => ast::ScalarType::F64
+}
+
ArithDetails: ast::ArithDetails = {
<t:UIntType> => ast::ArithDetails::Unsigned(t),
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 4e9d39f..7ba3c4d 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -104,6 +104,7 @@ test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
test_ptx!(sqrt, [0.25f32], [0.5f32]);
test_ptx!(rsqrt, [0.25f64], [2f64]);
+test_ptx!(neg, [181i32], [-181i32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/neg.ptx b/ptx/src/test/spirv_run/neg.ptx
new file mode 100644
index 0000000..60fe162
--- /dev/null
+++ b/ptx/src/test/spirv_run/neg.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry neg(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp1;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp1, [in_addr];
+ neg.s32 temp1, temp1;
+ st.s32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/neg.spvtxt b/ptx/src/test/spirv_run/neg.spvtxt
new file mode 100644
index 0000000..b358858
--- /dev/null
+++ b/ptx/src/test/spirv_run/neg.spvtxt
@@ -0,0 +1,47 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %26 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "not"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %29 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %1 = OpFunction %void None %29
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %24 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %20 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %20
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %22 = OpCopyObject %ulong %17
+ %21 = OpNot %ulong %22
+ %16 = OpCopyObject %ulong %21
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %23 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %23 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index c351ccd..36e15f9 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1511,6 +1511,9 @@ fn convert_to_typed_statements(
ast::Instruction::Rsqrt(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
}
+ ast::Instruction::Neg(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -2805,6 +2808,15 @@ fn emit_function_body_ops(
&[a.src],
)?;
}
+ ast::Instruction::Neg(details, arg) => {
+ let result_type = map.get_or_add_scalar(builder, details.typ);
+ let negate_func = if details.typ.kind() == ScalarKind::Float {
+ dr::Builder::f_negate
+ } else {
+ dr::Builder::s_negate
+ };
+ negate_func(builder, result_type, Some(arg.dst), arg.src)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -3406,7 +3418,7 @@ fn emit_setp(
(ast::SetpCompareOp::NanGreaterOrEq, _) => {
builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- _ => todo!()
+ _ => todo!(),
}?;
Ok(())
}
@@ -4678,6 +4690,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Rsqrt(d, a) => {
ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
}
+ ast::Instruction::Neg(d, a) => {
+ ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?)
+ }
})
}
}
@@ -4984,6 +4999,9 @@ impl ast::Instruction<ExpandedArgParams> {
details.flush_to_zero,
ast::ScalarType::from(details.typ).size_of(),
)),
+ ast::Instruction::Neg(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, details.typ.size_of())),
}
}
}