aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/normalize_predicates2.rs
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-09-23 16:33:46 +0200
committerGitHub <[email protected]>2024-09-23 16:33:46 +0200
commitc92abba2bb884a4dba8ca5e3df4d46a30878f27e (patch)
tree89bab98e3071aedd12f755bfde8a7c7382138ed7 /ptx/src/pass/normalize_predicates2.rs
parent46def3e7e09dbf4d3e7287a72bfecb73e6e429c5 (diff)
downloadZLUDA-c92abba2bb884a4dba8ca5e3df4d46a30878f27e.tar.gz
ZLUDA-c92abba2bb884a4dba8ca5e3df4d46a30878f27e.zip
Refactor compilation passes (#270)
The overarching goal is to refactor all passes so they are module-scoped and not function-scoped. Additionally, make improvements to the most egregiously buggy/unfit passes (so the code is ready for the next major features: linking, ftz handling) and continue adding more code to the LLVM backend
Diffstat (limited to 'ptx/src/pass/normalize_predicates2.rs')
-rw-r--r--ptx/src/pass/normalize_predicates2.rs84
1 files changed, 84 insertions, 0 deletions
diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs
new file mode 100644
index 0000000..d91e23c
--- /dev/null
+++ b/ptx/src/pass/normalize_predicates2.rs
@@ -0,0 +1,84 @@
+use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<NormalizedDirective2<'input>>,
+) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: NormalizedDirective2<'input>,
+) -> Result<UnconditionalDirective<'input>, TranslateError> {
+ Ok(match directive {
+ Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
+ Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
+ })
+}
+
+fn run_method<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ method: NormalizedFunction2<'input>,
+) -> Result<UnconditionalFunction<'input>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(resolver, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ result: &mut Vec<UnconditionalStatement>,
+ statement: NormalizedStatement,
+) -> Result<(), TranslateError> {
+ Ok(match statement {
+ Statement::Label(label) => result.push(Statement::Label(label)),
+ Statement::Variable(var) => result.push(Statement::Variable(var)),
+ Statement::Instruction((predicate, instruction)) => {
+ if let Some(pred) = predicate {
+ let if_true = resolver.register_unnamed(None);
+ let if_false = resolver.register_unnamed(None);
+ let folded_bra = match &instruction {
+ ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
+ _ => None,
+ };
+ let mut branch = BrachCondition {
+ predicate: pred.label,
+ if_true: folded_bra.unwrap_or(if_true),
+ if_false,
+ };
+ if pred.not {
+ std::mem::swap(&mut branch.if_true, &mut branch.if_false);
+ }
+ result.push(Statement::Conditional(branch));
+ if folded_bra.is_none() {
+ result.push(Statement::Label(if_true));
+ result.push(Statement::Instruction(instruction));
+ }
+ result.push(Statement::Label(if_false));
+ } else {
+ result.push(Statement::Instruction(instruction));
+ }
+ }
+ _ => return Err(error_unreachable()),
+ })
+}