aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-09-16 17:20:46 +0200
committerAndrzej Janik <[email protected]>2024-09-16 17:20:46 +0200
commitc84d257bb72be1e047bf9ceb02a4c0bdaf220b4f (patch)
treefbc1a9bb76ffa17a7d325359641b4b2c56b6ae1f
parente87388bc352601201960458c2768b571c5947696 (diff)
downloadZLUDA-c84d257bb72be1e047bf9ceb02a4c0bdaf220b4f.tar.gz
ZLUDA-c84d257bb72be1e047bf9ceb02a4c0bdaf220b4f.zip
Refactor type-of-function resolution
-rw-r--r--ptx/src/pass/mod.rs1
-rw-r--r--ptx/src/pass/resolve_function_pointers.rs82
2 files changed, 83 insertions, 0 deletions
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs
index 9277de4..04d3e49 100644
--- a/ptx/src/pass/mod.rs
+++ b/ptx/src/pass/mod.rs
@@ -29,6 +29,7 @@ mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
mod normalize_predicates2;
+mod resolve_function_pointers;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs
new file mode 100644
index 0000000..9aaa694
--- /dev/null
+++ b/ptx/src/pass/resolve_function_pointers.rs
@@ -0,0 +1,82 @@
+use super::*;
+use ptx_parser as ast;
+use rustc_hash::FxHashSet;
+
+pub(crate) fn run<'input>(
+ directives: Vec<UnconditionalDirective<'input>>,
+) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
+ let mut functions = FxHashSet::default();
+ directives
+ .into_iter()
+ .map(|directive| run_directive(&mut functions, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ functions: &mut FxHashSet<SpirvWord>,
+ directive: UnconditionalDirective<'input>,
+) -> Result<UnconditionalDirective<'input>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(method) => {
+ {
+ let func_decl = method.func_decl.borrow();
+ match func_decl.name {
+ ptx_parser::MethodName::Kernel(_) => {}
+ ptx_parser::MethodName::Func(name) => {
+ functions.insert(name);
+ }
+ }
+ }
+ Directive2::Method(run_method(functions, method)?)
+ }
+ })
+}
+
+fn run_method<'input>(
+ functions: &mut FxHashSet<SpirvWord>,
+ method: UnconditionalFunction<'input>,
+) -> Result<UnconditionalFunction<'input>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ statements
+ .into_iter()
+ .map(|statement| run_statement(functions, statement))
+ .collect::<Result<Vec<_>, _>>()
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ functions: &mut FxHashSet<SpirvWord>,
+ statement: UnconditionalStatement,
+) -> Result<UnconditionalStatement, TranslateError> {
+ Ok(match statement {
+ Statement::Instruction(ast::Instruction::Mov {
+ data,
+ arguments:
+ ast::MovArgs {
+ dst: ast::ParsedOperand::Reg(dst_reg),
+ src: ast::ParsedOperand::Reg(src_reg),
+ },
+ }) if functions.contains(&src_reg) => {
+ if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
+ return Err(error_mismatched_type());
+ }
+ UnconditionalStatement::FunctionPointer(FunctionPointerDetails {
+ dst: dst_reg,
+ src: src_reg,
+ })
+ }
+ s => s,
+ })
+}