aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-09-24 01:31:50 +0200
committerAndrzej Janik <[email protected]>2021-09-24 01:31:50 +0200
commit370c0bd09ef5b49e327368fb1899c1692bb8eff4 (patch)
tree6935830138f9105947a8e95c266e42d05d31661e
parent9609f86033e9bff1e080ab9f7e856ed8ce3bd93d (diff)
downloadZLUDA-370c0bd09ef5b49e327368fb1899c1692bb8eff4.tar.gz
ZLUDA-370c0bd09ef5b49e327368fb1899c1692bb8eff4.zip
Start implementing .shared unification
-rw-r--r--ptx/src/ptx.lalrpop3
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
-rw-r--r--ptx/src/test/spirv_run/shared_unify_extern.ptx34
-rw-r--r--ptx/src/test/spirv_run/shared_unify_extern.spvtxt62
-rw-r--r--ptx/src/translate.rs323
5 files changed, 305 insertions, 119 deletions
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index fa3cfec..5c4811c 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -1970,6 +1970,9 @@ ArgCall: (Vec<&'input str>, &'input str, Vec<ast::Operand<&'input str>>) = {
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => {
(ret_params, func, param_list)
},
+ "(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> => {
+ (ret_params, func, Vec::new())
+ },
<func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list),
<func:ExtendedID> => (Vec::new(), func, Vec::<ast::Operand<_>>::new()),
};
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index be34d0f..b7fd386 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -221,6 +221,8 @@ test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]);
+test_ptx!(shared_unify_extern, [7681u64], [15362u64]);
+
test_ptx!(func_ptr);
test_ptx!(lanemask_lt);
test_ptx!(extern_func);
diff --git a/ptx/src/test/spirv_run/shared_unify_extern.ptx b/ptx/src/test/spirv_run/shared_unify_extern.ptx
new file mode 100644
index 0000000..8b406b2
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_unify_extern.ptx
@@ -0,0 +1,34 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.extern .shared .b32 shared_ex[];
+.shared .b32 shared_mod[4];
+
+
+.func (.reg .b64 out) load_from_shared()
+{
+ ld.shared.u64 out, [shared_mod];
+ ret;
+}
+
+.visible .entry shared_unify_extern(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp1, [in_addr];
+ st.shared.u64 [shared_ex], temp1;
+ call (temp2), load_from_shared;
+ add.u64 temp2, temp2, temp1;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/shared_unify_extern.spvtxt b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt
new file mode 100644
index 0000000..9b62045
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_unify_extern.spvtxt
@@ -0,0 +1,62 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %30 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %2 "shared_ptr_take_address" %1
+ OpExecutionMode %2 ContractionOff
+ OpDecorate %1 Alignment 4
+ OpDecorate %1 LinkageAttributes "shared_mem" Import
+ %void = OpTypeVoid
+ %uchar = OpTypeInt 8 0
+%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
+ %1 = OpVariable %_ptr_Workgroup_uchar Workgroup
+ %ulong = OpTypeInt 64 0
+ %35 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
+ %2 = OpFunction %void None %35
+ %10 = OpFunctionParameter %ulong
+ %11 = OpFunctionParameter %ulong
+ %28 = OpLabel
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ %9 = OpVariable %_ptr_Function_ulong Function
+ OpStore %3 %10
+ OpStore %4 %11
+ %12 = OpLoad %ulong %3 Aligned 8
+ OpStore %5 %12
+ %13 = OpLoad %ulong %4 Aligned 8
+ OpStore %6 %13
+ %23 = OpConvertPtrToU %ulong %1
+ %14 = OpCopyObject %ulong %23
+ OpStore %7 %14
+ %16 = OpLoad %ulong %5
+ %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
+ %15 = OpLoad %ulong %24 Aligned 8
+ OpStore %8 %15
+ %17 = OpLoad %ulong %7
+ %18 = OpLoad %ulong %8
+ %25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17
+ OpStore %25 %18 Aligned 8
+ %20 = OpLoad %ulong %7
+ %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20
+ %19 = OpLoad %ulong %26 Aligned 8
+ OpStore %9 %19
+ %21 = OpLoad %ulong %6
+ %22 = OpLoad %ulong %9
+ %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21
+ OpStore %27 %22 Aligned 8
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 2af7534..e96cdc2 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -443,7 +443,8 @@ pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, Trans
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
let call_map = get_kernels_call_map(&directives);
- //let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
+ let mut directives =
+ convert_dynamic_shared_memory_usage(directives, &call_map, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@@ -607,7 +608,7 @@ fn emit_directives<'input>(
}
}
emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
- emit_function_linkage(builder, id_defs, f, fn_id);
+ emit_function_linkage(builder, id_defs, f, fn_id)?;
builder.select_block(None)?;
builder.end_function()?;
}
@@ -683,7 +684,7 @@ fn get_kernels_call_map<'input>(
}
fn add_call_map_single<'input>(
- directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>,
+ directly_called_by: &HashMap<ast::MethodName<'input, spirv::Word>, Vec<spirv::Word>>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
@@ -697,15 +698,21 @@ fn add_call_map_single<'input>(
}
}
-type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
-
-fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) {
+fn multi_hash_map_append<
+ K: Eq + std::hash::Hash,
+ V,
+ Collection: std::iter::Extend<V> + std::default::Default,
+>(
+ m: &mut HashMap<K, Collection>,
+ key: K,
+ value: V,
+) {
match m.entry(key) {
hash_map::Entry::Occupied(mut entry) => {
- entry.get_mut().push(value);
+ entry.get_mut().extend(iter::once(value));
}
hash_map::Entry::Vacant(entry) => {
- entry.insert(vec![value]);
+ entry.insert(Default::default());
}
}
}
@@ -713,7 +720,8 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
/*
PTX represents dynamically allocated shared local memory as
.extern .shared .b32 shared_mem[];
- In SPIRV/OpenCL world this is expressed as an additional argument
+ In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
+ And in AMD compilation
This pass looks for all uses of .extern .shared and converts them to
an additional method argument
The question is how this artificial argument should be expressed. There are
@@ -735,30 +743,35 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
*/
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
+ kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
new_id: &mut impl FnMut() -> spirv::Word,
) -> Vec<Directive<'input>> {
- let mut extern_shared_decls = HashMap::new();
+ let mut globals_shared = HashMap::new();
for dir in module.iter() {
match dir {
Directive::Variable(
linking,
ast::Variable {
- v_type: ast::Type::Array(p_type, dims),
state_space: ast::StateSpace::Shared,
name,
+ v_type,
..
},
- ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
- extern_shared_decls.insert(*name, *p_type);
+ ) => {
+ let size = if linking.contains(ast::LinkingDirective::EXTERN) {
+ GlobalSharedSize::ExternUnsized
+ } else {
+ GlobalSharedSize::Sized((*v_type).size_of())
+ };
+ globals_shared.insert(*name, (size, v_type.clone()));
}
_ => {}
}
}
- if extern_shared_decls.len() == 0 {
+ if globals_shared.len() == 0 {
return module;
}
- let mut methods_using_extern_shared = HashSet::new();
- let mut directly_called_by = MultiHashMap::new();
+ let mut methods_to_globals_shared_direct_only_use = HashMap::<_, GlobalSharedSize>::new();
let module = module
.into_iter()
.map(|directive| match directive {
@@ -773,17 +786,21 @@ fn convert_dynamic_shared_memory_usage<'input>(
let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
- .map(|statement| match statement {
- Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call.name, call_key);
- Statement::Call(call)
- }
- statement => statement.map_id(&mut |id, _| {
- if extern_shared_decls.contains_key(&id) {
- methods_using_extern_shared.insert(call_key);
+ .map(|statement| {
+ statement.map_id(&mut |id, _| {
+ if let Some((size, _)) = globals_shared.get(&id) {
+ match methods_to_globals_shared_direct_only_use.entry(call_key) {
+ hash_map::Entry::Occupied(mut e) => {
+ let original_size = *e.get();
+ e.insert(original_size.fold(*size));
+ }
+ hash_map::Entry::Vacant(mut e) => {
+ e.insert(*size);
+ }
+ }
}
id
- }),
+ })
})
.collect();
Directive::Method(Function {
@@ -800,11 +817,15 @@ fn convert_dynamic_shared_memory_usage<'input>(
.collect::<Vec<_>>();
// If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
// make sure it gets propagated to `fn1` and `kernel`
- get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by);
+ let (kernels_to_global_shared, functions_to_global_shared) =
+ resolve_indirect_uses_of_globals_shared(
+ methods_to_globals_shared_direct_only_use,
+ kernels_methods_call_map,
+ );
// now visit every method declaration and inject those additional arguments
- module
- .into_iter()
- .map(|directive| match directive {
+ let mut result = Vec::with_capacity(module.len());
+ for directive in module.into_iter() {
+ match directive {
Directive::Method(Function {
func_decl,
globals,
@@ -813,46 +834,119 @@ fn convert_dynamic_shared_memory_usage<'input>(
tuning,
linkage,
}) => {
- if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
- return Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- tuning,
- linkage,
- });
- }
- let shared_id_param = new_id();
- {
- let mut func_decl = (*func_decl).borrow_mut();
- func_decl.shared_mem = Some(shared_id_param);
- }
- let statements = replace_uses_of_shared_memory(
- new_id,
- &extern_shared_decls,
- &mut methods_using_extern_shared,
- shared_id_param,
- statements,
- );
- Directive::Method(Function {
+ let statements = {
+ let func_decl_ref = &mut (*func_decl).borrow_mut();
+ let method_name = func_decl_ref.name;
+ insert_arguments_remap_statements(
+ method_name,
+ &kernels_to_global_shared,
+ new_id,
+ &mut result,
+ &functions_to_global_shared,
+ func_decl_ref,
+ &globals_shared,
+ statements,
+ )
+ };
+ result.push(Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
tuning,
linkage,
- })
+ }));
}
- directive => directive,
- })
- .collect::<Vec<_>>()
+ directive => result.push(directive),
+ }
+ }
+ result
+}
+
+fn insert_arguments_remap_statements(
+ method_name: ast::MethodName<u32>,
+ kernels_to_global_shared: &HashMap<&str, GlobalSharedSize>,
+ new_id: &mut impl FnMut() -> u32,
+ result: &mut Vec<Directive>,
+ functions_to_global_shared: &HashSet<u32>,
+ func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<u32>>,
+ globals_shared: &HashMap<u32, (GlobalSharedSize, ast::Type)>,
+ statements: Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>,
+) -> Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>> {
+ let shared_id_param = match method_name {
+ ast::MethodName::Kernel(kernel_name) => {
+ let globals_shared_size = match kernels_to_global_shared.get(kernel_name) {
+ Some(s) => *s,
+ None => return statements,
+ };
+ let shared_id_param = new_id();
+ let (linkage, type_) = match globals_shared_size {
+ GlobalSharedSize::ExternUnsized => (
+ ast::LinkingDirective::EXTERN,
+ ast::Type::Array(ast::ScalarType::U8, Vec::new()),
+ ),
+ GlobalSharedSize::Sized(size) => (
+ ast::LinkingDirective::NONE,
+ ast::Type::Array(ast::ScalarType::U8, vec![size as u32]),
+ ),
+ };
+ result.push(Directive::Variable(
+ linkage,
+ ast::Variable {
+ align: None,
+ v_type: type_,
+ state_space: ast::StateSpace::Shared,
+ name: shared_id_param,
+ array_init: Vec::new(),
+ },
+ ));
+ shared_id_param
+ }
+ ast::MethodName::Func(function_name) => {
+ if !functions_to_global_shared.contains(&function_name) {
+ return statements;
+ }
+ let shared_id_param = new_id();
+ func_decl_ref.input_arguments.push(ast::Variable {
+ align: None,
+ v_type: ast::Type::Pointer(ast::ScalarType::B8, ast::StateSpace::Shared),
+ state_space: ast::StateSpace::Reg,
+ name: shared_id_param,
+ array_init: Vec::new(),
+ });
+ shared_id_param
+ }
+ };
+ replace_uses_of_shared_memory(
+ new_id,
+ globals_shared,
+ functions_to_global_shared,
+ shared_id_param,
+ statements,
+ )
+}
+
+#[derive(Clone, Copy, PartialEq, Eq, Hash)]
+enum GlobalSharedSize {
+ ExternUnsized,
+ Sized(usize),
+}
+
+impl GlobalSharedSize {
+ fn fold(self, other: GlobalSharedSize) -> GlobalSharedSize {
+ match (self, other) {
+ (GlobalSharedSize::Sized(s1), GlobalSharedSize::Sized(s2)) => {
+ GlobalSharedSize::Sized(usize::max(s1, s2))
+ }
+ _ => GlobalSharedSize::ExternUnsized,
+ }
+ }
}
fn replace_uses_of_shared_memory<'a>(
new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
- methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ extern_shared_decls: &HashMap<spirv::Word, (GlobalSharedSize, ast::Type)>,
+ methods_using_extern_shared: &HashSet<spirv::Word>,
shared_id_param: spirv::Word,
statements: Vec<ExpandedStatement>,
) -> Vec<ExpandedStatement> {
@@ -863,7 +957,7 @@ fn replace_uses_of_shared_memory<'a>(
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
+ if methods_using_extern_shared.contains(&call.name) {
call.input_arguments.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
@@ -874,8 +968,8 @@ fn replace_uses_of_shared_memory<'a>(
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
- if let Some(scalar_type) = extern_shared_decls.get(&id) {
- if *scalar_type == ast::ScalarType::B8 {
+ if let Some((_, type_)) = extern_shared_decls.get(&id) {
+ if *type_ == ast::Type::Scalar(ast::ScalarType::B8) {
return shared_id_param;
}
let replacement_id = new_id();
@@ -884,7 +978,7 @@ fn replace_uses_of_shared_memory<'a>(
dst: replacement_id,
from_type: ast::Type::Scalar(ast::ScalarType::B8),
from_space: ast::StateSpace::Shared,
- to_type: ast::Type::Scalar(*scalar_type),
+ to_type: type_.clone(),
to_space: ast::StateSpace::Shared,
kind: ConversionKind::PtrToPtr,
}));
@@ -900,43 +994,40 @@ fn replace_uses_of_shared_memory<'a>(
result
}
-fn get_callers_of_extern_shared<'a>(
- methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
- directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
-) {
- let direct_uses_of_extern_shared = methods_using_extern_shared
- .iter()
- .filter_map(|method| {
- if let ast::MethodName::Func(f_id) = method {
- Some(*f_id)
- } else {
- None
- }
- })
- .collect::<Vec<_>>();
- for fn_id in direct_uses_of_extern_shared {
- get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id);
- }
-}
-
-fn get_callers_of_extern_shared_single<'a>(
- methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
- directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
- fn_id: spirv::Word,
-) {
- if let Some(callers) = directly_called_by.get(&fn_id) {
- for caller in callers {
- if methods_using_extern_shared.insert(*caller) {
- if let ast::MethodName::Func(caller_fn) = caller {
- get_callers_of_extern_shared_single(
- methods_using_extern_shared,
- directly_called_by,
- *caller_fn,
- );
+// We need to compute two kinds of information:
+// * If it's a kernel -> size of .shared globals in use (direct or indirect)
+// * If it's a function -> does it use .shared global (directly or indirectly)
+fn resolve_indirect_uses_of_globals_shared<'input>(
+ mut methods_use_of_globals_shared: HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ GlobalSharedSize,
+ >,
+ kernels_methods_call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
+) -> (HashMap<&'input str, GlobalSharedSize>, HashSet<spirv::Word>) {
+ let mut kernel_use = HashMap::new();
+ let mut functions_using_global = HashSet::new();
+ let empty = HashSet::new();
+ for (method, globals) in methods_use_of_globals_shared.iter() {
+ match method {
+ ast::MethodName::Kernel(kernel_name) => {
+ let mut size = *globals;
+ for &called_subfunction in
+ kernels_methods_call_map.get(kernel_name).unwrap_or(&empty)
+ {
+ if let Some(new_size) = methods_use_of_globals_shared
+ .get(&ast::MethodName::Func(called_subfunction))
+ {
+ size = size.fold(*new_size);
+ }
}
+ kernel_use.insert(*kernel_name, size);
+ }
+ ast::MethodName::Func(fn_id) => {
+ functions_using_global.insert(*fn_id);
}
}
}
+ (kernel_use, functions_using_global)
}
type DenormCountMap<T> = HashMap<T, isize>;
@@ -3480,7 +3571,10 @@ fn emit_variable<'input>(
[dr::Operand::LiteralInt32(align)].iter().cloned(),
);
}
- emit_linking_decoration(builder, id_defs, None, var.name, linking);
+ if var.state_space != ast::StateSpace::Shared || !linking.contains(ast::LinkingDirective::EXTERN)
+ {
+ emit_linking_decoration(builder, id_defs, None, var.name, linking);
+ }
Ok(())
}
@@ -3494,9 +3588,9 @@ fn emit_linking_decoration<'input>(
if linking == ast::LinkingDirective::NONE {
return;
}
- let string_name =
- name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
if linking.contains(ast::LinkingDirective::VISIBLE) {
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
builder.decorate(
name,
spirv::Decoration::LinkageAttributes,
@@ -3508,6 +3602,8 @@ fn emit_linking_decoration<'input>(
.cloned(),
);
} else if linking.contains(ast::LinkingDirective::EXTERN) {
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
builder.decorate(
name,
spirv::Decoration::LinkageAttributes,
@@ -4454,7 +4550,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
})
.collect::<HashSet<_>>();
let mut stateful_markers = Vec::new();
- let mut stateful_init_reg = MultiHashMap::new();
+ let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Cvta(
@@ -7863,26 +7959,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
let is_kernel = self.name.is_kernel();
- self.input_arguments
- .iter()
- .map(move |arg| {
- if !is_kernel && arg.state_space != ast::StateSpace::Reg {
- let spirv_type =
- SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
- (arg.name, spirv_type)
- } else {
- (arg.name, SpirvType::new(arg.v_type.clone()))
- }
- })
- .chain(self.shared_mem.iter().map(|id| {
- (
- *id,
- SpirvType::Pointer(
- Box::new(SpirvType::Base(SpirvScalarKey::B8)),
- spirv::StorageClass::Workgroup,
- ),
- )
- }))
+ self.input_arguments.iter().map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
}
}