aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/insert_explicit_load_store.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/pass/insert_explicit_load_store.rs')
-rw-r--r--ptx/src/pass/insert_explicit_load_store.rs42
1 files changed, 42 insertions, 0 deletions
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs
index 60c4a14..702f733 100644
--- a/ptx/src/pass/insert_explicit_load_store.rs
+++ b/ptx/src/pass/insert_explicit_load_store.rs
@@ -122,6 +122,13 @@ fn run_statement<'a, 'input>(
result.push(Statement::Instruction(instruction));
result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
+ Statement::PtrAccess(ptr_access) => {
+ let statement = Statement::PtrAccess(visitor.visit_ptr_access(ptr_access)?);
+ let statement = statement.visit_map(visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(statement);
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
s => {
let new_statement = s.visit_map(visitor)?;
result.extend(visitor.pre.drain(..).map(Statement::Instruction));
@@ -259,6 +266,41 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
Ok(ast::Instruction::Ld { data, arguments })
}
+ fn visit_ptr_access(
+ &mut self,
+ ptr_access: PtrAccess<SpirvWord>,
+ ) -> Result<PtrAccess<SpirvWord>, TranslateError> {
+ let (old_space, new_space, name) = match self.variables.get(&ptr_access.ptr_src) {
+ Some(RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name,
+ }) => (*old_space, *new_space, *name),
+ Some(RemapAction::PreLdPostSt { .. }) | None => return Ok(ptr_access),
+ };
+ if ptr_access.state_space != old_space {
+ return Err(error_mismatched_type());
+ }
+ // Propagate space changes in dst
+ let new_dst = self
+ .resolver
+ .register_unnamed(Some((ptr_access.underlying_type.clone(), new_space)));
+ self.variables.insert(
+ ptr_access.dst,
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name: new_dst,
+ },
+ );
+ Ok(PtrAccess {
+ ptr_src: name,
+ dst: new_dst,
+ state_space: new_space,
+ ..ptr_access
+ })
+ }
+
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
let old_space = match var.state_space {
space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,