aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/normalize_predicates.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/pass/normalize_predicates.rs')
-rw-r--r--ptx/src/pass/normalize_predicates.rs44
1 files changed, 44 insertions, 0 deletions
diff --git a/ptx/src/pass/normalize_predicates.rs b/ptx/src/pass/normalize_predicates.rs
new file mode 100644
index 0000000..c971cfa
--- /dev/null
+++ b/ptx/src/pass/normalize_predicates.rs
@@ -0,0 +1,44 @@
+use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run(
+ func: Vec<NormalizedStatement>,
+ id_def: &mut NumericIdResolver,
+) -> Result<Vec<UnconditionalStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Instruction((pred, inst)) => {
+ if let Some(pred) = pred {
+ let if_true = id_def.register_intermediate(None);
+ let if_false = id_def.register_intermediate(None);
+ let folded_bra = match &inst {
+ ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
+ _ => None,
+ };
+ let mut branch = BrachCondition {
+ predicate: pred.label,
+ if_true: folded_bra.unwrap_or(if_true),
+ if_false,
+ };
+ if pred.not {
+ std::mem::swap(&mut branch.if_true, &mut branch.if_false);
+ }
+ result.push(Statement::Conditional(branch));
+ if folded_bra.is_none() {
+ result.push(Statement::Label(if_true));
+ result.push(Statement::Instruction(inst));
+ }
+ result.push(Statement::Label(if_false));
+ } else {
+ result.push(Statement::Instruction(inst));
+ }
+ }
+ Statement::Variable(var) => result.push(Statement::Variable(var)),
+ // Blocks are flattened when resolving ids
+ _ => return Err(error_unreachable()),
+ }
+ }
+ Ok(result)
+}