aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-03 01:45:08 +0200
committerAndrzej Janik <[email protected]>2020-09-03 01:45:08 +0200
commitde734305cfe8124c1a3a4a0adfee143e4ff5b680 (patch)
tree2939440ee2a67088ddda74f36a147b8035e128b2
parent0f4a4c634b3dd9e1117cb843fcde59498ac2ae07 (diff)
downloadZLUDA-de734305cfe8124c1a3a4a0adfee143e4ff5b680.tar.gz
ZLUDA-de734305cfe8124c1a3a4a0adfee143e4ff5b680.zip
Start refactoring SPIRV module generation in preparation for support of functions
-rw-r--r--notcuda/src/impl/module.rs5
-rw-r--r--ptx/src/ast.rs25
-rw-r--r--ptx/src/ptx.lalrpop19
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/translate.rs193
5 files changed, 141 insertions, 102 deletions
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index feae40b..491778a 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -14,10 +14,7 @@ pub enum ModuleCompileError<'a> {
}
impl<'a> ModuleCompileError<'a> {
- pub fn get_build_log(&self) {
-
- }
-
+ pub fn get_build_log(&self) {}
}
impl<'a> From<ptx::SpirvError> for ModuleCompileError<'a> {
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 5de1db6..7550d55 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -48,24 +48,25 @@ impl<
pub struct Module<'a> {
pub version: (u8, u8),
- pub functions: Vec<Function<'a>>,
+ pub functions: Vec<ParsedFunction<'a>>,
}
-pub enum FunctionReturn<'a> {
- Func(Vec<Argument<'a>>),
- Kernel,
+pub enum FunctionHeader<'a, P: ArgParams> {
+ Func(Vec<Argument<P>>, P::ID),
+ Kernel(&'a str),
}
-pub struct Function<'a> {
- pub func_directive: FunctionReturn<'a>,
- pub name: &'a str,
- pub args: Vec<Argument<'a>>,
- pub body: Option<Vec<Statement<ParsedArgParams<'a>>>>,
+pub struct Function<'a, P: ArgParams, S> {
+ pub func_directive: FunctionHeader<'a, P>,
+ pub args: Vec<Argument<P>>,
+ pub body: Option<Vec<S>>,
}
+pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>;
+
#[derive(Default)]
-pub struct Argument<'a> {
- pub name: &'a str,
+pub struct Argument<P: ArgParams> {
+ pub name: P::ID,
pub a_type: ScalarType,
pub length: u32,
}
@@ -231,7 +232,7 @@ pub struct CallData {
pub struct AbsDetails {
pub flush_to_zero: bool,
- pub typ: ScalarType
+ pub typ: ScalarType,
}
pub struct ArgCall<P: ArgParams> {
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 7438e97..7e38b78 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -189,7 +189,7 @@ TargetSpecifier = {
"map_f64_to_f32"
};
-Directive: Option<ast::Function<'input>> = {
+Directive: Option<ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement<ast::ParsedArgParams<'input>>>> = {
AddressSize => None,
<f:Function> => Some(f),
File => None,
@@ -200,12 +200,11 @@ AddressSize = {
".address_size" Num
};
-Function: ast::Function<'input> = {
+Function: ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement<ast::ParsedArgParams<'input>>> = {
LinkingDirective*
- <func_directive:FunctionReturn>
- <name:ExtendedID>
+ <func_directive:FunctionHeader>
<args:Arguments>
- <body:FunctionBody> => ast::Function{<>}
+ <body:FunctionBody> => ast::Function{<>}
};
LinkingDirective = {
@@ -214,17 +213,17 @@ LinkingDirective = {
".weak"
};
-FunctionReturn: ast::FunctionReturn<'input> = {
- ".entry" => ast::FunctionReturn::Kernel,
- ".func" <args:Arguments?> => ast::FunctionReturn::Func(args.unwrap_or_else(|| Vec::new()))
+FunctionHeader: ast::FunctionHeader<'input, ast::ParsedArgParams<'input>> = {
+ ".entry" <name:ExtendedID> => ast::FunctionHeader::Kernel(name),
+ ".func" <args:Arguments?> <name:ExtendedID> => ast::FunctionHeader::Func(args.unwrap_or_else(|| Vec::new()), name)
};
-Arguments: Vec<ast::Argument<'input>> = {
+Arguments: Vec<ast::Argument<ast::ParsedArgParams<'input>>> = {
"(" <args:Comma<FunctionInput>> ")" => args
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
-FunctionInput: ast::Argument<'input> = {
+FunctionInput: ast::Argument<ast::ParsedArgParams<'input>> = {
".param" <_type:ScalarType> <name:ExtendedID> => {
ast::Argument {a_type: _type, name: name, length: 1 }
},
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 8883669..9ea0100 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -52,6 +52,7 @@ test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]);
test_ptx!(local_align, [1u64], [1u64]);
+test_ptx!(call, [1u64], [2u64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 642e6ec..8cf3aca 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -154,7 +154,7 @@ impl TypeWordMap {
}
}
-pub fn to_spirv_module(ast: ast::Module) -> Result<dr::Module, dr::Error> {
+pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error> {
let mut builder = dr::Builder::new();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 3);
@@ -163,13 +163,21 @@ pub fn to_spirv_module(ast: ast::Module) -> Result<dr::Module, dr::Error> {
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
- for f in ast.functions {
- emit_function(&mut builder, &mut map, opencl_id, f)?;
+ let mut id_defs = GlobalStringIdResolver::new(builder.id());
+ let ssa_functions = ast
+ .functions
+ .into_iter()
+ .map(|f| to_ssa_function(&mut id_defs, opencl_id, f))
+ .collect::<Vec<_>>();
+ for f in ssa_functions {
+ emit_function_args(&mut builder, &mut map, &*f.args);
+ emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.body)?;
+ builder.end_function()?;
}
Ok(builder.module())
}
-pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> {
+pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, dr::Error> {
let module = to_spirv_module(ast)?;
Ok(module.assemble())
}
@@ -196,28 +204,28 @@ fn emit_memory_model(builder: &mut dr::Builder) {
);
}
-fn emit_function<'a>(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
+fn to_ssa_function<'a>(
+ id_defs: &mut GlobalStringIdResolver<'a>,
opencl_id: spirv::Word,
- f: ast::Function<'a>,
-) -> Result<spirv::Word, rspirv::dr::Error> {
- let func_type = get_function_type(builder, map, &f.args);
- let func_id =
- builder.begin_function(map.void(), None, spirv::FunctionControl::NONE, func_type)?;
- match f.func_directive {
- ast::FunctionReturn::Kernel => {
- builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[])
- }
- _ => todo!(),
+ f: ast::ParsedFunction<'a>,
+) -> ExpandedFunction<'a> {
+ let ids_start = id_defs.current_id();
+ let fn_resolver = FnStringIdResolver::new(id_defs);
+ let f_header = match f.func_directive {
+ ast::FunctionHeader::Kernel(name) => todo!(),
+ ast::FunctionHeader::Func(ret_params, name) => todo!(),
+ };
+ let f_args = todo!();
+ let f_body = Some(to_ssa(
+ fn_resolver,
+ &f.args,
+ f.body.unwrap_or_else(|| todo!()),
+ ));
+ ExpandedFunction {
+ func_directive: f_header,
+ args: f_args,
+ body: f_body,
}
- let (mut func_body, unique_ids) = to_ssa(&f.args, f.body.unwrap_or_else(|| todo!()));
- let id_offset = builder.reserve_ids(unique_ids);
- emit_function_args(builder, id_offset, map, &f.args);
- func_body = apply_id_offset(func_body, id_offset);
- emit_function_body_ops(builder, map, opencl_id, &func_body)?;
- builder.end_function()?;
- Ok(func_id)
}
fn apply_id_offset(func_body: Vec<ExpandedStatement>, id_offset: u32) -> Vec<ExpandedStatement> {
@@ -228,16 +236,19 @@ fn apply_id_offset(func_body: Vec<ExpandedStatement>, id_offset: u32) -> Vec<Exp
}
fn to_ssa<'a, 'b>(
- f_args: &'b [ast::Argument<'a>],
+ mut id_defs: FnStringIdResolver<'a, 'b>,
+ f_args: &'b [ast::Argument<ast::ParsedArgParams<'a>>],
f_body: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
-) -> (Vec<ExpandedStatement>, spirv::Word) {
- let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body);
- let normalized_statements = normalize_predicates(normalized_ids, &mut id_def);
- let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut id_def);
- let expanded_statements = expand_arguments(ssa_statements, &mut id_def);
- let expanded_statements = insert_implicit_conversions(expanded_statements, &mut id_def);
- let labeled_statements = normalize_labels(expanded_statements, &mut id_def);
- (labeled_statements, id_def.ids_count())
+) -> Vec<ExpandedStatement> {
+ let normalized_ids = normalize_identifiers(&mut id_defs, &f_args, f_body);
+ let mut numeric_id_defs = id_defs.finish();
+ let normalized_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
+ let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut numeric_id_defs);
+ let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
+ let expanded_statements =
+ insert_implicit_conversions(expanded_statements, &mut numeric_id_defs);
+ let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
+ labeled_statements
}
fn normalize_labels(
@@ -391,9 +402,9 @@ fn insert_mem_ssa_statements(
result
}
-fn expand_arguments(
+fn expand_arguments<'a, 'b, 'c>(
func: Vec<NormalizedStatement>,
- id_def: &mut NumericIdResolver,
+ id_def: &'c mut NumericIdResolver<'a, 'b>,
) -> Vec<ExpandedStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func {
@@ -416,18 +427,23 @@ fn expand_arguments(
result
}
-struct FlattenArguments<'a> {
- func: &'a mut Vec<ExpandedStatement>,
- id_def: &'a mut NumericIdResolver,
+struct FlattenArguments<'a, 'b, 'c> {
+ func: &'c mut Vec<ExpandedStatement>,
+ id_def: &'c mut NumericIdResolver<'a, 'b>,
}
-impl<'a> FlattenArguments<'a> {
- fn new(func: &'a mut Vec<ExpandedStatement>, id_def: &'a mut NumericIdResolver) -> Self {
+impl<'a, 'b, 'c> FlattenArguments<'a, 'b, 'c> {
+ fn new(
+ func: &'c mut Vec<ExpandedStatement>,
+ id_def: &'c mut NumericIdResolver<'a, 'b>,
+ ) -> Self {
FlattenArguments { func, id_def }
}
}
-impl<'a> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> for FlattenArguments<'a> {
+impl<'a, 'b, 'c> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
+ for FlattenArguments<'a, 'b, 'c>
+{
fn dst_variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
desc.op
}
@@ -577,18 +593,17 @@ fn insert_implicit_conversions(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- args: &[ast::Argument],
+ args: &[ast::Argument<ast::ParsedArgParams>],
) -> spirv::Word {
map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type)))
}
fn emit_function_args(
builder: &mut dr::Builder,
- id_offset: spirv::Word,
map: &mut TypeWordMap,
- args: &[ast::Argument],
+ args: &[ast::Argument<ExpandedArgParams>],
) {
- let mut id = id_offset;
+ let mut id = todo!();
for arg in args {
let result_type = map.get_or_add_scalar(builder, arg.a_type);
let inst = dr::Instruction::new(
@@ -606,9 +621,9 @@ fn emit_function_body_ops(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
- func: &[ExpandedStatement],
+ func: &Option<Vec<ExpandedStatement>>,
) -> Result<(), dr::Error> {
- for s in func {
+ for s in func.as_ref().unwrap() {
match s {
Statement::Label(id) => {
if builder.block.is_some() {
@@ -1079,10 +1094,10 @@ fn emit_implicit_conversion(
// TODO: support scopes
fn normalize_identifiers<'a, 'b>(
- args: &'b [ast::Argument<'a>],
+ id_defs: &mut FnStringIdResolver<'a, 'b>,
+ args: &[ast::Argument<ast::ParsedArgParams<'a>>],
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
-) -> (Vec<ast::Statement<NormalizedArgParams>>, NumericIdResolver) {
- let mut id_defs = StringIdResolver::new();
+) -> Vec<ast::Statement<NormalizedArgParams>> {
for arg in args {
id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type)));
}
@@ -1096,13 +1111,13 @@ fn normalize_identifiers<'a, 'b>(
}
let mut result = Vec::new();
for s in func {
- expand_map_variables(&mut id_defs, &mut result, s);
+ expand_map_variables(id_defs, &mut result, s);
}
- (result, id_defs.finish())
+ result
}
-fn expand_map_variables<'a>(
- id_defs: &mut StringIdResolver<'a>,
+fn expand_map_variables<'a, 'b>(
+ id_defs: &mut FnStringIdResolver<'a, 'b>,
result: &mut Vec<ast::Statement<NormalizedArgParams>>,
s: ast::Statement<ast::ParsedArgParams<'a>>,
) {
@@ -1145,24 +1160,53 @@ fn expand_map_variables<'a>(
}
}
-struct StringIdResolver<'a> {
+struct GlobalStringIdResolver<'a> {
current_id: spirv::Word,
+ variables: HashMap<Cow<'a, str>, spirv::Word>,
+}
+
+impl<'a> GlobalStringIdResolver<'a> {
+ fn new(start_id: spirv::Word) -> Self {
+ Self {
+ current_id: start_id,
+ variables: HashMap::new(),
+ }
+ }
+
+ fn add_def(&mut self, id: &'a str) -> spirv::Word {
+ let numeric_id = self.current_id;
+ self.variables.insert(Cow::Borrowed(id), numeric_id);
+ self.current_id += 1;
+ numeric_id
+ }
+
+ fn reserve_id(&mut self) {
+ self.current_id += 1;
+ }
+
+ fn current_id(&self) -> spirv::Word {
+ self.current_id
+ }
+}
+
+struct FnStringIdResolver<'a, 'b> {
+ global: &'b mut GlobalStringIdResolver<'a>,
variables: Vec<HashMap<Cow<'a, str>, spirv::Word>>,
type_check: HashMap<u32, ast::Type>,
}
-impl<'a> StringIdResolver<'a> {
- fn new() -> Self {
- StringIdResolver {
- current_id: 0u32,
+impl<'a, 'b> FnStringIdResolver<'a, 'b> {
+ fn new(global: &'b mut GlobalStringIdResolver<'a>) -> Self {
+ Self {
+ global: global,
variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
}
}
- fn finish(self) -> NumericIdResolver {
+ fn finish(self) -> NumericIdResolver<'a, 'b> {
NumericIdResolver {
- current_id: self.current_id,
+ global: self.global,
type_check: self.type_check,
}
}
@@ -1175,18 +1219,18 @@ impl<'a> StringIdResolver<'a> {
self.variables.pop();
}
- fn get_id(&self, id: &'a str) -> spirv::Word {
+ fn get_id(&self, id: &str) -> spirv::Word {
for scope in self.variables.iter().rev() {
match scope.get(id) {
Some(id) => return *id,
None => continue,
}
}
- panic!()
+ self.global.variables[id]
}
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
- let numeric_id = self.current_id;
+ let numeric_id = self.global.current_id;
self.variables
.last_mut()
.unwrap()
@@ -1194,7 +1238,7 @@ impl<'a> StringIdResolver<'a> {
if let Some(typ) = typ {
self.type_check.insert(numeric_id, typ);
}
- self.current_id += 1;
+ self.global.current_id += 1;
numeric_id
}
@@ -1205,7 +1249,7 @@ impl<'a> StringIdResolver<'a> {
count: u32,
typ: ast::Type,
) -> impl Iterator<Item = spirv::Word> {
- let numeric_id = self.current_id;
+ let numeric_id = self.global.current_id;
for i in 0..count {
self.variables
.last_mut()
@@ -1213,33 +1257,29 @@ impl<'a> StringIdResolver<'a> {
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
self.type_check.insert(numeric_id + i, typ);
}
- self.current_id += count;
+ self.global.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
}
}
-struct NumericIdResolver {
- current_id: spirv::Word,
+struct NumericIdResolver<'a, 'b> {
+ global: &'b mut GlobalStringIdResolver<'a>,
type_check: HashMap<u32, ast::Type>,
}
-impl NumericIdResolver {
+impl<'a, 'b> NumericIdResolver<'a, 'b> {
fn get_type(&self, id: spirv::Word) -> ast::Type {
self.type_check[&id]
}
fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
- let new_id = self.current_id;
+ let new_id = self.global.current_id;
if let Some(typ) = typ {
self.type_check.insert(new_id, typ);
}
- self.current_id += 1;
+ self.global.current_id += 1;
new_id
}
-
- fn ids_count(&self) -> spirv::Word {
- self.current_id
- }
}
enum Statement<I> {
@@ -1284,6 +1324,7 @@ impl ast::ArgParams for NormalizedArgParams {
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>>;
+type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;