diff options
-rw-r--r-- | ptx/src/ast.rs | 6 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 38 | ||||
-rw-r--r-- | ptx/src/translate.rs | 119 |
3 files changed, 87 insertions, 76 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index a0bb023..5432207 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -82,8 +82,8 @@ pub struct Module<'a> { } pub enum Directive<'a, P: ArgParams> { - Variable(Variable<P::Id>), - Method(Function<'a, &'a str, Statement<P>>), + Variable(LinkingDirective, Variable<P::Id>), + Method(LinkingDirective, Function<'a, &'a str, Statement<P>>), } #[derive(Hash, PartialEq, Eq, Copy, Clone)] @@ -96,7 +96,7 @@ pub struct MethodDeclaration<'input, ID> { pub return_arguments: Vec<Variable<ID>>, pub name: MethodName<'input, ID>, pub input_arguments: Vec<Variable<ID>>, - pub shared_mem: Option<Variable<ID>>, + pub shared_mem: Option<ID>, } pub struct Function<'a, ID, S> { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index e8370cd..b697317 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -343,10 +343,16 @@ TargetSpecifier = { Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = { AddressSize => None, - <f:Function> => Some(ast::Directive::Method(f)), + <f:Function> => { + let (linking, func) = f; + Some(ast::Directive::Method(linking, func)) + }, File => None, Section => None, - <v:ModuleVariable> ";" => Some(ast::Directive::Variable(v)), + <v:ModuleVariable> ";" => { + let (linking, var) = v; + Some(ast::Directive::Variable(linking, var)) + }, ! => { let err = <>; errors.push(err.error); @@ -358,11 +364,13 @@ AddressSize = { ".address_size" U8Num }; -Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = { - LinkingDirectives +Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>) = { + <linking:LinkingDirectives> <func_directive:MethodDeclaration> <tuning:TuningDirective*> - <body:FunctionBody> => ast::Function{<>} + <body:FunctionBody> => { + (linking, ast::Function{func_directive, tuning, body}) + } }; LinkingDirective: ast::LinkingDirective = { @@ -598,18 +606,18 @@ SharedVariable: ast::Variable<&'input str> = { } } -ModuleVariable: ast::Variable<&'input str> = { - LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => { +ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = { + <linking:LinkingDirectives> ".global" <def:GlobalVariableDefinitionNoArray> => { let (align, v_type, name, array_init) = def; let state_space = ast::StateSpace::Global; - ast::Variable { align, v_type, state_space, name, array_init } + (linking, ast::Variable { align, v_type, state_space, name, array_init }) }, - LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => { + <linking:LinkingDirectives> ".shared" <def:GlobalVariableDefinitionNoArray> => { let (align, v_type, name, array_init) = def; let state_space = ast::StateSpace::Shared; - ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } + (linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }) }, - <ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? { + <linking:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? { let (align, t, name, arr_or_ptr) = var; let (v_type, state_space, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { @@ -620,17 +628,17 @@ ModuleVariable: ast::Variable<&'input str> = { } } ast::ArrayOrPointer::Pointer => { - if !ldirs.contains(ast::LinkingDirective::EXTERN) { + if !linking.contains(ast::LinkingDirective::EXTERN) { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new()) } else { - (ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new()) } } }; - Ok(ast::Variable{ align, v_type, state_space, name, array_init }) + Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init })) } } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a0b5077..6b9dcfb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -172,14 +172,18 @@ impl TypeWordMap { .or_insert_with(|| b.type_vector(None, base, len as u32))
}
SpirvType::Array(typ, array_dimensions) => {
- let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let (base_type, length) = match &*array_dimensions {
+ &[] => {
+ return self.get_or_add(b, SpirvType::Base(typ));
+ }
&[len] => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self.get_or_add_spirv_scalar(b, typ);
let len_const = b.constant_u32(u32_type, None, len);
(base, len_const)
}
array_dimensions => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self
.get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
@@ -221,7 +225,7 @@ impl TypeWordMap { fn get_or_add_fn(
&mut self,
b: &mut dr::Builder,
- in_params: impl ExactSizeIterator<Item = SpirvType>,
+ in_params: impl Iterator<Item = SpirvType>,
mut out_params: impl ExactSizeIterator<Item = SpirvType>,
) -> (spirv::Word, spirv::Word) {
let (out_args, out_spirv_type) = if out_params.len() == 0 {
@@ -233,6 +237,7 @@ impl TypeWordMap { self.get_or_add(b, arg_as_key),
)
} else {
+ // TODO: support multiple return values
todo!()
};
(
@@ -436,7 +441,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
let call_map = get_kernels_call_map(&directives);
- //let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
+ let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@@ -528,7 +533,7 @@ fn emit_directives<'input>( let empty_body = Vec::new();
for d in directives.iter() {
match d {
- Directive::Variable(var) => {
+ Directive::Variable(_, var) => {
emit_variable(builder, map, &var)?;
}
Directive::Method(f) => {
@@ -699,7 +704,6 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, transformation has a semantical meaning - we emit additional
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
*/
-/*
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
@@ -707,13 +711,16 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
- Directive::Variable(ast::Variable {
- v_type: ast::Type::Pointer(p_type),
- state_space: ast::StateSpace::Shared,
- name,
- ..
- }) => {
- extern_shared_decls.insert(*name, p_type.clone());
+ Directive::Variable(
+ linking,
+ ast::Variable {
+ v_type: ast::Type::Array(p_type, dims),
+ state_space: ast::StateSpace::Shared,
+ name,
+ ..
+ },
+ ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
+ extern_shared_decls.insert(*name, *p_type);
}
_ => {}
}
@@ -732,14 +739,13 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
tuning,
- uses_shared_mem,
}) => {
let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| match statement {
Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call.func, call_key);
+ multi_hash_map_append(&mut directly_called_by, call.name, call_key);
Statement::Call(call)
}
statement => statement.map_id(&mut |id, _| {
@@ -756,7 +762,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
tuning,
- uses_shared_mem,
})
}
directive => directive,
@@ -775,7 +780,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
tuning,
- uses_shared_mem,
}) => {
if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
@@ -784,21 +788,12 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
tuning,
- uses_shared_mem,
});
}
let shared_id_param = new_id();
{
let mut func_decl = (*func_decl).borrow_mut();
- func_decl.input_arguments.push({
- ast::Variable {
- name: shared_id_param,
- align: None,
- v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()),
- state_space: ast::StateSpace::Shared,
- array_init: Vec::new(),
- }
- });
+ func_decl.shared_mem = Some(shared_id_param);
}
let statements = replace_uses_of_shared_memory(
new_id,
@@ -813,7 +808,6 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
tuning,
- uses_shared_mem: true,
})
}
directive => directive,
@@ -835,8 +829,8 @@ fn replace_uses_of_shared_memory<'a>( // We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) {
- call.param_list.push((
+ if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
+ call.input_arguments.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
ast::StateSpace::Shared,
@@ -854,13 +848,11 @@ fn replace_uses_of_shared_memory<'a>( result.push(Statement::Conversion(ImplicitConversion {
src: shared_id_param,
dst: replacement_id,
- from_type: ast::Type::Pointer(ast::ScalarType::B8),
+ from_type: ast::Type::Scalar(ast::ScalarType::B8),
from_space: ast::StateSpace::Shared,
- to_type: ast::Type::Pointer((*scalar_type).into()),
+ to_type: ast::Type::Scalar(*scalar_type),
to_space: ast::StateSpace::Shared,
- kind: ConversionKind::PtrToPtr { spirv_ptr: true },
- src_
- dst_
+ kind: ConversionKind::PtrToPtr,
}));
replacement_id
} else {
@@ -912,7 +904,6 @@ fn get_callers_of_extern_shared_single<'a>( }
}
}
-*/
type DenormCountMap<T> = HashMap<T, isize>;
@@ -948,7 +939,7 @@ fn compute_denorm_information<'input>( let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
- Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
func_decl,
body: Some(statements),
@@ -1158,14 +1149,17 @@ fn translate_directive<'input>( d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Option<Directive<'input>>, TranslateError> {
Ok(match d {
- ast::Directive::Variable(var) => Some(Directive::Variable(ast::Variable {
- align: var.align,
- v_type: var.v_type.clone(),
- state_space: var.state_space,
- name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
- array_init: var.array_init,
- })),
- ast::Directive::Method(f) => {
+ ast::Directive::Variable(linking, var) => Some(Directive::Variable(
+ linking,
+ ast::Variable {
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
+ array_init: var.array_init,
+ },
+ )),
+ ast::Directive::Method(_, f) => {
translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method)
}
})
@@ -2576,7 +2570,7 @@ fn insert_implicit_conversions_impl( fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- spirv_input: impl ExactSizeIterator<Item = SpirvType>,
+ spirv_input: impl Iterator<Item = SpirvType>,
spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
@@ -5597,7 +5591,7 @@ impl ast::ArgParams for ExpandedArgParams { impl ArgParamsEx for ExpandedArgParams {}
enum Directive<'input> {
- Variable(ast::Variable<spirv::Word>),
+ Variable(ast::LinkingDirective, ast::Variable<spirv::Word>),
Method(Function<'input>),
}
@@ -7582,19 +7576,28 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> { }
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
- fn effective_input_arguments(
- &self,
- ) -> impl ExactSizeIterator<Item = (spirv::Word, SpirvType)> + '_ {
+ fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
let is_kernel = self.name.is_kernel();
- self.input_arguments.iter().map(move |arg| {
- if !is_kernel && arg.state_space != ast::StateSpace::Reg {
- let spirv_type =
- SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
- (arg.name, spirv_type)
- } else {
- (arg.name, SpirvType::new(arg.v_type.clone()))
- }
- })
+ self.input_arguments
+ .iter()
+ .map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
+ .chain(self.shared_mem.iter().map(|id| {
+ (
+ *id,
+ SpirvType::Pointer(
+ Box::new(SpirvType::Base(SpirvScalarKey::B8)),
+ spirv::StorageClass::Workgroup,
+ ),
+ )
+ }))
}
}
|