aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-07 16:14:37 +0100
committerAndrzej Janik <[email protected]>2020-11-07 16:14:37 +0100
commit62d14cdffe57134fc89099672ee2954ee413b440 (patch)
tree802f9e9f2b21a1cc3f9cc8471fc0968b567057c7
parentac6265f257654180f6661c406a025313190448c4 (diff)
downloadZLUDA-62d14cdffe57134fc89099672ee2954ee413b440.tar.gz
ZLUDA-62d14cdffe57134fc89099672ee2954ee413b440.zip
Fix ftz behavior slightly
-rw-r--r--ptx/src/test/spirv_run/mod.rs14
-rw-r--r--ptx/src/test/spirv_run/mul_ftz.spvtxt119
-rw-r--r--ptx/src/translate.rs55
3 files changed, 114 insertions, 74 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 5bbe45a..bd74508 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -83,8 +83,11 @@ test_ptx!(extern_shared_call, [121u64], [123u64]);
test_ptx!(rcp, [2f32], [0.5f32]);
// 0b1_00000000_10000000000000000000000u32 is a large denormal
// 0x3f000000 is 0.5
-// TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2
-// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]);
+test_ptx!(
+ mul_ftz,
+ [0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
+ [0b1_00000000_00000000000000000000000u32]
+);
test_ptx!(
mul_non_ftz,
[0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
@@ -196,7 +199,12 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
let (module, maybe_log) = match module.should_link_ptx_impl {
Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]),
None => {
- let (module, log) = ze::Module::build_spirv(&mut ctx, &dev, byte_il, None);
+ let (module, log) = ze::Module::build_spirv(
+ &mut ctx,
+ &dev,
+ byte_il,
+ Some(module.build_options.as_c_str()),
+ );
(module, Some(log))
}
};
diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt
index 56cec5a..3e80ae3 100644
--- a/ptx/src/test/spirv_run/mul_ftz.spvtxt
+++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt
@@ -1,64 +1,55 @@
-; SPIR-V
-; Version: 1.3
-; Generator: rspirv
-; Bound: 38
-OpCapability GenericPointer
-OpCapability Linkage
-OpCapability Addresses
-OpCapability Kernel
-OpCapability Int8
-OpCapability Int16
-OpCapability Int64
-OpCapability Float16
-OpCapability Float64
-; OpCapability FunctionFloatControlINTEL
-; OpExtension "SPV_INTEL_float_controls2"
-%30 = OpExtInstImport "OpenCL.std"
-OpMemoryModel Physical64 OpenCL
-OpEntryPoint Kernel %1 "mul_ftz"
-OpDecorate %1 FunctionDenormModeINTEL 32 FlushToZero
-%31 = OpTypeVoid
-%32 = OpTypeInt 64 0
-%33 = OpTypeFunction %31 %32 %32
-%34 = OpTypePointer Function %32
-%35 = OpTypeFloat 32
-%36 = OpTypePointer Function %35
-%37 = OpTypePointer Generic %35
-%23 = OpConstant %32 4
-%1 = OpFunction %31 None %33
-%8 = OpFunctionParameter %32
-%9 = OpFunctionParameter %32
-%28 = OpLabel
-%2 = OpVariable %34 Function
-%3 = OpVariable %34 Function
-%4 = OpVariable %34 Function
-%5 = OpVariable %34 Function
-%6 = OpVariable %36 Function
-%7 = OpVariable %36 Function
-OpStore %2 %8
-OpStore %3 %9
-%11 = OpLoad %32 %2
-%10 = OpCopyObject %32 %11
-OpStore %4 %10
-%13 = OpLoad %32 %3
-%12 = OpCopyObject %32 %13
-OpStore %5 %12
-%15 = OpLoad %32 %4
-%25 = OpConvertUToPtr %37 %15
-%14 = OpLoad %35 %25
-OpStore %6 %14
-%17 = OpLoad %32 %4
-%24 = OpIAdd %32 %17 %23
-%26 = OpConvertUToPtr %37 %24
-%16 = OpLoad %35 %26
-OpStore %7 %16
-%19 = OpLoad %35 %6
-%20 = OpLoad %35 %7
-%18 = OpFMul %35 %19 %20
-OpStore %6 %18
-%21 = OpLoad %32 %5
-%22 = OpLoad %35 %6
-%27 = OpConvertUToPtr %37 %21
-OpStore %27 %22
-OpReturn
-OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %28 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "mul_ftz"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %31 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Generic_float = OpTypePointer Generic %float
+ %ulong_4 = OpConstant %ulong 4
+ %1 = OpFunction %void None %31
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %26 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_float Function
+ %7 = OpVariable %_ptr_Function_float Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %10 = OpLoad %ulong %2
+ OpStore %4 %10
+ %11 = OpLoad %ulong %3
+ OpStore %5 %11
+ %13 = OpLoad %ulong %4
+ %23 = OpConvertUToPtr %_ptr_Generic_float %13
+ %12 = OpLoad %float %23
+ OpStore %6 %12
+ %15 = OpLoad %ulong %4
+ %22 = OpIAdd %ulong %15 %ulong_4
+ %24 = OpConvertUToPtr %_ptr_Generic_float %22
+ %14 = OpLoad %float %24
+ OpStore %7 %14
+ %17 = OpLoad %float %6
+ %18 = OpLoad %float %7
+ %16 = OpFMul %float %17 %18
+ OpStore %6 %16
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %float %6
+ %25 = OpConvertUToPtr %_ptr_Generic_float %19
+ OpStore %25 %20
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 365d1e8..c0e15f2 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,7 +1,7 @@
use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
-use std::{borrow::Cow, convert::TryFrom, hash::Hash, iter, mem};
+use std::{borrow::Cow, convert::TryFrom, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@@ -448,6 +448,7 @@ pub struct Module {
pub spirv: dr::Module,
pub kernel_info: HashMap<String, KernelInfo>,
pub should_link_ptx_impl: Option<&'static [u8]>,
+ pub build_options: CString,
}
pub struct KernelInfo {
@@ -484,6 +485,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
+ let build_options = emit_denorm_build_string(&call_map, &denorm_information);
emit_directives(
&mut builder,
&mut map,
@@ -503,15 +505,51 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
} else {
None
},
+ build_options,
})
}
+// TODO: remove this once we have perf-function support for denorms
+fn emit_denorm_build_string(
+ call_map: &HashMap<&str, HashSet<u32>>,
+ denorm_information: &HashMap<MethodName, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+) -> CString {
+ let denorm_counts = denorm_information
+ .iter()
+ .map(|(method, meth_denorm)| {
+ let f16_count = meth_denorm
+ .get(&(mem::size_of::<f16>() as u8))
+ .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
+ .1;
+ let f32_count = meth_denorm
+ .get(&(mem::size_of::<f32>() as u8))
+ .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
+ .1;
+ (method, (f16_count + f32_count))
+ })
+ .collect::<HashMap<_, _>>();
+ let mut flush_over_preserve = 0;
+ for (kernel, children) in call_map {
+ flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
+ for child_fn in children {
+ flush_over_preserve += *denorm_counts
+ .get(&MethodName::Func(*child_fn))
+ .unwrap_or(&0);
+ }
+ }
+ if flush_over_preserve > 0 {
+ CString::new("-cl-denorms-are-zero").unwrap()
+ } else {
+ CString::default()
+ }
+}
+
fn emit_directives<'input>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver<'input>,
opencl_id: spirv::Word,
- denorm_information: &HashMap<MethodName<'input>, HashMap<u8, spirv::FPDenormMode>>,
+ denorm_information: &HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
directives: Vec<Directive>,
kernel_info: &mut HashMap<String, KernelInfo>,
@@ -579,6 +617,9 @@ fn get_call_map<'input>(
..
}) => {
let call_key = MethodName::new(&func_decl);
+ if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
+ entry.insert(Vec::new());
+ }
for statement in statements {
match statement {
Statement::Call(call) => {
@@ -895,7 +936,7 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
// and emit suitable execution mode
fn compute_denorm_information<'input>(
module: &[Directive<'input>],
-) -> HashMap<MethodName<'input>, HashMap<u8, spirv::FPDenormMode>> {
+) -> HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
@@ -937,13 +978,13 @@ fn compute_denorm_information<'input>(
.map(|(name, v)| {
let width_to_denorm = v
.into_iter()
- .map(|(k, ftz_over_preserve)| {
- let mode = if ftz_over_preserve > 0 {
+ .map(|(k, flush_over_preserve)| {
+ let mode = if flush_over_preserve > 0 {
spirv::FPDenormMode::FlushToZero
} else {
spirv::FPDenormMode::Preserve
};
- (k, mode)
+ (k, (mode, flush_over_preserve))
})
.collect();
(name, width_to_denorm)
@@ -999,7 +1040,7 @@ fn emit_function_header<'a>(
defined_globals: &GlobalStringIdResolver<'a>,
synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
func_decl: &SpirvMethodDecl<'a>,
- _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, spirv::FPDenormMode>>,
+ _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,