aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ptx/src/ast.rs6
-rw-r--r--ptx/src/ptx.lalrpop38
-rw-r--r--ptx/src/translate.rs119
3 files changed, 87 insertions, 76 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index a0bb023..5432207 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -82,8 +82,8 @@ pub struct Module<'a> {
}
pub enum Directive<'a, P: ArgParams> {
- Variable(Variable<P::Id>),
- Method(Function<'a, &'a str, Statement<P>>),
+ Variable(LinkingDirective, Variable<P::Id>),
+ Method(LinkingDirective, Function<'a, &'a str, Statement<P>>),
}
#[derive(Hash, PartialEq, Eq, Copy, Clone)]
@@ -96,7 +96,7 @@ pub struct MethodDeclaration<'input, ID> {
pub return_arguments: Vec<Variable<ID>>,
pub name: MethodName<'input, ID>,
pub input_arguments: Vec<Variable<ID>>,
- pub shared_mem: Option<Variable<ID>>,
+ pub shared_mem: Option<ID>,
}
pub struct Function<'a, ID, S> {
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index e8370cd..b697317 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -343,10 +343,16 @@ TargetSpecifier = {
Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
AddressSize => None,
- <f:Function> => Some(ast::Directive::Method(f)),
+ <f:Function> => {
+ let (linking, func) = f;
+ Some(ast::Directive::Method(linking, func))
+ },
File => None,
Section => None,
- <v:ModuleVariable> ";" => Some(ast::Directive::Variable(v)),
+ <v:ModuleVariable> ";" => {
+ let (linking, var) = v;
+ Some(ast::Directive::Variable(linking, var))
+ },
! => {
let err = <>;
errors.push(err.error);
@@ -358,11 +364,13 @@ AddressSize = {
".address_size" U8Num
};
-Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
- LinkingDirectives
+Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>) = {
+ <linking:LinkingDirectives>
<func_directive:MethodDeclaration>
<tuning:TuningDirective*>
- <body:FunctionBody> => ast::Function{<>}
+ <body:FunctionBody> => {
+ (linking, ast::Function{func_directive, tuning, body})
+ }
};
LinkingDirective: ast::LinkingDirective = {
@@ -598,18 +606,18 @@ SharedVariable: ast::Variable<&'input str> = {
}
}
-ModuleVariable: ast::Variable<&'input str> = {
- LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
+ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = {
+ <linking:LinkingDirectives> ".global" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
let state_space = ast::StateSpace::Global;
- ast::Variable { align, v_type, state_space, name, array_init }
+ (linking, ast::Variable { align, v_type, state_space, name, array_init })
},
- LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
+ <linking:LinkingDirectives> ".shared" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
let state_space = ast::StateSpace::Shared;
- ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
+ (linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() })
},
- <ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
+ <linking:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
let (v_type, state_space, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
@@ -620,17 +628,17 @@ ModuleVariable: ast::Variable<&'input str> = {
}
}
ast::ArrayOrPointer::Pointer => {
- if !ldirs.contains(ast::LinkingDirective::EXTERN) {
+ if !linking.contains(ast::LinkingDirective::EXTERN) {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
}
if space == ".global" {
- (ast::Type::Scalar(t), ast::StateSpace::Global, Vec::new())
+ (ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new())
} else {
- (ast::Type::Scalar(t), ast::StateSpace::Shared, Vec::new())
+ (ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new())
}
}
};
- Ok(ast::Variable{ align, v_type, state_space, name, array_init })
+ Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init }))
}
}
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index a0b5077..6b9dcfb 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -172,14 +172,18 @@ impl TypeWordMap {
.or_insert_with(|| b.type_vector(None, base, len as u32))
}
SpirvType::Array(typ, array_dimensions) => {
- let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let (base_type, length) = match &*array_dimensions {
+ &[] => {
+ return self.get_or_add(b, SpirvType::Base(typ));
+ }
&[len] => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self.get_or_add_spirv_scalar(b, typ);
let len_const = b.constant_u32(u32_type, None, len);
(base, len_const)
}
array_dimensions => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self
.get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
@@ -221,7 +225,7 @@ impl TypeWordMap {
fn get_or_add_fn(
&mut self,
b: &mut dr::Builder,
- in_params: impl ExactSizeIterator<Item = SpirvType>,
+ in_params: impl Iterator<Item = SpirvType>,
mut out_params: impl ExactSizeIterator<Item = SpirvType>,
) -> (spirv::Word, spirv::Word) {
let (out_args, out_spirv_type) = if out_params.len() == 0 {
@@ -233,6 +237,7 @@ impl TypeWordMap {
self.get_or_add(b, arg_as_key),
)
} else {
+ // TODO: support multiple return values
todo!()
};
(
@@ -436,7 +441,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
let call_map = get_kernels_call_map(&directives);
- //let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
+ let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@@ -528,7 +533,7 @@ fn emit_directives<'input>(
let empty_body = Vec::new();
for d in directives.iter() {
match d {
- Directive::Variable(var) => {
+ Directive::Variable(_, var) => {
emit_variable(builder, map, &var)?;
}
Directive::Method(f) => {
@@ -699,7 +704,6 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
transformation has a semantical meaning - we emit additional
"OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
*/
-/*
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
@@ -707,13 +711,16 @@ fn convert_dynamic_shared_memory_usage<'input>(
let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
- Directive::Variable(ast::Variable {
- v_type: ast::Type::Pointer(p_type),
- state_space: ast::StateSpace::Shared,
- name,
- ..
- }) => {
- extern_shared_decls.insert(*name, p_type.clone());
+ Directive::Variable(
+ linking,
+ ast::Variable {
+ v_type: ast::Type::Array(p_type, dims),
+ state_space: ast::StateSpace::Shared,
+ name,
+ ..
+ },
+ ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
+ extern_shared_decls.insert(*name, *p_type);
}
_ => {}
}
@@ -732,14 +739,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
- uses_shared_mem,
}) => {
let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| match statement {
Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call.func, call_key);
+ multi_hash_map_append(&mut directly_called_by, call.name, call_key);
Statement::Call(call)
}
statement => statement.map_id(&mut |id, _| {
@@ -756,7 +762,6 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
- uses_shared_mem,
})
}
directive => directive,
@@ -775,7 +780,6 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
- uses_shared_mem,
}) => {
if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
@@ -784,21 +788,12 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
- uses_shared_mem,
});
}
let shared_id_param = new_id();
{
let mut func_decl = (*func_decl).borrow_mut();
- func_decl.input_arguments.push({
- ast::Variable {
- name: shared_id_param,
- align: None,
- v_type: ast::Type::Pointer(ast::ScalarType::B8, new_todo!()),
- state_space: ast::StateSpace::Shared,
- array_init: Vec::new(),
- }
- });
+ func_decl.shared_mem = Some(shared_id_param);
}
let statements = replace_uses_of_shared_memory(
new_id,
@@ -813,7 +808,6 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
- uses_shared_mem: true,
})
}
directive => directive,
@@ -835,8 +829,8 @@ fn replace_uses_of_shared_memory<'a>(
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&ast::MethodName::Func(call.func)) {
- call.param_list.push((
+ if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
+ call.input_arguments.push((
shared_id_param,
ast::Type::Scalar(ast::ScalarType::B8),
ast::StateSpace::Shared,
@@ -854,13 +848,11 @@ fn replace_uses_of_shared_memory<'a>(
result.push(Statement::Conversion(ImplicitConversion {
src: shared_id_param,
dst: replacement_id,
- from_type: ast::Type::Pointer(ast::ScalarType::B8),
+ from_type: ast::Type::Scalar(ast::ScalarType::B8),
from_space: ast::StateSpace::Shared,
- to_type: ast::Type::Pointer((*scalar_type).into()),
+ to_type: ast::Type::Scalar(*scalar_type),
to_space: ast::StateSpace::Shared,
- kind: ConversionKind::PtrToPtr { spirv_ptr: true },
- src_
- dst_
+ kind: ConversionKind::PtrToPtr,
}));
replacement_id
} else {
@@ -912,7 +904,6 @@ fn get_callers_of_extern_shared_single<'a>(
}
}
}
-*/
type DenormCountMap<T> = HashMap<T, isize>;
@@ -948,7 +939,7 @@ fn compute_denorm_information<'input>(
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
- Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
func_decl,
body: Some(statements),
@@ -1158,14 +1149,17 @@ fn translate_directive<'input>(
d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Option<Directive<'input>>, TranslateError> {
Ok(match d {
- ast::Directive::Variable(var) => Some(Directive::Variable(ast::Variable {
- align: var.align,
- v_type: var.v_type.clone(),
- state_space: var.state_space,
- name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
- array_init: var.array_init,
- })),
- ast::Directive::Method(f) => {
+ ast::Directive::Variable(linking, var) => Some(Directive::Variable(
+ linking,
+ ast::Variable {
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
+ array_init: var.array_init,
+ },
+ )),
+ ast::Directive::Method(_, f) => {
translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method)
}
})
@@ -2576,7 +2570,7 @@ fn insert_implicit_conversions_impl(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- spirv_input: impl ExactSizeIterator<Item = SpirvType>,
+ spirv_input: impl Iterator<Item = SpirvType>,
spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
@@ -5597,7 +5591,7 @@ impl ast::ArgParams for ExpandedArgParams {
impl ArgParamsEx for ExpandedArgParams {}
enum Directive<'input> {
- Variable(ast::Variable<spirv::Word>),
+ Variable(ast::LinkingDirective, ast::Variable<spirv::Word>),
Method(Function<'input>),
}
@@ -7582,19 +7576,28 @@ impl<'a> ast::MethodDeclaration<'a, &'a str> {
}
impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
- fn effective_input_arguments(
- &self,
- ) -> impl ExactSizeIterator<Item = (spirv::Word, SpirvType)> + '_ {
+ fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
let is_kernel = self.name.is_kernel();
- self.input_arguments.iter().map(move |arg| {
- if !is_kernel && arg.state_space != ast::StateSpace::Reg {
- let spirv_type =
- SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
- (arg.name, spirv_type)
- } else {
- (arg.name, SpirvType::new(arg.v_type.clone()))
- }
- })
+ self.input_arguments
+ .iter()
+ .map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
+ .chain(self.shared_mem.iter().map(|id| {
+ (
+ *id,
+ SpirvType::Pointer(
+ Box::new(SpirvType::Base(SpirvScalarKey::B8)),
+ spirv::StorageClass::Workgroup,
+ ),
+ )
+ }))
}
}