diff options
-rw-r--r-- | ptx/src/ast.rs | 7 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 36 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 1 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/neg.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/neg.spvtxt | 47 | ||||
-rw-r--r-- | ptx/src/translate.rs | 20 |
6 files changed, 131 insertions, 1 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index f00ddce..7f2fc9a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -542,6 +542,7 @@ pub enum Instruction<P: ArgParams> { Div(DivDetails, Arg3<P>), Sqrt(SqrtDetails, Arg2<P>), Rsqrt(RsqrtDetails, Arg2<P>), + Neg(NegDetails, Arg2<P>), } #[derive(Copy, Clone)] @@ -1183,6 +1184,12 @@ pub struct RsqrtDetails { pub flush_to_zero: bool, } +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct NegDetails { + pub typ: ScalarType, + pub flush_to_zero: Option<bool>, +} + impl<'a> NumsOrArrays<'a> { pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> { self.normalize_dimensions(dimensions)?; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 4cf4255..9d2adec 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -156,6 +156,7 @@ match { "min", "mov", "mul", + "neg", "not", "or", "rcp", @@ -198,6 +199,7 @@ ExtendedID : &'input str = { "min", "mov", "mul", + "neg", "not", "or", "rcp", @@ -684,6 +686,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstDiv, InstSqrt, InstRsqrt, + InstNeg, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -1577,6 +1580,39 @@ InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { }, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg +InstNeg: ast::Instruction<ast::ParsedArgParams<'input>> = { + "neg" <ftz:".ftz"?> <typ:NegTypeFtz> <a:Arg2> => { + let details = ast::NegDetails { + typ, + flush_to_zero: Some(ftz.is_some()), + }; + ast::Instruction::Neg(details, a) + }, + "neg" <typ:NegTypeNonFtz> <a:Arg2> => { + let details = ast::NegDetails { + typ, + flush_to_zero: None, + }; + ast::Instruction::Neg(details, a) + }, +} + +NegTypeFtz: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, +} + +NegTypeNonFtz: ast::ScalarType = { + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f64" => ast::ScalarType::F64 +} + ArithDetails: ast::ArithDetails = { <t:UIntType> => ast::ArithDetails::Unsigned(t), <t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt { diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 4e9d39f..7ba3c4d 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -104,6 +104,7 @@ test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]); test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
test_ptx!(sqrt, [0.25f32], [0.5f32]);
test_ptx!(rsqrt, [0.25f64], [2f64]);
+test_ptx!(neg, [181i32], [-181i32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/neg.ptx b/ptx/src/test/spirv_run/neg.ptx new file mode 100644 index 0000000..60fe162 --- /dev/null +++ b/ptx/src/test/spirv_run/neg.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry neg(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp1;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp1, [in_addr];
+ neg.s32 temp1, temp1;
+ st.s32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/neg.spvtxt b/ptx/src/test/spirv_run/neg.spvtxt new file mode 100644 index 0000000..b358858 --- /dev/null +++ b/ptx/src/test/spirv_run/neg.spvtxt @@ -0,0 +1,47 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %26 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "not" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %29 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %1 = OpFunction %void None %29 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %24 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %20 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %20 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %22 = OpCopyObject %ulong %17 + %21 = OpNot %ulong %22 + %16 = OpCopyObject %ulong %21 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %23 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %23 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c351ccd..36e15f9 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1511,6 +1511,9 @@ fn convert_to_typed_statements( ast::Instruction::Rsqrt(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
}
+ ast::Instruction::Neg(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -2805,6 +2808,15 @@ fn emit_function_body_ops( &[a.src],
)?;
}
+ ast::Instruction::Neg(details, arg) => {
+ let result_type = map.get_or_add_scalar(builder, details.typ);
+ let negate_func = if details.typ.kind() == ScalarKind::Float {
+ dr::Builder::f_negate
+ } else {
+ dr::Builder::s_negate
+ };
+ negate_func(builder, result_type, Some(arg.dst), arg.src)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -3406,7 +3418,7 @@ fn emit_setp( (ast::SetpCompareOp::NanGreaterOrEq, _) => {
builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- _ => todo!()
+ _ => todo!(),
}?;
Ok(())
}
@@ -4678,6 +4690,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Instruction::Rsqrt(d, a) => {
ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
}
+ ast::Instruction::Neg(d, a) => {
+ ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?)
+ }
})
}
}
@@ -4984,6 +4999,9 @@ impl ast::Instruction<ExpandedArgParams> { details.flush_to_zero,
ast::ScalarType::from(details.typ).size_of(),
)),
+ ast::Instruction::Neg(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, details.typ.size_of())),
}
}
}
|