summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ptx/src/ptx.lalrpop32
-rw-r--r--ptx/src/test/spirv_run/fma.ptx25
-rw-r--r--ptx/src/test/spirv_run/fma.spvtxt72
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/translate.rs15
5 files changed, 142 insertions, 3 deletions
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 025f0be..dfe5a5f 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -131,6 +131,7 @@ match {
"cvt",
"cvta",
"debug",
+ "fma",
"ld",
"mad",
"map_f64_to_f32",
@@ -166,6 +167,7 @@ ExtendedID : &'input str = {
"cvt",
"cvta",
"debug",
+ "fma",
"ld",
"mad",
"map_f64_to_f32",
@@ -1185,7 +1187,8 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a),
- "mad" ".hi" ".sat" ".s32" => todo!()
+ "mad" ".hi" ".sat" ".s32" => todo!(),
+ "fma" <f:ArithFloatMustRound> <a:Arg4> => ast::Instruction::Mad(ast::MulDetails::Float(f), a),
};
SignedIntType: ast::ScalarType = {
@@ -1333,6 +1336,33 @@ ArithFloat: ast::ArithFloat = {
},
}
+ArithFloatMustRound: ast::ArithFloat = {
+ <rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
+ typ: ast::FloatType::F32,
+ rounding: Some(rn),
+ flush_to_zero: Some(ftz.is_some()),
+ saturate: sat.is_some(),
+ },
+ <rn:RoundingModeFloat> ".f64" => ast::ArithFloat {
+ typ: ast::FloatType::F64,
+ rounding: Some(rn),
+ flush_to_zero: None,
+ saturate: false,
+ },
+ ".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
+ typ: ast::FloatType::F16,
+ rounding: Some(ast::RoundingMode::NearestEven),
+ flush_to_zero: Some(ftz.is_some()),
+ saturate: sat.is_some(),
+ },
+ ".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
+ typ: ast::FloatType::F16x2,
+ rounding: Some(ast::RoundingMode::NearestEven),
+ flush_to_zero: Some(ftz.is_some()),
+ saturate: sat.is_some(),
+ },
+}
+
Operand: ast::Operand<&'input str> = {
<r:ExtendedID> => ast::Operand::Reg(r),
<r:ExtendedID> "+" <offset:S32Num> => ast::Operand::RegOffset(r, offset),
diff --git a/ptx/src/test/spirv_run/fma.ptx b/ptx/src/test/spirv_run/fma.ptx
new file mode 100644
index 0000000..171d306
--- /dev/null
+++ b/ptx/src/test/spirv_run/fma.ptx
@@ -0,0 +1,25 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry fma(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp1;
+ .reg .f32 temp2;
+ .reg .f32 temp3;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp1, [in_addr];
+ ld.f32 temp2, [in_addr+4];
+ ld.f32 temp3, [in_addr+8];
+ fma.rn.f32 temp1, temp1, temp2, temp3;
+ st.f32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt
new file mode 100644
index 0000000..734bf0f
--- /dev/null
+++ b/ptx/src/test/spirv_run/fma.spvtxt
@@ -0,0 +1,72 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 45
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+OpCapability FunctionFloatControlINTEL
+OpExtension "SPV_INTEL_float_controls2"
+%37 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "fma"
+OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
+%38 = OpTypeVoid
+%39 = OpTypeInt 64 0
+%40 = OpTypeFunction %38 %39 %39
+%41 = OpTypePointer Function %39
+%42 = OpTypeFloat 32
+%43 = OpTypePointer Function %42
+%44 = OpTypePointer Generic %42
+%27 = OpConstant %39 4
+%29 = OpConstant %39 8
+%1 = OpFunction %38 None %40
+%9 = OpFunctionParameter %39
+%10 = OpFunctionParameter %39
+%35 = OpLabel
+%2 = OpVariable %41 Function
+%3 = OpVariable %41 Function
+%4 = OpVariable %41 Function
+%5 = OpVariable %41 Function
+%6 = OpVariable %43 Function
+%7 = OpVariable %43 Function
+%8 = OpVariable %43 Function
+OpStore %2 %9
+OpStore %3 %10
+%12 = OpLoad %39 %2
+%11 = OpCopyObject %39 %12
+OpStore %4 %11
+%14 = OpLoad %39 %3
+%13 = OpCopyObject %39 %14
+OpStore %5 %13
+%16 = OpLoad %39 %4
+%31 = OpConvertUToPtr %44 %16
+%15 = OpLoad %42 %31
+OpStore %6 %15
+%18 = OpLoad %39 %4
+%28 = OpIAdd %39 %18 %27
+%32 = OpConvertUToPtr %44 %28
+%17 = OpLoad %42 %32
+OpStore %7 %17
+%20 = OpLoad %39 %4
+%30 = OpIAdd %39 %20 %29
+%33 = OpConvertUToPtr %44 %30
+%19 = OpLoad %42 %33
+OpStore %8 %19
+%22 = OpLoad %42 %6
+%23 = OpLoad %42 %7
+%24 = OpLoad %42 %8
+%21 = OpExtInst %42 %37 mad %22 %23 %24
+OpStore %6 %21
+%25 = OpLoad %39 %5
+%26 = OpLoad %42 %6
+%34 = OpConvertUToPtr %44 %25
+OpStore %34 %26
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index f336055..98b9630 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -91,6 +91,7 @@ test_ptx!(constant_f32, [10f32], [5f32]);
test_ptx!(constant_negative, [-101i32], [101i32]);
test_ptx!(and, [6u32, 3u32], [2u32]);
test_ptx!(selp, [100u16, 200u16], [200u16]);
+test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 9d73742..a7025b1 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -2343,7 +2343,9 @@ fn emit_function_body_ops(
ast::MulDetails::Unsigned(ref desc) => {
emit_mad_uint(builder, map, opencl, desc, arg)?
}
- ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
+ ast::MulDetails::Float(desc) => {
+ emit_mad_float(builder, map, opencl, desc, arg)?
+ }
},
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
@@ -2560,10 +2562,19 @@ fn emit_mad_sint(
fn emit_mad_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
+ opencl: spirv::Word,
desc: &ast::ArithFloat,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
- todo!()
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::mad as spirv::Word,
+ [arg.src1, arg.src2, arg.src3],
+ )?;
+ Ok(())
}
fn emit_add_float(