From de36980cbe6ea6387555d34c8485db3bf04e1968 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Thu, 19 Nov 2020 01:57:02 +0100 Subject: Fix remaining bugs --- ptx/src/translate.rs | 54 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 50f37fb..76a2714 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast; use half::f16; use rspirv::dr; -use std::{collections::BTreeSet, borrow::Cow, ffi::CString, hash::Hash, iter, mem}; +use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem}; use std::{ collections::{hash_map, HashMap, HashSet}, convert::TryInto, @@ -1254,16 +1254,11 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator>, ) -> Result>, TranslateError> { args.map(|a| { - let mut var_type = a.v_type.to_func_type(); - let mut is_variable = false; - var_type = match a.v_type { - ast::FnArgumentType::Reg(_) => { - is_variable = true; - var_type - } - ast::FnArgumentType::Shared => var_type.param_pointer_to(ast::LdStateSpace::Shared)?, - ast::FnArgumentType::Param(_) => var_type.param_pointer_to(ast::LdStateSpace::Param)?, + let is_variable = match a.v_type { + ast::FnArgumentType::Reg(_) => true, + _ => false, }; + let var_type = a.v_type.to_func_type(); Ok(ast::FnArgument { name: fn_resolver.add_def(a.name, Some(var_type), is_variable), v_type: a.v_type.clone(), @@ -1301,8 +1296,12 @@ fn to_ssa<'input, 'b>( convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; let typed_statements = convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; - let ssa_statements = - insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?; + let ssa_statements = insert_mem_ssa_statements( + typed_statements, + &mut numeric_id_defs, + &f_args, + &mut spirv_decl, + )?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1925,11 +1924,16 @@ fn normalize_predicates( fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, + ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>, fn_decl: &mut SpirvMethodDecl, ) -> Result, TranslateError> { + let is_func = match ast_fn_decl { + ast::MethodDecl::Func(..) => true, + ast::MethodDecl::Kernel { .. } => false, + }; let mut result = Vec::with_capacity(func.len()); for arg in fn_decl.output.iter() { - match type_to_variable_type(&arg.v_type)? { + match type_to_variable_type(&arg.v_type, is_func)? { Some(var_type) => { result.push(Statement::Variable(ast::Variable { align: arg.align, @@ -1941,25 +1945,25 @@ fn insert_mem_ssa_statements<'a, 'b>( None => return Err(TranslateError::Unreachable), } } - for arg in fn_decl.input.iter_mut() { - match type_to_variable_type(&arg.v_type)? { + for (index, spirv_arg) in fn_decl.input.iter_mut().enumerate() { + match type_to_variable_type(&spirv_arg.v_type, is_func)? { Some(var_type) => { - let typ = arg.v_type.clone(); + let typ = spirv_arg.v_type.clone(); let new_id = id_def.new_non_variable(Some(typ.clone())); result.push(Statement::Variable(ast::Variable { - align: arg.align, + align: spirv_arg.align, v_type: var_type, - name: arg.name, - array_init: arg.array_init.clone(), + name: spirv_arg.name, + array_init: spirv_arg.array_init.clone(), })); result.push(Statement::StoreVar( ast::Arg2St { - src1: arg.name, + src1: spirv_arg.name, src2: new_id, }, typ, )); - arg.name = new_id; + spirv_arg.name = new_id; } None => {} } @@ -2015,7 +2019,10 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result) } -fn type_to_variable_type(t: &ast::Type) -> Result, TranslateError> { +fn type_to_variable_type( + t: &ast::Type, + is_func: bool, +) -> Result, TranslateError> { Ok(match t { ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))), ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector( @@ -2031,6 +2038,9 @@ fn type_to_variable_type(t: &ast::Type) -> Result, Tra len.clone(), ))), ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => { + if is_func { + return Ok(None); + } Some(ast::VariableType::Reg(ast::VariableRegType::Pointer( scalar_type .clone() -- cgit v1.2.3