diff options
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r-- | ptx/src/translate.rs | 54 |
1 files changed, 32 insertions, 22 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 50f37fb..76a2714 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast;
use half::f16;
use rspirv::dr;
-use std::{collections::BTreeSet, borrow::Cow, ffi::CString, hash::Hash, iter, mem};
+use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
use std::{
collections::{hash_map, HashMap, HashSet},
convert::TryInto,
@@ -1254,16 +1254,11 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
- let mut var_type = a.v_type.to_func_type();
- let mut is_variable = false;
- var_type = match a.v_type {
- ast::FnArgumentType::Reg(_) => {
- is_variable = true;
- var_type
- }
- ast::FnArgumentType::Shared => var_type.param_pointer_to(ast::LdStateSpace::Shared)?,
- ast::FnArgumentType::Param(_) => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
+ let is_variable = match a.v_type {
+ ast::FnArgumentType::Reg(_) => true,
+ _ => false,
};
+ let var_type = a.v_type.to_func_type();
Ok(ast::FnArgument {
name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
v_type: a.v_type.clone(),
@@ -1301,8 +1296,12 @@ fn to_ssa<'input, 'b>( convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
let typed_statements =
convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
- let ssa_statements =
- insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?;
+ let ssa_statements = insert_mem_ssa_statements(
+ typed_statements,
+ &mut numeric_id_defs,
+ &f_args,
+ &mut spirv_decl,
+ )?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1925,11 +1924,16 @@ fn normalize_predicates( fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
+ ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
fn_decl: &mut SpirvMethodDecl,
) -> Result<Vec<TypedStatement>, TranslateError> {
+ let is_func = match ast_fn_decl {
+ ast::MethodDecl::Func(..) => true,
+ ast::MethodDecl::Kernel { .. } => false,
+ };
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type)? {
+ match type_to_variable_type(&arg.v_type, is_func)? {
Some(var_type) => {
result.push(Statement::Variable(ast::Variable {
align: arg.align,
@@ -1941,25 +1945,25 @@ fn insert_mem_ssa_statements<'a, 'b>( None => return Err(TranslateError::Unreachable),
}
}
- for arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&arg.v_type)? {
+ for (index, spirv_arg) in fn_decl.input.iter_mut().enumerate() {
+ match type_to_variable_type(&spirv_arg.v_type, is_func)? {
Some(var_type) => {
- let typ = arg.v_type.clone();
+ let typ = spirv_arg.v_type.clone();
let new_id = id_def.new_non_variable(Some(typ.clone()));
result.push(Statement::Variable(ast::Variable {
- align: arg.align,
+ align: spirv_arg.align,
v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
+ name: spirv_arg.name,
+ array_init: spirv_arg.array_init.clone(),
}));
result.push(Statement::StoreVar(
ast::Arg2St {
- src1: arg.name,
+ src1: spirv_arg.name,
src2: new_id,
},
typ,
));
- arg.name = new_id;
+ spirv_arg.name = new_id;
}
None => {}
}
@@ -2015,7 +2019,10 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result)
}
-fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, TranslateError> {
+fn type_to_variable_type(
+ t: &ast::Type,
+ is_func: bool,
+) -> Result<Option<ast::VariableType>, TranslateError> {
Ok(match t {
ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
@@ -2031,6 +2038,9 @@ fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, Tra len.clone(),
))),
ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
+ if is_func {
+ return Ok(None);
+ }
Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
scalar_type
.clone()
|