diff options
Diffstat (limited to 'ptx/src/pass/insert_explicit_load_store.rs')
-rw-r--r-- | ptx/src/pass/insert_explicit_load_store.rs | 90 |
1 files changed, 61 insertions, 29 deletions
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index ec6498c..60c4a14 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -1,7 +1,4 @@ use super::*;
-use ptx_parser::VisitorMap;
-use rustc_hash::FxHashSet;
-
// This pass:
// * Turns all .local, .param and .reg in-body variables into .local variables
// (if _not_ an input method argument)
@@ -40,9 +37,6 @@ fn run_method<'a, 'input>( method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
- for arg in func_decl.return_arguments.iter_mut() {
- visitor.visit_variable(arg)?;
- }
let is_kernel = func_decl.name.is_kernel();
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
@@ -52,17 +46,21 @@ fn run_method<'a, 'input>( let new_name = visitor
.resolver
.register_unnamed(Some((arg.v_type.clone(), new_space)));
- visitor.input_argument(old_name, new_name, old_space);
+ visitor.input_argument(old_name, new_name, old_space)?;
arg.name = new_name;
arg.state_space = new_space;
}
};
+ for arg in func_decl.return_arguments.iter_mut() {
+ visitor.visit_variable(arg)?;
+ }
+ let return_arguments = &func_decl.return_arguments[..];
let body = method
.body
.map(move |statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
- run_statement(&mut visitor, &mut result, statement)?;
+ run_statement(&mut visitor, return_arguments, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
@@ -79,10 +77,33 @@ fn run_method<'a, 'input>( fn run_statement<'a, 'input>(
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
+ return_arguments: &[ast::Variable<SpirvWord>],
result: &mut Vec<ExpandedStatement>,
statement: ExpandedStatement,
) -> Result<(), TranslateError> {
match statement {
+ Statement::Instruction(ast::Instruction::Ret { data }) => {
+ let statement = if return_arguments.is_empty() {
+ Statement::Instruction(ast::Instruction::Ret { data })
+ } else {
+ Statement::RetValue(
+ data,
+ return_arguments
+ .iter()
+ .map(|arg| {
+ if arg.state_space != ast::StateSpace::Local {
+ return Err(error_unreachable());
+ }
+ Ok((arg.name, arg.v_type.clone()))
+ })
+ .collect::<Result<Vec<_>, _>>()?,
+ )
+ };
+ let new_statement = statement.visit_map(visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(new_statement);
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
Statement::Variable(mut var) => {
visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
@@ -154,7 +175,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
- ) -> Result<(), TranslateError> {
+ ) -> Result<bool, TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables.insert(
@@ -164,6 +185,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { type_: type_.clone(),
},
);
+ true
}
ast::StateSpace::Param => {
self.variables.insert(
@@ -174,19 +196,18 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { name: new_name,
},
);
+ true
}
// Good as-is
- ast::StateSpace::Local => {}
- // Will be pulled into global scope later
- ast::StateSpace::Generic
+ ast::StateSpace::Local
+ | ast::StateSpace::Generic
| ast::StateSpace::SharedCluster
| ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::SharedCta
- | ast::StateSpace::Shared => {}
- ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
- return Err(error_unreachable())
- }
+ | ast::StateSpace::Shared
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc => return Err(error_unreachable()),
})
}
@@ -239,17 +260,28 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
- if var.state_space != ast::StateSpace::Local {
- let old_name = var.name;
- let old_space = var.state_space;
- let new_space = ast::StateSpace::Local;
- let new_name = self
- .resolver
- .register_unnamed(Some((var.v_type.clone(), new_space)));
- self.variable(&var.v_type, old_name, new_name, old_space)?;
- var.name = new_name;
- var.state_space = new_space;
- }
+ let old_space = match var.state_space {
+ space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
+ // Do nothing
+ ptx_parser::StateSpace::Local => return Ok(()),
+ // Handled by another pass
+ ptx_parser::StateSpace::Generic
+ | ptx_parser::StateSpace::SharedCluster
+ | ptx_parser::StateSpace::ParamEntry
+ | ptx_parser::StateSpace::Global
+ | ptx_parser::StateSpace::SharedCta
+ | ptx_parser::StateSpace::Const
+ | ptx_parser::StateSpace::Shared
+ | ptx_parser::StateSpace::ParamFunc => return Ok(()),
+ };
+ let old_name = var.name;
+ let new_space = ast::StateSpace::Local;
+ let new_name = self
+ .resolver
+ .register_unnamed(Some((var.v_type.clone(), new_space)));
+ self.variable(&var.v_type, old_name, new_name, old_space)?;
+ var.name = new_name;
+ var.state_space = new_space;
Ok(())
}
}
@@ -260,9 +292,9 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError> fn visit(
&mut self,
ident: SpirvWord,
- type_space: Option<(&ast::Type, ast::StateSpace)>,
+ _type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
- relaxed_type_check: bool,
+ _relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
if let Some(remap) = self.variables.get(&ident) {
match remap {
|