From c92abba2bb884a4dba8ca5e3df4d46a30878f27e Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 23 Sep 2024 16:33:46 +0200 Subject: Refactor compilation passes (#270) The overarching goal is to refactor all passes so they are module-scoped and not function-scoped. Additionally, make improvements to the most egregiously buggy/unfit passes (so the code is ready for the next major features: linking, ftz handling) and continue adding more code to the LLVM backend --- ptx/src/pass/normalize_predicates2.rs | 84 +++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 ptx/src/pass/normalize_predicates2.rs (limited to 'ptx/src/pass/normalize_predicates2.rs') diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs new file mode 100644 index 0000000..d91e23c --- /dev/null +++ b/ptx/src/pass/normalize_predicates2.rs @@ -0,0 +1,84 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec>, +) -> Result>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: NormalizedDirective2<'input>, +) -> Result, TranslateError> { + Ok(match directive { + Directive2::Variable(linking, var) => Directive2::Variable(linking, var), + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + method: NormalizedFunction2<'input>, +) -> Result, TranslateError> { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(resolver, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, + statement: NormalizedStatement, +) -> Result<(), TranslateError> { + Ok(match statement { + Statement::Label(label) => result.push(Statement::Label(label)), + Statement::Variable(var) => result.push(Statement::Variable(var)), + Statement::Instruction((predicate, instruction)) => { + if let Some(pred) = predicate { + let if_true = resolver.register_unnamed(None); + let if_false = resolver.register_unnamed(None); + let folded_bra = match &instruction { + ast::Instruction::Bra { arguments, .. } => Some(arguments.src), + _ => None, + }; + let mut branch = BrachCondition { + predicate: pred.label, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + if pred.not { + std::mem::swap(&mut branch.if_true, &mut branch.if_false); + } + result.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + result.push(Statement::Label(if_true)); + result.push(Statement::Instruction(instruction)); + } + result.push(Statement::Label(if_false)); + } else { + result.push(Statement::Instruction(instruction)); + } + } + _ => return Err(error_unreachable()), + }) +} -- cgit v1.2.3