aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-06-04 00:48:51 +0200
committerAndrzej Janik <[email protected]>2021-06-04 00:48:51 +0200
commitf70abd065bc7651f75b5f41475a862f509fd68bd (patch)
tree8ef96c306c51c5cf97eb7aefec4b47b0850158ce
parent2e6f7e3fdc6176279644f7bd02f8fb09195d6298 (diff)
downloadZLUDA-f70abd065bc7651f75b5f41475a862f509fd68bd.tar.gz
ZLUDA-f70abd065bc7651f75b5f41475a862f509fd68bd.zip
Continue attempts at fixing code emission for method args
-rw-r--r--ptx/src/ast.rs1
-rw-r--r--ptx/src/ptx.lalrpop4
-rw-r--r--ptx/src/translate.rs323
3 files changed, 224 insertions, 104 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 3ad61e5..a0bb023 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -96,6 +96,7 @@ pub struct MethodDeclaration<'input, ID> {
pub return_arguments: Vec<Variable<ID>>,
pub name: MethodName<'input, ID>,
pub input_arguments: Vec<Variable<ID>>,
+ pub shared_mem: Option<Variable<ID>>,
}
pub struct Function<'a, ID, S> {
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 2253f85..e8370cd 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -392,12 +392,12 @@ MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = {
".entry" <name:ExtendedID> <input_arguments:KernelArguments> => {
let return_arguments = Vec::new();
let name = ast::MethodName::Kernel(name);
- ast::MethodDeclaration{ return_arguments, name, input_arguments }
+ ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
},
".func" <return_arguments:FnArguments?> <name:ExtendedID> <input_arguments:FnArguments> => {
let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
let name = ast::MethodName::Func(name);
- ast::MethodDeclaration{ return_arguments, name, input_arguments }
+ ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
}
};
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 90a28b7..6d5d5bc 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -562,7 +562,6 @@ fn emit_directives<'input>(
call_map,
&directives,
kernel_info,
- f.uses_shared_mem,
)?;
for t in f.tuning.iter() {
match *t {
@@ -1038,10 +1037,9 @@ fn emit_function_header<'a>(
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
- uses_shared_mem: bool,
) -> Result<spirv::Word, TranslateError> {
if let ast::MethodName::Kernel(name) = func_decl.name {
- let input_args = if !uses_shared_mem {
+ let input_args = if func_decl.shared_mem.is_none() {
func_decl.input_arguments.as_slice()
} else {
&func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
@@ -1054,7 +1052,7 @@ fn emit_function_header<'a>(
name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
- uses_shared_mem: uses_shared_mem,
+ uses_shared_mem: func_decl.shared_mem.is_some(),
},
);
}
@@ -1218,7 +1216,7 @@ fn rename_fn_params<'a, 'b>(
) -> Vec<ast::Variable<spirv::Word>> {
args.iter()
.map(|a| ast::Variable {
- name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), false),
+ name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align,
@@ -1245,7 +1243,6 @@ fn to_ssa<'input, 'b>(
globals: Vec::new(),
import_as: None,
tuning,
- uses_shared_mem: false,
})
}
};
@@ -1276,7 +1273,6 @@ fn to_ssa<'input, 'b>(
body: Some(f_body),
import_as: None,
tuning,
- uses_shared_mem: false,
})
}
@@ -1529,18 +1525,8 @@ fn convert_to_typed_statements(
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Call(call) => {
- // TODO: error out if lengths don't match
- let fn_def = (**fn_defs.get_fn_decl(call.func)?).borrow();
- let return_arguments =
- to_resolved_fn_args(call.ret_params, &*fn_def.return_arguments);
- let input_arguments =
- to_resolved_fn_args(call.param_list, &*fn_def.input_arguments);
- let resolved_call = ResolvedCall {
- uniform: call.uniform,
- return_arguments,
- name: call.func,
- input_arguments,
- };
+ let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
+ let resolved_call = resolver.resolve_in_spirv_repr(call)?;
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let reresolved_call = resolved_call.visit(&mut visitor)?;
visitor.func.push(reresolved_call);
@@ -1683,6 +1669,7 @@ fn to_ptx_impl_atomic_call(
array_init: Vec::new(),
},
],
+ shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
@@ -1690,7 +1677,6 @@ fn to_ptx_impl_atomic_call(
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
- uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
@@ -1772,6 +1758,7 @@ fn to_ptx_impl_bfe_call(
array_init: Vec::new(),
},
],
+ shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
@@ -1779,7 +1766,6 @@ fn to_ptx_impl_bfe_call(
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
- uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
@@ -1871,6 +1857,7 @@ fn to_ptx_impl_bfi_call(
array_init: Vec::new(),
},
],
+ shared_mem: None,
};
let func = Function {
func_decl: Rc::new(RefCell::new(func_decl)),
@@ -1878,7 +1865,6 @@ fn to_ptx_impl_bfi_call(
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
- uses_shared_mem: false,
};
entry.insert(Directive::Method(func));
fn_id
@@ -2009,42 +1995,44 @@ fn normalize_predicates(
Ok(result)
}
+/*
+ How do we handle arguments:
+ - input .params
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %%_ptr_Function_ulong Function
+ OpStore %2 %1
+ We do this for two reasons. One, common treatment for argument-declared
+ .param variables and .param variables inside function (we assume that
+ at SPIR-V level every .param is a pointer in Function storage class). Two,
+ PTX devs in their infinite wisdom decided that .reg arguments are writable
+ - input .regs
+ .reg .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %%_ptr_Function_ulong Function
+ OpStore %2 %1
+ with the difference that %2 is defined as a variable and not temp
+ - output .regs
+ .reg .b64 out_arg
+ get just a variable declaration:
+ %2 = OpVariable %%_ptr_Function_ulong Function
+ - output .params
+ .param .b64 out_arg
+ get treated the same as input .params, because there's no difference
+*/
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
- for arg in fn_decl.return_arguments.iter() {
- result.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: arg.v_type.clone(),
- state_space: arg.state_space,
- name: arg.name,
- array_init: arg.array_init.clone(),
- }));
- }
for arg in fn_decl.input_arguments.iter_mut() {
- let typ = arg.v_type.clone();
- let state_space = arg.state_space;
- let new_id = id_def.register_intermediate(Some((typ.clone(), state_space)));
- result.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: arg.v_type.clone(),
- state_space: arg.state_space,
- name: arg.name,
- array_init: Vec::new(),
- }));
- result.push(Statement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: arg.name,
- src2: new_id,
- },
- state_space,
- typ,
- member_index: None,
- }));
- arg.name = new_id;
+ insert_mem_ssa_argument(id_def, &mut result, arg);
+ }
+ for arg in fn_decl.return_arguments.iter() {
+ insert_mem_ssa_argument_reg_return(&mut result, arg);
}
for s in func {
match s {
@@ -2054,22 +2042,26 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
- if let &[out_param] = &fn_decl.return_arguments.as_slice() {
- let (typ, space, _) = id_def.get_typed(out_param.name)?;
- let new_id = id_def.register_intermediate(Some((typ.clone(), space)));
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: ast::Arg2 {
- dst: new_id,
- src: out_param.name,
- },
- // TODO: ret with stateful conversion
- state_space: new_todo!(),
- typ: typ.clone(),
- member_index: None,
- }));
- result.push(Statement::RetValue(d, new_id));
- } else {
- result.push(Statement::Instruction(ast::Instruction::Ret(d)))
+ match &fn_decl.return_arguments[..] {
+ [return_reg] => {
+ let new_id = id_def.register_intermediate(Some((
+ return_reg.v_type.clone(),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::Arg2 {
+ dst: new_id,
+ src: return_reg.name,
+ },
+ // TODO: ret with stateful conversion
+ state_space: ast::StateSpace::Reg,
+ typ: return_reg.v_type.clone(),
+ member_index: None,
+ }));
+ result.push(Statement::RetValue(d, new_id));
+ }
+ [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))),
+ _ => unimplemented!(),
}
}
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
@@ -2107,6 +2099,43 @@ fn insert_mem_ssa_statements<'a, 'b>(
Ok(result)
}
+fn insert_mem_ssa_argument(
+ id_def: &mut NumericIdResolver,
+ func: &mut Vec<TypedStatement>,
+ arg: &mut ast::Variable<spirv::Word>,
+) {
+ let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: ast::StateSpace::Reg,
+ name: arg.name,
+ array_init: Vec::new(),
+ }));
+ func.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
+ src1: arg.name,
+ src2: new_id,
+ },
+ typ: arg.v_type.clone(),
+ member_index: None,
+ }));
+ arg.name = new_id;
+}
+
+fn insert_mem_ssa_argument_reg_return(
+ func: &mut Vec<TypedStatement>,
+ arg: &ast::Variable<spirv::Word>,
+) {
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: arg.array_init.clone(),
+ }));
+}
+
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
fn visit(
self,
@@ -2202,7 +2231,6 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
src1: symbol,
src2: generated_id,
},
- state_space: ast::StateSpace::Reg,
typ: var_type,
member_index: member_index.map(|(idx, _)| idx),
}));
@@ -4162,10 +4190,10 @@ fn emit_load_var(
Ok(())
}
-fn normalize_identifiers<'a, 'b>(
- id_defs: &mut FnStringIdResolver<'a, 'b>,
- fn_defs: &GlobalFnDeclResolver<'a, 'b>,
- func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
+fn normalize_identifiers<'input, 'b>(
+ id_defs: &mut FnStringIdResolver<'input, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'input, 'b>,
+ func: Vec<ast::Statement<ast::ParsedArgParams<'input>>>,
) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
@@ -4796,12 +4824,92 @@ impl SpecialRegistersMap {
}
}
+struct FnSigMapper<'input> {
+ // true - stays as return argument
+ // false - is moved to input argument
+ return_param_args: Vec<bool>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+}
+
+impl<'input> FnSigMapper<'input> {
+ fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self {
+ let return_param_args = method
+ .return_arguments
+ .iter()
+ .map(|a| a.state_space != ast::StateSpace::Param)
+ .collect::<Vec<_>>();
+ let mut new_return_arguments = Vec::new();
+ for arg in method.return_arguments.into_iter() {
+ if arg.state_space == ast::StateSpace::Param {
+ method.input_arguments.push(arg);
+ } else {
+ new_return_arguments.push(arg);
+ }
+ }
+ method.return_arguments = new_return_arguments;
+ FnSigMapper {
+ return_param_args,
+ func_decl: Rc::new(RefCell::new(method)),
+ }
+ }
+
+ fn resolve_in_spirv_repr(
+ &self,
+ call_inst: ast::CallInst<NormalizedArgParams>,
+ ) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
+ let func_decl = (*self.func_decl).borrow();
+ let mut return_arguments = Vec::new();
+ let mut input_arguments = call_inst
+ .param_list
+ .into_iter()
+ .zip(func_decl.input_arguments.iter())
+ .map(|(id, var)| (id, var.v_type.clone(), var.state_space))
+ .collect::<Vec<_>>();
+ let mut func_decl_return_iter = func_decl.return_arguments.iter();
+ let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
+ for (idx, id) in call_inst.ret_params.iter().enumerate() {
+ let stays_as_return = match self.return_param_args.get(idx) {
+ Some(x) => *x,
+ None => return Err(TranslateError::MismatchedType),
+ };
+ if stays_as_return {
+ if let Some(var) = func_decl_return_iter.next() {
+ return_arguments.push((*id, var.v_type.clone(), var.state_space));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ } else {
+ if let Some(var) = func_decl_input_iter.next() {
+ input_arguments.push((
+ ast::Operand::Reg(*id),
+ var.v_type.clone(),
+ var.state_space,
+ ));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ }
+ }
+ if return_arguments.len() != func_decl.return_arguments.len()
+ || input_arguments.len() != func_decl.input_arguments.len()
+ {
+ return Err(TranslateError::MismatchedType);
+ }
+ Ok(ResolvedCall {
+ return_arguments,
+ input_arguments,
+ uniform: call_inst.uniform,
+ name: call_inst.func,
+ })
+ }
+}
+
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: SpecialRegistersMap,
- fns: HashMap<spirv::Word, Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>>,
+ fns: HashMap<spirv::Word, FnSigMapper<'input>>,
}
impl<'input> GlobalStringIdResolver<'input> {
@@ -4885,45 +4993,36 @@ impl<'input> GlobalStringIdResolver<'input> {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
};
- let new_fn_decl = Rc::new(RefCell::new(ast::MethodDeclaration {
+ let fn_decl = ast::MethodDeclaration {
return_arguments,
name,
input_arguments,
- }));
- self.fns.insert(name_id, Rc::clone(&new_fn_decl));
+ shared_mem: None,
+ };
+ let new_fn_decl = if !fn_decl.name.is_kernel() {
+ let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
+ let new_fn_decl = resolver.func_decl.clone();
+ self.fns.insert(name_id, resolver);
+ new_fn_decl
+ } else {
+ Rc::new(RefCell::new(fn_decl))
+ };
Ok((
fn_resolver,
- GlobalFnDeclResolver {
- variables: &self.variables,
- fns: &self.fns,
- },
+ GlobalFnDeclResolver { fns: &self.fns },
new_fn_decl,
))
}
}
pub struct GlobalFnDeclResolver<'input, 'a> {
- variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
- fns: &'a HashMap<spirv::Word, Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>>,
+ fns: &'a HashMap<spirv::Word, FnSigMapper<'input>>,
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
- fn get_fn_decl(
- &self,
- id: spirv::Word,
- ) -> Result<&Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>, TranslateError> {
+ fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
}
-
- fn get_fn_decl_str(
- &self,
- id: &str,
- ) -> Result<&'a Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>, TranslateError> {
- match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
- Some(Some(fn_d)) => Ok(fn_d),
- _ => Err(TranslateError::UnknownSymbol),
- }
- }
}
struct FnStringIdResolver<'input, 'b> {
@@ -5209,7 +5308,6 @@ struct LoadVarDetails {
struct StoreVarDetails {
arg: ast::Arg2St<ExpandedArgParams>,
typ: ast::Type,
- state_space: ast::StateSpace,
member_index: Option<u8>,
}
@@ -5300,7 +5398,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
- let ret_params = self
+ let return_arguments = self
.return_arguments
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
@@ -5324,7 +5422,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
},
None,
)?;
- let param_list = self
+ let input_arguments = self
.input_arguments
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
@@ -5342,9 +5440,9 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
uniform: self.uniform,
- return_arguments: ret_params,
+ return_arguments,
name: func,
- input_arguments: param_list,
+ input_arguments,
})
}
}
@@ -5485,7 +5583,6 @@ struct Function<'input> {
pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
- pub uses_shared_mem: bool,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
}
@@ -7185,6 +7282,19 @@ fn default_implicit_conversion_space(
},
_ => Err(TranslateError::MismatchedType),
}
+ } else if instruction_space.is_compatible(ast::StateSpace::Reg) {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ _ => Err(TranslateError::MismatchedType),
+ }
} else {
Err(TranslateError::MismatchedType)
}
@@ -7432,6 +7542,15 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
}
}
+impl<'input, ID> ast::MethodName<'input, ID> {
+ fn is_kernel(&self) -> bool {
+ match self {
+ ast::MethodName::Kernel(..) => true,
+ ast::MethodName::Func(..) => false,
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;