diff options
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r-- | ptx/src/translate.rs | 289 |
1 files changed, 261 insertions, 28 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index cccf6ad..604b4ef 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
-use std::{borrow::Cow, iter, mem};
+use std::{borrow::Cow, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryFrom,
@@ -438,6 +438,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro let mut directives =
convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
+ let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 3);
emit_capabilities(&mut builder);
@@ -463,6 +464,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro &mut map,
&id_defs,
f.func_decl,
+ &denorm_information,
&mut kernel_info,
)?;
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
@@ -523,10 +525,7 @@ fn convert_dynamic_shared_memory_usage<'input>( globals,
body: Some(statements),
}) => {
- let call_key = match func_decl {
- ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
- ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
- };
+ let call_key = CallgraphKey::new(&func_decl);
let statements = statements
.into_iter()
.map(|statement| match statement {
@@ -563,10 +562,7 @@ fn convert_dynamic_shared_memory_usage<'input>( globals,
body: Some(statements),
}) => {
- let call_key = match func_decl {
- ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
- ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
- };
+ let call_key = CallgraphKey::new(&func_decl);
if !methods_using_extern_shared.contains(&call_key) {
return Directive::Method(Function {
func_decl,
@@ -726,12 +722,171 @@ fn get_callers_of_extern_shared_single<'a>( }
}
+type DenormCountMap<T> = HashMap<T, isize>;
+
+fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
+ let num_value = if value { 1 } else { -1 };
+ denorm_count_map_update_impl(map, key, num_value);
+}
+
+fn denorm_count_map_update_impl<T: Eq + Hash>(
+ map: &mut DenormCountMap<T>,
+ key: T,
+ num_value: isize,
+) {
+ match map.entry(key) {
+ hash_map::Entry::Occupied(mut counter) => {
+ *(counter.get_mut()) += num_value;
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(num_value);
+ }
+ }
+}
+
+fn denorm_count_map_merge<T: Eq + Hash + Copy>(
+ dst: &mut DenormCountMap<T>,
+ src: &DenormCountMap<T>,
+) {
+ for (k, count) in src {
+ denorm_count_map_update_impl(dst, *k, *count);
+ }
+}
+
+// HACK ALERT!
+// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
+// in the kernel as flushing denorms to zero or preserving them
+// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
+// such capability, so instead we guesstimate which use is more common in the kernel
+// and emit suitable execution mode
+fn compute_denorm_information<'input>(
+ module: &[Directive<'input>],
+) -> HashMap<&'input str, HashMap<u8, spirv::ExecutionMode>> {
+ let mut direct_func_calls = MultiHashMap::new();
+ let mut denorm_methods = HashMap::new();
+ for directive in module.iter() {
+ match directive {
+ Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ ..
+ }) => {
+ let mut flush_counter = DenormCountMap::new();
+ let method_key = CallgraphKey::new(func_decl);
+ for statement in statements {
+ match statement {
+ Statement::Instruction(inst) => {
+ if let Some((flush, width)) = inst.flush_to_zero() {
+ denorm_count_map_update(&mut flush_counter, width, flush);
+ }
+ }
+ Statement::LoadVar(_, _) => {}
+ Statement::StoreVar(_, _) => {}
+ Statement::Call(ResolvedCall { func, .. }) => {
+ multi_hash_map_append(&mut direct_func_calls, method_key, *func);
+ }
+ Statement::Composite(_) => {}
+ Statement::Conditional(_) => {}
+ Statement::Conversion(_) => {}
+ Statement::Constant(_) => {}
+ Statement::RetValue(_, _) => {}
+ Statement::Undef(_, _) => {}
+ Statement::Label(_) => {}
+ Statement::Variable(_) => {}
+ }
+ }
+ denorm_methods.insert(method_key, flush_counter);
+ }
+ }
+ }
+ let summed_denorm_methods = sum_up_denorm_use(module, denorm_methods, &direct_func_calls);
+ summed_denorm_methods
+ .into_iter()
+ .filter_map(|(name, v)| {
+ let width_to_denorm = v
+ .into_iter()
+ .map(|(k, ftz_over_preserve)| {
+ let mode = if ftz_over_preserve > 0 {
+ spirv::ExecutionMode::DenormFlushToZero
+ } else {
+ spirv::ExecutionMode::DenormPreserve
+ };
+ (k, mode)
+ })
+ .collect();
+ Some((name, width_to_denorm))
+ })
+ .collect()
+}
+
+fn sum_up_denorm_use<'input>(
+ module: &[Directive<'input>],
+ denorm_methods: HashMap<CallgraphKey<'input>, DenormCountMap<u8>>,
+ direct_func_calls: &MultiHashMap<CallgraphKey<'input>, spirv::Word>,
+) -> HashMap<&'input str, DenormCountMap<u8>> {
+ let mut result = HashMap::new();
+ let empty = Vec::new();
+ for (method_key, denorm_map) in denorm_methods.iter() {
+ match method_key {
+ CallgraphKey::Kernel(name) => {
+ let mut sum = denorm_map.clone();
+ let mut visited = HashSet::new();
+ for child in direct_func_calls
+ .get(&CallgraphKey::Kernel(name))
+ .unwrap_or(&empty)
+ {
+ sum_up_denorm_use_single(
+ &denorm_methods,
+ direct_func_calls,
+ &mut sum,
+ &mut visited,
+ *child,
+ );
+ }
+ result.insert(*name, sum);
+ }
+ CallgraphKey::Func(_) => {}
+ }
+ }
+ result
+}
+
+fn sum_up_denorm_use_single<'input>(
+ denorm_methods: &HashMap<CallgraphKey<'input>, DenormCountMap<u8>>,
+ direct_func_calls: &MultiHashMap<CallgraphKey<'input>, spirv::Word>,
+ sum: &mut DenormCountMap<u8>,
+ visited: &mut HashSet<spirv::Word>,
+ current: spirv::Word,
+) {
+ if !visited.insert(current) {
+ return;
+ }
+ if let Some(denorm_map) = denorm_methods.get(&CallgraphKey::Func(current)) {
+ denorm_count_map_merge(sum, denorm_map);
+ }
+ if let Some(children) = direct_func_calls.get(&CallgraphKey::Func(current)) {
+ for child in children {
+ sum_up_denorm_use_single(denorm_methods, direct_func_calls, sum, visited, *child);
+ }
+ }
+}
+
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
enum CallgraphKey<'input> {
Kernel(&'input str),
Func(spirv::Word),
}
+impl<'input> CallgraphKey<'input> {
+ fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
+ match decl {
+ ast::MethodDecl::Kernel { name, .. } => CallgraphKey::Kernel(name),
+ ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(*id),
+ }
+ }
+}
+
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -764,6 +919,7 @@ fn emit_function_header<'a>( map: &mut TypeWordMap,
global: &GlobalStringIdResolver<'a>,
func_directive: ast::MethodDecl<spirv::Word>,
+ denorm_information: &HashMap<&'a str, HashMap<u8, spirv::ExecutionMode>>,
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
if let ast::MethodDecl::Kernel {
@@ -797,6 +953,11 @@ fn emit_function_header<'a>( .collect::<Vec<_>>();
global_variables.append(&mut interface);
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
+ if let Some(exec_modes) = denorm_information.get(name) {
+ for (size_of, exec_mode) in exec_modes {
+ builder.execution_mode(fn_id, *exec_mode, [(*size_of as u32) * 8])
+ }
+ }
fn_id
}
ast::MethodDecl::Func(_, name, _) => name,
@@ -844,9 +1005,14 @@ fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::Int64);
builder.capability(spirv::Capability::Float16);
builder.capability(spirv::Capability::Float64);
+ builder.capability(spirv::Capability::DenormFlushToZero);
+ builder.capability(spirv::Capability::DenormPreserve);
}
-fn emit_extensions(_: &mut dr::Builder) {}
+// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html
+fn emit_extensions(builder: &mut dr::Builder) {
+ builder.extension("SPV_KHR_float_controls");
+}
fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {
builder.ext_inst_import("OpenCL.std")
@@ -2088,7 +2254,7 @@ fn emit_function_body_ops( ast::MulDetails::Unsigned(ref ctr) => {
emit_mul_uint(builder, map, opencl, ctr, arg)?
}
- ast::MulDetails::Float(_) => todo!(),
+ ast::MulDetails::Float(ref ctr) => emit_mul_float(builder, map, ctr, arg)?,
},
ast::Instruction::Add(add, arg) => match add {
ast::ArithDetails::Signed(ref desc) => {
@@ -2215,15 +2381,27 @@ fn emit_function_body_ops( Ok(())
}
+fn emit_mul_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ ctr: &ast::ArithFloat,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ if ctr.saturate {
+ todo!()
+ }
+ let result_type = map.get_or_add_scalar(builder, ctr.typ.into());
+ builder.f_mul(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ emit_rounding_decoration(builder, arg.dst, ctr.rounding);
+ Ok(())
+}
+
fn emit_rcp(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
desc: &ast::RcpDetails,
a: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), TranslateError> {
- if desc.flush_to_zero {
- todo!()
- }
let (instr_type, constant) = if desc.is_f64 {
(ast::ScalarType::F64, vec_repr(1.0f64))
} else {
@@ -2360,9 +2538,6 @@ fn emit_add_float( desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
- if desc.flush_to_zero {
- todo!()
- }
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
@@ -2375,9 +2550,6 @@ fn emit_sub_float( desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
- if desc.flush_to_zero {
- todo!()
- }
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
emit_rounding_decoration(builder, arg.dst, desc.rounding);
@@ -2441,7 +2613,7 @@ fn emit_cvt( if desc.dst == desc.src {
return Ok(());
}
- if desc.saturate || desc.flush_to_zero {
+ if desc.saturate {
todo!()
}
let dest_t: ast::ScalarType = desc.dst.into();
@@ -2450,7 +2622,7 @@ fn emit_cvt( emit_rounding_decoration(builder, arg.dst, desc.rounding);
}
ast::CvtDetails::FloatFromInt(desc) => {
- if desc.saturate || desc.flush_to_zero {
+ if desc.saturate {
todo!()
}
let dest_t: ast::ScalarType = desc.dst.into();
@@ -2463,9 +2635,6 @@ fn emit_cvt( emit_rounding_decoration(builder, arg.dst, desc.rounding);
}
ast::CvtDetails::IntFromFloat(desc) => {
- if desc.flush_to_zero {
- todo!()
- }
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.dst.is_signed() {
@@ -2561,9 +2730,6 @@ fn emit_setp( setp: &ast::SetpData,
arg: &ast::Arg4Setp<ExpandedArgParams>,
) -> Result<(), dr::Error> {
- if setp.flush_to_zero {
- todo!()
- }
let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred));
let result_id = Some(arg.dst1);
let operand_1 = arg.src1;
@@ -4122,6 +4288,73 @@ impl ast::Instruction<ExpandedArgParams> { | ast::Instruction::Mad(_, _) => None,
}
}
+
+ // .wide instructions don't support ftz, so it's enough to just look at the
+ // type declared by the instruction
+ fn flush_to_zero(&self) -> Option<(bool, u8)> {
+ match self {
+ ast::Instruction::Ld(_, _) => None,
+ ast::Instruction::St(_, _) => None,
+ ast::Instruction::Mov(_, _) => None,
+ ast::Instruction::Not(_, _) => None,
+ ast::Instruction::Bra(_, _) => None,
+ ast::Instruction::Shl(_, _) => None,
+ ast::Instruction::Shr(_, _) => None,
+ ast::Instruction::Ret(_) => None,
+ ast::Instruction::Call(_) => None,
+ ast::Instruction::Or(_, _) => None,
+ ast::Instruction::Cvta(_, _) => None,
+ ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None,
+ ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None,
+ ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None,
+ ast::Instruction::Add(ast::ArithDetails::Unsigned(_), _) => None,
+ ast::Instruction::Mul(ast::MulDetails::Unsigned(_), _) => None,
+ ast::Instruction::Mul(ast::MulDetails::Signed(_), _) => None,
+ ast::Instruction::Mad(ast::MulDetails::Unsigned(_), _) => None,
+ ast::Instruction::Mad(ast::MulDetails::Signed(_), _) => None,
+ ast::Instruction::Min(ast::MinMaxDetails::Signed(_), _) => None,
+ ast::Instruction::Min(ast::MinMaxDetails::Unsigned(_), _) => None,
+ ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None,
+ ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None,
+ ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None,
+ ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
+ | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
+ | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
+ | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
+ ast::Instruction::Setp(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, details.typ.size_of())),
+ ast::Instruction::SetpBool(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, details.typ.size_of())),
+ ast::Instruction::Abs(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, details.typ.size_of())),
+ ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _)
+ | ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
+ ast::Instruction::Rcp(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })),
+ // Modifier .ftz can only be specified when either .dtype or .atype
+ // is .f32 and applies only to single precision (.f32) inputs and results.
+ ast::Instruction::Cvt(
+ ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }),
+ _,
+ )
+ | ast::Instruction::Cvt(
+ ast::CvtDetails::FloatFromInt(ast::CvtDesc { flush_to_zero, .. }),
+ _,
+ )
+ | ast::Instruction::Cvt(
+ ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }),
+ _,
+ ) => flush_to_zero.map(|ftz| (ftz, 4)),
+ }
+ }
}
impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
|