diff options
Diffstat (limited to 'ptx')
-rw-r--r-- | ptx/src/ptx.lalrpop | 32 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/fma.ptx | 25 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/fma.spvtxt | 72 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 1 | ||||
-rw-r--r-- | ptx/src/translate.rs | 15 |
5 files changed, 142 insertions, 3 deletions
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 025f0be..dfe5a5f 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -131,6 +131,7 @@ match { "cvt", "cvta", "debug", + "fma", "ld", "mad", "map_f64_to_f32", @@ -166,6 +167,7 @@ ExtendedID : &'input str = { "cvt", "cvta", "debug", + "fma", "ld", "mad", "map_f64_to_f32", @@ -1185,7 +1187,8 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = { "mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a), - "mad" ".hi" ".sat" ".s32" => todo!() + "mad" ".hi" ".sat" ".s32" => todo!(), + "fma" <f:ArithFloatMustRound> <a:Arg4> => ast::Instruction::Mad(ast::MulDetails::Float(f), a), }; SignedIntType: ast::ScalarType = { @@ -1333,6 +1336,33 @@ ArithFloat: ast::ArithFloat = { }, } +ArithFloatMustRound: ast::ArithFloat = { + <rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat { + typ: ast::FloatType::F32, + rounding: Some(rn), + flush_to_zero: Some(ftz.is_some()), + saturate: sat.is_some(), + }, + <rn:RoundingModeFloat> ".f64" => ast::ArithFloat { + typ: ast::FloatType::F64, + rounding: Some(rn), + flush_to_zero: None, + saturate: false, + }, + ".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat { + typ: ast::FloatType::F16, + rounding: Some(ast::RoundingMode::NearestEven), + flush_to_zero: Some(ftz.is_some()), + saturate: sat.is_some(), + }, + ".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat { + typ: ast::FloatType::F16x2, + rounding: Some(ast::RoundingMode::NearestEven), + flush_to_zero: Some(ftz.is_some()), + saturate: sat.is_some(), + }, +} + Operand: ast::Operand<&'input str> = { <r:ExtendedID> => ast::Operand::Reg(r), <r:ExtendedID> "+" <offset:S32Num> => ast::Operand::RegOffset(r, offset), diff --git a/ptx/src/test/spirv_run/fma.ptx b/ptx/src/test/spirv_run/fma.ptx new file mode 100644 index 0000000..171d306 --- /dev/null +++ b/ptx/src/test/spirv_run/fma.ptx @@ -0,0 +1,25 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry fma(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp1;
+ .reg .f32 temp2;
+ .reg .f32 temp3;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp1, [in_addr];
+ ld.f32 temp2, [in_addr+4];
+ ld.f32 temp3, [in_addr+8];
+ fma.rn.f32 temp1, temp1, temp2, temp3;
+ st.f32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt new file mode 100644 index 0000000..734bf0f --- /dev/null +++ b/ptx/src/test/spirv_run/fma.spvtxt @@ -0,0 +1,72 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 45 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +OpCapability FunctionFloatControlINTEL +OpExtension "SPV_INTEL_float_controls2" +%37 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "fma" +OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +%38 = OpTypeVoid +%39 = OpTypeInt 64 0 +%40 = OpTypeFunction %38 %39 %39 +%41 = OpTypePointer Function %39 +%42 = OpTypeFloat 32 +%43 = OpTypePointer Function %42 +%44 = OpTypePointer Generic %42 +%27 = OpConstant %39 4 +%29 = OpConstant %39 8 +%1 = OpFunction %38 None %40 +%9 = OpFunctionParameter %39 +%10 = OpFunctionParameter %39 +%35 = OpLabel +%2 = OpVariable %41 Function +%3 = OpVariable %41 Function +%4 = OpVariable %41 Function +%5 = OpVariable %41 Function +%6 = OpVariable %43 Function +%7 = OpVariable %43 Function +%8 = OpVariable %43 Function +OpStore %2 %9 +OpStore %3 %10 +%12 = OpLoad %39 %2 +%11 = OpCopyObject %39 %12 +OpStore %4 %11 +%14 = OpLoad %39 %3 +%13 = OpCopyObject %39 %14 +OpStore %5 %13 +%16 = OpLoad %39 %4 +%31 = OpConvertUToPtr %44 %16 +%15 = OpLoad %42 %31 +OpStore %6 %15 +%18 = OpLoad %39 %4 +%28 = OpIAdd %39 %18 %27 +%32 = OpConvertUToPtr %44 %28 +%17 = OpLoad %42 %32 +OpStore %7 %17 +%20 = OpLoad %39 %4 +%30 = OpIAdd %39 %20 %29 +%33 = OpConvertUToPtr %44 %30 +%19 = OpLoad %42 %33 +OpStore %8 %19 +%22 = OpLoad %42 %6 +%23 = OpLoad %42 %7 +%24 = OpLoad %42 %8 +%21 = OpExtInst %42 %37 mad %22 %23 %24 +OpStore %6 %21 +%25 = OpLoad %39 %5 +%26 = OpLoad %42 %6 +%34 = OpConvertUToPtr %44 %25 +OpStore %34 %26 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f336055..98b9630 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -91,6 +91,7 @@ test_ptx!(constant_f32, [10f32], [5f32]); test_ptx!(constant_negative, [-101i32], [101i32]);
test_ptx!(and, [6u32, 3u32], [2u32]);
test_ptx!(selp, [100u16, 200u16], [200u16]);
+test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 9d73742..a7025b1 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2343,7 +2343,9 @@ fn emit_function_body_ops( ast::MulDetails::Unsigned(ref desc) => {
emit_mad_uint(builder, map, opencl, desc, arg)?
}
- ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
+ ast::MulDetails::Float(desc) => {
+ emit_mad_float(builder, map, opencl, desc, arg)?
+ }
},
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
@@ -2560,10 +2562,19 @@ fn emit_mad_sint( fn emit_mad_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
+ opencl: spirv::Word,
desc: &ast::ArithFloat,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
- todo!()
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::mad as spirv::Word,
+ [arg.src1, arg.src2, arg.src3],
+ )?;
+ Ok(())
}
fn emit_add_float(
|