summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-05 21:39:34 +0100
committerAndrzej Janik <[email protected]>2020-11-05 21:39:34 +0100
commit8e409254b3f30577a840885f6d7a56b27f4c2611 (patch)
tree163c60d33c90475077a9efffd7011d5c9be760b0
parent96702d86c96ef2d14795a71af43015a8eacd0a94 (diff)
downloadZLUDA-8e409254b3f30577a840885f6d7a56b27f4c2611.tar.gz
ZLUDA-8e409254b3f30577a840885f6d7a56b27f4c2611.zip
Fix same width float-to-float conversions
-rw-r--r--ptx/src/ptx.lalrpop4
-rw-r--r--ptx/src/test/spirv_run/cvt_rni.ptx25
-rw-r--r--ptx/src/test/spirv_run/cvt_rni.spvtxt63
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/translate.rs57
5 files changed, 139 insertions, 11 deletions
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 584ef84..31c2356 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -1068,7 +1068,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
), a)
},
- "cvt" <r:RoundingModeFloat?> <f:".ftz"?> <s:".sat"?> ".f32" ".f32" <a:Arg2> => {
+ "cvt" <r:RoundingModeInt?> <f:".ftz"?> <s:".sat"?> ".f32" ".f32" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: r,
@@ -1112,7 +1112,7 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
), a)
},
- "cvt" <r:RoundingModeFloat?> <s:".sat"?> ".f64" ".f64" <a:Arg2> => {
+ "cvt" <r:RoundingModeInt?> <s:".sat"?> ".f64" ".f64" <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(
ast::CvtDesc {
rounding: r,
diff --git a/ptx/src/test/spirv_run/cvt_rni.ptx b/ptx/src/test/spirv_run/cvt_rni.ptx
new file mode 100644
index 0000000..ecf20f8
--- /dev/null
+++ b/ptx/src/test/spirv_run/cvt_rni.ptx
@@ -0,0 +1,25 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry cvt_rni(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp1;
+ .reg .f32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp1, [in_addr];
+ ld.f32 temp2, [in_addr+4];
+ cvt.rni.f32.f32 temp1, temp1;
+ cvt.rni.f32.f32 temp2, temp2;
+ st.f32 [out_addr], temp1;
+ st.f32 [out_addr+4], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/cvt_rni.spvtxt b/ptx/src/test/spirv_run/cvt_rni.spvtxt
new file mode 100644
index 0000000..cad84a2
--- /dev/null
+++ b/ptx/src/test/spirv_run/cvt_rni.spvtxt
@@ -0,0 +1,63 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %34 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "cvt_rni"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %37 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Generic_float = OpTypePointer Generic %float
+ %ulong_4 = OpConstant %ulong 4
+ %ulong_4_0 = OpConstant %ulong 4
+ %1 = OpFunction %void None %37
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %32 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_float Function
+ %7 = OpVariable %_ptr_Function_float Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %10 = OpLoad %ulong %2
+ OpStore %4 %10
+ %11 = OpLoad %ulong %3
+ OpStore %5 %11
+ %13 = OpLoad %ulong %4
+ %28 = OpConvertUToPtr %_ptr_Generic_float %13
+ %12 = OpLoad %float %28
+ OpStore %6 %12
+ %15 = OpLoad %ulong %4
+ %25 = OpIAdd %ulong %15 %ulong_4
+ %29 = OpConvertUToPtr %_ptr_Generic_float %25
+ %14 = OpLoad %float %29
+ OpStore %7 %14
+ %17 = OpLoad %float %6
+ %16 = OpExtInst %float %34 rint %17
+ OpStore %6 %16
+ %19 = OpLoad %float %7
+ %18 = OpExtInst %float %34 rint %19
+ OpStore %7 %18
+ %20 = OpLoad %ulong %5
+ %21 = OpLoad %float %6
+ %30 = OpConvertUToPtr %_ptr_Generic_float %20
+ OpStore %30 %21
+ %22 = OpLoad %ulong %5
+ %23 = OpLoad %float %7
+ %27 = OpIAdd %ulong %22 %ulong_4_0
+ %31 = OpConvertUToPtr %_ptr_Generic_float %27
+ OpStore %31 %23
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 3fa82ba..163caac 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -108,6 +108,7 @@ test_ptx!(sin, [std::f32::consts::PI/2f32], [1f32]);
test_ptx!(cos, [std::f32::consts::PI], [-1f32]);
test_ptx!(lg2, [512f32], [9f32]);
test_ptx!(ex2, [10f32], [1024f32]);
+test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 7a0dd08..9519951 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -2813,7 +2813,7 @@ fn emit_function_body_ops(
}
}
ast::Instruction::Cvt(dets, arg) => {
- emit_cvt(builder, map, dets, arg)?;
+ emit_cvt(builder, map, opencl, dets, arg)?;
}
ast::Instruction::Cvta(_, arg) => {
// This would be only meaningful if const/slm/global pointers
@@ -3410,21 +3410,63 @@ fn emit_max(
fn emit_cvt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
+ opencl: spirv::Word,
dets: &ast::CvtDetails,
arg: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), TranslateError> {
match dets {
ast::CvtDetails::FloatFromFloat(desc) => {
- if desc.dst == desc.src {
- return Ok(());
- }
if desc.saturate {
todo!()
}
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- builder.f_convert(result_type, Some(arg.dst), arg.src)?;
- emit_rounding_decoration(builder, arg.dst, desc.rounding);
+ if desc.dst == desc.src {
+ match desc.rounding {
+ Some(ast::RoundingMode::NearestEven) => {
+ builder.ext_inst(
+ result_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::rint as u32,
+ [arg.src],
+ )?;
+ }
+ Some(ast::RoundingMode::Zero) => {
+ builder.ext_inst(
+ result_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::trunc as u32,
+ [arg.src],
+ )?;
+ }
+ Some(ast::RoundingMode::NegativeInf) => {
+ builder.ext_inst(
+ result_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::floor as u32,
+ [arg.src],
+ )?;
+ }
+ Some(ast::RoundingMode::PositiveInf) => {
+ builder.ext_inst(
+ result_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::ceil as u32,
+ [arg.src],
+ )?;
+ }
+ None => {
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
+ }
+ }
+ } else {
+ builder.f_convert(result_type, Some(arg.dst), arg.src)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
+ }
}
ast::CvtDetails::FloatFromInt(desc) => {
if desc.saturate {
@@ -3451,9 +3493,6 @@ fn emit_cvt(
emit_saturating_decoration(builder, arg.dst, desc.saturate);
}
ast::CvtDetails::IntFromInt(desc) => {
- if desc.dst == desc.src {
- return Ok(());
- }
let dest_t: ast::ScalarType = desc.dst.into();
let src_t: ast::ScalarType = desc.src.into();
// first do shortening/widening