aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-05-07 18:22:09 +0200
committerAndrzej Janik <[email protected]>2021-05-07 18:22:09 +0200
commit425edfcdd49a4fa49d480f1b078c55dba4709e29 (patch)
treeb4c558356c5c8bc3341f90c3f2a4c56dad1876ab
parent7f051ad20ec933f78ce4539020a25fab3503011c (diff)
downloadZLUDA-425edfcdd49a4fa49d480f1b078c55dba4709e29.tar.gz
ZLUDA-425edfcdd49a4fa49d480f1b078c55dba4709e29.zip
Simplify typing
-rw-r--r--ptx/src/ast.rs21
-rw-r--r--ptx/src/ptx.lalrpop21
-rw-r--r--ptx/src/translate.rs524
-rw-r--r--zluda_dump/src/lib.rs14
4 files changed, 247 insertions, 333 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 364ec01..e45a6fb 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,6 +1,6 @@
use half::f16;
use lalrpop_util::{lexer::Token, ParseError};
-use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
+use std::{convert::From, mem, num::ParseFloatError, rc::Rc, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
#[derive(Debug, thiserror::Error)]
@@ -86,19 +86,20 @@ pub enum Directive<'a, P: ArgParams> {
Method(Function<'a, &'a str, Statement<P>>),
}
-pub enum MethodDecl<'a, ID> {
- Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>),
- Kernel {
- name: &'a str,
- in_args: Vec<KernelArgument<ID>>,
- },
+#[derive(Hash, PartialEq, Eq, Copy, Clone)]
+pub enum MethodName<'input, ID> {
+ Kernel(&'input str),
+ Func(ID),
}
-pub type FnArgument<ID> = Variable<ID>;
-pub type KernelArgument<ID> = Variable<ID>;
+pub struct MethodDeclaration<'input, ID> {
+ pub return_arguments: Vec<Variable<ID>>,
+ pub name: MethodName<'input, ID>,
+ pub input_arguments: Vec<Variable<ID>>,
+}
pub struct Function<'a, ID, S> {
- pub func_directive: MethodDecl<'a, ID>,
+ pub func_directive: MethodDeclaration<'a, ID>,
pub tuning: Vec<TuningDirective>,
pub body: Option<Vec<S>>,
}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 8fee7c2..78ebf1d 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -360,7 +360,7 @@ AddressSize = {
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
LinkingDirectives
- <func_directive:MethodDecl>
+ <func_directive:MethodDeclaration>
<tuning:TuningDirective*>
<body:FunctionBody> => ast::Function{<>}
};
@@ -388,19 +388,24 @@ LinkingDirectives: ast::LinkingDirective = {
}
}
-MethodDecl: ast::MethodDecl<'input, &'input str> = {
- ".entry" <name:ExtendedID> <in_args:KernelArguments> =>
- ast::MethodDecl::Kernel{ name, in_args },
- ".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
- ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
+MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = {
+ ".entry" <name:ExtendedID> <input_arguments:KernelArguments> => {
+ let return_arguments = Vec::new();
+ let name = ast::MethodName::Kernel(name);
+ ast::MethodDeclaration{ return_arguments, name, input_arguments }
+ },
+ ".func" <return_arguments:FnArguments?> <name:ExtendedID> <input_arguments:FnArguments> => {
+ let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
+ let name = ast::MethodName::Func(name);
+ ast::MethodDeclaration{ return_arguments, name, input_arguments }
}
};
-KernelArguments: Vec<ast::KernelArgument<&'input str>> = {
+KernelArguments: Vec<ast::Variable<&'input str>> = {
"(" <args:Comma<KernelInput>> ")" => args
};
-FnArguments: Vec<ast::FnArgument<&'input str>> = {
+FnArguments: Vec<ast::Variable<&'input str>> = {
"(" <args:Comma<FnInput>> ")" => args
};
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 1a2eda3..88ef51b 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,7 +1,9 @@
use crate::ast;
+use core::borrow;
use half::f16;
use rspirv::dr;
-use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
+use std::{borrow::Borrow, cell::RefCell};
+use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@@ -458,7 +460,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
.collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
- let call_map = get_call_map(&directives);
+ let call_map = get_kernels_call_map(&directives);
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
@@ -496,9 +498,12 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
}
// TODO: remove this once we have perf-function support for denorms
-fn emit_denorm_build_string(
+fn emit_denorm_build_string<'input>(
call_map: &HashMap<&str, HashSet<u32>>,
- denorm_information: &HashMap<MethodName, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
) -> CString {
let denorm_counts = denorm_information
.iter()
@@ -516,10 +521,12 @@ fn emit_denorm_build_string(
.collect::<HashMap<_, _>>();
let mut flush_over_preserve = 0;
for (kernel, children) in call_map {
- flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Kernel(kernel))
+ .unwrap_or(&0);
for child_fn in children {
flush_over_preserve += *denorm_counts
- .get(&MethodName::Func(*child_fn))
+ .get(&ast::MethodName::Func(*child_fn))
.unwrap_or(&0);
}
}
@@ -535,9 +542,12 @@ fn emit_directives<'input>(
map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver<'input>,
opencl_id: spirv::Word,
- denorm_information: &HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
- directives: Vec<Directive>,
+ directives: Vec<Directive<'input>>,
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
let empty_body = Vec::new();
@@ -560,16 +570,18 @@ fn emit_directives<'input>(
for var in f.globals.iter() {
emit_variable(builder, map, var)?;
}
+ let func_decl = (*f.func_decl).borrow();
let fn_id = emit_function_header(
builder,
map,
&id_defs,
&f.globals,
- &f.spirv_decl,
+ &*func_decl,
&denorm_information,
call_map,
&directives,
kernel_info,
+ f.uses_shared_mem,
)?;
for t in f.tuning.iter() {
match *t {
@@ -594,8 +606,13 @@ fn emit_directives<'input>(
}
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
- if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
- (&f.func_decl, &f.import_as)
+ if let (
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func(fn_id),
+ ..
+ },
+ Some(name),
+ ) = (&*func_decl, &f.import_as)
{
builder.decorate(
*fn_id,
@@ -614,7 +631,7 @@ fn emit_directives<'input>(
Ok(())
}
-fn get_call_map<'input>(
+fn get_kernels_call_map<'input>(
module: &[Directive<'input>],
) -> HashMap<&'input str, HashSet<spirv::Word>> {
let mut directly_called_by = HashMap::new();
@@ -625,7 +642,7 @@ fn get_call_map<'input>(
body: Some(statements),
..
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
entry.insert(Vec::new());
}
@@ -644,28 +661,28 @@ fn get_call_map<'input>(
let mut result = HashMap::new();
for (method_key, children) in directly_called_by.iter() {
match method_key {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let mut visited = HashSet::new();
for child in children {
add_call_map_single(&directly_called_by, &mut visited, *child);
}
result.insert(*name, visited);
}
- MethodName::Func(_) => {}
+ ast::MethodName::Func(_) => {}
}
}
result
}
fn add_call_map_single<'input>(
- directly_called_by: &MultiHashMap<MethodName<'input>, spirv::Word>,
+ directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
if !visited.insert(current) {
return;
}
- if let Some(children) = directly_called_by.get(&MethodName::Func(current)) {
+ if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
for child in children {
add_call_map_single(directly_called_by, visited, *child);
}
@@ -739,10 +756,10 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
+ uses_shared_mem,
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| match statement {
@@ -763,8 +780,8 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
+ uses_shared_mem,
})
}
directive => directive,
@@ -782,30 +799,32 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- mut spirv_decl,
tuning,
+ uses_shared_mem,
}) => {
- if !methods_using_extern_shared.contains(&spirv_decl.name) {
+ if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
+ uses_shared_mem,
});
}
let shared_id_param = new_id();
- spirv_decl.input.push({
- ast::Variable {
- name: shared_id_param,
- align: None,
- v_type: ast::Type::Pointer(ast::ScalarType::B8),
- state_space: ast::StateSpace::Shared,
- array_init: Vec::new(),
- }
- });
- spirv_decl.uses_shared_mem = true;
+ {
+ let mut func_decl = (*func_decl).borrow_mut();
+ func_decl.input_arguments.push({
+ ast::Variable {
+ name: shared_id_param,
+ align: None,
+ v_type: ast::Type::Pointer(ast::ScalarType::B8),
+ state_space: ast::StateSpace::Shared,
+ array_init: Vec::new(),
+ }
+ });
+ }
let statements = replace_uses_of_shared_memory(
new_id,
&extern_shared_decls,
@@ -818,8 +837,8 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
+ uses_shared_mem: true,
})
}
directive => directive,
@@ -830,7 +849,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
fn replace_uses_of_shared_memory<'a>(
new_id: &mut impl FnMut() -> spirv::Word,
extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
shared_id_param: spirv::Word,
statements: Vec<ExpandedStatement>,
) -> Vec<ExpandedStatement> {
@@ -841,7 +860,7 @@ fn replace_uses_of_shared_memory<'a>(
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
+ if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) {
call.param_list.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
@@ -881,13 +900,13 @@ fn replace_uses_of_shared_memory<'a>(
}
fn get_callers_of_extern_shared<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
) {
let direct_uses_of_extern_shared = methods_using_extern_shared
.iter()
.filter_map(|method| {
- if let MethodName::Func(f_id) = method {
+ if let ast::MethodName::Func(f_id) = method {
Some(*f_id)
} else {
None
@@ -900,14 +919,14 @@ fn get_callers_of_extern_shared<'a>(
}
fn get_callers_of_extern_shared_single<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
fn_id: spirv::Word,
) {
if let Some(callers) = directly_called_by.get(&fn_id) {
for caller in callers {
if methods_using_extern_shared.insert(*caller) {
- if let MethodName::Func(caller_fn) = caller {
+ if let ast::MethodName::Func(caller_fn) = caller {
get_callers_of_extern_shared_single(
methods_using_extern_shared,
directly_called_by,
@@ -949,7 +968,7 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
// and emit suitable execution mode
fn compute_denorm_information<'input>(
module: &[Directive<'input>],
-) -> HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
+) -> HashMap<ast::MethodName<'input, spirv::Word>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
@@ -960,7 +979,7 @@ fn compute_denorm_information<'input>(
..
}) => {
let mut flush_counter = DenormCountMap::new();
- let method_key = MethodName::new(func_decl);
+ let method_key = (**func_decl).borrow().name;
for statement in statements {
match statement {
Statement::Instruction(inst) => {
@@ -1004,21 +1023,6 @@ fn compute_denorm_information<'input>(
.collect()
}
-#[derive(Hash, PartialEq, Eq, Copy, Clone)]
-enum MethodName<'input> {
- Kernel(&'input str),
- Func(spirv::Word),
-}
-
-impl<'input> MethodName<'input> {
- fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- match decl {
- ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name),
- ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id),
- }
- }
-}
-
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1047,17 +1051,21 @@ fn emit_function_header<'a>(
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
synthetic_globals: &[ast::Variable<spirv::Word>],
- func_decl: &SpirvMethodDecl<'a>,
- _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ func_decl: &ast::MethodDeclaration<'a, spirv::Word>,
+ _denorm_information: &HashMap<
+ ast::MethodName<'a, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
+ uses_shared_mem: bool,
) -> Result<spirv::Word, TranslateError> {
- if let MethodName::Kernel(name) = func_decl.name {
- let input_args = if !func_decl.uses_shared_mem {
- func_decl.input.as_slice()
+ if let ast::MethodName::Kernel(name) = func_decl.name {
+ let input_args = if !uses_shared_mem {
+ func_decl.input_arguments.as_slice()
} else {
- &func_decl.input[0..func_decl.input.len() - 1]
+ &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
};
let args_lens = input_args
.iter()
@@ -1067,14 +1075,18 @@ fn emit_function_header<'a>(
name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
- uses_shared_mem: func_decl.uses_shared_mem,
+ uses_shared_mem: uses_shared_mem,
},
);
}
- let (ret_type, func_type) =
- get_function_type(builder, map, &func_decl.input, &func_decl.output);
+ let (ret_type, func_type) = get_function_type(
+ builder,
+ map,
+ &func_decl.input_arguments,
+ &func_decl.return_arguments,
+ );
let fn_id = match func_decl.name {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let fn_id = defined_globals.get_id(name)?;
let mut global_variables = defined_globals
.variables_type_check
@@ -1090,15 +1102,16 @@ fn emit_function_header<'a>(
for directive in direcitves {
match directive {
Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- globals,
- ..
+ func_decl, globals, ..
}) => {
- if child_fns.contains(name) {
- for var in globals {
- interface.push(var.name);
+ match (**func_decl).borrow().name {
+ ast::MethodName::Func(name) => {
+ for var in globals {
+ interface.push(var.name);
+ }
}
- }
+ ast::MethodName::Kernel(_) => {}
+ };
}
_ => {}
}
@@ -1107,7 +1120,7 @@ fn emit_function_header<'a>(
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
fn_id
}
- MethodName::Func(name) => name,
+ ast::MethodName::Func(name) => name,
};
builder.begin_function(
ret_type,
@@ -1130,7 +1143,7 @@ fn emit_function_header<'a>(
}
}
*/
- for input in &func_decl.input {
+ for input in &func_decl.input_arguments {
let result_type = map.get_or_add(
builder,
SpirvType::new(input.v_type.clone(), input.state_space),
@@ -1225,9 +1238,10 @@ fn translate_function<'a>(
f: ast::ParsedFunction<'a>,
) -> Result<Option<Function<'a>>, TranslateError> {
let import_as = match &f.func_directive {
- ast::MethodDecl::Func(_, "__assertfail", _) => {
- Some("__zluda_ptx_impl____assertfail".to_owned())
- }
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func("__assertfail"),
+ ..
+ } => Some("__zluda_ptx_impl____assertfail".to_owned()),
_ => None,
};
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
@@ -1253,10 +1267,10 @@ fn translate_function<'a>(
fn expand_kernel_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
-) -> Result<Vec<ast::KernelArgument<spirv::Word>>, TranslateError> {
+ args: impl Iterator<Item = &'b ast::Variable<&'a str>>,
+) -> Result<Vec<ast::Variable<spirv::Word>>, TranslateError> {
args.map(|a| {
- Ok(ast::KernelArgument {
+ Ok(ast::Variable {
name: fn_resolver.add_def(
a.name,
Some((
@@ -1274,42 +1288,39 @@ fn expand_kernel_params<'a, 'b>(
.collect::<Result<_, _>>()
}
-fn expand_fn_params<'a, 'b>(
+fn rename_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
-) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- let is_variable = a.state_space == ast::StateSpace::Reg;
- Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), is_variable),
+ args: &'b [ast::Variable<&'a str>],
+) -> Vec<ast::Variable<spirv::Word>> {
+ args.iter()
+ .map(|a| ast::Variable {
+ name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), false),
v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align,
- array_init: Vec::new(),
+ array_init: a.array_init.clone(),
})
- })
- .collect()
+ .collect()
}
fn to_ssa<'input, 'b>(
ptx_impl_imports: &mut HashMap<String, Directive>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
- f_args: ast::MethodDecl<'input, spirv::Word>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
) -> Result<Function<'input>, TranslateError> {
- let mut spirv_decl = SpirvMethodDecl::new(&f_args);
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(Function {
- func_decl: f_args,
+ func_decl: func_decl,
body: None,
globals: Vec::new(),
import_as: None,
- spirv_decl,
tuning,
+ uses_shared_mem: false,
})
}
};
@@ -1323,8 +1334,7 @@ fn to_ssa<'input, 'b>(
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
- &f_args,
- &mut spirv_decl,
+ &mut (*func_decl).borrow_mut(),
)?;
let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
@@ -1336,12 +1346,12 @@ fn to_ssa<'input, 'b>(
let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
Ok(Function {
- func_decl: f_args,
+ func_decl: func_decl,
globals: globals,
body: Some(f_body),
import_as: None,
- spirv_decl,
tuning,
+ uses_shared_mem: false,
})
}
@@ -1573,9 +1583,9 @@ fn convert_to_typed_statements(
Statement::Instruction(inst) => match inst {
ast::Instruction::Call(call) => {
// TODO: error out if lengths don't match
- let fn_def = fn_defs.get_fn_decl(call.func)?;
- let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
- let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
+ let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow();
+ let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments);
+ let in_args = to_resolved_fn_args(call.param_list, &*fn_def.input_arguments);
let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
.into_iter()
.partition(|(_, _, space)| *space == ast::StateSpace::Param);
@@ -1731,24 +1741,24 @@ fn to_ptx_impl_atomic_call(
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
v_type: ast::Type::Pointer(typ),
state_space: ptr_space,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
@@ -1756,24 +1766,23 @@ fn to_ptx_impl_atomic_call(
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
+ uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
@@ -1810,31 +1819,31 @@ fn to_ptx_impl_bfe_call(
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
@@ -1842,24 +1851,23 @@ fn to_ptx_impl_bfe_call(
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
+ uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
@@ -1903,38 +1911,38 @@ fn to_ptx_impl_bfi_call(
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.register_intermediate(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
@@ -1942,24 +1950,23 @@ fn to_ptx_impl_bfi_call(
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
+ uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
@@ -1994,12 +2001,12 @@ fn to_ptx_impl_bfi_call(
fn to_resolved_fn_args<T>(
params: Vec<T>,
- params_decl: &[(ast::Type, ast::StateSpace)],
+ params_decl: &[ast::Variable<spirv::Word>],
) -> Vec<(T, ast::Type, ast::StateSpace)> {
params
.into_iter()
.zip(params_decl.iter())
- .map(|(id, (typ, space))| (id, typ.clone(), *space))
+ .map(|(id, var)| (id, var.v_type.clone(), var.state_space))
.collect::<Vec<_>>()
}
@@ -2084,11 +2091,10 @@ fn normalize_predicates(
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
- _: &'a ast::MethodDecl<'b, spirv::Word>,
- fn_decl: &mut SpirvMethodDecl,
+ fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
- for arg in fn_decl.output.iter() {
+ for arg in fn_decl.return_arguments.iter() {
result.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
@@ -2097,27 +2103,27 @@ fn insert_mem_ssa_statements<'a, 'b>(
array_init: arg.array_init.clone(),
}));
}
- for spirv_arg in fn_decl.input.iter_mut() {
- let typ = spirv_arg.v_type.clone();
- let state_space = spirv_arg.state_space;
+ for arg in fn_decl.input_arguments.iter_mut() {
+ let typ = arg.v_type.clone();
+ let state_space = arg.state_space;
let new_id = id_def.register_intermediate(Some((typ.clone(), state_space)));
result.push(Statement::Variable(ast::Variable {
- align: spirv_arg.align,
- v_type: spirv_arg.v_type.clone(),
- state_space: spirv_arg.state_space,
- name: spirv_arg.name,
- array_init: spirv_arg.array_init.clone(),
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: Vec::new(),
}));
result.push(Statement::StoreVar(StoreVarDetails {
arg: ast::Arg2St {
- src1: spirv_arg.name,
+ src1: arg.name,
src2: new_id,
},
state_space,
typ,
member_index: None,
}));
- spirv_arg.name = new_id;
+ arg.name = new_id;
}
for s in func {
match s {
@@ -2127,7 +2133,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
- if let &[out_param] = &fn_decl.output.as_slice() {
+ if let &[out_param] = &fn_decl.return_arguments.as_slice() {
let (typ, space, _) = id_def.get_typed(out_param.name)?;
let new_id = id_def.register_intermediate(Some((typ.clone(), space)));
result.push(Statement::LoadVar(LoadVarDetails {
@@ -5081,15 +5087,10 @@ struct GlobalStringIdResolver<'input> {
variables: HashMap<Cow<'input, str>, spirv::Word>,
variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: SpecialRegistersMap,
- fns: HashMap<spirv::Word, FnDecl>,
+ fns: HashMap<spirv::Word, Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>>,
}
-pub struct FnDecl {
- ret_vals: Vec<(ast::Type, ast::StateSpace)>,
- params: Vec<(ast::Type, ast::StateSpace)>,
-}
-
-impl<'a> GlobalStringIdResolver<'a> {
+impl<'input> GlobalStringIdResolver<'input> {
fn new(start_id: spirv::Word) -> Self {
Self {
current_id: start_id,
@@ -5100,13 +5101,13 @@ impl<'a> GlobalStringIdResolver<'a> {
}
}
- fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
+ fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word {
self.get_or_add_impl(id, None)
}
fn get_or_add_def_typed(
&mut self,
- id: &'a str,
+ id: &'input str,
typ: ast::Type,
state_space: ast::StateSpace,
is_variable: bool,
@@ -5116,7 +5117,7 @@ impl<'a> GlobalStringIdResolver<'a> {
fn get_or_add_impl(
&mut self,
- id: &'a str,
+ id: &'input str,
typ: Option<(ast::Type, ast::StateSpace, bool)>,
) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
@@ -5145,12 +5146,12 @@ impl<'a> GlobalStringIdResolver<'a> {
fn start_fn<'b>(
&'b mut self,
- header: &'b ast::MethodDecl<'a, &'a str>,
+ header: &'b ast::MethodDeclaration<'input, &'input str>,
) -> Result<
(
- FnStringIdResolver<'a, 'b>,
- GlobalFnDeclResolver<'a, 'b>,
- ast::MethodDecl<'a, spirv::Word>,
+ FnStringIdResolver<'input, 'b>,
+ GlobalFnDeclResolver<'input, 'b>,
+ Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
),
TranslateError,
> {
@@ -5164,30 +5165,18 @@ impl<'a> GlobalStringIdResolver<'a> {
variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
};
- let new_fn_decl = match header {
- ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel {
- name,
- in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?,
- },
- ast::MethodDecl::Func(ret_params, _, params) => {
- let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?;
- let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?;
- self.fns.insert(
- name_id,
- FnDecl {
- ret_vals: ret_params_ids
- .iter()
- .map(|p| (p.v_type.clone(), p.state_space))
- .collect(),
- params: params_ids
- .iter()
- .map(|p| (p.v_type.clone(), p.state_space))
- .collect(),
- },
- );
- ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
- }
+ let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
+ let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
+ let name = match header.name {
+ ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
+ ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
};
+ let new_fn_decl = Rc::new(RefCell::new(ast::MethodDeclaration {
+ return_arguments,
+ name,
+ input_arguments,
+ }));
+ self.fns.insert(name_id, Rc::clone(&new_fn_decl));
Ok((
fn_resolver,
GlobalFnDeclResolver {
@@ -5201,15 +5190,21 @@ impl<'a> GlobalStringIdResolver<'a> {
pub struct GlobalFnDeclResolver<'input, 'a> {
variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
- fns: &'a HashMap<spirv::Word, FnDecl>,
+ fns: &'a HashMap<spirv::Word, Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>>,
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
- fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> {
+ fn get_fn_decl(
+ &self,
+ id: spirv::Word,
+ ) -> Result<&Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
}
- fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> {
+ fn get_fn_decl_str(
+ &self,
+ id: &str,
+ ) -> Result<&'a Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>, TranslateError> {
match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
Some(Some(fn_d)) => Ok(fn_d),
_ => Err(TranslateError::UnknownSymbol),
@@ -5713,21 +5708,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
}
}
-pub trait ArgParamsEx: ast::ArgParams + Sized {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError>;
-}
+pub trait ArgParamsEx: ast::ArgParams + Sized {}
-impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl_str(id)
- }
-}
+impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {}
enum NormalizedArgParams {}
@@ -5736,14 +5719,7 @@ impl ast::ArgParams for NormalizedArgParams {
type Operand = ast::Operand<spirv::Word>;
}
-impl ArgParamsEx for NormalizedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for NormalizedArgParams {}
type NormalizedStatement = Statement<
(
@@ -5762,14 +5738,7 @@ impl ast::ArgParams for TypedArgParams {
type Operand = TypedOperand;
}
-impl ArgParamsEx for TypedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for TypedArgParams {}
#[derive(Copy, Clone)]
enum TypedOperand {
@@ -5800,14 +5769,7 @@ impl ast::ArgParams for ExpandedArgParams {
type Operand = spirv::Word;
}
-impl ArgParamsEx for ExpandedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for ExpandedArgParams {}
enum Directive<'input> {
Variable(ast::Variable<spirv::Word>),
@@ -5815,10 +5777,10 @@ enum Directive<'input> {
}
struct Function<'input> {
- pub func_decl: ast::MethodDecl<'input, spirv::Word>,
- pub spirv_decl: SpirvMethodDecl<'input>,
+ pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
+ pub uses_shared_mem: bool,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
}
@@ -7671,73 +7633,11 @@ fn should_convert_relaxed_dst(
}
}
-impl<'a> ast::MethodDecl<'a, &'a str> {
+impl<'a> ast::MethodDeclaration<'a, &'a str> {
fn name(&self) -> &'a str {
- match self {
- ast::MethodDecl::Kernel { name, .. } => name,
- ast::MethodDecl::Func(_, name, _) => name,
- }
- }
-}
-
-struct SpirvMethodDecl<'input> {
- input: Vec<ast::Variable<spirv::Word>>,
- output: Vec<ast::Variable<spirv::Word>>,
- name: MethodName<'input>,
- uses_shared_mem: bool,
-}
-
-impl<'input> SpirvMethodDecl<'input> {
- fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- let (input, output) = match ast_decl {
- ast::MethodDecl::Kernel { in_args, .. } => {
- let spirv_input = in_args
- .iter()
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.clone(),
- state_space: var.state_space,
- array_init: var.array_init.clone(),
- })
- .collect();
- (spirv_input, Vec::new())
- }
- ast::MethodDecl::Func(out_args, _, in_args) => {
- let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args
- .iter()
- .partition(|var| var.state_space == ast::StateSpace::Param);
- let spirv_output = non_param_output
- .into_iter()
- .cloned()
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.clone(),
- state_space: var.state_space,
- array_init: var.array_init.clone(),
- })
- .collect();
- let spirv_input = param_output
- .into_iter()
- .cloned()
- .chain(in_args.iter().cloned())
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.clone(),
- state_space: var.state_space,
- array_init: var.array_init.clone(),
- })
- .collect();
- (spirv_input, spirv_output)
- }
- };
- SpirvMethodDecl {
- input,
- output,
- name: MethodName::new(ast_decl),
- uses_shared_mem: false,
+ match self.name {
+ ast::MethodName::Kernel(name) => name,
+ ast::MethodName::Func(name) => name,
}
}
}
diff --git a/zluda_dump/src/lib.rs b/zluda_dump/src/lib.rs
index 4ea449c..f168930 100644
--- a/zluda_dump/src/lib.rs
+++ b/zluda_dump/src/lib.rs
@@ -191,7 +191,10 @@ unsafe fn record_module_image(module: CUmodule, image: &str) {
unsafe fn try_dump_module_image(image: &str) -> Result<(), Box<dyn Error>> {
let mut dump_path = get_dump_dir()?;
- dump_path.push(format!("module_{:04}.ptx", MODULES.as_ref().unwrap().len() - 1));
+ dump_path.push(format!(
+ "module_{:04}.ptx",
+ MODULES.as_ref().unwrap().len() - 1
+ ));
let mut file = File::create(dump_path)?;
file.write_all(image.as_bytes())?;
Ok(())
@@ -217,10 +220,15 @@ unsafe fn to_str<T>(image: *const T) -> Option<&'static str> {
fn directive_to_kernel(dir: &ast::Directive<ast::ParsedArgParams>) -> Option<(String, Vec<usize>)> {
match dir {
ast::Directive::Method(ast::Function {
- func_directive: ast::MethodDecl::Kernel { name, in_args },
+ func_directive:
+ ast::MethodDeclaration {
+ name: ast::MethodName::Kernel(name),
+ input_arguments,
+ ..
+ },
..
}) => {
- let arg_sizes = in_args
+ let arg_sizes = input_arguments
.iter()
.map(|arg| ast::Type::from(arg.v_type.clone()).size_of())
.collect();