diff options
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r-- | ptx/src/translate.rs | 3316 |
1 files changed, 1544 insertions, 1772 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7170950..c2562c3 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,11 +1,9 @@ use crate::ast;
use half::f16;
use rspirv::dr;
-use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
-use std::{
- collections::{hash_map, HashMap, HashSet},
- convert::TryInto,
-};
+use std::cell::RefCell;
+use std::collections::{hash_map, HashMap, HashSet};
+use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc};
use rspirv::binary::Assemble;
@@ -48,64 +46,21 @@ enum SpirvType { }
impl SpirvType {
- fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
- let key = t.into();
- SpirvType::Pointer(Box::new(key), sc)
- }
-}
-
-impl From<ast::Type> for SpirvType {
- fn from(t: ast::Type) -> Self {
+ fn new(t: ast::Type) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
- ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer(
- Box::new(SpirvType::from(ast::Type::from(pointer_t))),
- state_space.to_spirv(),
+ ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
+ Box::new(SpirvType::Base(pointer_t.into())),
+ space.to_spirv(),
),
}
}
-}
-impl From<ast::PointerType> for ast::Type {
- fn from(t: ast::PointerType) -> Self {
- match t {
- ast::PointerType::Scalar(t) => ast::Type::Scalar(t),
- ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len),
- ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims),
- ast::PointerType::Pointer(t, space) => {
- ast::Type::Pointer(ast::PointerType::Scalar(t), space)
- }
- }
- }
-}
-
-impl ast::Type {
- fn param_pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
- Ok(match self {
- ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Vector(t, len) => {
- ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
- }
- ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
- ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
- }
- ast::Type::Pointer(_, _) => return Err(error_unreachable()),
- })
- }
-}
-
-impl Into<spirv::StorageClass> for ast::PointerStateSpace {
- fn into(self) -> spirv::StorageClass {
- match self {
- ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::PointerStateSpace::Param => spirv::StorageClass::Function,
- ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
- }
+ fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self {
+ let key = Self::new(t);
+ SpirvType::Pointer(Box::new(key), outer_space)
}
}
@@ -213,14 +168,18 @@ impl TypeWordMap { .or_insert_with(|| b.type_vector(None, base, len as u32))
}
SpirvType::Array(typ, array_dimensions) => {
- let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let (base_type, length) = match &*array_dimensions {
+ &[] => {
+ return self.get_or_add(b, SpirvType::Base(typ));
+ }
&[len] => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self.get_or_add_spirv_scalar(b, typ);
let len_const = b.constant_u32(u32_type, None, len);
(base, len_const)
}
array_dimensions => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self
.get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
@@ -262,7 +221,7 @@ impl TypeWordMap { fn get_or_add_fn(
&mut self,
b: &mut dr::Builder,
- in_params: impl ExactSizeIterator<Item = SpirvType>,
+ in_params: impl Iterator<Item = SpirvType>,
mut out_params: impl ExactSizeIterator<Item = SpirvType>,
) -> (spirv::Word, spirv::Word) {
let (out_args, out_spirv_type) = if out_params.len() == 0 {
@@ -274,6 +233,7 @@ impl TypeWordMap { self.get_or_add(b, arg_as_key),
)
} else {
+ // TODO: support multiple return values
todo!()
};
(
@@ -410,18 +370,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, components.into_iter())
}
},
- ast::Type::Pointer(typ, state_space) => {
- let base_t = typ.clone().into();
- let base = self.get_or_add_constant(b, &base_t, &[])?;
- let result_type = self.get_or_add(
- b,
- SpirvType::Pointer(
- Box::new(SpirvType::from(base_t)),
- (*state_space).to_spirv(),
- ),
- );
- b.variable(result_type, None, (*state_space).to_spirv(), Some(base))
- }
+ ast::Type::Pointer(..) => return Err(error_unreachable()),
})
}
@@ -487,7 +436,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro .collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
- let call_map = get_call_map(&directives);
+ let call_map = get_kernels_call_map(&directives);
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
@@ -525,9 +474,12 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro }
// TODO: remove this once we have perf-function support for denorms
-fn emit_denorm_build_string(
+fn emit_denorm_build_string<'input>(
call_map: &HashMap<&str, HashSet<u32>>,
- denorm_information: &HashMap<MethodName, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
) -> CString {
let denorm_counts = denorm_information
.iter()
@@ -545,10 +497,12 @@ fn emit_denorm_build_string( .collect::<HashMap<_, _>>();
let mut flush_over_preserve = 0;
for (kernel, children) in call_map {
- flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Kernel(kernel))
+ .unwrap_or(&0);
for child_fn in children {
flush_over_preserve += *denorm_counts
- .get(&MethodName::Func(*child_fn))
+ .get(&ast::MethodName::Func(*child_fn))
.unwrap_or(&0);
}
}
@@ -564,15 +518,18 @@ fn emit_directives<'input>( map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver<'input>,
opencl_id: spirv::Word,
- denorm_information: &HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
- directives: Vec<Directive>,
+ directives: Vec<Directive<'input>>,
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
let empty_body = Vec::new();
for d in directives.iter() {
match d {
- Directive::Variable(var) => {
+ Directive::Variable(_, var) => {
emit_variable(builder, map, &var)?;
}
Directive::Method(f) => {
@@ -589,12 +546,13 @@ fn emit_directives<'input>( for var in f.globals.iter() {
emit_variable(builder, map, var)?;
}
+ let func_decl = (*f.func_decl).borrow();
let fn_id = emit_function_header(
builder,
map,
&id_defs,
&f.globals,
- &f.spirv_decl,
+ &*func_decl,
&denorm_information,
call_map,
&directives,
@@ -623,8 +581,13 @@ fn emit_directives<'input>( }
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
- if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
- (&f.func_decl, &f.import_as)
+ if let (
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func(fn_id),
+ ..
+ },
+ Some(name),
+ ) = (&*func_decl, &f.import_as)
{
builder.decorate(
*fn_id,
@@ -643,7 +606,7 @@ fn emit_directives<'input>( Ok(())
}
-fn get_call_map<'input>(
+fn get_kernels_call_map<'input>(
module: &[Directive<'input>],
) -> HashMap<&'input str, HashSet<spirv::Word>> {
let mut directly_called_by = HashMap::new();
@@ -654,14 +617,14 @@ fn get_call_map<'input>( body: Some(statements),
..
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
entry.insert(Vec::new());
}
for statement in statements {
match statement {
Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call_key, call.func);
+ multi_hash_map_append(&mut directly_called_by, call_key, call.name);
}
_ => {}
}
@@ -673,28 +636,28 @@ fn get_call_map<'input>( let mut result = HashMap::new();
for (method_key, children) in directly_called_by.iter() {
match method_key {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let mut visited = HashSet::new();
for child in children {
add_call_map_single(&directly_called_by, &mut visited, *child);
}
result.insert(*name, visited);
}
- MethodName::Func(_) => {}
+ ast::MethodName::Func(_) => {}
}
}
result
}
fn add_call_map_single<'input>(
- directly_called_by: &MultiHashMap<MethodName<'input>, spirv::Word>,
+ directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
if !visited.insert(current) {
return;
}
- if let Some(children) = directly_called_by.get(&MethodName::Func(current)) {
+ if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
for child in children {
add_call_map_single(directly_called_by, visited, *child);
}
@@ -714,11 +677,29 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, }
}
-// PTX represents dynamically allocated shared local memory as
-// .extern .shared .align 4 .b8 shared_mem[];
-// In SPIRV/OpenCL world this is expressed as an additional argument
-// This pass looks for all uses of .extern .shared and converts them to
-// an additional method argument
+/*
+ PTX represents dynamically allocated shared local memory as
+ .extern .shared .b32 shared_mem[];
+ In SPIRV/OpenCL world this is expressed as an additional argument
+ This pass looks for all uses of .extern .shared and converts them to
+ an additional method argument
+ The question is how this artificial argument should be expressed. There are
+ several options:
+ * Straight conversion:
+ .shared .b32 shared_mem[]
+ * Introduce .param_shared statespace:
+ .param_shared .b32 shared_mem
+ or
+ .param_shared .b32 shared_mem[]
+ * Introduce .shared_ptr <SCALAR> type:
+ .param .shared_ptr .b32 shared_mem
+ * Reuse .ptr hint:
+ .param .u64 .ptr shared_mem
+ This is the most tempting, but also the most nonsensical, .ptr is just a
+ hint, which has no semantical meaning (and the output of our
+ transformation has a semantical meaning - we emit additional
+ "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
+*/
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
@@ -726,12 +707,16 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
- Directive::Variable(var) => {
- if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
- var.v_type
- {
- extern_shared_decls.insert(var.name, p_type);
- }
+ Directive::Variable(
+ linking,
+ ast::Variable {
+ v_type: ast::Type::Array(p_type, dims),
+ state_space: ast::StateSpace::Shared,
+ name,
+ ..
+ },
+ ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
+ extern_shared_decls.insert(*name, *p_type);
}
_ => {}
}
@@ -749,15 +734,14 @@ fn convert_dynamic_shared_memory_usage<'input>( globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| match statement {
Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call.func, call_key);
+ multi_hash_map_append(&mut directly_called_by, call.name, call_key);
Statement::Call(call)
}
statement => statement.map_id(&mut |id, _| {
@@ -773,7 +757,6 @@ fn convert_dynamic_shared_memory_usage<'input>( globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
})
}
@@ -792,66 +775,34 @@ fn convert_dynamic_shared_memory_usage<'input>( globals,
body: Some(statements),
import_as,
- mut spirv_decl,
tuning,
}) => {
- if !methods_using_extern_shared.contains(&spirv_decl.name) {
+ if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
});
}
let shared_id_param = new_id();
- spirv_decl.input.push({
- ast::Variable {
- align: None,
- v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Shared,
- ),
- array_init: Vec::new(),
- name: shared_id_param,
- }
- });
- spirv_decl.uses_shared_mem = true;
- let shared_var_id = new_id();
- let shared_var = ExpandedStatement::Variable(ast::Variable {
- align: None,
- name: shared_var_id,
- array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::B8,
- ast::PointerStateSpace::Shared,
- )),
- });
- let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: shared_var_id,
- src2: shared_id_param,
- },
- typ: ast::Type::Scalar(ast::ScalarType::B8),
- member_index: None,
- });
- let mut new_statements = vec![shared_var, shared_var_st];
- replace_uses_of_shared_memory(
- &mut new_statements,
+ {
+ let mut func_decl = (*func_decl).borrow_mut();
+ func_decl.shared_mem = Some(shared_id_param);
+ }
+ let statements = replace_uses_of_shared_memory(
new_id,
&extern_shared_decls,
&mut methods_using_extern_shared,
shared_id_param,
- shared_var_id,
statements,
);
Directive::Method(Function {
func_decl,
globals,
- body: Some(new_statements),
+ body: Some(statements),
import_as,
- spirv_decl,
tuning,
})
}
@@ -861,47 +812,43 @@ fn convert_dynamic_shared_memory_usage<'input>( }
fn replace_uses_of_shared_memory<'a>(
- result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
+ extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
shared_id_param: spirv::Word,
- shared_var_id: spirv::Word,
statements: Vec<ExpandedStatement>,
-) {
+) -> Vec<ExpandedStatement> {
+ let mut result = Vec::with_capacity(statements.len());
for statement in statements {
match statement {
Statement::Call(mut call) => {
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
- call.param_list
- .push((shared_id_param, ast::FnArgumentType::Shared));
+ if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
+ call.input_arguments.push((
+ shared_id_param,
+ ast::Type::Scalar(ast::ScalarType::B8),
+ ast::StateSpace::Shared,
+ ));
}
result.push(Statement::Call(call))
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
- if let Some(typ) = extern_shared_decls.get(&id) {
- if *typ == ast::SizedScalarType::B8 {
- return shared_var_id;
+ if let Some(scalar_type) = extern_shared_decls.get(&id) {
+ if *scalar_type == ast::ScalarType::B8 {
+ return shared_id_param;
}
let replacement_id = new_id();
result.push(Statement::Conversion(ImplicitConversion {
- src: shared_var_id,
+ src: shared_id_param,
dst: replacement_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::B8),
- ast::LdStateSpace::Shared,
- ),
- to: ast::Type::Pointer(
- ast::PointerType::Scalar((*typ).into()),
- ast::LdStateSpace::Shared,
- ),
- kind: ConversionKind::PtrToPtr { spirv_ptr: true },
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
+ from_type: ast::Type::Scalar(ast::ScalarType::B8),
+ from_space: ast::StateSpace::Shared,
+ to_type: ast::Type::Scalar(*scalar_type),
+ to_space: ast::StateSpace::Shared,
+ kind: ConversionKind::PtrToPtr,
}));
replacement_id
} else {
@@ -912,16 +859,17 @@ fn replace_uses_of_shared_memory<'a>( }
}
}
+ result
}
fn get_callers_of_extern_shared<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
) {
let direct_uses_of_extern_shared = methods_using_extern_shared
.iter()
.filter_map(|method| {
- if let MethodName::Func(f_id) = method {
+ if let ast::MethodName::Func(f_id) = method {
Some(*f_id)
} else {
None
@@ -934,14 +882,14 @@ fn get_callers_of_extern_shared<'a>( }
fn get_callers_of_extern_shared_single<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
fn_id: spirv::Word,
) {
if let Some(callers) = directly_called_by.get(&fn_id) {
for caller in callers {
if methods_using_extern_shared.insert(*caller) {
- if let MethodName::Func(caller_fn) = caller {
+ if let ast::MethodName::Func(caller_fn) = caller {
get_callers_of_extern_shared_single(
methods_using_extern_shared,
directly_called_by,
@@ -983,18 +931,18 @@ fn denorm_count_map_update_impl<T: Eq + Hash>( // and emit suitable execution mode
fn compute_denorm_information<'input>(
module: &[Directive<'input>],
-) -> HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
+) -> HashMap<ast::MethodName<'input, spirv::Word>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
- Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
func_decl,
body: Some(statements),
..
}) => {
let mut flush_counter = DenormCountMap::new();
- let method_key = MethodName::new(func_decl);
+ let method_key = (**func_decl).borrow().name;
for statement in statements {
match statement {
Statement::Instruction(inst) => {
@@ -1038,21 +986,6 @@ fn compute_denorm_information<'input>( .collect()
}
-#[derive(Hash, PartialEq, Eq, Copy, Clone)]
-enum MethodName<'input> {
- Kernel(&'input str),
- Func(spirv::Word),
-}
-
-impl<'input> MethodName<'input> {
- fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- match decl {
- ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name),
- ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id),
- }
- }
-}
-
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1061,10 +994,7 @@ fn emit_builtins( for (reg, id) in id_defs.special_registers.builtins() {
let result_type = map.get_or_add(
builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(reg.get_type())),
- spirv::StorageClass::Input,
- ),
+ SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input),
);
builder.variable(result_type, Some(id), spirv::StorageClass::Input, None);
builder.decorate(
@@ -1079,18 +1009,21 @@ fn emit_function_header<'a>( builder: &mut dr::Builder,
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
- synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
- func_decl: &SpirvMethodDecl<'a>,
- _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ synthetic_globals: &[ast::Variable<spirv::Word>],
+ func_decl: &ast::MethodDeclaration<'a, spirv::Word>,
+ _denorm_information: &HashMap<
+ ast::MethodName<'a, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<spirv::Word, TranslateError> {
- if let MethodName::Kernel(name) = func_decl.name {
- let input_args = if !func_decl.uses_shared_mem {
- func_decl.input.as_slice()
+ if let ast::MethodName::Kernel(name) = func_decl.name {
+ let input_args = if func_decl.shared_mem.is_none() {
+ func_decl.input_arguments.as_slice()
} else {
- &func_decl.input[0..func_decl.input.len() - 1]
+ &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
};
let args_lens = input_args
.iter()
@@ -1100,14 +1033,18 @@ fn emit_function_header<'a>( name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
- uses_shared_mem: func_decl.uses_shared_mem,
+ uses_shared_mem: func_decl.shared_mem.is_some(),
},
);
}
- let (ret_type, func_type) =
- get_function_type(builder, map, &func_decl.input, &func_decl.output);
+ let (ret_type, func_type) = get_function_type(
+ builder,
+ map,
+ func_decl.effective_input_arguments().map(|(_, typ)| typ),
+ &func_decl.return_arguments,
+ );
let fn_id = match func_decl.name {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let fn_id = defined_globals.get_id(name)?;
let mut global_variables = defined_globals
.variables_type_check
@@ -1123,15 +1060,18 @@ fn emit_function_header<'a>( for directive in direcitves {
match directive {
Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- globals,
- ..
+ func_decl, globals, ..
}) => {
- if child_fns.contains(name) {
- for var in globals {
- interface.push(var.name);
+ match (**func_decl).borrow().name {
+ ast::MethodName::Func(name) => {
+ if child_fns.contains(&name) {
+ for var in globals {
+ interface.push(var.name);
+ }
+ }
}
- }
+ ast::MethodName::Kernel(_) => {}
+ };
}
_ => {}
}
@@ -1140,7 +1080,7 @@ fn emit_function_header<'a>( builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
fn_id
}
- MethodName::Func(name) => name,
+ ast::MethodName::Func(name) => name,
};
builder.begin_function(
ret_type,
@@ -1163,9 +1103,9 @@ fn emit_function_header<'a>( }
}
*/
- for input in &func_decl.input {
- let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
- builder.function_parameter(Some(input.name), result_type)?;
+ for (name, typ) in func_decl.effective_input_arguments() {
+ let result_type = map.get_or_add(builder, typ);
+ builder.function_parameter(Some(name), result_type)?;
}
Ok(fn_id)
}
@@ -1207,55 +1147,32 @@ fn translate_directive<'input>( d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Option<Directive<'input>>, TranslateError> {
Ok(match d {
- ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)),
- ast::Directive::Method(f) => {
+ ast::Directive::Variable(linking, var) => Some(Directive::Variable(
+ linking,
+ ast::Variable {
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
+ array_init: var.array_init,
+ },
+ )),
+ ast::Directive::Method(_, f) => {
translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method)
}
})
}
-fn translate_variable<'a>(
- id_defs: &mut GlobalStringIdResolver<'a>,
- var: ast::Variable<ast::VariableType, &'a str>,
-) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (space, var_type) = var.v_type.to_type();
- let mut is_variable = false;
- let var_type = match space {
- ast::StateSpace::Reg => {
- is_variable = true;
- var_type
- }
- ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?,
- ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
- ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
- ast::StateSpace::Shared => {
- // If it's a pointer it will be translated to a method parameter later
- if let ast::Type::Pointer(..) = var_type {
- is_variable = true;
- var_type
- } else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
- }
- }
- ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
- };
- Ok(ast::Variable {
- align: var.align,
- v_type: var.v_type,
- name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
- array_init: var.array_init,
- })
-}
-
fn translate_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
ptx_impl_imports: &mut HashMap<String, Directive<'a>>,
f: ast::ParsedFunction<'a>,
) -> Result<Option<Function<'a>>, TranslateError> {
let import_as = match &f.func_directive {
- ast::MethodDecl::Func(_, "__assertfail", _) => {
- Some("__zluda_ptx_impl____assertfail".to_owned())
- }
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func("__assertfail"),
+ ..
+ } => Some("__zluda_ptx_impl____assertfail".to_owned()),
_ => None,
};
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
@@ -1279,63 +1196,38 @@ fn translate_function<'a>( }
}
-fn expand_kernel_params<'a, 'b>(
+fn rename_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
-) -> Result<Vec<ast::KernelArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- Ok(ast::KernelArgument {
- name: fn_resolver.add_def(
- a.name,
- Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?),
- false,
- ),
+ args: &'b [ast::Variable<&'a str>],
+) -> Vec<ast::Variable<spirv::Word>> {
+ args.iter()
+ .map(|a| ast::Variable {
+ name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
v_type: a.v_type.clone(),
+ state_space: a.state_space,
align: a.align,
- array_init: Vec::new(),
+ array_init: a.array_init.clone(),
})
- })
- .collect::<Result<_, _>>()
-}
-
-fn expand_fn_params<'a, 'b>(
- fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
-) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- let is_variable = match a.v_type {
- ast::FnArgumentType::Reg(_) => true,
- _ => false,
- };
- let var_type = a.v_type.to_func_type();
- Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
- v_type: a.v_type.clone(),
- align: a.align,
- array_init: Vec::new(),
- })
- })
- .collect()
+ .collect()
}
fn to_ssa<'input, 'b>(
ptx_impl_imports: &mut HashMap<String, Directive>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
- f_args: ast::MethodDecl<'input, spirv::Word>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
) -> Result<Function<'input>, TranslateError> {
- let mut spirv_decl = SpirvMethodDecl::new(&f_args);
+ //deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(Function {
- func_decl: f_args,
+ func_decl: func_decl,
body: None,
globals: Vec::new(),
import_as: None,
- spirv_decl,
tuning,
})
}
@@ -1345,15 +1237,14 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
- let typed_statements =
- convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
+ let (func_decl, typed_statements) =
+ convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
- &f_args,
- &mut spirv_decl,
+ &mut (*func_decl).borrow_mut(),
)?;
- let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?;
+ let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1363,16 +1254,15 @@ fn to_ssa<'input, 'b>( let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
Ok(Function {
- func_decl: f_args,
+ func_decl: func_decl,
globals: globals,
body: Some(f_body),
import_as: None,
- spirv_decl,
tuning,
})
}
-fn fix_builtins(
+fn fix_special_registers(
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
@@ -1408,7 +1298,8 @@ fn fix_builtins( continue;
}
};
- let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone()));
+ let temp_id = numeric_id_defs
+ .register_intermediate(Some((details.typ.clone(), details.state_space)));
let real_dst = details.arg.dst;
details.arg.dst = temp_id;
result.push(Statement::LoadVar(LoadVarDetails {
@@ -1416,17 +1307,18 @@ fn fix_builtins( src: sreg_src,
dst: temp_id,
},
+ state_space: ast::StateSpace::Sreg,
typ: ast::Type::Scalar(scalar_typ),
member_index: Some((index, Some(vector_width))),
}));
result.push(Statement::Conversion(ImplicitConversion {
src: temp_id,
dst: real_dst,
- from: ast::Type::Scalar(scalar_typ),
- to: ast::Type::Scalar(ast::ScalarType::U32),
+ from_type: ast::Type::Scalar(scalar_typ),
+ from_space: ast::StateSpace::Sreg,
+ to_type: ast::Type::Scalar(ast::ScalarType::U32),
+ to_space: ast::StateSpace::Sreg,
kind: ConversionKind::Default,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
}));
}
}
@@ -1456,10 +1348,7 @@ fn extract_globals<'input, 'b>( sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
-) -> (
- Vec<ExpandedStatement>,
- Vec<ast::Variable<ast::VariableType, spirv::Word>>,
-) {
+) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
@@ -1468,7 +1357,7 @@ fn extract_globals<'input, 'b>( var
@
ast::Variable {
- v_type: ast::VariableType::Shared(_),
+ state_space: ast::StateSpace::Shared,
..
},
)
@@ -1476,7 +1365,7 @@ fn extract_globals<'input, 'b>( var
@
ast::Variable {
- v_type: ast::VariableType::Global(_),
+ state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
@@ -1505,7 +1394,7 @@ fn extract_globals<'input, 'b>( d,
a,
"inc",
- ast::SizedScalarType::U32,
+ ast::ScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@@ -1527,7 +1416,7 @@ fn extract_globals<'input, 'b>( d,
a,
"dec",
- ast::SizedScalarType::U32,
+ ast::ScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@@ -1553,10 +1442,9 @@ fn extract_globals<'input, 'b>( space,
};
let (op, typ) = match typ {
- ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32),
- ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64),
- ast::FloatType::F16 => unreachable!(),
- ast::FloatType::F16x2 => unreachable!(),
+ ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32),
+ ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64),
+ _ => unreachable!(),
};
local.push(to_ptx_impl_atomic_call(
id_def,
@@ -1599,47 +1487,13 @@ fn convert_to_typed_statements( match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Call(call) => {
- // TODO: error out if lengths don't match
- let fn_def = fn_defs.get_fn_decl(call.func)?;
- let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
- let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
- let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
- .into_iter()
- .partition(|(_, arg_type)| arg_type.is_param());
- let normalized_input_args = out_params
- .into_iter()
- .map(|(id, typ)| (ast::Operand::Reg(id), typ))
- .chain(in_args.into_iter())
- .collect();
- let resolved_call = ResolvedCall {
- uniform: call.uniform,
- ret_params: out_non_params,
- func: call.func,
- param_list: normalized_input_args,
- };
+ let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
+ let resolved_call = resolver.resolve_in_spirv_repr(call)?;
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let reresolved_call = resolved_call.visit(&mut visitor)?;
visitor.func.push(reresolved_call);
visitor.func.extend(visitor.post_stmts);
}
- ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => {
- if let Some(src_id) = src.underlying() {
- let (typ, _) = id_defs.get_typed(*src_id)?;
- let take_address = match typ {
- ast::Type::Scalar(_) => false,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => true,
- ast::Type::Pointer(_, _) => true,
- };
- d.src_is_address = take_address;
- }
- let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
- let instruction = Statement::Instruction(
- ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?,
- );
- visitor.func.push(instruction);
- visitor.func.extend(visitor.post_stmts);
- }
inst => {
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let instruction = Statement::Instruction(inst.map(&mut visitor)?);
@@ -1674,8 +1528,14 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector(
&mut self,
is_dst: bool,
- vector_sema: ArgumentSemantics,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
typ: &ast::Type,
+ state_space: ast::StateSpace,
idx: Vec<spirv::Word>,
) -> Result<spirv::Word, TranslateError> {
// mov.u32 foobar, {a,b};
@@ -1683,13 +1543,15 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { ast::Type::Vector(scalar_t, _) => *scalar_t,
_ => return Err(TranslateError::MismatchedType),
};
- let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
+ let temp_vec = self
+ .id_def
+ .register_intermediate(Some((typ.clone(), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: idx,
- vector_sema,
+ non_default_implicit_conversion,
});
if is_dst {
self.post_stmts = Some(statement);
@@ -1706,7 +1568,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams> fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -1715,15 +1577,20 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams> &mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(reg) => TypedOperand::Reg(reg),
ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
ast::Operand::Imm(x) => TypedOperand::Imm(x),
ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
- ast::Operand::VecPack(vec) => {
- TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?)
- }
+ ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector(
+ desc.is_dst,
+ desc.non_default_implicit_conversion,
+ typ,
+ state_space,
+ vec,
+ )?),
})
}
}
@@ -1735,7 +1602,7 @@ fn to_ptx_impl_atomic_call( details: ast::AtomDetails,
arg: ast::Arg3<ExpandedArgParams>,
op: &'static str,
- typ: ast::SizedScalarType,
+ typ: ast::ScalarType,
) -> ExpandedStatement {
let semantics = ptx_semantics_name(details.semantics);
let scope = ptx_scope_name(details.scope);
@@ -1745,75 +1612,70 @@ fn to_ptx_impl_atomic_call( semantics, scope, space, op
);
// TODO: extract to a function
- let ptr_space = match details.space {
- ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
- ast::AtomSpace::Global => ast::PointerStateSpace::Global,
- ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
- };
+ let ptr_space = details.space;
let scalar_typ = ast::ScalarType::from(typ);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- typ, ptr_space,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Pointer(typ, ptr_space),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)),
+ ast::Type::Pointer(typ, ptr_space),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ ast::Type::Scalar(scalar_typ),
+ ast::StateSpace::Reg,
),
],
})
@@ -1822,93 +1684,92 @@ fn to_ptx_impl_atomic_call( fn to_ptx_impl_bfe_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
- typ: ast::IntType,
+ typ: ast::ScalarType,
arg: ast::Arg4<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
- ast::IntType::U32 => "bfe_u32",
- ast::IntType::U64 => "bfe_u64",
- ast::IntType::S32 => "bfe_s32",
- ast::IntType::S64 => "bfe_s64",
+ ast::ScalarType::U32 => "bfe_u32",
+ ast::ScalarType::U64 => "bfe_u64",
+ ast::ScalarType::S32 => "bfe_s32",
+ ast::ScalarType::S64 => "bfe_s64",
_ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
@@ -1917,117 +1778,107 @@ fn to_ptx_impl_bfe_call( fn to_ptx_impl_bfi_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
- typ: ast::BitType,
+ typ: ast::ScalarType,
arg: ast::Arg5<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
- ast::BitType::B32 => "bfi_b32",
- ast::BitType::B64 => "bfi_b64",
- ast::BitType::B8 | ast::BitType::B16 => unreachable!(),
+ ast::ScalarType::B32 => "bfi_b32",
+ ast::ScalarType::B64 => "bfi_b64",
+ _ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src4,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
}
-fn to_resolved_fn_args<T>(
- params: Vec<T>,
- params_decl: &[ast::FnArgumentType],
-) -> Vec<(T, ast::FnArgumentType)> {
- params
- .into_iter()
- .zip(params_decl.iter())
- .map(|(id, typ)| (id, typ.clone()))
- .collect::<Vec<_>>()
-}
-
fn normalize_labels(
func: Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
@@ -2056,7 +1907,7 @@ fn normalize_labels( | Statement::RepackVector(..) => {}
}
}
- iter::once(Statement::Label(id_def.new_non_variable(None)))
+ iter::once(Statement::Label(id_def.register_intermediate(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
@@ -2074,8 +1925,8 @@ fn normalize_predicates( Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
- let if_true = id_def.new_non_variable(None);
- let if_false = id_def.new_non_variable(None);
+ let if_true = id_def.register_intermediate(None);
+ let if_false = id_def.register_intermediate(None);
let folded_bra = match &inst {
ast::Instruction::Bra(_, arg) => Some(arg.src),
_ => None,
@@ -2106,53 +1957,52 @@ fn normalize_predicates( Ok(result)
}
+/*
+ How do we handle arguments:
+ - input .params in kernels
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ We do this for two reasons. One, common treatment for argument-declared
+ .param variables and .param variables inside function (we assume that
+ at SPIR-V level every .param is a pointer in Function storage class)
+ - input .params in functions
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %_ptr_Function_ulong
+ - input .regs
+ .reg .b64 in_arg
+ get turned into the same SPIR-V as kernel .params:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ - output .regs
+ .reg .b64 out_arg
+ get just a variable declaration:
+ %2 = OpVariable %%_ptr_Function_ulong Function
+ - output .params don't exist, they have been moved to input positions
+ by an earlier pass
+ Distinguishing betweem kernel .params and function .params is not the
+ cleanest solution. Alternatively, we could "deparamize" all kernel .param
+ arguments by turning them into .reg arguments like this:
+ .param .b64 arg -> .reg ptr<.b64,.param> arg
+ This has the massive downside that this transformation would have to run
+ very early and would muddy up already difficult code. It's simpler to just
+ have an if here
+*/
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
- ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
- fn_decl: &mut SpirvMethodDecl,
+ fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
- let is_func = match ast_fn_decl {
- ast::MethodDecl::Func(..) => true,
- ast::MethodDecl::Kernel { .. } => false,
- };
let mut result = Vec::with_capacity(func.len());
- for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type, is_func)? {
- Some(var_type) => {
- result.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
- }));
- }
- None => return Err(error_unreachable()),
- }
+ for arg in fn_decl.input_arguments.iter_mut() {
+ insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel());
}
- for spirv_arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&spirv_arg.v_type, is_func)? {
- Some(var_type) => {
- let typ = spirv_arg.v_type.clone();
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::Variable(ast::Variable {
- align: spirv_arg.align,
- v_type: var_type,
- name: spirv_arg.name,
- array_init: spirv_arg.array_init.clone(),
- }));
- result.push(Statement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: spirv_arg.name,
- src2: new_id,
- },
- typ,
- member_index: None,
- }));
- spirv_arg.name = new_id;
- }
- None => {}
- }
+ for arg in fn_decl.return_arguments.iter() {
+ insert_mem_ssa_argument_reg_return(&mut result, arg);
}
for s in func {
match s {
@@ -2162,32 +2012,41 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
- if let &[out_param] = &fn_decl.output.as_slice() {
- let (typ, _) = id_def.get_typed(out_param.name)?;
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: ast::Arg2 {
- dst: new_id,
- src: out_param.name,
- },
- typ: typ.clone(),
- member_index: None,
- }));
- result.push(Statement::RetValue(d, new_id));
- } else {
- result.push(Statement::Instruction(ast::Instruction::Ret(d)))
+ match &fn_decl.return_arguments[..] {
+ [return_reg] => {
+ let new_id = id_def.register_intermediate(Some((
+ return_reg.v_type.clone(),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::Arg2 {
+ dst: new_id,
+ src: return_reg.name,
+ },
+ // TODO: ret with stateful conversion
+ state_space: ast::StateSpace::Reg,
+ typ: return_reg.v_type.clone(),
+ member_index: None,
+ }));
+ result.push(Statement::RetValue(d, new_id));
+ }
+ [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))),
+ _ => unimplemented!(),
}
}
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
- let generated_id =
- id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
+ let generated_id = id_def.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )));
result.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: bra.predicate,
},
+ state_space: ast::StateSpace::Reg,
typ: ast::Type::Scalar(ast::ScalarType::Pred),
member_index: None,
}));
@@ -2210,39 +2069,45 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result)
}
-fn type_to_variable_type(
- t: &ast::Type,
- is_func: bool,
-) -> Result<Option<ast::VariableType>, TranslateError> {
- Ok(match t {
- ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
- ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- *len,
- ))),
- ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- len.clone(),
- ))),
- ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
- if is_func {
- return Ok(None);
- }
- Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
- scalar_type
- .clone()
- .try_into()
- .map_err(|_| error_unreachable())?,
- (*space).try_into().map_err(|_| error_unreachable())?,
- )))
- }
- ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
- _ => return Err(error_unreachable()),
- })
+fn insert_mem_ssa_argument(
+ id_def: &mut NumericIdResolver,
+ func: &mut Vec<TypedStatement>,
+ arg: &mut ast::Variable<spirv::Word>,
+ is_kernel: bool,
+) {
+ if !is_kernel && arg.state_space == ast::StateSpace::Param {
+ return;
+ }
+ let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: ast::StateSpace::Reg,
+ name: arg.name,
+ array_init: Vec::new(),
+ }));
+ func.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
+ src1: arg.name,
+ src2: new_id,
+ },
+ typ: arg.v_type.clone(),
+ member_index: None,
+ }));
+ arg.name = new_id;
+}
+
+fn insert_mem_ssa_argument_reg_return(
+ func: &mut Vec<TypedStatement>,
+ arg: &ast::Variable<spirv::Word>,
+) {
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: arg.array_init.clone(),
+ }));
}
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
@@ -2259,6 +2124,7 @@ struct VisitArgumentDescriptor< > {
desc: ArgumentDescriptor<spirv::Word>,
typ: &'a ast::Type,
+ state_space: ast::StateSpace,
stmt_ctor: Ctor,
}
@@ -2273,7 +2139,9 @@ impl< self,
visitor: &mut impl ArgumentMapVisitor<T, U>,
) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
- Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?))
+ Ok((self.stmt_ctor)(
+ visitor.id(self.desc, Some((self.typ, self.state_space)))?,
+ ))
}
}
@@ -2287,14 +2155,14 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
- expected_type: Option<&ast::Type>,
+ expected: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
let symbol = desc.op.0;
- if expected_type.is_none() {
+ if expected.is_none() {
return Ok(symbol);
};
- let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
- if !is_variable {
+ let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
+ if !var_space.is_compatible(ast::StateSpace::Reg) || !is_variable {
return Ok(symbol);
};
let member_index = match desc.op.1 {
@@ -2317,13 +2185,16 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }
None => None,
};
- let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
+ let generated_id = self
+ .id_def
+ .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
if !desc.is_dst {
self.func.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: symbol,
},
+ state_space: ast::StateSpace::Reg,
typ: var_type,
member_index,
}));
@@ -2348,7 +2219,7 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams> fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self.symbol(desc.new_op((desc.op, None)), typ)
}
@@ -2357,18 +2228,20 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams> &mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
TypedOperand::Reg(reg) => {
- TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
- }
- TypedOperand::RegOffset(reg, offset) => {
- TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset)
+ TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?)
}
+ TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(
+ self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?,
+ offset,
+ ),
op @ TypedOperand::Imm(..) => op,
- TypedOperand::VecMember(symbol, index) => {
- TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
- }
+ TypedOperand::VecMember(symbol, index) => TypedOperand::Reg(
+ self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?,
+ ),
})
}
}
@@ -2411,11 +2284,13 @@ fn expand_arguments<'a, 'b>( Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
})),
@@ -2464,7 +2339,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -2473,108 +2348,86 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { &mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
- let add_type;
- match typ {
- ast::Type::Pointer(underlying_type, state_space) => {
- let reg_typ = self.id_def.get_typed(reg)?;
- if let ast::Type::Pointer(_, _) = reg_typ {
- let id_constant_stmt = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: ast::ScalarType::S64,
- value: ast::ImmediateValue::S64(offset as i64),
- }));
- let dst = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::PtrAccess(PtrAccess {
- underlying_type: underlying_type.clone(),
- state_space: *state_space,
- dst,
- ptr_src: reg,
- offset_src: id_constant_stmt,
- }));
- return Ok(dst);
- } else {
- add_type = self.id_def.get_typed(reg)?;
- }
- }
- _ => {
- add_type = typ.clone();
+ if !desc.is_memory_access {
+ let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
+ if !reg_space.is_compatible(ast::StateSpace::Reg) {
+ return Err(TranslateError::MismatchedType);
}
- };
- let (width, kind) = match add_type {
- ast::Type::Scalar(scalar_t) => {
- let kind = match scalar_t.kind() {
- kind @ ScalarKind::Bit
- | kind @ ScalarKind::Unsigned
- | kind @ ScalarKind::Signed => kind,
- ScalarKind::Float => return Err(TranslateError::MismatchedType),
- ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
- ScalarKind::Pred => return Err(TranslateError::MismatchedType),
- };
- (scalar_t.size_of(), kind)
- }
- _ => return Err(TranslateError::MismatchedType),
- };
- let arith_detail = if kind == ScalarKind::Signed {
- ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::from_size(width),
- saturate: false,
- })
- } else {
- ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
- };
- let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
- let result_id = self.id_def.new_non_variable(add_type);
- // TODO: check for edge cases around min value/max value/wrapping
- if offset < 0 && kind != ScalarKind::Signed {
+ let reg_scalar_type = match reg_type {
+ ast::Type::Scalar(underlying_type) => underlying_type,
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let id_constant_stmt = self
+ .id_def
+ .register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
- typ: ast::ScalarType::from_parts(width, kind),
- value: ast::ImmediateValue::U64(-(offset as i64) as u64),
+ typ: reg_scalar_type,
+ value: ast::ImmediateValue::S64(offset as i64),
}));
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Sub(
- arith_detail,
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ let arith_details = match reg_scalar_type.kind() {
+ ast::ScalarKind::Signed => ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: reg_scalar_type,
+ saturate: false,
+ }),
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
+ ast::ArithDetails::Unsigned(reg_scalar_type)
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
+ self.func.push(Statement::Instruction(ast::Instruction::Add(
+ arith_details,
+ ast::Arg3 {
+ dst: id_add_result,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ )));
+ Ok(id_add_result)
} else {
+ let scalar_type = match typ {
+ ast::Type::Scalar(underlying_type) => *underlying_type,
+ _ => return Err(error_unreachable()),
+ };
+ let id_constant_stmt = self.id_def.register_intermediate(
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ );
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
- typ: ast::ScalarType::from_parts(width, kind),
+ typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- arith_detail,
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ let dst = self.id_def.register_intermediate(typ.clone(), state_space);
+ self.func.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: scalar_type,
+ state_space: state_space,
+ dst,
+ ptr_src: reg,
+ offset_src: id_constant_stmt,
+ }));
+ Ok(dst)
}
- Ok(result_id)
}
fn immediate(
&mut self,
desc: ArgumentDescriptor<ast::ImmediateValue>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
*scalar
} else {
todo!()
};
- let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t));
+ let id = self
+ .id_def
+ .register_intermediate(ast::Type::Scalar(scalar_t), state_space);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
@@ -2588,7 +2441,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self.reg(desc, t)
}
@@ -2597,12 +2450,13 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr &mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
- TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))),
+ TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space),
TypedOperand::RegOffset(reg, offset) => {
- self.reg_offset(desc.new_op((reg, offset)), typ)
+ self.reg_offset(desc.new_op((reg, offset)), typ, state_space)
}
TypedOperand::VecMember(..) => Err(error_unreachable()),
}
@@ -2630,79 +2484,18 @@ fn insert_implicit_conversions( let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
- Statement::Call(call) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- call,
- should_bitcast_wrapper,
- None,
- )?,
+ Statement::Call(call) => {
+ insert_implicit_conversions_impl(&mut result, id_def, call)?;
+ }
Statement::Instruction(inst) => {
- let mut default_conversion_fn =
- should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _;
- let mut state_space = None;
- if let ast::Instruction::Ld(d, _) = &inst {
- state_space = Some(d.state_space);
- }
- if let ast::Instruction::St(d, _) = &inst {
- state_space = Some(d.state_space.to_ld_ss());
- }
- if let ast::Instruction::Atom(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
- }
- if let ast::Instruction::AtomCas(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
- }
- if let ast::Instruction::Mov(..) = &inst {
- default_conversion_fn = should_bitcast_packed;
- }
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- inst,
- default_conversion_fn,
- state_space,
- )?;
+ insert_implicit_conversions_impl(&mut result, id_def, inst)?;
}
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- offset_src: constant_src,
- }) => {
- let visit_desc = VisitArgumentDescriptor {
- desc: ArgumentDescriptor {
- op: ptr_src,
- is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
- },
- typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
- stmt_ctor: |new_ptr_src| {
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src: new_ptr_src,
- offset_src: constant_src,
- })
- },
- };
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- visit_desc,
- bitcast_physical_pointer,
- Some(state_space),
- )?;
+ Statement::PtrAccess(access) => {
+ insert_implicit_conversions_impl(&mut result, id_def, access)?;
+ }
+ Statement::RepackVector(repack) => {
+ insert_implicit_conversions_impl(&mut result, id_def, repack)?;
}
- Statement::RepackVector(repack) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- repack,
- should_bitcast_wrapper,
- None,
- )?,
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
@@ -2720,72 +2513,56 @@ fn insert_implicit_conversions_impl( func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: impl Visitable<ExpandedArgParams, ExpandedArgParams>,
- default_conversion_fn: for<'a> fn(
- &'a ast::Type,
- &'a ast::Type,
- Option<ast::LdStateSpace>,
- ) -> Result<Option<ConversionKind>, TranslateError>,
- state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
- let statement = stmt.visit(
- &mut |desc: ArgumentDescriptor<spirv::Word>, typ: Option<&ast::Type>| {
- let instr_type = match typ {
+ let statement =
+ stmt.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<(&ast::Type, ast::StateSpace)>| {
+ let (instr_type, instruction_space) = match typ {
None => return Ok(desc.op),
Some(t) => t,
};
- let operand_type = id_def.get_typed(desc.op)?;
- let mut conversion_fn = default_conversion_fn;
- match desc.sema {
- ArgumentSemantics::Default => {}
- ArgumentSemantics::DefaultRelaxed => {
- if desc.is_dst {
- conversion_fn = should_convert_relaxed_dst_wrapper;
- } else {
- conversion_fn = should_convert_relaxed_src_wrapper;
- }
- }
- ArgumentSemantics::PhysicalPointer => {
- conversion_fn = bitcast_physical_pointer;
- }
- ArgumentSemantics::RegisterPointer => {
- conversion_fn = bitcast_register_pointer;
- }
- ArgumentSemantics::Address => {
- conversion_fn = force_bitcast_ptr_to_bit;
- }
- };
- match conversion_fn(&operand_type, instr_type, state_space)? {
+ let (operand_type, operand_space) = id_def.get_typed(desc.op)?;
+ let conversion_fn = desc
+ .non_default_implicit_conversion
+ .unwrap_or(default_implicit_conversion);
+ match conversion_fn(
+ (operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
- let mut from = instr_type.clone();
- let mut to = operand_type;
- let mut src = id_def.new_non_variable(instr_type.clone());
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type;
+ let mut to_space = operand_space;
+ let mut src =
+ id_def.register_intermediate(instr_type.clone(), instruction_space);
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
mem::swap(&mut src, &mut dst);
- mem::swap(&mut from, &mut to);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
- from,
- to,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
kind: conv_kind,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
}));
result
}
None => Ok(desc.op),
}
- },
- )?;
+ })?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
@@ -2794,17 +2571,15 @@ fn insert_implicit_conversions_impl( fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
- spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
+ spirv_input: impl Iterator<Item = SpirvType>,
+ spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
- spirv_input
- .iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
+ spirv_input,
spirv_output
.iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
+ .map(|var| SpirvType::new(var.v_type.clone())),
)
}
@@ -2831,20 +2606,25 @@ fn emit_function_body_ops( match s {
Statement::Label(_) => (),
Statement::Call(call) => {
- let (result_type, result_id) = match &*call.ret_params {
- [(id, typ)] => (
- map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
- Some(*id),
- ),
+ let (result_type, result_id) = match &*call.return_arguments {
+ [(id, typ, space)] => {
+ if *space != ast::StateSpace::Reg {
+ return Err(error_unreachable());
+ }
+ (
+ map.get_or_add(builder, SpirvType::new(typ.clone())),
+ Some(*id),
+ )
+ }
[] => (map.void(), None),
_ => todo!(),
};
let arg_list = call
- .param_list
+ .input_arguments
.iter()
- .map(|(id, _)| *id)
+ .map(|(id, _, _)| *id)
.collect::<Vec<_>>();
- builder.function_call(result_type, result_id, call.func, arg_list)?;
+ builder.function_call(result_type, result_id, call.name, arg_list)?;
}
Statement::Variable(var) => {
emit_variable(builder, map, var)?;
@@ -2966,7 +2746,7 @@ fn emit_function_body_ops( todo!()
}
let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone())));
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
builder.load(
result_type,
Some(arg.dst),
@@ -2998,7 +2778,7 @@ fn emit_function_body_ops( ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(d, arg) => {
let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone())));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Mul(mul, arg) => match mul {
@@ -3026,20 +2806,20 @@ fn emit_function_body_ops( emit_setp(builder, map, setp, arg)?;
}
ast::Instruction::Not(t, a) => {
- let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
+ let result_type = map.get_or_add(builder, SpirvType::from(*t));
let result_id = Some(a.dst);
let operand = a.src;
match t {
- ast::BooleanType::Pred => {
+ ast::ScalarType::Pred => {
logical_not(builder, result_type, result_id, operand)
}
_ => builder.not(result_type, result_id, operand),
}?;
}
ast::Instruction::Shl(t, a) => {
- let full_type = t.to_type();
+ let full_type = ast::Type::Scalar(*t);
let size_of = full_type.size_of();
- let result_type = map.get_or_add(builder, SpirvType::from(full_type));
+ let result_type = map.get_or_add(builder, SpirvType::new(full_type));
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?;
}
@@ -3048,7 +2828,7 @@ fn emit_function_body_ops( let size_of = full_type.size_of();
let result_type = map.get_or_add_scalar(builder, full_type);
let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?;
- if t.signed() {
+ if t.kind() == ast::ScalarKind::Signed {
builder.shift_right_arithmetic(
result_type,
Some(a.dst),
@@ -3088,7 +2868,7 @@ fn emit_function_body_ops( },
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::BooleanType::Pred {
+ if *t == ast::ScalarType::Pred {
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -3116,7 +2896,7 @@ fn emit_function_body_ops( }
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::BooleanType::Pred {
+ if *t == ast::ScalarType::Pred {
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -3202,7 +2982,7 @@ fn emit_function_body_ops( }
ast::Instruction::Neg(details, arg) => {
let result_type = map.get_or_add_scalar(builder, details.typ);
- let negate_func = if details.typ.kind() == ScalarKind::Float {
+ let negate_func = if details.typ.kind() == ast::ScalarKind::Float {
dr::Builder::f_negate
} else {
dr::Builder::s_negate
@@ -3269,7 +3049,7 @@ fn emit_function_body_ops( }
ast::Instruction::Xor { typ, arg } => {
let builder_fn = match typ {
- ast::BooleanType::Pred => emit_logical_xor_spirv,
+ ast::ScalarType::Pred => emit_logical_xor_spirv,
_ => dr::Builder::bitwise_xor,
};
let result_type = map.get_or_add_scalar(builder, (*typ).into());
@@ -3284,7 +3064,7 @@ fn emit_function_body_ops( return Err(error_unreachable());
}
ast::Instruction::Rem { typ, arg } => {
- let builder_fn = if typ.is_signed() {
+ let builder_fn = if typ.kind() == ast::ScalarKind::Signed {
dr::Builder::s_mod
} else {
dr::Builder::u_mod
@@ -3301,7 +3081,7 @@ fn emit_function_body_ops( Some(index) => {
let result_ptr_type = map.get_or_add(
builder,
- SpirvType::new_pointer(
+ SpirvType::pointer_to(
details.typ.clone(),
spirv::StorageClass::Function,
),
@@ -3334,14 +3114,11 @@ fn emit_function_body_ops( }) => {
let u8_pointer = map.get_or_add(
builder,
- SpirvType::from(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- *state_space,
- )),
+ SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)),
);
let result_type = map.get_or_add(
builder,
- SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
+ SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
);
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain(
@@ -3553,11 +3330,16 @@ fn ptx_scope_name(scope: ast::MemScope) -> &'static str { }
}
-fn ptx_space_name(space: ast::AtomSpace) -> &'static str {
+fn ptx_space_name(space: ast::StateSpace) -> &'static str {
match space {
- ast::AtomSpace::Generic => "generic",
- ast::AtomSpace::Global => "global",
- ast::AtomSpace::Shared => "shared",
+ ast::StateSpace::Generic => "generic",
+ ast::StateSpace::Global => "global",
+ ast::StateSpace::Shared => "shared",
+ ast::StateSpace::Reg => "reg",
+ ast::StateSpace::Const => "const",
+ ast::StateSpace::Local => "local",
+ ast::StateSpace::Param => "param",
+ ast::StateSpace::Sreg => "sreg",
}
}
@@ -3612,14 +3394,17 @@ fn vec_repr<T: Copy>(t: T) -> Vec<u8> { fn emit_variable(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- var: &ast::Variable<ast::VariableType, spirv::Word>,
+ var: &ast::Variable<spirv::Word>,
) -> Result<(), TranslateError> {
- let (must_init, st_class) = match var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
+ let (must_init, st_class) = match var.state_space {
+ ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
(false, spirv::StorageClass::Function)
}
- ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
- ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
+ ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Const => todo!(),
+ ast::StateSpace::Generic => todo!(),
+ ast::StateSpace::Sreg => todo!(),
};
let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
@@ -3628,18 +3413,12 @@ fn emit_variable( &*var.array_init,
)?)
} else if must_init {
- let type_id = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::from(var.v_type.clone())),
- );
+ let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone()));
Some(builder.constant_null(type_id, None))
} else {
None
};
- let ptr_type_id = map.get_or_add(
- builder,
- SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
- );
+ let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class));
builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
if let Some(align) = var.align {
builder.decorate(
@@ -3777,7 +3556,7 @@ fn emit_min( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
};
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
@@ -3802,7 +3581,7 @@ fn emit_max( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
};
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
@@ -3882,7 +3661,7 @@ fn emit_cvt( }
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.src.is_signed() {
+ if desc.src.kind() == ast::ScalarKind::Signed {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
@@ -3892,7 +3671,7 @@ fn emit_cvt( ast::CvtDetails::IntFromFloat(desc) => {
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.dst.is_signed() {
+ if desc.dst.kind() == ast::ScalarKind::Signed {
builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
@@ -3904,7 +3683,7 @@ fn emit_cvt( let dest_t: ast::ScalarType = desc.dst.into();
let src_t: ast::ScalarType = desc.src.into();
// first do shortening/widening
- let src = if desc.dst.width() != desc.src.width() {
+ let src = if desc.dst.size_of() != desc.src.size_of() {
let new_dst = if dest_t.kind() == src_t.kind() {
arg.dst
} else {
@@ -3913,14 +3692,14 @@ fn emit_cvt( let cv = ImplicitConversion {
src: arg.src,
dst: new_dst,
- from: ast::Type::Scalar(src_t),
- to: ast::Type::Scalar(ast::ScalarType::from_parts(
+ from_type: ast::Type::Scalar(src_t),
+ from_space: ast::StateSpace::Reg,
+ to_type: ast::Type::Scalar(ast::ScalarType::from_parts(
dest_t.size_of(),
src_t.kind(),
)),
+ to_space: ast::StateSpace::Reg,
kind: ConversionKind::Default,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
};
emit_implicit_conversion(builder, map, &cv)?;
new_dst
@@ -3933,7 +3712,7 @@ fn emit_cvt( // now do actual conversion
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.saturate {
- if desc.dst.is_signed() {
+ if desc.dst.kind() == ast::ScalarKind::Signed {
builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
} else {
builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
@@ -3989,60 +3768,60 @@ fn emit_setp( let operand_1 = arg.src1;
let operand_2 = arg.src2;
match (setp.cmp_op, setp.typ.kind()) {
- (ast::SetpCompareOp::Eq, ScalarKind::Signed)
- | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Eq, ast::ScalarKind::Signed)
+ | (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Eq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => {
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::NotEq, ScalarKind::Signed)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed)
+ | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => {
builder.s_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Float) => {
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => {
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => {
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => {
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => {
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => {
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanEq, _) => {
@@ -4222,7 +4001,7 @@ fn emit_abs( ) -> Result<(), dr::Error> {
let scalar_t = ast::ScalarType::from(d.typ);
let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
- let cl_abs = if scalar_t.kind() == ScalarKind::Signed {
+ let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
spirv::CLOp::s_abs
} else {
spirv::CLOp::fabs
@@ -4272,22 +4051,21 @@ fn emit_implicit_conversion( map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), TranslateError> {
- let from_parts = cv.from.to_parts();
- let to_parts = cv.to.to_parts();
- match (from_parts.kind, to_parts.kind, cv.kind) {
- (_, _, ConversionKind::PtrToBit(typ)) => {
- let dst_type = map.get_or_add_scalar(builder, typ.into());
- builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
- }
- (_, _, ConversionKind::BitToPtr(_)) => {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ let from_parts = cv.from_type.to_parts();
+ let to_parts = cv.to_type.to_parts();
+ match (from_parts.kind, to_parts.kind, &cv.kind) {
+ (_, _, &ConversionKind::BitToPtr) => {
+ let dst_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()),
+ );
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
- (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => {
if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
- if from_parts.scalar_kind != ScalarKind::Float
- && to_parts.scalar_kind != ScalarKind::Float
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ if from_parts.scalar_kind != ast::ScalarKind::Float
+ && to_parts.scalar_kind != ast::ScalarKind::Float
{
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
@@ -4295,28 +4073,28 @@ fn emit_implicit_conversion( builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
} else {
- // This block is safe because it's illegal to implictly convert between floating point instructions
+ // This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = map.get_or_add(
builder,
- SpirvType::from(ast::Type::from_parts(TypeParts {
- scalar_kind: ScalarKind::Bit,
+ SpirvType::new(ast::Type::from_parts(TypeParts {
+ scalar_kind: ast::ScalarKind::Bit,
..from_parts
})),
);
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts {
- scalar_kind: ScalarKind::Bit,
+ scalar_kind: ast::ScalarKind::Bit,
..to_parts
});
let wide_bit_type_spirv =
- map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
- if to_parts.scalar_kind == ScalarKind::Unsigned
- || to_parts.scalar_kind == ScalarKind::Bit
+ map.get_or_add(builder, SpirvType::new(wide_bit_type.clone()));
+ if to_parts.scalar_kind == ast::ScalarKind::Unsigned
+ || to_parts.scalar_kind == ast::ScalarKind::Bit
{
builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else {
- let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed
- && to_parts.scalar_kind == ScalarKind::Signed
+ let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
+ && to_parts.scalar_kind == ast::ScalarKind::Signed
{
dr::Builder::s_convert
} else {
@@ -4330,40 +4108,48 @@ fn emit_implicit_conversion( &ImplicitConversion {
src: wide_bit_value,
dst: cv.dst,
- from: wide_bit_type,
- to: cv.to.clone(),
+ from_type: wide_bit_type,
+ from_space: cv.from_space,
+ to_type: cv.to_type.clone(),
+ to_space: cv.to_space,
kind: ConversionKind::Default,
- src_sema: cv.src_sema,
- dst_sema: cv.dst_sema,
},
)?;
}
}
}
- (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => {
- let result_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
builder.s_convert(result_type, Some(cv.dst), cv.src)?;
}
- (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
- | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
- | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default)
+ | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default)
+ | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => {
+ let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
- (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
- let result_type = if spirv_ptr {
- map.get_or_add(
- builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(cv.to.clone())),
- spirv::StorageClass::Function,
- ),
- )
- } else {
- map.get_or_add(builder, SpirvType::from(cv.to.clone()))
- };
+ (_, _, &ConversionKind::PtrToPtr) => {
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ cv.to_space.to_spirv(),
+ ),
+ );
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
}
+ (_, _, &ConversionKind::AddressOf) => {
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
+ }
+ (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(result_type, Some(cv.dst), cv.src)?;
+ }
+ (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_u_to_ptr(result_type, Some(cv.dst), cv.src)?;
+ }
_ => unreachable!(),
}
Ok(())
@@ -4374,14 +4160,14 @@ fn emit_load_var( map: &mut TypeWordMap,
details: &LoadVarDetails,
) -> Result<(), TranslateError> {
- let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
+ let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone()));
match details.member_index {
Some((index, Some(width))) => {
let vector_type = match details.typ {
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
_ => return Err(TranslateError::MismatchedType),
};
- let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
+ let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
let vector_temp = builder.load(
vector_type_spirv,
None,
@@ -4399,7 +4185,7 @@ fn emit_load_var( Some((index, None)) => {
let result_ptr_type = map.get_or_add(
builder,
- SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function),
+ SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function),
);
let index_spirv = map.get_or_add_constant(
builder,
@@ -4427,10 +4213,10 @@ fn emit_load_var( Ok(())
}
-fn normalize_identifiers<'a, 'b>(
- id_defs: &mut FnStringIdResolver<'a, 'b>,
- fn_defs: &GlobalFnDeclResolver<'a, 'b>,
- func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
+fn normalize_identifiers<'input, 'b>(
+ id_defs: &mut FnStringIdResolver<'input, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'input, 'b>,
+ func: Vec<ast::Statement<ast::ParsedArgParams<'input>>>,
) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
@@ -4468,48 +4254,28 @@ fn expand_map_variables<'a, 'b>( i.map_variable(&mut |id| id_defs.get_id(id))?,
))),
ast::Statement::Variable(var) => {
- let mut var_type = ast::Type::from(var.var.v_type.clone());
- let mut is_variable = false;
- var_type = match var.var.v_type {
- ast::VariableType::Reg(_) => {
- is_variable = true;
- var_type
- }
- ast::VariableType::Shared(_) => {
- // If it's a pointer it will be translated to a method parameter later
- if let ast::Type::Pointer(..) = var_type {
- is_variable = true;
- var_type
- } else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
- }
- }
- ast::VariableType::Global(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Global)?
- }
- ast::VariableType::Param(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Param)?
- }
- ast::VariableType::Local(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Local)?
- }
- };
+ let var_type = var.var.v_type.clone();
match var.count {
Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) {
+ for new_id in
+ id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
+ {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init.clone(),
}))
}
}
None => {
- let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable);
+ let new_id =
+ id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init,
}));
@@ -4520,18 +4286,62 @@ fn expand_map_variables<'a, 'b>( Ok(())
}
+/*
+ Our goal here is to transform
+ .visible .entry foobar(.param .u64 input) {
+ .reg .b64 in_addr;
+ .reg .b64 in_addr2;
+ ld.param.u64 in_addr, [input];
+ cvta.to.global.u64 in_addr2, in_addr;
+ }
+ into:
+ .visible .entry foobar(.param .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ ld.param.u8[] in_addr, [input];
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.reg .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ mov.u8[] in_addr, input;
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.param ptr<u8, global> input) {
+ .reg ptr<u8, global> in_addr;
+ .reg ptr<u8, global> in_addr2;
+ ld.param.ptr<u8, global> in_addr, [input];
+ mov.ptr<u8, global> in_addr2, in_addr;
+ }
+*/
// TODO: detect more patterns (mov, call via reg, call via param)
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
-// TODO: propagate through calls?
-fn convert_to_stateful_memory_access<'a>(
- func_args: &mut SpirvMethodDecl,
+// TODO: propagate out of calls and into calls
+fn convert_to_stateful_memory_access<'a, 'input>(
+ func_args: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let func_args_64bit = func_args
- .input
+) -> Result<
+ (
+ Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+ Vec<TypedStatement>,
+ ),
+ TranslateError,
+> {
+ let mut method_decl = func_args.borrow_mut();
+ if !method_decl.name.is_kernel() {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
+ if Rc::strong_count(&func_args) != 1 {
+ return Err(error_unreachable());
+ }
+ let func_args_64bit = (*method_decl)
+ .input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
ast::Type::Scalar(ast::ScalarType::U64)
@@ -4546,9 +4356,9 @@ fn convert_to_stateful_memory_access<'a>( match statement {
Statement::Instruction(ast::Instruction::Cvta(
ast::CvtaDetails {
- to: ast::CvtaStateSpace::Global,
+ to: ast::StateSpace::Global,
size: ast::CvtaSize::U64,
- from: ast::CvtaStateSpace::Generic,
+ from: ast::StateSpace::Generic,
},
arg,
)) => {
@@ -4562,24 +4372,24 @@ fn convert_to_stateful_memory_access<'a>( }
Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::U64),
..
},
arg,
))
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::S64),
..
},
arg,
))
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::B64),
..
},
arg,
@@ -4595,6 +4405,10 @@ fn convert_to_stateful_memory_access<'a>( _ => {}
}
}
+ if stateful_markers.len() == 0 {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
let mut func_args_ptr = HashSet::new();
let mut regs_ptr_current = HashSet::new();
for (dst, src) in stateful_markers {
@@ -4614,23 +4428,23 @@ fn convert_to_stateful_memory_access<'a>( for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4661,21 +4475,32 @@ fn convert_to_stateful_memory_access<'a>( let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen {
- let new_id = id_defs.new_variable(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ));
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Reg,
+ );
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::U8,
- ast::PointerStateSpace::Global,
- )),
+ v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
+ for arg in (*method_decl).input_arguments.iter_mut() {
+ if !func_args_ptr.contains(&arg.name) {
+ continue;
+ }
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Param,
+ );
+ let old_name = arg.name;
+ arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
+ arg.name = new_id;
+ remapped_ids.insert(old_name, new_id);
+ }
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
@@ -4686,12 +4511,12 @@ fn convert_to_stateful_memory_access<'a>( }
}
Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4707,20 +4532,20 @@ fn convert_to_stateful_memory_access<'a>( };
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
+ underlying_type: ast::ScalarType::U8,
+ state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
}))
}
Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4734,8 +4559,10 @@ fn convert_to_stateful_memory_access<'a>( }
_ => return Err(error_unreachable()),
};
- let offset_neg =
- id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
+ let offset_neg = id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )));
result.push(Statement::Instruction(ast::Instruction::Neg(
ast::NegDetails {
typ: ast::ScalarType::S64,
@@ -4748,8 +4575,8 @@ fn convert_to_stateful_memory_access<'a>( )));
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
+ underlying_type: ast::ScalarType::U8,
+ state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: TypedOperand::Reg(offset_neg),
@@ -4757,151 +4584,116 @@ fn convert_to_stateful_memory_access<'a>( }
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
- let new_statement = inst.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
Statement::Call(call) => {
let mut post_statements = Vec::new();
- let new_statement = call.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
Statement::RepackVector(pack) => {
let mut post_statements = Vec::new();
- let new_statement = pack.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
_ => return Err(error_unreachable()),
}
}
- for arg in func_args.input.iter_mut() {
- if func_args_ptr.contains(&arg.name) {
- arg.v_type = ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- );
- }
- }
- Ok(result)
+ drop(method_decl);
+ Ok((func_args, result))
}
fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
- func_args_ptr: &HashSet<spirv::Word>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>,
+ expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => {
- // We skip conversion here to trigger PtrAcces in a later pass
- let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
- _ => id_defs.get_typed(arg_desc.op)?.0,
+ let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
+ if let Some((expected_type, expected_space)) = expected_type {
+ let implicit_conversion = arg_desc
+ .non_default_implicit_conversion
+ .unwrap_or(default_implicit_conversion);
+ if implicit_conversion(
+ (new_operand_space, &new_operand_type),
+ (expected_space, expected_type),
+ )
+ .is_ok()
+ {
+ return Ok(*new_id);
+ }
+ }
+ let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
+ let converting_id =
+ id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
+ let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
+ ConversionKind::Default
+ } else {
+ ConversionKind::PtrToPtr
};
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type_clone));
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
- from: old_type,
- to: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
- src_sema: ArgumentSemantics::Default,
- dst_sema: arg_desc.sema,
+ from_type: old_operand_type,
+ from_space: old_operand_space,
+ to_type: new_operand_type,
+ to_space: new_operand_space,
+ kind,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- to: old_type,
- kind: ConversionKind::PtrToBit(ast::UIntType::U64),
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
+ from_type: new_operand_type,
+ from_space: new_operand_space,
+ to_type: old_operand_type,
+ to_space: old_operand_space,
+ kind,
}));
converting_id
}
}
- None => match func_args_ptr.get(&arg_desc.op) {
- Some(new_id) => {
- if arg_desc.is_dst {
- return Err(error_unreachable());
- }
- // We skip conversion here to trigger PtrAcces in a later pass
- let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
- _ => id_defs.get_typed(arg_desc.op)?.0,
- };
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type));
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
- ast::LdStateSpace::Param,
- ),
- to: old_type_clone,
- kind: ConversionKind::PtrToPtr { spirv_ptr: false },
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- None => arg_desc.op,
- },
+ None => arg_desc.op,
})
}
@@ -4925,9 +4717,9 @@ fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgP fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
match id_defs.get_typed(id) {
- Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::S64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true,
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
_ => false,
}
}
@@ -5055,20 +4847,95 @@ impl SpecialRegistersMap { }
}
+struct FnSigMapper<'input> {
+ // true - stays as return argument
+ // false - is moved to input argument
+ return_param_args: Vec<bool>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+}
+
+impl<'input> FnSigMapper<'input> {
+ fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self {
+ let return_param_args = method
+ .return_arguments
+ .iter()
+ .map(|a| a.state_space != ast::StateSpace::Param)
+ .collect::<Vec<_>>();
+ let mut new_return_arguments = Vec::new();
+ for arg in method.return_arguments.into_iter() {
+ if arg.state_space == ast::StateSpace::Param {
+ method.input_arguments.push(arg);
+ } else {
+ new_return_arguments.push(arg);
+ }
+ }
+ method.return_arguments = new_return_arguments;
+ FnSigMapper {
+ return_param_args,
+ func_decl: Rc::new(RefCell::new(method)),
+ }
+ }
+
+ fn resolve_in_spirv_repr(
+ &self,
+ call_inst: ast::CallInst<NormalizedArgParams>,
+ ) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
+ let func_decl = (*self.func_decl).borrow();
+ let mut return_arguments = Vec::new();
+ let mut input_arguments = call_inst
+ .param_list
+ .into_iter()
+ .zip(func_decl.input_arguments.iter())
+ .map(|(id, var)| (id, var.v_type.clone(), var.state_space))
+ .collect::<Vec<_>>();
+ let mut func_decl_return_iter = func_decl.return_arguments.iter();
+ let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
+ for (idx, id) in call_inst.ret_params.iter().enumerate() {
+ let stays_as_return = match self.return_param_args.get(idx) {
+ Some(x) => *x,
+ None => return Err(TranslateError::MismatchedType),
+ };
+ if stays_as_return {
+ if let Some(var) = func_decl_return_iter.next() {
+ return_arguments.push((*id, var.v_type.clone(), var.state_space));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ } else {
+ if let Some(var) = func_decl_input_iter.next() {
+ input_arguments.push((
+ ast::Operand::Reg(*id),
+ var.v_type.clone(),
+ var.state_space,
+ ));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ }
+ }
+ if return_arguments.len() != func_decl.return_arguments.len()
+ || input_arguments.len() != func_decl.input_arguments.len()
+ {
+ return Err(TranslateError::MismatchedType);
+ }
+ Ok(ResolvedCall {
+ return_arguments,
+ input_arguments,
+ uniform: call_inst.uniform,
+ name: call_inst.func,
+ })
+ }
+}
+
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
- variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: SpecialRegistersMap,
- fns: HashMap<spirv::Word, FnDecl>,
-}
-
-pub struct FnDecl {
- ret_vals: Vec<ast::FnArgumentType>,
- params: Vec<ast::FnArgumentType>,
+ fns: HashMap<spirv::Word, FnSigMapper<'input>>,
}
-impl<'a> GlobalStringIdResolver<'a> {
+impl<'input> GlobalStringIdResolver<'input> {
fn new(start_id: spirv::Word) -> Self {
Self {
current_id: start_id,
@@ -5079,20 +4946,25 @@ impl<'a> GlobalStringIdResolver<'a> { }
}
- fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
+ fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word {
self.get_or_add_impl(id, None)
}
fn get_or_add_def_typed(
&mut self,
- id: &'a str,
+ id: &'input str,
typ: ast::Type,
+ state_space: ast::StateSpace,
is_variable: bool,
) -> spirv::Word {
- self.get_or_add_impl(id, Some((typ, is_variable)))
+ self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
}
- fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word {
+ fn get_or_add_impl(
+ &mut self,
+ id: &'input str,
+ typ: Option<(ast::Type, ast::StateSpace, bool)>,
+ ) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
@@ -5119,12 +4991,12 @@ impl<'a> GlobalStringIdResolver<'a> { fn start_fn<'b>(
&'b mut self,
- header: &'b ast::MethodDecl<'a, &'a str>,
+ header: &'b ast::MethodDeclaration<'input, &'input str>,
) -> Result<
(
- FnStringIdResolver<'a, 'b>,
- GlobalFnDeclResolver<'a, 'b>,
- ast::MethodDecl<'a, spirv::Word>,
+ FnStringIdResolver<'input, 'b>,
+ GlobalFnDeclResolver<'input, 'b>,
+ Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
),
TranslateError,
> {
@@ -5138,60 +5010,51 @@ impl<'a> GlobalStringIdResolver<'a> { variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
};
- let new_fn_decl = match header {
- ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel {
- name,
- in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?,
- },
- ast::MethodDecl::Func(ret_params, _, params) => {
- let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?;
- let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?;
- self.fns.insert(
- name_id,
- FnDecl {
- ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
- params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
- },
- );
- ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
- }
+ let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
+ let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
+ let name = match header.name {
+ ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
+ ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
+ };
+ let fn_decl = ast::MethodDeclaration {
+ return_arguments,
+ name,
+ input_arguments,
+ shared_mem: None,
+ };
+ let new_fn_decl = if !fn_decl.name.is_kernel() {
+ let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
+ let new_fn_decl = resolver.func_decl.clone();
+ self.fns.insert(name_id, resolver);
+ new_fn_decl
+ } else {
+ Rc::new(RefCell::new(fn_decl))
};
Ok((
fn_resolver,
- GlobalFnDeclResolver {
- variables: &self.variables,
- fns: &self.fns,
- },
+ GlobalFnDeclResolver { fns: &self.fns },
new_fn_decl,
))
}
}
pub struct GlobalFnDeclResolver<'input, 'a> {
- variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
- fns: &'a HashMap<spirv::Word, FnDecl>,
+ fns: &'a HashMap<spirv::Word, FnSigMapper<'input>>,
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
- fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> {
+ fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
}
-
- fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> {
- match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
- Some(Some(fn_d)) => Ok(fn_d),
- _ => Err(TranslateError::UnknownSymbol),
- }
- }
}
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@@ -5229,14 +5092,21 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { }
}
- fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>, is_variable: bool) -> spirv::Word {
+ fn add_def(
+ &mut self,
+ id: &'a str,
+ typ: Option<(ast::Type, ast::StateSpace)>,
+ is_variable: bool,
+ ) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
- self.type_check
- .insert(numeric_id, typ.map(|t| (t, is_variable)));
+ self.type_check.insert(
+ numeric_id,
+ typ.map(|(typ, space)| (typ, space, is_variable)),
+ );
*self.current_id += 1;
numeric_id
}
@@ -5247,6 +5117,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { base_id: &'a str,
count: u32,
typ: ast::Type,
+ state_space: ast::StateSpace,
is_variable: bool,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
@@ -5255,8 +5126,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check
- .insert(numeric_id + i, Some((typ.clone(), is_variable)));
+ self.type_check.insert(
+ numeric_id + i,
+ Some((typ.clone(), state_space, is_variable)),
+ );
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -5265,8 +5138,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
}
@@ -5275,12 +5148,15 @@ impl<'b> NumericIdResolver<'b> { MutableNumericIdResolver { base: self }
}
- fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> {
+ fn get_typed(
+ &self,
+ id: spirv::Word,
+ ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) {
- Some(x) => Ok((x.get_type(), true)),
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
None => match self.global_type_check.get(&id) {
Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
@@ -5291,16 +5167,18 @@ impl<'b> NumericIdResolver<'b> { // This is for identifiers which will be emitted later as OpVariable
// They are candidates for insertion of LoadVar/StoreVar
- fn new_variable(&mut self, typ: ast::Type) -> spirv::Word {
+ fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, Some((typ, true)));
+ self.type_check
+ .insert(new_id, Some((typ, state_space, true)));
*self.current_id += 1;
new_id
}
- fn new_non_variable(&mut self, typ: Option<ast::Type>) -> spirv::Word {
+ fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, typ.map(|t| (t, false)));
+ self.type_check
+ .insert(new_id, typ.map(|(t, space)| (t, space, false)));
*self.current_id += 1;
new_id
}
@@ -5315,18 +5193,22 @@ impl<'b> MutableNumericIdResolver<'b> { self.base
}
- fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
- self.base.get_typed(id).map(|(t, _)| t)
+ fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
+ self.base.get_typed(id).map(|(t, space, _)| (t, space))
}
- fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word {
- self.base.new_non_variable(Some(typ))
+ fn register_intermediate(
+ &mut self,
+ typ: ast::Type,
+ state_space: ast::StateSpace,
+ ) -> spirv::Word {
+ self.base.register_intermediate(Some((typ, state_space)))
}
}
enum Statement<I, P: ast::ArgParams> {
Label(u32),
- Variable(ast::Variable<ast::VariableType, P::Id>),
+ Variable(ast::Variable<P::Id>),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
@@ -5349,7 +5231,8 @@ impl ExpandedStatement { Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| {
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
Ok(f(arg.op, arg.is_dst))
})
.unwrap(),
@@ -5364,16 +5247,17 @@ impl ExpandedStatement { Statement::StoreVar(details)
}
Statement::Call(mut call) => {
- for (id, typ) in call.ret_params.iter_mut() {
- let is_dst = match typ {
- ast::FnArgumentType::Reg(_) => true,
- ast::FnArgumentType::Param(_) => false,
- ast::FnArgumentType::Shared => false,
+ for (id, _, space) in call.return_arguments.iter_mut() {
+ let is_dst = match space {
+ ast::StateSpace::Reg => true,
+ ast::StateSpace::Param => false,
+ ast::StateSpace::Shared => false,
+ _ => todo!(),
};
*id = f(*id, is_dst);
}
- call.func = f(call.func, false);
- for (id, _) in call.param_list.iter_mut() {
+ call.name = f(call.name, false);
+ for (id, _, _) in call.input_arguments.iter_mut() {
*id = f(*id, false);
}
Statement::Call(call)
@@ -5435,6 +5319,7 @@ impl ExpandedStatement { struct LoadVarDetails {
arg: ast::Arg2<ExpandedArgParams>,
typ: ast::Type,
+ state_space: ast::StateSpace,
// (index, vector_width)
// HACK ALERT
// For some reason IGC explodes when you try to load from builtin vectors
@@ -5454,7 +5339,12 @@ struct RepackVectorDetails { typ: ast::ScalarType,
packed: spirv::Word,
unpacked: Vec<spirv::Word>,
- vector_sema: ArgumentSemantics,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
}
impl RepackVectorDetails {
@@ -5470,13 +5360,17 @@ impl RepackVectorDetails { ArgumentDescriptor {
op: self.packed,
is_dst: !self.is_extract,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
+ Some((
+ &ast::Type::Vector(self.typ, self.unpacked.len() as u8),
+ ast::StateSpace::Reg,
+ )),
)?;
let scalar_type = self.typ;
let is_extract = self.is_extract;
- let vector_sema = self.vector_sema;
+ let non_default_implicit_conversion = self.non_default_implicit_conversion;
let vector = self
.unpacked
.into_iter()
@@ -5485,9 +5379,10 @@ impl RepackVectorDetails { ArgumentDescriptor {
op: id,
is_dst: is_extract,
- sema: vector_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)),
)
})
.collect::<Result<_, _>>()?;
@@ -5496,7 +5391,7 @@ impl RepackVectorDetails { typ: self.typ,
packed: scalar,
unpacked: vector,
- vector_sema,
+ non_default_implicit_conversion,
})
}
}
@@ -5514,18 +5409,18 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
- pub func: P::Id,
- pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
+ pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>,
+ pub name: P::Id,
+ pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
fn cast<U: ast::ArgParams<Id = T::Id, Operand = T::Operand>>(self) -> ResolvedCall<U> {
ResolvedCall {
uniform: self.uniform,
- ret_params: self.ret_params,
- func: self.func,
- param_list: self.param_list,
+ return_arguments: self.return_arguments,
+ name: self.name,
+ input_arguments: self.input_arguments,
}
}
}
@@ -5535,49 +5430,53 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> { self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
- let ret_params = self
- .ret_params
+ let return_arguments = self
+ .return_arguments
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.id(
ArgumentDescriptor {
op: id,
- is_dst: !typ.is_param(),
- sema: typ.semantics(),
+ is_dst: space != ast::StateSpace::Param,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&typ.to_func_type()),
+ Some((&typ, space)),
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.id(
ArgumentDescriptor {
- op: self.func,
+ op: self.name,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
None,
)?;
- let param_list = self
- .param_list
+ let input_arguments = self
+ .input_arguments
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
- sema: typ.semantics(),
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- &typ.to_func_type(),
+ &typ,
+ space,
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
uniform: self.uniform,
- ret_params,
- func,
- param_list,
+ return_arguments,
+ name: func,
+ input_arguments,
})
}
}
@@ -5598,39 +5497,34 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> { self,
visitor: &mut V,
) -> Result<PtrAccess<To>, TranslateError> {
- let sema = match self.state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
- };
- let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
+ let ptr_type = ast::Type::Scalar(self.underlying_type.clone());
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ptr_type),
+ Some((&ptr_type, self.state_space)),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
op: self.ptr_src,
is_dst: false,
- sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ptr_type),
+ Some((&ptr_type, self.state_space)),
)?;
let new_constant_src = visitor.operand(
ArgumentDescriptor {
op: self.offset_src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
)?;
Ok(PtrAccess {
underlying_type: self.underlying_type,
@@ -5653,21 +5547,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab }
}
-pub trait ArgParamsEx: ast::ArgParams + Sized {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError>;
-}
+pub trait ArgParamsEx: ast::ArgParams + Sized {}
-impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl_str(id)
- }
-}
+impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {}
enum NormalizedArgParams {}
@@ -5676,14 +5558,7 @@ impl ast::ArgParams for NormalizedArgParams { type Operand = ast::Operand<spirv::Word>;
}
-impl ArgParamsEx for NormalizedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for NormalizedArgParams {}
type NormalizedStatement = Statement<
(
@@ -5702,14 +5577,7 @@ impl ast::ArgParams for TypedArgParams { type Operand = TypedOperand;
}
-impl ArgParamsEx for TypedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for TypedArgParams {}
#[derive(Copy, Clone)]
enum TypedOperand {
@@ -5740,24 +5608,16 @@ impl ast::ArgParams for ExpandedArgParams { type Operand = spirv::Word;
}
-impl ArgParamsEx for ExpandedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for ExpandedArgParams {}
enum Directive<'input> {
- Variable(ast::Variable<ast::VariableType, spirv::Word>),
+ Variable(ast::LinkingDirective, ast::Variable<spirv::Word>),
Method(Function<'input>),
}
struct Function<'input> {
- pub func_decl: ast::MethodDecl<'input, spirv::Word>,
- pub spirv_decl: SpirvMethodDecl<'input>,
- pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
+ pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+ pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
@@ -5767,12 +5627,13 @@ pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> { fn id(
&mut self,
desc: ArgumentDescriptor<T::Id>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<U::Operand, TranslateError>;
}
@@ -5780,13 +5641,13 @@ impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
+ Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -5795,8 +5656,9 @@ where &mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(typ))
+ self(desc, Some((typ, state_space)))
}
}
@@ -5807,7 +5669,7 @@ where fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc.op)
}
@@ -5816,6 +5678,7 @@ where &mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?),
@@ -5824,7 +5687,7 @@ where ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member),
ast::Operand::VecPack(ref ids) => ast::Operand::VecPack(
ids.into_iter()
- .map(|id| self.id(desc.new_op(id), Some(typ)))
+ .map(|id| self.id(desc.new_op(id), Some((typ, state_space))))
.collect::<Result<Vec<_>, _>>()?,
),
})
@@ -5834,37 +5697,30 @@ where pub struct ArgumentDescriptor<Op> {
op: Op,
is_dst: bool,
- sema: ArgumentSemantics,
+ is_memory_access: bool,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
}
pub struct PtrAccess<P: ast::ArgParams> {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
+ underlying_type: ast::ScalarType,
+ state_space: ast::StateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,
offset_src: P::Operand,
}
-#[derive(Copy, Clone, PartialEq, Eq, Debug)]
-pub enum ArgumentSemantics {
- // normal register access
- Default,
- // normal register access with relaxed conversion rules (ld/st)
- DefaultRelaxed,
- // st/ld global
- PhysicalPointer,
- // st/ld .param, .local
- RegisterPointer,
- // mov of .local/.global variables
- Address,
-}
-
impl<T> ArgumentDescriptor<T> {
fn new_op<U>(&self, u: U) -> ArgumentDescriptor<U> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
- sema: self.sema,
+ is_memory_access: self.is_memory_access,
+ non_default_implicit_conversion: self.non_default_implicit_conversion,
}
}
}
@@ -5905,7 +5761,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> { let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
- ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
+ ast::Instruction::Not(t, a) => {
+ ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
+ }
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -5928,7 +5786,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
+ ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
}
ast::Instruction::Shr(t, a) => {
ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
@@ -6101,17 +5959,19 @@ impl ImplicitConversion { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: self.dst_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&self.to),
+ Some((&self.to_type, self.to_space)),
)?;
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: self.src_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&self.from),
+ Some((&self.from_type, self.from_space)),
)?;
Ok(Statement::Conversion({
ImplicitConversion {
@@ -6138,13 +5998,13 @@ impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
+ Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -6153,12 +6013,15 @@ where &mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
- TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?),
+ TypedOperand::Reg(id) => {
+ TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?)
+ }
TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
TypedOperand::RegOffset(id, imm) => {
- TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm)
+ TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm)
}
TypedOperand::VecMember(reg, index) => {
let scalar_type = match typ {
@@ -6166,7 +6029,10 @@ where _ => return Err(error_unreachable()),
};
let vec_type = ast::Type::Vector(scalar_type, index + 1);
- TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index)
+ TypedOperand::VecMember(
+ self(desc.new_op(reg), Some((&vec_type, state_space)))?,
+ index,
+ )
}
})
}
@@ -6178,9 +6044,9 @@ impl ast::Type { ast::Type::Scalar(scalar) => {
let kind = scalar.kind();
let width = scalar.size_of();
- if (kind != ScalarKind::Signed
- && kind != ScalarKind::Unsigned
- && kind != ScalarKind::Bit)
+ if (kind != ast::ScalarKind::Signed
+ && kind != ast::ScalarKind::Unsigned
+ && kind != ast::ScalarKind::Bit)
|| (width == 8)
{
return Err(TranslateError::MismatchedType);
@@ -6198,57 +6064,32 @@ impl ast::Type { match self {
ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
- state_space: ast::LdStateSpace::Global,
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*components as u32],
- state_space: ast::LdStateSpace::Global,
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: components.clone(),
- state_space: ast::LdStateSpace::Global,
},
- ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts {
- kind: TypeKind::PointerScalar,
+ ast::Type::Pointer(scalar, space) => TypeParts {
+ kind: TypeKind::Pointer,
+ state_space: *space,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
- state_space: *state_space,
- },
- ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts {
- kind: TypeKind::PointerVector,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*len as u32],
- state_space: *state_space,
},
- ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => {
- TypeParts {
- kind: TypeKind::PointerArray,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: components.clone(),
- state_space: *state_space,
- }
- }
- ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => {
- TypeParts {
- kind: TypeKind::PointerPointer,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*inner_space as u32],
- state_space: *state_space,
- }
- }
}
}
@@ -6265,29 +6106,8 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
- TypeKind::PointerScalar => ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)),
- t.state_space,
- ),
- TypeKind::PointerVector => ast::Type::Pointer(
- ast::PointerType::Vector(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components[0] as u8,
- ),
- t.state_space,
- ),
- TypeKind::PointerArray => ast::Type::Pointer(
- ast::PointerType::Array(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components,
- ),
- t.state_space,
- ),
- TypeKind::PointerPointer => ast::Type::Pointer(
- ast::PointerType::Pointer(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) },
- ),
+ TypeKind::Pointer => ast::Type::Pointer(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.state_space,
),
}
@@ -6300,7 +6120,7 @@ impl ast::Type { ast::Type::Array(typ, len) => len
.iter()
.fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
- ast::Type::Pointer(_, _) => mem::size_of::<usize>(),
+ ast::Type::Pointer(..) => mem::size_of::<usize>(),
}
}
}
@@ -6308,10 +6128,10 @@ impl ast::Type { #[derive(Eq, PartialEq, Clone)]
struct TypeParts {
kind: TypeKind,
- scalar_kind: ScalarKind,
+ scalar_kind: ast::ScalarKind,
width: u8,
+ state_space: ast::StateSpace,
components: Vec<u32>,
- state_space: ast::LdStateSpace,
}
#[derive(Eq, PartialEq, Copy, Clone)]
@@ -6319,10 +6139,7 @@ enum TypeKind { Scalar,
Vector,
Array,
- PointerScalar,
- PointerVector,
- PointerArray,
- PointerPointer,
+ Pointer,
}
impl ast::Instruction<ExpandedArgParams> {
@@ -6450,21 +6267,21 @@ struct BrachCondition { struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
- from: ast::Type,
- to: ast::Type,
+ from_type: ast::Type,
+ to_type: ast::Type,
+ from_space: ast::StateSpace,
+ to_space: ast::StateSpace,
kind: ConversionKind,
- src_sema: ArgumentSemantics,
- dst_sema: ArgumentSemantics,
}
-#[derive(PartialEq, Copy, Clone)]
+#[derive(PartialEq, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
- BitToPtr(ast::LdStateSpace),
- PtrToBit(ast::UIntType),
- PtrToPtr { spirv_ptr: bool },
+ BitToPtr,
+ PtrToPtr,
+ AddressOf,
}
impl<T> ast::PredAt<T> {
@@ -6512,13 +6329,14 @@ impl<T: ArgParamsEx> ast::Arg1<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
)?;
@@ -6535,9 +6353,11 @@ impl<T: ArgParamsEx> ast::Arg1Bar<T> { ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg1Bar { src: new_src })
}
@@ -6553,17 +6373,21 @@ impl<T: ArgParamsEx> ast::Arg2<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let new_src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 {
dst: new_dst,
@@ -6581,17 +6405,21 @@ impl<T: ArgParamsEx> ast::Arg2<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
dst_t,
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
src_t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 { dst, src })
}
@@ -6607,26 +6435,21 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::DefaultRelaxed,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper),
},
&ast::Type::from(details.typ.clone()),
+ ast::StateSpace::Reg,
)?;
- let is_logical_ptr = details.state_space == ast::LdStateSpace::Param
- || details.state_space == ast::LdStateSpace::Local;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if is_logical_ptr {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space,
- ),
+ &details.typ,
+ details.state_space,
)?;
Ok(ast::Arg2Ld { dst, src })
}
@@ -6638,30 +6461,25 @@ impl<T: ArgParamsEx> ast::Arg2St<T> { visitor: &mut V,
details: &ast::StData,
) -> Result<ast::Arg2St<U>, TranslateError> {
- let is_logical_ptr = details.state_space == ast::StStateSpace::Param
- || details.state_space == ast::StStateSpace::Local;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: if is_logical_ptr {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space.to_ld_ss(),
- ),
+ &details.typ,
+ details.state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::DefaultRelaxed,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper),
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2St { src1, src2 })
}
@@ -6677,21 +6495,21 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if details.src_is_address {
- ArgumentSemantics::Address
- } else {
- ArgumentSemantics::Default
- },
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(implicit_conversion_mov),
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2Mov { dst, src })
}
@@ -6713,25 +6531,31 @@ impl<T: ArgParamsEx> ast::Arg3<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
wide_type.as_ref().unwrap_or(typ),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6745,25 +6569,31 @@ impl<T: ArgParamsEx> ast::Arg3<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6772,35 +6602,38 @@ impl<T: ArgParamsEx> ast::Arg3<T> { self,
visitor: &mut V,
t: ast::ScalarType,
- state_space: ast::AtomSpace,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg3<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6822,33 +6655,41 @@ impl<T: ArgParamsEx> ast::Arg4<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
wide_type.as_ref().unwrap_or(t),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6861,39 +6702,47 @@ impl<T: ArgParamsEx> ast::Arg4<T> { fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::SelpType,
+ t: ast::ScalarType,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6906,44 +6755,49 @@ impl<T: ArgParamsEx> ast::Arg4<T> { fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::BitType,
- state_space: ast::AtomSpace,
+ t: ast::ScalarType,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6962,34 +6816,42 @@ impl<T: ArgParamsEx> ast::Arg4<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let u32_type = ast::Type::Scalar(ast::ScalarType::U32);
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -7010,9 +6872,13 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { ArgumentDescriptor {
op: self.dst1,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -7021,9 +6887,13 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { ArgumentDescriptor {
op: dst2,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -7031,17 +6901,21 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4Setp {
dst1,
@@ -7062,41 +6936,51 @@ impl<T: ArgParamsEx> ast::Arg5<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
let src4 = visitor.operand(
ArgumentDescriptor {
op: self.src4,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5 {
dst,
@@ -7118,9 +7002,13 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> { ArgumentDescriptor {
op: self.dst1,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -7129,9 +7017,13 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> { ArgumentDescriptor {
op: dst2,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -7139,25 +7031,31 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> { ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5Setp {
dst1,
@@ -7195,115 +7093,41 @@ impl ast::Operand<spirv::Word> { }
}
-impl ast::StStateSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
- ast::StStateSpace::Global => ast::LdStateSpace::Global,
- ast::StStateSpace::Local => ast::LdStateSpace::Local,
- ast::StStateSpace::Param => ast::LdStateSpace::Param,
- ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
-#[derive(Clone, Copy, PartialEq, Eq)]
-enum ScalarKind {
- Bit,
- Unsigned,
- Signed,
- Float,
- Float2,
- Pred,
-}
-
impl ast::ScalarType {
- fn kind(self) -> ScalarKind {
- match self {
- ast::ScalarType::U8 => ScalarKind::Unsigned,
- ast::ScalarType::U16 => ScalarKind::Unsigned,
- ast::ScalarType::U32 => ScalarKind::Unsigned,
- ast::ScalarType::U64 => ScalarKind::Unsigned,
- ast::ScalarType::S8 => ScalarKind::Signed,
- ast::ScalarType::S16 => ScalarKind::Signed,
- ast::ScalarType::S32 => ScalarKind::Signed,
- ast::ScalarType::S64 => ScalarKind::Signed,
- ast::ScalarType::B8 => ScalarKind::Bit,
- ast::ScalarType::B16 => ScalarKind::Bit,
- ast::ScalarType::B32 => ScalarKind::Bit,
- ast::ScalarType::B64 => ScalarKind::Bit,
- ast::ScalarType::F16 => ScalarKind::Float,
- ast::ScalarType::F32 => ScalarKind::Float,
- ast::ScalarType::F64 => ScalarKind::Float,
- ast::ScalarType::F16x2 => ScalarKind::Float2,
- ast::ScalarType::Pred => ScalarKind::Pred,
- }
- }
-
- fn from_parts(width: u8, kind: ScalarKind) -> Self {
+ fn from_parts(width: u8, kind: ast::ScalarKind) -> Self {
match kind {
- ScalarKind::Float => match width {
+ ast::ScalarKind::Float => match width {
2 => ast::ScalarType::F16,
4 => ast::ScalarType::F32,
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
- ScalarKind::Bit => match width {
+ ast::ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => unreachable!(),
},
- ScalarKind::Signed => match width {
+ ast::ScalarKind::Signed => match width {
1 => ast::ScalarType::S8,
2 => ast::ScalarType::S16,
4 => ast::ScalarType::S32,
8 => ast::ScalarType::S64,
_ => unreachable!(),
},
- ScalarKind::Unsigned => match width {
+ ast::ScalarKind::Unsigned => match width {
1 => ast::ScalarType::U8,
2 => ast::ScalarType::U16,
4 => ast::ScalarType::U32,
8 => ast::ScalarType::U64,
_ => unreachable!(),
},
- ScalarKind::Float2 => match width {
+ ast::ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2,
_ => unreachable!(),
},
- ScalarKind::Pred => ast::ScalarType::Pred,
- }
- }
-}
-
-impl ast::BooleanType {
- fn to_type(self) -> ast::Type {
- match self {
- ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
- ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
- }
- }
-}
-
-impl ast::ShlType {
- fn to_type(self) -> ast::Type {
- match self {
- ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
- }
- }
-}
-
-impl ast::ShrType {
- fn signed(&self) -> bool {
- match self {
- ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
- _ => false,
+ ast::ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
@@ -7359,49 +7183,47 @@ impl ast::AtomInnerDetails { }
}
-impl ast::SIntType {
- fn from_size(width: u8) -> Self {
- match width {
- 1 => ast::SIntType::S8,
- 2 => ast::SIntType::S16,
- 4 => ast::SIntType::S32,
- 8 => ast::SIntType::S64,
- _ => unreachable!(),
+impl ast::StateSpace {
+ fn to_spirv(self) -> spirv::StorageClass {
+ match self {
+ ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::StateSpace::Generic => spirv::StorageClass::Generic,
+ ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::StateSpace::Local => spirv::StorageClass::Function,
+ ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::StateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Reg => spirv::StorageClass::Function,
+ ast::StateSpace::Sreg => spirv::StorageClass::Input,
}
}
-}
-impl ast::UIntType {
- fn from_size(width: u8) -> Self {
- match width {
- 1 => ast::UIntType::U8,
- 2 => ast::UIntType::U16,
- 4 => ast::UIntType::U32,
- 8 => ast::UIntType::U64,
- _ => unreachable!(),
- }
+ fn is_compatible(self, other: ast::StateSpace) -> bool {
+ self == other
+ || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
+ || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}
-}
-impl ast::LdStateSpace {
- fn to_spirv(self) -> spirv::StorageClass {
+ fn coerces_to_generic(self) -> bool {
match self {
- ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
- ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::LdStateSpace::Local => spirv::StorageClass::Function,
- ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::LdStateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Sreg => false,
}
}
-}
-impl From<ast::FnArgumentType> for ast::VariableType {
- fn from(t: ast::FnArgumentType) -> Self {
- match t {
- ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
- ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
- ast::FnArgumentType::Shared => todo!(),
+ fn is_addressable(self) -> bool {
+ match self {
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
}
}
}
@@ -7427,16 +7249,6 @@ impl ast::MulDetails { }
}
-impl ast::AtomSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
- ast::AtomSpace::Global => ast::LdStateSpace::Global,
- ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
impl ast::MemScope {
fn to_spirv(self) -> spirv::Scope {
match self {
@@ -7458,109 +7270,96 @@ impl ast::AtomSemantics { }
}
-impl ast::FnArgumentType {
- fn semantics(&self) -> ArgumentSemantics {
- match self {
- ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
- ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
- ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
- }
- }
-}
-
-fn bitcast_register_pointer(
- operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- bitcast_physical_pointer(operand_type, instr_type, ss)
+ if !instruction_space.is_compatible(operand_space) {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
}
-fn bitcast_physical_pointer(
- operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- match operand_type {
- // array decays to a pointer
- ast::Type::Array(op_scalar_t, _) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if ss == Some(*instr_space) {
- if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
+ if (instruction_space == ast::StateSpace::Generic && operand_space.coerces_to_generic())
+ || (operand_space == ast::StateSpace::Generic && instruction_space.coerces_to_generic())
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if operand_space.is_compatible(ast::StateSpace::Reg) {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
} else {
- if ss == Some(ast::LdStateSpace::Generic)
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ Ok(None)
}
- } else {
- Err(TranslateError::MismatchedType)
- }
- }
- ast::Type::Scalar(ast::ScalarType::B64)
- | ast::Type::Scalar(ast::ScalarType::U64)
- | ast::Type::Scalar(ast::ScalarType::S64) => {
- if let Some(space) = ss {
- Ok(Some(ConversionKind::BitToPtr(space)))
- } else {
- Err(error_unreachable())
- }
- }
- ast::Type::Scalar(ast::ScalarType::B32)
- | ast::Type::Scalar(ast::ScalarType::U32)
- | ast::Type::Scalar(ast::ScalarType::S32) => match ss {
- Some(ast::LdStateSpace::Shared)
- | Some(ast::LdStateSpace::Generic)
- | Some(ast::LdStateSpace::Param)
- | Some(ast::LdStateSpace::Local) => {
- Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
}
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(TranslateError::MismatchedType),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(TranslateError::MismatchedType),
+ },
_ => Err(TranslateError::MismatchedType),
- },
- ast::Type::Pointer(op_scalar_t, op_space) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if op_space == instr_space {
- if op_scalar_t == instr_scalar_t {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
+ }
+ } else if instruction_space.is_compatible(ast::StateSpace::Reg) {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
} else {
- if *op_space == ast::LdStateSpace::Generic
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ Ok(None)
}
- } else {
- Err(TranslateError::MismatchedType)
}
+ _ => Err(TranslateError::MismatchedType),
}
- _ => Err(TranslateError::MismatchedType),
+ } else {
+ Err(TranslateError::MismatchedType)
}
}
-fn force_bitcast_ptr_to_bit(
- _: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
- // TODO: verify this on f32, u16 and the like
- if let ast::Type::Scalar(scalar_t) = instr_type {
- if let Ok(int_type) = (*scalar_t).try_into() {
- return Ok(Some(ConversionKind::PtrToBit(int_type)));
+ if space.is_compatible(ast::StateSpace::Reg) {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
}
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
}
- Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
@@ -7570,16 +7369,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { return false;
}
match inst.kind() {
- ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
- ScalarKind::Float => operand.kind() == ScalarKind::Bit,
- ScalarKind::Signed => {
- operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
}
- ScalarKind::Unsigned => {
- operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
}
- ScalarKind::Float2 => false,
- ScalarKind::Pred => false,
+ ast::ScalarKind::Float2 => false,
+ ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
@@ -7590,47 +7391,45 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { }
}
-fn should_bitcast_packed(
- operand: &ast::Type,
- instr: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+fn implicit_conversion_mov(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
- (operand, instr)
- {
- if scalar.kind() == ScalarKind::Bit
- && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ // instruction_space is always reg
+ if operand_space.is_compatible(ast::StateSpace::Reg) {
+ if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
{
- return Ok(Some(ConversionKind::Default));
- }
- }
- should_bitcast_wrapper(operand, instr, ss)
-}
-
-fn should_bitcast_wrapper(
- operand: &ast::Type,
- instr: &ast::Type,
- _: Option<ast::LdStateSpace>,
-) -> Result<Option<ConversionKind>, TranslateError> {
- if instr == operand {
- return Ok(None);
- }
- if should_bitcast(instr, operand) {
- Ok(Some(ConversionKind::Default))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
+ }
+ // TODO: verify .params addressability:
+ // * kernel arg
+ // * func arg
+ // * variable
+ } else if operand_space.is_addressable() {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ default_implicit_conversion(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
}
fn should_convert_relaxed_src_wrapper(
- src_type: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if src_type == instr_type {
+ if !operand_space.is_compatible(instruction_space) {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
return Ok(None);
}
- match should_convert_relaxed_src(src_type, instr_type) {
+ match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
@@ -7646,32 +7445,33 @@ fn should_convert_relaxed_src( }
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ScalarKind::Bit => {
+ ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Signed | ScalarKind::Unsigned => {
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
- && src_type.kind() != ScalarKind::Float
+ && src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float => {
- if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float2 => todo!(),
- ScalarKind::Pred => None,
+ ast::ScalarKind::Float2 => todo!(),
+ ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@@ -7685,14 +7485,16 @@ fn should_convert_relaxed_src( }
fn should_convert_relaxed_dst_wrapper(
- dst_type: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if dst_type == instr_type {
+ if !operand_space.is_compatible(instruction_space) {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
return Ok(None);
}
- match should_convert_relaxed_dst(dst_type, instr_type) {
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
@@ -7708,15 +7510,15 @@ fn should_convert_relaxed_dst( }
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ScalarKind::Bit => {
+ ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Signed => {
- if dst_type.kind() != ScalarKind::Float {
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
@@ -7728,25 +7530,26 @@ fn should_convert_relaxed_dst( None
}
}
- ScalarKind::Unsigned => {
+ ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
- && dst_type.kind() != ScalarKind::Float
+ && dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float => {
- if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float2 => todo!(),
- ScalarKind::Pred => None,
+ ast::ScalarKind::Float2 => todo!(),
+ ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@@ -7759,77 +7562,46 @@ fn should_convert_relaxed_dst( }
}
-impl<'a> ast::MethodDecl<'a, &'a str> {
+impl<'a> ast::MethodDeclaration<'a, &'a str> {
fn name(&self) -> &'a str {
- match self {
- ast::MethodDecl::Kernel { name, .. } => name,
- ast::MethodDecl::Func(_, name, _) => name,
+ match self.name {
+ ast::MethodName::Kernel(name) => name,
+ ast::MethodName::Func(name) => name,
}
}
}
-struct SpirvMethodDecl<'input> {
- input: Vec<ast::Variable<ast::Type, spirv::Word>>,
- output: Vec<ast::Variable<ast::Type, spirv::Word>>,
- name: MethodName<'input>,
- uses_shared_mem: bool,
+impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
+ fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
+ let is_kernel = self.name.is_kernel();
+ self.input_arguments
+ .iter()
+ .map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
+ .chain(self.shared_mem.iter().map(|id| {
+ (
+ *id,
+ SpirvType::Pointer(
+ Box::new(SpirvType::Base(SpirvScalarKey::B8)),
+ spirv::StorageClass::Workgroup,
+ ),
+ )
+ }))
+ }
}
-impl<'input> SpirvMethodDecl<'input> {
- fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- let (input, output) = match ast_decl {
- ast::MethodDecl::Kernel { in_args, .. } => {
- let spirv_input = in_args
- .iter()
- .map(|var| {
- let v_type = match &var.v_type {
- ast::KernelArgumentType::Normal(t) => {
- ast::FnArgumentType::Param(t.clone())
- }
- ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
- };
- ast::Variable {
- name: var.name,
- align: var.align,
- v_type: v_type.to_kernel_type(),
- array_init: var.array_init.clone(),
- }
- })
- .collect();
- (spirv_input, Vec::new())
- }
- ast::MethodDecl::Func(out_args, _, in_args) => {
- let (param_output, non_param_output): (Vec<_>, Vec<_>) =
- out_args.iter().partition(|var| var.v_type.is_param());
- let spirv_output = non_param_output
- .into_iter()
- .cloned()
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.to_func_type(),
- array_init: var.array_init.clone(),
- })
- .collect();
- let spirv_input = param_output
- .into_iter()
- .cloned()
- .chain(in_args.iter().cloned())
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.to_func_type(),
- array_init: var.array_init.clone(),
- })
- .collect();
- (spirv_input, spirv_output)
- }
- };
- SpirvMethodDecl {
- input,
- output,
- name: MethodName::new(ast_decl),
- uses_shared_mem: false,
+impl<'input, ID> ast::MethodName<'input, ID> {
+ fn is_kernel(&self) -> bool {
+ match self {
+ ast::MethodName::Kernel(..) => true,
+ ast::MethodName::Func(..) => false,
}
}
}
|