aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-06-05 00:46:41 +0200
committerAndrzej Janik <[email protected]>2021-06-05 00:49:27 +0200
commit90960fd9239b9972dfffbff6ce26ce2642ec50af (patch)
tree212660fa421d73ba1ea9a73514477a64a47240d8 /ptx
parentf70abd065bc7651f75b5f41475a862f509fd68bd (diff)
downloadZLUDA-90960fd9239b9972dfffbff6ce26ce2642ec50af.tar.gz
ZLUDA-90960fd9239b9972dfffbff6ce26ce2642ec50af.zip
Fix method arg load generation
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/translate.rs67
1 files changed, 47 insertions, 20 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 6d5d5bc..c4efe55 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1059,7 +1059,7 @@ fn emit_function_header<'a>(
let (ret_type, func_type) = get_function_type(
builder,
map,
- &func_decl.input_arguments,
+ func_decl.effective_input_arguments().map(|(_, typ)| typ),
&func_decl.return_arguments,
);
let fn_id = match func_decl.name {
@@ -1120,9 +1120,9 @@ fn emit_function_header<'a>(
}
}
*/
- for input in &func_decl.input_arguments {
- let result_type = map.get_or_add(builder, SpirvType::new(input.v_type.clone()));
- builder.function_parameter(Some(input.name), result_type)?;
+ for (name, typ) in func_decl.effective_input_arguments() {
+ let result_type = map.get_or_add(builder, typ);
+ builder.function_parameter(Some(name), result_type)?;
}
Ok(fn_id)
}
@@ -1233,7 +1233,7 @@ fn to_ssa<'input, 'b>(
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
) -> Result<Function<'input>, TranslateError> {
- deparamize_function_decl(&func_decl)?;
+ //deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
Some(vec) => vec,
None => {
@@ -1997,30 +1997,38 @@ fn normalize_predicates(
/*
How do we handle arguments:
- - input .params
+ - input .params in kernels
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %ulong
- %2 = OpVariable %%_ptr_Function_ulong Function
+ %2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1
We do this for two reasons. One, common treatment for argument-declared
.param variables and .param variables inside function (we assume that
- at SPIR-V level every .param is a pointer in Function storage class). Two,
- PTX devs in their infinite wisdom decided that .reg arguments are writable
+ at SPIR-V level every .param is a pointer in Function storage class)
+ - input .params in functions
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %_ptr_Function_ulong
- input .regs
.reg .b64 in_arg
- get turned into this SPIR-V:
+ get turned into the same SPIR-V as kernel .params:
%1 = OpFunctionParameter %ulong
- %2 = OpVariable %%_ptr_Function_ulong Function
+ %2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1
- with the difference that %2 is defined as a variable and not temp
- output .regs
.reg .b64 out_arg
get just a variable declaration:
%2 = OpVariable %%_ptr_Function_ulong Function
- - output .params
- .param .b64 out_arg
- get treated the same as input .params, because there's no difference
+ - output .params don't exist, they have been moved to input positions
+ by an earlier pass
+ Distinguishing betweem kernel .params and function .params is not the
+ cleanest solution. Alternatively, we could "deparamize" all kernel .param
+ arguments by turning them into .reg arguments like this:
+ .param .b64 arg -> .reg ptr<.b64,.param> arg
+ This has the massive downside that this transformation would have to run
+ very early and would muddy up already difficult code. It's simpler to just
+ have an if here
*/
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
@@ -2029,7 +2037,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.input_arguments.iter_mut() {
- insert_mem_ssa_argument(id_def, &mut result, arg);
+ insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel());
}
for arg in fn_decl.return_arguments.iter() {
insert_mem_ssa_argument_reg_return(&mut result, arg);
@@ -2103,7 +2111,11 @@ fn insert_mem_ssa_argument(
id_def: &mut NumericIdResolver,
func: &mut Vec<TypedStatement>,
arg: &mut ast::Variable<spirv::Word>,
+ is_kernel: bool,
) {
+ if !is_kernel && arg.state_space == ast::StateSpace::Param {
+ return;
+ }
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
func.push(Statement::Variable(ast::Variable {
align: arg.align,
@@ -2559,14 +2571,12 @@ fn insert_implicit_conversions_impl(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- spirv_input: &[ast::Variable<spirv::Word>],
+ spirv_input: impl ExactSizeIterator<Item = SpirvType>,
spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
- spirv_input
- .iter()
- .map(|var| SpirvType::new(var.v_type.clone())),
+ spirv_input,
spirv_output
.iter()
.map(|var| SpirvType::new(var.v_type.clone())),
@@ -7542,6 +7552,23 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
}
}
+impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
+ fn effective_input_arguments(
+ &self,
+ ) -> impl ExactSizeIterator<Item = (spirv::Word, SpirvType)> + '_ {
+ let is_kernel = self.name.is_kernel();
+ self.input_arguments.iter().map(move |arg| {
+ if !is_kernel {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
+ }
+}
+
impl<'input, ID> ast::MethodName<'input, ID> {
fn is_kernel(&self) -> bool {
match self {