diff options
Diffstat (limited to 'ptx/src/pass/deparamize_functions.rs')
-rw-r--r-- | ptx/src/pass/deparamize_functions.rs | 185 |
1 files changed, 121 insertions, 64 deletions
diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index 04c8831..15125b0 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -1,5 +1,3 @@ -use std::collections::BTreeMap;
-
use super::*;
pub(super) fn run<'a, 'input>(
@@ -26,75 +24,73 @@ fn run_method<'input>( resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
- if method.func_decl.name.is_kernel() {
- return Ok(method);
- }
let is_declaration = method.body.is_none();
let mut body = Vec::new();
let mut remap_returns = Vec::new();
- for arg in method.func_decl.return_arguments.iter_mut() {
- match arg.state_space {
- ptx_parser::StateSpace::Param => {
- arg.state_space = ptx_parser::StateSpace::Reg;
- let old_name = arg.name;
- arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
- if is_declaration {
- continue;
+ if !method.func_decl.name.is_kernel() {
+ for arg in method.func_decl.return_arguments.iter_mut() {
+ match arg.state_space {
+ ptx_parser::StateSpace::Param => {
+ arg.state_space = ptx_parser::StateSpace::Reg;
+ let old_name = arg.name;
+ arg.name =
+ resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
+ if is_declaration {
+ continue;
+ }
+ remap_returns.push((old_name, arg.name, arg.v_type.clone()));
+ body.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: old_name,
+ v_type: arg.v_type.clone(),
+ state_space: ptx_parser::StateSpace::Param,
+ array_init: Vec::new(),
+ }));
}
- remap_returns.push((old_name, arg.name, arg.v_type.clone()));
- body.push(Statement::Variable(ast::Variable {
- align: None,
- name: old_name,
- v_type: arg.v_type.clone(),
- state_space: ptx_parser::StateSpace::Param,
- array_init: Vec::new(),
- }));
+ ptx_parser::StateSpace::Reg => {}
+ _ => return Err(error_unreachable()),
}
- ptx_parser::StateSpace::Reg => {}
- _ => return Err(error_unreachable()),
}
- }
- for arg in method.func_decl.input_arguments.iter_mut() {
- match arg.state_space {
- ptx_parser::StateSpace::Param => {
- arg.state_space = ptx_parser::StateSpace::Reg;
- let old_name = arg.name;
- arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
- if is_declaration {
- continue;
+ for arg in method.func_decl.input_arguments.iter_mut() {
+ match arg.state_space {
+ ptx_parser::StateSpace::Param => {
+ arg.state_space = ptx_parser::StateSpace::Reg;
+ let old_name = arg.name;
+ arg.name =
+ resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
+ if is_declaration {
+ continue;
+ }
+ body.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: old_name,
+ v_type: arg.v_type.clone(),
+ state_space: ptx_parser::StateSpace::Param,
+ array_init: Vec::new(),
+ }));
+ body.push(Statement::Instruction(ast::Instruction::St {
+ data: ast::StData {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: arg.v_type.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: old_name,
+ src2: arg.name,
+ },
+ }));
}
- body.push(Statement::Variable(ast::Variable {
- align: None,
- name: old_name,
- v_type: arg.v_type.clone(),
- state_space: ptx_parser::StateSpace::Param,
- array_init: Vec::new(),
- }));
- body.push(Statement::Instruction(ast::Instruction::St {
- data: ast::StData {
- qualifier: ast::LdStQualifier::Weak,
- state_space: ast::StateSpace::Param,
- caching: ast::StCacheOperator::Writethrough,
- typ: arg.v_type.clone(),
- },
- arguments: ast::StArgs {
- src1: old_name,
- src2: arg.name,
- },
- }));
+ ptx_parser::StateSpace::Reg => {}
+ _ => return Err(error_unreachable()),
}
- ptx_parser::StateSpace::Reg => {}
- _ => return Err(error_unreachable()),
}
}
- if remap_returns.is_empty() {
- return Ok(method);
- }
let body = method
.body
.map(|statements| {
for statement in statements {
- run_statement(&remap_returns, &mut body, statement)?;
+ run_statement(resolver, &remap_returns, &mut body, statement)?;
}
Ok::<_, TranslateError>(body)
})
@@ -110,28 +106,89 @@ fn run_method<'input>( }
fn run_statement<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match statement {
- Statement::Instruction(ast::Instruction::Ret { .. }) => {
- for (old_name, new_name, type_) in remap_returns.iter().cloned() {
+ Statement::Instruction(ast::Instruction::Call {
+ mut data,
+ mut arguments,
+ }) => {
+ let mut post_st = Vec::new();
+ for ((type_, space), ident) in data
+ .input_arguments
+ .iter_mut()
+ .zip(arguments.input_arguments.iter_mut())
+ {
+ if *space == ptx_parser::StateSpace::Param {
+ *space = ptx_parser::StateSpace::Reg;
+ let old_name = *ident;
+ *ident = resolver
+ .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
+ result.push(Statement::Instruction(ast::Instruction::Ld {
+ data: ast::LdDetails {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_.clone(),
+ non_coherent: false,
+ },
+ arguments: ast::LdArgs {
+ dst: *ident,
+ src: old_name,
+ },
+ }));
+ }
+ }
+ for ((type_, space), ident) in data
+ .return_arguments
+ .iter_mut()
+ .zip(arguments.return_arguments.iter_mut())
+ {
+ if *space == ptx_parser::StateSpace::Param {
+ *space = ptx_parser::StateSpace::Reg;
+ let old_name = *ident;
+ *ident = resolver
+ .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
+ post_st.push(Statement::Instruction(ast::Instruction::St {
+ data: ast::StData {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: type_.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: old_name,
+ src2: *ident,
+ },
+ }));
+ }
+ }
+ result.push(Statement::Instruction(ast::Instruction::Call {
+ data,
+ arguments,
+ }));
+ result.extend(post_st.into_iter());
+ }
+ Statement::Instruction(ast::Instruction::Ret { data }) => {
+ for (old_name, new_name, type_) in remap_returns.iter() {
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
- state_space: ast::StateSpace::Reg,
+ state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached,
- typ: type_,
+ typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
- dst: new_name,
- src: old_name,
+ dst: *new_name,
+ src: *old_name,
},
}));
}
- result.push(statement);
+ result.push(Statement::Instruction(ast::Instruction::Ret { data }));
}
statement => {
result.push(statement);
|