aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/hoist_globals.rs
blob: 718c05242a1bc15191317c69047ad846ed19da31 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
use super::*;

pub(super) fn run<'input>(
    directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
    let mut result = Vec::with_capacity(directives.len());
    for mut directive in directives.into_iter() {
        run_directive(&mut result, &mut directive)?;
        result.push(directive);
    }
    Ok(result)
}

fn run_directive<'input>(
    result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
    directive: &mut Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
    match directive {
        Directive2::Variable(..) => {}
        Directive2::Method(function2) => run_function(result, function2),
    }
    Ok(())
}

fn run_function<'input>(
    result: &mut Vec<Directive2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>>,
    function: &mut Function2<'input, ptx_parser::Instruction<SpirvWord>, SpirvWord>,
) {
    function.body = function.body.take().map(|statements| {
        statements
            .into_iter()
            .filter_map(|statement| match statement {
                Statement::Variable(var @ ast::Variable {
                    state_space:
                        ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared,
                    ..
                }) => {
                    result.push(Directive2::Variable(ast::LinkingDirective::NONE, var));
                    None
                }
                s => Some(s),
            })
            .collect()
    });
}