aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/normalize_labels.rs
blob: 037e918d7c1282e5bb7419eebeb905f11f81a84e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use std::{collections::HashSet, iter};

use super::*;

pub(super) fn run(
    func: Vec<ExpandedStatement>,
    id_def: &mut NumericIdResolver,
) -> Vec<ExpandedStatement> {
    let mut labels_in_use = HashSet::new();
    for s in func.iter() {
        match s {
            Statement::Instruction(i) => {
                if let Some(target) = jump_target(i) {
                    labels_in_use.insert(target);
                }
            }
            Statement::Conditional(cond) => {
                labels_in_use.insert(cond.if_true);
                labels_in_use.insert(cond.if_false);
            }
            Statement::Variable(..)
            | Statement::LoadVar(..)
            | Statement::StoreVar(..)
            | Statement::RetValue(..)
            | Statement::Conversion(..)
            | Statement::Constant(..)
            | Statement::Label(..)
            | Statement::PtrAccess { .. }
            | Statement::VectorAccess { .. }
            | Statement::RepackVector(..)
            | Statement::FunctionPointer(..) => {}
        }
    }
    iter::once(Statement::Label(id_def.register_intermediate(None)))
        .chain(func.into_iter().filter(|s| match s {
            Statement::Label(i) => labels_in_use.contains(i),
            _ => true,
        }))
        .collect::<Vec<_>>()
}

fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
    this: &ast::Instruction<T>,
) -> Option<SpirvWord> {
    match this {
        ast::Instruction::Bra { arguments } => Some(arguments.src),
        _ => None,
    }
}