From c84d257bb72be1e047bf9ceb02a4c0bdaf220b4f Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 16 Sep 2024 17:20:46 +0200 Subject: Refactor type-of-function resolution --- ptx/src/pass/mod.rs | 1 + ptx/src/pass/resolve_function_pointers.rs | 82 +++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 ptx/src/pass/resolve_function_pointers.rs diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 9277de4..04d3e49 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -29,6 +29,7 @@ mod normalize_identifiers2; mod normalize_labels; mod normalize_predicates; mod normalize_predicates2; +mod resolve_function_pointers; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs new file mode 100644 index 0000000..9aaa694 --- /dev/null +++ b/ptx/src/pass/resolve_function_pointers.rs @@ -0,0 +1,82 @@ +use super::*; +use ptx_parser as ast; +use rustc_hash::FxHashSet; + +pub(crate) fn run<'input>( + directives: Vec>, +) -> Result>, TranslateError> { + let mut functions = FxHashSet::default(); + directives + .into_iter() + .map(|directive| run_directive(&mut functions, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + functions: &mut FxHashSet, + directive: UnconditionalDirective<'input>, +) -> Result, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => { + { + let func_decl = method.func_decl.borrow(); + match func_decl.name { + ptx_parser::MethodName::Kernel(_) => {} + ptx_parser::MethodName::Func(name) => { + functions.insert(name); + } + } + } + Directive2::Method(run_method(functions, method)?) + } + }) +} + +fn run_method<'input>( + functions: &mut FxHashSet, + method: UnconditionalFunction<'input>, +) -> Result, TranslateError> { + let body = method + .body + .map(|statements| { + statements + .into_iter() + .map(|statement| run_statement(functions, statement)) + .collect::, _>>() + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + functions: &mut FxHashSet, + statement: UnconditionalStatement, +) -> Result { + Ok(match statement { + Statement::Instruction(ast::Instruction::Mov { + data, + arguments: + ast::MovArgs { + dst: ast::ParsedOperand::Reg(dst_reg), + src: ast::ParsedOperand::Reg(src_reg), + }, + }) if functions.contains(&src_reg) => { + if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { + return Err(error_mismatched_type()); + } + UnconditionalStatement::FunctionPointer(FunctionPointerDetails { + dst: dst_reg, + src: src_reg, + }) + } + s => s, + }) +} -- cgit v1.2.3