aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src')
-rw-r--r--ptx/src/ast.rs28
-rw-r--r--ptx/src/lib.rs1
-rw-r--r--ptx/src/pass/convert_dynamic_shared_memory_usage.rs299
-rw-r--r--ptx/src/pass/convert_to_stateful_memory_access.rs524
-rw-r--r--ptx/src/pass/convert_to_typed.rs138
-rw-r--r--ptx/src/pass/emit_spirv.rs2763
-rw-r--r--ptx/src/pass/expand_arguments.rs181
-rw-r--r--ptx/src/pass/extract_globals.rs282
-rw-r--r--ptx/src/pass/fix_special_registers.rs130
-rw-r--r--ptx/src/pass/insert_implicit_conversions.rs432
-rw-r--r--ptx/src/pass/insert_mem_ssa_statements.rs275
-rw-r--r--ptx/src/pass/mod.rs1677
-rw-r--r--ptx/src/pass/normalize_identifiers.rs80
-rw-r--r--ptx/src/pass/normalize_labels.rs48
-rw-r--r--ptx/src/pass/normalize_predicates.rs44
-rw-r--r--ptx/src/test/spirv_run/clz.spvtxt19
-rw-r--r--ptx/src/test/spirv_run/cvt_s16_s8.spvtxt7
-rw-r--r--ptx/src/test/spirv_run/cvt_s64_s32.spvtxt8
-rw-r--r--ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt6
-rw-r--r--ptx/src/test/spirv_run/mod.rs7
-rw-r--r--ptx/src/test/spirv_run/popc.spvtxt19
-rw-r--r--ptx/src/test/spirv_run/vector.ptx2
-rw-r--r--ptx/src/translate.rs20
23 files changed, 6937 insertions, 53 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index d308479..358b8ce 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -16,6 +16,8 @@ pub enum PtxError {
source: ParseFloatError,
},
#[error("")]
+ Unsupported32Bit,
+ #[error("")]
SyntaxError,
#[error("")]
NonF32Ftz,
@@ -32,15 +34,9 @@ pub enum PtxError {
#[error("")]
NonExternPointer,
#[error("{start}:{end}")]
- UnrecognizedStatement {
- start: usize,
- end: usize,
- },
+ UnrecognizedStatement { start: usize, end: usize },
#[error("{start}:{end}")]
- UnrecognizedDirective {
- start: usize,
- end: usize,
- },
+ UnrecognizedDirective { start: usize, end: usize },
}
// For some weird reson this is illegal:
@@ -576,11 +572,15 @@ impl CvtDetails {
if saturate {
if src.kind() == ScalarKind::Signed {
if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
- err.push(ParseError::from(PtxError::SyntaxError));
+ err.push(ParseError::User {
+ error: PtxError::SyntaxError,
+ });
}
} else {
if dst == src || dst.size_of() >= src.size_of() {
- err.push(ParseError::from(PtxError::SyntaxError));
+ err.push(ParseError::User {
+ error: PtxError::SyntaxError,
+ });
}
}
}
@@ -596,7 +596,9 @@ impl CvtDetails {
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if flush_to_zero && dst != ScalarType::F32 {
- err.push(ParseError::from(PtxError::NonF32Ftz));
+ err.push(ParseError::from(lalrpop_util::ParseError::User {
+ error: PtxError::NonF32Ftz,
+ }));
}
CvtDetails::FloatFromInt(CvtDesc {
dst,
@@ -616,7 +618,9 @@ impl CvtDetails {
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if flush_to_zero && src != ScalarType::F32 {
- err.push(ParseError::from(PtxError::NonF32Ftz));
+ err.push(ParseError::from(lalrpop_util::ParseError::User {
+ error: PtxError::NonF32Ftz,
+ }));
}
CvtDetails::IntFromFloat(CvtDesc {
dst,
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 1cb9630..5e95dae 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -24,6 +24,7 @@ lalrpop_mod!(
);
pub mod ast;
+pub(crate) mod pass;
#[cfg(test)]
mod test;
mod translate;
diff --git a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs
new file mode 100644
index 0000000..1dac7fd
--- /dev/null
+++ b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs
@@ -0,0 +1,299 @@
+use std::collections::{BTreeMap, BTreeSet};
+
+use super::*;
+
+/*
+ PTX represents dynamically allocated shared local memory as
+ .extern .shared .b32 shared_mem[];
+ In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
+ And in AMD compilation
+ This pass looks for all uses of .extern .shared and converts them to
+ an additional method argument
+ The question is how this artificial argument should be expressed. There are
+ several options:
+ * Straight conversion:
+ .shared .b32 shared_mem[]
+ * Introduce .param_shared statespace:
+ .param_shared .b32 shared_mem
+ or
+ .param_shared .b32 shared_mem[]
+ * Introduce .shared_ptr <SCALAR> type:
+ .param .shared_ptr .b32 shared_mem
+ * Reuse .ptr hint:
+ .param .u64 .ptr shared_mem
+ This is the most tempting, but also the most nonsensical, .ptr is just a
+ hint, which has no semantical meaning (and the output of our
+ transformation has a semantical meaning - we emit additional
+ "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
+*/
+pub(super) fn run<'input>(
+ module: Vec<Directive<'input>>,
+ kernels_methods_call_map: &MethodsCallMap<'input>,
+ new_id: &mut impl FnMut() -> SpirvWord,
+) -> Result<Vec<Directive<'input>>, TranslateError> {
+ let mut globals_shared = HashMap::new();
+ for dir in module.iter() {
+ match dir {
+ Directive::Variable(
+ _,
+ ast::Variable {
+ state_space: ast::StateSpace::Shared,
+ name,
+ v_type,
+ ..
+ },
+ ) => {
+ globals_shared.insert(*name, v_type.clone());
+ }
+ _ => {}
+ }
+ }
+ if globals_shared.len() == 0 {
+ return Ok(module);
+ }
+ let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet<SpirvWord>>::new();
+ let module = module
+ .into_iter()
+ .map(|directive| match directive {
+ Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ import_as,
+ tuning,
+ linkage,
+ }) => {
+ let call_key = (*func_decl).borrow().name;
+ let statements = statements
+ .into_iter()
+ .map(|statement| {
+ statement.visit_map(
+ &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
+ if let Some(_) = globals_shared.get(&id) {
+ methods_to_directly_used_shared_globals
+ .entry(call_key)
+ .or_insert_with(HashSet::new)
+ .insert(id);
+ }
+ Ok::<_, TranslateError>(id)
+ },
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok::<_, TranslateError>(Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ import_as,
+ tuning,
+ linkage,
+ }))
+ }
+ directive => Ok(directive),
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
+ // make sure it gets propagated to `fn1` and `kernel`
+ let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
+ methods_to_directly_used_shared_globals,
+ kernels_methods_call_map,
+ );
+ // now visit every method declaration and inject those additional arguments
+ let mut directives = Vec::with_capacity(module.len());
+ for directive in module.into_iter() {
+ match directive {
+ Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ import_as,
+ tuning,
+ linkage,
+ }) => {
+ let statements = {
+ let func_decl_ref = &mut (*func_decl).borrow_mut();
+ let method_name = func_decl_ref.name;
+ insert_arguments_remap_statements(
+ new_id,
+ kernels_methods_call_map,
+ &globals_shared,
+ &methods_to_indirectly_used_shared_globals,
+ method_name,
+ &mut directives,
+ func_decl_ref,
+ statements,
+ )?
+ };
+ directives.push(Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ import_as,
+ tuning,
+ linkage,
+ }));
+ }
+ directive => directives.push(directive),
+ }
+ }
+ Ok(directives)
+}
+
+// We need to compute two kinds of information:
+// * If it's a kernel -> size of .shared globals in use (direct or indirect)
+// * If it's a function -> does it use .shared global (directly or indirectly)
+fn resolve_indirect_uses_of_globals_shared<'input>(
+ methods_use_of_globals_shared: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+ kernels_methods_call_map: &MethodsCallMap<'input>,
+) -> HashMap<ast::MethodName<'input, SpirvWord>, BTreeSet<SpirvWord>> {
+ let mut result = HashMap::new();
+ for (method, callees) in kernels_methods_call_map.methods() {
+ let mut indirect_globals = methods_use_of_globals_shared
+ .get(&method)
+ .into_iter()
+ .flatten()
+ .copied()
+ .collect::<BTreeSet<_>>();
+ for &callee in callees {
+ indirect_globals.extend(
+ methods_use_of_globals_shared
+ .get(&ast::MethodName::Func(callee))
+ .into_iter()
+ .flatten()
+ .copied(),
+ );
+ }
+ result.insert(method, indirect_globals);
+ }
+ result
+}
+
+fn insert_arguments_remap_statements<'input>(
+ new_id: &mut impl FnMut() -> SpirvWord,
+ kernels_methods_call_map: &MethodsCallMap<'input>,
+ globals_shared: &HashMap<SpirvWord, ast::Type>,
+ methods_to_indirectly_used_shared_globals: &HashMap<
+ ast::MethodName<'input, SpirvWord>,
+ BTreeSet<SpirvWord>,
+ >,
+ method_name: ast::MethodName<SpirvWord>,
+ result: &mut Vec<Directive>,
+ func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<SpirvWord>>,
+ statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ let remapped_globals_in_method =
+ if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) {
+ match method_name {
+ ast::MethodName::Func(..) => {
+ let remapped_globals = method_globals
+ .iter()
+ .map(|global| {
+ (
+ *global,
+ (
+ new_id(),
+ globals_shared
+ .get(&global)
+ .unwrap_or_else(|| todo!())
+ .clone(),
+ ),
+ )
+ })
+ .collect::<BTreeMap<_, _>>();
+ for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() {
+ func_decl_ref.input_arguments.push(ast::Variable {
+ align: None,
+ v_type: shared_global_type.clone(),
+ state_space: ast::StateSpace::Shared,
+ name: *new_shared_global_id,
+ array_init: Vec::new(),
+ });
+ }
+ remapped_globals
+ }
+ ast::MethodName::Kernel(..) => method_globals
+ .iter()
+ .map(|global| {
+ (
+ *global,
+ (
+ *global,
+ globals_shared
+ .get(&global)
+ .unwrap_or_else(|| todo!())
+ .clone(),
+ ),
+ )
+ })
+ .collect::<BTreeMap<_, _>>(),
+ }
+ } else {
+ return Ok(statements);
+ };
+ replace_uses_of_shared_memory(
+ new_id,
+ methods_to_indirectly_used_shared_globals,
+ statements,
+ remapped_globals_in_method,
+ )
+}
+
+fn replace_uses_of_shared_memory<'input>(
+ new_id: &mut impl FnMut() -> SpirvWord,
+ methods_to_indirectly_used_shared_globals: &HashMap<
+ ast::MethodName<'input, SpirvWord>,
+ BTreeSet<SpirvWord>,
+ >,
+ statements: Vec<ExpandedStatement>,
+ remapped_globals_in_method: BTreeMap<SpirvWord, (SpirvWord, ast::Type)>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ match statement {
+ Statement::Instruction(ast::Instruction::Call {
+ mut data,
+ mut arguments,
+ }) => {
+ // We can safely skip checking call arguments,
+ // because there's simply no way to pass shared ptr
+ // without converting it to .b64 first
+ if let Some(shared_globals_used_by_callee) =
+ methods_to_indirectly_used_shared_globals
+ .get(&ast::MethodName::Func(arguments.func))
+ {
+ for &shared_global_used_by_callee in shared_globals_used_by_callee {
+ let (remapped_shared_id, type_) = remapped_globals_in_method
+ .get(&shared_global_used_by_callee)
+ .unwrap_or_else(|| todo!());
+ data.input_arguments
+ .push((type_.clone(), ast::StateSpace::Shared));
+ arguments.input_arguments.push(*remapped_shared_id);
+ }
+ }
+ result.push(Statement::Instruction(ast::Instruction::Call {
+ data,
+ arguments,
+ }))
+ }
+ statement => {
+ let new_statement =
+ statement.visit_map(&mut |id,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ _,
+ _| {
+ Ok::<_, TranslateError>(
+ if let Some((remapped_shared_id, _)) =
+ remapped_globals_in_method.get(&id)
+ {
+ *remapped_shared_id
+ } else {
+ id
+ },
+ )
+ })?;
+ result.push(new_statement);
+ }
+ }
+ }
+ Ok(result)
+}
diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs
new file mode 100644
index 0000000..455a8c2
--- /dev/null
+++ b/ptx/src/pass/convert_to_stateful_memory_access.rs
@@ -0,0 +1,524 @@
+use super::*;
+use ptx_parser as ast;
+use std::{
+ collections::{BTreeSet, HashSet},
+ iter,
+ rc::Rc,
+};
+
+/*
+ Our goal here is to transform
+ .visible .entry foobar(.param .u64 input) {
+ .reg .b64 in_addr;
+ .reg .b64 in_addr2;
+ ld.param.u64 in_addr, [input];
+ cvta.to.global.u64 in_addr2, in_addr;
+ }
+ into:
+ .visible .entry foobar(.param .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ ld.param.u8[] in_addr, [input];
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.reg .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ mov.u8[] in_addr, input;
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.param ptr<u8, global> input) {
+ .reg ptr<u8, global> in_addr;
+ .reg ptr<u8, global> in_addr2;
+ ld.param.ptr<u8, global> in_addr, [input];
+ mov.ptr<u8, global> in_addr2, in_addr;
+ }
+*/
+// TODO: detect more patterns (mov, call via reg, call via param)
+// TODO: don't convert to ptr if the register is not ultimately used for ld/st
+// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
+// argument expansion
+// TODO: propagate out of calls and into calls
+pub(super) fn run<'a, 'input>(
+ func_args: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ func_body: Vec<TypedStatement>,
+ id_defs: &mut NumericIdResolver<'a>,
+) -> Result<
+ (
+ Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ Vec<TypedStatement>,
+ ),
+ TranslateError,
+> {
+ let mut method_decl = func_args.borrow_mut();
+ if !matches!(method_decl.name, ast::MethodName::Kernel(..)) {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
+ if Rc::strong_count(&func_args) != 1 {
+ return Err(error_unreachable());
+ }
+ let func_args_64bit = (*method_decl)
+ .input_arguments
+ .iter()
+ .filter_map(|arg| match arg.v_type {
+ ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
+ _ => None,
+ })
+ .collect::<HashSet<_>>();
+ let mut stateful_markers = Vec::new();
+ let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Cvta {
+ data:
+ ast::CvtaDetails {
+ state_space: ast::StateSpace::Global,
+ direction: ast::CvtaDirection::GenericToExplicit,
+ },
+ arguments,
+ }) => {
+ if let (TypedOperand::Reg(dst), Some(src)) =
+ (arguments.dst, arguments.src.underlying_register())
+ {
+ if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) {
+ stateful_markers.push((dst, src));
+ }
+ }
+ }
+ Statement::Instruction(ast::Instruction::Ld {
+ data:
+ ast::LdDetails {
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::U64),
+ ..
+ },
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Ld {
+ data:
+ ast::LdDetails {
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::S64),
+ ..
+ },
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Ld {
+ data:
+ ast::LdDetails {
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::B64),
+ ..
+ },
+ arguments,
+ }) => {
+ if let (TypedOperand::Reg(dst), Some(src)) =
+ (arguments.dst, arguments.src.underlying_register())
+ {
+ if func_args_64bit.contains(&src) {
+ multi_hash_map_append(&mut stateful_init_reg, dst, src);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ if stateful_markers.len() == 0 {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
+ let mut func_args_ptr = HashSet::new();
+ let mut regs_ptr_current = HashSet::new();
+ for (dst, src) in stateful_markers {
+ if let Some(func_args) = stateful_init_reg.get(&src) {
+ for a in func_args {
+ func_args_ptr.insert(*a);
+ regs_ptr_current.insert(src);
+ regs_ptr_current.insert(dst);
+ }
+ }
+ }
+ // BTreeSet here to have a stable order of iteration,
+ // unfortunately our tests rely on it
+ let mut regs_ptr_seen = BTreeSet::new();
+ while regs_ptr_current.len() > 0 {
+ let mut regs_ptr_new = HashSet::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Add {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::U64,
+ saturate: false,
+ }),
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Add {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::S64,
+ saturate: false,
+ }),
+ arguments,
+ }) => {
+ // TODO: don't mark result of double pointer sub or double
+ // pointer add as ptr result
+ if let (TypedOperand::Reg(dst), Some(src1)) =
+ (arguments.dst, arguments.src1.underlying_register())
+ {
+ if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
+ regs_ptr_new.insert(dst);
+ }
+ } else if let (TypedOperand::Reg(dst), Some(src2)) =
+ (arguments.dst, arguments.src2.underlying_register())
+ {
+ if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
+ regs_ptr_new.insert(dst);
+ }
+ }
+ }
+
+ Statement::Instruction(ast::Instruction::Sub {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::U64,
+ saturate: false,
+ }),
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Sub {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::S64,
+ saturate: false,
+ }),
+ arguments,
+ }) => {
+ // TODO: don't mark result of double pointer sub or double
+ // pointer add as ptr result
+ if let (TypedOperand::Reg(dst), Some(src1)) =
+ (arguments.dst, arguments.src1.underlying_register())
+ {
+ if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
+ regs_ptr_new.insert(dst);
+ }
+ } else if let (TypedOperand::Reg(dst), Some(src2)) =
+ (arguments.dst, arguments.src2.underlying_register())
+ {
+ if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
+ regs_ptr_new.insert(dst);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ for id in regs_ptr_current {
+ regs_ptr_seen.insert(id);
+ }
+ regs_ptr_current = regs_ptr_new;
+ }
+ drop(regs_ptr_current);
+ let mut remapped_ids = HashMap::new();
+ let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
+ for reg in regs_ptr_seen {
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Reg,
+ );
+ result.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: new_id,
+ array_init: Vec::new(),
+ v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ state_space: ast::StateSpace::Reg,
+ }));
+ remapped_ids.insert(reg, new_id);
+ }
+ for arg in (*method_decl).input_arguments.iter_mut() {
+ if !func_args_ptr.contains(&arg.name) {
+ continue;
+ }
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Param,
+ );
+ let old_name = arg.name;
+ arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
+ arg.name = new_id;
+ remapped_ids.insert(old_name, new_id);
+ }
+ for statement in func_body {
+ match statement {
+ l @ Statement::Label(_) => result.push(l),
+ c @ Statement::Conditional(_) => result.push(c),
+ c @ Statement::Constant(..) => result.push(c),
+ Statement::Variable(var) => {
+ if !remapped_ids.contains_key(&var.name) {
+ result.push(Statement::Variable(var));
+ }
+ }
+ Statement::Instruction(ast::Instruction::Add {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::U64,
+ saturate: false,
+ }),
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Add {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::S64,
+ saturate: false,
+ }),
+ arguments,
+ }) if is_add_ptr_direct(&remapped_ids, &arguments) => {
+ let (ptr, offset) = match arguments.src1.underlying_register() {
+ Some(src1) if remapped_ids.contains_key(&src1) => {
+ (remapped_ids.get(&src1).unwrap(), arguments.src2)
+ }
+ Some(src2) if remapped_ids.contains_key(&src2) => {
+ (remapped_ids.get(&src2).unwrap(), arguments.src1)
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let dst = arguments.dst.unwrap_reg()?;
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
+ state_space: ast::StateSpace::Global,
+ dst: *remapped_ids.get(&dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: offset,
+ }))
+ }
+ Statement::Instruction(ast::Instruction::Sub {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::U64,
+ saturate: false,
+ }),
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Sub {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::S64,
+ saturate: false,
+ }),
+ arguments,
+ }) if is_sub_ptr_direct(&remapped_ids, &arguments) => {
+ let (ptr, offset) = match arguments.src1.underlying_register() {
+ Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2),
+ _ => return Err(error_unreachable()),
+ };
+ let offset_neg = id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::Instruction(ast::Instruction::Neg {
+ data: ast::TypeFtz {
+ type_: ast::ScalarType::S64,
+ flush_to_zero: None,
+ },
+ arguments: ast::NegArgs {
+ src: offset,
+ dst: TypedOperand::Reg(offset_neg),
+ },
+ }));
+ let dst = arguments.dst.unwrap_reg()?;
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
+ state_space: ast::StateSpace::Global,
+ dst: *remapped_ids.get(&dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: TypedOperand::Reg(offset_neg),
+ }))
+ }
+ inst @ Statement::Instruction(_) => {
+ let mut post_statements = Vec::new();
+ let new_statement = inst.visit_map(&mut FnVisitor::new(
+ |operand, type_space, is_dst, relaxed_conversion| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &mut result,
+ &mut post_statements,
+ operand,
+ type_space,
+ is_dst,
+ relaxed_conversion,
+ )
+ },
+ ))?;
+ result.push(new_statement);
+ result.extend(post_statements);
+ }
+ repack @ Statement::RepackVector(_) => {
+ let mut post_statements = Vec::new();
+ let new_statement = repack.visit_map(&mut FnVisitor::new(
+ |operand, type_space, is_dst, relaxed_conversion| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &mut result,
+ &mut post_statements,
+ operand,
+ type_space,
+ is_dst,
+ relaxed_conversion,
+ )
+ },
+ ))?;
+ result.push(new_statement);
+ result.extend(post_statements);
+ }
+ _ => return Err(error_unreachable()),
+ }
+ }
+ drop(method_decl);
+ Ok((func_args, result))
+}
+
+fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
+ match id_defs.get_typed(id) {
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
+ _ => false,
+ }
+}
+
+fn is_add_ptr_direct(
+ remapped_ids: &HashMap<SpirvWord, SpirvWord>,
+ arg: &ast::AddArgs<TypedOperand>,
+) -> bool {
+ match arg.dst {
+ TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
+ return false
+ }
+ TypedOperand::Reg(dst) => {
+ if !remapped_ids.contains_key(&dst) {
+ return false;
+ }
+ if let Some(ref src1_reg) = arg.src1.underlying_register() {
+ if remapped_ids.contains_key(src1_reg) {
+ // don't trigger optimization when adding two pointers
+ if let Some(ref src2_reg) = arg.src2.underlying_register() {
+ return !remapped_ids.contains_key(src2_reg);
+ }
+ }
+ }
+ if let Some(ref src2_reg) = arg.src2.underlying_register() {
+ remapped_ids.contains_key(src2_reg)
+ } else {
+ false
+ }
+ }
+ }
+}
+
+fn is_sub_ptr_direct(
+ remapped_ids: &HashMap<SpirvWord, SpirvWord>,
+ arg: &ast::SubArgs<TypedOperand>,
+) -> bool {
+ match arg.dst {
+ TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
+ return false
+ }
+ TypedOperand::Reg(dst) => {
+ if !remapped_ids.contains_key(&dst) {
+ return false;
+ }
+ match arg.src1.underlying_register() {
+ Some(ref src1_reg) => {
+ if remapped_ids.contains_key(src1_reg) {
+ // don't trigger optimization when subtracting two pointers
+ arg.src2
+ .underlying_register()
+ .map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg))
+ } else {
+ false
+ }
+ }
+ None => false,
+ }
+ }
+ }
+}
+
+fn convert_to_stateful_memory_access_postprocess(
+ id_defs: &mut NumericIdResolver,
+ remapped_ids: &HashMap<SpirvWord, SpirvWord>,
+ result: &mut Vec<TypedStatement>,
+ post_statements: &mut Vec<TypedStatement>,
+ operand: TypedOperand,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_conversion: bool,
+) -> Result<TypedOperand, TranslateError> {
+ operand.map(|operand, _| {
+ Ok(match remapped_ids.get(&operand) {
+ Some(new_id) => {
+ let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
+ // TODO: readd if required
+ if let Some((expected_type, expected_space)) = type_space {
+ let implicit_conversion = if relaxed_conversion {
+ if is_dst {
+ super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper
+ } else {
+ super::insert_implicit_conversions::should_convert_relaxed_src_wrapper
+ }
+ } else {
+ super::insert_implicit_conversions::default_implicit_conversion
+ };
+ if implicit_conversion(
+ (new_operand_space, &new_operand_type),
+ (expected_space, expected_type),
+ )
+ .is_ok()
+ {
+ return Ok(*new_id);
+ }
+ }
+ let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
+ let converting_id = id_defs
+ .register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
+ let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) {
+ ConversionKind::Default
+ } else {
+ ConversionKind::PtrToPtr
+ };
+ if is_dst {
+ post_statements.push(Statement::Conversion(ImplicitConversion {
+ src: converting_id,
+ dst: *new_id,
+ from_type: old_operand_type,
+ from_space: old_operand_space,
+ to_type: new_operand_type,
+ to_space: new_operand_space,
+ kind,
+ }));
+ converting_id
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from_type: new_operand_type,
+ from_space: new_operand_space,
+ to_type: old_operand_type,
+ to_space: old_operand_space,
+ kind,
+ }));
+ converting_id
+ }
+ }
+ None => operand,
+ })
+ })
+}
diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs
new file mode 100644
index 0000000..550c662
--- /dev/null
+++ b/ptx/src/pass/convert_to_typed.rs
@@ -0,0 +1,138 @@
+use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run(
+ func: Vec<UnconditionalStatement>,
+ fn_defs: &GlobalFnDeclResolver,
+ id_defs: &mut NumericIdResolver,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let mut result = Vec::<TypedStatement>::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::Mov {
+ data,
+ arguments:
+ ast::MovArgs {
+ dst: ast::ParsedOperand::Reg(dst_reg),
+ src: ast::ParsedOperand::Reg(src_reg),
+ },
+ } if fn_defs.fns.contains_key(&src_reg) => {
+ if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
+ return Err(error_mismatched_type());
+ }
+ result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
+ dst: dst_reg,
+ src: src_reg,
+ }));
+ }
+ ast::Instruction::Call { data, arguments } => {
+ let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?;
+ let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?;
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let reresolved_call =
+ Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?);
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ inst => {
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?);
+ visitor.func.push(instruction);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ },
+ Statement::Label(i) => result.push(Statement::Label(i)),
+ Statement::Variable(v) => result.push(Statement::Variable(v)),
+ Statement::Conditional(c) => result.push(Statement::Conditional(c)),
+ _ => return Err(error_unreachable()),
+ }
+ }
+ Ok(result)
+}
+
+struct VectorRepackVisitor<'a, 'b> {
+ func: &'b mut Vec<TypedStatement>,
+ id_def: &'b mut NumericIdResolver<'a>,
+ post_stmts: Option<TypedStatement>,
+}
+
+impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
+ fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
+ VectorRepackVisitor {
+ func,
+ id_def,
+ post_stmts: None,
+ }
+ }
+
+ fn convert_vector(
+ &mut self,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ typ: &ast::Type,
+ state_space: ast::StateSpace,
+ idx: Vec<SpirvWord>,
+ ) -> Result<SpirvWord, TranslateError> {
+ // mov.u32 foobar, {a,b};
+ let scalar_t = match typ {
+ ast::Type::Vector(_, scalar_t) => *scalar_t,
+ _ => return Err(error_mismatched_type()),
+ };
+ let temp_vec = self
+ .id_def
+ .register_intermediate(Some((typ.clone(), state_space)));
+ let statement = Statement::RepackVector(RepackVectorDetails {
+ is_extract: is_dst,
+ typ: scalar_t,
+ packed: temp_vec,
+ unpacked: idx,
+ relaxed_type_check,
+ });
+ if is_dst {
+ self.post_stmts = Some(statement);
+ } else {
+ self.func.push(statement);
+ }
+ Ok(temp_vec)
+ }
+}
+
+impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
+ for VectorRepackVisitor<'a, 'b>
+{
+ fn visit_ident(
+ &mut self,
+ ident: SpirvWord,
+ _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ _: bool,
+ _: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ Ok(ident)
+ }
+
+ fn visit(
+ &mut self,
+ op: ast::ParsedOperand<SpirvWord>,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match op {
+ ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
+ ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
+ ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
+ ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
+ ast::ParsedOperand::VecPack(vec) => {
+ let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?;
+ TypedOperand::Reg(self.convert_vector(
+ is_dst,
+ relaxed_type_check,
+ type_,
+ space,
+ vec,
+ )?)
+ }
+ })
+ }
+}
diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs
new file mode 100644
index 0000000..5147b79
--- /dev/null
+++ b/ptx/src/pass/emit_spirv.rs
@@ -0,0 +1,2763 @@
+use super::*;
+use half::f16;
+use ptx_parser as ast;
+use rspirv::{binary::Assemble, dr};
+use std::{
+ collections::{HashMap, HashSet},
+ ffi::CString,
+ mem,
+};
+
+pub(super) fn run<'input>(
+ mut builder: dr::Builder,
+ id_defs: &GlobalStringIdResolver<'input>,
+ call_map: MethodsCallMap<'input>,
+ denorm_information: HashMap<
+ ptx_parser::MethodName<SpirvWord>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
+ directives: Vec<Directive<'input>>,
+) -> Result<(dr::Module, HashMap<String, KernelInfo>, CString), TranslateError> {
+ builder.set_version(1, 3);
+ emit_capabilities(&mut builder);
+ emit_extensions(&mut builder);
+ let opencl_id = emit_opencl_import(&mut builder);
+ emit_memory_model(&mut builder);
+ let mut map = TypeWordMap::new(&mut builder);
+ //emit_builtins(&mut builder, &mut map, &id_defs);
+ let mut kernel_info = HashMap::new();
+ let (build_options, should_flush_denorms) =
+ emit_denorm_build_string(&call_map, &denorm_information);
+ let (directives, globals_use_map) = get_globals_use_map(directives);
+ emit_directives(
+ &mut builder,
+ &mut map,
+ &id_defs,
+ opencl_id,
+ should_flush_denorms,
+ &call_map,
+ globals_use_map,
+ directives,
+ &mut kernel_info,
+ )?;
+ Ok((builder.module(), kernel_info, build_options))
+}
+
+fn emit_capabilities(builder: &mut dr::Builder) {
+ builder.capability(spirv::Capability::GenericPointer);
+ builder.capability(spirv::Capability::Linkage);
+ builder.capability(spirv::Capability::Addresses);
+ builder.capability(spirv::Capability::Kernel);
+ builder.capability(spirv::Capability::Int8);
+ builder.capability(spirv::Capability::Int16);
+ builder.capability(spirv::Capability::Int64);
+ builder.capability(spirv::Capability::Float16);
+ builder.capability(spirv::Capability::Float64);
+ builder.capability(spirv::Capability::DenormFlushToZero);
+ // TODO: re-enable when Intel float control extension works
+ //builder.capability(spirv::Capability::FunctionFloatControlINTEL);
+}
+
+// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html
+fn emit_extensions(builder: &mut dr::Builder) {
+ // TODO: re-enable when Intel float control extension works
+ //builder.extension("SPV_INTEL_float_controls2");
+ builder.extension("SPV_KHR_float_controls");
+ builder.extension("SPV_KHR_no_integer_wrap_decoration");
+}
+
+fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {
+ builder.ext_inst_import("OpenCL.std")
+}
+
+fn emit_memory_model(builder: &mut dr::Builder) {
+ builder.memory_model(
+ spirv::AddressingModel::Physical64,
+ spirv::MemoryModel::OpenCL,
+ );
+}
+
+struct TypeWordMap {
+ void: spirv::Word,
+ complex: HashMap<SpirvType, SpirvWord>,
+ constants: HashMap<(SpirvType, u64), SpirvWord>,
+}
+
+impl TypeWordMap {
+ fn new(b: &mut dr::Builder) -> TypeWordMap {
+ let void = b.type_void(None);
+ TypeWordMap {
+ void: void,
+ complex: HashMap::<SpirvType, SpirvWord>::new(),
+ constants: HashMap::new(),
+ }
+ }
+
+ fn void(&self) -> spirv::Word {
+ self.void
+ }
+
+ fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord {
+ let key: SpirvScalarKey = t.into();
+ self.get_or_add_spirv_scalar(b, key)
+ }
+
+ fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord {
+ *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| {
+ SpirvWord(match key {
+ SpirvScalarKey::B8 => b.type_int(None, 8, 0),
+ SpirvScalarKey::B16 => b.type_int(None, 16, 0),
+ SpirvScalarKey::B32 => b.type_int(None, 32, 0),
+ SpirvScalarKey::B64 => b.type_int(None, 64, 0),
+ SpirvScalarKey::F16 => b.type_float(None, 16),
+ SpirvScalarKey::F32 => b.type_float(None, 32),
+ SpirvScalarKey::F64 => b.type_float(None, 64),
+ SpirvScalarKey::Pred => b.type_bool(None),
+ SpirvScalarKey::F16x2 => todo!(),
+ })
+ })
+ }
+
+ fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord {
+ match t {
+ SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
+ SpirvType::Pointer(ref typ, storage) => {
+ let base = self.get_or_add(b, *typ.clone());
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0)))
+ }
+ SpirvType::Vector(typ, len) => {
+ let base = self.get_or_add_spirv_scalar(b, typ);
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32)))
+ }
+ SpirvType::Array(typ, array_dimensions) => {
+ let (base_type, length) = match &*array_dimensions {
+ &[] => {
+ return self.get_or_add(b, SpirvType::Base(typ));
+ }
+ &[len] => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
+ let base = self.get_or_add_spirv_scalar(b, typ);
+ let len_const = b.constant_u32(u32_type.0, None, len);
+ (base, len_const)
+ }
+ array_dimensions => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
+ let base = self
+ .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
+ let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]);
+ (base, len_const)
+ }
+ };
+ *self
+ .complex
+ .entry(SpirvType::Array(typ, array_dimensions))
+ .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length)))
+ }
+ SpirvType::Func(ref out_params, ref in_params) => {
+ let out_t = match out_params {
+ Some(p) => self.get_or_add(b, *p.clone()),
+ None => SpirvWord(self.void()),
+ };
+ let in_t = in_params
+ .iter()
+ .map(|t| self.get_or_add(b, t.clone()).0)
+ .collect::<Vec<_>>();
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t)))
+ }
+ SpirvType::Struct(ref underlying) => {
+ let underlying_ids = underlying
+ .iter()
+ .map(|t| self.get_or_add_spirv_scalar(b, *t).0)
+ .collect::<Vec<_>>();
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids)))
+ }
+ }
+ }
+
+ fn get_or_add_fn(
+ &mut self,
+ b: &mut dr::Builder,
+ in_params: impl Iterator<Item = SpirvType>,
+ mut out_params: impl ExactSizeIterator<Item = SpirvType>,
+ ) -> (SpirvWord, SpirvWord) {
+ let (out_args, out_spirv_type) = if out_params.len() == 0 {
+ (None, SpirvWord(self.void()))
+ } else if out_params.len() == 1 {
+ let arg_as_key = out_params.next().unwrap();
+ (
+ Some(Box::new(arg_as_key.clone())),
+ self.get_or_add(b, arg_as_key),
+ )
+ } else {
+ // TODO: support multiple return values
+ todo!()
+ };
+ (
+ out_spirv_type,
+ self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::<Vec<_>>())),
+ )
+ }
+
+ fn get_or_add_constant(
+ &mut self,
+ b: &mut dr::Builder,
+ typ: &ast::Type,
+ init: &[u8],
+ ) -> Result<SpirvWord, TranslateError> {
+ Ok(match typ {
+ ast::Type::Scalar(t) => match t {
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self
+ .get_or_add_constant_single::<u8, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v as u32),
+ ),
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self
+ .get_or_add_constant_single::<u16, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v as u32),
+ ),
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self
+ .get_or_add_constant_single::<u32, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v),
+ ),
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self
+ .get_or_add_constant_single::<u64, _, _>(
+ b,
+ *t,
+ init,
+ |v| v,
+ |b, result_type, v| b.constant_u64(result_type, None, v),
+ ),
+ ast::ScalarType::F16 => self.get_or_add_constant_single::<f16, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u16>(v) } as u64,
+ |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()),
+ ),
+ ast::ScalarType::F32 => self.get_or_add_constant_single::<f32, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u32>(v) } as u64,
+ |b, result_type, v| b.constant_f32(result_type, None, v),
+ ),
+ ast::ScalarType::F64 => self.get_or_add_constant_single::<f64, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u64>(v) },
+ |b, result_type, v| b.constant_f64(result_type, None, v),
+ ),
+ ast::ScalarType::F16x2 => return Err(TranslateError::Todo),
+ ast::ScalarType::Pred => self.get_or_add_constant_single::<u8, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| {
+ if v == 0 {
+ b.constant_false(result_type, None)
+ } else {
+ b.constant_true(result_type, None)
+ }
+ },
+ ),
+ ast::ScalarType::S16x2
+ | ast::ScalarType::U16x2
+ | ast::ScalarType::BF16
+ | ast::ScalarType::BF16x2
+ | ast::ScalarType::B128 => todo!(),
+ },
+ ast::Type::Vector(len, typ) => {
+ let result_type =
+ self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
+ let size_of_t = typ.size_of();
+ let components = (0..*len)
+ .map(|x| {
+ Ok::<_, TranslateError>(
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Scalar(*typ),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )?
+ .0,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
+ }
+ ast::Type::Array(_, typ, dims) => match dims.as_slice() {
+ [] => return Err(error_unreachable()),
+ [dim] => {
+ let result_type = self
+ .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
+ let size_of_t = typ.size_of();
+ let components = (0..*dim)
+ .map(|x| {
+ Ok::<_, TranslateError>(
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Scalar(*typ),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )?
+ .0,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
+ }
+ [first_dim, rest @ ..] => {
+ let result_type = self.get_or_add(
+ b,
+ SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()),
+ );
+ let size_of_t = rest
+ .iter()
+ .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y));
+ let components = (0..*first_dim)
+ .map(|x| {
+ Ok::<_, TranslateError>(
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Array(None, *typ, rest.to_vec()),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )?
+ .0,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
+ }
+ },
+ ast::Type::Pointer(..) => return Err(error_unreachable()),
+ })
+ }
+
+ fn get_or_add_constant_single<
+ T: Copy,
+ CastAsU64: FnOnce(T) -> u64,
+ InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word,
+ >(
+ &mut self,
+ b: &mut dr::Builder,
+ key: ast::ScalarType,
+ init: &[u8],
+ cast: CastAsU64,
+ f: InsertConstant,
+ ) -> SpirvWord {
+ let value = unsafe { *(init.as_ptr() as *const T) };
+ let value_64 = cast(value);
+ let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64);
+ match self.constants.get(&ht_key) {
+ Some(value) => *value,
+ None => {
+ let spirv_type = self.get_or_add_scalar(b, key);
+ let result = SpirvWord(f(b, spirv_type.0, value));
+ self.constants.insert(ht_key, result);
+ result
+ }
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Hash, Clone)]
+enum SpirvType {
+ Base(SpirvScalarKey),
+ Vector(SpirvScalarKey, u8),
+ Array(SpirvScalarKey, Vec<u32>),
+ Pointer(Box<SpirvType>, spirv::StorageClass),
+ Func(Option<Box<SpirvType>>, Vec<SpirvType>),
+ Struct(Vec<SpirvScalarKey>),
+}
+
+impl SpirvType {
+ fn new(t: ast::Type) -> Self {
+ match t {
+ ast::Type::Scalar(t) => SpirvType::Base(t.into()),
+ ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len),
+ ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len),
+ ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
+ Box::new(SpirvType::Base(pointer_t.into())),
+ space_to_spirv(space),
+ ),
+ }
+ }
+
+ fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self {
+ let key = Self::new(t);
+ SpirvType::Pointer(Box::new(key), outer_space)
+ }
+}
+
+impl From<ast::ScalarType> for SpirvType {
+ fn from(t: ast::ScalarType) -> Self {
+ SpirvType::Base(t.into())
+ }
+}
+// SPIR-V integer type definitions are signless, more below:
+// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
+// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+enum SpirvScalarKey {
+ B8,
+ B16,
+ B32,
+ B64,
+ F16,
+ F32,
+ F64,
+ Pred,
+ F16x2,
+}
+
+impl From<ast::ScalarType> for SpirvScalarKey {
+ fn from(t: ast::ScalarType) -> Self {
+ match t {
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8,
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
+ SpirvScalarKey::B16
+ }
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => {
+ SpirvScalarKey::B32
+ }
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => {
+ SpirvScalarKey::B64
+ }
+ ast::ScalarType::F16 => SpirvScalarKey::F16,
+ ast::ScalarType::F32 => SpirvScalarKey::F32,
+ ast::ScalarType::F64 => SpirvScalarKey::F64,
+ ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
+ ast::ScalarType::Pred => SpirvScalarKey::Pred,
+ ast::ScalarType::S16x2
+ | ast::ScalarType::U16x2
+ | ast::ScalarType::BF16
+ | ast::ScalarType::BF16x2
+ | ast::ScalarType::B128 => todo!(),
+ }
+ }
+}
+
+fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass {
+ match this {
+ ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::StateSpace::Generic => spirv::StorageClass::Generic,
+ ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::StateSpace::Local => spirv::StorageClass::Function,
+ ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::StateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Reg => spirv::StorageClass::Function,
+ ast::StateSpace::Sreg => spirv::StorageClass::Input,
+ ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta => todo!(),
+ }
+}
+
+// TODO: remove this once we have pef-function support for denorms
+fn emit_denorm_build_string<'input>(
+ call_map: &MethodsCallMap,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, SpirvWord>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
+) -> (CString, bool) {
+ let denorm_counts = denorm_information
+ .iter()
+ .map(|(method, meth_denorm)| {
+ let f16_count = meth_denorm
+ .get(&(mem::size_of::<f16>() as u8))
+ .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
+ .1;
+ let f32_count = meth_denorm
+ .get(&(mem::size_of::<f32>() as u8))
+ .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
+ .1;
+ (method, (f16_count + f32_count))
+ })
+ .collect::<HashMap<_, _>>();
+ let mut flush_over_preserve = 0;
+ for (kernel, children) in call_map.kernels() {
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Kernel(kernel))
+ .unwrap_or(&0);
+ for child_fn in children {
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Func(*child_fn))
+ .unwrap_or(&0);
+ }
+ }
+ if flush_over_preserve > 0 {
+ (
+ CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(),
+ true,
+ )
+ } else {
+ (CString::new("-ze-take-global-address").unwrap(), false)
+ }
+}
+
+fn get_globals_use_map<'input>(
+ directives: Vec<Directive<'input>>,
+) -> (
+ Vec<Directive<'input>>,
+ HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+) {
+ let mut known_globals = HashSet::new();
+ for directive in directives.iter() {
+ match directive {
+ Directive::Variable(_, ast::Variable { name, .. }) => {
+ known_globals.insert(*name);
+ }
+ Directive::Method(..) => {}
+ }
+ }
+ let mut symbol_uses_map = HashMap::new();
+ let directives = directives
+ .into_iter()
+ .map(|directive| match directive {
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive,
+ Directive::Method(Function {
+ func_decl,
+ body: Some(mut statements),
+ globals,
+ import_as,
+ tuning,
+ linkage,
+ }) => {
+ let method_name = func_decl.borrow().name;
+ statements = statements
+ .into_iter()
+ .map(|statement| {
+ statement.visit_map(
+ &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
+ if known_globals.contains(&symbol) {
+ multi_hash_map_append(
+ &mut symbol_uses_map,
+ method_name,
+ symbol,
+ );
+ }
+ Ok::<_, TranslateError>(symbol)
+ },
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()
+ .unwrap();
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ globals,
+ import_as,
+ tuning,
+ linkage,
+ })
+ }
+ })
+ .collect::<Vec<_>>();
+ (directives, symbol_uses_map)
+}
+
+fn emit_directives<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver<'input>,
+ opencl_id: spirv::Word,
+ should_flush_denorms: bool,
+ call_map: &MethodsCallMap<'input>,
+ globals_use_map: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+ directives: Vec<Directive<'input>>,
+ kernel_info: &mut HashMap<String, KernelInfo>,
+) -> Result<(), TranslateError> {
+ let empty_body = Vec::new();
+ for d in directives.iter() {
+ match d {
+ Directive::Variable(linking, var) => {
+ emit_variable(builder, map, id_defs, *linking, &var)?;
+ }
+ Directive::Method(f) => {
+ let f_body = match &f.body {
+ Some(f) => f,
+ None => {
+ if f.linkage.contains(ast::LinkingDirective::EXTERN) {
+ &empty_body
+ } else {
+ continue;
+ }
+ }
+ };
+ for var in f.globals.iter() {
+ emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?;
+ }
+ let func_decl = (*f.func_decl).borrow();
+ let fn_id = emit_function_header(
+ builder,
+ map,
+ &id_defs,
+ &*func_decl,
+ call_map,
+ &globals_use_map,
+ kernel_info,
+ )?;
+ if matches!(func_decl.name, ast::MethodName::Kernel(_)) {
+ if should_flush_denorms {
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::DenormFlushToZero,
+ [16],
+ );
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::DenormFlushToZero,
+ [32],
+ );
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::DenormFlushToZero,
+ [64],
+ );
+ }
+ // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx)
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::ContractionOff,
+ [],
+ );
+ for t in f.tuning.iter() {
+ match *t {
+ ast::TuningDirective::MaxNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
+ [nx, ny, nz],
+ );
+ }
+ ast::TuningDirective::ReqNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::LocalSize,
+ [nx, ny, nz],
+ );
+ }
+ // Too architecture specific
+ ast::TuningDirective::MaxNReg(..)
+ | ast::TuningDirective::MinNCtaPerSm(..) => {}
+ }
+ }
+ }
+ emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
+ emit_function_linkage(builder, id_defs, f, fn_id)?;
+ builder.select_block(None)?;
+ builder.end_function()?;
+ }
+ }
+ }
+ Ok(())
+}
+
+fn emit_variable<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver<'input>,
+ linking: ast::LinkingDirective,
+ var: &ast::Variable<SpirvWord>,
+) -> Result<(), TranslateError> {
+ let (must_init, st_class) = match var.state_space {
+ ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
+ (false, spirv::StorageClass::Function)
+ }
+ ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
+ ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
+ ast::StateSpace::Generic => todo!(),
+ ast::StateSpace::Sreg => todo!(),
+ ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta => todo!(),
+ };
+ let initalizer = if var.array_init.len() > 0 {
+ Some(
+ map.get_or_add_constant(
+ builder,
+ &ast::Type::from(var.v_type.clone()),
+ &*var.array_init,
+ )?
+ .0,
+ )
+ } else if must_init {
+ let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone()));
+ Some(builder.constant_null(type_id.0, None))
+ } else {
+ None
+ };
+ let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class));
+ builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer);
+ if let Some(align) = var.align {
+ builder.decorate(
+ var.name.0,
+ spirv::Decoration::Alignment,
+ [dr::Operand::LiteralInt32(align)].iter().cloned(),
+ );
+ }
+ if var.state_space != ast::StateSpace::Shared
+ || !linking.contains(ast::LinkingDirective::EXTERN)
+ {
+ emit_linking_decoration(builder, id_defs, None, var.name, linking);
+ }
+ Ok(())
+}
+
+fn emit_function_header<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ defined_globals: &GlobalStringIdResolver<'input>,
+ func_decl: &ast::MethodDeclaration<'input, SpirvWord>,
+ call_map: &MethodsCallMap<'input>,
+ globals_use_map: &HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+ kernel_info: &mut HashMap<String, KernelInfo>,
+) -> Result<SpirvWord, TranslateError> {
+ if let ast::MethodName::Kernel(name) = func_decl.name {
+ let args_lens = func_decl
+ .input_arguments
+ .iter()
+ .map(|param| {
+ (
+ type_size_of(&param.v_type),
+ matches!(param.v_type, ast::Type::Pointer(..)),
+ )
+ })
+ .collect();
+ kernel_info.insert(
+ name.to_string(),
+ KernelInfo {
+ arguments_sizes: args_lens,
+ uses_shared_mem: func_decl.shared_mem.is_some(),
+ },
+ );
+ }
+ let (ret_type, func_type) = get_function_type(
+ builder,
+ map,
+ effective_input_arguments(func_decl).map(|(_, typ)| typ),
+ &func_decl.return_arguments,
+ );
+ let fn_id = match func_decl.name {
+ ast::MethodName::Kernel(name) => {
+ let fn_id = defined_globals.get_id(name)?;
+ let interface = globals_use_map
+ .get(&ast::MethodName::Kernel(name))
+ .into_iter()
+ .flatten()
+ .copied()
+ .chain({
+ call_map
+ .get_kernel_children(name)
+ .copied()
+ .flat_map(|subfunction| {
+ globals_use_map
+ .get(&ast::MethodName::Func(subfunction))
+ .into_iter()
+ .flatten()
+ .copied()
+ })
+ .into_iter()
+ })
+ .map(|word| word.0)
+ .collect::<Vec<spirv::Word>>();
+ builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface);
+ fn_id
+ }
+ ast::MethodName::Func(name) => name,
+ };
+ builder.begin_function(
+ ret_type.0,
+ Some(fn_id.0),
+ spirv::FunctionControl::NONE,
+ func_type.0,
+ )?;
+ for (name, typ) in effective_input_arguments(func_decl) {
+ let result_type = map.get_or_add(builder, typ);
+ builder.function_parameter(Some(name.0), result_type.0)?;
+ }
+ Ok(fn_id)
+}
+
+pub fn type_size_of(this: &ast::Type) -> usize {
+ match this {
+ ast::Type::Scalar(typ) => typ.size_of() as usize,
+ ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize),
+ ast::Type::Array(_, typ, len) => len
+ .iter()
+ .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
+ ast::Type::Pointer(..) => mem::size_of::<usize>(),
+ }
+}
+fn emit_function_body_ops<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver<'input>,
+ opencl: spirv::Word,
+ func: &[ExpandedStatement],
+) -> Result<(), TranslateError> {
+ for s in func {
+ match s {
+ Statement::Label(id) => {
+ if builder.selected_block().is_some() {
+ builder.branch(id.0)?;
+ }
+ builder.begin_block(Some(id.0))?;
+ }
+ _ => {
+ if builder.selected_block().is_none() && builder.selected_function().is_some() {
+ builder.begin_block(None)?;
+ }
+ }
+ }
+ match s {
+ Statement::Label(_) => (),
+ Statement::Variable(var) => {
+ emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?;
+ }
+ Statement::Constant(cnst) => {
+ let typ_id = map.get_or_add_scalar(builder, cnst.typ);
+ match (cnst.typ, cnst.value) {
+ (ast::ScalarType::B8, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32);
+ }
+ (ast::ScalarType::B16, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32);
+ }
+ (ast::ScalarType::B32, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32);
+ }
+ (ast::ScalarType::B64, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u64(typ_id.0, Some(cnst.dst.0), value);
+ }
+ (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32);
+ }
+ (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32);
+ }
+ (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32);
+ }
+ (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64);
+ }
+ (ast::ScalarType::B8, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32);
+ }
+ (ast::ScalarType::B16, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32);
+ }
+ (ast::ScalarType::B32, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32);
+ }
+ (ast::ScalarType::B64, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64);
+ }
+ (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32);
+ }
+ (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32);
+ }
+ (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32);
+ }
+ (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64);
+ }
+ (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f32(
+ typ_id.0,
+ Some(cnst.dst.0),
+ f16::from_f32(value).to_f32(),
+ );
+ }
+ (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f32(typ_id.0, Some(cnst.dst.0), value);
+ }
+ (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64);
+ }
+ (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f32(
+ typ_id.0,
+ Some(cnst.dst.0),
+ f16::from_f64(value).to_f32(),
+ );
+ }
+ (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32);
+ }
+ (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f64(typ_id.0, Some(cnst.dst.0), value);
+ }
+ (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => {
+ let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0;
+ if value == 0 {
+ builder.constant_false(bool_type, Some(cnst.dst.0));
+ } else {
+ builder.constant_true(bool_type, Some(cnst.dst.0));
+ }
+ }
+ (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => {
+ let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0;
+ if value == 0 {
+ builder.constant_false(bool_type, Some(cnst.dst.0));
+ } else {
+ builder.constant_true(bool_type, Some(cnst.dst.0));
+ }
+ }
+ _ => return Err(error_mismatched_type()),
+ }
+ }
+ Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
+ Statement::Conditional(bra) => {
+ builder.branch_conditional(
+ bra.predicate.0,
+ bra.if_true.0,
+ bra.if_false.0,
+ iter::empty(),
+ )?;
+ }
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
+ // TODO: implement properly
+ let zero = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U64),
+ &vec_repr(0u64),
+ )?;
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64);
+ builder.copy_object(result_type.0, Some(dst.0), zero.0)?;
+ }
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(),
+ ast::Instruction::Call { data, arguments } => {
+ let (result_type, result_id) =
+ match (&*data.return_arguments, &*arguments.return_arguments) {
+ ([(type_, space)], [id]) => {
+ if *space != ast::StateSpace::Reg {
+ return Err(error_unreachable());
+ }
+ (
+ map.get_or_add(builder, SpirvType::new(type_.clone())).0,
+ Some(id.0),
+ )
+ }
+ ([], []) => (map.void(), None),
+ _ => todo!(),
+ };
+ let arg_list = arguments
+ .input_arguments
+ .iter()
+ .map(|id| id.0)
+ .collect::<Vec<_>>();
+ builder.function_call(result_type, result_id, arguments.func.0, arg_list)?;
+ }
+ ast::Instruction::Abs { data, arguments } => {
+ emit_abs(builder, map, opencl, data, arguments)?
+ }
+ // SPIR-V does not support marking jumps as guaranteed-converged
+ ast::Instruction::Bra { arguments, .. } => {
+ builder.branch(arguments.src.0)?;
+ }
+ ast::Instruction::Ld { data, arguments } => {
+ let mem_access = match data.qualifier {
+ ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE,
+ // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad
+ ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE,
+ _ => return Err(TranslateError::Todo),
+ };
+ let result_type =
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
+ builder.load(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src.0,
+ Some(mem_access | spirv::MemoryAccess::ALIGNED),
+ [dr::Operand::LiteralInt32(
+ type_size_of(&ast::Type::from(data.typ.clone())) as u32,
+ )]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ ast::Instruction::St { data, arguments } => {
+ let mem_access = match data.qualifier {
+ ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE,
+ // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore
+ ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE,
+ _ => return Err(TranslateError::Todo),
+ };
+ builder.store(
+ arguments.src1.0,
+ arguments.src2.0,
+ Some(mem_access | spirv::MemoryAccess::ALIGNED),
+ [dr::Operand::LiteralInt32(
+ type_size_of(&ast::Type::from(data.typ.clone())) as u32,
+ )]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ // SPIR-V does not support ret as guaranteed-converged
+ ast::Instruction::Ret { .. } => builder.ret()?,
+ ast::Instruction::Mov { data, arguments } => {
+ let result_type =
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
+ builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
+ }
+ ast::Instruction::Mul { data, arguments } => match data {
+ ast::MulDetails::Integer { type_, control } => {
+ emit_mul_int(builder, map, opencl, *type_, *control, arguments)?
+ }
+ ast::MulDetails::Float(ref ctr) => {
+ emit_mul_float(builder, map, ctr, arguments)?
+ }
+ },
+ ast::Instruction::Add { data, arguments } => match data {
+ ast::ArithDetails::Integer(desc) => {
+ emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)?
+ }
+ ast::ArithDetails::Float(desc) => {
+ emit_add_float(builder, map, desc, arguments)?
+ }
+ },
+ ast::Instruction::Setp { data, arguments } => {
+ if arguments.dst2.is_some() {
+ todo!()
+ }
+ emit_setp(builder, map, data, arguments)?;
+ }
+ ast::Instruction::Not { data, arguments } => {
+ let result_type = map.get_or_add(builder, SpirvType::from(*data));
+ let result_id = Some(arguments.dst.0);
+ let operand = arguments.src;
+ match data {
+ ast::ScalarType::Pred => {
+ logical_not(builder, result_type.0, result_id, operand.0)
+ }
+ _ => builder.not(result_type.0, result_id, operand.0),
+ }?;
+ }
+ ast::Instruction::Shl { data, arguments } => {
+ let full_type = ast::Type::Scalar(*data);
+ let size_of = type_size_of(&full_type);
+ let result_type = map.get_or_add(builder, SpirvType::new(full_type));
+ let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?;
+ builder.shift_left_logical(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ offset_src,
+ )?;
+ }
+ ast::Instruction::Shr { data, arguments } => {
+ let full_type = ast::ScalarType::from(data.type_);
+ let size_of = full_type.size_of();
+ let result_type = map.get_or_add_scalar(builder, full_type).0;
+ let offset_src =
+ insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?;
+ match data.kind {
+ ptx_parser::RightShiftKind::Arithmetic => {
+ builder.shift_right_arithmetic(
+ result_type,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ offset_src,
+ )?;
+ }
+ ptx_parser::RightShiftKind::Logical => {
+ builder.shift_right_logical(
+ result_type,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ offset_src,
+ )?;
+ }
+ }
+ }
+ ast::Instruction::Cvt { data, arguments } => {
+ emit_cvt(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::Cvta { data, arguments } => {
+ // This would be only meaningful if const/slm/global pointers
+ // had a different format than generic pointers, but they don't pretty much by ptx definition
+ // Honestly, I have no idea why this instruction exists and is emitted by the compiler
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
+ builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
+ }
+ ast::Instruction::SetpBool { .. } => todo!(),
+ ast::Instruction::Mad { data, arguments } => match data {
+ ast::MadDetails::Integer {
+ type_,
+ control,
+ saturate,
+ } => {
+ if *saturate {
+ todo!()
+ }
+ if type_.kind() == ast::ScalarKind::Signed {
+ emit_mad_sint(builder, map, opencl, *type_, *control, arguments)?
+ } else {
+ emit_mad_uint(builder, map, opencl, *type_, *control, arguments)?
+ }
+ }
+ ast::MadDetails::Float(desc) => {
+ emit_mad_float(builder, map, opencl, desc, arguments)?
+ }
+ },
+ ast::Instruction::Fma { data, arguments } => {
+ emit_fma_float(builder, map, opencl, data, arguments)?
+ }
+ ast::Instruction::Or { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, *data).0;
+ if *data == ast::ScalarType::Pred {
+ builder.logical_or(
+ result_type,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ } else {
+ builder.bitwise_or(
+ result_type,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ }
+ ast::Instruction::Sub { data, arguments } => match data {
+ ast::ArithDetails::Integer(desc) => {
+ emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?;
+ }
+ ast::ArithDetails::Float(desc) => {
+ emit_sub_float(builder, map, desc, arguments)?;
+ }
+ },
+ ast::Instruction::Min { data, arguments } => {
+ emit_min(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::Max { data, arguments } => {
+ emit_max(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::Rcp { data, arguments } => {
+ emit_rcp(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::And { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, *data);
+ if *data == ast::ScalarType::Pred {
+ builder.logical_and(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ } else {
+ builder.bitwise_and(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ }
+ ast::Instruction::Selp { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, *data);
+ builder.select(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src3.0,
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ // TODO: implement named barriers
+ ast::Instruction::Bar { data, arguments } => {
+ let workgroup_scope = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(spirv::Scope::Workgroup as u32),
+ )?;
+ let barrier_semantics = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ )?;
+ builder.control_barrier(
+ workgroup_scope.0,
+ workgroup_scope.0,
+ barrier_semantics.0,
+ )?;
+ }
+ ast::Instruction::Atom { data, arguments } => {
+ emit_atom(builder, map, data, arguments)?;
+ }
+ ast::Instruction::AtomCas { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, data.type_);
+ let memory_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(scope_to_spirv(data.scope) as u32),
+ )?;
+ let semantics_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(semantics_to_spirv(data.semantics).bits()),
+ )?;
+ builder.atomic_compare_exchange(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ memory_const.0,
+ semantics_const.0,
+ semantics_const.0,
+ arguments.src3.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::Instruction::Div { data, arguments } => match data {
+ ast::DivDetails::Unsigned(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.u_div(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::DivDetails::Signed(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.s_div(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::DivDetails::Float(t) => {
+ let result_type = map.get_or_add_scalar(builder, t.type_.into());
+ builder.f_div(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ emit_float_div_decoration(builder, arguments.dst, t.kind);
+ }
+ },
+ ast::Instruction::Sqrt { data, arguments } => {
+ emit_sqrt(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::Rsqrt { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, data.type_.into());
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::rsqrt as spirv::Word,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Neg { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, data.type_);
+ let negate_func = if data.type_.kind() == ast::ScalarKind::Float {
+ dr::Builder::f_negate
+ } else {
+ dr::Builder::s_negate
+ };
+ negate_func(
+ builder,
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src.0,
+ )?;
+ }
+ ast::Instruction::Sin { arguments, .. } => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::sin as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Cos { arguments, .. } => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::cos as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Lg2 { arguments, .. } => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::log2 as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Ex2 { arguments, .. } => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::exp2 as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Clz { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::clz as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Brev { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
+ }
+ ast::Instruction::Popc { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
+ }
+ ast::Instruction::Xor { data, arguments } => {
+ let builder_fn: fn(
+ &mut dr::Builder,
+ u32,
+ Option<u32>,
+ u32,
+ u32,
+ ) -> Result<u32, dr::Error> = match data {
+ ast::ScalarType::Pred => emit_logical_xor_spirv,
+ _ => dr::Builder::bitwise_xor,
+ };
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder_fn(
+ builder,
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::Instruction::Bfe { .. }
+ | ast::Instruction::Bfi { .. }
+ | ast::Instruction::Activemask { .. } => {
+ // Should have beeen replaced with a funciton call earlier
+ return Err(error_unreachable());
+ }
+
+ ast::Instruction::Rem { data, arguments } => {
+ let builder_fn = if data.kind() == ast::ScalarKind::Signed {
+ dr::Builder::s_mod
+ } else {
+ dr::Builder::u_mod
+ };
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder_fn(
+ builder,
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::Instruction::Prmt { data, arguments } => {
+ let control = *data as u32;
+ let components = [
+ (control >> 0) & 0b1111,
+ (control >> 4) & 0b1111,
+ (control >> 8) & 0b1111,
+ (control >> 12) & 0b1111,
+ ];
+ if components.iter().any(|&c| c > 7) {
+ return Err(TranslateError::Todo);
+ }
+ let vec4_b8_type =
+ map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4));
+ let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
+ let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?;
+ let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?;
+ let dst_vector = builder.vector_shuffle(
+ vec4_b8_type.0,
+ None,
+ src1_vector,
+ src2_vector,
+ components,
+ )?;
+ builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?;
+ }
+ ast::Instruction::Membar { data } => {
+ let (scope, semantics) = match data {
+ ast::MemScope::Cta => (
+ spirv::Scope::Workgroup,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ ast::MemScope::Gpu => (
+ spirv::Scope::Device,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ ast::MemScope::Sys => (
+ spirv::Scope::CrossDevice,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+
+ ast::MemScope::Cluster => todo!(),
+ };
+ let spirv_scope = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(scope as u32),
+ )?;
+ let spirv_semantics = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(semantics),
+ )?;
+ builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?;
+ }
+ },
+ Statement::LoadVar(details) => {
+ emit_load_var(builder, map, details)?;
+ }
+ Statement::StoreVar(details) => {
+ let dst_ptr = match details.member_index {
+ Some(index) => {
+ let result_ptr_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(
+ details.typ.clone(),
+ spirv::StorageClass::Function,
+ ),
+ );
+ let index_spirv = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(index as u32),
+ )?;
+ builder.in_bounds_access_chain(
+ result_ptr_type.0,
+ None,
+ details.arg.src1.0,
+ [index_spirv.0].iter().copied(),
+ )?
+ }
+ None => details.arg.src1.0,
+ };
+ builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?;
+ }
+ Statement::RetValue(_, id) => {
+ builder.ret_value(id.0)?;
+ }
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src,
+ offset_src,
+ }) => {
+ let u8_pointer = map.get_or_add(
+ builder,
+ SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)),
+ );
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)),
+ );
+ let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?;
+ let temp = builder.in_bounds_ptr_access_chain(
+ u8_pointer.0,
+ None,
+ ptr_src_u8,
+ offset_src.0,
+ iter::empty(),
+ )?;
+ builder.bitcast(result_type.0, Some(dst.0), temp)?;
+ }
+ Statement::RepackVector(repack) => {
+ if repack.is_extract {
+ let scalar_type = map.get_or_add_scalar(builder, repack.typ);
+ for (index, dst_id) in repack.unpacked.iter().enumerate() {
+ builder.composite_extract(
+ scalar_type.0,
+ Some(dst_id.0),
+ repack.packed.0,
+ [index as u32].iter().copied(),
+ )?;
+ }
+ } else {
+ let vector_type = map.get_or_add(
+ builder,
+ SpirvType::Vector(
+ SpirvScalarKey::from(repack.typ),
+ repack.unpacked.len() as u8,
+ ),
+ );
+ let mut temp_vec = builder.undef(vector_type.0, None);
+ for (index, src_id) in repack.unpacked.iter().enumerate() {
+ temp_vec = builder.composite_insert(
+ vector_type.0,
+ None,
+ src_id.0,
+ temp_vec,
+ [index as u32].iter().copied(),
+ )?;
+ }
+ builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
+ }
+ }
+ }
+ }
+ Ok(())
+}
+
+fn emit_function_linkage<'input>(
+ builder: &mut dr::Builder,
+ id_defs: &GlobalStringIdResolver<'input>,
+ f: &Function,
+ fn_name: SpirvWord,
+) -> Result<(), TranslateError> {
+ if f.linkage == ast::LinkingDirective::NONE {
+ return Ok(());
+ };
+ let linking_name = match f.func_decl.borrow().name {
+ // According to SPIR-V rules linkage attributes are invalid on kernels
+ ast::MethodName::Kernel(..) => return Ok(()),
+ ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else(
+ || match id_defs.reverse_variables.get(&fn_id) {
+ Some(fn_name) => Ok(fn_name),
+ None => Err(error_unknown_symbol()),
+ },
+ Result::Ok,
+ )?,
+ };
+ emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage);
+ Ok(())
+}
+
+fn get_function_type(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ spirv_input: impl Iterator<Item = SpirvType>,
+ spirv_output: &[ast::Variable<SpirvWord>],
+) -> (SpirvWord, SpirvWord) {
+ map.get_or_add_fn(
+ builder,
+ spirv_input,
+ spirv_output
+ .iter()
+ .map(|var| SpirvType::new(var.v_type.clone())),
+ )
+}
+
+fn emit_linking_decoration<'input>(
+ builder: &mut dr::Builder,
+ id_defs: &GlobalStringIdResolver<'input>,
+ name_override: Option<&str>,
+ name: SpirvWord,
+ linking: ast::LinkingDirective,
+) {
+ if linking == ast::LinkingDirective::NONE {
+ return;
+ }
+ if linking.contains(ast::LinkingDirective::VISIBLE) {
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
+ builder.decorate(
+ name.0,
+ spirv::Decoration::LinkageAttributes,
+ [
+ dr::Operand::LiteralString(string_name.to_string()),
+ dr::Operand::LinkageType(spirv::LinkageType::Export),
+ ]
+ .iter()
+ .cloned(),
+ );
+ } else if linking.contains(ast::LinkingDirective::EXTERN) {
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
+ builder.decorate(
+ name.0,
+ spirv::Decoration::LinkageAttributes,
+ [
+ dr::Operand::LiteralString(string_name.to_string()),
+ dr::Operand::LinkageType(spirv::LinkageType::Import),
+ ]
+ .iter()
+ .cloned(),
+ );
+ }
+ // TODO: handle LinkingDirective::WEAK
+}
+
+fn effective_input_arguments<'a>(
+ this: &'a ast::MethodDeclaration<'a, SpirvWord>,
+) -> impl Iterator<Item = (SpirvWord, SpirvType)> + 'a {
+ let is_kernel = matches!(this.name, ast::MethodName::Kernel(_));
+ this.input_arguments.iter().map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space));
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
+}
+
+fn emit_implicit_conversion(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ cv: &ImplicitConversion,
+) -> Result<(), TranslateError> {
+ let from_parts = to_parts(&cv.from_type);
+ let to_parts = to_parts(&cv.to_type);
+ match (from_parts.kind, to_parts.kind, &cv.kind) {
+ (_, _, &ConversionKind::BitToPtr) => {
+ let dst_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)),
+ );
+ builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => {
+ if from_parts.width == to_parts.width {
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ if from_parts.scalar_kind != ast::ScalarKind::Float
+ && to_parts.scalar_kind != ast::ScalarKind::Float
+ {
+ // It is noop, but another instruction expects result of this conversion
+ builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?;
+ } else {
+ builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ } else {
+ // This block is safe because it's illegal to implictly convert between floating point values
+ let same_width_bit_type = map.get_or_add(
+ builder,
+ SpirvType::new(type_from_parts(TypeParts {
+ scalar_kind: ast::ScalarKind::Bit,
+ ..from_parts
+ })),
+ );
+ let same_width_bit_value =
+ builder.bitcast(same_width_bit_type.0, None, cv.src.0)?;
+ let wide_bit_type = type_from_parts(TypeParts {
+ scalar_kind: ast::ScalarKind::Bit,
+ ..to_parts
+ });
+ let wide_bit_type_spirv =
+ map.get_or_add(builder, SpirvType::new(wide_bit_type.clone()));
+ if to_parts.scalar_kind == ast::ScalarKind::Unsigned
+ || to_parts.scalar_kind == ast::ScalarKind::Bit
+ {
+ builder.u_convert(
+ wide_bit_type_spirv.0,
+ Some(cv.dst.0),
+ same_width_bit_value,
+ )?;
+ } else {
+ let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
+ && to_parts.scalar_kind == ast::ScalarKind::Signed
+ {
+ dr::Builder::s_convert
+ } else {
+ dr::Builder::u_convert
+ };
+ let wide_bit_value =
+ conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?;
+ emit_implicit_conversion(
+ builder,
+ map,
+ &ImplicitConversion {
+ src: SpirvWord(wide_bit_value),
+ dst: cv.dst,
+ from_type: wide_bit_type,
+ from_space: cv.from_space,
+ to_type: cv.to_type.clone(),
+ to_space: cv.to_space,
+ kind: ConversionKind::Default,
+ },
+ )?;
+ }
+ }
+ }
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default)
+ | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default)
+ | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => {
+ let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (_, _, &ConversionKind::PtrToPtr) => {
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ space_to_spirv(cv.to_space),
+ ),
+ );
+ if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic
+ {
+ let src = if cv.from_type != cv.to_type {
+ let temp_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ space_to_spirv(cv.from_space),
+ ),
+ );
+ builder.bitcast(temp_type.0, None, cv.src.0)?
+ } else {
+ cv.src.0
+ };
+ builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?;
+ } else if cv.from_space == ast::StateSpace::Generic
+ && cv.to_space != ast::StateSpace::Generic
+ {
+ let src = if cv.from_type != cv.to_type {
+ let temp_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ space_to_spirv(cv.from_space),
+ ),
+ );
+ builder.bitcast(temp_type.0, None, cv.src.0)?
+ } else {
+ cv.src.0
+ };
+ builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?;
+ } else {
+ builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ }
+ (_, _, &ConversionKind::AddressOf) => {
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+}
+
+fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
+ let mut result = vec![0; mem::size_of::<T>()];
+ unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) };
+ result
+}
+
+fn emit_abs(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ d: &ast::TypeFtz,
+ arg: &ast::AbsArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let scalar_t = ast::ScalarType::from(d.type_);
+ let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
+ let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
+ spirv::CLOp::s_abs
+ } else {
+ spirv::CLOp::fabs
+ };
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ cl_abs as spirv::Word,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_mul_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ type_: ast::ScalarType,
+ control: ast::MulIntControl,
+ arg: &ast::MulArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(type_));
+ match control {
+ ast::MulIntControl::Low => {
+ builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ }
+ ast::MulIntControl::High => {
+ let opencl_inst = if type_.kind() == ast::ScalarKind::Signed {
+ spirv::CLOp::s_mul_hi
+ } else {
+ spirv::CLOp::u_mul_hi
+ };
+ builder.ext_inst(
+ inst_type.0,
+ Some(arg.dst.0),
+ opencl,
+ opencl_inst as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ ast::MulIntControl::Wide => {
+ let instr_width = type_.size_of();
+ let instr_kind = type_.kind();
+ let dst_type = scalar_from_parts(instr_width * 2, instr_kind);
+ let dst_type_id = map.get_or_add_scalar(builder, dst_type);
+ let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed {
+ let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?;
+ let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?;
+ (src1, src2)
+ } else {
+ let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?;
+ let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?;
+ (src1, src2)
+ };
+ builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?;
+ builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty());
+ }
+ }
+ Ok(())
+}
+
+fn emit_mul_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ ctr: &ast::ArithFloat,
+ arg: &ast::MulArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ if ctr.saturate {
+ todo!()
+ }
+ let result_type = map.get_or_add_scalar(builder, ctr.type_.into());
+ builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ emit_rounding_decoration(builder, arg.dst, ctr.rounding);
+ Ok(())
+}
+
+fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType {
+ match kind {
+ ast::ScalarKind::Float => match width {
+ 2 => ast::ScalarType::F16,
+ 4 => ast::ScalarType::F32,
+ 8 => ast::ScalarType::F64,
+ _ => unreachable!(),
+ },
+ ast::ScalarKind::Bit => match width {
+ 1 => ast::ScalarType::B8,
+ 2 => ast::ScalarType::B16,
+ 4 => ast::ScalarType::B32,
+ 8 => ast::ScalarType::B64,
+ _ => unreachable!(),
+ },
+ ast::ScalarKind::Signed => match width {
+ 1 => ast::ScalarType::S8,
+ 2 => ast::ScalarType::S16,
+ 4 => ast::ScalarType::S32,
+ 8 => ast::ScalarType::S64,
+ _ => unreachable!(),
+ },
+ ast::ScalarKind::Unsigned => match width {
+ 1 => ast::ScalarType::U8,
+ 2 => ast::ScalarType::U16,
+ 4 => ast::ScalarType::U32,
+ 8 => ast::ScalarType::U64,
+ _ => unreachable!(),
+ },
+ ast::ScalarKind::Pred => ast::ScalarType::Pred,
+ }
+}
+
+fn emit_rounding_decoration(
+ builder: &mut dr::Builder,
+ dst: SpirvWord,
+ rounding: Option<ast::RoundingMode>,
+) {
+ if let Some(rounding) = rounding {
+ builder.decorate(
+ dst.0,
+ spirv::Decoration::FPRoundingMode,
+ [rounding_to_spirv(rounding)].iter().cloned(),
+ );
+ }
+}
+
+fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand {
+ let mode = match this {
+ ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE,
+ ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ,
+ ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP,
+ ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN,
+ };
+ rspirv::dr::Operand::FPRoundingMode(mode)
+}
+
+fn emit_add_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ typ: ast::ScalarType,
+ saturate: bool,
+ arg: &ast::AddArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ if saturate {
+ todo!()
+ }
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
+ builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ Ok(())
+}
+
+fn emit_add_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ desc: &ast::ArithFloat,
+ arg: &ast::AddArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)));
+ builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
+ Ok(())
+}
+
+fn emit_setp(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ setp: &ast::SetpData,
+ arg: &ast::SetpArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let result_type = map
+ .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred))
+ .0;
+ let result_id = Some(arg.dst1.0);
+ let operand_1 = arg.src1.0;
+ let operand_2 = arg.src2.0;
+ match setp.cmp_op {
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => {
+ builder.i_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => {
+ builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => {
+ builder.i_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => {
+ builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => {
+ builder.u_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => {
+ builder.s_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => {
+ builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => {
+ builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => {
+ builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => {
+ builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => {
+ builder.u_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => {
+ builder.s_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => {
+ builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => {
+ builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => {
+ builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => {
+ builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => {
+ builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => {
+ builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => {
+ builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => {
+ builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => {
+ builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => {
+ builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => {
+ let temp1 = builder.is_nan(result_type, None, operand_1)?;
+ let temp2 = builder.is_nan(result_type, None, operand_2)?;
+ builder.logical_or(result_type, result_id, temp1, temp2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => {
+ let temp1 = builder.is_nan(result_type, None, operand_1)?;
+ let temp2 = builder.is_nan(result_type, None, operand_2)?;
+ let any_nan = builder.logical_or(result_type, None, temp1, temp2)?;
+ logical_not(builder, result_type, result_id, any_nan)
+ }
+ _ => todo!(),
+ }?;
+ Ok(())
+}
+
+// HACK ALERT
+// Temporary workaround until IGC gets its shit together
+// Currently IGC carries two copies of SPIRV-LLVM translator
+// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/.
+// Obviously, old and buggy one is used for compiling L0 SPIRV
+// https://github.com/intel/intel-graphics-compiler/issues/148
+fn logical_not(
+ builder: &mut dr::Builder,
+ result_type: spirv::Word,
+ result_id: Option<spirv::Word>,
+ operand: spirv::Word,
+) -> Result<spirv::Word, dr::Error> {
+ let const_true = builder.constant_true(result_type, None);
+ let const_false = builder.constant_false(result_type, None);
+ builder.select(result_type, result_id, operand, const_false, const_true)
+}
+
+// HACK ALERT
+// For some reason IGC fails linking if the value and shift size are of different type
+fn insert_shift_hack(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ offset_var: spirv::Word,
+ size_of: usize,
+) -> Result<spirv::Word, TranslateError> {
+ let result_type = match size_of {
+ 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16),
+ 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64),
+ 4 => return Ok(offset_var),
+ _ => return Err(error_unreachable()),
+ };
+ Ok(builder.u_convert(result_type.0, None, offset_var)?)
+}
+
+fn emit_cvt(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ dets: &ast::CvtDetails,
+ arg: &ast::CvtArgs<SpirvWord>,
+) -> Result<(), TranslateError> {
+ match dets.mode {
+ ptx_parser::CvtMode::SignExtend => {
+ let cv = ImplicitConversion {
+ src: arg.src,
+ dst: arg.dst,
+ from_type: dets.from.into(),
+ from_space: ast::StateSpace::Reg,
+ to_type: dets.to.into(),
+ to_space: ast::StateSpace::Reg,
+ kind: ConversionKind::SignExtend,
+ };
+ emit_implicit_conversion(builder, map, &cv)?;
+ }
+ ptx_parser::CvtMode::ZeroExtend
+ | ptx_parser::CvtMode::Truncate
+ | ptx_parser::CvtMode::Bitcast => {
+ let cv = ImplicitConversion {
+ src: arg.src,
+ dst: arg.dst,
+ from_type: dets.from.into(),
+ from_space: ast::StateSpace::Reg,
+ to_type: dets.to.into(),
+ to_space: ast::StateSpace::Reg,
+ kind: ConversionKind::Default,
+ };
+ emit_implicit_conversion(builder, map, &cv)?;
+ }
+ ptx_parser::CvtMode::SaturateUnsignedToSigned => {
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ }
+ ptx_parser::CvtMode::SaturateSignedToUnsigned => {
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ }
+ ptx_parser::CvtMode::FPExtend { flush_to_zero } => {
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ }
+ ptx_parser::CvtMode::FPTruncate {
+ rounding,
+ flush_to_zero,
+ } => {
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ ptx_parser::CvtMode::FPRound {
+ integer_rounding,
+ flush_to_zero,
+ } => {
+ if flush_to_zero == Some(true) {
+ todo!()
+ }
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ match integer_rounding {
+ Some(ast::RoundingMode::NearestEven) => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::rint as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ }
+ Some(ast::RoundingMode::Zero) => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::trunc as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ }
+ Some(ast::RoundingMode::NegativeInf) => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::floor as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ }
+ Some(ast::RoundingMode::PositiveInf) => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::ceil as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ }
+ None => {
+ builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ }
+ }
+ }
+ ptx_parser::CvtMode::SignedFromFP {
+ rounding,
+ flush_to_zero,
+ } => {
+ let dest_t: ast::ScalarType = dets.to.into();
+ let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
+ builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ ptx_parser::CvtMode::UnsignedFromFP {
+ rounding,
+ flush_to_zero,
+ } => {
+ let dest_t: ast::ScalarType = dets.to.into();
+ let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
+ builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ ptx_parser::CvtMode::FPFromSigned(rounding) => {
+ let dest_t: ast::ScalarType = dets.to.into();
+ let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
+ builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ ptx_parser::CvtMode::FPFromUnsigned(rounding) => {
+ let dest_t: ast::ScalarType = dets.to.into();
+ let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
+ builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ }
+ Ok(())
+}
+
+fn emit_mad_uint(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ type_: ast::ScalarType,
+ control: ast::MulIntControl,
+ arg: &ast::MadArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_)))
+ .0;
+ match control {
+ ast::MulIntControl::Low => {
+ let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?;
+ builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?;
+ }
+ ast::MulIntControl::High => {
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::u_mad_hi as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ dr::Operand::IdRef(arg.src3.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ ast::MulIntControl::Wide => todo!(),
+ };
+ Ok(())
+}
+
+fn emit_mad_sint(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ type_: ast::ScalarType,
+ control: ast::MulIntControl,
+ arg: &ast::MadArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0;
+ match control {
+ ast::MulIntControl::Low => {
+ let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?;
+ builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?;
+ }
+ ast::MulIntControl::High => {
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::s_mad_hi as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ dr::Operand::IdRef(arg.src3.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ ast::MulIntControl::Wide => todo!(),
+ };
+ Ok(())
+}
+
+fn emit_mad_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::ArithFloat,
+ arg: &ast::MadArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
+ .0;
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::mad as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ dr::Operand::IdRef(arg.src3.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_fma_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::ArithFloat,
+ arg: &ast::FmaArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
+ .0;
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::fma as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ dr::Operand::IdRef(arg.src3.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_sub_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ typ: ast::ScalarType,
+ saturate: bool,
+ arg: &ast::SubArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ if saturate {
+ todo!()
+ }
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)))
+ .0;
+ builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ Ok(())
+}
+
+fn emit_sub_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ desc: &ast::ArithFloat,
+ arg: &ast::SubArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
+ .0;
+ builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
+ Ok(())
+}
+
+fn emit_min(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MinMaxDetails,
+ arg: &ast::MinArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let cl_op = match desc {
+ ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
+ ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
+ ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
+ };
+ let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_()));
+ builder.ext_inst(
+ inst_type.0,
+ Some(arg.dst.0),
+ opencl,
+ cl_op as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_max(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MinMaxDetails,
+ arg: &ast::MaxArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let cl_op = match desc {
+ ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
+ ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
+ ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
+ };
+ let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_()));
+ builder.ext_inst(
+ inst_type.0,
+ Some(arg.dst.0),
+ opencl,
+ cl_op as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_rcp(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::RcpData,
+ arg: &ast::RcpArgs<SpirvWord>,
+) -> Result<(), TranslateError> {
+ let is_f64 = desc.type_ == ast::ScalarType::F64;
+ let (instr_type, constant) = if is_f64 {
+ (ast::ScalarType::F64, vec_repr(1.0f64))
+ } else {
+ (ast::ScalarType::F32, vec_repr(1.0f32))
+ };
+ let result_type = map.get_or_add_scalar(builder, instr_type);
+ let rounding = match desc.kind {
+ ptx_parser::RcpKind::Approx => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::native_recip as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ return Ok(());
+ }
+ ptx_parser::RcpKind::Compliant(rounding) => rounding,
+ };
+ let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
+ builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ builder.decorate(
+ arg.dst.0,
+ spirv::Decoration::FPFastMathMode,
+ [dr::Operand::FPFastMathMode(
+ spirv::FPFastMathMode::ALLOW_RECIP,
+ )]
+ .iter()
+ .cloned(),
+ );
+ Ok(())
+}
+
+fn emit_atom(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ details: &ast::AtomDetails,
+ arg: &ast::AtomArgs<SpirvWord>,
+) -> Result<(), TranslateError> {
+ let spirv_op = match details.op {
+ ptx_parser::AtomicOp::And => dr::Builder::atomic_and,
+ ptx_parser::AtomicOp::Or => dr::Builder::atomic_or,
+ ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor,
+ ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange,
+ ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add,
+ ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => {
+ return Err(error_unreachable())
+ }
+ ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min,
+ ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min,
+ ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max,
+ ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max,
+ ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext,
+ ptx_parser::AtomicOp::FloatMin => todo!(),
+ ptx_parser::AtomicOp::FloatMax => todo!(),
+ };
+ let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone()));
+ let memory_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(scope_to_spirv(details.scope) as u32),
+ )?;
+ let semantics_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(semantics_to_spirv(details.semantics).bits()),
+ )?;
+ spirv_op(
+ builder,
+ result_type.0,
+ Some(arg.dst.0),
+ arg.src1.0,
+ memory_const.0,
+ semantics_const.0,
+ arg.src2.0,
+ )?;
+ Ok(())
+}
+
+fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope {
+ match this {
+ ast::MemScope::Cta => spirv::Scope::Workgroup,
+ ast::MemScope::Gpu => spirv::Scope::Device,
+ ast::MemScope::Sys => spirv::Scope::CrossDevice,
+ ptx_parser::MemScope::Cluster => todo!(),
+ }
+}
+
+fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics {
+ match this {
+ ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED,
+ ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE,
+ ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE,
+ ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE,
+ }
+}
+
+fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) {
+ match kind {
+ ast::DivFloatKind::Approx => {
+ builder.decorate(
+ dst.0,
+ spirv::Decoration::FPFastMathMode,
+ [dr::Operand::FPFastMathMode(
+ spirv::FPFastMathMode::ALLOW_RECIP,
+ )]
+ .iter()
+ .cloned(),
+ );
+ }
+ ast::DivFloatKind::Rounding(rnd) => {
+ emit_rounding_decoration(builder, dst, Some(rnd));
+ }
+ ast::DivFloatKind::ApproxFull => {}
+ }
+}
+
+fn emit_sqrt(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ details: &ast::RcpData,
+ a: &ast::SqrtArgs<SpirvWord>,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add_scalar(builder, details.type_.into());
+ let (ocl_op, rounding) = match details.kind {
+ ast::RcpKind::Approx => (spirv::CLOp::sqrt, None),
+ ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
+ };
+ builder.ext_inst(
+ result_type.0,
+ Some(a.dst.0),
+ opencl,
+ ocl_op as spirv::Word,
+ [dr::Operand::IdRef(a.src.0)].iter().cloned(),
+ )?;
+ emit_rounding_decoration(builder, a.dst, rounding);
+ Ok(())
+}
+
+// TODO: check what kind of assembly do we emit
+fn emit_logical_xor_spirv(
+ builder: &mut dr::Builder,
+ result_type: spirv::Word,
+ result_id: Option<spirv::Word>,
+ op1: spirv::Word,
+ op2: spirv::Word,
+) -> Result<spirv::Word, dr::Error> {
+ let temp_or = builder.logical_or(result_type, None, op1, op2)?;
+ let temp_and = builder.logical_and(result_type, None, op1, op2)?;
+ let temp_neg = logical_not(builder, result_type, None, temp_and)?;
+ builder.logical_and(result_type, result_id, temp_or, temp_neg)
+}
+
+fn emit_load_var(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ details: &LoadVarDetails,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone()));
+ match details.member_index {
+ Some((index, Some(width))) => {
+ let vector_type = match details.typ {
+ ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t),
+ _ => return Err(error_mismatched_type()),
+ };
+ let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
+ let vector_temp = builder.load(
+ vector_type_spirv.0,
+ None,
+ details.arg.src.0,
+ None,
+ iter::empty(),
+ )?;
+ builder.composite_extract(
+ result_type.0,
+ Some(details.arg.dst.0),
+ vector_temp,
+ [index as u32].iter().copied(),
+ )?;
+ }
+ Some((index, None)) => {
+ let result_ptr_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function),
+ );
+ let index_spirv = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(index as u32),
+ )?;
+ let src = builder.in_bounds_access_chain(
+ result_ptr_type.0,
+ None,
+ details.arg.src.0,
+ [index_spirv.0].iter().copied(),
+ )?;
+ builder.load(
+ result_type.0,
+ Some(details.arg.dst.0),
+ src,
+ None,
+ iter::empty(),
+ )?;
+ }
+ None => {
+ builder.load(
+ result_type.0,
+ Some(details.arg.dst.0),
+ details.arg.src.0,
+ None,
+ iter::empty(),
+ )?;
+ }
+ };
+ Ok(())
+}
+
+fn to_parts(this: &ast::Type) -> TypeParts {
+ match this {
+ ast::Type::Scalar(scalar) => TypeParts {
+ kind: TypeKind::Scalar,
+ state_space: ast::StateSpace::Reg,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: Vec::new(),
+ },
+ ast::Type::Vector(components, scalar) => TypeParts {
+ kind: TypeKind::Vector,
+ state_space: ast::StateSpace::Reg,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: vec![*components as u32],
+ },
+ ast::Type::Array(_, scalar, components) => TypeParts {
+ kind: TypeKind::Array,
+ state_space: ast::StateSpace::Reg,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: components.clone(),
+ },
+ ast::Type::Pointer(scalar, space) => TypeParts {
+ kind: TypeKind::Pointer,
+ state_space: *space,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: Vec::new(),
+ },
+ }
+}
+
+fn type_from_parts(t: TypeParts) -> ast::Type {
+ match t.kind {
+ TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)),
+ TypeKind::Vector => ast::Type::Vector(
+ t.components[0] as u8,
+ scalar_from_parts(t.width, t.scalar_kind),
+ ),
+ TypeKind::Array => ast::Type::Array(
+ None,
+ scalar_from_parts(t.width, t.scalar_kind),
+ t.components,
+ ),
+ TypeKind::Pointer => {
+ ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space)
+ }
+ }
+}
+
+#[derive(Eq, PartialEq, Clone)]
+struct TypeParts {
+ kind: TypeKind,
+ scalar_kind: ast::ScalarKind,
+ width: u8,
+ state_space: ast::StateSpace,
+ components: Vec<u32>,
+}
+
+#[derive(Eq, PartialEq, Copy, Clone)]
+enum TypeKind {
+ Scalar,
+ Vector,
+ Array,
+ Pointer,
+}
diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs
new file mode 100644
index 0000000..d0c7c98
--- /dev/null
+++ b/ptx/src/pass/expand_arguments.rs
@@ -0,0 +1,181 @@
+use super::*;
+use ptx_parser as ast;
+
+pub(super) fn run<'a, 'b>(
+ func: Vec<TypedStatement>,
+ id_def: &'b mut MutableNumericIdResolver<'a>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
+ Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
+ Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
+ Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
+ Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
+ Statement::Constant(c) => result.push(Statement::Constant(c)),
+ Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
+ s => {
+ let (new_statement, post_stmts) = {
+ let mut visitor = FlattenArguments::new(&mut result, id_def);
+ (s.visit_map(&mut visitor)?, visitor.post_stmts)
+ };
+ result.push(new_statement);
+ result.extend(post_stmts);
+ }
+ }
+ }
+ Ok(result)
+}
+
+struct FlattenArguments<'a, 'b> {
+ func: &'b mut Vec<ExpandedStatement>,
+ id_def: &'b mut MutableNumericIdResolver<'a>,
+ post_stmts: Vec<ExpandedStatement>,
+}
+
+impl<'a, 'b> FlattenArguments<'a, 'b> {
+ fn new(
+ func: &'b mut Vec<ExpandedStatement>,
+ id_def: &'b mut MutableNumericIdResolver<'a>,
+ ) -> Self {
+ FlattenArguments {
+ func,
+ id_def,
+ post_stmts: Vec::new(),
+ }
+ }
+
+ fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
+ Ok(name)
+ }
+
+ fn reg_offset(
+ &mut self,
+ reg: SpirvWord,
+ offset: i32,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ _is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (type_, state_space) = if let Some((type_, state_space)) = type_space {
+ (type_, state_space)
+ } else {
+ return Err(TranslateError::UntypedSymbol);
+ };
+ if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg {
+ let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
+ if !space_is_compatible(reg_space, ast::StateSpace::Reg) {
+ return Err(error_mismatched_type());
+ }
+ let reg_scalar_type = match reg_type {
+ ast::Type::Scalar(underlying_type) => underlying_type,
+ _ => return Err(error_mismatched_type()),
+ };
+ let id_constant_stmt = self
+ .id_def
+ .register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: reg_scalar_type,
+ value: ast::ImmediateValue::S64(offset as i64),
+ }));
+ let arith_details = match reg_scalar_type.kind() {
+ ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: reg_scalar_type,
+ saturate: false,
+ }),
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: reg_scalar_type,
+ saturate: false,
+ })
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
+ self.func
+ .push(Statement::Instruction(ast::Instruction::Add {
+ data: arith_details,
+ arguments: ast::AddArgs {
+ dst: id_add_result,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ }));
+ Ok(id_add_result)
+ } else {
+ let id_constant_stmt = self.id_def.register_intermediate(
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ );
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: ast::ScalarType::S64,
+ value: ast::ImmediateValue::S64(offset as i64),
+ }));
+ let dst = self
+ .id_def
+ .register_intermediate(type_.clone(), state_space);
+ self.func.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: type_.clone(),
+ state_space: state_space,
+ dst,
+ ptr_src: reg,
+ offset_src: id_constant_stmt,
+ }));
+ Ok(dst)
+ }
+ }
+
+ fn immediate(
+ &mut self,
+ value: ast::ImmediateValue,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (scalar_t, state_space) =
+ if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
+ (*scalar, state_space)
+ } else {
+ return Err(TranslateError::UntypedSymbol);
+ };
+ let id = self
+ .id_def
+ .register_intermediate(ast::Type::Scalar(scalar_t), state_space);
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value,
+ }));
+ Ok(id)
+ }
+}
+
+impl<'a, 'b> ast::VisitorMap<TypedOperand, SpirvWord, TranslateError> for FlattenArguments<'a, 'b> {
+ fn visit(
+ &mut self,
+ args: TypedOperand,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ match args {
+ TypedOperand::Reg(r) => self.reg(r),
+ TypedOperand::Imm(x) => self.immediate(x, type_space),
+ TypedOperand::RegOffset(reg, offset) => {
+ self.reg_offset(reg, offset, type_space, is_dst)
+ }
+ TypedOperand::VecMember(..) => Err(error_unreachable()),
+ }
+ }
+
+ fn visit_ident(
+ &mut self,
+ name: <TypedOperand as ptx_parser::Operand>::Ident,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ _is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, TranslateError> {
+ self.reg(name)
+ }
+}
diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs
new file mode 100644
index 0000000..680a5ee
--- /dev/null
+++ b/ptx/src/pass/extract_globals.rs
@@ -0,0 +1,282 @@
+use super::*;
+
+pub(super) fn run<'input, 'b>(
+ sorted_statements: Vec<ExpandedStatement>,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ id_def: &mut NumericIdResolver,
+) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
+ let mut local = Vec::with_capacity(sorted_statements.len());
+ let mut global = Vec::new();
+ for statement in sorted_statements {
+ match statement {
+ Statement::Variable(
+ var @ ast::Variable {
+ state_space: ast::StateSpace::Shared,
+ ..
+ },
+ )
+ | Statement::Variable(
+ var @ ast::Variable {
+ state_space: ast::StateSpace::Global,
+ ..
+ },
+ ) => global.push(var),
+ Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bfe { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bfi { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
+ let fn_name: String =
+ [ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Brev { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Activemask { arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::IncrementWrap,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_inc",
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::DecrementWrap,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_dec",
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::FloatAdd,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let scalar_type = match data.type_ {
+ ptx_parser::Type::Scalar(scalar) => scalar,
+ _ => return Err(error_unreachable()),
+ };
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_add_",
+ scalar_to_ptx_name(scalar_type),
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ s => local.push(s),
+ }
+ }
+ Ok((local, global))
+}
+
+fn instruction_to_fn_call(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ inst: ast::Instruction<SpirvWord>,
+ fn_name: String,
+) -> Result<ExpandedStatement, TranslateError> {
+ let mut arguments = Vec::new();
+ ast::visit_map(inst, &mut |operand,
+ type_space: Option<(
+ &ast::Type,
+ ast::StateSpace,
+ )>,
+ is_dst,
+ _| {
+ let (typ, space) = match type_space {
+ Some((typ, space)) => (typ.clone(), space),
+ None => return Err(error_unreachable()),
+ };
+ arguments.push((operand, is_dst, typ, space));
+ Ok(SpirvWord(0))
+ })?;
+ let return_arguments_count = arguments
+ .iter()
+ .position(|(desc, is_dst, _, _)| !is_dst)
+ .unwrap_or(arguments.len());
+ let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
+ let fn_id = register_external_fn_call(
+ id_defs,
+ ptx_impl_imports,
+ fn_name,
+ return_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ, *state)),
+ input_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ, *state)),
+ )?;
+ Ok(Statement::Instruction(ast::Instruction::Call {
+ data: ast::CallDetails {
+ uniform: false,
+ return_arguments: return_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ.clone(), *state))
+ .collect::<Vec<_>>(),
+ input_arguments: input_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ.clone(), *state))
+ .collect::<Vec<_>>(),
+ },
+ arguments: ast::CallArgs {
+ return_arguments: return_arguments
+ .iter()
+ .map(|(name, _, _, _)| *name)
+ .collect::<Vec<_>>(),
+ func: fn_id,
+ input_arguments: input_arguments
+ .iter()
+ .map(|(name, _, _, _)| *name)
+ .collect::<Vec<_>>(),
+ },
+ }))
+}
+
+fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
+ match this {
+ ast::ScalarType::B8 => "b8",
+ ast::ScalarType::B16 => "b16",
+ ast::ScalarType::B32 => "b32",
+ ast::ScalarType::B64 => "b64",
+ ast::ScalarType::B128 => "b128",
+ ast::ScalarType::U8 => "u8",
+ ast::ScalarType::U16 => "u16",
+ ast::ScalarType::U16x2 => "u16x2",
+ ast::ScalarType::U32 => "u32",
+ ast::ScalarType::U64 => "u64",
+ ast::ScalarType::S8 => "s8",
+ ast::ScalarType::S16 => "s16",
+ ast::ScalarType::S16x2 => "s16x2",
+ ast::ScalarType::S32 => "s32",
+ ast::ScalarType::S64 => "s64",
+ ast::ScalarType::F16 => "f16",
+ ast::ScalarType::F16x2 => "f16x2",
+ ast::ScalarType::F32 => "f32",
+ ast::ScalarType::F64 => "f64",
+ ast::ScalarType::BF16 => "bf16",
+ ast::ScalarType::BF16x2 => "bf16x2",
+ ast::ScalarType::Pred => "pred",
+ }
+}
+
+fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
+ match this {
+ ast::AtomSemantics::Relaxed => "relaxed",
+ ast::AtomSemantics::Acquire => "acquire",
+ ast::AtomSemantics::Release => "release",
+ ast::AtomSemantics::AcqRel => "acq_rel",
+ }
+}
+
+fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
+ match this {
+ ast::MemScope::Cta => "cta",
+ ast::MemScope::Gpu => "gpu",
+ ast::MemScope::Sys => "sys",
+ ast::MemScope::Cluster => "cluster",
+ }
+}
+
+fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
+ match this {
+ ast::StateSpace::Generic => "generic",
+ ast::StateSpace::Global => "global",
+ ast::StateSpace::Shared => "shared",
+ ast::StateSpace::Reg => "reg",
+ ast::StateSpace::Const => "const",
+ ast::StateSpace::Local => "local",
+ ast::StateSpace::Param => "param",
+ ast::StateSpace::Sreg => "sreg",
+ ast::StateSpace::SharedCluster => "shared_cluster",
+ ast::StateSpace::ParamEntry => "param_entry",
+ ast::StateSpace::SharedCta => "shared_cta",
+ ast::StateSpace::ParamFunc => "param_func",
+ }
+}
diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs
new file mode 100644
index 0000000..c029016
--- /dev/null
+++ b/ptx/src/pass/fix_special_registers.rs
@@ -0,0 +1,130 @@
+use super::*;
+use std::collections::HashMap;
+
+pub(super) fn run<'a, 'b, 'input>(
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ typed_statements: Vec<TypedStatement>,
+ numeric_id_defs: &'a mut NumericIdResolver<'b>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let result = Vec::with_capacity(typed_statements.len());
+ let mut sreg_sresolver = SpecialRegisterResolver {
+ ptx_impl_imports,
+ numeric_id_defs,
+ result,
+ };
+ for statement in typed_statements {
+ let statement = statement.visit_map(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(statement);
+ }
+ Ok(sreg_sresolver.result)
+}
+
+struct SpecialRegisterResolver<'a, 'b, 'input> {
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ numeric_id_defs: &'a mut NumericIdResolver<'b>,
+ result: Vec<TypedStatement>,
+}
+
+impl<'a, 'b, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
+ for SpecialRegisterResolver<'a, 'b, 'input>
+{
+ fn visit(
+ &mut self,
+ operand: TypedOperand,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<TypedOperand, TranslateError> {
+ operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index))
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ self.replace_sreg(args, is_dst, None)
+ }
+}
+
+impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
+ fn replace_sreg(
+ &mut self,
+ name: SpirvWord,
+ is_dst: bool,
+ vector_index: Option<u8>,
+ ) -> Result<SpirvWord, TranslateError> {
+ if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
+ if is_dst {
+ return Err(error_mismatched_type());
+ }
+ let input_arguments = match (vector_index, sreg.get_function_input_type()) {
+ (Some(idx), Some(inp_type)) => {
+ if inp_type != ast::ScalarType::U8 {
+ return Err(TranslateError::Unreachable);
+ }
+ let constant = self.numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: constant,
+ typ: inp_type,
+ value: ast::ImmediateValue::U64(idx as u64),
+ }));
+ vec![(
+ TypedOperand::Reg(constant),
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )]
+ }
+ (None, None) => Vec::new(),
+ _ => return Err(error_mismatched_type()),
+ };
+ let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
+ let return_type = sreg.get_function_return_type();
+ let fn_result = self.numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )));
+ let return_arguments = vec![(
+ fn_result,
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )];
+ let fn_call = register_external_fn_call(
+ self.numeric_id_defs,
+ self.ptx_impl_imports,
+ ocl_fn_name.to_string(),
+ return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ )?;
+ let data = ast::CallDetails {
+ uniform: false,
+ return_arguments: return_arguments
+ .iter()
+ .map(|(_, typ, space)| (typ.clone(), *space))
+ .collect(),
+ input_arguments: input_arguments
+ .iter()
+ .map(|(_, typ, space)| (typ.clone(), *space))
+ .collect(),
+ };
+ let arguments = ast::CallArgs {
+ return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
+ func: fn_call,
+ input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(),
+ };
+ self.result
+ .push(Statement::Instruction(ast::Instruction::Call {
+ data,
+ arguments,
+ }));
+ Ok(fn_result)
+ } else {
+ Ok(name)
+ }
+ }
+}
diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs
new file mode 100644
index 0000000..25e80f0
--- /dev/null
+++ b/ptx/src/pass/insert_implicit_conversions.rs
@@ -0,0 +1,432 @@
+use std::mem;
+
+use super::*;
+use ptx_parser as ast;
+
+/*
+ There are several kinds of implicit conversions in PTX:
+ * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
+ * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
+ - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
+ semantics are to first zext/chop/bitcast `y` as needed and then do
+ documented special ld/st/cvt conversion rules for destination operands
+ - st.param [x] y (used as function return arguments) same rule as above applies
+ - generic/global ld: for instruction `ld x, [y]`, y must be of type
+ b64/u64/s64, which is bitcast to a pointer, dereferenced and then
+ documented special ld/st/cvt conversion rules are applied to dst
+ - generic/global st: for instruction `st [x], y`, x must be of type
+ b64/u64/s64, which is bitcast to a pointer
+*/
+pub(super) fn run(
+ func: Vec<ExpandedStatement>,
+ id_def: &mut MutableNumericIdResolver,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func.into_iter() {
+ match s {
+ Statement::Instruction(inst) => {
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ Statement::Instruction(inst),
+ )?;
+ }
+ Statement::PtrAccess(access) => {
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ Statement::PtrAccess(access),
+ )?;
+ }
+ Statement::RepackVector(repack) => {
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ Statement::RepackVector(repack),
+ )?;
+ }
+ s @ Statement::Conditional(_)
+ | s @ Statement::Conversion(_)
+ | s @ Statement::Label(_)
+ | s @ Statement::Constant(_)
+ | s @ Statement::Variable(_)
+ | s @ Statement::LoadVar(..)
+ | s @ Statement::StoreVar(..)
+ | s @ Statement::RetValue(..)
+ | s @ Statement::FunctionPointer(..) => result.push(s),
+ }
+ }
+ Ok(result)
+}
+
+fn insert_implicit_conversions_impl(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut MutableNumericIdResolver,
+ stmt: ExpandedStatement,
+) -> Result<(), TranslateError> {
+ let mut post_conv = Vec::new();
+ let statement = stmt.visit_map::<SpirvWord, TranslateError>(
+ &mut |operand,
+ type_state: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst,
+ relaxed_type_check| {
+ let (instr_type, instruction_space) = match type_state {
+ None => return Ok(operand),
+ Some(t) => t,
+ };
+ let (operand_type, operand_space) = id_def.get_typed(operand)?;
+ let conversion_fn = if relaxed_type_check {
+ if is_dst {
+ should_convert_relaxed_dst_wrapper
+ } else {
+ should_convert_relaxed_src_wrapper
+ }
+ } else {
+ default_implicit_conversion
+ };
+ match conversion_fn(
+ (operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
+ Some(conv_kind) => {
+ let conv_output = if is_dst { &mut post_conv } else { &mut *func };
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type;
+ let mut to_space = operand_space;
+ let mut src =
+ id_def.register_intermediate(instr_type.clone(), instruction_space);
+ let mut dst = operand;
+ let result = Ok::<_, TranslateError>(src);
+ if !is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
+ kind: conv_kind,
+ }));
+ result
+ }
+ None => Ok(operand),
+ }
+ },
+ )?;
+ func.push(statement);
+ func.append(&mut post_conv);
+ Ok(())
+}
+
+pub(crate) fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instruction_space == ast::StateSpace::Reg {
+ if space_is_compatible(operand_space, ast::StateSpace::Reg) {
+ if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
+ {
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
+ }
+ } else if is_addressable(operand_space) {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ }
+ if !space_is_compatible(instruction_space, operand_space) {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
+}
+
+fn is_addressable(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
+ ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc => todo!(),
+ }
+}
+
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
+ || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if space_is_compatible(operand_space, ast::StateSpace::Reg) {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(error_mismatched_type()),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(error_mismatched_type()),
+ },
+ _ => Err(error_mismatched_type()),
+ }
+ } else if space_is_compatible(instruction_space, ast::StateSpace::Reg) {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ _ => Err(error_mismatched_type()),
+ }
+ } else {
+ Err(error_mismatched_type())
+ }
+}
+
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if space_is_compatible(space, ast::StateSpace::Reg) {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
+ }
+}
+
+fn coerces_to_generic(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ptx_parser::StateSpace::SharedCta
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Sreg => false,
+ }
+}
+
+fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
+ match (instr, operand) {
+ (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
+ if inst.size_of() != operand.size_of() {
+ return false;
+ }
+ match inst.kind() {
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
+ }
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
+ }
+ ast::ScalarKind::Pred => false,
+ }
+ }
+ (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
+ | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
+ should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
+ }
+ _ => false,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_dst_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if !space_is_compatible(operand_space, instruction_space) {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
+fn should_convert_relaxed_dst(
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if dst_type == instr_type {
+ return None;
+ }
+ match (dst_type, instr_type) {
+ (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
+ if instr_type.size_of() == dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else if instr_type.size_of() < dst_type.size_of() {
+ Some(ConversionKind::SignExtend)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_dst(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_src_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if !space_is_compatible(operand_space, instruction_space) {
+ return Err(error_mismatched_type());
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_src(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(error_mismatched_type()),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
+fn should_convert_relaxed_src(
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if src_type == instr_type {
+ return None;
+ }
+ match (src_type, instr_type) {
+ (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= src_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_src(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs
new file mode 100644
index 0000000..e314b05
--- /dev/null
+++ b/ptx/src/pass/insert_mem_ssa_statements.rs
@@ -0,0 +1,275 @@
+use super::*;
+use ptx_parser as ast;
+
+/*
+ How do we handle arguments:
+ - input .params in kernels
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ We do this for two reasons. One, common treatment for argument-declared
+ .param variables and .param variables inside function (we assume that
+ at SPIR-V level every .param is a pointer in Function storage class)
+ - input .params in functions
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %_ptr_Function_ulong
+ - input .regs
+ .reg .b64 in_arg
+ get turned into the same SPIR-V as kernel .params:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ - output .regs
+ .reg .b64 out_arg
+ get just a variable declaration:
+ %2 = OpVariable %%_ptr_Function_ulong Function
+ - output .params don't exist, they have been moved to input positions
+ by an earlier pass
+ Distinguishing betweem kernel .params and function .params is not the
+ cleanest solution. Alternatively, we could "deparamize" all kernel .param
+ arguments by turning them into .reg arguments like this:
+ .param .b64 arg -> .reg ptr<.b64,.param> arg
+ This has the massive downside that this transformation would have to run
+ very early and would muddy up already difficult code. It's simpler to just
+ have an if here
+*/
+pub(super) fn run<'a, 'b>(
+ func: Vec<TypedStatement>,
+ id_def: &mut NumericIdResolver,
+ fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for arg in fn_decl.input_arguments.iter_mut() {
+ insert_mem_ssa_argument(
+ id_def,
+ &mut result,
+ arg,
+ matches!(fn_decl.name, ast::MethodName::Kernel(_)),
+ );
+ }
+ for arg in fn_decl.return_arguments.iter() {
+ insert_mem_ssa_argument_reg_return(&mut result, arg);
+ }
+ for s in func {
+ match s {
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::Ret { data } => {
+ // TODO: handle multiple output args
+ match &fn_decl.return_arguments[..] {
+ [return_reg] => {
+ let new_id = id_def.register_intermediate(Some((
+ return_reg.v_type.clone(),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::LdArgs {
+ dst: new_id,
+ src: return_reg.name,
+ },
+ typ: return_reg.v_type.clone(),
+ member_index: None,
+ }));
+ result.push(Statement::RetValue(data, new_id));
+ }
+ [] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
+ _ => unimplemented!(),
+ }
+ }
+ inst => insert_mem_ssa_statement_default(
+ id_def,
+ &mut result,
+ Statement::Instruction(inst),
+ )?,
+ },
+ Statement::Conditional(bra) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
+ }
+ Statement::Conversion(conv) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
+ }
+ Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
+ id_def,
+ &mut result,
+ Statement::PtrAccess(ptr_access),
+ )?,
+ Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
+ id_def,
+ &mut result,
+ Statement::RepackVector(repack),
+ )?,
+ Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
+ id_def,
+ &mut result,
+ Statement::FunctionPointer(func_ptr),
+ )?,
+ s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
+ result.push(s)
+ }
+ _ => return Err(error_unreachable()),
+ }
+ }
+ Ok(result)
+}
+
+fn insert_mem_ssa_argument(
+ id_def: &mut NumericIdResolver,
+ func: &mut Vec<TypedStatement>,
+ arg: &mut ast::Variable<SpirvWord>,
+ is_kernel: bool,
+) {
+ if !is_kernel && arg.state_space == ast::StateSpace::Param {
+ return;
+ }
+ let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: ast::StateSpace::Reg,
+ name: arg.name,
+ array_init: Vec::new(),
+ }));
+ func.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::StArgs {
+ src1: arg.name,
+ src2: new_id,
+ },
+ typ: arg.v_type.clone(),
+ member_index: None,
+ }));
+ arg.name = new_id;
+}
+
+fn insert_mem_ssa_argument_reg_return(
+ func: &mut Vec<TypedStatement>,
+ arg: &ast::Variable<SpirvWord>,
+) {
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: arg.array_init.clone(),
+ }));
+}
+
+fn insert_mem_ssa_statement_default<'a, 'input>(
+ id_def: &'a mut NumericIdResolver<'input>,
+ func: &'a mut Vec<TypedStatement>,
+ stmt: TypedStatement,
+) -> Result<(), TranslateError> {
+ let mut visitor = InsertMemSSAVisitor {
+ id_def,
+ func,
+ post_statements: Vec::new(),
+ };
+ let new_stmt = stmt.visit_map(&mut visitor)?;
+ visitor.func.push(new_stmt);
+ visitor.func.extend(visitor.post_statements);
+ Ok(())
+}
+
+struct InsertMemSSAVisitor<'a, 'input> {
+ id_def: &'a mut NumericIdResolver<'input>,
+ func: &'a mut Vec<TypedStatement>,
+ post_statements: Vec<TypedStatement>,
+}
+
+impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
+ fn symbol(
+ &mut self,
+ symbol: SpirvWord,
+ member_index: Option<u8>,
+ expected: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ if expected.is_none() {
+ return Ok(symbol);
+ };
+ let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
+ if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
+ return Ok(symbol);
+ };
+ let member_index = match member_index {
+ Some(idx) => {
+ let vector_width = match var_type {
+ ast::Type::Vector(width, scalar_t) => {
+ var_type = ast::Type::Scalar(scalar_t);
+ width
+ }
+ _ => return Err(error_mismatched_type()),
+ };
+ Some((
+ idx,
+ if self.id_def.special_registers.get(symbol).is_some() {
+ Some(vector_width)
+ } else {
+ None
+ },
+ ))
+ }
+ None => None,
+ };
+ let generated_id = self
+ .id_def
+ .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
+ if !is_dst {
+ self.func.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::LdArgs {
+ dst: generated_id,
+ src: symbol,
+ },
+ typ: var_type,
+ member_index,
+ }));
+ } else {
+ self.post_statements
+ .push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::StArgs {
+ src1: symbol,
+ src2: generated_id,
+ },
+ typ: var_type,
+ member_index: member_index.map(|(idx, _)| idx),
+ }));
+ }
+ Ok(generated_id)
+ }
+}
+
+impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
+ for InsertMemSSAVisitor<'a, 'input>
+{
+ fn visit(
+ &mut self,
+ operand: TypedOperand,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match operand {
+ TypedOperand::Reg(reg) => {
+ TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
+ }
+ TypedOperand::RegOffset(reg, offset) => {
+ TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
+ }
+ op @ TypedOperand::Imm(..) => op,
+ TypedOperand::VecMember(symbol, index) => {
+ TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?)
+ }
+ })
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ self.symbol(args, None, type_space, is_dst)
+ }
+}
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs
new file mode 100644
index 0000000..2be6297
--- /dev/null
+++ b/ptx/src/pass/mod.rs
@@ -0,0 +1,1677 @@
+use ptx_parser as ast;
+use rspirv::{binary::Assemble, dr};
+use std::hash::Hash;
+use std::num::NonZeroU8;
+use std::{
+ borrow::Cow,
+ cell::RefCell,
+ collections::{hash_map, HashMap, HashSet},
+ ffi::CString,
+ iter,
+ marker::PhantomData,
+ mem,
+ rc::Rc,
+};
+
+mod convert_dynamic_shared_memory_usage;
+mod convert_to_stateful_memory_access;
+mod convert_to_typed;
+mod emit_spirv;
+mod expand_arguments;
+mod extract_globals;
+mod fix_special_registers;
+mod insert_implicit_conversions;
+mod insert_mem_ssa_statements;
+mod normalize_identifiers;
+mod normalize_labels;
+mod normalize_predicates;
+
+static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
+static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
+const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
+
+pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
+ let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1));
+ let mut ptx_impl_imports = HashMap::new();
+ let directives = ast
+ .directives
+ .into_iter()
+ .filter_map(|directive| {
+ translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ let directives = hoist_function_globals(directives);
+ let must_link_ptx_impl = ptx_impl_imports.len() > 0;
+ let mut directives = ptx_impl_imports
+ .into_iter()
+ .map(|(_, v)| v)
+ .chain(directives.into_iter())
+ .collect::<Vec<_>>();
+ let mut builder = dr::Builder::new();
+ builder.reserve_ids(id_defs.current_id().0);
+ let call_map = MethodsCallMap::new(&directives);
+ let mut directives =
+ convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || {
+ SpirvWord(builder.id())
+ })?;
+ normalize_variable_decls(&mut directives);
+ let denorm_information = compute_denorm_information(&directives);
+ let (spirv, kernel_info, build_options) =
+ emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives)?;
+ Ok(Module {
+ spirv,
+ kernel_info,
+ should_link_ptx_impl: if must_link_ptx_impl {
+ Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD))
+ } else {
+ None
+ },
+ build_options,
+ })
+}
+
+fn translate_directive<'input, 'a>(
+ id_defs: &'a mut GlobalStringIdResolver<'input>,
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ d: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
+) -> Result<Option<Directive<'input>>, TranslateError> {
+ Ok(match d {
+ ast::Directive::Variable(linking, var) => Some(Directive::Variable(
+ linking,
+ ast::Variable {
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
+ array_init: var.array_init,
+ },
+ )),
+ ast::Directive::Method(linkage, f) => {
+ translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method)
+ }
+ })
+}
+
+type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement<ast::ParsedOperand<&'a str>>>;
+
+fn translate_function<'input, 'a>(
+ id_defs: &'a mut GlobalStringIdResolver<'input>,
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ linkage: ast::LinkingDirective,
+ f: ParsedFunction<'input>,
+) -> Result<Option<Function<'input>>, TranslateError> {
+ let import_as = match &f.func_directive {
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func(func_name),
+ ..
+ } if *func_name == "__assertfail" || *func_name == "vprintf" => {
+ Some([ZLUDA_PTX_PREFIX, func_name].concat())
+ }
+ _ => None,
+ };
+ let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
+ let mut func = to_ssa(
+ ptx_impl_imports,
+ str_resolver,
+ fn_resolver,
+ fn_decl,
+ f.body,
+ f.tuning,
+ linkage,
+ )?;
+ func.import_as = import_as;
+ if func.import_as.is_some() {
+ ptx_impl_imports.insert(
+ func.import_as.as_ref().unwrap().clone(),
+ Directive::Method(func),
+ );
+ Ok(None)
+ } else {
+ Ok(Some(func))
+ }
+}
+
+fn to_ssa<'input, 'b>(
+ ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
+ mut id_defs: FnStringIdResolver<'input, 'b>,
+ fn_defs: GlobalFnDeclResolver<'input, 'b>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ f_body: Option<Vec<ast::Statement<ast::ParsedOperand<&'input str>>>>,
+ tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
+) -> Result<Function<'input>, TranslateError> {
+ //deparamize_function_decl(&func_decl)?;
+ let f_body = match f_body {
+ Some(vec) => vec,
+ None => {
+ return Ok(Function {
+ func_decl: func_decl,
+ body: None,
+ globals: Vec::new(),
+ import_as: None,
+ tuning,
+ linkage,
+ })
+ }
+ };
+ let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?;
+ let mut numeric_id_defs = id_defs.finish();
+ let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?;
+ let typed_statements =
+ convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
+ let typed_statements =
+ fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
+ let (func_decl, typed_statements) =
+ convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
+ let ssa_statements = insert_mem_ssa_statements::run(
+ typed_statements,
+ &mut numeric_id_defs,
+ &mut (*func_decl).borrow_mut(),
+ )?;
+ let mut numeric_id_defs = numeric_id_defs.finish();
+ let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?;
+ let expanded_statements =
+ insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?;
+ let mut numeric_id_defs = numeric_id_defs.unmut();
+ let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs);
+ let (f_body, globals) =
+ extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
+ Ok(Function {
+ func_decl: func_decl,
+ globals: globals,
+ body: Some(f_body),
+ import_as: None,
+ tuning,
+ linkage,
+ })
+}
+
+pub struct Module {
+ pub spirv: dr::Module,
+ pub kernel_info: HashMap<String, KernelInfo>,
+ pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>,
+ pub build_options: CString,
+}
+
+impl Module {
+ pub fn assemble(&self) -> Vec<u32> {
+ self.spirv.assemble()
+ }
+}
+
+struct GlobalStringIdResolver<'input> {
+ current_id: SpirvWord,
+ variables: HashMap<Cow<'input, str>, SpirvWord>,
+ reverse_variables: HashMap<SpirvWord, &'input str>,
+ variables_type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+ special_registers: SpecialRegistersMap,
+ fns: HashMap<SpirvWord, FnSigMapper<'input>>,
+}
+
+impl<'input> GlobalStringIdResolver<'input> {
+ fn new(start_id: SpirvWord) -> Self {
+ Self {
+ current_id: start_id,
+ variables: HashMap::new(),
+ reverse_variables: HashMap::new(),
+ variables_type_check: HashMap::new(),
+ special_registers: SpecialRegistersMap::new(),
+ fns: HashMap::new(),
+ }
+ }
+
+ fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord {
+ self.get_or_add_impl(id, None)
+ }
+
+ fn get_or_add_def_typed(
+ &mut self,
+ id: &'input str,
+ typ: ast::Type,
+ state_space: ast::StateSpace,
+ is_variable: bool,
+ ) -> SpirvWord {
+ self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
+ }
+
+ fn get_or_add_impl(
+ &mut self,
+ id: &'input str,
+ typ: Option<(ast::Type, ast::StateSpace, bool)>,
+ ) -> SpirvWord {
+ let id = match self.variables.entry(Cow::Borrowed(id)) {
+ hash_map::Entry::Occupied(e) => *(e.get()),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = self.current_id;
+ e.insert(numeric_id);
+ self.reverse_variables.insert(numeric_id, id);
+ self.current_id.0 += 1;
+ numeric_id
+ }
+ };
+ self.variables_type_check.insert(id, typ);
+ id
+ }
+
+ fn get_id(&self, id: &str) -> Result<SpirvWord, TranslateError> {
+ self.variables
+ .get(id)
+ .copied()
+ .ok_or_else(error_unknown_symbol)
+ }
+
+ fn current_id(&self) -> SpirvWord {
+ self.current_id
+ }
+
+ fn start_fn<'b>(
+ &'b mut self,
+ header: &'b ast::MethodDeclaration<'input, &'input str>,
+ ) -> Result<
+ (
+ FnStringIdResolver<'input, 'b>,
+ GlobalFnDeclResolver<'input, 'b>,
+ Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ ),
+ TranslateError,
+ > {
+ // In case a function decl was inserted earlier we want to use its id
+ let name_id = self.get_or_add_def(header.name());
+ let mut fn_resolver = FnStringIdResolver {
+ current_id: &mut self.current_id,
+ global_variables: &self.variables,
+ global_type_check: &self.variables_type_check,
+ special_registers: &mut self.special_registers,
+ variables: vec![HashMap::new(); 1],
+ type_check: HashMap::new(),
+ };
+ let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
+ let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
+ let name = match header.name {
+ ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
+ ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
+ };
+ let fn_decl = ast::MethodDeclaration {
+ return_arguments,
+ name,
+ input_arguments,
+ shared_mem: None,
+ };
+ let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) {
+ let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
+ let new_fn_decl = resolver.func_decl.clone();
+ self.fns.insert(name_id, resolver);
+ new_fn_decl
+ } else {
+ Rc::new(RefCell::new(fn_decl))
+ };
+ Ok((
+ fn_resolver,
+ GlobalFnDeclResolver { fns: &self.fns },
+ new_fn_decl,
+ ))
+ }
+}
+
+fn rename_fn_params<'a, 'b>(
+ fn_resolver: &mut FnStringIdResolver<'a, 'b>,
+ args: &'b [ast::Variable<&'a str>],
+) -> Vec<ast::Variable<SpirvWord>> {
+ args.iter()
+ .map(|a| ast::Variable {
+ name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
+ v_type: a.v_type.clone(),
+ state_space: a.state_space,
+ align: a.align,
+ array_init: a.array_init.clone(),
+ })
+ .collect()
+}
+
+pub struct KernelInfo {
+ pub arguments_sizes: Vec<(usize, bool)>,
+ pub uses_shared_mem: bool,
+}
+
+#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
+enum PtxSpecialRegister {
+ Tid,
+ Ntid,
+ Ctaid,
+ Nctaid,
+ Clock,
+ LanemaskLt,
+}
+
+impl PtxSpecialRegister {
+ fn try_parse(s: &str) -> Option<Self> {
+ match s {
+ "%tid" => Some(Self::Tid),
+ "%ntid" => Some(Self::Ntid),
+ "%ctaid" => Some(Self::Ctaid),
+ "%nctaid" => Some(Self::Nctaid),
+ "%clock" => Some(Self::Clock),
+ "%lanemask_lt" => Some(Self::LanemaskLt),
+ _ => None,
+ }
+ }
+
+ fn get_type(self) -> ast::Type {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => ast::Type::Vector(4, self.get_function_return_type()),
+ _ => ast::Type::Scalar(self.get_function_return_type()),
+ }
+ }
+
+ fn get_function_return_type(self) -> ast::ScalarType {
+ match self {
+ PtxSpecialRegister::Tid => ast::ScalarType::U32,
+ PtxSpecialRegister::Ntid => ast::ScalarType::U32,
+ PtxSpecialRegister::Ctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Clock => ast::ScalarType::U32,
+ PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32,
+ }
+ }
+
+ fn get_function_input_type(self) -> Option<ast::ScalarType> {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8),
+ PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None,
+ }
+ }
+
+ fn get_unprefixed_function_name(self) -> &'static str {
+ match self {
+ PtxSpecialRegister::Tid => "sreg_tid",
+ PtxSpecialRegister::Ntid => "sreg_ntid",
+ PtxSpecialRegister::Ctaid => "sreg_ctaid",
+ PtxSpecialRegister::Nctaid => "sreg_nctaid",
+ PtxSpecialRegister::Clock => "sreg_clock",
+ PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt",
+ }
+ }
+}
+
+struct SpecialRegistersMap {
+ reg_to_id: HashMap<PtxSpecialRegister, SpirvWord>,
+ id_to_reg: HashMap<SpirvWord, PtxSpecialRegister>,
+}
+
+impl SpecialRegistersMap {
+ fn new() -> Self {
+ SpecialRegistersMap {
+ reg_to_id: HashMap::new(),
+ id_to_reg: HashMap::new(),
+ }
+ }
+
+ fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
+ self.id_to_reg.get(&id).copied()
+ }
+
+ fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
+ match self.reg_to_id.entry(reg) {
+ hash_map::Entry::Occupied(e) => *e.get(),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = SpirvWord(current_id.0);
+ current_id.0 += 1;
+ e.insert(numeric_id);
+ self.id_to_reg.insert(numeric_id, reg);
+ numeric_id
+ }
+ }
+ }
+}
+
+struct FnStringIdResolver<'input, 'b> {
+ current_id: &'b mut SpirvWord,
+ global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
+ global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+ special_registers: &'b mut SpecialRegistersMap,
+ variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
+ type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+}
+
+impl<'a, 'b> FnStringIdResolver<'a, 'b> {
+ fn finish(self) -> NumericIdResolver<'b> {
+ NumericIdResolver {
+ current_id: self.current_id,
+ global_type_check: self.global_type_check,
+ type_check: self.type_check,
+ special_registers: self.special_registers,
+ }
+ }
+
+ fn start_block(&mut self) {
+ self.variables.push(HashMap::new())
+ }
+
+ fn end_block(&mut self) {
+ self.variables.pop();
+ }
+
+ fn get_id(&mut self, id: &str) -> Result<SpirvWord, TranslateError> {
+ for scope in self.variables.iter().rev() {
+ match scope.get(id) {
+ Some(id) => return Ok(*id),
+ None => continue,
+ }
+ }
+ match self.global_variables.get(id) {
+ Some(id) => Ok(*id),
+ None => {
+ let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?;
+ Ok(self.special_registers.get_or_add(self.current_id, sreg))
+ }
+ }
+ }
+
+ fn add_def(
+ &mut self,
+ id: &'a str,
+ typ: Option<(ast::Type, ast::StateSpace)>,
+ is_variable: bool,
+ ) -> SpirvWord {
+ let numeric_id = *self.current_id;
+ self.variables
+ .last_mut()
+ .unwrap()
+ .insert(Cow::Borrowed(id), numeric_id);
+ self.type_check.insert(
+ numeric_id,
+ typ.map(|(typ, space)| (typ, space, is_variable)),
+ );
+ self.current_id.0 += 1;
+ numeric_id
+ }
+
+ #[must_use]
+ fn add_defs(
+ &mut self,
+ base_id: &'a str,
+ count: u32,
+ typ: ast::Type,
+ state_space: ast::StateSpace,
+ is_variable: bool,
+ ) -> impl Iterator<Item = SpirvWord> {
+ let numeric_id = *self.current_id;
+ for i in 0..count {
+ self.variables.last_mut().unwrap().insert(
+ Cow::Owned(format!("{}{}", base_id, i)),
+ SpirvWord(numeric_id.0 + i),
+ );
+ self.type_check.insert(
+ SpirvWord(numeric_id.0 + i),
+ Some((typ.clone(), state_space, is_variable)),
+ );
+ }
+ self.current_id.0 += count;
+ (0..count)
+ .into_iter()
+ .map(move |i| SpirvWord(i + numeric_id.0))
+ }
+}
+
+struct NumericIdResolver<'b> {
+ current_id: &'b mut SpirvWord,
+ global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+ type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+ special_registers: &'b mut SpecialRegistersMap,
+}
+
+impl<'b> NumericIdResolver<'b> {
+ fn finish(self) -> MutableNumericIdResolver<'b> {
+ MutableNumericIdResolver { base: self }
+ }
+
+ fn get_typed(
+ &self,
+ id: SpirvWord,
+ ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
+ match self.type_check.get(&id) {
+ Some(Some(x)) => Ok(x.clone()),
+ Some(None) => Err(TranslateError::UntypedSymbol),
+ None => match self.special_registers.get(id) {
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
+ None => match self.global_type_check.get(&id) {
+ Some(Some(result)) => Ok(result.clone()),
+ Some(None) | None => Err(TranslateError::UntypedSymbol),
+ },
+ },
+ }
+ }
+
+ // This is for identifiers which will be emitted later as OpVariable
+ // They are candidates for insertion of LoadVar/StoreVar
+ fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
+ let new_id = *self.current_id;
+ self.type_check
+ .insert(new_id, Some((typ, state_space, true)));
+ self.current_id.0 += 1;
+ new_id
+ }
+
+ fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
+ let new_id = *self.current_id;
+ self.type_check
+ .insert(new_id, typ.map(|(t, space)| (t, space, false)));
+ self.current_id.0 += 1;
+ new_id
+ }
+}
+
+struct MutableNumericIdResolver<'b> {
+ base: NumericIdResolver<'b>,
+}
+
+impl<'b> MutableNumericIdResolver<'b> {
+ fn unmut(self) -> NumericIdResolver<'b> {
+ self.base
+ }
+
+ fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
+ self.base.get_typed(id).map(|(t, space, _)| (t, space))
+ }
+
+ fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
+ self.base.register_intermediate(Some((typ, state_space)))
+ }
+}
+
+quick_error! {
+ #[derive(Debug)]
+ pub enum TranslateError {
+ UnknownSymbol {}
+ UntypedSymbol {}
+ MismatchedType {}
+ Spirv(err: rspirv::dr::Error) {
+ from()
+ display("{}", err)
+ cause(err)
+ }
+ Unreachable {}
+ Todo {}
+ }
+}
+
+#[cfg(debug_assertions)]
+fn error_unreachable() -> TranslateError {
+ unreachable!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_unreachable() -> TranslateError {
+ TranslateError::Unreachable
+}
+
+fn error_unknown_symbol() -> TranslateError {
+ panic!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_unknown_symbol() -> TranslateError {
+ TranslateError::UnknownSymbol
+}
+
+fn error_mismatched_type() -> TranslateError {
+ panic!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_mismatched_type() -> TranslateError {
+ TranslateError::MismatchedType
+}
+
+pub struct GlobalFnDeclResolver<'input, 'a> {
+ fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
+}
+
+impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
+ fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> {
+ self.fns.get(&id).ok_or_else(error_unknown_symbol)
+ }
+}
+
+struct FnSigMapper<'input> {
+ // true - stays as return argument
+ // false - is moved to input argument
+ return_param_args: Vec<bool>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+}
+
+impl<'input> FnSigMapper<'input> {
+ fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self {
+ let return_param_args = method
+ .return_arguments
+ .iter()
+ .map(|a| a.state_space != ast::StateSpace::Param)
+ .collect::<Vec<_>>();
+ let mut new_return_arguments = Vec::new();
+ for arg in method.return_arguments.into_iter() {
+ if arg.state_space == ast::StateSpace::Param {
+ method.input_arguments.push(arg);
+ } else {
+ new_return_arguments.push(arg);
+ }
+ }
+ method.return_arguments = new_return_arguments;
+ FnSigMapper {
+ return_param_args,
+ func_decl: Rc::new(RefCell::new(method)),
+ }
+ }
+
+ fn resolve_in_spirv_repr(
+ &self,
+ data: ast::CallDetails,
+ arguments: ast::CallArgs<ast::ParsedOperand<SpirvWord>>,
+ ) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
+ let func_decl = (*self.func_decl).borrow();
+ let mut data_return = Vec::new();
+ let mut arguments_return = Vec::new();
+ let mut data_input = data.input_arguments;
+ let mut arguments_input = arguments.input_arguments;
+ let mut func_decl_return_iter = func_decl.return_arguments.iter();
+ let mut func_decl_input_iter = func_decl.input_arguments[arguments_input.len()..].iter();
+ for (idx, id) in arguments.return_arguments.iter().enumerate() {
+ let stays_as_return = match self.return_param_args.get(idx) {
+ Some(x) => *x,
+ None => return Err(TranslateError::MismatchedType),
+ };
+ if stays_as_return {
+ if let Some(var) = func_decl_return_iter.next() {
+ data_return.push((var.v_type.clone(), var.state_space));
+ arguments_return.push(*id);
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ } else {
+ if let Some(var) = func_decl_input_iter.next() {
+ data_input.push((var.v_type.clone(), var.state_space));
+ arguments_input.push(ast::ParsedOperand::Reg(*id));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ }
+ }
+ if arguments_return.len() != func_decl.return_arguments.len()
+ || arguments_input.len() != func_decl.input_arguments.len()
+ {
+ return Err(TranslateError::MismatchedType);
+ }
+ let data = ast::CallDetails {
+ uniform: data.uniform,
+ return_arguments: data_return,
+ input_arguments: data_input,
+ };
+ let arguments = ast::CallArgs {
+ func: arguments.func,
+ return_arguments: arguments_return,
+ input_arguments: arguments_input,
+ };
+ Ok(ast::Instruction::Call { data, arguments })
+ }
+}
+
+enum Statement<I, P: ast::Operand> {
+ Label(SpirvWord),
+ Variable(ast::Variable<P::Ident>),
+ Instruction(I),
+ // SPIR-V compatible replacement for PTX predicates
+ Conditional(BrachCondition),
+ LoadVar(LoadVarDetails),
+ StoreVar(StoreVarDetails),
+ Conversion(ImplicitConversion),
+ Constant(ConstantDefinition),
+ RetValue(ast::RetData, SpirvWord),
+ PtrAccess(PtrAccess<P>),
+ RepackVector(RepackVectorDetails),
+ FunctionPointer(FunctionPointerDetails),
+}
+
+impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
+ fn visit_map<To: ast::Operand<Ident = SpirvWord>, Err>(
+ self,
+ visitor: &mut impl ast::VisitorMap<T, To, Err>,
+ ) -> std::result::Result<Statement<ast::Instruction<To>, To>, Err> {
+ Ok(match self {
+ Statement::Instruction(i) => {
+ return ast::visit_map(i, visitor).map(Statement::Instruction)
+ }
+ Statement::Label(label) => {
+ Statement::Label(visitor.visit_ident(label, None, false, false)?)
+ }
+ Statement::Variable(var) => {
+ let name = visitor.visit_ident(
+ var.name,
+ Some((&var.v_type, var.state_space)),
+ true,
+ false,
+ )?;
+ Statement::Variable(ast::Variable {
+ align: var.align,
+ v_type: var.v_type,
+ state_space: var.state_space,
+ name,
+ array_init: var.array_init,
+ })
+ }
+ Statement::Conditional(conditional) => {
+ let predicate = visitor.visit_ident(
+ conditional.predicate,
+ Some((&ast::ScalarType::Pred.into(), ast::StateSpace::Reg)),
+ false,
+ false,
+ )?;
+ let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?;
+ let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?;
+ Statement::Conditional(BrachCondition {
+ predicate,
+ if_true,
+ if_false,
+ })
+ }
+ Statement::LoadVar(LoadVarDetails {
+ arg,
+ typ,
+ member_index,
+ }) => {
+ let dst = visitor.visit_ident(
+ arg.dst,
+ Some((&typ, ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(
+ arg.src,
+ Some((&typ, ast::StateSpace::Local)),
+ false,
+ false,
+ )?;
+ Statement::LoadVar(LoadVarDetails {
+ arg: ast::LdArgs { dst, src },
+ typ,
+ member_index,
+ })
+ }
+ Statement::StoreVar(StoreVarDetails {
+ arg,
+ typ,
+ member_index,
+ }) => {
+ let src1 = visitor.visit_ident(
+ arg.src1,
+ Some((&typ, ast::StateSpace::Local)),
+ false,
+ false,
+ )?;
+ let src2 = visitor.visit_ident(
+ arg.src2,
+ Some((&typ, ast::StateSpace::Reg)),
+ false,
+ false,
+ )?;
+ Statement::StoreVar(StoreVarDetails {
+ arg: ast::StArgs { src1, src2 },
+ typ,
+ member_index,
+ })
+ }
+ Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ to_type,
+ from_space,
+ to_space,
+ kind,
+ }) => {
+ let dst = visitor.visit_ident(
+ dst,
+ Some((&to_type, ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(
+ src,
+ Some((&from_type, ast::StateSpace::Reg)),
+ false,
+ false,
+ )?;
+ Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ to_type,
+ from_space,
+ to_space,
+ kind,
+ })
+ }
+ Statement::Constant(ConstantDefinition { dst, typ, value }) => {
+ let dst = visitor.visit_ident(
+ dst,
+ Some((&typ.into(), ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ Statement::Constant(ConstantDefinition { dst, typ, value })
+ }
+ Statement::RetValue(data, value) => {
+ // TODO:
+ // We should report type here
+ let value = visitor.visit_ident(value, None, false, false)?;
+ Statement::RetValue(data, value)
+ }
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src,
+ offset_src,
+ }) => {
+ let dst =
+ visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?;
+ let ptr_src = visitor.visit_ident(
+ ptr_src,
+ Some((&underlying_type, state_space)),
+ false,
+ false,
+ )?;
+ let offset_src = visitor.visit(
+ offset_src,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )),
+ false,
+ false,
+ )?;
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src,
+ offset_src,
+ })
+ }
+ Statement::RepackVector(RepackVectorDetails {
+ is_extract,
+ typ,
+ packed,
+ unpacked,
+ relaxed_type_check,
+ }) => {
+ let (packed, unpacked) = if is_extract {
+ let unpacked = unpacked
+ .into_iter()
+ .map(|ident| {
+ visitor.visit_ident(
+ ident,
+ Some((&typ.into(), ast::StateSpace::Reg)),
+ true,
+ relaxed_type_check,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ let packed = visitor.visit_ident(
+ packed,
+ Some((
+ &ast::Type::Vector(unpacked.len() as u8, typ),
+ ast::StateSpace::Reg,
+ )),
+ false,
+ false,
+ )?;
+ (packed, unpacked)
+ } else {
+ let packed = visitor.visit_ident(
+ packed,
+ Some((
+ &ast::Type::Vector(unpacked.len() as u8, typ),
+ ast::StateSpace::Reg,
+ )),
+ true,
+ false,
+ )?;
+ let unpacked = unpacked
+ .into_iter()
+ .map(|ident| {
+ visitor.visit_ident(
+ ident,
+ Some((&typ.into(), ast::StateSpace::Reg)),
+ false,
+ relaxed_type_check,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ (packed, unpacked)
+ };
+ Statement::RepackVector(RepackVectorDetails {
+ is_extract,
+ typ,
+ packed,
+ unpacked,
+ relaxed_type_check,
+ })
+ }
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
+ let dst = visitor.visit_ident(
+ dst,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::U64),
+ ast::StateSpace::Reg,
+ )),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(src, None, false, false)?;
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src })
+ }
+ })
+ }
+}
+
+struct BrachCondition {
+ predicate: SpirvWord,
+ if_true: SpirvWord,
+ if_false: SpirvWord,
+}
+struct LoadVarDetails {
+ arg: ast::LdArgs<SpirvWord>,
+ typ: ast::Type,
+ // (index, vector_width)
+ // HACK ALERT
+ // For some reason IGC explodes when you try to load from builtin vectors
+ // using OpInBoundsAccessChain, the one true way to do it is to
+ // OpLoad+OpCompositeExtract
+ member_index: Option<(u8, Option<u8>)>,
+}
+
+struct StoreVarDetails {
+ arg: ast::StArgs<SpirvWord>,
+ typ: ast::Type,
+ member_index: Option<u8>,
+}
+
+#[derive(Clone)]
+struct ImplicitConversion {
+ src: SpirvWord,
+ dst: SpirvWord,
+ from_type: ast::Type,
+ to_type: ast::Type,
+ from_space: ast::StateSpace,
+ to_space: ast::StateSpace,
+ kind: ConversionKind,
+}
+
+#[derive(PartialEq, Clone)]
+enum ConversionKind {
+ Default,
+ // zero-extend/chop/bitcast depending on types
+ SignExtend,
+ BitToPtr,
+ PtrToPtr,
+ AddressOf,
+}
+
+struct ConstantDefinition {
+ pub dst: SpirvWord,
+ pub typ: ast::ScalarType,
+ pub value: ast::ImmediateValue,
+}
+
+pub struct PtrAccess<T> {
+ underlying_type: ast::Type,
+ state_space: ast::StateSpace,
+ dst: SpirvWord,
+ ptr_src: SpirvWord,
+ offset_src: T,
+}
+
+struct RepackVectorDetails {
+ is_extract: bool,
+ typ: ast::ScalarType,
+ packed: SpirvWord,
+ unpacked: Vec<SpirvWord>,
+ relaxed_type_check: bool,
+}
+
+struct FunctionPointerDetails {
+ dst: SpirvWord,
+ src: SpirvWord,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
+struct SpirvWord(spirv::Word);
+
+impl From<spirv::Word> for SpirvWord {
+ fn from(value: spirv::Word) -> Self {
+ Self(value)
+ }
+}
+impl From<SpirvWord> for spirv::Word {
+ fn from(value: SpirvWord) -> Self {
+ value.0
+ }
+}
+
+impl ast::Operand for SpirvWord {
+ type Ident = Self;
+
+ fn from_ident(ident: Self::Ident) -> Self {
+ ident
+ }
+}
+
+fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
+ this: ast::PredAt<T>,
+ f: &mut F,
+) -> Result<ast::PredAt<U>, TranslateError> {
+ let new_label = f(this.label)?;
+ Ok(ast::PredAt {
+ not: this.not,
+ label: new_label,
+ })
+}
+
+pub(crate) enum Directive<'input> {
+ Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
+ Method(Function<'input>),
+}
+
+pub(crate) struct Function<'input> {
+ pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ pub globals: Vec<ast::Variable<SpirvWord>>,
+ pub body: Option<Vec<ExpandedStatement>>,
+ import_as: Option<String>,
+ tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
+}
+
+type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
+
+type NormalizedStatement = Statement<
+ (
+ Option<ast::PredAt<SpirvWord>>,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ),
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type UnconditionalStatement =
+ Statement<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
+
+type TypedStatement = Statement<ast::Instruction<TypedOperand>, TypedOperand>;
+
+#[derive(Copy, Clone)]
+enum TypedOperand {
+ Reg(SpirvWord),
+ RegOffset(SpirvWord, i32),
+ Imm(ast::ImmediateValue),
+ VecMember(SpirvWord, u8),
+}
+
+impl TypedOperand {
+ fn map<Err>(
+ self,
+ fn_: impl FnOnce(SpirvWord, Option<u8>) -> Result<SpirvWord, Err>,
+ ) -> Result<Self, Err> {
+ Ok(match self {
+ TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?),
+ TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off),
+ TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
+ TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx),
+ })
+ }
+
+ fn underlying_register(&self) -> Option<SpirvWord> {
+ match self {
+ Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r),
+ Self::Imm(_) => None,
+ }
+ }
+
+ fn unwrap_reg(&self) -> Result<SpirvWord, TranslateError> {
+ match self {
+ TypedOperand::Reg(reg) => Ok(*reg),
+ _ => Err(error_unreachable()),
+ }
+ }
+}
+
+impl ast::Operand for TypedOperand {
+ type Ident = SpirvWord;
+
+ fn from_ident(ident: Self::Ident) -> Self {
+ TypedOperand::Reg(ident)
+ }
+}
+
+impl<Fn> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
+ for FnVisitor<TypedOperand, TypedOperand, TranslateError, Fn>
+where
+ Fn: FnMut(
+ TypedOperand,
+ Option<(&ast::Type, ast::StateSpace)>,
+ bool,
+ bool,
+ ) -> Result<TypedOperand, TranslateError>,
+{
+ fn visit(
+ &mut self,
+ args: TypedOperand,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<TypedOperand, TranslateError> {
+ (self.fn_)(args, type_space, is_dst, relaxed_type_check)
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ match (self.fn_)(
+ TypedOperand::Reg(args),
+ type_space,
+ is_dst,
+ relaxed_type_check,
+ )? {
+ TypedOperand::Reg(reg) => Ok(reg),
+ _ => Err(TranslateError::Unreachable),
+ }
+ }
+}
+
+struct FnVisitor<
+ T,
+ U,
+ Err,
+ Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
+> {
+ fn_: Fn,
+ _marker: PhantomData<fn(T) -> Result<U, Err>>,
+}
+
+impl<
+ T,
+ U,
+ Err,
+ Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
+ > FnVisitor<T, U, Err, Fn>
+{
+ fn new(fn_: Fn) -> Self {
+ Self {
+ fn_,
+ _marker: PhantomData,
+ }
+ }
+}
+
+fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
+ this == other
+ || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
+ || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
+}
+
+fn register_external_fn_call<'a>(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ name: String,
+ return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+ input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+) -> Result<SpirvWord, TranslateError> {
+ match ptx_impl_imports.entry(name) {
+ hash_map::Entry::Vacant(entry) => {
+ let fn_id = id_defs.register_intermediate(None);
+ let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
+ let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
+ let func_decl = ast::MethodDeclaration::<SpirvWord> {
+ return_arguments,
+ name: ast::MethodName::Func(fn_id),
+ input_arguments,
+ shared_mem: None,
+ };
+ let func = Function {
+ func_decl: Rc::new(RefCell::new(func_decl)),
+ globals: Vec::new(),
+ body: None,
+ import_as: Some(entry.key().clone()),
+ tuning: Vec::new(),
+ linkage: ast::LinkingDirective::EXTERN,
+ };
+ entry.insert(Directive::Method(func));
+ Ok(fn_id)
+ }
+ hash_map::Entry::Occupied(entry) => match entry.get() {
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => Ok(fn_id),
+ ast::MethodName::Kernel(_) => Err(error_unreachable()),
+ },
+ _ => Err(error_unreachable()),
+ },
+ }
+}
+
+fn fn_arguments_to_variables<'a>(
+ id_defs: &mut NumericIdResolver,
+ args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+) -> Vec<ast::Variable<SpirvWord>> {
+ args.map(|(typ, space)| ast::Variable {
+ align: None,
+ v_type: typ.clone(),
+ state_space: space,
+ name: id_defs.register_intermediate(None),
+ array_init: Vec::new(),
+ })
+ .collect::<Vec<_>>()
+}
+
+fn hoist_function_globals(directives: Vec<Directive>) -> Vec<Directive> {
+ let mut result = Vec::with_capacity(directives.len());
+ for directive in directives {
+ match directive {
+ Directive::Method(method) => {
+ for variable in method.globals {
+ result.push(Directive::Variable(ast::LinkingDirective::NONE, variable));
+ }
+ result.push(Directive::Method(Function {
+ globals: Vec::new(),
+ ..method
+ }))
+ }
+ _ => result.push(directive),
+ }
+ }
+ result
+}
+
+struct MethodsCallMap<'input> {
+ map: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+}
+
+impl<'input> MethodsCallMap<'input> {
+ fn new(module: &[Directive<'input>]) -> Self {
+ let mut directly_called_by = HashMap::new();
+ for directive in module {
+ match directive {
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ ..
+ }) => {
+ let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
+ if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
+ entry.insert(Vec::new());
+ }
+ for statement in statements {
+ match statement {
+ Statement::Instruction(ast::Instruction::Call { data, arguments }) => {
+ multi_hash_map_append(
+ &mut directly_called_by,
+ call_key,
+ arguments.func,
+ );
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ let mut result = HashMap::new();
+ for (&method_key, children) in directly_called_by.iter() {
+ let mut visited = HashSet::new();
+ for child in children {
+ Self::add_call_map_single(&directly_called_by, &mut visited, *child);
+ }
+ result.insert(method_key, visited);
+ }
+ MethodsCallMap { map: result }
+ }
+
+ fn add_call_map_single(
+ directly_called_by: &HashMap<ast::MethodName<'input, SpirvWord>, Vec<SpirvWord>>,
+ visited: &mut HashSet<SpirvWord>,
+ current: SpirvWord,
+ ) {
+ if !visited.insert(current) {
+ return;
+ }
+ if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
+ for child in children {
+ Self::add_call_map_single(directly_called_by, visited, *child);
+ }
+ }
+ }
+
+ fn get_kernel_children(&self, name: &'input str) -> impl Iterator<Item = &SpirvWord> {
+ self.map
+ .get(&ast::MethodName::Kernel(name))
+ .into_iter()
+ .flatten()
+ }
+
+ fn kernels(&self) -> impl Iterator<Item = (&'input str, &HashSet<SpirvWord>)> {
+ self.map
+ .iter()
+ .filter_map(|(method, children)| match method {
+ ast::MethodName::Kernel(kernel) => Some((*kernel, children)),
+ ast::MethodName::Func(..) => None,
+ })
+ }
+
+ fn methods(
+ &self,
+ ) -> impl Iterator<Item = (ast::MethodName<'input, SpirvWord>, &HashSet<SpirvWord>)> {
+ self.map
+ .iter()
+ .map(|(method, children)| (*method, children))
+ }
+
+ fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) {
+ self.map
+ .get(&method)
+ .into_iter()
+ .flatten()
+ .copied()
+ .for_each(f);
+ }
+}
+
+fn multi_hash_map_append<
+ K: Eq + std::hash::Hash,
+ V,
+ Collection: std::iter::Extend<V> + std::default::Default,
+>(
+ m: &mut HashMap<K, Collection>,
+ key: K,
+ value: V,
+) {
+ match m.entry(key) {
+ hash_map::Entry::Occupied(mut entry) => {
+ entry.get_mut().extend(iter::once(value));
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(Default::default()).extend(iter::once(value));
+ }
+ }
+}
+
+fn normalize_variable_decls(directives: &mut Vec<Directive>) {
+ for directive in directives {
+ match directive {
+ Directive::Method(Function {
+ body: Some(func), ..
+ }) => {
+ func[1..].sort_by_key(|s| match s {
+ Statement::Variable(_) => 0,
+ _ => 1,
+ });
+ }
+ _ => (),
+ }
+ }
+}
+
+// HACK ALERT!
+// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
+// in the kernel as flushing denorms to zero or preserving them
+// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
+// such capability, so instead we guesstimate which use is more common in the kernel
+// and emit suitable execution mode
+fn compute_denorm_information<'input>(
+ module: &[Directive<'input>],
+) -> HashMap<ast::MethodName<'input, SpirvWord>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
+ let mut denorm_methods = HashMap::new();
+ for directive in module {
+ match directive {
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ ..
+ }) => {
+ let mut flush_counter = DenormCountMap::new();
+ let method_key = (**func_decl).borrow().name;
+ for statement in statements {
+ match statement {
+ Statement::Instruction(inst) => {
+ if let Some((flush, width)) = flush_to_zero(inst) {
+ denorm_count_map_update(&mut flush_counter, width, flush);
+ }
+ }
+ Statement::LoadVar(..) => {}
+ Statement::StoreVar(..) => {}
+ Statement::Conditional(_) => {}
+ Statement::Conversion(_) => {}
+ Statement::Constant(_) => {}
+ Statement::RetValue(_, _) => {}
+ Statement::Label(_) => {}
+ Statement::Variable(_) => {}
+ Statement::PtrAccess { .. } => {}
+ Statement::RepackVector(_) => {}
+ Statement::FunctionPointer(_) => {}
+ }
+ }
+ denorm_methods.insert(method_key, flush_counter);
+ }
+ }
+ }
+ denorm_methods
+ .into_iter()
+ .map(|(name, v)| {
+ let width_to_denorm = v
+ .into_iter()
+ .map(|(k, flush_over_preserve)| {
+ let mode = if flush_over_preserve > 0 {
+ spirv::FPDenormMode::FlushToZero
+ } else {
+ spirv::FPDenormMode::Preserve
+ };
+ (k, (mode, flush_over_preserve))
+ })
+ .collect();
+ (name, width_to_denorm)
+ })
+ .collect()
+}
+
+fn flush_to_zero(this: &ast::Instruction<SpirvWord>) -> Option<(bool, u8)> {
+ match this {
+ ast::Instruction::Ld { .. } => None,
+ ast::Instruction::St { .. } => None,
+ ast::Instruction::Mov { .. } => None,
+ ast::Instruction::Not { .. } => None,
+ ast::Instruction::Bra { .. } => None,
+ ast::Instruction::Shl { .. } => None,
+ ast::Instruction::Shr { .. } => None,
+ ast::Instruction::Ret { .. } => None,
+ ast::Instruction::Call { .. } => None,
+ ast::Instruction::Or { .. } => None,
+ ast::Instruction::And { .. } => None,
+ ast::Instruction::Cvta { .. } => None,
+ ast::Instruction::Selp { .. } => None,
+ ast::Instruction::Bar { .. } => None,
+ ast::Instruction::Atom { .. } => None,
+ ast::Instruction::AtomCas { .. } => None,
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Integer(_),
+ ..
+ } => None,
+ ast::Instruction::Add {
+ data: ast::ArithDetails::Integer(_),
+ ..
+ } => None,
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Integer { .. },
+ ..
+ } => None,
+ ast::Instruction::Mad {
+ data: ast::MadDetails::Integer { .. },
+ ..
+ } => None,
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Signed(_),
+ ..
+ } => None,
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Unsigned(_),
+ ..
+ } => None,
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Signed(_),
+ ..
+ } => None,
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Unsigned(_),
+ ..
+ } => None,
+ ast::Instruction::Cvt {
+ data:
+ ast::CvtDetails {
+ mode:
+ ast::CvtMode::ZeroExtend
+ | ast::CvtMode::SignExtend
+ | ast::CvtMode::Truncate
+ | ast::CvtMode::Bitcast
+ | ast::CvtMode::SaturateUnsignedToSigned
+ | ast::CvtMode::SaturateSignedToUnsigned
+ | ast::CvtMode::FPFromSigned(_)
+ | ast::CvtMode::FPFromUnsigned(_),
+ ..
+ },
+ ..
+ } => None,
+ ast::Instruction::Div {
+ data: ast::DivDetails::Unsigned(_),
+ ..
+ } => None,
+ ast::Instruction::Div {
+ data: ast::DivDetails::Signed(_),
+ ..
+ } => None,
+ ast::Instruction::Clz { .. } => None,
+ ast::Instruction::Brev { .. } => None,
+ ast::Instruction::Popc { .. } => None,
+ ast::Instruction::Xor { .. } => None,
+ ast::Instruction::Bfe { .. } => None,
+ ast::Instruction::Bfi { .. } => None,
+ ast::Instruction::Rem { .. } => None,
+ ast::Instruction::Prmt { .. } => None,
+ ast::Instruction::Activemask { .. } => None,
+ ast::Instruction::Membar { .. } => None,
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Float(float_control),
+ ..
+ }
+ | ast::Instruction::Add {
+ data: ast::ArithDetails::Float(float_control),
+ ..
+ }
+ | ast::Instruction::Mul {
+ data: ast::MulDetails::Float(float_control),
+ ..
+ }
+ | ast::Instruction::Mad {
+ data: ast::MadDetails::Float(float_control),
+ ..
+ } => float_control
+ .flush_to_zero
+ .map(|ftz| (ftz, float_control.type_.size_of())),
+ ast::Instruction::Fma { data, .. } => {
+ data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
+ }
+ ast::Instruction::Setp { data, .. } => {
+ data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
+ }
+ ast::Instruction::SetpBool { data, .. } => data
+ .base
+ .flush_to_zero
+ .map(|ftz| (ftz, data.base.type_.size_of())),
+ ast::Instruction::Abs { data, .. }
+ | ast::Instruction::Rsqrt { data, .. }
+ | ast::Instruction::Neg { data, .. }
+ | ast::Instruction::Ex2 { data, .. } => {
+ data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
+ }
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Float(float_control),
+ ..
+ }
+ | ast::Instruction::Max {
+ data: ast::MinMaxDetails::Float(float_control),
+ ..
+ } => float_control
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())),
+ ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => {
+ data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
+ }
+ // Modifier .ftz can only be specified when either .dtype or .atype
+ // is .f32 and applies only to single precision (.f32) inputs and results.
+ ast::Instruction::Cvt {
+ data:
+ ast::CvtDetails {
+ mode:
+ ast::CvtMode::FPExtend { flush_to_zero }
+ | ast::CvtMode::FPTruncate { flush_to_zero, .. }
+ | ast::CvtMode::FPRound { flush_to_zero, .. }
+ | ast::CvtMode::SignedFromFP { flush_to_zero, .. }
+ | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. },
+ ..
+ },
+ ..
+ } => flush_to_zero.map(|ftz| (ftz, 4)),
+ ast::Instruction::Div {
+ data:
+ ast::DivDetails::Float(ast::DivFloatDetails {
+ type_,
+ flush_to_zero,
+ ..
+ }),
+ ..
+ } => flush_to_zero.map(|ftz| (ftz, type_.size_of())),
+ ast::Instruction::Sin { data, .. }
+ | ast::Instruction::Cos { data, .. }
+ | ast::Instruction::Lg2 { data, .. } => {
+ Some((data.flush_to_zero, mem::size_of::<f32>() as u8))
+ }
+ ptx_parser::Instruction::PrmtSlow { .. } => None,
+ ptx_parser::Instruction::Trap {} => None,
+ }
+}
+
+type DenormCountMap<T> = HashMap<T, isize>;
+
+fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
+ let num_value = if value { 1 } else { -1 };
+ denorm_count_map_update_impl(map, key, num_value);
+}
+
+fn denorm_count_map_update_impl<T: Eq + Hash>(
+ map: &mut DenormCountMap<T>,
+ key: T,
+ num_value: isize,
+) {
+ match map.entry(key) {
+ hash_map::Entry::Occupied(mut counter) => {
+ *(counter.get_mut()) += num_value;
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(num_value);
+ }
+ }
+}
diff --git a/ptx/src/pass/normalize_identifiers.rs b/ptx/src/pass/normalize_identifiers.rs
new file mode 100644
index 0000000..b598345
--- /dev/null
+++ b/ptx/src/pass/normalize_identifiers.rs
@@ -0,0 +1,80 @@
+use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run<'input, 'b>(
+ id_defs: &mut FnStringIdResolver<'input, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'input, 'b>,
+ func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
+) -> Result<Vec<NormalizedStatement>, TranslateError> {
+ for s in func.iter() {
+ match s {
+ ast::Statement::Label(id) => {
+ id_defs.add_def(*id, None, false);
+ }
+ _ => (),
+ }
+ }
+ let mut result = Vec::new();
+ for s in func {
+ expand_map_variables(id_defs, fn_defs, &mut result, s)?;
+ }
+ Ok(result)
+}
+
+fn expand_map_variables<'a, 'b>(
+ id_defs: &mut FnStringIdResolver<'a, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'a, 'b>,
+ result: &mut Vec<NormalizedStatement>,
+ s: ast::Statement<ast::ParsedOperand<&'a str>>,
+) -> Result<(), TranslateError> {
+ match s {
+ ast::Statement::Block(block) => {
+ id_defs.start_block();
+ for s in block {
+ expand_map_variables(id_defs, fn_defs, result, s)?;
+ }
+ id_defs.end_block();
+ }
+ ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
+ ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
+ p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
+ .transpose()?,
+ ast::visit_map(i, &mut |id,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ _: bool,
+ _: bool| {
+ id_defs.get_id(id)
+ })?,
+ ))),
+ ast::Statement::Variable(var) => {
+ let var_type = var.var.v_type.clone();
+ match var.count {
+ Some(count) => {
+ for new_id in
+ id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
+ {
+ result.push(Statement::Variable(ast::Variable {
+ align: var.var.align,
+ v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
+ name: new_id,
+ array_init: var.var.array_init.clone(),
+ }))
+ }
+ }
+ None => {
+ let new_id =
+ id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
+ result.push(Statement::Variable(ast::Variable {
+ align: var.var.align,
+ v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
+ name: new_id,
+ array_init: var.var.array_init,
+ }));
+ }
+ }
+ }
+ };
+ Ok(())
+}
diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs
new file mode 100644
index 0000000..097d87c
--- /dev/null
+++ b/ptx/src/pass/normalize_labels.rs
@@ -0,0 +1,48 @@
+use std::{collections::HashSet, iter};
+
+use super::*;
+
+pub(super) fn run(
+ func: Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+) -> Vec<ExpandedStatement> {
+ let mut labels_in_use = HashSet::new();
+ for s in func.iter() {
+ match s {
+ Statement::Instruction(i) => {
+ if let Some(target) = jump_target(i) {
+ labels_in_use.insert(target);
+ }
+ }
+ Statement::Conditional(cond) => {
+ labels_in_use.insert(cond.if_true);
+ labels_in_use.insert(cond.if_false);
+ }
+ Statement::Variable(..)
+ | Statement::LoadVar(..)
+ | Statement::StoreVar(..)
+ | Statement::RetValue(..)
+ | Statement::Conversion(..)
+ | Statement::Constant(..)
+ | Statement::Label(..)
+ | Statement::PtrAccess { .. }
+ | Statement::RepackVector(..)
+ | Statement::FunctionPointer(..) => {}
+ }
+ }
+ iter::once(Statement::Label(id_def.register_intermediate(None)))
+ .chain(func.into_iter().filter(|s| match s {
+ Statement::Label(i) => labels_in_use.contains(i),
+ _ => true,
+ }))
+ .collect::<Vec<_>>()
+}
+
+fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
+ this: &ast::Instruction<T>,
+) -> Option<SpirvWord> {
+ match this {
+ ast::Instruction::Bra { arguments } => Some(arguments.src),
+ _ => None,
+ }
+}
diff --git a/ptx/src/pass/normalize_predicates.rs b/ptx/src/pass/normalize_predicates.rs
new file mode 100644
index 0000000..c971cfa
--- /dev/null
+++ b/ptx/src/pass/normalize_predicates.rs
@@ -0,0 +1,44 @@
+use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run(
+ func: Vec<NormalizedStatement>,
+ id_def: &mut NumericIdResolver,
+) -> Result<Vec<UnconditionalStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Instruction((pred, inst)) => {
+ if let Some(pred) = pred {
+ let if_true = id_def.register_intermediate(None);
+ let if_false = id_def.register_intermediate(None);
+ let folded_bra = match &inst {
+ ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
+ _ => None,
+ };
+ let mut branch = BrachCondition {
+ predicate: pred.label,
+ if_true: folded_bra.unwrap_or(if_true),
+ if_false,
+ };
+ if pred.not {
+ std::mem::swap(&mut branch.if_true, &mut branch.if_false);
+ }
+ result.push(Statement::Conditional(branch));
+ if folded_bra.is_none() {
+ result.push(Statement::Label(if_true));
+ result.push(Statement::Instruction(inst));
+ }
+ result.push(Statement::Label(if_false));
+ } else {
+ result.push(Statement::Instruction(inst));
+ }
+ }
+ Statement::Variable(var) => result.push(Statement::Variable(var)),
+ // Blocks are flattened when resolving ids
+ _ => return Err(error_unreachable()),
+ }
+ }
+ Ok(result)
+}
diff --git a/ptx/src/test/spirv_run/clz.spvtxt b/ptx/src/test/spirv_run/clz.spvtxt
index 9a7f254..1feb5a0 100644
--- a/ptx/src/test/spirv_run/clz.spvtxt
+++ b/ptx/src/test/spirv_run/clz.spvtxt
@@ -7,20 +7,24 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %21 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ OpExtension "SPV_KHR_float_controls"
+ OpExtension "SPV_KHR_no_integer_wrap_decoration"
+ %22 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "clz"
+ OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
- %24 = OpTypeFunction %void %ulong %ulong
+ %25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
- %1 = OpFunction %void None %24
+ %1 = OpFunction %void None %25
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
- %19 = OpLabel
+ %20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@@ -37,11 +41,12 @@
%11 = OpLoad %uint %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %uint %6
- %13 = OpExtInst %uint %21 clz %14
+ %18 = OpExtInst %uint %22 clz %14
+ %13 = OpCopyObject %uint %18
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %uint %6
- %18 = OpConvertUToPtr %_ptr_Generic_uint %15
- OpStore %18 %16 Aligned 4
+ %19 = OpConvertUToPtr %_ptr_Generic_uint %15
+ OpStore %19 %16 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt
index 5f4b050..92322ec 100644
--- a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt
+++ b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt
@@ -7,6 +7,9 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
+ OpCapability DenormFlushToZero
+ OpExtension "SPV_KHR_float_controls"
+ OpExtension "SPV_KHR_no_integer_wrap_decoration"
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_s16_s8"
@@ -45,9 +48,7 @@
%32 = OpBitcast %uint %15
%34 = OpUConvert %uchar %32
%20 = OpCopyObject %uchar %34
- %35 = OpBitcast %uchar %20
- %37 = OpSConvert %ushort %35
- %19 = OpCopyObject %ushort %37
+ %19 = OpSConvert %ushort %20
%14 = OpSConvert %uint %19
OpStore %6 %14
%16 = OpLoad %ulong %5
diff --git a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt
index 3f46103..1165290 100644
--- a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt
+++ b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt
@@ -7,9 +7,13 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
+ OpCapability DenormFlushToZero
+ OpExtension "SPV_KHR_float_controls"
+ OpExtension "SPV_KHR_no_integer_wrap_decoration"
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_s64_s32"
+ OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%27 = OpTypeFunction %void %ulong %ulong
@@ -40,9 +44,7 @@
%12 = OpCopyObject %uint %18
OpStore %6 %12
%15 = OpLoad %uint %6
- %32 = OpBitcast %uint %15
- %33 = OpSConvert %ulong %32
- %14 = OpCopyObject %ulong %33
+ %14 = OpSConvert %ulong %15
OpStore %7 %14
%16 = OpLoad %ulong %5
%17 = OpLoad %ulong %7
diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt
index b676049..07b228e 100644
--- a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt
+++ b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt
@@ -7,9 +7,13 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
+ OpCapability DenormFlushToZero
+ OpExtension "SPV_KHR_float_controls"
+ OpExtension "SPV_KHR_no_integer_wrap_decoration"
%25 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvt_sat_s_u"
+ OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%28 = OpTypeFunction %void %ulong %ulong
@@ -42,7 +46,7 @@
%15 = OpSatConvertSToU %uint %16
OpStore %7 %15
%18 = OpLoad %uint %7
- %17 = OpBitcast %uint %18
+ %17 = OpCopyObject %uint %18
OpStore %8 %17
%19 = OpLoad %ulong %5
%20 = OpLoad %uint %8
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index f5dfa64..a798720 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -1,3 +1,4 @@
+use crate::pass;
use crate::ptx;
use crate::translate;
use hip_runtime_sys::hipError_t;
@@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>(
spirv_txt: &'a [u8],
spirv_file_name: &'a str,
) -> Result<(), Box<dyn error::Error + 'a>> {
- let mut errors = Vec::new();
- let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
- assert!(errors.len() == 0);
- let spirv_module = translate::to_spirv_module(ast)?;
+ let ast = ptx_parser::parse_module_checked(ptx_txt).unwrap();
+ let spirv_module = pass::to_spirv_module(ast)?;
let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut());
diff --git a/ptx/src/test/spirv_run/popc.spvtxt b/ptx/src/test/spirv_run/popc.spvtxt
index 845add7..c41e792 100644
--- a/ptx/src/test/spirv_run/popc.spvtxt
+++ b/ptx/src/test/spirv_run/popc.spvtxt
@@ -7,20 +7,24 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %21 = OpExtInstImport "OpenCL.std"
+ OpCapability DenormFlushToZero
+ OpExtension "SPV_KHR_float_controls"
+ OpExtension "SPV_KHR_no_integer_wrap_decoration"
+ %22 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "popc"
+ OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
- %24 = OpTypeFunction %void %ulong %ulong
+ %25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
- %1 = OpFunction %void None %24
+ %1 = OpFunction %void None %25
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
- %19 = OpLabel
+ %20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@@ -37,11 +41,12 @@
%11 = OpLoad %uint %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %uint %6
- %13 = OpBitCount %uint %14
+ %18 = OpBitCount %uint %14
+ %13 = OpCopyObject %uint %18
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %uint %6
- %18 = OpConvertUToPtr %_ptr_Generic_uint %15
- OpStore %18 %16 Aligned 4
+ %19 = OpConvertUToPtr %_ptr_Generic_uint %15
+ OpStore %19 %16 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/vector.ptx b/ptx/src/test/spirv_run/vector.ptx
index 90b8ad3..ba07e15 100644
--- a/ptx/src/test/spirv_run/vector.ptx
+++ b/ptx/src/test/spirv_run/vector.ptx
@@ -1,4 +1,4 @@
-// Excersise as many features of vector types as possible
+// Exercise as many features of vector types as possible
.version 6.5
.target sm_60
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index db1063b..9b422fd 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>(
for statement in sorted_statements {
match statement {
Statement::Variable(
- var
- @
- ast::Variable {
+ var @ ast::Variable {
state_space: ast::StateSpace::Shared,
..
},
)
| Statement::Variable(
- var
- @
- ast::Variable {
+ var @ ast::Variable {
state_space: ast::StateSpace::Global,
..
},
@@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
- details
- @
- ast::AtomDetails {
+ details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
@@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
- details
- @
- ast::AtomDetails {
+ details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
@@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>(
)?);
}
Statement::Instruction(ast::Instruction::Atom(
- details
- @
- ast::AtomDetails {
+ details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Float {
op: ast::AtomFloatOp::Add,