aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-06-11 00:00:56 +0200
committerAndrzej Janik <[email protected]>2021-06-11 00:00:56 +0200
commitf0771e1fb6bb95e3f22b8bfa3a9efd3bfe88c946 (patch)
treee9311d5bdc0a910a0f50a20e2d06262fcbf72bc2
parent994cfb338655048ac274f913582aed214102b3d9 (diff)
downloadZLUDA-f0771e1fb6bb95e3f22b8bfa3a9efd3bfe88c946.tar.gz
ZLUDA-f0771e1fb6bb95e3f22b8bfa3a9efd3bfe88c946.zip
Slightly improve stateful optimization
-rw-r--r--ptx/src/translate.rs156
-rw-r--r--zluda_dump/src/lib.rs21
2 files changed, 95 insertions, 82 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 4c1c0e7..511d763 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1239,8 +1239,8 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
- let typed_statements =
- convert_to_stateful_memory_access(&func_decl, typed_statements, &mut numeric_id_defs)?;
+ let (func_decl, typed_statements) =
+ convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
@@ -4311,14 +4311,27 @@ fn expand_map_variables<'a, 'b>(
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
-// TODO: propagate through calls?
+// TODO: propagate out of calls and into calls
fn convert_to_stateful_memory_access<'a, 'input>(
- func_args: &Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+ func_args: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let mut func_args = func_args.borrow_mut();
- let func_args_64bit = (*func_args)
+) -> Result<
+ (
+ Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+ Vec<TypedStatement>,
+ ),
+ TranslateError,
+> {
+ let mut method_decl = func_args.borrow_mut();
+ if !method_decl.name.is_kernel() {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
+ if Rc::strong_count(&func_args) != 1 {
+ return Err(error_unreachable());
+ }
+ let func_args_64bit = (*method_decl)
.input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
@@ -4462,6 +4475,18 @@ fn convert_to_stateful_memory_access<'a, 'input>(
}));
remapped_ids.insert(reg, new_id);
}
+ for arg in (*method_decl).input_arguments.iter_mut() {
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Reg,
+ );
+ let old_name = arg.name;
+ if func_args_ptr.contains(&arg.name) {
+ arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
+ arg.name = new_id;
+ }
+ remapped_ids.insert(old_name, new_id);
+ }
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
@@ -4550,7 +4575,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@@ -4567,7 +4591,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@@ -4584,7 +4607,6 @@ fn convert_to_stateful_memory_access<'a, 'input>(
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
@@ -4597,81 +4619,69 @@ fn convert_to_stateful_memory_access<'a, 'input>(
_ => return Err(error_unreachable()),
}
}
- for arg in (*func_args).input_arguments.iter_mut() {
- if func_args_ptr.contains(&arg.name) {
- arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
- arg.state_space = ast::StateSpace::Reg;
- }
- }
- Ok(result)
+ drop(method_decl);
+ Ok((func_args, result))
}
fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
- func_args_ptr: &HashSet<spirv::Word>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
arg_desc: ArgumentDescriptor<spirv::Word>,
expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
- Ok(
- match remapped_ids
- .get(&arg_desc.op)
- .or_else(|| func_args_ptr.get(&arg_desc.op))
- {
- Some(new_id) => {
- let (new_operand_type, new_operand_space, is_variable) =
- id_defs.get_typed(*new_id)?;
- if let Some((expected_type, expected_space)) = expected_type {
- let implicit_conversion = arg_desc
- .non_default_implicit_conversion
- .unwrap_or(default_implicit_conversion);
- if implicit_conversion(
- (new_operand_space, &new_operand_type),
- (expected_space, expected_type),
- )
- .is_ok()
- {
- return Ok(*new_id);
- }
- }
- let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
- let new_operand_type_clone = new_operand_type.clone();
- let converting_id = id_defs
- .register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
- let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
- ConversionKind::Default
- } else {
- ConversionKind::PtrToPtr
- };
- if arg_desc.is_dst {
- post_statements.push(Statement::Conversion(ImplicitConversion {
- src: converting_id,
- dst: *new_id,
- from_type: old_operand_type,
- from_space: old_operand_space,
- to_type: new_operand_type,
- to_space: new_operand_space,
- kind,
- }));
- converting_id
- } else {
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from_type: new_operand_type,
- from_space: new_operand_space,
- to_type: old_operand_type,
- to_space: old_operand_space,
- kind,
- }));
- converting_id
+ Ok(match remapped_ids.get(&arg_desc.op) {
+ Some(new_id) => {
+ let (new_operand_type, new_operand_space, is_variable) = id_defs.get_typed(*new_id)?;
+ if let Some((expected_type, expected_space)) = expected_type {
+ let implicit_conversion = arg_desc
+ .non_default_implicit_conversion
+ .unwrap_or(default_implicit_conversion);
+ if implicit_conversion(
+ (new_operand_space, &new_operand_type),
+ (expected_space, expected_type),
+ )
+ .is_ok()
+ {
+ return Ok(*new_id);
}
}
- None => arg_desc.op,
- },
- )
+ let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
+ let new_operand_type_clone = new_operand_type.clone();
+ let converting_id =
+ id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
+ let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
+ ConversionKind::Default
+ } else {
+ ConversionKind::PtrToPtr
+ };
+ if arg_desc.is_dst {
+ post_statements.push(Statement::Conversion(ImplicitConversion {
+ src: converting_id,
+ dst: *new_id,
+ from_type: old_operand_type,
+ from_space: old_operand_space,
+ to_type: new_operand_type,
+ to_space: new_operand_space,
+ kind,
+ }));
+ converting_id
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from_type: new_operand_type,
+ from_space: new_operand_space,
+ to_type: old_operand_type,
+ to_space: old_operand_space,
+ kind,
+ }));
+ converting_id
+ }
+ }
+ None => arg_desc.op,
+ })
}
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs
index f168930..ffd1498 100644
--- a/zluda_dump/src/lib.rs
+++ b/zluda_dump/src/lib.rs
@@ -219,15 +219,18 @@ unsafe fn to_str<T>(image: *const T) -> Option<&'static str> {
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
match dir {
- ast::Directive::Method(ast::Function {
- func_directive:
- ast::MethodDeclaration {
- name: ast::MethodName::Kernel(name),
- input_arguments,
- ..
- },
- ..
- }) => {
+ ast::Directive::Method(
+ _,
+ ast::Function {
+ func_directive:
+ ast::MethodDeclaration {
+ name: ast::MethodName::Kernel(name),
+ input_arguments,
+ ..
+ },
+ ..
+ },
+ ) => {
let arg_sizes = input_arguments
.iter()
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of())