aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-09-24 01:31:50 +0200
committerAndrzej Janik <[email protected]>2021-09-24 01:31:50 +0200
commit370c0bd09ef5b49e327368fb1899c1692bb8eff4 (patch)
tree6935830138f9105947a8e95c266e42d05d31661e /ptx/src/translate.rs
parent9609f86033e9bff1e080ab9f7e856ed8ce3bd93d (diff)
downloadZLUDA-370c0bd09ef5b49e327368fb1899c1692bb8eff4.tar.gz
ZLUDA-370c0bd09ef5b49e327368fb1899c1692bb8eff4.zip
Start implementing .shared unification
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs323
1 files changed, 204 insertions, 119 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 2af7534..e96cdc2 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -443,7 +443,8 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
let call_map = get_kernels_call_map(&directives);
- //let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
+ let mut directives =
+ convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@@ -607,7 +608,7 @@ fn emit_directives<'input>(
}
}
emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
- emit_function_linkage(builder, id_defs, f, fn_id);
+ emit_function_linkage(builder, id_defs, f, fn_id)?;
builder.select_block(None)?;
builder.end_function()?;
}
@@ -683,7 +684,7 @@ fn get_kernels_call_map<'input>(
}
fn add_call_map_single<'input>(
- directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>,
+ directly_called_by: &HashMap<ast::MethodName<'input, spirv::Word>, Vec<spirv::Word>>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
@@ -697,15 +698,21 @@ fn add_call_map_single<'input>(
}
}
-type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
-
-fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) {
+fn multi_hash_map_append<
+ K: Eq + std::hash::Hash,
+ V,
+ Collection: std::iter::Extend<V> + std::default::Default,
+>(
+ m: &mut HashMap<K, Collection>,
+ key: K,
+ value: V,
+) {
match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => {
- entry.get_mut().push(value);
+ entry.get_mut().extend(iter::once(value));
}
hash_map::Entry::Vacant(entry) => {
- entry.insert(vec![value]);
+ entry.insert(Default::default());
}
}
}
@@ -713,7 +720,8 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
/*
PTX represents dynamically allocated shared local memory as
.extern .shared .b32 shared_mem[];
- In SPIRV/OpenCL world this is expressed as an additional argument
+ In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
+ And in AMD compilation
This pass looks for all uses of .extern .shared and converts them to
an additional method argument
The question is how this artificial argument should be expressed. There are
@@ -735,30 +743,35 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
*/
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
+ kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
new_id: &mut impl FnMut() -> spirv::Word,
) -> Vec<Directive<'input>> {
- let mut extern_shared_decls = HashMap::new();
+ let mut globals_shared = HashMap::new();
for dir in module.iter() {
match dir {
Directive::Variable(
linking,
ast::Variable {
- v_type: ast::Type::Array(p_type, dims),
state_space: ast::StateSpace::Shared,
name,
+ v_type,
..
},
- ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
- extern_shared_decls.insert(*name, *p_type);
+ ) => {
+ let size = if linking.contains(ast::LinkingDirective::EXTERN) {
+ GlobalSharedSize::ExternUnsized
+ } else {
+ GlobalSharedSize::Sized((*v_type).size_of())
+ };
+ globals_shared.insert(*name, (size, v_type.clone()));
}
_ => {}
}
}
- if extern_shared_decls.len() == 0 {
+ if globals_shared.len() == 0 {
return module;
}
- let mut methods_using_extern_shared = HashSet::new();
- let mut directly_called_by = MultiHashMap::new();
+ let mut methods_to_globals_shared_direct_only_use = HashMap::<_, GlobalSharedSize>::new();
let module = module
.into_iter()
.map(|directive| match directive {
@@ -773,17 +786,21 @@ fn convert_dynamic_shared_memory_usage<'input>(
let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
- .map(|statement| match statement {
- Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call.name, call_key);
- Statement::Call(call)
- }
- statement => statement.map_id(&mut |id, _| {
- if extern_shared_decls.contains_key(&id) {
- methods_using_extern_shared.insert(call_key);
+ .map(|statement| {
+ statement.map_id(&mut |id, _| {
+ if let Some((size, _)) = globals_shared.get(&id) {
+ match methods_to_globals_shared_direct_only_use.entry(call_key) {
+ hash_map::Entry::Occupied(mut e) => {
+ let original_size = *e.get();
+ e.insert(original_size.fold(*size));
+ }
+ hash_map::Entry::Vacant(mut e) => {
+ e.insert(*size);
+ }
+ }
}
id
- }),
+ })
})
.collect();
Directive::Method(Function {
@@ -800,11 +817,15 @@ fn convert_dynamic_shared_memory_usage<'input>(
.collect::<Vec<_>>();
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
// make sure it gets propagated to `fn1` and `kernel`
- get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by);
+ let (kernels_to_global_shared, functions_to_global_shared) =
+ resolve_indirect_uses_of_globals_shared(
+ methods_to_globals_shared_direct_only_use,
+ kernels_methods_call_map,
+ );
// now visit every method declaration and inject those additional arguments
- module
- .into_iter()
- .map(|directive| match directive {
+ let mut result = Vec::with_capacity(module.len());
+ for directive in module.into_iter() {
+ match directive {
Directive::Method(Function {
func_decl,
globals,
@@ -813,46 +834,119 @@ fn convert_dynamic_shared_memory_usage<'input>(
tuning,
linkage,
}) => {
- if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
- return Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- tuning,
- linkage,
- });
- }
- let shared_id_param = new_id();
- {
- let mut func_decl = (*func_decl).borrow_mut();
- func_decl.shared_mem = Some(shared_id_param);
- }
- let statements = replace_uses_of_shared_memory(
- new_id,
- &extern_shared_decls,
- &mut methods_using_extern_shared,
- shared_id_param,
- statements,
- );
- Directive::Method(Function {
+ let statements = {
+ let func_decl_ref = &mut (*func_decl).borrow_mut();
+ let method_name = func_decl_ref.name;
+ insert_arguments_remap_statements(
+ method_name,
+ &kernels_to_global_shared,
+ new_id,
+ &mut result,
+ &functions_to_global_shared,
+ func_decl_ref,
+ &globals_shared,
+ statements,
+ )
+ };
+ result.push(Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
- })
+ }));
}
- directive => directive,
- })
- .collect::<Vec<_>>()
+ directive => result.push(directive),
+ }
+ }
+ result
+}
+
+fn insert_arguments_remap_statements(
+ method_name: ast::MethodName<u32>,
+ kernels_to_global_shared: &HashMap<&str, GlobalSharedSize>,
+ new_id: &mut impl FnMut() -> u32,
+ result: &mut Vec<Directive>,
+ functions_to_global_shared: &HashSet<u32>,
+ func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<u32>>,
+ globals_shared: &HashMap<u32, (GlobalSharedSize, ast::Type)>,
+ statements: Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>,
+) -> Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>> {
+ let shared_id_param = match method_name {
+ ast::MethodName::Kernel(kernel_name) => {
+ let globals_shared_size = match kernels_to_global_shared.get(kernel_name) {
+ Some(s) => *s,
+ None => return statements,
+ };
+ let shared_id_param = new_id();
+ let (linkage, type_) = match globals_shared_size {
+ GlobalSharedSize::ExternUnsized => (
+ ast::LinkingDirective::EXTERN,
+ ast::Type::Array(ast::ScalarType::U8, Vec::new()),
+ ),
+ GlobalSharedSize::Sized(size) => (
+ ast::LinkingDirective::NONE,
+ ast::Type::Array(ast::ScalarType::U8, vec![size as u32]),
+ ),
+ };
+ result.push(Directive::Variable(
+ linkage,
+ ast::Variable {
+ align: None,
+ v_type: type_,
+ state_space: ast::StateSpace::Shared,
+ name: shared_id_param,
+ array_init: Vec::new(),
+ },
+ ));
+ shared_id_param
+ }
+ ast::MethodName::Func(function_name) => {
+ if !functions_to_global_shared.contains(&function_name) {
+ return statements;
+ }
+ let shared_id_param = new_id();
+ func_decl_ref.input_arguments.push(ast::Variable {
+ align: None,
+ v_type: ast::Type::Pointer(ast::ScalarType::B8, ast::StateSpace::Shared),
+ state_space: ast::StateSpace::Reg,
+ name: shared_id_param,
+ array_init: Vec::new(),
+ });
+ shared_id_param
+ }
+ };
+ replace_uses_of_shared_memory(
+ new_id,
+ globals_shared,
+ functions_to_global_shared,
+ shared_id_param,
+ statements,
+ )
+}
+
+#[derive(Clone, Copy, PartialEq, Eq, Hash)]
+enum GlobalSharedSize {
+ ExternUnsized,
+ Sized(usize),
+}
+
+impl GlobalSharedSize {
+ fn fold(self, other: GlobalSharedSize) -> GlobalSharedSize {
+ match (self, other) {
+ (GlobalSharedSize::Sized(s1), GlobalSharedSize::Sized(s2)) => {
+ GlobalSharedSize::Sized(usize::max(s1, s2))
+ }
+ _ => GlobalSharedSize::ExternUnsized,
+ }
+ }
}
fn replace_uses_of_shared_memory<'a>(
new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
- methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ extern_shared_decls: &HashMap<spirv::Word, (GlobalSharedSize, ast::Type)>,
+ methods_using_extern_shared: &HashSet<spirv::Word>,
shared_id_param: spirv::Word,
statements: Vec<ExpandedStatement>,
) -> Vec<ExpandedStatement> {
@@ -863,7 +957,7 @@ fn replace_uses_of_shared_memory<'a>(
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
+ if methods_using_extern_shared.contains(&call.name) {
call.input_arguments.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
@@ -874,8 +968,8 @@ fn replace_uses_of_shared_memory<'a>(
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
- if let Some(scalar_type) = extern_shared_decls.get(&id) {
- if *scalar_type == ast::ScalarType::B8 {
+ if let Some((_, type_)) = extern_shared_decls.get(&id) {
+ if *type_ == ast::Type::Scalar(ast::ScalarType::B8) {
return shared_id_param;
}
let replacement_id = new_id();
@@ -884,7 +978,7 @@ fn replace_uses_of_shared_memory<'a>(
dst: replacement_id,
from_type: ast::Type::Scalar(ast::ScalarType::B8),
from_space: ast::StateSpace::Shared,
- to_type: ast::Type::Scalar(*scalar_type),
+ to_type: type_.clone(),
to_space: ast::StateSpace::Shared,
kind: ConversionKind::PtrToPtr,
}));
@@ -900,43 +994,40 @@ fn replace_uses_of_shared_memory<'a>(
result
}
-fn get_callers_of_extern_shared<'a>(
- methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
- directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
-) {
- let direct_uses_of_extern_shared = methods_using_extern_shared
- .iter()
- .filter_map(|method| {
- if let ast::MethodName::Func(f_id) = method {
- Some(*f_id)
- } else {
- None
- }
- })
- .collect::<Vec<_>>();
- for fn_id in direct_uses_of_extern_shared {
- get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id);
- }
-}
-
-fn get_callers_of_extern_shared_single<'a>(
- methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
- directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
- fn_id: spirv::Word,
-) {
- if let Some(callers) = directly_called_by.get(&fn_id) {
- for caller in callers {
- if methods_using_extern_shared.insert(*caller) {
- if let ast::MethodName::Func(caller_fn) = caller {
- get_callers_of_extern_shared_single(
- methods_using_extern_shared,
- directly_called_by,
- *caller_fn,
- );
+// We need to compute two kinds of information:
+// * If it's a kernel -> size of .shared globals in use (direct or indirect)
+// * If it's a function -> does it use .shared global (directly or indirectly)
+fn resolve_indirect_uses_of_globals_shared<'input>(
+ mut methods_use_of_globals_shared: HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ GlobalSharedSize,
+ >,
+ kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
+) -> (HashMap<&'input str, GlobalSharedSize>, HashSet<spirv::Word>) {
+ let mut kernel_use = HashMap::new();
+ let mut functions_using_global = HashSet::new();
+ let empty = HashSet::new();
+ for (method, globals) in methods_use_of_globals_shared.iter() {
+ match method {
+ ast::MethodName::Kernel(kernel_name) => {
+ let mut size = *globals;
+ for &called_subfunction in
+ kernels_methods_call_map.get(kernel_name).unwrap_or(&empty)
+ {
+ if let Some(new_size) = methods_use_of_globals_shared
+ .get(&ast::MethodName::Func(called_subfunction))
+ {
+ size = size.fold(*new_size);
+ }
}
+ kernel_use.insert(*kernel_name, size);
+ }
+ ast::MethodName::Func(fn_id) => {
+ functions_using_global.insert(*fn_id);
}
}
}
+ (kernel_use, functions_using_global)
}
type DenormCountMap<T> = HashMap<T, isize>;
@@ -3480,7 +3571,10 @@ fn emit_variable<'input>(
[dr::Operand::LiteralInt32(align)].iter().cloned(),
);
}
- emit_linking_decoration(builder, id_defs, None, var.name, linking);
+ if var.state_space != ast::StateSpace::Shared || !linking.contains(ast::LinkingDirective::EXTERN)
+ {
+ emit_linking_decoration(builder, id_defs, None, var.name, linking);
+ }
Ok(())
}
@@ -3494,9 +3588,9 @@ fn emit_linking_decoration<'input>(
if linking == ast::LinkingDirective::NONE {
return;
}
- let string_name =
- name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
if linking.contains(ast::LinkingDirective::VISIBLE) {
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
builder.decorate(
name,
spirv::Decoration::LinkageAttributes,
@@ -3508,6 +3602,8 @@ fn emit_linking_decoration<'input>(
.cloned(),
);
} else if linking.contains(ast::LinkingDirective::EXTERN) {
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
builder.decorate(
name,
spirv::Decoration::LinkageAttributes,
@@ -4454,7 +4550,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
})
.collect::<HashSet<_>>();
let mut stateful_markers = Vec::new();
- let mut stateful_init_reg = MultiHashMap::new();
+ let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Cvta(
@@ -7863,26 +7959,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
let is_kernel = self.name.is_kernel();
- self.input_arguments
- .iter()
- .map(move |arg| {
- if !is_kernel && arg.state_space != ast::StateSpace::Reg {
- let spirv_type =
- SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
- (arg.name, spirv_type)
- } else {
- (arg.name, SpirvType::new(arg.v_type.clone()))
- }
- })
- .chain(self.shared_mem.iter().map(|id| {
- (
- *id,
- SpirvType::Pointer(
- Box::new(SpirvType::Base(SpirvScalarKey::B8)),
- spirv::StorageClass::Workgroup,
- ),
- )
- }))
+ self.input_arguments.iter().map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
}
}