summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-03 20:58:35 +0200
committerAndrzej Janik <[email protected]>2020-09-03 20:58:35 +0200
commitbbb3a6c5cbaff3430191ef4858aa16be8320ce77 (patch)
tree53499d025cb7909967b20dcc8c6d88b045cb8c5d
parentde734305cfe8124c1a3a4a0adfee143e4ff5b680 (diff)
downloadZLUDA-bbb3a6c5cbaff3430191ef4858aa16be8320ce77.tar.gz
ZLUDA-bbb3a6c5cbaff3430191ef4858aa16be8320ce77.zip
Finish up cleanup for PTX function support
-rw-r--r--notcuda/src/impl/export_table.rs2
-rw-r--r--ptx/src/test/spirv_run/mod.rs17
-rw-r--r--ptx/src/translate.rs168
3 files changed, 83 insertions, 104 deletions
diff --git a/notcuda/src/impl/export_table.rs b/notcuda/src/impl/export_table.rs
index afd9077..233c496 100644
--- a/notcuda/src/impl/export_table.rs
+++ b/notcuda/src/impl/export_table.rs
@@ -8,7 +8,7 @@ use super::{context, device, module, Decuda, Encuda};
use std::mem;
use std::os::raw::{c_uint, c_ulong, c_ushort};
use std::{
- ffi::{c_void, CStr, CString},
+ ffi::{c_void, CStr},
ptr, slice,
};
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 9ea0100..9f62292 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -190,14 +190,17 @@ fn test_spvtxt_assert<'a>(
ptr::null_mut()
)
};
- assert_eq!(result, spv_result_t::SPV_SUCCESS);
- let raw_text = unsafe {
- std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length)
- };
- let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) };
- // TODO: stop leaking kernel text
unsafe { spirv_tools::spvContextDestroy(spv_context) };
- panic!(spv_from_ptx_text);
+ if result == spv_result_t::SPV_SUCCESS {
+ let raw_text = unsafe {
+ std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length)
+ };
+ let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) };
+ // TODO: stop leaking kernel text
+ panic!(spv_from_ptx_text);
+ } else {
+ panic!(ptx_mod.disassemble());
+ }
}
unsafe { spirv_tools::spvContextDestroy(spv_context) };
Ok(())
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 8cf3aca..34d8c12 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -155,7 +155,14 @@ impl TypeWordMap {
}
pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error> {
+ let mut id_defs = GlobalStringIdResolver::new(1);
+ let ssa_functions = ast
+ .functions
+ .into_iter()
+ .map(|f| to_ssa_function(&mut id_defs, f))
+ .collect::<Vec<_>>();
let mut builder = dr::Builder::new();
+ builder.reserve_ids(id_defs.current_id());
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 3);
emit_capabilities(&mut builder);
@@ -163,13 +170,8 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
- let mut id_defs = GlobalStringIdResolver::new(builder.id());
- let ssa_functions = ast
- .functions
- .into_iter()
- .map(|f| to_ssa_function(&mut id_defs, opencl_id, f))
- .collect::<Vec<_>>();
for f in ssa_functions {
+ emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive, &*f.args)?;
emit_function_args(&mut builder, &mut map, &*f.args);
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.body)?;
builder.end_function()?;
@@ -177,6 +179,31 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error
Ok(builder.module())
}
+fn emit_function_header(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ global: &GlobalStringIdResolver,
+ func_directive: ast::FunctionHeader<ExpandedArgParams>,
+ params: &[ast::Argument<ExpandedArgParams>],
+) -> Result<(), dr::Error> {
+ let func_type = get_function_type(builder, map, params);
+ let (fn_id, ret_type) = match func_directive {
+ ast::FunctionHeader::Kernel(name) => {
+ let fn_id = global.get_id(name);
+ builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]);
+ (fn_id, map.void())
+ }
+ ast::FunctionHeader::Func(params, name) => todo!(),
+ };
+ builder.begin_function(
+ ret_type,
+ Some(fn_id),
+ spirv::FunctionControl::NONE,
+ func_type,
+ )?;
+ Ok(())
+}
+
pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, dr::Error> {
let module = to_spirv_module(ast)?;
Ok(module.assemble())
@@ -206,21 +233,19 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn to_ssa_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
- opencl_id: spirv::Word,
f: ast::ParsedFunction<'a>,
) -> ExpandedFunction<'a> {
- let ids_start = id_defs.current_id();
- let fn_resolver = FnStringIdResolver::new(id_defs);
+ let mut fn_resolver = FnStringIdResolver::new(id_defs, f.func_directive.name());
let f_header = match f.func_directive {
- ast::FunctionHeader::Kernel(name) => todo!(),
- ast::FunctionHeader::Func(ret_params, name) => todo!(),
+ ast::FunctionHeader::Kernel(name) => ast::FunctionHeader::Kernel(name),
+ ast::FunctionHeader::Func(ret_params, name) => {
+ let name_id = fn_resolver.add_global_def(name);
+ let ret_ids = expand_fn_params(&mut fn_resolver, ret_params);
+ ast::FunctionHeader::Func(ret_ids, name_id)
+ }
};
- let f_args = todo!();
- let f_body = Some(to_ssa(
- fn_resolver,
- &f.args,
- f.body.unwrap_or_else(|| todo!()),
- ));
+ let f_args = expand_fn_params(&mut fn_resolver, f.args);
+ let f_body = Some(to_ssa(fn_resolver, f.body.unwrap_or_else(|| Vec::new())));
ExpandedFunction {
func_directive: f_header,
args: f_args,
@@ -228,19 +253,24 @@ fn to_ssa_function<'a>(
}
}
-fn apply_id_offset(func_body: Vec<ExpandedStatement>, id_offset: u32) -> Vec<ExpandedStatement> {
- func_body
- .into_iter()
- .map(|s| s.visit_variable(&mut |id| id + id_offset))
+fn expand_fn_params<'a, 'b>(
+ fn_resolver: &mut FnStringIdResolver<'a, 'b>,
+ args: Vec<ast::Argument<ast::ParsedArgParams<'a>>>,
+) -> Vec<ast::Argument<ExpandedArgParams>> {
+ args.into_iter()
+ .map(|a| ast::Argument {
+ name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))),
+ a_type: a.a_type,
+ length: a.length,
+ })
.collect()
}
fn to_ssa<'a, 'b>(
mut id_defs: FnStringIdResolver<'a, 'b>,
- f_args: &'b [ast::Argument<ast::ParsedArgParams<'a>>],
f_body: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> Vec<ExpandedStatement> {
- let normalized_ids = normalize_identifiers(&mut id_defs, &f_args, f_body);
+ let normalized_ids = normalize_identifiers(&mut id_defs, f_body);
let mut numeric_id_defs = id_defs.finish();
let normalized_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut numeric_id_defs);
@@ -593,7 +623,7 @@ fn insert_implicit_conversions(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- args: &[ast::Argument<ast::ParsedArgParams>],
+ args: &[ast::Argument<ExpandedArgParams>],
) -> spirv::Word {
map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type)))
}
@@ -603,17 +633,15 @@ fn emit_function_args(
map: &mut TypeWordMap,
args: &[ast::Argument<ExpandedArgParams>],
) {
- let mut id = todo!();
for arg in args {
let result_type = map.get_or_add_scalar(builder, arg.a_type);
let inst = dr::Instruction::new(
spirv::Op::FunctionParameter,
Some(result_type),
- Some(id),
+ Some(arg.name),
Vec::new(),
);
builder.function.as_mut().unwrap().parameters.push(inst);
- id += 1;
}
}
@@ -1095,12 +1123,8 @@ fn emit_implicit_conversion(
// TODO: support scopes
fn normalize_identifiers<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
- args: &[ast::Argument<ast::ParsedArgParams<'a>>],
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::Statement<NormalizedArgParams>> {
- for arg in args {
- id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type)));
- }
for s in func.iter() {
match s {
ast::Statement::Label(id) => {
@@ -1180,8 +1204,8 @@ impl<'a> GlobalStringIdResolver<'a> {
numeric_id
}
- fn reserve_id(&mut self) {
- self.current_id += 1;
+ fn get_id(&self, id: &str) -> spirv::Word {
+ self.variables[id]
}
fn current_id(&self) -> spirv::Word {
@@ -1196,7 +1220,8 @@ struct FnStringIdResolver<'a, 'b> {
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
- fn new(global: &'b mut GlobalStringIdResolver<'a>) -> Self {
+ fn new(global: &'b mut GlobalStringIdResolver<'a>, f_name: &'a str) -> Self {
+ global.add_def(f_name);
Self {
global: global,
variables: vec![HashMap::new(); 1],
@@ -1229,6 +1254,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
self.global.variables[id]
}
+ fn add_global_def(&mut self, id: &'a str) -> spirv::Word {
+ self.global.add_def(id)
+ }
+
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
let numeric_id = self.global.current_id;
self.variables
@@ -1294,25 +1323,6 @@ enum Statement<I> {
Constant(ConstantDefinition),
}
-impl Statement<ast::Instruction<ExpandedArgParams>> {
- fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
- match self {
- Statement::Variable(id, t, ss, align) => Statement::Variable(f(id), t, ss, align),
- Statement::LoadVar(a, t) => {
- Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t)
- }
- Statement::StoreVar(a, t) => {
- Statement::StoreVar(a.map(&mut reduced_visitor(f), Some(t)), t)
- }
- Statement::Label(id) => Statement::Label(f(id)),
- Statement::Instruction(inst) => Statement::Instruction(inst.visit_variable(f)),
- Statement::Conditional(bra) => Statement::Conditional(bra.map(f)),
- Statement::Conversion(conv) => Statement::Conversion(conv.map(f)),
- Statement::Constant(cons) => Statement::Constant(cons.map(f)),
- }
- }
-}
-
enum NormalizedArgParams {}
type NormalizedStatement = Statement<ast::Instruction<NormalizedArgParams>>;
@@ -1513,18 +1523,7 @@ where
}
}
-fn reduced_visitor<'a>(
- f: &'a mut impl FnMut(spirv::Word) -> spirv::Word,
-) -> impl FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word + 'a {
- move |desc| f(desc.op)
-}
-
impl ast::Instruction<ExpandedArgParams> {
- fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
- let mut visitor = reduced_visitor(f);
- self.map(&mut visitor)
- }
-
fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
self,
f: &mut F,
@@ -1562,32 +1561,12 @@ struct ConstantDefinition {
pub value: i128,
}
-impl ConstantDefinition {
- fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
- Self {
- dst: f(self.dst),
- typ: self.typ,
- value: self.value,
- }
- }
-}
-
struct BrachCondition {
predicate: spirv::Word,
if_true: spirv::Word,
if_false: spirv::Word,
}
-impl BrachCondition {
- fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
- Self {
- predicate: f(self.predicate),
- if_true: f(self.if_true),
- if_false: f(self.if_false),
- }
- }
-}
-
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
@@ -1604,18 +1583,6 @@ enum ConversionKind {
Ptr(ast::LdStateSpace),
}
-impl ImplicitConversion {
- fn map<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
- Self {
- src: f(self.src),
- dst: f(self.dst),
- from: self.from,
- to: self.to,
- kind: self.kind,
- }
- }
-}
-
impl<T> ast::PredAt<T> {
fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> {
ast::PredAt {
@@ -2354,6 +2321,15 @@ fn insert_implicit_bitcasts(
}
}
+impl<'a> ast::FunctionHeader<'a, ast::ParsedArgParams<'a>> {
+ fn name(&self) -> &'a str {
+ match self {
+ ast::FunctionHeader::Kernel(name) => name,
+ ast::FunctionHeader::Func(_, name) => name,
+ }
+ }
+}
+
// CFGs below taken from "Modern Compiler Implementation in Java"
#[cfg(test)]
mod tests {