aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/deparamize_functions.rs
blob: 15125b0e02750fee56122446a0671b891c430973 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
use super::*;

pub(super) fn run<'a, 'input>(
    resolver: &mut GlobalStringIdentResolver2<'input>,
    directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
    directives
        .into_iter()
        .map(|directive| run_directive(resolver, directive))
        .collect::<Result<Vec<_>, _>>()
}

fn run_directive<'input>(
    resolver: &mut GlobalStringIdentResolver2,
    directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
    Ok(match directive {
        var @ Directive2::Variable(..) => var,
        Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
    })
}

fn run_method<'input>(
    resolver: &mut GlobalStringIdentResolver2,
    mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
    let is_declaration = method.body.is_none();
    let mut body = Vec::new();
    let mut remap_returns = Vec::new();
    if !method.func_decl.name.is_kernel() {
        for arg in method.func_decl.return_arguments.iter_mut() {
            match arg.state_space {
                ptx_parser::StateSpace::Param => {
                    arg.state_space = ptx_parser::StateSpace::Reg;
                    let old_name = arg.name;
                    arg.name =
                        resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
                    if is_declaration {
                        continue;
                    }
                    remap_returns.push((old_name, arg.name, arg.v_type.clone()));
                    body.push(Statement::Variable(ast::Variable {
                        align: None,
                        name: old_name,
                        v_type: arg.v_type.clone(),
                        state_space: ptx_parser::StateSpace::Param,
                        array_init: Vec::new(),
                    }));
                }
                ptx_parser::StateSpace::Reg => {}
                _ => return Err(error_unreachable()),
            }
        }
        for arg in method.func_decl.input_arguments.iter_mut() {
            match arg.state_space {
                ptx_parser::StateSpace::Param => {
                    arg.state_space = ptx_parser::StateSpace::Reg;
                    let old_name = arg.name;
                    arg.name =
                        resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
                    if is_declaration {
                        continue;
                    }
                    body.push(Statement::Variable(ast::Variable {
                        align: None,
                        name: old_name,
                        v_type: arg.v_type.clone(),
                        state_space: ptx_parser::StateSpace::Param,
                        array_init: Vec::new(),
                    }));
                    body.push(Statement::Instruction(ast::Instruction::St {
                        data: ast::StData {
                            qualifier: ast::LdStQualifier::Weak,
                            state_space: ast::StateSpace::Param,
                            caching: ast::StCacheOperator::Writethrough,
                            typ: arg.v_type.clone(),
                        },
                        arguments: ast::StArgs {
                            src1: old_name,
                            src2: arg.name,
                        },
                    }));
                }
                ptx_parser::StateSpace::Reg => {}
                _ => return Err(error_unreachable()),
            }
        }
    }
    let body = method
        .body
        .map(|statements| {
            for statement in statements {
                run_statement(resolver, &remap_returns, &mut body, statement)?;
            }
            Ok::<_, TranslateError>(body)
        })
        .transpose()?;
    Ok(Function2 {
        func_decl: method.func_decl,
        globals: method.globals,
        body,
        import_as: method.import_as,
        tuning: method.tuning,
        linkage: method.linkage,
    })
}

fn run_statement<'input>(
    resolver: &mut GlobalStringIdentResolver2<'input>,
    remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
    result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
    statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
    match statement {
        Statement::Instruction(ast::Instruction::Call {
            mut data,
            mut arguments,
        }) => {
            let mut post_st = Vec::new();
            for ((type_, space), ident) in data
                .input_arguments
                .iter_mut()
                .zip(arguments.input_arguments.iter_mut())
            {
                if *space == ptx_parser::StateSpace::Param {
                    *space = ptx_parser::StateSpace::Reg;
                    let old_name = *ident;
                    *ident = resolver
                        .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
                    result.push(Statement::Instruction(ast::Instruction::Ld {
                        data: ast::LdDetails {
                            qualifier: ast::LdStQualifier::Weak,
                            state_space: ast::StateSpace::Param,
                            caching: ast::LdCacheOperator::Cached,
                            typ: type_.clone(),
                            non_coherent: false,
                        },
                        arguments: ast::LdArgs {
                            dst: *ident,
                            src: old_name,
                        },
                    }));
                }
            }
            for ((type_, space), ident) in data
                .return_arguments
                .iter_mut()
                .zip(arguments.return_arguments.iter_mut())
            {
                if *space == ptx_parser::StateSpace::Param {
                    *space = ptx_parser::StateSpace::Reg;
                    let old_name = *ident;
                    *ident = resolver
                        .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
                    post_st.push(Statement::Instruction(ast::Instruction::St {
                        data: ast::StData {
                            qualifier: ast::LdStQualifier::Weak,
                            state_space: ast::StateSpace::Param,
                            caching: ast::StCacheOperator::Writethrough,
                            typ: type_.clone(),
                        },
                        arguments: ast::StArgs {
                            src1: old_name,
                            src2: *ident,
                        },
                    }));
                }
            }
            result.push(Statement::Instruction(ast::Instruction::Call {
                data,
                arguments,
            }));
            result.extend(post_st.into_iter());
        }
        Statement::Instruction(ast::Instruction::Ret { data }) => {
            for (old_name, new_name, type_) in remap_returns.iter() {
                result.push(Statement::Instruction(ast::Instruction::Ld {
                    data: ast::LdDetails {
                        qualifier: ast::LdStQualifier::Weak,
                        state_space: ast::StateSpace::Param,
                        caching: ast::LdCacheOperator::Cached,
                        typ: type_.clone(),
                        non_coherent: false,
                    },
                    arguments: ast::LdArgs {
                        dst: *new_name,
                        src: *old_name,
                    },
                }));
            }
            result.push(Statement::Instruction(ast::Instruction::Ret { data }));
        }
        statement => {
            result.push(statement);
        }
    }
    Ok(())
}