aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/insert_explicit_load_store.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/pass/insert_explicit_load_store.rs')
-rw-r--r--ptx/src/pass/insert_explicit_load_store.rs101
1 files changed, 83 insertions, 18 deletions
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs
index e8f01cd..ec6498c 100644
--- a/ptx/src/pass/insert_explicit_load_store.rs
+++ b/ptx/src/pass/insert_explicit_load_store.rs
@@ -41,10 +41,9 @@ fn run_method<'a, 'input>(
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
for arg in func_decl.return_arguments.iter_mut() {
- visitor.visit_variable(arg);
+ visitor.visit_variable(arg)?;
}
let is_kernel = func_decl.name.is_kernel();
- // let mut prelude = Vec::with_capacity(method.body.as_ref().map(Vec::len).unwrap_or(0));
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
let old_name = arg.name;
@@ -85,23 +84,29 @@ fn run_statement<'a, 'input>(
) -> Result<(), TranslateError> {
match statement {
Statement::Variable(mut var) => {
- visitor.visit_variable(&mut var);
+ visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
}
Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
let instruction = visitor.visit_ld(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
- Statement::Instruction(ast::Instruction::St {
- data,
- mut arguments,
- }) => {
+ Statement::Instruction(ast::Instruction::St { data, arguments }) => {
let instruction = visitor.visit_st(data, arguments)?;
let instruction = ast::visit_map(instruction, visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
result.push(Statement::Instruction(instruction));
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
+ s => {
+ let new_statement = s.visit_map(visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(new_statement);
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
}
- s => result.push(s.visit_map(visitor)?),
}
Ok(())
}
@@ -109,6 +114,8 @@ fn run_statement<'a, 'input>(
struct InsertMemSSAVisitor<'a, 'input> {
resolver: &'a mut GlobalStringIdentResolver2<'input>,
variables: FxHashMap<SpirvWord, RemapAction>,
+ pre: Vec<ast::Instruction<SpirvWord>>,
+ post: Vec<ast::Instruction<SpirvWord>>,
}
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
@@ -116,6 +123,8 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
Self {
resolver,
variables: FxHashMap::default(),
+ pre: Vec::new(),
+ post: Vec::new(),
}
}
@@ -141,14 +150,20 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn variable(
&mut self,
+ type_: &ast::Type,
old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
) -> Result<(), TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
- self.variables
- .insert(old_name, RemapAction::PreLdPostSt(new_name));
+ self.variables.insert(
+ old_name,
+ RemapAction::PreLdPostSt {
+ name: new_name,
+ type_: type_.clone(),
+ },
+ );
}
ast::StateSpace::Param => {
self.variables.insert(
@@ -182,7 +197,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src1) {
match remap {
- RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
+ RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
@@ -206,7 +221,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
if let Some(remap) = self.variables.get(&arguments.src) {
match remap {
- RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
+ RemapAction::PreLdPostSt { .. } => {}
RemapAction::LDStSpaceChange {
old_space,
new_space,
@@ -223,7 +238,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
Ok(ast::Instruction::Ld { data, arguments })
}
- fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) {
+ fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
if var.state_space != ast::StateSpace::Local {
let old_name = var.name;
let old_space = var.state_space;
@@ -231,10 +246,11 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
let new_name = self
.resolver
.register_unnamed(Some((var.v_type.clone(), new_space)));
- self.variable(old_name, new_name, old_space);
+ self.variable(&var.v_type, old_name, new_name, old_space)?;
var.name = new_name;
var.state_space = new_space;
}
+ Ok(())
}
}
@@ -243,12 +259,58 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
{
fn visit(
&mut self,
- args: SpirvWord,
+ ident: SpirvWord,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
- todo!()
+ if let Some(remap) = self.variables.get(&ident) {
+ match remap {
+ RemapAction::PreLdPostSt { name, type_ } => {
+ if is_dst {
+ let temp = self
+ .resolver
+ .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
+ self.post.push(ast::Instruction::St {
+ data: ast::StData {
+ state_space: ast::StateSpace::Local,
+ qualifier: ast::LdStQualifier::Weak,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: type_.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: *name,
+ src2: temp,
+ },
+ });
+ Ok(temp)
+ } else {
+ let temp = self
+ .resolver
+ .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg)));
+ self.pre.push(ast::Instruction::Ld {
+ data: ast::LdDetails {
+ state_space: ast::StateSpace::Local,
+ qualifier: ast::LdStQualifier::Weak,
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_.clone(),
+ non_coherent: false,
+ },
+ arguments: ast::LdArgs {
+ dst: temp,
+ src: *name,
+ },
+ });
+ Ok(temp)
+ }
+ }
+ RemapAction::LDStSpaceChange { .. } => {
+ return Err(error_mismatched_type());
+ }
+ }
+ } else {
+ Ok(ident)
+ }
}
fn visit_ident(
@@ -262,9 +324,12 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
}
}
-#[derive(Clone, Copy)]
+#[derive(Clone)]
enum RemapAction {
- PreLdPostSt(SpirvWord),
+ PreLdPostSt {
+ name: SpirvWord,
+ type_: ast::Type,
+ },
LDStSpaceChange {
old_space: ast::StateSpace,
new_space: ast::StateSpace,