diff options
34 files changed, 14776 insertions, 55 deletions
@@ -1,5 +1,7 @@ [workspace]
+resolver = "2"
+
members = [
"cuda_base",
"cuda_types",
@@ -15,6 +17,9 @@ members = [ "zluda_redirect",
"zluda_ml",
"ptx",
+ "ptx_parser",
+ "ptx_parser_macros",
+ "ptx_parser_macros_impl",
]
default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"]
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2ac1f68..d485286 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" [lib] [dependencies] -lalrpop-util = "0.19" +ptx_parser = { path = "../ptx_parser" } regex = "1" rspirv = "0.7" spirv_headers = "1.5" @@ -17,8 +17,12 @@ bit-vec = "0.6" half ="1.6" bitflags = "1.2" +[dependencies.lalrpop-util] +version = "0.19.12" +features = ["lexer"] + [build-dependencies.lalrpop] -version = "0.19" +version = "0.19.12" features = ["lexer"] [dev-dependencies] diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index d308479..358b8ce 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -16,6 +16,8 @@ pub enum PtxError { source: ParseFloatError, }, #[error("")] + Unsupported32Bit, + #[error("")] SyntaxError, #[error("")] NonF32Ftz, @@ -32,15 +34,9 @@ pub enum PtxError { #[error("")] NonExternPointer, #[error("{start}:{end}")] - UnrecognizedStatement { - start: usize, - end: usize, - }, + UnrecognizedStatement { start: usize, end: usize }, #[error("{start}:{end}")] - UnrecognizedDirective { - start: usize, - end: usize, - }, + UnrecognizedDirective { start: usize, end: usize }, } // For some weird reson this is illegal: @@ -576,11 +572,15 @@ impl CvtDetails { if saturate { if src.kind() == ScalarKind::Signed { if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); + err.push(ParseError::User { + error: PtxError::SyntaxError, + }); } } else { if dst == src || dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); + err.push(ParseError::User { + error: PtxError::SyntaxError, + }); } } } @@ -596,7 +596,9 @@ impl CvtDetails { err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, ) -> Self { if flush_to_zero && dst != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); + err.push(ParseError::from(lalrpop_util::ParseError::User { + error: PtxError::NonF32Ftz, + })); } CvtDetails::FloatFromInt(CvtDesc { dst, @@ -616,7 +618,9 @@ impl CvtDetails { err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, ) -> Self { if flush_to_zero && src != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); + err.push(ParseError::from(lalrpop_util::ParseError::User { + error: PtxError::NonF32Ftz, + })); } CvtDetails::IntFromFloat(CvtDesc { dst, diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 1cb9630..5e95dae 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -24,6 +24,7 @@ lalrpop_mod!( ); pub mod ast; +pub(crate) mod pass; #[cfg(test)] mod test; mod translate; diff --git a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs new file mode 100644 index 0000000..1dac7fd --- /dev/null +++ b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs @@ -0,0 +1,299 @@ +use std::collections::{BTreeMap, BTreeSet};
+
+use super::*;
+
+/*
+ PTX represents dynamically allocated shared local memory as
+ .extern .shared .b32 shared_mem[];
+ In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
+ And in AMD compilation
+ This pass looks for all uses of .extern .shared and converts them to
+ an additional method argument
+ The question is how this artificial argument should be expressed. There are
+ several options:
+ * Straight conversion:
+ .shared .b32 shared_mem[]
+ * Introduce .param_shared statespace:
+ .param_shared .b32 shared_mem
+ or
+ .param_shared .b32 shared_mem[]
+ * Introduce .shared_ptr <SCALAR> type:
+ .param .shared_ptr .b32 shared_mem
+ * Reuse .ptr hint:
+ .param .u64 .ptr shared_mem
+ This is the most tempting, but also the most nonsensical, .ptr is just a
+ hint, which has no semantical meaning (and the output of our
+ transformation has a semantical meaning - we emit additional
+ "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
+*/
+pub(super) fn run<'input>(
+ module: Vec<Directive<'input>>,
+ kernels_methods_call_map: &MethodsCallMap<'input>,
+ new_id: &mut impl FnMut() -> SpirvWord,
+) -> Result<Vec<Directive<'input>>, TranslateError> {
+ let mut globals_shared = HashMap::new();
+ for dir in module.iter() {
+ match dir {
+ Directive::Variable(
+ _,
+ ast::Variable {
+ state_space: ast::StateSpace::Shared,
+ name,
+ v_type,
+ ..
+ },
+ ) => {
+ globals_shared.insert(*name, v_type.clone());
+ }
+ _ => {}
+ }
+ }
+ if globals_shared.len() == 0 {
+ return Ok(module);
+ }
+ let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet<SpirvWord>>::new();
+ let module = module
+ .into_iter()
+ .map(|directive| match directive {
+ Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ import_as,
+ tuning,
+ linkage,
+ }) => {
+ let call_key = (*func_decl).borrow().name;
+ let statements = statements
+ .into_iter()
+ .map(|statement| {
+ statement.visit_map(
+ &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
+ if let Some(_) = globals_shared.get(&id) {
+ methods_to_directly_used_shared_globals
+ .entry(call_key)
+ .or_insert_with(HashSet::new)
+ .insert(id);
+ }
+ Ok::<_, TranslateError>(id)
+ },
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok::<_, TranslateError>(Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ import_as,
+ tuning,
+ linkage,
+ }))
+ }
+ directive => Ok(directive),
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
+ // make sure it gets propagated to `fn1` and `kernel`
+ let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
+ methods_to_directly_used_shared_globals,
+ kernels_methods_call_map,
+ );
+ // now visit every method declaration and inject those additional arguments
+ let mut directives = Vec::with_capacity(module.len());
+ for directive in module.into_iter() {
+ match directive {
+ Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ import_as,
+ tuning,
+ linkage,
+ }) => {
+ let statements = {
+ let func_decl_ref = &mut (*func_decl).borrow_mut();
+ let method_name = func_decl_ref.name;
+ insert_arguments_remap_statements(
+ new_id,
+ kernels_methods_call_map,
+ &globals_shared,
+ &methods_to_indirectly_used_shared_globals,
+ method_name,
+ &mut directives,
+ func_decl_ref,
+ statements,
+ )?
+ };
+ directives.push(Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ import_as,
+ tuning,
+ linkage,
+ }));
+ }
+ directive => directives.push(directive),
+ }
+ }
+ Ok(directives)
+}
+
+// We need to compute two kinds of information:
+// * If it's a kernel -> size of .shared globals in use (direct or indirect)
+// * If it's a function -> does it use .shared global (directly or indirectly)
+fn resolve_indirect_uses_of_globals_shared<'input>(
+ methods_use_of_globals_shared: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+ kernels_methods_call_map: &MethodsCallMap<'input>,
+) -> HashMap<ast::MethodName<'input, SpirvWord>, BTreeSet<SpirvWord>> {
+ let mut result = HashMap::new();
+ for (method, callees) in kernels_methods_call_map.methods() {
+ let mut indirect_globals = methods_use_of_globals_shared
+ .get(&method)
+ .into_iter()
+ .flatten()
+ .copied()
+ .collect::<BTreeSet<_>>();
+ for &callee in callees {
+ indirect_globals.extend(
+ methods_use_of_globals_shared
+ .get(&ast::MethodName::Func(callee))
+ .into_iter()
+ .flatten()
+ .copied(),
+ );
+ }
+ result.insert(method, indirect_globals);
+ }
+ result
+}
+
+fn insert_arguments_remap_statements<'input>(
+ new_id: &mut impl FnMut() -> SpirvWord,
+ kernels_methods_call_map: &MethodsCallMap<'input>,
+ globals_shared: &HashMap<SpirvWord, ast::Type>,
+ methods_to_indirectly_used_shared_globals: &HashMap<
+ ast::MethodName<'input, SpirvWord>,
+ BTreeSet<SpirvWord>,
+ >,
+ method_name: ast::MethodName<SpirvWord>,
+ result: &mut Vec<Directive>,
+ func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<SpirvWord>>,
+ statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ let remapped_globals_in_method =
+ if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) {
+ match method_name {
+ ast::MethodName::Func(..) => {
+ let remapped_globals = method_globals
+ .iter()
+ .map(|global| {
+ (
+ *global,
+ (
+ new_id(),
+ globals_shared
+ .get(&global)
+ .unwrap_or_else(|| todo!())
+ .clone(),
+ ),
+ )
+ })
+ .collect::<BTreeMap<_, _>>();
+ for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() {
+ func_decl_ref.input_arguments.push(ast::Variable {
+ align: None,
+ v_type: shared_global_type.clone(),
+ state_space: ast::StateSpace::Shared,
+ name: *new_shared_global_id,
+ array_init: Vec::new(),
+ });
+ }
+ remapped_globals
+ }
+ ast::MethodName::Kernel(..) => method_globals
+ .iter()
+ .map(|global| {
+ (
+ *global,
+ (
+ *global,
+ globals_shared
+ .get(&global)
+ .unwrap_or_else(|| todo!())
+ .clone(),
+ ),
+ )
+ })
+ .collect::<BTreeMap<_, _>>(),
+ }
+ } else {
+ return Ok(statements);
+ };
+ replace_uses_of_shared_memory(
+ new_id,
+ methods_to_indirectly_used_shared_globals,
+ statements,
+ remapped_globals_in_method,
+ )
+}
+
+fn replace_uses_of_shared_memory<'input>(
+ new_id: &mut impl FnMut() -> SpirvWord,
+ methods_to_indirectly_used_shared_globals: &HashMap<
+ ast::MethodName<'input, SpirvWord>,
+ BTreeSet<SpirvWord>,
+ >,
+ statements: Vec<ExpandedStatement>,
+ remapped_globals_in_method: BTreeMap<SpirvWord, (SpirvWord, ast::Type)>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ match statement {
+ Statement::Instruction(ast::Instruction::Call {
+ mut data,
+ mut arguments,
+ }) => {
+ // We can safely skip checking call arguments,
+ // because there's simply no way to pass shared ptr
+ // without converting it to .b64 first
+ if let Some(shared_globals_used_by_callee) =
+ methods_to_indirectly_used_shared_globals
+ .get(&ast::MethodName::Func(arguments.func))
+ {
+ for &shared_global_used_by_callee in shared_globals_used_by_callee {
+ let (remapped_shared_id, type_) = remapped_globals_in_method
+ .get(&shared_global_used_by_callee)
+ .unwrap_or_else(|| todo!());
+ data.input_arguments
+ .push((type_.clone(), ast::StateSpace::Shared));
+ arguments.input_arguments.push(*remapped_shared_id);
+ }
+ }
+ result.push(Statement::Instruction(ast::Instruction::Call {
+ data,
+ arguments,
+ }))
+ }
+ statement => {
+ let new_statement =
+ statement.visit_map(&mut |id,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ _,
+ _| {
+ Ok::<_, TranslateError>(
+ if let Some((remapped_shared_id, _)) =
+ remapped_globals_in_method.get(&id)
+ {
+ *remapped_shared_id
+ } else {
+ id
+ },
+ )
+ })?;
+ result.push(new_statement);
+ }
+ }
+ }
+ Ok(result)
+}
diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs new file mode 100644 index 0000000..455a8c2 --- /dev/null +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -0,0 +1,524 @@ +use super::*;
+use ptx_parser as ast;
+use std::{
+ collections::{BTreeSet, HashSet},
+ iter,
+ rc::Rc,
+};
+
+/*
+ Our goal here is to transform
+ .visible .entry foobar(.param .u64 input) {
+ .reg .b64 in_addr;
+ .reg .b64 in_addr2;
+ ld.param.u64 in_addr, [input];
+ cvta.to.global.u64 in_addr2, in_addr;
+ }
+ into:
+ .visible .entry foobar(.param .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ ld.param.u8[] in_addr, [input];
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.reg .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ mov.u8[] in_addr, input;
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.param ptr<u8, global> input) {
+ .reg ptr<u8, global> in_addr;
+ .reg ptr<u8, global> in_addr2;
+ ld.param.ptr<u8, global> in_addr, [input];
+ mov.ptr<u8, global> in_addr2, in_addr;
+ }
+*/
+// TODO: detect more patterns (mov, call via reg, call via param)
+// TODO: don't convert to ptr if the register is not ultimately used for ld/st
+// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
+// argument expansion
+// TODO: propagate out of calls and into calls
+pub(super) fn run<'a, 'input>(
+ func_args: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ func_body: Vec<TypedStatement>,
+ id_defs: &mut NumericIdResolver<'a>,
+) -> Result<
+ (
+ Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ Vec<TypedStatement>,
+ ),
+ TranslateError,
+> {
+ let mut method_decl = func_args.borrow_mut();
+ if !matches!(method_decl.name, ast::MethodName::Kernel(..)) {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
+ if Rc::strong_count(&func_args) != 1 {
+ return Err(error_unreachable());
+ }
+ let func_args_64bit = (*method_decl)
+ .input_arguments
+ .iter()
+ .filter_map(|arg| match arg.v_type {
+ ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
+ _ => None,
+ })
+ .collect::<HashSet<_>>();
+ let mut stateful_markers = Vec::new();
+ let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Cvta {
+ data:
+ ast::CvtaDetails {
+ state_space: ast::StateSpace::Global,
+ direction: ast::CvtaDirection::GenericToExplicit,
+ },
+ arguments,
+ }) => {
+ if let (TypedOperand::Reg(dst), Some(src)) =
+ (arguments.dst, arguments.src.underlying_register())
+ {
+ if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) {
+ stateful_markers.push((dst, src));
+ }
+ }
+ }
+ Statement::Instruction(ast::Instruction::Ld {
+ data:
+ ast::LdDetails {
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::U64),
+ ..
+ },
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Ld {
+ data:
+ ast::LdDetails {
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::S64),
+ ..
+ },
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Ld {
+ data:
+ ast::LdDetails {
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::B64),
+ ..
+ },
+ arguments,
+ }) => {
+ if let (TypedOperand::Reg(dst), Some(src)) =
+ (arguments.dst, arguments.src.underlying_register())
+ {
+ if func_args_64bit.contains(&src) {
+ multi_hash_map_append(&mut stateful_init_reg, dst, src);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ if stateful_markers.len() == 0 {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
+ let mut func_args_ptr = HashSet::new();
+ let mut regs_ptr_current = HashSet::new();
+ for (dst, src) in stateful_markers {
+ if let Some(func_args) = stateful_init_reg.get(&src) {
+ for a in func_args {
+ func_args_ptr.insert(*a);
+ regs_ptr_current.insert(src);
+ regs_ptr_current.insert(dst);
+ }
+ }
+ }
+ // BTreeSet here to have a stable order of iteration,
+ // unfortunately our tests rely on it
+ let mut regs_ptr_seen = BTreeSet::new();
+ while regs_ptr_current.len() > 0 {
+ let mut regs_ptr_new = HashSet::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Add {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::U64,
+ saturate: false,
+ }),
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Add {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::S64,
+ saturate: false,
+ }),
+ arguments,
+ }) => {
+ // TODO: don't mark result of double pointer sub or double
+ // pointer add as ptr result
+ if let (TypedOperand::Reg(dst), Some(src1)) =
+ (arguments.dst, arguments.src1.underlying_register())
+ {
+ if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
+ regs_ptr_new.insert(dst);
+ }
+ } else if let (TypedOperand::Reg(dst), Some(src2)) =
+ (arguments.dst, arguments.src2.underlying_register())
+ {
+ if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
+ regs_ptr_new.insert(dst);
+ }
+ }
+ }
+
+ Statement::Instruction(ast::Instruction::Sub {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::U64,
+ saturate: false,
+ }),
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Sub {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::S64,
+ saturate: false,
+ }),
+ arguments,
+ }) => {
+ // TODO: don't mark result of double pointer sub or double
+ // pointer add as ptr result
+ if let (TypedOperand::Reg(dst), Some(src1)) =
+ (arguments.dst, arguments.src1.underlying_register())
+ {
+ if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
+ regs_ptr_new.insert(dst);
+ }
+ } else if let (TypedOperand::Reg(dst), Some(src2)) =
+ (arguments.dst, arguments.src2.underlying_register())
+ {
+ if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
+ regs_ptr_new.insert(dst);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ for id in regs_ptr_current {
+ regs_ptr_seen.insert(id);
+ }
+ regs_ptr_current = regs_ptr_new;
+ }
+ drop(regs_ptr_current);
+ let mut remapped_ids = HashMap::new();
+ let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
+ for reg in regs_ptr_seen {
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Reg,
+ );
+ result.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: new_id,
+ array_init: Vec::new(),
+ v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ state_space: ast::StateSpace::Reg,
+ }));
+ remapped_ids.insert(reg, new_id);
+ }
+ for arg in (*method_decl).input_arguments.iter_mut() {
+ if !func_args_ptr.contains(&arg.name) {
+ continue;
+ }
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Param,
+ );
+ let old_name = arg.name;
+ arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
+ arg.name = new_id;
+ remapped_ids.insert(old_name, new_id);
+ }
+ for statement in func_body {
+ match statement {
+ l @ Statement::Label(_) => result.push(l),
+ c @ Statement::Conditional(_) => result.push(c),
+ c @ Statement::Constant(..) => result.push(c),
+ Statement::Variable(var) => {
+ if !remapped_ids.contains_key(&var.name) {
+ result.push(Statement::Variable(var));
+ }
+ }
+ Statement::Instruction(ast::Instruction::Add {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::U64,
+ saturate: false,
+ }),
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Add {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::S64,
+ saturate: false,
+ }),
+ arguments,
+ }) if is_add_ptr_direct(&remapped_ids, &arguments) => {
+ let (ptr, offset) = match arguments.src1.underlying_register() {
+ Some(src1) if remapped_ids.contains_key(&src1) => {
+ (remapped_ids.get(&src1).unwrap(), arguments.src2)
+ }
+ Some(src2) if remapped_ids.contains_key(&src2) => {
+ (remapped_ids.get(&src2).unwrap(), arguments.src1)
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let dst = arguments.dst.unwrap_reg()?;
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
+ state_space: ast::StateSpace::Global,
+ dst: *remapped_ids.get(&dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: offset,
+ }))
+ }
+ Statement::Instruction(ast::Instruction::Sub {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::U64,
+ saturate: false,
+ }),
+ arguments,
+ })
+ | Statement::Instruction(ast::Instruction::Sub {
+ data:
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: ast::ScalarType::S64,
+ saturate: false,
+ }),
+ arguments,
+ }) if is_sub_ptr_direct(&remapped_ids, &arguments) => {
+ let (ptr, offset) = match arguments.src1.underlying_register() {
+ Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2),
+ _ => return Err(error_unreachable()),
+ };
+ let offset_neg = id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::Instruction(ast::Instruction::Neg {
+ data: ast::TypeFtz {
+ type_: ast::ScalarType::S64,
+ flush_to_zero: None,
+ },
+ arguments: ast::NegArgs {
+ src: offset,
+ dst: TypedOperand::Reg(offset_neg),
+ },
+ }));
+ let dst = arguments.dst.unwrap_reg()?;
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
+ state_space: ast::StateSpace::Global,
+ dst: *remapped_ids.get(&dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: TypedOperand::Reg(offset_neg),
+ }))
+ }
+ inst @ Statement::Instruction(_) => {
+ let mut post_statements = Vec::new();
+ let new_statement = inst.visit_map(&mut FnVisitor::new(
+ |operand, type_space, is_dst, relaxed_conversion| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &mut result,
+ &mut post_statements,
+ operand,
+ type_space,
+ is_dst,
+ relaxed_conversion,
+ )
+ },
+ ))?;
+ result.push(new_statement);
+ result.extend(post_statements);
+ }
+ repack @ Statement::RepackVector(_) => {
+ let mut post_statements = Vec::new();
+ let new_statement = repack.visit_map(&mut FnVisitor::new(
+ |operand, type_space, is_dst, relaxed_conversion| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &mut result,
+ &mut post_statements,
+ operand,
+ type_space,
+ is_dst,
+ relaxed_conversion,
+ )
+ },
+ ))?;
+ result.push(new_statement);
+ result.extend(post_statements);
+ }
+ _ => return Err(error_unreachable()),
+ }
+ }
+ drop(method_decl);
+ Ok((func_args, result))
+}
+
+fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
+ match id_defs.get_typed(id) {
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
+ _ => false,
+ }
+}
+
+fn is_add_ptr_direct(
+ remapped_ids: &HashMap<SpirvWord, SpirvWord>,
+ arg: &ast::AddArgs<TypedOperand>,
+) -> bool {
+ match arg.dst {
+ TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
+ return false
+ }
+ TypedOperand::Reg(dst) => {
+ if !remapped_ids.contains_key(&dst) {
+ return false;
+ }
+ if let Some(ref src1_reg) = arg.src1.underlying_register() {
+ if remapped_ids.contains_key(src1_reg) {
+ // don't trigger optimization when adding two pointers
+ if let Some(ref src2_reg) = arg.src2.underlying_register() {
+ return !remapped_ids.contains_key(src2_reg);
+ }
+ }
+ }
+ if let Some(ref src2_reg) = arg.src2.underlying_register() {
+ remapped_ids.contains_key(src2_reg)
+ } else {
+ false
+ }
+ }
+ }
+}
+
+fn is_sub_ptr_direct(
+ remapped_ids: &HashMap<SpirvWord, SpirvWord>,
+ arg: &ast::SubArgs<TypedOperand>,
+) -> bool {
+ match arg.dst {
+ TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
+ return false
+ }
+ TypedOperand::Reg(dst) => {
+ if !remapped_ids.contains_key(&dst) {
+ return false;
+ }
+ match arg.src1.underlying_register() {
+ Some(ref src1_reg) => {
+ if remapped_ids.contains_key(src1_reg) {
+ // don't trigger optimization when subtracting two pointers
+ arg.src2
+ .underlying_register()
+ .map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg))
+ } else {
+ false
+ }
+ }
+ None => false,
+ }
+ }
+ }
+}
+
+fn convert_to_stateful_memory_access_postprocess(
+ id_defs: &mut NumericIdResolver,
+ remapped_ids: &HashMap<SpirvWord, SpirvWord>,
+ result: &mut Vec<TypedStatement>,
+ post_statements: &mut Vec<TypedStatement>,
+ operand: TypedOperand,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_conversion: bool,
+) -> Result<TypedOperand, TranslateError> {
+ operand.map(|operand, _| {
+ Ok(match remapped_ids.get(&operand) {
+ Some(new_id) => {
+ let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
+ // TODO: readd if required
+ if let Some((expected_type, expected_space)) = type_space {
+ let implicit_conversion = if relaxed_conversion {
+ if is_dst {
+ super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper
+ } else {
+ super::insert_implicit_conversions::should_convert_relaxed_src_wrapper
+ }
+ } else {
+ super::insert_implicit_conversions::default_implicit_conversion
+ };
+ if implicit_conversion(
+ (new_operand_space, &new_operand_type),
+ (expected_space, expected_type),
+ )
+ .is_ok()
+ {
+ return Ok(*new_id);
+ }
+ }
+ let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
+ let converting_id = id_defs
+ .register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
+ let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) {
+ ConversionKind::Default
+ } else {
+ ConversionKind::PtrToPtr
+ };
+ if is_dst {
+ post_statements.push(Statement::Conversion(ImplicitConversion {
+ src: converting_id,
+ dst: *new_id,
+ from_type: old_operand_type,
+ from_space: old_operand_space,
+ to_type: new_operand_type,
+ to_space: new_operand_space,
+ kind,
+ }));
+ converting_id
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from_type: new_operand_type,
+ from_space: new_operand_space,
+ to_type: old_operand_type,
+ to_space: old_operand_space,
+ kind,
+ }));
+ converting_id
+ }
+ }
+ None => operand,
+ })
+ })
+}
diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs new file mode 100644 index 0000000..550c662 --- /dev/null +++ b/ptx/src/pass/convert_to_typed.rs @@ -0,0 +1,138 @@ +use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run(
+ func: Vec<UnconditionalStatement>,
+ fn_defs: &GlobalFnDeclResolver,
+ id_defs: &mut NumericIdResolver,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let mut result = Vec::<TypedStatement>::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::Mov {
+ data,
+ arguments:
+ ast::MovArgs {
+ dst: ast::ParsedOperand::Reg(dst_reg),
+ src: ast::ParsedOperand::Reg(src_reg),
+ },
+ } if fn_defs.fns.contains_key(&src_reg) => {
+ if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
+ return Err(error_mismatched_type());
+ }
+ result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
+ dst: dst_reg,
+ src: src_reg,
+ }));
+ }
+ ast::Instruction::Call { data, arguments } => {
+ let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?;
+ let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?;
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let reresolved_call =
+ Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?);
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ inst => {
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?);
+ visitor.func.push(instruction);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ },
+ Statement::Label(i) => result.push(Statement::Label(i)),
+ Statement::Variable(v) => result.push(Statement::Variable(v)),
+ Statement::Conditional(c) => result.push(Statement::Conditional(c)),
+ _ => return Err(error_unreachable()),
+ }
+ }
+ Ok(result)
+}
+
+struct VectorRepackVisitor<'a, 'b> {
+ func: &'b mut Vec<TypedStatement>,
+ id_def: &'b mut NumericIdResolver<'a>,
+ post_stmts: Option<TypedStatement>,
+}
+
+impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
+ fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
+ VectorRepackVisitor {
+ func,
+ id_def,
+ post_stmts: None,
+ }
+ }
+
+ fn convert_vector(
+ &mut self,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ typ: &ast::Type,
+ state_space: ast::StateSpace,
+ idx: Vec<SpirvWord>,
+ ) -> Result<SpirvWord, TranslateError> {
+ // mov.u32 foobar, {a,b};
+ let scalar_t = match typ {
+ ast::Type::Vector(_, scalar_t) => *scalar_t,
+ _ => return Err(error_mismatched_type()),
+ };
+ let temp_vec = self
+ .id_def
+ .register_intermediate(Some((typ.clone(), state_space)));
+ let statement = Statement::RepackVector(RepackVectorDetails {
+ is_extract: is_dst,
+ typ: scalar_t,
+ packed: temp_vec,
+ unpacked: idx,
+ relaxed_type_check,
+ });
+ if is_dst {
+ self.post_stmts = Some(statement);
+ } else {
+ self.func.push(statement);
+ }
+ Ok(temp_vec)
+ }
+}
+
+impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
+ for VectorRepackVisitor<'a, 'b>
+{
+ fn visit_ident(
+ &mut self,
+ ident: SpirvWord,
+ _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ _: bool,
+ _: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ Ok(ident)
+ }
+
+ fn visit(
+ &mut self,
+ op: ast::ParsedOperand<SpirvWord>,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match op {
+ ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
+ ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
+ ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
+ ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
+ ast::ParsedOperand::VecPack(vec) => {
+ let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?;
+ TypedOperand::Reg(self.convert_vector(
+ is_dst,
+ relaxed_type_check,
+ type_,
+ space,
+ vec,
+ )?)
+ }
+ })
+ }
+}
diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs new file mode 100644 index 0000000..5147b79 --- /dev/null +++ b/ptx/src/pass/emit_spirv.rs @@ -0,0 +1,2763 @@ +use super::*;
+use half::f16;
+use ptx_parser as ast;
+use rspirv::{binary::Assemble, dr};
+use std::{
+ collections::{HashMap, HashSet},
+ ffi::CString,
+ mem,
+};
+
+pub(super) fn run<'input>(
+ mut builder: dr::Builder,
+ id_defs: &GlobalStringIdResolver<'input>,
+ call_map: MethodsCallMap<'input>,
+ denorm_information: HashMap<
+ ptx_parser::MethodName<SpirvWord>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
+ directives: Vec<Directive<'input>>,
+) -> Result<(dr::Module, HashMap<String, KernelInfo>, CString), TranslateError> {
+ builder.set_version(1, 3);
+ emit_capabilities(&mut builder);
+ emit_extensions(&mut builder);
+ let opencl_id = emit_opencl_import(&mut builder);
+ emit_memory_model(&mut builder);
+ let mut map = TypeWordMap::new(&mut builder);
+ //emit_builtins(&mut builder, &mut map, &id_defs);
+ let mut kernel_info = HashMap::new();
+ let (build_options, should_flush_denorms) =
+ emit_denorm_build_string(&call_map, &denorm_information);
+ let (directives, globals_use_map) = get_globals_use_map(directives);
+ emit_directives(
+ &mut builder,
+ &mut map,
+ &id_defs,
+ opencl_id,
+ should_flush_denorms,
+ &call_map,
+ globals_use_map,
+ directives,
+ &mut kernel_info,
+ )?;
+ Ok((builder.module(), kernel_info, build_options))
+}
+
+fn emit_capabilities(builder: &mut dr::Builder) {
+ builder.capability(spirv::Capability::GenericPointer);
+ builder.capability(spirv::Capability::Linkage);
+ builder.capability(spirv::Capability::Addresses);
+ builder.capability(spirv::Capability::Kernel);
+ builder.capability(spirv::Capability::Int8);
+ builder.capability(spirv::Capability::Int16);
+ builder.capability(spirv::Capability::Int64);
+ builder.capability(spirv::Capability::Float16);
+ builder.capability(spirv::Capability::Float64);
+ builder.capability(spirv::Capability::DenormFlushToZero);
+ // TODO: re-enable when Intel float control extension works
+ //builder.capability(spirv::Capability::FunctionFloatControlINTEL);
+}
+
+// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html
+fn emit_extensions(builder: &mut dr::Builder) {
+ // TODO: re-enable when Intel float control extension works
+ //builder.extension("SPV_INTEL_float_controls2");
+ builder.extension("SPV_KHR_float_controls");
+ builder.extension("SPV_KHR_no_integer_wrap_decoration");
+}
+
+fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {
+ builder.ext_inst_import("OpenCL.std")
+}
+
+fn emit_memory_model(builder: &mut dr::Builder) {
+ builder.memory_model(
+ spirv::AddressingModel::Physical64,
+ spirv::MemoryModel::OpenCL,
+ );
+}
+
+struct TypeWordMap {
+ void: spirv::Word,
+ complex: HashMap<SpirvType, SpirvWord>,
+ constants: HashMap<(SpirvType, u64), SpirvWord>,
+}
+
+impl TypeWordMap {
+ fn new(b: &mut dr::Builder) -> TypeWordMap {
+ let void = b.type_void(None);
+ TypeWordMap {
+ void: void,
+ complex: HashMap::<SpirvType, SpirvWord>::new(),
+ constants: HashMap::new(),
+ }
+ }
+
+ fn void(&self) -> spirv::Word {
+ self.void
+ }
+
+ fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord {
+ let key: SpirvScalarKey = t.into();
+ self.get_or_add_spirv_scalar(b, key)
+ }
+
+ fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord {
+ *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| {
+ SpirvWord(match key {
+ SpirvScalarKey::B8 => b.type_int(None, 8, 0),
+ SpirvScalarKey::B16 => b.type_int(None, 16, 0),
+ SpirvScalarKey::B32 => b.type_int(None, 32, 0),
+ SpirvScalarKey::B64 => b.type_int(None, 64, 0),
+ SpirvScalarKey::F16 => b.type_float(None, 16),
+ SpirvScalarKey::F32 => b.type_float(None, 32),
+ SpirvScalarKey::F64 => b.type_float(None, 64),
+ SpirvScalarKey::Pred => b.type_bool(None),
+ SpirvScalarKey::F16x2 => todo!(),
+ })
+ })
+ }
+
+ fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord {
+ match t {
+ SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
+ SpirvType::Pointer(ref typ, storage) => {
+ let base = self.get_or_add(b, *typ.clone());
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0)))
+ }
+ SpirvType::Vector(typ, len) => {
+ let base = self.get_or_add_spirv_scalar(b, typ);
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32)))
+ }
+ SpirvType::Array(typ, array_dimensions) => {
+ let (base_type, length) = match &*array_dimensions {
+ &[] => {
+ return self.get_or_add(b, SpirvType::Base(typ));
+ }
+ &[len] => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
+ let base = self.get_or_add_spirv_scalar(b, typ);
+ let len_const = b.constant_u32(u32_type.0, None, len);
+ (base, len_const)
+ }
+ array_dimensions => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
+ let base = self
+ .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
+ let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]);
+ (base, len_const)
+ }
+ };
+ *self
+ .complex
+ .entry(SpirvType::Array(typ, array_dimensions))
+ .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length)))
+ }
+ SpirvType::Func(ref out_params, ref in_params) => {
+ let out_t = match out_params {
+ Some(p) => self.get_or_add(b, *p.clone()),
+ None => SpirvWord(self.void()),
+ };
+ let in_t = in_params
+ .iter()
+ .map(|t| self.get_or_add(b, t.clone()).0)
+ .collect::<Vec<_>>();
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t)))
+ }
+ SpirvType::Struct(ref underlying) => {
+ let underlying_ids = underlying
+ .iter()
+ .map(|t| self.get_or_add_spirv_scalar(b, *t).0)
+ .collect::<Vec<_>>();
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids)))
+ }
+ }
+ }
+
+ fn get_or_add_fn(
+ &mut self,
+ b: &mut dr::Builder,
+ in_params: impl Iterator<Item = SpirvType>,
+ mut out_params: impl ExactSizeIterator<Item = SpirvType>,
+ ) -> (SpirvWord, SpirvWord) {
+ let (out_args, out_spirv_type) = if out_params.len() == 0 {
+ (None, SpirvWord(self.void()))
+ } else if out_params.len() == 1 {
+ let arg_as_key = out_params.next().unwrap();
+ (
+ Some(Box::new(arg_as_key.clone())),
+ self.get_or_add(b, arg_as_key),
+ )
+ } else {
+ // TODO: support multiple return values
+ todo!()
+ };
+ (
+ out_spirv_type,
+ self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::<Vec<_>>())),
+ )
+ }
+
+ fn get_or_add_constant(
+ &mut self,
+ b: &mut dr::Builder,
+ typ: &ast::Type,
+ init: &[u8],
+ ) -> Result<SpirvWord, TranslateError> {
+ Ok(match typ {
+ ast::Type::Scalar(t) => match t {
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self
+ .get_or_add_constant_single::<u8, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v as u32),
+ ),
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self
+ .get_or_add_constant_single::<u16, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v as u32),
+ ),
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self
+ .get_or_add_constant_single::<u32, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v),
+ ),
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self
+ .get_or_add_constant_single::<u64, _, _>(
+ b,
+ *t,
+ init,
+ |v| v,
+ |b, result_type, v| b.constant_u64(result_type, None, v),
+ ),
+ ast::ScalarType::F16 => self.get_or_add_constant_single::<f16, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u16>(v) } as u64,
+ |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()),
+ ),
+ ast::ScalarType::F32 => self.get_or_add_constant_single::<f32, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u32>(v) } as u64,
+ |b, result_type, v| b.constant_f32(result_type, None, v),
+ ),
+ ast::ScalarType::F64 => self.get_or_add_constant_single::<f64, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u64>(v) },
+ |b, result_type, v| b.constant_f64(result_type, None, v),
+ ),
+ ast::ScalarType::F16x2 => return Err(TranslateError::Todo),
+ ast::ScalarType::Pred => self.get_or_add_constant_single::<u8, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| {
+ if v == 0 {
+ b.constant_false(result_type, None)
+ } else {
+ b.constant_true(result_type, None)
+ }
+ },
+ ),
+ ast::ScalarType::S16x2
+ | ast::ScalarType::U16x2
+ | ast::ScalarType::BF16
+ | ast::ScalarType::BF16x2
+ | ast::ScalarType::B128 => todo!(),
+ },
+ ast::Type::Vector(len, typ) => {
+ let result_type =
+ self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
+ let size_of_t = typ.size_of();
+ let components = (0..*len)
+ .map(|x| {
+ Ok::<_, TranslateError>(
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Scalar(*typ),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )?
+ .0,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
+ }
+ ast::Type::Array(_, typ, dims) => match dims.as_slice() {
+ [] => return Err(error_unreachable()),
+ [dim] => {
+ let result_type = self
+ .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
+ let size_of_t = typ.size_of();
+ let components = (0..*dim)
+ .map(|x| {
+ Ok::<_, TranslateError>(
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Scalar(*typ),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )?
+ .0,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
+ }
+ [first_dim, rest @ ..] => {
+ let result_type = self.get_or_add(
+ b,
+ SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()),
+ );
+ let size_of_t = rest
+ .iter()
+ .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y));
+ let components = (0..*first_dim)
+ .map(|x| {
+ Ok::<_, TranslateError>(
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Array(None, *typ, rest.to_vec()),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )?
+ .0,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
+ }
+ },
+ ast::Type::Pointer(..) => return Err(error_unreachable()),
+ })
+ }
+
+ fn get_or_add_constant_single<
+ T: Copy,
+ CastAsU64: FnOnce(T) -> u64,
+ InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word,
+ >(
+ &mut self,
+ b: &mut dr::Builder,
+ key: ast::ScalarType,
+ init: &[u8],
+ cast: CastAsU64,
+ f: InsertConstant,
+ ) -> SpirvWord {
+ let value = unsafe { *(init.as_ptr() as *const T) };
+ let value_64 = cast(value);
+ let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64);
+ match self.constants.get(&ht_key) {
+ Some(value) => *value,
+ None => {
+ let spirv_type = self.get_or_add_scalar(b, key);
+ let result = SpirvWord(f(b, spirv_type.0, value));
+ self.constants.insert(ht_key, result);
+ result
+ }
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Hash, Clone)]
+enum SpirvType {
+ Base(SpirvScalarKey),
+ Vector(SpirvScalarKey, u8),
+ Array(SpirvScalarKey, Vec<u32>),
+ Pointer(Box<SpirvType>, spirv::StorageClass),
+ Func(Option<Box<SpirvType>>, Vec<SpirvType>),
+ Struct(Vec<SpirvScalarKey>),
+}
+
+impl SpirvType {
+ fn new(t: ast::Type) -> Self {
+ match t {
+ ast::Type::Scalar(t) => SpirvType::Base(t.into()),
+ ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len),
+ ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len),
+ ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
+ Box::new(SpirvType::Base(pointer_t.into())),
+ space_to_spirv(space),
+ ),
+ }
+ }
+
+ fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self {
+ let key = Self::new(t);
+ SpirvType::Pointer(Box::new(key), outer_space)
+ }
+}
+
+impl From<ast::ScalarType> for SpirvType {
+ fn from(t: ast::ScalarType) -> Self {
+ SpirvType::Base(t.into())
+ }
+}
+// SPIR-V integer type definitions are signless, more below:
+// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
+// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+enum SpirvScalarKey {
+ B8,
+ B16,
+ B32,
+ B64,
+ F16,
+ F32,
+ F64,
+ Pred,
+ F16x2,
+}
+
+impl From<ast::ScalarType> for SpirvScalarKey {
+ fn from(t: ast::ScalarType) -> Self {
+ match t {
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8,
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
+ SpirvScalarKey::B16
+ }
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => {
+ SpirvScalarKey::B32
+ }
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => {
+ SpirvScalarKey::B64
+ }
+ ast::ScalarType::F16 => SpirvScalarKey::F16,
+ ast::ScalarType::F32 => SpirvScalarKey::F32,
+ ast::ScalarType::F64 => SpirvScalarKey::F64,
+ ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
+ ast::ScalarType::Pred => SpirvScalarKey::Pred,
+ ast::ScalarType::S16x2
+ | ast::ScalarType::U16x2
+ | ast::ScalarType::BF16
+ | ast::ScalarType::BF16x2
+ | ast::ScalarType::B128 => todo!(),
+ }
+ }
+}
+
+fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass {
+ match this {
+ ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::StateSpace::Generic => spirv::StorageClass::Generic,
+ ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::StateSpace::Local => spirv::StorageClass::Function,
+ ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::StateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Reg => spirv::StorageClass::Function,
+ ast::StateSpace::Sreg => spirv::StorageClass::Input,
+ ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta => todo!(),
+ }
+}
+
+// TODO: remove this once we have pef-function support for denorms
+fn emit_denorm_build_string<'input>(
+ call_map: &MethodsCallMap,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, SpirvWord>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
+) -> (CString, bool) {
+ let denorm_counts = denorm_information
+ .iter()
+ .map(|(method, meth_denorm)| {
+ let f16_count = meth_denorm
+ .get(&(mem::size_of::<f16>() as u8))
+ .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
+ .1;
+ let f32_count = meth_denorm
+ .get(&(mem::size_of::<f32>() as u8))
+ .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
+ .1;
+ (method, (f16_count + f32_count))
+ })
+ .collect::<HashMap<_, _>>();
+ let mut flush_over_preserve = 0;
+ for (kernel, children) in call_map.kernels() {
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Kernel(kernel))
+ .unwrap_or(&0);
+ for child_fn in children {
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Func(*child_fn))
+ .unwrap_or(&0);
+ }
+ }
+ if flush_over_preserve > 0 {
+ (
+ CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(),
+ true,
+ )
+ } else {
+ (CString::new("-ze-take-global-address").unwrap(), false)
+ }
+}
+
+fn get_globals_use_map<'input>(
+ directives: Vec<Directive<'input>>,
+) -> (
+ Vec<Directive<'input>>,
+ HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+) {
+ let mut known_globals = HashSet::new();
+ for directive in directives.iter() {
+ match directive {
+ Directive::Variable(_, ast::Variable { name, .. }) => {
+ known_globals.insert(*name);
+ }
+ Directive::Method(..) => {}
+ }
+ }
+ let mut symbol_uses_map = HashMap::new();
+ let directives = directives
+ .into_iter()
+ .map(|directive| match directive {
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive,
+ Directive::Method(Function {
+ func_decl,
+ body: Some(mut statements),
+ globals,
+ import_as,
+ tuning,
+ linkage,
+ }) => {
+ let method_name = func_decl.borrow().name;
+ statements = statements
+ .into_iter()
+ .map(|statement| {
+ statement.visit_map(
+ &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
+ if known_globals.contains(&symbol) {
+ multi_hash_map_append(
+ &mut symbol_uses_map,
+ method_name,
+ symbol,
+ );
+ }
+ Ok::<_, TranslateError>(symbol)
+ },
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()
+ .unwrap();
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ globals,
+ import_as,
+ tuning,
+ linkage,
+ })
+ }
+ })
+ .collect::<Vec<_>>();
+ (directives, symbol_uses_map)
+}
+
+fn emit_directives<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver<'input>,
+ opencl_id: spirv::Word,
+ should_flush_denorms: bool,
+ call_map: &MethodsCallMap<'input>,
+ globals_use_map: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+ directives: Vec<Directive<'input>>,
+ kernel_info: &mut HashMap<String, KernelInfo>,
+) -> Result<(), TranslateError> {
+ let empty_body = Vec::new();
+ for d in directives.iter() {
+ match d {
+ Directive::Variable(linking, var) => {
+ emit_variable(builder, map, id_defs, *linking, &var)?;
+ }
+ Directive::Method(f) => {
+ let f_body = match &f.body {
+ Some(f) => f,
+ None => {
+ if f.linkage.contains(ast::LinkingDirective::EXTERN) {
+ &empty_body
+ } else {
+ continue;
+ }
+ }
+ };
+ for var in f.globals.iter() {
+ emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?;
+ }
+ let func_decl = (*f.func_decl).borrow();
+ let fn_id = emit_function_header(
+ builder,
+ map,
+ &id_defs,
+ &*func_decl,
+ call_map,
+ &globals_use_map,
+ kernel_info,
+ )?;
+ if matches!(func_decl.name, ast::MethodName::Kernel(_)) {
+ if should_flush_denorms {
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::DenormFlushToZero,
+ [16],
+ );
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::DenormFlushToZero,
+ [32],
+ );
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::DenormFlushToZero,
+ [64],
+ );
+ }
+ // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx)
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::ContractionOff,
+ [],
+ );
+ for t in f.tuning.iter() {
+ match *t {
+ ast::TuningDirective::MaxNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
+ [nx, ny, nz],
+ );
+ }
+ ast::TuningDirective::ReqNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id.0,
+ spirv_headers::ExecutionMode::LocalSize,
+ [nx, ny, nz],
+ );
+ }
+ // Too architecture specific
+ ast::TuningDirective::MaxNReg(..)
+ | ast::TuningDirective::MinNCtaPerSm(..) => {}
+ }
+ }
+ }
+ emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
+ emit_function_linkage(builder, id_defs, f, fn_id)?;
+ builder.select_block(None)?;
+ builder.end_function()?;
+ }
+ }
+ }
+ Ok(())
+}
+
+fn emit_variable<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver<'input>,
+ linking: ast::LinkingDirective,
+ var: &ast::Variable<SpirvWord>,
+) -> Result<(), TranslateError> {
+ let (must_init, st_class) = match var.state_space {
+ ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
+ (false, spirv::StorageClass::Function)
+ }
+ ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
+ ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
+ ast::StateSpace::Generic => todo!(),
+ ast::StateSpace::Sreg => todo!(),
+ ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta => todo!(),
+ };
+ let initalizer = if var.array_init.len() > 0 {
+ Some(
+ map.get_or_add_constant(
+ builder,
+ &ast::Type::from(var.v_type.clone()),
+ &*var.array_init,
+ )?
+ .0,
+ )
+ } else if must_init {
+ let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone()));
+ Some(builder.constant_null(type_id.0, None))
+ } else {
+ None
+ };
+ let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class));
+ builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer);
+ if let Some(align) = var.align {
+ builder.decorate(
+ var.name.0,
+ spirv::Decoration::Alignment,
+ [dr::Operand::LiteralInt32(align)].iter().cloned(),
+ );
+ }
+ if var.state_space != ast::StateSpace::Shared
+ || !linking.contains(ast::LinkingDirective::EXTERN)
+ {
+ emit_linking_decoration(builder, id_defs, None, var.name, linking);
+ }
+ Ok(())
+}
+
+fn emit_function_header<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ defined_globals: &GlobalStringIdResolver<'input>,
+ func_decl: &ast::MethodDeclaration<'input, SpirvWord>,
+ call_map: &MethodsCallMap<'input>,
+ globals_use_map: &HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+ kernel_info: &mut HashMap<String, KernelInfo>,
+) -> Result<SpirvWord, TranslateError> {
+ if let ast::MethodName::Kernel(name) = func_decl.name {
+ let args_lens = func_decl
+ .input_arguments
+ .iter()
+ .map(|param| {
+ (
+ type_size_of(¶m.v_type),
+ matches!(param.v_type, ast::Type::Pointer(..)),
+ )
+ })
+ .collect();
+ kernel_info.insert(
+ name.to_string(),
+ KernelInfo {
+ arguments_sizes: args_lens,
+ uses_shared_mem: func_decl.shared_mem.is_some(),
+ },
+ );
+ }
+ let (ret_type, func_type) = get_function_type(
+ builder,
+ map,
+ effective_input_arguments(func_decl).map(|(_, typ)| typ),
+ &func_decl.return_arguments,
+ );
+ let fn_id = match func_decl.name {
+ ast::MethodName::Kernel(name) => {
+ let fn_id = defined_globals.get_id(name)?;
+ let interface = globals_use_map
+ .get(&ast::MethodName::Kernel(name))
+ .into_iter()
+ .flatten()
+ .copied()
+ .chain({
+ call_map
+ .get_kernel_children(name)
+ .copied()
+ .flat_map(|subfunction| {
+ globals_use_map
+ .get(&ast::MethodName::Func(subfunction))
+ .into_iter()
+ .flatten()
+ .copied()
+ })
+ .into_iter()
+ })
+ .map(|word| word.0)
+ .collect::<Vec<spirv::Word>>();
+ builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface);
+ fn_id
+ }
+ ast::MethodName::Func(name) => name,
+ };
+ builder.begin_function(
+ ret_type.0,
+ Some(fn_id.0),
+ spirv::FunctionControl::NONE,
+ func_type.0,
+ )?;
+ for (name, typ) in effective_input_arguments(func_decl) {
+ let result_type = map.get_or_add(builder, typ);
+ builder.function_parameter(Some(name.0), result_type.0)?;
+ }
+ Ok(fn_id)
+}
+
+pub fn type_size_of(this: &ast::Type) -> usize {
+ match this {
+ ast::Type::Scalar(typ) => typ.size_of() as usize,
+ ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize),
+ ast::Type::Array(_, typ, len) => len
+ .iter()
+ .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
+ ast::Type::Pointer(..) => mem::size_of::<usize>(),
+ }
+}
+fn emit_function_body_ops<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver<'input>,
+ opencl: spirv::Word,
+ func: &[ExpandedStatement],
+) -> Result<(), TranslateError> {
+ for s in func {
+ match s {
+ Statement::Label(id) => {
+ if builder.selected_block().is_some() {
+ builder.branch(id.0)?;
+ }
+ builder.begin_block(Some(id.0))?;
+ }
+ _ => {
+ if builder.selected_block().is_none() && builder.selected_function().is_some() {
+ builder.begin_block(None)?;
+ }
+ }
+ }
+ match s {
+ Statement::Label(_) => (),
+ Statement::Variable(var) => {
+ emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?;
+ }
+ Statement::Constant(cnst) => {
+ let typ_id = map.get_or_add_scalar(builder, cnst.typ);
+ match (cnst.typ, cnst.value) {
+ (ast::ScalarType::B8, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32);
+ }
+ (ast::ScalarType::B16, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32);
+ }
+ (ast::ScalarType::B32, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32);
+ }
+ (ast::ScalarType::B64, ast::ImmediateValue::U64(value))
+ | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u64(typ_id.0, Some(cnst.dst.0), value);
+ }
+ (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32);
+ }
+ (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32);
+ }
+ (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32);
+ }
+ (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => {
+ builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64);
+ }
+ (ast::ScalarType::B8, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32);
+ }
+ (ast::ScalarType::B16, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32);
+ }
+ (ast::ScalarType::B32, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32);
+ }
+ (ast::ScalarType::B64, ast::ImmediateValue::S64(value))
+ | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64);
+ }
+ (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32);
+ }
+ (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32);
+ }
+ (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32);
+ }
+ (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => {
+ builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64);
+ }
+ (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f32(
+ typ_id.0,
+ Some(cnst.dst.0),
+ f16::from_f32(value).to_f32(),
+ );
+ }
+ (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f32(typ_id.0, Some(cnst.dst.0), value);
+ }
+ (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => {
+ builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64);
+ }
+ (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f32(
+ typ_id.0,
+ Some(cnst.dst.0),
+ f16::from_f64(value).to_f32(),
+ );
+ }
+ (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32);
+ }
+ (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => {
+ builder.constant_f64(typ_id.0, Some(cnst.dst.0), value);
+ }
+ (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => {
+ let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0;
+ if value == 0 {
+ builder.constant_false(bool_type, Some(cnst.dst.0));
+ } else {
+ builder.constant_true(bool_type, Some(cnst.dst.0));
+ }
+ }
+ (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => {
+ let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0;
+ if value == 0 {
+ builder.constant_false(bool_type, Some(cnst.dst.0));
+ } else {
+ builder.constant_true(bool_type, Some(cnst.dst.0));
+ }
+ }
+ _ => return Err(error_mismatched_type()),
+ }
+ }
+ Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
+ Statement::Conditional(bra) => {
+ builder.branch_conditional(
+ bra.predicate.0,
+ bra.if_true.0,
+ bra.if_false.0,
+ iter::empty(),
+ )?;
+ }
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
+ // TODO: implement properly
+ let zero = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U64),
+ &vec_repr(0u64),
+ )?;
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64);
+ builder.copy_object(result_type.0, Some(dst.0), zero.0)?;
+ }
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(),
+ ast::Instruction::Call { data, arguments } => {
+ let (result_type, result_id) =
+ match (&*data.return_arguments, &*arguments.return_arguments) {
+ ([(type_, space)], [id]) => {
+ if *space != ast::StateSpace::Reg {
+ return Err(error_unreachable());
+ }
+ (
+ map.get_or_add(builder, SpirvType::new(type_.clone())).0,
+ Some(id.0),
+ )
+ }
+ ([], []) => (map.void(), None),
+ _ => todo!(),
+ };
+ let arg_list = arguments
+ .input_arguments
+ .iter()
+ .map(|id| id.0)
+ .collect::<Vec<_>>();
+ builder.function_call(result_type, result_id, arguments.func.0, arg_list)?;
+ }
+ ast::Instruction::Abs { data, arguments } => {
+ emit_abs(builder, map, opencl, data, arguments)?
+ }
+ // SPIR-V does not support marking jumps as guaranteed-converged
+ ast::Instruction::Bra { arguments, .. } => {
+ builder.branch(arguments.src.0)?;
+ }
+ ast::Instruction::Ld { data, arguments } => {
+ let mem_access = match data.qualifier {
+ ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE,
+ // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad
+ ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE,
+ _ => return Err(TranslateError::Todo),
+ };
+ let result_type =
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
+ builder.load(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src.0,
+ Some(mem_access | spirv::MemoryAccess::ALIGNED),
+ [dr::Operand::LiteralInt32(
+ type_size_of(&ast::Type::from(data.typ.clone())) as u32,
+ )]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ ast::Instruction::St { data, arguments } => {
+ let mem_access = match data.qualifier {
+ ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE,
+ // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore
+ ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE,
+ _ => return Err(TranslateError::Todo),
+ };
+ builder.store(
+ arguments.src1.0,
+ arguments.src2.0,
+ Some(mem_access | spirv::MemoryAccess::ALIGNED),
+ [dr::Operand::LiteralInt32(
+ type_size_of(&ast::Type::from(data.typ.clone())) as u32,
+ )]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ // SPIR-V does not support ret as guaranteed-converged
+ ast::Instruction::Ret { .. } => builder.ret()?,
+ ast::Instruction::Mov { data, arguments } => {
+ let result_type =
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
+ builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
+ }
+ ast::Instruction::Mul { data, arguments } => match data {
+ ast::MulDetails::Integer { type_, control } => {
+ emit_mul_int(builder, map, opencl, *type_, *control, arguments)?
+ }
+ ast::MulDetails::Float(ref ctr) => {
+ emit_mul_float(builder, map, ctr, arguments)?
+ }
+ },
+ ast::Instruction::Add { data, arguments } => match data {
+ ast::ArithDetails::Integer(desc) => {
+ emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)?
+ }
+ ast::ArithDetails::Float(desc) => {
+ emit_add_float(builder, map, desc, arguments)?
+ }
+ },
+ ast::Instruction::Setp { data, arguments } => {
+ if arguments.dst2.is_some() {
+ todo!()
+ }
+ emit_setp(builder, map, data, arguments)?;
+ }
+ ast::Instruction::Not { data, arguments } => {
+ let result_type = map.get_or_add(builder, SpirvType::from(*data));
+ let result_id = Some(arguments.dst.0);
+ let operand = arguments.src;
+ match data {
+ ast::ScalarType::Pred => {
+ logical_not(builder, result_type.0, result_id, operand.0)
+ }
+ _ => builder.not(result_type.0, result_id, operand.0),
+ }?;
+ }
+ ast::Instruction::Shl { data, arguments } => {
+ let full_type = ast::Type::Scalar(*data);
+ let size_of = type_size_of(&full_type);
+ let result_type = map.get_or_add(builder, SpirvType::new(full_type));
+ let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?;
+ builder.shift_left_logical(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ offset_src,
+ )?;
+ }
+ ast::Instruction::Shr { data, arguments } => {
+ let full_type = ast::ScalarType::from(data.type_);
+ let size_of = full_type.size_of();
+ let result_type = map.get_or_add_scalar(builder, full_type).0;
+ let offset_src =
+ insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?;
+ match data.kind {
+ ptx_parser::RightShiftKind::Arithmetic => {
+ builder.shift_right_arithmetic(
+ result_type,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ offset_src,
+ )?;
+ }
+ ptx_parser::RightShiftKind::Logical => {
+ builder.shift_right_logical(
+ result_type,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ offset_src,
+ )?;
+ }
+ }
+ }
+ ast::Instruction::Cvt { data, arguments } => {
+ emit_cvt(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::Cvta { data, arguments } => {
+ // This would be only meaningful if const/slm/global pointers
+ // had a different format than generic pointers, but they don't pretty much by ptx definition
+ // Honestly, I have no idea why this instruction exists and is emitted by the compiler
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
+ builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
+ }
+ ast::Instruction::SetpBool { .. } => todo!(),
+ ast::Instruction::Mad { data, arguments } => match data {
+ ast::MadDetails::Integer {
+ type_,
+ control,
+ saturate,
+ } => {
+ if *saturate {
+ todo!()
+ }
+ if type_.kind() == ast::ScalarKind::Signed {
+ emit_mad_sint(builder, map, opencl, *type_, *control, arguments)?
+ } else {
+ emit_mad_uint(builder, map, opencl, *type_, *control, arguments)?
+ }
+ }
+ ast::MadDetails::Float(desc) => {
+ emit_mad_float(builder, map, opencl, desc, arguments)?
+ }
+ },
+ ast::Instruction::Fma { data, arguments } => {
+ emit_fma_float(builder, map, opencl, data, arguments)?
+ }
+ ast::Instruction::Or { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, *data).0;
+ if *data == ast::ScalarType::Pred {
+ builder.logical_or(
+ result_type,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ } else {
+ builder.bitwise_or(
+ result_type,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ }
+ ast::Instruction::Sub { data, arguments } => match data {
+ ast::ArithDetails::Integer(desc) => {
+ emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?;
+ }
+ ast::ArithDetails::Float(desc) => {
+ emit_sub_float(builder, map, desc, arguments)?;
+ }
+ },
+ ast::Instruction::Min { data, arguments } => {
+ emit_min(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::Max { data, arguments } => {
+ emit_max(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::Rcp { data, arguments } => {
+ emit_rcp(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::And { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, *data);
+ if *data == ast::ScalarType::Pred {
+ builder.logical_and(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ } else {
+ builder.bitwise_and(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ }
+ ast::Instruction::Selp { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, *data);
+ builder.select(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src3.0,
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ // TODO: implement named barriers
+ ast::Instruction::Bar { data, arguments } => {
+ let workgroup_scope = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(spirv::Scope::Workgroup as u32),
+ )?;
+ let barrier_semantics = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ )?;
+ builder.control_barrier(
+ workgroup_scope.0,
+ workgroup_scope.0,
+ barrier_semantics.0,
+ )?;
+ }
+ ast::Instruction::Atom { data, arguments } => {
+ emit_atom(builder, map, data, arguments)?;
+ }
+ ast::Instruction::AtomCas { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, data.type_);
+ let memory_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(scope_to_spirv(data.scope) as u32),
+ )?;
+ let semantics_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(semantics_to_spirv(data.semantics).bits()),
+ )?;
+ builder.atomic_compare_exchange(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ memory_const.0,
+ semantics_const.0,
+ semantics_const.0,
+ arguments.src3.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::Instruction::Div { data, arguments } => match data {
+ ast::DivDetails::Unsigned(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.u_div(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::DivDetails::Signed(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.s_div(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::DivDetails::Float(t) => {
+ let result_type = map.get_or_add_scalar(builder, t.type_.into());
+ builder.f_div(
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ emit_float_div_decoration(builder, arguments.dst, t.kind);
+ }
+ },
+ ast::Instruction::Sqrt { data, arguments } => {
+ emit_sqrt(builder, map, opencl, data, arguments)?;
+ }
+ ast::Instruction::Rsqrt { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, data.type_.into());
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::rsqrt as spirv::Word,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Neg { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, data.type_);
+ let negate_func = if data.type_.kind() == ast::ScalarKind::Float {
+ dr::Builder::f_negate
+ } else {
+ dr::Builder::s_negate
+ };
+ negate_func(
+ builder,
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src.0,
+ )?;
+ }
+ ast::Instruction::Sin { arguments, .. } => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::sin as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Cos { arguments, .. } => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::cos as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Lg2 { arguments, .. } => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::log2 as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Ex2 { arguments, .. } => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::exp2 as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Clz { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder.ext_inst(
+ result_type.0,
+ Some(arguments.dst.0),
+ opencl,
+ spirv::CLOp::clz as u32,
+ [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
+ )?;
+ }
+ ast::Instruction::Brev { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
+ }
+ ast::Instruction::Popc { data, arguments } => {
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
+ }
+ ast::Instruction::Xor { data, arguments } => {
+ let builder_fn: fn(
+ &mut dr::Builder,
+ u32,
+ Option<u32>,
+ u32,
+ u32,
+ ) -> Result<u32, dr::Error> = match data {
+ ast::ScalarType::Pred => emit_logical_xor_spirv,
+ _ => dr::Builder::bitwise_xor,
+ };
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder_fn(
+ builder,
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::Instruction::Bfe { .. }
+ | ast::Instruction::Bfi { .. }
+ | ast::Instruction::Activemask { .. } => {
+ // Should have beeen replaced with a funciton call earlier
+ return Err(error_unreachable());
+ }
+
+ ast::Instruction::Rem { data, arguments } => {
+ let builder_fn = if data.kind() == ast::ScalarKind::Signed {
+ dr::Builder::s_mod
+ } else {
+ dr::Builder::u_mod
+ };
+ let result_type = map.get_or_add_scalar(builder, (*data).into());
+ builder_fn(
+ builder,
+ result_type.0,
+ Some(arguments.dst.0),
+ arguments.src1.0,
+ arguments.src2.0,
+ )?;
+ }
+ ast::Instruction::Prmt { data, arguments } => {
+ let control = *data as u32;
+ let components = [
+ (control >> 0) & 0b1111,
+ (control >> 4) & 0b1111,
+ (control >> 8) & 0b1111,
+ (control >> 12) & 0b1111,
+ ];
+ if components.iter().any(|&c| c > 7) {
+ return Err(TranslateError::Todo);
+ }
+ let vec4_b8_type =
+ map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4));
+ let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
+ let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?;
+ let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?;
+ let dst_vector = builder.vector_shuffle(
+ vec4_b8_type.0,
+ None,
+ src1_vector,
+ src2_vector,
+ components,
+ )?;
+ builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?;
+ }
+ ast::Instruction::Membar { data } => {
+ let (scope, semantics) = match data {
+ ast::MemScope::Cta => (
+ spirv::Scope::Workgroup,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ ast::MemScope::Gpu => (
+ spirv::Scope::Device,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ ast::MemScope::Sys => (
+ spirv::Scope::CrossDevice,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+
+ ast::MemScope::Cluster => todo!(),
+ };
+ let spirv_scope = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(scope as u32),
+ )?;
+ let spirv_semantics = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(semantics),
+ )?;
+ builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?;
+ }
+ },
+ Statement::LoadVar(details) => {
+ emit_load_var(builder, map, details)?;
+ }
+ Statement::StoreVar(details) => {
+ let dst_ptr = match details.member_index {
+ Some(index) => {
+ let result_ptr_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(
+ details.typ.clone(),
+ spirv::StorageClass::Function,
+ ),
+ );
+ let index_spirv = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(index as u32),
+ )?;
+ builder.in_bounds_access_chain(
+ result_ptr_type.0,
+ None,
+ details.arg.src1.0,
+ [index_spirv.0].iter().copied(),
+ )?
+ }
+ None => details.arg.src1.0,
+ };
+ builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?;
+ }
+ Statement::RetValue(_, id) => {
+ builder.ret_value(id.0)?;
+ }
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src,
+ offset_src,
+ }) => {
+ let u8_pointer = map.get_or_add(
+ builder,
+ SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)),
+ );
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)),
+ );
+ let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?;
+ let temp = builder.in_bounds_ptr_access_chain(
+ u8_pointer.0,
+ None,
+ ptr_src_u8,
+ offset_src.0,
+ iter::empty(),
+ )?;
+ builder.bitcast(result_type.0, Some(dst.0), temp)?;
+ }
+ Statement::RepackVector(repack) => {
+ if repack.is_extract {
+ let scalar_type = map.get_or_add_scalar(builder, repack.typ);
+ for (index, dst_id) in repack.unpacked.iter().enumerate() {
+ builder.composite_extract(
+ scalar_type.0,
+ Some(dst_id.0),
+ repack.packed.0,
+ [index as u32].iter().copied(),
+ )?;
+ }
+ } else {
+ let vector_type = map.get_or_add(
+ builder,
+ SpirvType::Vector(
+ SpirvScalarKey::from(repack.typ),
+ repack.unpacked.len() as u8,
+ ),
+ );
+ let mut temp_vec = builder.undef(vector_type.0, None);
+ for (index, src_id) in repack.unpacked.iter().enumerate() {
+ temp_vec = builder.composite_insert(
+ vector_type.0,
+ None,
+ src_id.0,
+ temp_vec,
+ [index as u32].iter().copied(),
+ )?;
+ }
+ builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
+ }
+ }
+ }
+ }
+ Ok(())
+}
+
+fn emit_function_linkage<'input>(
+ builder: &mut dr::Builder,
+ id_defs: &GlobalStringIdResolver<'input>,
+ f: &Function,
+ fn_name: SpirvWord,
+) -> Result<(), TranslateError> {
+ if f.linkage == ast::LinkingDirective::NONE {
+ return Ok(());
+ };
+ let linking_name = match f.func_decl.borrow().name {
+ // According to SPIR-V rules linkage attributes are invalid on kernels
+ ast::MethodName::Kernel(..) => return Ok(()),
+ ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else(
+ || match id_defs.reverse_variables.get(&fn_id) {
+ Some(fn_name) => Ok(fn_name),
+ None => Err(error_unknown_symbol()),
+ },
+ Result::Ok,
+ )?,
+ };
+ emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage);
+ Ok(())
+}
+
+fn get_function_type(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ spirv_input: impl Iterator<Item = SpirvType>,
+ spirv_output: &[ast::Variable<SpirvWord>],
+) -> (SpirvWord, SpirvWord) {
+ map.get_or_add_fn(
+ builder,
+ spirv_input,
+ spirv_output
+ .iter()
+ .map(|var| SpirvType::new(var.v_type.clone())),
+ )
+}
+
+fn emit_linking_decoration<'input>(
+ builder: &mut dr::Builder,
+ id_defs: &GlobalStringIdResolver<'input>,
+ name_override: Option<&str>,
+ name: SpirvWord,
+ linking: ast::LinkingDirective,
+) {
+ if linking == ast::LinkingDirective::NONE {
+ return;
+ }
+ if linking.contains(ast::LinkingDirective::VISIBLE) {
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
+ builder.decorate(
+ name.0,
+ spirv::Decoration::LinkageAttributes,
+ [
+ dr::Operand::LiteralString(string_name.to_string()),
+ dr::Operand::LinkageType(spirv::LinkageType::Export),
+ ]
+ .iter()
+ .cloned(),
+ );
+ } else if linking.contains(ast::LinkingDirective::EXTERN) {
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
+ builder.decorate(
+ name.0,
+ spirv::Decoration::LinkageAttributes,
+ [
+ dr::Operand::LiteralString(string_name.to_string()),
+ dr::Operand::LinkageType(spirv::LinkageType::Import),
+ ]
+ .iter()
+ .cloned(),
+ );
+ }
+ // TODO: handle LinkingDirective::WEAK
+}
+
+fn effective_input_arguments<'a>(
+ this: &'a ast::MethodDeclaration<'a, SpirvWord>,
+) -> impl Iterator<Item = (SpirvWord, SpirvType)> + 'a {
+ let is_kernel = matches!(this.name, ast::MethodName::Kernel(_));
+ this.input_arguments.iter().map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space));
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
+}
+
+fn emit_implicit_conversion(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ cv: &ImplicitConversion,
+) -> Result<(), TranslateError> {
+ let from_parts = to_parts(&cv.from_type);
+ let to_parts = to_parts(&cv.to_type);
+ match (from_parts.kind, to_parts.kind, &cv.kind) {
+ (_, _, &ConversionKind::BitToPtr) => {
+ let dst_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)),
+ );
+ builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => {
+ if from_parts.width == to_parts.width {
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ if from_parts.scalar_kind != ast::ScalarKind::Float
+ && to_parts.scalar_kind != ast::ScalarKind::Float
+ {
+ // It is noop, but another instruction expects result of this conversion
+ builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?;
+ } else {
+ builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ } else {
+ // This block is safe because it's illegal to implictly convert between floating point values
+ let same_width_bit_type = map.get_or_add(
+ builder,
+ SpirvType::new(type_from_parts(TypeParts {
+ scalar_kind: ast::ScalarKind::Bit,
+ ..from_parts
+ })),
+ );
+ let same_width_bit_value =
+ builder.bitcast(same_width_bit_type.0, None, cv.src.0)?;
+ let wide_bit_type = type_from_parts(TypeParts {
+ scalar_kind: ast::ScalarKind::Bit,
+ ..to_parts
+ });
+ let wide_bit_type_spirv =
+ map.get_or_add(builder, SpirvType::new(wide_bit_type.clone()));
+ if to_parts.scalar_kind == ast::ScalarKind::Unsigned
+ || to_parts.scalar_kind == ast::ScalarKind::Bit
+ {
+ builder.u_convert(
+ wide_bit_type_spirv.0,
+ Some(cv.dst.0),
+ same_width_bit_value,
+ )?;
+ } else {
+ let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
+ && to_parts.scalar_kind == ast::ScalarKind::Signed
+ {
+ dr::Builder::s_convert
+ } else {
+ dr::Builder::u_convert
+ };
+ let wide_bit_value =
+ conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?;
+ emit_implicit_conversion(
+ builder,
+ map,
+ &ImplicitConversion {
+ src: SpirvWord(wide_bit_value),
+ dst: cv.dst,
+ from_type: wide_bit_type,
+ from_space: cv.from_space,
+ to_type: cv.to_type.clone(),
+ to_space: cv.to_space,
+ kind: ConversionKind::Default,
+ },
+ )?;
+ }
+ }
+ }
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default)
+ | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default)
+ | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => {
+ let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (_, _, &ConversionKind::PtrToPtr) => {
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ space_to_spirv(cv.to_space),
+ ),
+ );
+ if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic
+ {
+ let src = if cv.from_type != cv.to_type {
+ let temp_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ space_to_spirv(cv.from_space),
+ ),
+ );
+ builder.bitcast(temp_type.0, None, cv.src.0)?
+ } else {
+ cv.src.0
+ };
+ builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?;
+ } else if cv.from_space == ast::StateSpace::Generic
+ && cv.to_space != ast::StateSpace::Generic
+ {
+ let src = if cv.from_type != cv.to_type {
+ let temp_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ space_to_spirv(cv.from_space),
+ ),
+ );
+ builder.bitcast(temp_type.0, None, cv.src.0)?
+ } else {
+ cv.src.0
+ };
+ builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?;
+ } else {
+ builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ }
+ (_, _, &ConversionKind::AddressOf) => {
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?;
+ }
+ _ => unreachable!(),
+ }
+ Ok(())
+}
+
+fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
+ let mut result = vec![0; mem::size_of::<T>()];
+ unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) };
+ result
+}
+
+fn emit_abs(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ d: &ast::TypeFtz,
+ arg: &ast::AbsArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let scalar_t = ast::ScalarType::from(d.type_);
+ let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
+ let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
+ spirv::CLOp::s_abs
+ } else {
+ spirv::CLOp::fabs
+ };
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ cl_abs as spirv::Word,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_mul_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ type_: ast::ScalarType,
+ control: ast::MulIntControl,
+ arg: &ast::MulArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(type_));
+ match control {
+ ast::MulIntControl::Low => {
+ builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ }
+ ast::MulIntControl::High => {
+ let opencl_inst = if type_.kind() == ast::ScalarKind::Signed {
+ spirv::CLOp::s_mul_hi
+ } else {
+ spirv::CLOp::u_mul_hi
+ };
+ builder.ext_inst(
+ inst_type.0,
+ Some(arg.dst.0),
+ opencl,
+ opencl_inst as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ ast::MulIntControl::Wide => {
+ let instr_width = type_.size_of();
+ let instr_kind = type_.kind();
+ let dst_type = scalar_from_parts(instr_width * 2, instr_kind);
+ let dst_type_id = map.get_or_add_scalar(builder, dst_type);
+ let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed {
+ let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?;
+ let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?;
+ (src1, src2)
+ } else {
+ let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?;
+ let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?;
+ (src1, src2)
+ };
+ builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?;
+ builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty());
+ }
+ }
+ Ok(())
+}
+
+fn emit_mul_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ ctr: &ast::ArithFloat,
+ arg: &ast::MulArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ if ctr.saturate {
+ todo!()
+ }
+ let result_type = map.get_or_add_scalar(builder, ctr.type_.into());
+ builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ emit_rounding_decoration(builder, arg.dst, ctr.rounding);
+ Ok(())
+}
+
+fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType {
+ match kind {
+ ast::ScalarKind::Float => match width {
+ 2 => ast::ScalarType::F16,
+ 4 => ast::ScalarType::F32,
+ 8 => ast::ScalarType::F64,
+ _ => unreachable!(),
+ },
+ ast::ScalarKind::Bit => match width {
+ 1 => ast::ScalarType::B8,
+ 2 => ast::ScalarType::B16,
+ 4 => ast::ScalarType::B32,
+ 8 => ast::ScalarType::B64,
+ _ => unreachable!(),
+ },
+ ast::ScalarKind::Signed => match width {
+ 1 => ast::ScalarType::S8,
+ 2 => ast::ScalarType::S16,
+ 4 => ast::ScalarType::S32,
+ 8 => ast::ScalarType::S64,
+ _ => unreachable!(),
+ },
+ ast::ScalarKind::Unsigned => match width {
+ 1 => ast::ScalarType::U8,
+ 2 => ast::ScalarType::U16,
+ 4 => ast::ScalarType::U32,
+ 8 => ast::ScalarType::U64,
+ _ => unreachable!(),
+ },
+ ast::ScalarKind::Pred => ast::ScalarType::Pred,
+ }
+}
+
+fn emit_rounding_decoration(
+ builder: &mut dr::Builder,
+ dst: SpirvWord,
+ rounding: Option<ast::RoundingMode>,
+) {
+ if let Some(rounding) = rounding {
+ builder.decorate(
+ dst.0,
+ spirv::Decoration::FPRoundingMode,
+ [rounding_to_spirv(rounding)].iter().cloned(),
+ );
+ }
+}
+
+fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand {
+ let mode = match this {
+ ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE,
+ ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ,
+ ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP,
+ ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN,
+ };
+ rspirv::dr::Operand::FPRoundingMode(mode)
+}
+
+fn emit_add_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ typ: ast::ScalarType,
+ saturate: bool,
+ arg: &ast::AddArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ if saturate {
+ todo!()
+ }
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
+ builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ Ok(())
+}
+
+fn emit_add_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ desc: &ast::ArithFloat,
+ arg: &ast::AddArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)));
+ builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
+ Ok(())
+}
+
+fn emit_setp(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ setp: &ast::SetpData,
+ arg: &ast::SetpArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let result_type = map
+ .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred))
+ .0;
+ let result_id = Some(arg.dst1.0);
+ let operand_1 = arg.src1.0;
+ let operand_2 = arg.src2.0;
+ match setp.cmp_op {
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => {
+ builder.i_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => {
+ builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => {
+ builder.i_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => {
+ builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => {
+ builder.u_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => {
+ builder.s_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => {
+ builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => {
+ builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => {
+ builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => {
+ builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => {
+ builder.u_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => {
+ builder.s_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => {
+ builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => {
+ builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => {
+ builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => {
+ builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => {
+ builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => {
+ builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => {
+ builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => {
+ builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => {
+ builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => {
+ builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => {
+ let temp1 = builder.is_nan(result_type, None, operand_1)?;
+ let temp2 = builder.is_nan(result_type, None, operand_2)?;
+ builder.logical_or(result_type, result_id, temp1, temp2)
+ }
+ ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => {
+ let temp1 = builder.is_nan(result_type, None, operand_1)?;
+ let temp2 = builder.is_nan(result_type, None, operand_2)?;
+ let any_nan = builder.logical_or(result_type, None, temp1, temp2)?;
+ logical_not(builder, result_type, result_id, any_nan)
+ }
+ _ => todo!(),
+ }?;
+ Ok(())
+}
+
+// HACK ALERT
+// Temporary workaround until IGC gets its shit together
+// Currently IGC carries two copies of SPIRV-LLVM translator
+// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/.
+// Obviously, old and buggy one is used for compiling L0 SPIRV
+// https://github.com/intel/intel-graphics-compiler/issues/148
+fn logical_not(
+ builder: &mut dr::Builder,
+ result_type: spirv::Word,
+ result_id: Option<spirv::Word>,
+ operand: spirv::Word,
+) -> Result<spirv::Word, dr::Error> {
+ let const_true = builder.constant_true(result_type, None);
+ let const_false = builder.constant_false(result_type, None);
+ builder.select(result_type, result_id, operand, const_false, const_true)
+}
+
+// HACK ALERT
+// For some reason IGC fails linking if the value and shift size are of different type
+fn insert_shift_hack(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ offset_var: spirv::Word,
+ size_of: usize,
+) -> Result<spirv::Word, TranslateError> {
+ let result_type = match size_of {
+ 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16),
+ 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64),
+ 4 => return Ok(offset_var),
+ _ => return Err(error_unreachable()),
+ };
+ Ok(builder.u_convert(result_type.0, None, offset_var)?)
+}
+
+fn emit_cvt(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ dets: &ast::CvtDetails,
+ arg: &ast::CvtArgs<SpirvWord>,
+) -> Result<(), TranslateError> {
+ match dets.mode {
+ ptx_parser::CvtMode::SignExtend => {
+ let cv = ImplicitConversion {
+ src: arg.src,
+ dst: arg.dst,
+ from_type: dets.from.into(),
+ from_space: ast::StateSpace::Reg,
+ to_type: dets.to.into(),
+ to_space: ast::StateSpace::Reg,
+ kind: ConversionKind::SignExtend,
+ };
+ emit_implicit_conversion(builder, map, &cv)?;
+ }
+ ptx_parser::CvtMode::ZeroExtend
+ | ptx_parser::CvtMode::Truncate
+ | ptx_parser::CvtMode::Bitcast => {
+ let cv = ImplicitConversion {
+ src: arg.src,
+ dst: arg.dst,
+ from_type: dets.from.into(),
+ from_space: ast::StateSpace::Reg,
+ to_type: dets.to.into(),
+ to_space: ast::StateSpace::Reg,
+ kind: ConversionKind::Default,
+ };
+ emit_implicit_conversion(builder, map, &cv)?;
+ }
+ ptx_parser::CvtMode::SaturateUnsignedToSigned => {
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ }
+ ptx_parser::CvtMode::SaturateSignedToUnsigned => {
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ }
+ ptx_parser::CvtMode::FPExtend { flush_to_zero } => {
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ }
+ ptx_parser::CvtMode::FPTruncate {
+ rounding,
+ flush_to_zero,
+ } => {
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ ptx_parser::CvtMode::FPRound {
+ integer_rounding,
+ flush_to_zero,
+ } => {
+ if flush_to_zero == Some(true) {
+ todo!()
+ }
+ let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
+ match integer_rounding {
+ Some(ast::RoundingMode::NearestEven) => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::rint as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ }
+ Some(ast::RoundingMode::Zero) => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::trunc as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ }
+ Some(ast::RoundingMode::NegativeInf) => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::floor as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ }
+ Some(ast::RoundingMode::PositiveInf) => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::ceil as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ }
+ None => {
+ builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ }
+ }
+ }
+ ptx_parser::CvtMode::SignedFromFP {
+ rounding,
+ flush_to_zero,
+ } => {
+ let dest_t: ast::ScalarType = dets.to.into();
+ let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
+ builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ ptx_parser::CvtMode::UnsignedFromFP {
+ rounding,
+ flush_to_zero,
+ } => {
+ let dest_t: ast::ScalarType = dets.to.into();
+ let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
+ builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ ptx_parser::CvtMode::FPFromSigned(rounding) => {
+ let dest_t: ast::ScalarType = dets.to.into();
+ let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
+ builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ ptx_parser::CvtMode::FPFromUnsigned(rounding) => {
+ let dest_t: ast::ScalarType = dets.to.into();
+ let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
+ builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ }
+ }
+ Ok(())
+}
+
+fn emit_mad_uint(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ type_: ast::ScalarType,
+ control: ast::MulIntControl,
+ arg: &ast::MadArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_)))
+ .0;
+ match control {
+ ast::MulIntControl::Low => {
+ let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?;
+ builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?;
+ }
+ ast::MulIntControl::High => {
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::u_mad_hi as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ dr::Operand::IdRef(arg.src3.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ ast::MulIntControl::Wide => todo!(),
+ };
+ Ok(())
+}
+
+fn emit_mad_sint(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ type_: ast::ScalarType,
+ control: ast::MulIntControl,
+ arg: &ast::MadArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0;
+ match control {
+ ast::MulIntControl::Low => {
+ let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?;
+ builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?;
+ }
+ ast::MulIntControl::High => {
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::s_mad_hi as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ dr::Operand::IdRef(arg.src3.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ }
+ ast::MulIntControl::Wide => todo!(),
+ };
+ Ok(())
+}
+
+fn emit_mad_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::ArithFloat,
+ arg: &ast::MadArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
+ .0;
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::mad as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ dr::Operand::IdRef(arg.src3.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_fma_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::ArithFloat,
+ arg: &ast::FmaArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
+ .0;
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::fma as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ dr::Operand::IdRef(arg.src3.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_sub_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ typ: ast::ScalarType,
+ saturate: bool,
+ arg: &ast::SubArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ if saturate {
+ todo!()
+ }
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)))
+ .0;
+ builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ Ok(())
+}
+
+fn emit_sub_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ desc: &ast::ArithFloat,
+ arg: &ast::SubArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let inst_type = map
+ .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
+ .0;
+ builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
+ Ok(())
+}
+
+fn emit_min(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MinMaxDetails,
+ arg: &ast::MinArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let cl_op = match desc {
+ ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
+ ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
+ ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
+ };
+ let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_()));
+ builder.ext_inst(
+ inst_type.0,
+ Some(arg.dst.0),
+ opencl,
+ cl_op as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_max(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MinMaxDetails,
+ arg: &ast::MaxArgs<SpirvWord>,
+) -> Result<(), dr::Error> {
+ let cl_op = match desc {
+ ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
+ ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
+ ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
+ };
+ let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_()));
+ builder.ext_inst(
+ inst_type.0,
+ Some(arg.dst.0),
+ opencl,
+ cl_op as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1.0),
+ dr::Operand::IdRef(arg.src2.0),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
+fn emit_rcp(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::RcpData,
+ arg: &ast::RcpArgs<SpirvWord>,
+) -> Result<(), TranslateError> {
+ let is_f64 = desc.type_ == ast::ScalarType::F64;
+ let (instr_type, constant) = if is_f64 {
+ (ast::ScalarType::F64, vec_repr(1.0f64))
+ } else {
+ (ast::ScalarType::F32, vec_repr(1.0f32))
+ };
+ let result_type = map.get_or_add_scalar(builder, instr_type);
+ let rounding = match desc.kind {
+ ptx_parser::RcpKind::Approx => {
+ builder.ext_inst(
+ result_type.0,
+ Some(arg.dst.0),
+ opencl,
+ spirv::CLOp::native_recip as u32,
+ [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
+ )?;
+ return Ok(());
+ }
+ ptx_parser::RcpKind::Compliant(rounding) => rounding,
+ };
+ let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
+ builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?;
+ emit_rounding_decoration(builder, arg.dst, Some(rounding));
+ builder.decorate(
+ arg.dst.0,
+ spirv::Decoration::FPFastMathMode,
+ [dr::Operand::FPFastMathMode(
+ spirv::FPFastMathMode::ALLOW_RECIP,
+ )]
+ .iter()
+ .cloned(),
+ );
+ Ok(())
+}
+
+fn emit_atom(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ details: &ast::AtomDetails,
+ arg: &ast::AtomArgs<SpirvWord>,
+) -> Result<(), TranslateError> {
+ let spirv_op = match details.op {
+ ptx_parser::AtomicOp::And => dr::Builder::atomic_and,
+ ptx_parser::AtomicOp::Or => dr::Builder::atomic_or,
+ ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor,
+ ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange,
+ ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add,
+ ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => {
+ return Err(error_unreachable())
+ }
+ ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min,
+ ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min,
+ ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max,
+ ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max,
+ ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext,
+ ptx_parser::AtomicOp::FloatMin => todo!(),
+ ptx_parser::AtomicOp::FloatMax => todo!(),
+ };
+ let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone()));
+ let memory_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(scope_to_spirv(details.scope) as u32),
+ )?;
+ let semantics_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(semantics_to_spirv(details.semantics).bits()),
+ )?;
+ spirv_op(
+ builder,
+ result_type.0,
+ Some(arg.dst.0),
+ arg.src1.0,
+ memory_const.0,
+ semantics_const.0,
+ arg.src2.0,
+ )?;
+ Ok(())
+}
+
+fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope {
+ match this {
+ ast::MemScope::Cta => spirv::Scope::Workgroup,
+ ast::MemScope::Gpu => spirv::Scope::Device,
+ ast::MemScope::Sys => spirv::Scope::CrossDevice,
+ ptx_parser::MemScope::Cluster => todo!(),
+ }
+}
+
+fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics {
+ match this {
+ ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED,
+ ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE,
+ ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE,
+ ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE,
+ }
+}
+
+fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) {
+ match kind {
+ ast::DivFloatKind::Approx => {
+ builder.decorate(
+ dst.0,
+ spirv::Decoration::FPFastMathMode,
+ [dr::Operand::FPFastMathMode(
+ spirv::FPFastMathMode::ALLOW_RECIP,
+ )]
+ .iter()
+ .cloned(),
+ );
+ }
+ ast::DivFloatKind::Rounding(rnd) => {
+ emit_rounding_decoration(builder, dst, Some(rnd));
+ }
+ ast::DivFloatKind::ApproxFull => {}
+ }
+}
+
+fn emit_sqrt(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ details: &ast::RcpData,
+ a: &ast::SqrtArgs<SpirvWord>,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add_scalar(builder, details.type_.into());
+ let (ocl_op, rounding) = match details.kind {
+ ast::RcpKind::Approx => (spirv::CLOp::sqrt, None),
+ ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
+ };
+ builder.ext_inst(
+ result_type.0,
+ Some(a.dst.0),
+ opencl,
+ ocl_op as spirv::Word,
+ [dr::Operand::IdRef(a.src.0)].iter().cloned(),
+ )?;
+ emit_rounding_decoration(builder, a.dst, rounding);
+ Ok(())
+}
+
+// TODO: check what kind of assembly do we emit
+fn emit_logical_xor_spirv(
+ builder: &mut dr::Builder,
+ result_type: spirv::Word,
+ result_id: Option<spirv::Word>,
+ op1: spirv::Word,
+ op2: spirv::Word,
+) -> Result<spirv::Word, dr::Error> {
+ let temp_or = builder.logical_or(result_type, None, op1, op2)?;
+ let temp_and = builder.logical_and(result_type, None, op1, op2)?;
+ let temp_neg = logical_not(builder, result_type, None, temp_and)?;
+ builder.logical_and(result_type, result_id, temp_or, temp_neg)
+}
+
+fn emit_load_var(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ details: &LoadVarDetails,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone()));
+ match details.member_index {
+ Some((index, Some(width))) => {
+ let vector_type = match details.typ {
+ ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t),
+ _ => return Err(error_mismatched_type()),
+ };
+ let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
+ let vector_temp = builder.load(
+ vector_type_spirv.0,
+ None,
+ details.arg.src.0,
+ None,
+ iter::empty(),
+ )?;
+ builder.composite_extract(
+ result_type.0,
+ Some(details.arg.dst.0),
+ vector_temp,
+ [index as u32].iter().copied(),
+ )?;
+ }
+ Some((index, None)) => {
+ let result_ptr_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function),
+ );
+ let index_spirv = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(index as u32),
+ )?;
+ let src = builder.in_bounds_access_chain(
+ result_ptr_type.0,
+ None,
+ details.arg.src.0,
+ [index_spirv.0].iter().copied(),
+ )?;
+ builder.load(
+ result_type.0,
+ Some(details.arg.dst.0),
+ src,
+ None,
+ iter::empty(),
+ )?;
+ }
+ None => {
+ builder.load(
+ result_type.0,
+ Some(details.arg.dst.0),
+ details.arg.src.0,
+ None,
+ iter::empty(),
+ )?;
+ }
+ };
+ Ok(())
+}
+
+fn to_parts(this: &ast::Type) -> TypeParts {
+ match this {
+ ast::Type::Scalar(scalar) => TypeParts {
+ kind: TypeKind::Scalar,
+ state_space: ast::StateSpace::Reg,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: Vec::new(),
+ },
+ ast::Type::Vector(components, scalar) => TypeParts {
+ kind: TypeKind::Vector,
+ state_space: ast::StateSpace::Reg,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: vec![*components as u32],
+ },
+ ast::Type::Array(_, scalar, components) => TypeParts {
+ kind: TypeKind::Array,
+ state_space: ast::StateSpace::Reg,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: components.clone(),
+ },
+ ast::Type::Pointer(scalar, space) => TypeParts {
+ kind: TypeKind::Pointer,
+ state_space: *space,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: Vec::new(),
+ },
+ }
+}
+
+fn type_from_parts(t: TypeParts) -> ast::Type {
+ match t.kind {
+ TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)),
+ TypeKind::Vector => ast::Type::Vector(
+ t.components[0] as u8,
+ scalar_from_parts(t.width, t.scalar_kind),
+ ),
+ TypeKind::Array => ast::Type::Array(
+ None,
+ scalar_from_parts(t.width, t.scalar_kind),
+ t.components,
+ ),
+ TypeKind::Pointer => {
+ ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space)
+ }
+ }
+}
+
+#[derive(Eq, PartialEq, Clone)]
+struct TypeParts {
+ kind: TypeKind,
+ scalar_kind: ast::ScalarKind,
+ width: u8,
+ state_space: ast::StateSpace,
+ components: Vec<u32>,
+}
+
+#[derive(Eq, PartialEq, Copy, Clone)]
+enum TypeKind {
+ Scalar,
+ Vector,
+ Array,
+ Pointer,
+}
diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs new file mode 100644 index 0000000..d0c7c98 --- /dev/null +++ b/ptx/src/pass/expand_arguments.rs @@ -0,0 +1,181 @@ +use super::*;
+use ptx_parser as ast;
+
+pub(super) fn run<'a, 'b>(
+ func: Vec<TypedStatement>,
+ id_def: &'b mut MutableNumericIdResolver<'a>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
+ Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
+ Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
+ Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
+ Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
+ Statement::Constant(c) => result.push(Statement::Constant(c)),
+ Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
+ s => {
+ let (new_statement, post_stmts) = {
+ let mut visitor = FlattenArguments::new(&mut result, id_def);
+ (s.visit_map(&mut visitor)?, visitor.post_stmts)
+ };
+ result.push(new_statement);
+ result.extend(post_stmts);
+ }
+ }
+ }
+ Ok(result)
+}
+
+struct FlattenArguments<'a, 'b> {
+ func: &'b mut Vec<ExpandedStatement>,
+ id_def: &'b mut MutableNumericIdResolver<'a>,
+ post_stmts: Vec<ExpandedStatement>,
+}
+
+impl<'a, 'b> FlattenArguments<'a, 'b> {
+ fn new(
+ func: &'b mut Vec<ExpandedStatement>,
+ id_def: &'b mut MutableNumericIdResolver<'a>,
+ ) -> Self {
+ FlattenArguments {
+ func,
+ id_def,
+ post_stmts: Vec::new(),
+ }
+ }
+
+ fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
+ Ok(name)
+ }
+
+ fn reg_offset(
+ &mut self,
+ reg: SpirvWord,
+ offset: i32,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ _is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (type_, state_space) = if let Some((type_, state_space)) = type_space {
+ (type_, state_space)
+ } else {
+ return Err(TranslateError::UntypedSymbol);
+ };
+ if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg {
+ let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
+ if !space_is_compatible(reg_space, ast::StateSpace::Reg) {
+ return Err(error_mismatched_type());
+ }
+ let reg_scalar_type = match reg_type {
+ ast::Type::Scalar(underlying_type) => underlying_type,
+ _ => return Err(error_mismatched_type()),
+ };
+ let id_constant_stmt = self
+ .id_def
+ .register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: reg_scalar_type,
+ value: ast::ImmediateValue::S64(offset as i64),
+ }));
+ let arith_details = match reg_scalar_type.kind() {
+ ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: reg_scalar_type,
+ saturate: false,
+ }),
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: reg_scalar_type,
+ saturate: false,
+ })
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
+ self.func
+ .push(Statement::Instruction(ast::Instruction::Add {
+ data: arith_details,
+ arguments: ast::AddArgs {
+ dst: id_add_result,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ }));
+ Ok(id_add_result)
+ } else {
+ let id_constant_stmt = self.id_def.register_intermediate(
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ );
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: ast::ScalarType::S64,
+ value: ast::ImmediateValue::S64(offset as i64),
+ }));
+ let dst = self
+ .id_def
+ .register_intermediate(type_.clone(), state_space);
+ self.func.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: type_.clone(),
+ state_space: state_space,
+ dst,
+ ptr_src: reg,
+ offset_src: id_constant_stmt,
+ }));
+ Ok(dst)
+ }
+ }
+
+ fn immediate(
+ &mut self,
+ value: ast::ImmediateValue,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (scalar_t, state_space) =
+ if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
+ (*scalar, state_space)
+ } else {
+ return Err(TranslateError::UntypedSymbol);
+ };
+ let id = self
+ .id_def
+ .register_intermediate(ast::Type::Scalar(scalar_t), state_space);
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value,
+ }));
+ Ok(id)
+ }
+}
+
+impl<'a, 'b> ast::VisitorMap<TypedOperand, SpirvWord, TranslateError> for FlattenArguments<'a, 'b> {
+ fn visit(
+ &mut self,
+ args: TypedOperand,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ match args {
+ TypedOperand::Reg(r) => self.reg(r),
+ TypedOperand::Imm(x) => self.immediate(x, type_space),
+ TypedOperand::RegOffset(reg, offset) => {
+ self.reg_offset(reg, offset, type_space, is_dst)
+ }
+ TypedOperand::VecMember(..) => Err(error_unreachable()),
+ }
+ }
+
+ fn visit_ident(
+ &mut self,
+ name: <TypedOperand as ptx_parser::Operand>::Ident,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ _is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, TranslateError> {
+ self.reg(name)
+ }
+}
diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs new file mode 100644 index 0000000..680a5ee --- /dev/null +++ b/ptx/src/pass/extract_globals.rs @@ -0,0 +1,282 @@ +use super::*;
+
+pub(super) fn run<'input, 'b>(
+ sorted_statements: Vec<ExpandedStatement>,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ id_def: &mut NumericIdResolver,
+) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
+ let mut local = Vec::with_capacity(sorted_statements.len());
+ let mut global = Vec::new();
+ for statement in sorted_statements {
+ match statement {
+ Statement::Variable(
+ var @ ast::Variable {
+ state_space: ast::StateSpace::Shared,
+ ..
+ },
+ )
+ | Statement::Variable(
+ var @ ast::Variable {
+ state_space: ast::StateSpace::Global,
+ ..
+ },
+ ) => global.push(var),
+ Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bfe { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bfi { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
+ let fn_name: String =
+ [ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Brev { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Activemask { arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::IncrementWrap,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_inc",
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::DecrementWrap,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_dec",
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Atom {
+ data:
+ data @ ast::AtomDetails {
+ op: ast::AtomicOp::FloatAdd,
+ semantics,
+ scope,
+ space,
+ ..
+ },
+ arguments,
+ }) => {
+ let scalar_type = match data.type_ {
+ ptx_parser::Type::Scalar(scalar) => scalar,
+ _ => return Err(error_unreachable()),
+ };
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ semantics_to_ptx_name(semantics),
+ "_",
+ scope_to_ptx_name(scope),
+ "_",
+ space_to_ptx_name(space),
+ "_add_",
+ scalar_to_ptx_name(scalar_type),
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Atom { data, arguments },
+ fn_name,
+ )?);
+ }
+ s => local.push(s),
+ }
+ }
+ Ok((local, global))
+}
+
+fn instruction_to_fn_call(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ inst: ast::Instruction<SpirvWord>,
+ fn_name: String,
+) -> Result<ExpandedStatement, TranslateError> {
+ let mut arguments = Vec::new();
+ ast::visit_map(inst, &mut |operand,
+ type_space: Option<(
+ &ast::Type,
+ ast::StateSpace,
+ )>,
+ is_dst,
+ _| {
+ let (typ, space) = match type_space {
+ Some((typ, space)) => (typ.clone(), space),
+ None => return Err(error_unreachable()),
+ };
+ arguments.push((operand, is_dst, typ, space));
+ Ok(SpirvWord(0))
+ })?;
+ let return_arguments_count = arguments
+ .iter()
+ .position(|(desc, is_dst, _, _)| !is_dst)
+ .unwrap_or(arguments.len());
+ let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
+ let fn_id = register_external_fn_call(
+ id_defs,
+ ptx_impl_imports,
+ fn_name,
+ return_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ, *state)),
+ input_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ, *state)),
+ )?;
+ Ok(Statement::Instruction(ast::Instruction::Call {
+ data: ast::CallDetails {
+ uniform: false,
+ return_arguments: return_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ.clone(), *state))
+ .collect::<Vec<_>>(),
+ input_arguments: input_arguments
+ .iter()
+ .map(|(_, _, typ, state)| (typ.clone(), *state))
+ .collect::<Vec<_>>(),
+ },
+ arguments: ast::CallArgs {
+ return_arguments: return_arguments
+ .iter()
+ .map(|(name, _, _, _)| *name)
+ .collect::<Vec<_>>(),
+ func: fn_id,
+ input_arguments: input_arguments
+ .iter()
+ .map(|(name, _, _, _)| *name)
+ .collect::<Vec<_>>(),
+ },
+ }))
+}
+
+fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
+ match this {
+ ast::ScalarType::B8 => "b8",
+ ast::ScalarType::B16 => "b16",
+ ast::ScalarType::B32 => "b32",
+ ast::ScalarType::B64 => "b64",
+ ast::ScalarType::B128 => "b128",
+ ast::ScalarType::U8 => "u8",
+ ast::ScalarType::U16 => "u16",
+ ast::ScalarType::U16x2 => "u16x2",
+ ast::ScalarType::U32 => "u32",
+ ast::ScalarType::U64 => "u64",
+ ast::ScalarType::S8 => "s8",
+ ast::ScalarType::S16 => "s16",
+ ast::ScalarType::S16x2 => "s16x2",
+ ast::ScalarType::S32 => "s32",
+ ast::ScalarType::S64 => "s64",
+ ast::ScalarType::F16 => "f16",
+ ast::ScalarType::F16x2 => "f16x2",
+ ast::ScalarType::F32 => "f32",
+ ast::ScalarType::F64 => "f64",
+ ast::ScalarType::BF16 => "bf16",
+ ast::ScalarType::BF16x2 => "bf16x2",
+ ast::ScalarType::Pred => "pred",
+ }
+}
+
+fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
+ match this {
+ ast::AtomSemantics::Relaxed => "relaxed",
+ ast::AtomSemantics::Acquire => "acquire",
+ ast::AtomSemantics::Release => "release",
+ ast::AtomSemantics::AcqRel => "acq_rel",
+ }
+}
+
+fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
+ match this {
+ ast::MemScope::Cta => "cta",
+ ast::MemScope::Gpu => "gpu",
+ ast::MemScope::Sys => "sys",
+ ast::MemScope::Cluster => "cluster",
+ }
+}
+
+fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
+ match this {
+ ast::StateSpace::Generic => "generic",
+ ast::StateSpace::Global => "global",
+ ast::StateSpace::Shared => "shared",
+ ast::StateSpace::Reg => "reg",
+ ast::StateSpace::Const => "const",
+ ast::StateSpace::Local => "local",
+ ast::StateSpace::Param => "param",
+ ast::StateSpace::Sreg => "sreg",
+ ast::StateSpace::SharedCluster => "shared_cluster",
+ ast::StateSpace::ParamEntry => "param_entry",
+ ast::StateSpace::SharedCta => "shared_cta",
+ ast::StateSpace::ParamFunc => "param_func",
+ }
+}
diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs new file mode 100644 index 0000000..c029016 --- /dev/null +++ b/ptx/src/pass/fix_special_registers.rs @@ -0,0 +1,130 @@ +use super::*;
+use std::collections::HashMap;
+
+pub(super) fn run<'a, 'b, 'input>(
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ typed_statements: Vec<TypedStatement>,
+ numeric_id_defs: &'a mut NumericIdResolver<'b>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let result = Vec::with_capacity(typed_statements.len());
+ let mut sreg_sresolver = SpecialRegisterResolver {
+ ptx_impl_imports,
+ numeric_id_defs,
+ result,
+ };
+ for statement in typed_statements {
+ let statement = statement.visit_map(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(statement);
+ }
+ Ok(sreg_sresolver.result)
+}
+
+struct SpecialRegisterResolver<'a, 'b, 'input> {
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ numeric_id_defs: &'a mut NumericIdResolver<'b>,
+ result: Vec<TypedStatement>,
+}
+
+impl<'a, 'b, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
+ for SpecialRegisterResolver<'a, 'b, 'input>
+{
+ fn visit(
+ &mut self,
+ operand: TypedOperand,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<TypedOperand, TranslateError> {
+ operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index))
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ self.replace_sreg(args, is_dst, None)
+ }
+}
+
+impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
+ fn replace_sreg(
+ &mut self,
+ name: SpirvWord,
+ is_dst: bool,
+ vector_index: Option<u8>,
+ ) -> Result<SpirvWord, TranslateError> {
+ if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
+ if is_dst {
+ return Err(error_mismatched_type());
+ }
+ let input_arguments = match (vector_index, sreg.get_function_input_type()) {
+ (Some(idx), Some(inp_type)) => {
+ if inp_type != ast::ScalarType::U8 {
+ return Err(TranslateError::Unreachable);
+ }
+ let constant = self.numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: constant,
+ typ: inp_type,
+ value: ast::ImmediateValue::U64(idx as u64),
+ }));
+ vec![(
+ TypedOperand::Reg(constant),
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )]
+ }
+ (None, None) => Vec::new(),
+ _ => return Err(error_mismatched_type()),
+ };
+ let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
+ let return_type = sreg.get_function_return_type();
+ let fn_result = self.numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )));
+ let return_arguments = vec![(
+ fn_result,
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )];
+ let fn_call = register_external_fn_call(
+ self.numeric_id_defs,
+ self.ptx_impl_imports,
+ ocl_fn_name.to_string(),
+ return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ )?;
+ let data = ast::CallDetails {
+ uniform: false,
+ return_arguments: return_arguments
+ .iter()
+ .map(|(_, typ, space)| (typ.clone(), *space))
+ .collect(),
+ input_arguments: input_arguments
+ .iter()
+ .map(|(_, typ, space)| (typ.clone(), *space))
+ .collect(),
+ };
+ let arguments = ast::CallArgs {
+ return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
+ func: fn_call,
+ input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(),
+ };
+ self.result
+ .push(Statement::Instruction(ast::Instruction::Call {
+ data,
+ arguments,
+ }));
+ Ok(fn_result)
+ } else {
+ Ok(name)
+ }
+ }
+}
diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs new file mode 100644 index 0000000..25e80f0 --- /dev/null +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -0,0 +1,432 @@ +use std::mem;
+
+use super::*;
+use ptx_parser as ast;
+
+/*
+ There are several kinds of implicit conversions in PTX:
+ * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
+ * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
+ - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
+ semantics are to first zext/chop/bitcast `y` as needed and then do
+ documented special ld/st/cvt conversion rules for destination operands
+ - st.param [x] y (used as function return arguments) same rule as above applies
+ - generic/global ld: for instruction `ld x, [y]`, y must be of type
+ b64/u64/s64, which is bitcast to a pointer, dereferenced and then
+ documented special ld/st/cvt conversion rules are applied to dst
+ - generic/global st: for instruction `st [x], y`, x must be of type
+ b64/u64/s64, which is bitcast to a pointer
+*/
+pub(super) fn run(
+ func: Vec<ExpandedStatement>,
+ id_def: &mut MutableNumericIdResolver,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func.into_iter() {
+ match s {
+ Statement::Instruction(inst) => {
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ Statement::Instruction(inst),
+ )?;
+ }
+ Statement::PtrAccess(access) => {
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ Statement::PtrAccess(access),
+ )?;
+ }
+ Statement::RepackVector(repack) => {
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ Statement::RepackVector(repack),
+ )?;
+ }
+ s @ Statement::Conditional(_)
+ | s @ Statement::Conversion(_)
+ | s @ Statement::Label(_)
+ | s @ Statement::Constant(_)
+ | s @ Statement::Variable(_)
+ | s @ Statement::LoadVar(..)
+ | s @ Statement::StoreVar(..)
+ | s @ Statement::RetValue(..)
+ | s @ Statement::FunctionPointer(..) => result.push(s),
+ }
+ }
+ Ok(result)
+}
+
+fn insert_implicit_conversions_impl(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut MutableNumericIdResolver,
+ stmt: ExpandedStatement,
+) -> Result<(), TranslateError> {
+ let mut post_conv = Vec::new();
+ let statement = stmt.visit_map::<SpirvWord, TranslateError>(
+ &mut |operand,
+ type_state: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst,
+ relaxed_type_check| {
+ let (instr_type, instruction_space) = match type_state {
+ None => return Ok(operand),
+ Some(t) => t,
+ };
+ let (operand_type, operand_space) = id_def.get_typed(operand)?;
+ let conversion_fn = if relaxed_type_check {
+ if is_dst {
+ should_convert_relaxed_dst_wrapper
+ } else {
+ should_convert_relaxed_src_wrapper
+ }
+ } else {
+ default_implicit_conversion
+ };
+ match conversion_fn(
+ (operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
+ Some(conv_kind) => {
+ let conv_output = if is_dst { &mut post_conv } else { &mut *func };
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type;
+ let mut to_space = operand_space;
+ let mut src =
+ id_def.register_intermediate(instr_type.clone(), instruction_space);
+ let mut dst = operand;
+ let result = Ok::<_, TranslateError>(src);
+ if !is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
+ kind: conv_kind,
+ }));
+ result
+ }
+ None => Ok(operand),
+ }
+ },
+ )?;
+ func.push(statement);
+ func.append(&mut post_conv);
+ Ok(())
+}
+
+pub(crate) fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instruction_space == ast::StateSpace::Reg {
+ if space_is_compatible(operand_space, ast::StateSpace::Reg) {
+ if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
+ {
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
+ }
+ } else if is_addressable(operand_space) {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ }
+ if !space_is_compatible(instruction_space, operand_space) {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
+}
+
+fn is_addressable(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
+ ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc => todo!(),
+ }
+}
+
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
+ || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if space_is_compatible(operand_space, ast::StateSpace::Reg) {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(error_mismatched_type()),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(error_mismatched_type()),
+ },
+ _ => Err(error_mismatched_type()),
+ }
+ } else if space_is_compatible(instruction_space, ast::StateSpace::Reg) {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ _ => Err(error_mismatched_type()),
+ }
+ } else {
+ Err(error_mismatched_type())
+ }
+}
+
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if space_is_compatible(space, ast::StateSpace::Reg) {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
+ }
+}
+
+fn coerces_to_generic(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ptx_parser::StateSpace::SharedCta
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Sreg => false,
+ }
+}
+
+fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
+ match (instr, operand) {
+ (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
+ if inst.size_of() != operand.size_of() {
+ return false;
+ }
+ match inst.kind() {
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
+ }
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
+ }
+ ast::ScalarKind::Pred => false,
+ }
+ }
+ (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
+ | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
+ should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
+ }
+ _ => false,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_dst_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if !space_is_compatible(operand_space, instruction_space) {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
+fn should_convert_relaxed_dst(
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if dst_type == instr_type {
+ return None;
+ }
+ match (dst_type, instr_type) {
+ (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
+ if instr_type.size_of() == dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else if instr_type.size_of() < dst_type.size_of() {
+ Some(ConversionKind::SignExtend)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_dst(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_src_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if !space_is_compatible(operand_space, instruction_space) {
+ return Err(error_mismatched_type());
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_src(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(error_mismatched_type()),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
+fn should_convert_relaxed_src(
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if src_type == instr_type {
+ return None;
+ }
+ match (src_type, instr_type) {
+ (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= src_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_src(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs new file mode 100644 index 0000000..e314b05 --- /dev/null +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -0,0 +1,275 @@ +use super::*;
+use ptx_parser as ast;
+
+/*
+ How do we handle arguments:
+ - input .params in kernels
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ We do this for two reasons. One, common treatment for argument-declared
+ .param variables and .param variables inside function (we assume that
+ at SPIR-V level every .param is a pointer in Function storage class)
+ - input .params in functions
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %_ptr_Function_ulong
+ - input .regs
+ .reg .b64 in_arg
+ get turned into the same SPIR-V as kernel .params:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ - output .regs
+ .reg .b64 out_arg
+ get just a variable declaration:
+ %2 = OpVariable %%_ptr_Function_ulong Function
+ - output .params don't exist, they have been moved to input positions
+ by an earlier pass
+ Distinguishing betweem kernel .params and function .params is not the
+ cleanest solution. Alternatively, we could "deparamize" all kernel .param
+ arguments by turning them into .reg arguments like this:
+ .param .b64 arg -> .reg ptr<.b64,.param> arg
+ This has the massive downside that this transformation would have to run
+ very early and would muddy up already difficult code. It's simpler to just
+ have an if here
+*/
+pub(super) fn run<'a, 'b>(
+ func: Vec<TypedStatement>,
+ id_def: &mut NumericIdResolver,
+ fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for arg in fn_decl.input_arguments.iter_mut() {
+ insert_mem_ssa_argument(
+ id_def,
+ &mut result,
+ arg,
+ matches!(fn_decl.name, ast::MethodName::Kernel(_)),
+ );
+ }
+ for arg in fn_decl.return_arguments.iter() {
+ insert_mem_ssa_argument_reg_return(&mut result, arg);
+ }
+ for s in func {
+ match s {
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::Ret { data } => {
+ // TODO: handle multiple output args
+ match &fn_decl.return_arguments[..] {
+ [return_reg] => {
+ let new_id = id_def.register_intermediate(Some((
+ return_reg.v_type.clone(),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::LdArgs {
+ dst: new_id,
+ src: return_reg.name,
+ },
+ typ: return_reg.v_type.clone(),
+ member_index: None,
+ }));
+ result.push(Statement::RetValue(data, new_id));
+ }
+ [] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
+ _ => unimplemented!(),
+ }
+ }
+ inst => insert_mem_ssa_statement_default(
+ id_def,
+ &mut result,
+ Statement::Instruction(inst),
+ )?,
+ },
+ Statement::Conditional(bra) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
+ }
+ Statement::Conversion(conv) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
+ }
+ Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
+ id_def,
+ &mut result,
+ Statement::PtrAccess(ptr_access),
+ )?,
+ Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
+ id_def,
+ &mut result,
+ Statement::RepackVector(repack),
+ )?,
+ Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
+ id_def,
+ &mut result,
+ Statement::FunctionPointer(func_ptr),
+ )?,
+ s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
+ result.push(s)
+ }
+ _ => return Err(error_unreachable()),
+ }
+ }
+ Ok(result)
+}
+
+fn insert_mem_ssa_argument(
+ id_def: &mut NumericIdResolver,
+ func: &mut Vec<TypedStatement>,
+ arg: &mut ast::Variable<SpirvWord>,
+ is_kernel: bool,
+) {
+ if !is_kernel && arg.state_space == ast::StateSpace::Param {
+ return;
+ }
+ let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: ast::StateSpace::Reg,
+ name: arg.name,
+ array_init: Vec::new(),
+ }));
+ func.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::StArgs {
+ src1: arg.name,
+ src2: new_id,
+ },
+ typ: arg.v_type.clone(),
+ member_index: None,
+ }));
+ arg.name = new_id;
+}
+
+fn insert_mem_ssa_argument_reg_return(
+ func: &mut Vec<TypedStatement>,
+ arg: &ast::Variable<SpirvWord>,
+) {
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: arg.array_init.clone(),
+ }));
+}
+
+fn insert_mem_ssa_statement_default<'a, 'input>(
+ id_def: &'a mut NumericIdResolver<'input>,
+ func: &'a mut Vec<TypedStatement>,
+ stmt: TypedStatement,
+) -> Result<(), TranslateError> {
+ let mut visitor = InsertMemSSAVisitor {
+ id_def,
+ func,
+ post_statements: Vec::new(),
+ };
+ let new_stmt = stmt.visit_map(&mut visitor)?;
+ visitor.func.push(new_stmt);
+ visitor.func.extend(visitor.post_statements);
+ Ok(())
+}
+
+struct InsertMemSSAVisitor<'a, 'input> {
+ id_def: &'a mut NumericIdResolver<'input>,
+ func: &'a mut Vec<TypedStatement>,
+ post_statements: Vec<TypedStatement>,
+}
+
+impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
+ fn symbol(
+ &mut self,
+ symbol: SpirvWord,
+ member_index: Option<u8>,
+ expected: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ if expected.is_none() {
+ return Ok(symbol);
+ };
+ let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
+ if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
+ return Ok(symbol);
+ };
+ let member_index = match member_index {
+ Some(idx) => {
+ let vector_width = match var_type {
+ ast::Type::Vector(width, scalar_t) => {
+ var_type = ast::Type::Scalar(scalar_t);
+ width
+ }
+ _ => return Err(error_mismatched_type()),
+ };
+ Some((
+ idx,
+ if self.id_def.special_registers.get(symbol).is_some() {
+ Some(vector_width)
+ } else {
+ None
+ },
+ ))
+ }
+ None => None,
+ };
+ let generated_id = self
+ .id_def
+ .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
+ if !is_dst {
+ self.func.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::LdArgs {
+ dst: generated_id,
+ src: symbol,
+ },
+ typ: var_type,
+ member_index,
+ }));
+ } else {
+ self.post_statements
+ .push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::StArgs {
+ src1: symbol,
+ src2: generated_id,
+ },
+ typ: var_type,
+ member_index: member_index.map(|(idx, _)| idx),
+ }));
+ }
+ Ok(generated_id)
+ }
+}
+
+impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
+ for InsertMemSSAVisitor<'a, 'input>
+{
+ fn visit(
+ &mut self,
+ operand: TypedOperand,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match operand {
+ TypedOperand::Reg(reg) => {
+ TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
+ }
+ TypedOperand::RegOffset(reg, offset) => {
+ TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
+ }
+ op @ TypedOperand::Imm(..) => op,
+ TypedOperand::VecMember(symbol, index) => {
+ TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?)
+ }
+ })
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ self.symbol(args, None, type_space, is_dst)
+ }
+}
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs new file mode 100644 index 0000000..2be6297 --- /dev/null +++ b/ptx/src/pass/mod.rs @@ -0,0 +1,1677 @@ +use ptx_parser as ast;
+use rspirv::{binary::Assemble, dr};
+use std::hash::Hash;
+use std::num::NonZeroU8;
+use std::{
+ borrow::Cow,
+ cell::RefCell,
+ collections::{hash_map, HashMap, HashSet},
+ ffi::CString,
+ iter,
+ marker::PhantomData,
+ mem,
+ rc::Rc,
+};
+
+mod convert_dynamic_shared_memory_usage;
+mod convert_to_stateful_memory_access;
+mod convert_to_typed;
+mod emit_spirv;
+mod expand_arguments;
+mod extract_globals;
+mod fix_special_registers;
+mod insert_implicit_conversions;
+mod insert_mem_ssa_statements;
+mod normalize_identifiers;
+mod normalize_labels;
+mod normalize_predicates;
+
+static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
+static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
+const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
+
+pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
+ let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1));
+ let mut ptx_impl_imports = HashMap::new();
+ let directives = ast
+ .directives
+ .into_iter()
+ .filter_map(|directive| {
+ translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ let directives = hoist_function_globals(directives);
+ let must_link_ptx_impl = ptx_impl_imports.len() > 0;
+ let mut directives = ptx_impl_imports
+ .into_iter()
+ .map(|(_, v)| v)
+ .chain(directives.into_iter())
+ .collect::<Vec<_>>();
+ let mut builder = dr::Builder::new();
+ builder.reserve_ids(id_defs.current_id().0);
+ let call_map = MethodsCallMap::new(&directives);
+ let mut directives =
+ convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || {
+ SpirvWord(builder.id())
+ })?;
+ normalize_variable_decls(&mut directives);
+ let denorm_information = compute_denorm_information(&directives);
+ let (spirv, kernel_info, build_options) =
+ emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives)?;
+ Ok(Module {
+ spirv,
+ kernel_info,
+ should_link_ptx_impl: if must_link_ptx_impl {
+ Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD))
+ } else {
+ None
+ },
+ build_options,
+ })
+}
+
+fn translate_directive<'input, 'a>(
+ id_defs: &'a mut GlobalStringIdResolver<'input>,
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ d: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
+) -> Result<Option<Directive<'input>>, TranslateError> {
+ Ok(match d {
+ ast::Directive::Variable(linking, var) => Some(Directive::Variable(
+ linking,
+ ast::Variable {
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
+ array_init: var.array_init,
+ },
+ )),
+ ast::Directive::Method(linkage, f) => {
+ translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method)
+ }
+ })
+}
+
+type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement<ast::ParsedOperand<&'a str>>>;
+
+fn translate_function<'input, 'a>(
+ id_defs: &'a mut GlobalStringIdResolver<'input>,
+ ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ linkage: ast::LinkingDirective,
+ f: ParsedFunction<'input>,
+) -> Result<Option<Function<'input>>, TranslateError> {
+ let import_as = match &f.func_directive {
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func(func_name),
+ ..
+ } if *func_name == "__assertfail" || *func_name == "vprintf" => {
+ Some([ZLUDA_PTX_PREFIX, func_name].concat())
+ }
+ _ => None,
+ };
+ let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
+ let mut func = to_ssa(
+ ptx_impl_imports,
+ str_resolver,
+ fn_resolver,
+ fn_decl,
+ f.body,
+ f.tuning,
+ linkage,
+ )?;
+ func.import_as = import_as;
+ if func.import_as.is_some() {
+ ptx_impl_imports.insert(
+ func.import_as.as_ref().unwrap().clone(),
+ Directive::Method(func),
+ );
+ Ok(None)
+ } else {
+ Ok(Some(func))
+ }
+}
+
+fn to_ssa<'input, 'b>(
+ ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
+ mut id_defs: FnStringIdResolver<'input, 'b>,
+ fn_defs: GlobalFnDeclResolver<'input, 'b>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ f_body: Option<Vec<ast::Statement<ast::ParsedOperand<&'input str>>>>,
+ tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
+) -> Result<Function<'input>, TranslateError> {
+ //deparamize_function_decl(&func_decl)?;
+ let f_body = match f_body {
+ Some(vec) => vec,
+ None => {
+ return Ok(Function {
+ func_decl: func_decl,
+ body: None,
+ globals: Vec::new(),
+ import_as: None,
+ tuning,
+ linkage,
+ })
+ }
+ };
+ let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?;
+ let mut numeric_id_defs = id_defs.finish();
+ let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?;
+ let typed_statements =
+ convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
+ let typed_statements =
+ fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
+ let (func_decl, typed_statements) =
+ convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
+ let ssa_statements = insert_mem_ssa_statements::run(
+ typed_statements,
+ &mut numeric_id_defs,
+ &mut (*func_decl).borrow_mut(),
+ )?;
+ let mut numeric_id_defs = numeric_id_defs.finish();
+ let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?;
+ let expanded_statements =
+ insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?;
+ let mut numeric_id_defs = numeric_id_defs.unmut();
+ let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs);
+ let (f_body, globals) =
+ extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
+ Ok(Function {
+ func_decl: func_decl,
+ globals: globals,
+ body: Some(f_body),
+ import_as: None,
+ tuning,
+ linkage,
+ })
+}
+
+pub struct Module {
+ pub spirv: dr::Module,
+ pub kernel_info: HashMap<String, KernelInfo>,
+ pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>,
+ pub build_options: CString,
+}
+
+impl Module {
+ pub fn assemble(&self) -> Vec<u32> {
+ self.spirv.assemble()
+ }
+}
+
+struct GlobalStringIdResolver<'input> {
+ current_id: SpirvWord,
+ variables: HashMap<Cow<'input, str>, SpirvWord>,
+ reverse_variables: HashMap<SpirvWord, &'input str>,
+ variables_type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+ special_registers: SpecialRegistersMap,
+ fns: HashMap<SpirvWord, FnSigMapper<'input>>,
+}
+
+impl<'input> GlobalStringIdResolver<'input> {
+ fn new(start_id: SpirvWord) -> Self {
+ Self {
+ current_id: start_id,
+ variables: HashMap::new(),
+ reverse_variables: HashMap::new(),
+ variables_type_check: HashMap::new(),
+ special_registers: SpecialRegistersMap::new(),
+ fns: HashMap::new(),
+ }
+ }
+
+ fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord {
+ self.get_or_add_impl(id, None)
+ }
+
+ fn get_or_add_def_typed(
+ &mut self,
+ id: &'input str,
+ typ: ast::Type,
+ state_space: ast::StateSpace,
+ is_variable: bool,
+ ) -> SpirvWord {
+ self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
+ }
+
+ fn get_or_add_impl(
+ &mut self,
+ id: &'input str,
+ typ: Option<(ast::Type, ast::StateSpace, bool)>,
+ ) -> SpirvWord {
+ let id = match self.variables.entry(Cow::Borrowed(id)) {
+ hash_map::Entry::Occupied(e) => *(e.get()),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = self.current_id;
+ e.insert(numeric_id);
+ self.reverse_variables.insert(numeric_id, id);
+ self.current_id.0 += 1;
+ numeric_id
+ }
+ };
+ self.variables_type_check.insert(id, typ);
+ id
+ }
+
+ fn get_id(&self, id: &str) -> Result<SpirvWord, TranslateError> {
+ self.variables
+ .get(id)
+ .copied()
+ .ok_or_else(error_unknown_symbol)
+ }
+
+ fn current_id(&self) -> SpirvWord {
+ self.current_id
+ }
+
+ fn start_fn<'b>(
+ &'b mut self,
+ header: &'b ast::MethodDeclaration<'input, &'input str>,
+ ) -> Result<
+ (
+ FnStringIdResolver<'input, 'b>,
+ GlobalFnDeclResolver<'input, 'b>,
+ Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ ),
+ TranslateError,
+ > {
+ // In case a function decl was inserted earlier we want to use its id
+ let name_id = self.get_or_add_def(header.name());
+ let mut fn_resolver = FnStringIdResolver {
+ current_id: &mut self.current_id,
+ global_variables: &self.variables,
+ global_type_check: &self.variables_type_check,
+ special_registers: &mut self.special_registers,
+ variables: vec![HashMap::new(); 1],
+ type_check: HashMap::new(),
+ };
+ let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
+ let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
+ let name = match header.name {
+ ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
+ ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
+ };
+ let fn_decl = ast::MethodDeclaration {
+ return_arguments,
+ name,
+ input_arguments,
+ shared_mem: None,
+ };
+ let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) {
+ let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
+ let new_fn_decl = resolver.func_decl.clone();
+ self.fns.insert(name_id, resolver);
+ new_fn_decl
+ } else {
+ Rc::new(RefCell::new(fn_decl))
+ };
+ Ok((
+ fn_resolver,
+ GlobalFnDeclResolver { fns: &self.fns },
+ new_fn_decl,
+ ))
+ }
+}
+
+fn rename_fn_params<'a, 'b>(
+ fn_resolver: &mut FnStringIdResolver<'a, 'b>,
+ args: &'b [ast::Variable<&'a str>],
+) -> Vec<ast::Variable<SpirvWord>> {
+ args.iter()
+ .map(|a| ast::Variable {
+ name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
+ v_type: a.v_type.clone(),
+ state_space: a.state_space,
+ align: a.align,
+ array_init: a.array_init.clone(),
+ })
+ .collect()
+}
+
+pub struct KernelInfo {
+ pub arguments_sizes: Vec<(usize, bool)>,
+ pub uses_shared_mem: bool,
+}
+
+#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
+enum PtxSpecialRegister {
+ Tid,
+ Ntid,
+ Ctaid,
+ Nctaid,
+ Clock,
+ LanemaskLt,
+}
+
+impl PtxSpecialRegister {
+ fn try_parse(s: &str) -> Option<Self> {
+ match s {
+ "%tid" => Some(Self::Tid),
+ "%ntid" => Some(Self::Ntid),
+ "%ctaid" => Some(Self::Ctaid),
+ "%nctaid" => Some(Self::Nctaid),
+ "%clock" => Some(Self::Clock),
+ "%lanemask_lt" => Some(Self::LanemaskLt),
+ _ => None,
+ }
+ }
+
+ fn get_type(self) -> ast::Type {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => ast::Type::Vector(4, self.get_function_return_type()),
+ _ => ast::Type::Scalar(self.get_function_return_type()),
+ }
+ }
+
+ fn get_function_return_type(self) -> ast::ScalarType {
+ match self {
+ PtxSpecialRegister::Tid => ast::ScalarType::U32,
+ PtxSpecialRegister::Ntid => ast::ScalarType::U32,
+ PtxSpecialRegister::Ctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Clock => ast::ScalarType::U32,
+ PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32,
+ }
+ }
+
+ fn get_function_input_type(self) -> Option<ast::ScalarType> {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8),
+ PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None,
+ }
+ }
+
+ fn get_unprefixed_function_name(self) -> &'static str {
+ match self {
+ PtxSpecialRegister::Tid => "sreg_tid",
+ PtxSpecialRegister::Ntid => "sreg_ntid",
+ PtxSpecialRegister::Ctaid => "sreg_ctaid",
+ PtxSpecialRegister::Nctaid => "sreg_nctaid",
+ PtxSpecialRegister::Clock => "sreg_clock",
+ PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt",
+ }
+ }
+}
+
+struct SpecialRegistersMap {
+ reg_to_id: HashMap<PtxSpecialRegister, SpirvWord>,
+ id_to_reg: HashMap<SpirvWord, PtxSpecialRegister>,
+}
+
+impl SpecialRegistersMap {
+ fn new() -> Self {
+ SpecialRegistersMap {
+ reg_to_id: HashMap::new(),
+ id_to_reg: HashMap::new(),
+ }
+ }
+
+ fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
+ self.id_to_reg.get(&id).copied()
+ }
+
+ fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
+ match self.reg_to_id.entry(reg) {
+ hash_map::Entry::Occupied(e) => *e.get(),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = SpirvWord(current_id.0);
+ current_id.0 += 1;
+ e.insert(numeric_id);
+ self.id_to_reg.insert(numeric_id, reg);
+ numeric_id
+ }
+ }
+ }
+}
+
+struct FnStringIdResolver<'input, 'b> {
+ current_id: &'b mut SpirvWord,
+ global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
+ global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+ special_registers: &'b mut SpecialRegistersMap,
+ variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
+ type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+}
+
+impl<'a, 'b> FnStringIdResolver<'a, 'b> {
+ fn finish(self) -> NumericIdResolver<'b> {
+ NumericIdResolver {
+ current_id: self.current_id,
+ global_type_check: self.global_type_check,
+ type_check: self.type_check,
+ special_registers: self.special_registers,
+ }
+ }
+
+ fn start_block(&mut self) {
+ self.variables.push(HashMap::new())
+ }
+
+ fn end_block(&mut self) {
+ self.variables.pop();
+ }
+
+ fn get_id(&mut self, id: &str) -> Result<SpirvWord, TranslateError> {
+ for scope in self.variables.iter().rev() {
+ match scope.get(id) {
+ Some(id) => return Ok(*id),
+ None => continue,
+ }
+ }
+ match self.global_variables.get(id) {
+ Some(id) => Ok(*id),
+ None => {
+ let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?;
+ Ok(self.special_registers.get_or_add(self.current_id, sreg))
+ }
+ }
+ }
+
+ fn add_def(
+ &mut self,
+ id: &'a str,
+ typ: Option<(ast::Type, ast::StateSpace)>,
+ is_variable: bool,
+ ) -> SpirvWord {
+ let numeric_id = *self.current_id;
+ self.variables
+ .last_mut()
+ .unwrap()
+ .insert(Cow::Borrowed(id), numeric_id);
+ self.type_check.insert(
+ numeric_id,
+ typ.map(|(typ, space)| (typ, space, is_variable)),
+ );
+ self.current_id.0 += 1;
+ numeric_id
+ }
+
+ #[must_use]
+ fn add_defs(
+ &mut self,
+ base_id: &'a str,
+ count: u32,
+ typ: ast::Type,
+ state_space: ast::StateSpace,
+ is_variable: bool,
+ ) -> impl Iterator<Item = SpirvWord> {
+ let numeric_id = *self.current_id;
+ for i in 0..count {
+ self.variables.last_mut().unwrap().insert(
+ Cow::Owned(format!("{}{}", base_id, i)),
+ SpirvWord(numeric_id.0 + i),
+ );
+ self.type_check.insert(
+ SpirvWord(numeric_id.0 + i),
+ Some((typ.clone(), state_space, is_variable)),
+ );
+ }
+ self.current_id.0 += count;
+ (0..count)
+ .into_iter()
+ .map(move |i| SpirvWord(i + numeric_id.0))
+ }
+}
+
+struct NumericIdResolver<'b> {
+ current_id: &'b mut SpirvWord,
+ global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+ type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
+ special_registers: &'b mut SpecialRegistersMap,
+}
+
+impl<'b> NumericIdResolver<'b> {
+ fn finish(self) -> MutableNumericIdResolver<'b> {
+ MutableNumericIdResolver { base: self }
+ }
+
+ fn get_typed(
+ &self,
+ id: SpirvWord,
+ ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
+ match self.type_check.get(&id) {
+ Some(Some(x)) => Ok(x.clone()),
+ Some(None) => Err(TranslateError::UntypedSymbol),
+ None => match self.special_registers.get(id) {
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
+ None => match self.global_type_check.get(&id) {
+ Some(Some(result)) => Ok(result.clone()),
+ Some(None) | None => Err(TranslateError::UntypedSymbol),
+ },
+ },
+ }
+ }
+
+ // This is for identifiers which will be emitted later as OpVariable
+ // They are candidates for insertion of LoadVar/StoreVar
+ fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
+ let new_id = *self.current_id;
+ self.type_check
+ .insert(new_id, Some((typ, state_space, true)));
+ self.current_id.0 += 1;
+ new_id
+ }
+
+ fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
+ let new_id = *self.current_id;
+ self.type_check
+ .insert(new_id, typ.map(|(t, space)| (t, space, false)));
+ self.current_id.0 += 1;
+ new_id
+ }
+}
+
+struct MutableNumericIdResolver<'b> {
+ base: NumericIdResolver<'b>,
+}
+
+impl<'b> MutableNumericIdResolver<'b> {
+ fn unmut(self) -> NumericIdResolver<'b> {
+ self.base
+ }
+
+ fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
+ self.base.get_typed(id).map(|(t, space, _)| (t, space))
+ }
+
+ fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
+ self.base.register_intermediate(Some((typ, state_space)))
+ }
+}
+
+quick_error! {
+ #[derive(Debug)]
+ pub enum TranslateError {
+ UnknownSymbol {}
+ UntypedSymbol {}
+ MismatchedType {}
+ Spirv(err: rspirv::dr::Error) {
+ from()
+ display("{}", err)
+ cause(err)
+ }
+ Unreachable {}
+ Todo {}
+ }
+}
+
+#[cfg(debug_assertions)]
+fn error_unreachable() -> TranslateError {
+ unreachable!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_unreachable() -> TranslateError {
+ TranslateError::Unreachable
+}
+
+fn error_unknown_symbol() -> TranslateError {
+ panic!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_unknown_symbol() -> TranslateError {
+ TranslateError::UnknownSymbol
+}
+
+fn error_mismatched_type() -> TranslateError {
+ panic!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_mismatched_type() -> TranslateError {
+ TranslateError::MismatchedType
+}
+
+pub struct GlobalFnDeclResolver<'input, 'a> {
+ fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
+}
+
+impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
+ fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> {
+ self.fns.get(&id).ok_or_else(error_unknown_symbol)
+ }
+}
+
+struct FnSigMapper<'input> {
+ // true - stays as return argument
+ // false - is moved to input argument
+ return_param_args: Vec<bool>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+}
+
+impl<'input> FnSigMapper<'input> {
+ fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self {
+ let return_param_args = method
+ .return_arguments
+ .iter()
+ .map(|a| a.state_space != ast::StateSpace::Param)
+ .collect::<Vec<_>>();
+ let mut new_return_arguments = Vec::new();
+ for arg in method.return_arguments.into_iter() {
+ if arg.state_space == ast::StateSpace::Param {
+ method.input_arguments.push(arg);
+ } else {
+ new_return_arguments.push(arg);
+ }
+ }
+ method.return_arguments = new_return_arguments;
+ FnSigMapper {
+ return_param_args,
+ func_decl: Rc::new(RefCell::new(method)),
+ }
+ }
+
+ fn resolve_in_spirv_repr(
+ &self,
+ data: ast::CallDetails,
+ arguments: ast::CallArgs<ast::ParsedOperand<SpirvWord>>,
+ ) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
+ let func_decl = (*self.func_decl).borrow();
+ let mut data_return = Vec::new();
+ let mut arguments_return = Vec::new();
+ let mut data_input = data.input_arguments;
+ let mut arguments_input = arguments.input_arguments;
+ let mut func_decl_return_iter = func_decl.return_arguments.iter();
+ let mut func_decl_input_iter = func_decl.input_arguments[arguments_input.len()..].iter();
+ for (idx, id) in arguments.return_arguments.iter().enumerate() {
+ let stays_as_return = match self.return_param_args.get(idx) {
+ Some(x) => *x,
+ None => return Err(TranslateError::MismatchedType),
+ };
+ if stays_as_return {
+ if let Some(var) = func_decl_return_iter.next() {
+ data_return.push((var.v_type.clone(), var.state_space));
+ arguments_return.push(*id);
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ } else {
+ if let Some(var) = func_decl_input_iter.next() {
+ data_input.push((var.v_type.clone(), var.state_space));
+ arguments_input.push(ast::ParsedOperand::Reg(*id));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ }
+ }
+ if arguments_return.len() != func_decl.return_arguments.len()
+ || arguments_input.len() != func_decl.input_arguments.len()
+ {
+ return Err(TranslateError::MismatchedType);
+ }
+ let data = ast::CallDetails {
+ uniform: data.uniform,
+ return_arguments: data_return,
+ input_arguments: data_input,
+ };
+ let arguments = ast::CallArgs {
+ func: arguments.func,
+ return_arguments: arguments_return,
+ input_arguments: arguments_input,
+ };
+ Ok(ast::Instruction::Call { data, arguments })
+ }
+}
+
+enum Statement<I, P: ast::Operand> {
+ Label(SpirvWord),
+ Variable(ast::Variable<P::Ident>),
+ Instruction(I),
+ // SPIR-V compatible replacement for PTX predicates
+ Conditional(BrachCondition),
+ LoadVar(LoadVarDetails),
+ StoreVar(StoreVarDetails),
+ Conversion(ImplicitConversion),
+ Constant(ConstantDefinition),
+ RetValue(ast::RetData, SpirvWord),
+ PtrAccess(PtrAccess<P>),
+ RepackVector(RepackVectorDetails),
+ FunctionPointer(FunctionPointerDetails),
+}
+
+impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
+ fn visit_map<To: ast::Operand<Ident = SpirvWord>, Err>(
+ self,
+ visitor: &mut impl ast::VisitorMap<T, To, Err>,
+ ) -> std::result::Result<Statement<ast::Instruction<To>, To>, Err> {
+ Ok(match self {
+ Statement::Instruction(i) => {
+ return ast::visit_map(i, visitor).map(Statement::Instruction)
+ }
+ Statement::Label(label) => {
+ Statement::Label(visitor.visit_ident(label, None, false, false)?)
+ }
+ Statement::Variable(var) => {
+ let name = visitor.visit_ident(
+ var.name,
+ Some((&var.v_type, var.state_space)),
+ true,
+ false,
+ )?;
+ Statement::Variable(ast::Variable {
+ align: var.align,
+ v_type: var.v_type,
+ state_space: var.state_space,
+ name,
+ array_init: var.array_init,
+ })
+ }
+ Statement::Conditional(conditional) => {
+ let predicate = visitor.visit_ident(
+ conditional.predicate,
+ Some((&ast::ScalarType::Pred.into(), ast::StateSpace::Reg)),
+ false,
+ false,
+ )?;
+ let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?;
+ let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?;
+ Statement::Conditional(BrachCondition {
+ predicate,
+ if_true,
+ if_false,
+ })
+ }
+ Statement::LoadVar(LoadVarDetails {
+ arg,
+ typ,
+ member_index,
+ }) => {
+ let dst = visitor.visit_ident(
+ arg.dst,
+ Some((&typ, ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(
+ arg.src,
+ Some((&typ, ast::StateSpace::Local)),
+ false,
+ false,
+ )?;
+ Statement::LoadVar(LoadVarDetails {
+ arg: ast::LdArgs { dst, src },
+ typ,
+ member_index,
+ })
+ }
+ Statement::StoreVar(StoreVarDetails {
+ arg,
+ typ,
+ member_index,
+ }) => {
+ let src1 = visitor.visit_ident(
+ arg.src1,
+ Some((&typ, ast::StateSpace::Local)),
+ false,
+ false,
+ )?;
+ let src2 = visitor.visit_ident(
+ arg.src2,
+ Some((&typ, ast::StateSpace::Reg)),
+ false,
+ false,
+ )?;
+ Statement::StoreVar(StoreVarDetails {
+ arg: ast::StArgs { src1, src2 },
+ typ,
+ member_index,
+ })
+ }
+ Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ to_type,
+ from_space,
+ to_space,
+ kind,
+ }) => {
+ let dst = visitor.visit_ident(
+ dst,
+ Some((&to_type, ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(
+ src,
+ Some((&from_type, ast::StateSpace::Reg)),
+ false,
+ false,
+ )?;
+ Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ to_type,
+ from_space,
+ to_space,
+ kind,
+ })
+ }
+ Statement::Constant(ConstantDefinition { dst, typ, value }) => {
+ let dst = visitor.visit_ident(
+ dst,
+ Some((&typ.into(), ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ Statement::Constant(ConstantDefinition { dst, typ, value })
+ }
+ Statement::RetValue(data, value) => {
+ // TODO:
+ // We should report type here
+ let value = visitor.visit_ident(value, None, false, false)?;
+ Statement::RetValue(data, value)
+ }
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src,
+ offset_src,
+ }) => {
+ let dst =
+ visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?;
+ let ptr_src = visitor.visit_ident(
+ ptr_src,
+ Some((&underlying_type, state_space)),
+ false,
+ false,
+ )?;
+ let offset_src = visitor.visit(
+ offset_src,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )),
+ false,
+ false,
+ )?;
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src,
+ offset_src,
+ })
+ }
+ Statement::RepackVector(RepackVectorDetails {
+ is_extract,
+ typ,
+ packed,
+ unpacked,
+ relaxed_type_check,
+ }) => {
+ let (packed, unpacked) = if is_extract {
+ let unpacked = unpacked
+ .into_iter()
+ .map(|ident| {
+ visitor.visit_ident(
+ ident,
+ Some((&typ.into(), ast::StateSpace::Reg)),
+ true,
+ relaxed_type_check,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ let packed = visitor.visit_ident(
+ packed,
+ Some((
+ &ast::Type::Vector(unpacked.len() as u8, typ),
+ ast::StateSpace::Reg,
+ )),
+ false,
+ false,
+ )?;
+ (packed, unpacked)
+ } else {
+ let packed = visitor.visit_ident(
+ packed,
+ Some((
+ &ast::Type::Vector(unpacked.len() as u8, typ),
+ ast::StateSpace::Reg,
+ )),
+ true,
+ false,
+ )?;
+ let unpacked = unpacked
+ .into_iter()
+ .map(|ident| {
+ visitor.visit_ident(
+ ident,
+ Some((&typ.into(), ast::StateSpace::Reg)),
+ false,
+ relaxed_type_check,
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ (packed, unpacked)
+ };
+ Statement::RepackVector(RepackVectorDetails {
+ is_extract,
+ typ,
+ packed,
+ unpacked,
+ relaxed_type_check,
+ })
+ }
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
+ let dst = visitor.visit_ident(
+ dst,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::U64),
+ ast::StateSpace::Reg,
+ )),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(src, None, false, false)?;
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src })
+ }
+ })
+ }
+}
+
+struct BrachCondition {
+ predicate: SpirvWord,
+ if_true: SpirvWord,
+ if_false: SpirvWord,
+}
+struct LoadVarDetails {
+ arg: ast::LdArgs<SpirvWord>,
+ typ: ast::Type,
+ // (index, vector_width)
+ // HACK ALERT
+ // For some reason IGC explodes when you try to load from builtin vectors
+ // using OpInBoundsAccessChain, the one true way to do it is to
+ // OpLoad+OpCompositeExtract
+ member_index: Option<(u8, Option<u8>)>,
+}
+
+struct StoreVarDetails {
+ arg: ast::StArgs<SpirvWord>,
+ typ: ast::Type,
+ member_index: Option<u8>,
+}
+
+#[derive(Clone)]
+struct ImplicitConversion {
+ src: SpirvWord,
+ dst: SpirvWord,
+ from_type: ast::Type,
+ to_type: ast::Type,
+ from_space: ast::StateSpace,
+ to_space: ast::StateSpace,
+ kind: ConversionKind,
+}
+
+#[derive(PartialEq, Clone)]
+enum ConversionKind {
+ Default,
+ // zero-extend/chop/bitcast depending on types
+ SignExtend,
+ BitToPtr,
+ PtrToPtr,
+ AddressOf,
+}
+
+struct ConstantDefinition {
+ pub dst: SpirvWord,
+ pub typ: ast::ScalarType,
+ pub value: ast::ImmediateValue,
+}
+
+pub struct PtrAccess<T> {
+ underlying_type: ast::Type,
+ state_space: ast::StateSpace,
+ dst: SpirvWord,
+ ptr_src: SpirvWord,
+ offset_src: T,
+}
+
+struct RepackVectorDetails {
+ is_extract: bool,
+ typ: ast::ScalarType,
+ packed: SpirvWord,
+ unpacked: Vec<SpirvWord>,
+ relaxed_type_check: bool,
+}
+
+struct FunctionPointerDetails {
+ dst: SpirvWord,
+ src: SpirvWord,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
+struct SpirvWord(spirv::Word);
+
+impl From<spirv::Word> for SpirvWord {
+ fn from(value: spirv::Word) -> Self {
+ Self(value)
+ }
+}
+impl From<SpirvWord> for spirv::Word {
+ fn from(value: SpirvWord) -> Self {
+ value.0
+ }
+}
+
+impl ast::Operand for SpirvWord {
+ type Ident = Self;
+
+ fn from_ident(ident: Self::Ident) -> Self {
+ ident
+ }
+}
+
+fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
+ this: ast::PredAt<T>,
+ f: &mut F,
+) -> Result<ast::PredAt<U>, TranslateError> {
+ let new_label = f(this.label)?;
+ Ok(ast::PredAt {
+ not: this.not,
+ label: new_label,
+ })
+}
+
+pub(crate) enum Directive<'input> {
+ Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
+ Method(Function<'input>),
+}
+
+pub(crate) struct Function<'input> {
+ pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ pub globals: Vec<ast::Variable<SpirvWord>>,
+ pub body: Option<Vec<ExpandedStatement>>,
+ import_as: Option<String>,
+ tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
+}
+
+type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
+
+type NormalizedStatement = Statement<
+ (
+ Option<ast::PredAt<SpirvWord>>,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ),
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type UnconditionalStatement =
+ Statement<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
+
+type TypedStatement = Statement<ast::Instruction<TypedOperand>, TypedOperand>;
+
+#[derive(Copy, Clone)]
+enum TypedOperand {
+ Reg(SpirvWord),
+ RegOffset(SpirvWord, i32),
+ Imm(ast::ImmediateValue),
+ VecMember(SpirvWord, u8),
+}
+
+impl TypedOperand {
+ fn map<Err>(
+ self,
+ fn_: impl FnOnce(SpirvWord, Option<u8>) -> Result<SpirvWord, Err>,
+ ) -> Result<Self, Err> {
+ Ok(match self {
+ TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?),
+ TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off),
+ TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
+ TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx),
+ })
+ }
+
+ fn underlying_register(&self) -> Option<SpirvWord> {
+ match self {
+ Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r),
+ Self::Imm(_) => None,
+ }
+ }
+
+ fn unwrap_reg(&self) -> Result<SpirvWord, TranslateError> {
+ match self {
+ TypedOperand::Reg(reg) => Ok(*reg),
+ _ => Err(error_unreachable()),
+ }
+ }
+}
+
+impl ast::Operand for TypedOperand {
+ type Ident = SpirvWord;
+
+ fn from_ident(ident: Self::Ident) -> Self {
+ TypedOperand::Reg(ident)
+ }
+}
+
+impl<Fn> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
+ for FnVisitor<TypedOperand, TypedOperand, TranslateError, Fn>
+where
+ Fn: FnMut(
+ TypedOperand,
+ Option<(&ast::Type, ast::StateSpace)>,
+ bool,
+ bool,
+ ) -> Result<TypedOperand, TranslateError>,
+{
+ fn visit(
+ &mut self,
+ args: TypedOperand,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<TypedOperand, TranslateError> {
+ (self.fn_)(args, type_space, is_dst, relaxed_type_check)
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ match (self.fn_)(
+ TypedOperand::Reg(args),
+ type_space,
+ is_dst,
+ relaxed_type_check,
+ )? {
+ TypedOperand::Reg(reg) => Ok(reg),
+ _ => Err(TranslateError::Unreachable),
+ }
+ }
+}
+
+struct FnVisitor<
+ T,
+ U,
+ Err,
+ Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
+> {
+ fn_: Fn,
+ _marker: PhantomData<fn(T) -> Result<U, Err>>,
+}
+
+impl<
+ T,
+ U,
+ Err,
+ Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
+ > FnVisitor<T, U, Err, Fn>
+{
+ fn new(fn_: Fn) -> Self {
+ Self {
+ fn_,
+ _marker: PhantomData,
+ }
+ }
+}
+
+fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
+ this == other
+ || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
+ || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
+}
+
+fn register_external_fn_call<'a>(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ name: String,
+ return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+ input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+) -> Result<SpirvWord, TranslateError> {
+ match ptx_impl_imports.entry(name) {
+ hash_map::Entry::Vacant(entry) => {
+ let fn_id = id_defs.register_intermediate(None);
+ let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
+ let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
+ let func_decl = ast::MethodDeclaration::<SpirvWord> {
+ return_arguments,
+ name: ast::MethodName::Func(fn_id),
+ input_arguments,
+ shared_mem: None,
+ };
+ let func = Function {
+ func_decl: Rc::new(RefCell::new(func_decl)),
+ globals: Vec::new(),
+ body: None,
+ import_as: Some(entry.key().clone()),
+ tuning: Vec::new(),
+ linkage: ast::LinkingDirective::EXTERN,
+ };
+ entry.insert(Directive::Method(func));
+ Ok(fn_id)
+ }
+ hash_map::Entry::Occupied(entry) => match entry.get() {
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => Ok(fn_id),
+ ast::MethodName::Kernel(_) => Err(error_unreachable()),
+ },
+ _ => Err(error_unreachable()),
+ },
+ }
+}
+
+fn fn_arguments_to_variables<'a>(
+ id_defs: &mut NumericIdResolver,
+ args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+) -> Vec<ast::Variable<SpirvWord>> {
+ args.map(|(typ, space)| ast::Variable {
+ align: None,
+ v_type: typ.clone(),
+ state_space: space,
+ name: id_defs.register_intermediate(None),
+ array_init: Vec::new(),
+ })
+ .collect::<Vec<_>>()
+}
+
+fn hoist_function_globals(directives: Vec<Directive>) -> Vec<Directive> {
+ let mut result = Vec::with_capacity(directives.len());
+ for directive in directives {
+ match directive {
+ Directive::Method(method) => {
+ for variable in method.globals {
+ result.push(Directive::Variable(ast::LinkingDirective::NONE, variable));
+ }
+ result.push(Directive::Method(Function {
+ globals: Vec::new(),
+ ..method
+ }))
+ }
+ _ => result.push(directive),
+ }
+ }
+ result
+}
+
+struct MethodsCallMap<'input> {
+ map: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
+}
+
+impl<'input> MethodsCallMap<'input> {
+ fn new(module: &[Directive<'input>]) -> Self {
+ let mut directly_called_by = HashMap::new();
+ for directive in module {
+ match directive {
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ ..
+ }) => {
+ let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
+ if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
+ entry.insert(Vec::new());
+ }
+ for statement in statements {
+ match statement {
+ Statement::Instruction(ast::Instruction::Call { data, arguments }) => {
+ multi_hash_map_append(
+ &mut directly_called_by,
+ call_key,
+ arguments.func,
+ );
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ let mut result = HashMap::new();
+ for (&method_key, children) in directly_called_by.iter() {
+ let mut visited = HashSet::new();
+ for child in children {
+ Self::add_call_map_single(&directly_called_by, &mut visited, *child);
+ }
+ result.insert(method_key, visited);
+ }
+ MethodsCallMap { map: result }
+ }
+
+ fn add_call_map_single(
+ directly_called_by: &HashMap<ast::MethodName<'input, SpirvWord>, Vec<SpirvWord>>,
+ visited: &mut HashSet<SpirvWord>,
+ current: SpirvWord,
+ ) {
+ if !visited.insert(current) {
+ return;
+ }
+ if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
+ for child in children {
+ Self::add_call_map_single(directly_called_by, visited, *child);
+ }
+ }
+ }
+
+ fn get_kernel_children(&self, name: &'input str) -> impl Iterator<Item = &SpirvWord> {
+ self.map
+ .get(&ast::MethodName::Kernel(name))
+ .into_iter()
+ .flatten()
+ }
+
+ fn kernels(&self) -> impl Iterator<Item = (&'input str, &HashSet<SpirvWord>)> {
+ self.map
+ .iter()
+ .filter_map(|(method, children)| match method {
+ ast::MethodName::Kernel(kernel) => Some((*kernel, children)),
+ ast::MethodName::Func(..) => None,
+ })
+ }
+
+ fn methods(
+ &self,
+ ) -> impl Iterator<Item = (ast::MethodName<'input, SpirvWord>, &HashSet<SpirvWord>)> {
+ self.map
+ .iter()
+ .map(|(method, children)| (*method, children))
+ }
+
+ fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) {
+ self.map
+ .get(&method)
+ .into_iter()
+ .flatten()
+ .copied()
+ .for_each(f);
+ }
+}
+
+fn multi_hash_map_append<
+ K: Eq + std::hash::Hash,
+ V,
+ Collection: std::iter::Extend<V> + std::default::Default,
+>(
+ m: &mut HashMap<K, Collection>,
+ key: K,
+ value: V,
+) {
+ match m.entry(key) {
+ hash_map::Entry::Occupied(mut entry) => {
+ entry.get_mut().extend(iter::once(value));
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(Default::default()).extend(iter::once(value));
+ }
+ }
+}
+
+fn normalize_variable_decls(directives: &mut Vec<Directive>) {
+ for directive in directives {
+ match directive {
+ Directive::Method(Function {
+ body: Some(func), ..
+ }) => {
+ func[1..].sort_by_key(|s| match s {
+ Statement::Variable(_) => 0,
+ _ => 1,
+ });
+ }
+ _ => (),
+ }
+ }
+}
+
+// HACK ALERT!
+// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
+// in the kernel as flushing denorms to zero or preserving them
+// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
+// such capability, so instead we guesstimate which use is more common in the kernel
+// and emit suitable execution mode
+fn compute_denorm_information<'input>(
+ module: &[Directive<'input>],
+) -> HashMap<ast::MethodName<'input, SpirvWord>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
+ let mut denorm_methods = HashMap::new();
+ for directive in module {
+ match directive {
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ ..
+ }) => {
+ let mut flush_counter = DenormCountMap::new();
+ let method_key = (**func_decl).borrow().name;
+ for statement in statements {
+ match statement {
+ Statement::Instruction(inst) => {
+ if let Some((flush, width)) = flush_to_zero(inst) {
+ denorm_count_map_update(&mut flush_counter, width, flush);
+ }
+ }
+ Statement::LoadVar(..) => {}
+ Statement::StoreVar(..) => {}
+ Statement::Conditional(_) => {}
+ Statement::Conversion(_) => {}
+ Statement::Constant(_) => {}
+ Statement::RetValue(_, _) => {}
+ Statement::Label(_) => {}
+ Statement::Variable(_) => {}
+ Statement::PtrAccess { .. } => {}
+ Statement::RepackVector(_) => {}
+ Statement::FunctionPointer(_) => {}
+ }
+ }
+ denorm_methods.insert(method_key, flush_counter);
+ }
+ }
+ }
+ denorm_methods
+ .into_iter()
+ .map(|(name, v)| {
+ let width_to_denorm = v
+ .into_iter()
+ .map(|(k, flush_over_preserve)| {
+ let mode = if flush_over_preserve > 0 {
+ spirv::FPDenormMode::FlushToZero
+ } else {
+ spirv::FPDenormMode::Preserve
+ };
+ (k, (mode, flush_over_preserve))
+ })
+ .collect();
+ (name, width_to_denorm)
+ })
+ .collect()
+}
+
+fn flush_to_zero(this: &ast::Instruction<SpirvWord>) -> Option<(bool, u8)> {
+ match this {
+ ast::Instruction::Ld { .. } => None,
+ ast::Instruction::St { .. } => None,
+ ast::Instruction::Mov { .. } => None,
+ ast::Instruction::Not { .. } => None,
+ ast::Instruction::Bra { .. } => None,
+ ast::Instruction::Shl { .. } => None,
+ ast::Instruction::Shr { .. } => None,
+ ast::Instruction::Ret { .. } => None,
+ ast::Instruction::Call { .. } => None,
+ ast::Instruction::Or { .. } => None,
+ ast::Instruction::And { .. } => None,
+ ast::Instruction::Cvta { .. } => None,
+ ast::Instruction::Selp { .. } => None,
+ ast::Instruction::Bar { .. } => None,
+ ast::Instruction::Atom { .. } => None,
+ ast::Instruction::AtomCas { .. } => None,
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Integer(_),
+ ..
+ } => None,
+ ast::Instruction::Add {
+ data: ast::ArithDetails::Integer(_),
+ ..
+ } => None,
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Integer { .. },
+ ..
+ } => None,
+ ast::Instruction::Mad {
+ data: ast::MadDetails::Integer { .. },
+ ..
+ } => None,
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Signed(_),
+ ..
+ } => None,
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Unsigned(_),
+ ..
+ } => None,
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Signed(_),
+ ..
+ } => None,
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Unsigned(_),
+ ..
+ } => None,
+ ast::Instruction::Cvt {
+ data:
+ ast::CvtDetails {
+ mode:
+ ast::CvtMode::ZeroExtend
+ | ast::CvtMode::SignExtend
+ | ast::CvtMode::Truncate
+ | ast::CvtMode::Bitcast
+ | ast::CvtMode::SaturateUnsignedToSigned
+ | ast::CvtMode::SaturateSignedToUnsigned
+ | ast::CvtMode::FPFromSigned(_)
+ | ast::CvtMode::FPFromUnsigned(_),
+ ..
+ },
+ ..
+ } => None,
+ ast::Instruction::Div {
+ data: ast::DivDetails::Unsigned(_),
+ ..
+ } => None,
+ ast::Instruction::Div {
+ data: ast::DivDetails::Signed(_),
+ ..
+ } => None,
+ ast::Instruction::Clz { .. } => None,
+ ast::Instruction::Brev { .. } => None,
+ ast::Instruction::Popc { .. } => None,
+ ast::Instruction::Xor { .. } => None,
+ ast::Instruction::Bfe { .. } => None,
+ ast::Instruction::Bfi { .. } => None,
+ ast::Instruction::Rem { .. } => None,
+ ast::Instruction::Prmt { .. } => None,
+ ast::Instruction::Activemask { .. } => None,
+ ast::Instruction::Membar { .. } => None,
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Float(float_control),
+ ..
+ }
+ | ast::Instruction::Add {
+ data: ast::ArithDetails::Float(float_control),
+ ..
+ }
+ | ast::Instruction::Mul {
+ data: ast::MulDetails::Float(float_control),
+ ..
+ }
+ | ast::Instruction::Mad {
+ data: ast::MadDetails::Float(float_control),
+ ..
+ } => float_control
+ .flush_to_zero
+ .map(|ftz| (ftz, float_control.type_.size_of())),
+ ast::Instruction::Fma { data, .. } => {
+ data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
+ }
+ ast::Instruction::Setp { data, .. } => {
+ data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
+ }
+ ast::Instruction::SetpBool { data, .. } => data
+ .base
+ .flush_to_zero
+ .map(|ftz| (ftz, data.base.type_.size_of())),
+ ast::Instruction::Abs { data, .. }
+ | ast::Instruction::Rsqrt { data, .. }
+ | ast::Instruction::Neg { data, .. }
+ | ast::Instruction::Ex2 { data, .. } => {
+ data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
+ }
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Float(float_control),
+ ..
+ }
+ | ast::Instruction::Max {
+ data: ast::MinMaxDetails::Float(float_control),
+ ..
+ } => float_control
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())),
+ ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => {
+ data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
+ }
+ // Modifier .ftz can only be specified when either .dtype or .atype
+ // is .f32 and applies only to single precision (.f32) inputs and results.
+ ast::Instruction::Cvt {
+ data:
+ ast::CvtDetails {
+ mode:
+ ast::CvtMode::FPExtend { flush_to_zero }
+ | ast::CvtMode::FPTruncate { flush_to_zero, .. }
+ | ast::CvtMode::FPRound { flush_to_zero, .. }
+ | ast::CvtMode::SignedFromFP { flush_to_zero, .. }
+ | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. },
+ ..
+ },
+ ..
+ } => flush_to_zero.map(|ftz| (ftz, 4)),
+ ast::Instruction::Div {
+ data:
+ ast::DivDetails::Float(ast::DivFloatDetails {
+ type_,
+ flush_to_zero,
+ ..
+ }),
+ ..
+ } => flush_to_zero.map(|ftz| (ftz, type_.size_of())),
+ ast::Instruction::Sin { data, .. }
+ | ast::Instruction::Cos { data, .. }
+ | ast::Instruction::Lg2 { data, .. } => {
+ Some((data.flush_to_zero, mem::size_of::<f32>() as u8))
+ }
+ ptx_parser::Instruction::PrmtSlow { .. } => None,
+ ptx_parser::Instruction::Trap {} => None,
+ }
+}
+
+type DenormCountMap<T> = HashMap<T, isize>;
+
+fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
+ let num_value = if value { 1 } else { -1 };
+ denorm_count_map_update_impl(map, key, num_value);
+}
+
+fn denorm_count_map_update_impl<T: Eq + Hash>(
+ map: &mut DenormCountMap<T>,
+ key: T,
+ num_value: isize,
+) {
+ match map.entry(key) {
+ hash_map::Entry::Occupied(mut counter) => {
+ *(counter.get_mut()) += num_value;
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(num_value);
+ }
+ }
+}
diff --git a/ptx/src/pass/normalize_identifiers.rs b/ptx/src/pass/normalize_identifiers.rs new file mode 100644 index 0000000..b598345 --- /dev/null +++ b/ptx/src/pass/normalize_identifiers.rs @@ -0,0 +1,80 @@ +use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run<'input, 'b>(
+ id_defs: &mut FnStringIdResolver<'input, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'input, 'b>,
+ func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
+) -> Result<Vec<NormalizedStatement>, TranslateError> {
+ for s in func.iter() {
+ match s {
+ ast::Statement::Label(id) => {
+ id_defs.add_def(*id, None, false);
+ }
+ _ => (),
+ }
+ }
+ let mut result = Vec::new();
+ for s in func {
+ expand_map_variables(id_defs, fn_defs, &mut result, s)?;
+ }
+ Ok(result)
+}
+
+fn expand_map_variables<'a, 'b>(
+ id_defs: &mut FnStringIdResolver<'a, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'a, 'b>,
+ result: &mut Vec<NormalizedStatement>,
+ s: ast::Statement<ast::ParsedOperand<&'a str>>,
+) -> Result<(), TranslateError> {
+ match s {
+ ast::Statement::Block(block) => {
+ id_defs.start_block();
+ for s in block {
+ expand_map_variables(id_defs, fn_defs, result, s)?;
+ }
+ id_defs.end_block();
+ }
+ ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
+ ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
+ p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
+ .transpose()?,
+ ast::visit_map(i, &mut |id,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ _: bool,
+ _: bool| {
+ id_defs.get_id(id)
+ })?,
+ ))),
+ ast::Statement::Variable(var) => {
+ let var_type = var.var.v_type.clone();
+ match var.count {
+ Some(count) => {
+ for new_id in
+ id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
+ {
+ result.push(Statement::Variable(ast::Variable {
+ align: var.var.align,
+ v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
+ name: new_id,
+ array_init: var.var.array_init.clone(),
+ }))
+ }
+ }
+ None => {
+ let new_id =
+ id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
+ result.push(Statement::Variable(ast::Variable {
+ align: var.var.align,
+ v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
+ name: new_id,
+ array_init: var.var.array_init,
+ }));
+ }
+ }
+ }
+ };
+ Ok(())
+}
diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs new file mode 100644 index 0000000..097d87c --- /dev/null +++ b/ptx/src/pass/normalize_labels.rs @@ -0,0 +1,48 @@ +use std::{collections::HashSet, iter};
+
+use super::*;
+
+pub(super) fn run(
+ func: Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+) -> Vec<ExpandedStatement> {
+ let mut labels_in_use = HashSet::new();
+ for s in func.iter() {
+ match s {
+ Statement::Instruction(i) => {
+ if let Some(target) = jump_target(i) {
+ labels_in_use.insert(target);
+ }
+ }
+ Statement::Conditional(cond) => {
+ labels_in_use.insert(cond.if_true);
+ labels_in_use.insert(cond.if_false);
+ }
+ Statement::Variable(..)
+ | Statement::LoadVar(..)
+ | Statement::StoreVar(..)
+ | Statement::RetValue(..)
+ | Statement::Conversion(..)
+ | Statement::Constant(..)
+ | Statement::Label(..)
+ | Statement::PtrAccess { .. }
+ | Statement::RepackVector(..)
+ | Statement::FunctionPointer(..) => {}
+ }
+ }
+ iter::once(Statement::Label(id_def.register_intermediate(None)))
+ .chain(func.into_iter().filter(|s| match s {
+ Statement::Label(i) => labels_in_use.contains(i),
+ _ => true,
+ }))
+ .collect::<Vec<_>>()
+}
+
+fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
+ this: &ast::Instruction<T>,
+) -> Option<SpirvWord> {
+ match this {
+ ast::Instruction::Bra { arguments } => Some(arguments.src),
+ _ => None,
+ }
+}
diff --git a/ptx/src/pass/normalize_predicates.rs b/ptx/src/pass/normalize_predicates.rs new file mode 100644 index 0000000..c971cfa --- /dev/null +++ b/ptx/src/pass/normalize_predicates.rs @@ -0,0 +1,44 @@ +use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run(
+ func: Vec<NormalizedStatement>,
+ id_def: &mut NumericIdResolver,
+) -> Result<Vec<UnconditionalStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Instruction((pred, inst)) => {
+ if let Some(pred) = pred {
+ let if_true = id_def.register_intermediate(None);
+ let if_false = id_def.register_intermediate(None);
+ let folded_bra = match &inst {
+ ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
+ _ => None,
+ };
+ let mut branch = BrachCondition {
+ predicate: pred.label,
+ if_true: folded_bra.unwrap_or(if_true),
+ if_false,
+ };
+ if pred.not {
+ std::mem::swap(&mut branch.if_true, &mut branch.if_false);
+ }
+ result.push(Statement::Conditional(branch));
+ if folded_bra.is_none() {
+ result.push(Statement::Label(if_true));
+ result.push(Statement::Instruction(inst));
+ }
+ result.push(Statement::Label(if_false));
+ } else {
+ result.push(Statement::Instruction(inst));
+ }
+ }
+ Statement::Variable(var) => result.push(Statement::Variable(var)),
+ // Blocks are flattened when resolving ids
+ _ => return Err(error_unreachable()),
+ }
+ }
+ Ok(result)
+}
diff --git a/ptx/src/test/spirv_run/clz.spvtxt b/ptx/src/test/spirv_run/clz.spvtxt index 9a7f254..1feb5a0 100644 --- a/ptx/src/test/spirv_run/clz.spvtxt +++ b/ptx/src/test/spirv_run/clz.spvtxt @@ -7,20 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %22 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "clz" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong + %25 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %24 + %1 = OpFunction %void None %25 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %19 = OpLabel + %20 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -37,11 +41,12 @@ %11 = OpLoad %uint %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %uint %6 - %13 = OpExtInst %uint %21 clz %14 + %18 = OpExtInst %uint %22 clz %14 + %13 = OpCopyObject %uint %18 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %uint %6 - %18 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %18 %16 Aligned 4 + %19 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %19 %16 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt index 5f4b050..92322ec 100644 --- a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt @@ -7,6 +7,9 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_s16_s8" @@ -45,9 +48,7 @@ %32 = OpBitcast %uint %15 %34 = OpUConvert %uchar %32 %20 = OpCopyObject %uchar %34 - %35 = OpBitcast %uchar %20 - %37 = OpSConvert %ushort %35 - %19 = OpCopyObject %ushort %37 + %19 = OpSConvert %ushort %20 %14 = OpSConvert %uint %19 OpStore %6 %14 %16 = OpLoad %ulong %5 diff --git a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt index 3f46103..1165290 100644 --- a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt @@ -7,9 +7,13 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_s64_s32" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 %27 = OpTypeFunction %void %ulong %ulong @@ -40,9 +44,7 @@ %12 = OpCopyObject %uint %18 OpStore %6 %12 %15 = OpLoad %uint %6 - %32 = OpBitcast %uint %15 - %33 = OpSConvert %ulong %32 - %14 = OpCopyObject %ulong %33 + %14 = OpSConvert %ulong %15 OpStore %7 %14 %16 = OpLoad %ulong %5 %17 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt index b676049..07b228e 100644 --- a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt +++ b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt @@ -7,9 +7,13 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %25 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_sat_s_u" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 %28 = OpTypeFunction %void %ulong %ulong @@ -42,7 +46,7 @@ %15 = OpSatConvertSToU %uint %16 OpStore %7 %15 %18 = OpLoad %uint %7 - %17 = OpBitcast %uint %18 + %17 = OpCopyObject %uint %18 OpStore %8 %17 %19 = OpLoad %ulong %5 %20 = OpLoad %uint %8 diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f5dfa64..a798720 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,3 +1,4 @@ +use crate::pass;
use crate::ptx;
use crate::translate;
use hip_runtime_sys::hipError_t;
@@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>( spirv_txt: &'a [u8],
spirv_file_name: &'a str,
) -> Result<(), Box<dyn error::Error + 'a>> {
- let mut errors = Vec::new();
- let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
- assert!(errors.len() == 0);
- let spirv_module = translate::to_spirv_module(ast)?;
+ let ast = ptx_parser::parse_module_checked(ptx_txt).unwrap();
+ let spirv_module = pass::to_spirv_module(ast)?;
let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut());
diff --git a/ptx/src/test/spirv_run/popc.spvtxt b/ptx/src/test/spirv_run/popc.spvtxt index 845add7..c41e792 100644 --- a/ptx/src/test/spirv_run/popc.spvtxt +++ b/ptx/src/test/spirv_run/popc.spvtxt @@ -7,20 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %22 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "popc" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong + %25 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %24 + %1 = OpFunction %void None %25 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %19 = OpLabel + %20 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -37,11 +41,12 @@ %11 = OpLoad %uint %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %uint %6 - %13 = OpBitCount %uint %14 + %18 = OpBitCount %uint %14 + %13 = OpCopyObject %uint %18 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %uint %6 - %18 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %18 %16 Aligned 4 + %19 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %19 %16 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector.ptx b/ptx/src/test/spirv_run/vector.ptx index 90b8ad3..ba07e15 100644 --- a/ptx/src/test/spirv_run/vector.ptx +++ b/ptx/src/test/spirv_run/vector.ptx @@ -1,4 +1,4 @@ -// Excersise as many features of vector types as possible
+// Exercise as many features of vector types as possible
.version 6.5
.target sm_60
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index db1063b..9b422fd 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>( for statement in sorted_statements {
match statement {
Statement::Variable(
- var
- @
- ast::Variable {
+ var @ ast::Variable {
state_space: ast::StateSpace::Shared,
..
},
)
| Statement::Variable(
- var
- @
- ast::Variable {
+ var @ ast::Variable {
state_space: ast::StateSpace::Global,
..
},
@@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>( )?);
}
Statement::Instruction(ast::Instruction::Atom(
- details
- @
- ast::AtomDetails {
+ details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
@@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>( )?);
}
Statement::Instruction(ast::Instruction::Atom(
- details
- @
- ast::AtomDetails {
+ details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
@@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>( )?);
}
Statement::Instruction(ast::Instruction::Atom(
- details
- @
- ast::AtomDetails {
+ details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Float {
op: ast::AtomFloatOp::Add,
diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml new file mode 100644 index 0000000..9032de5 --- /dev/null +++ b/ptx_parser/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "ptx_parser" +version = "0.0.0" +authors = ["Andrzej Janik <[email protected]>"] +edition = "2021" + +[lib] + +[dependencies] +logos = "0.14" +winnow = { version = "0.6.18" } +#winnow = { version = "0.6.18", features = ["debug"] } +ptx_parser_macros = { path = "../ptx_parser_macros" } +thiserror = "1.0" +bitflags = "1.2" +rustc-hash = "2.0.0" +derive_more = { version = "1", features = ["display"] } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs new file mode 100644 index 0000000..d0dc303 --- /dev/null +++ b/ptx_parser/src/ast.rs @@ -0,0 +1,1695 @@ +use super::{
+ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
+ StateSpace, VectorPrefix,
+};
+use crate::{PtxError, PtxParserState};
+use bitflags::bitflags;
+use std::{cmp::Ordering, num::NonZeroU8};
+
+pub enum Statement<P: Operand> {
+ Label(P::Ident),
+ Variable(MultiVariable<P::Ident>),
+ Instruction(Option<PredAt<P::Ident>>, Instruction<P>),
+ Block(Vec<Statement<P>>),
+}
+
+// We define the instruction enum through the macro instead of normally, because we have some of how
+// we use this type in the compilee. Each instruction can be logically split into two parts:
+// properties that define instruction semantics (e.g. is memory load volatile?) that don't change
+// during compilation and arguments (e.g. memory load source and destination) that evolve during
+// compilation. To support compilation passes we need to be able to visit (and change) every
+// argument in a generic way. This macro has visibility over all the fields. Consequently, we use it
+// to generate visitor functions. There re three functions to support three different semantics:
+// visit-by-ref, visit-by-mutable-ref, visit-and-map. In a previous version of the compiler it was
+// done by hand and was very limiting (we supported only visit-and-map).
+// The visitor must implement appropriate visitor trait defined below this macro. For convenience,
+// we implemented visitors for some corresponding FnMut(...) types.
+// Properties in this macro are used to encode information about the instruction arguments (what
+// Rust type is used for it post-parsing, what PTX type does it expect, what PTX address space does
+// it expect, etc.).
+// This information is then available to a visitor.
+ptx_parser_macros::generate_instruction_type!(
+ pub enum Instruction<T: Operand> {
+ Mov {
+ type: { &data.typ },
+ data: MovDetails,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Ld {
+ type: { &data.typ },
+ data: LdDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ relaxed_type_check: true,
+ },
+ src: {
+ repr: T,
+ space: { data.state_space },
+ }
+ }
+ },
+ Add {
+ type: { Type::from(data.type_()) },
+ data: ArithDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ St {
+ type: { &data.typ },
+ data: StData,
+ arguments<T>: {
+ src1: {
+ repr: T,
+ space: { data.state_space },
+ },
+ src2: {
+ repr: T,
+ relaxed_type_check: true,
+ }
+ }
+ },
+ Mul {
+ type: { Type::from(data.type_()) },
+ data: MulDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::from(data.dst_type()) },
+ },
+ src1: T,
+ src2: T,
+ }
+ },
+ Setp {
+ data: SetpData,
+ arguments<T>: {
+ dst1: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ },
+ dst2: {
+ repr: Option<T>,
+ type: Type::from(ScalarType::Pred)
+ },
+ src1: {
+ repr: T,
+ type: Type::from(data.type_),
+ },
+ src2: {
+ repr: T,
+ type: Type::from(data.type_),
+ }
+ }
+ },
+ SetpBool {
+ data: SetpBoolData,
+ arguments<T>: {
+ dst1: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ },
+ dst2: {
+ repr: Option<T>,
+ type: Type::from(ScalarType::Pred)
+ },
+ src1: {
+ repr: T,
+ type: Type::from(data.base.type_),
+ },
+ src2: {
+ repr: T,
+ type: Type::from(data.base.type_),
+ },
+ src3: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ }
+ }
+ },
+ Not {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Or {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ And {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Bra {
+ type: !,
+ arguments<T::Ident>: {
+ src: T
+ }
+ },
+ Call {
+ data: CallDetails,
+ arguments: CallArgs<T>,
+ visit: arguments.visit(data, visitor)?,
+ visit_mut: arguments.visit_mut(data, visitor)?,
+ map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data }
+ },
+ Cvt {
+ data: CvtDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::Scalar(data.to) },
+ // TODO: double check
+ relaxed_type_check: true,
+ },
+ src: {
+ repr: T,
+ type: { Type::Scalar(data.from) },
+ relaxed_type_check: true,
+ },
+ }
+ },
+ Shr {
+ data: ShrData,
+ type: { Type::Scalar(data.type_.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: { Type::Scalar(ScalarType::U32) },
+ },
+ }
+ },
+ Shl {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: { Type::Scalar(ScalarType::U32) },
+ },
+ }
+ },
+ Ret {
+ data: RetData
+ },
+ Cvta {
+ data: CvtaDetails,
+ type: { Type::Scalar(ScalarType::B64) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Abs {
+ data: TypeFtz,
+ type: { Type::Scalar(data.type_) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Mad {
+ type: { Type::from(data.type_()) },
+ data: MadDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::from(data.dst_type()) },
+ },
+ src1: T,
+ src2: T,
+ src3: T,
+ }
+ },
+ Fma {
+ type: { Type::from(data.type_) },
+ data: ArithFloat,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: T,
+ }
+ },
+ Sub {
+ type: { Type::from(data.type_()) },
+ data: ArithDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Min {
+ type: { Type::from(data.type_()) },
+ data: MinMaxDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Max {
+ type: { Type::from(data.type_()) },
+ data: MinMaxDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Rcp {
+ type: { Type::from(data.type_) },
+ data: RcpData,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Sqrt {
+ type: { Type::from(data.type_) },
+ data: RcpData,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Rsqrt {
+ type: { Type::from(data.type_) },
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Selp {
+ type: { Type::Scalar(data.clone()) },
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::Pred)
+ },
+ }
+ },
+ Bar {
+ type: Type::Scalar(ScalarType::U32),
+ data: BarData,
+ arguments<T>: {
+ src1: T,
+ src2: Option<T>,
+ }
+ },
+ Atom {
+ type: &data.type_,
+ data: AtomDetails,
+ arguments<T>: {
+ dst: T,
+ src1: {
+ repr: T,
+ space: { data.space },
+ },
+ src2: T,
+ }
+ },
+ AtomCas {
+ type: Type::Scalar(data.type_),
+ data: AtomCasDetails,
+ arguments<T>: {
+ dst: T,
+ src1: {
+ repr: T,
+ space: { data.space },
+ },
+ src2: T,
+ src3: T,
+ }
+ },
+ Div {
+ type: Type::Scalar(data.type_()),
+ data: DivDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Neg {
+ type: Type::Scalar(data.type_),
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Sin {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Cos {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Lg2 {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Ex2 {
+ type: Type::Scalar(ScalarType::F32),
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Clz {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src: T
+ }
+ },
+ Brev {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Popc {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src: T
+ }
+ },
+ Xor {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Rem {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Bfe {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ }
+ },
+ Bfi {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src4: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ }
+ },
+ PrmtSlow {
+ type: Type::Scalar(ScalarType::U32),
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: T
+ }
+ },
+ Prmt {
+ type: Type::Scalar(ScalarType::B32),
+ data: u16,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Activemask {
+ type: Type::Scalar(ScalarType::B32),
+ arguments<T>: {
+ dst: T
+ }
+ },
+ Membar {
+ data: MemScope
+ },
+ Trap { }
+ }
+);
+
+pub trait Visitor<T: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: &T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+ fn visit_ident(
+ &mut self,
+ args: &T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+}
+
+impl<
+ T: Operand,
+ Err,
+ Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>,
+ > Visitor<T, Err> for Fn
+{
+ fn visit(
+ &mut self,
+ args: &T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: &T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err> {
+ (self)(
+ &T::from_ident(*args),
+ type_space,
+ is_dst,
+ relaxed_type_check,
+ )
+ }
+}
+
+pub trait VisitorMut<T: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: &mut T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+ fn visit_ident(
+ &mut self,
+ args: &mut T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+}
+
+pub trait VisitorMap<From: Operand, To: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: From,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<To, Err>;
+ fn visit_ident(
+ &mut self,
+ args: From::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<To::Ident, Err>;
+}
+
+impl<T: Copy, U: Copy, Err, Fn> VisitorMap<ParsedOperand<T>, ParsedOperand<U>, Err> for Fn
+where
+ Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result<U, Err>,
+{
+ fn visit(
+ &mut self,
+ args: ParsedOperand<T>,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<ParsedOperand<U>, Err> {
+ Ok(match args {
+ ParsedOperand::Reg(ident) => {
+ ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?)
+ }
+ ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset(
+ (self)(ident, type_space, is_dst, relaxed_type_check)?,
+ imm,
+ ),
+ ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm),
+ ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember(
+ (self)(ident, type_space, is_dst, relaxed_type_check)?,
+ index,
+ ),
+ ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
+ vec.into_iter()
+ .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check))
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ })
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+}
+
+impl<T: Operand<Ident = T>, U: Operand<Ident = U>, Err, Fn> VisitorMap<T, U, Err> for Fn
+where
+ Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result<U, Err>,
+{
+ fn visit(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+}
+
+trait VisitOperand<Err> {
+ type Operand: Operand;
+ #[allow(unused)] // Used by generated code
+ fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>;
+ #[allow(unused)] // Used by generated code
+ fn visit_mut(
+ &mut self,
+ fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err>;
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for T {
+ type Operand = Self;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ fn_(self)
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ fn_(self)
+ }
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for Option<T> {
+ type Operand = T;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ if let Some(x) = self {
+ fn_(x)?;
+ }
+ Ok(())
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ if let Some(x) = self {
+ fn_(x)?;
+ }
+ Ok(())
+ }
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for Vec<T> {
+ type Operand = T;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ for o in self {
+ fn_(o)?;
+ }
+ Ok(())
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ for o in self {
+ fn_(o)?;
+ }
+ Ok(())
+ }
+}
+
+trait MapOperand<Err>: Sized {
+ type Input;
+ type Output<U>;
+ #[allow(unused)] // Used by generated code
+ fn map<U>(
+ self,
+ fn_: impl FnOnce(Self::Input) -> Result<U, Err>,
+ ) -> Result<Self::Output<U>, Err>;
+}
+
+impl<T: Operand, Err> MapOperand<Err> for T {
+ type Input = Self;
+ type Output<U> = U;
+ fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<U, Err> {
+ fn_(self)
+ }
+}
+
+impl<T: Operand, Err> MapOperand<Err> for Option<T> {
+ type Input = T;
+ type Output<U> = Option<U>;
+ fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<Option<U>, Err> {
+ self.map(|x| fn_(x)).transpose()
+ }
+}
+
+pub struct MultiVariable<ID> {
+ pub var: Variable<ID>,
+ pub count: Option<u32>,
+}
+
+#[derive(Clone)]
+pub struct Variable<ID> {
+ pub align: Option<u32>,
+ pub v_type: Type,
+ pub state_space: StateSpace,
+ pub name: ID,
+ pub array_init: Vec<u8>,
+}
+
+pub struct PredAt<ID> {
+ pub not: bool,
+ pub label: ID,
+}
+
+#[derive(PartialEq, Eq, Clone, Hash)]
+pub enum Type {
+ // .param.b32 foo;
+ Scalar(ScalarType),
+ // .param.v2.b32 foo;
+ Vector(u8, ScalarType),
+ // .param.b32 foo[4];
+ Array(Option<NonZeroU8>, ScalarType, Vec<u32>),
+ Pointer(ScalarType, StateSpace),
+}
+
+impl Type {
+ pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
+ match vector {
+ Some(prefix) => Type::Vector(prefix.len().get(), scalar),
+ None => Type::Scalar(scalar),
+ }
+ }
+
+ pub(crate) fn maybe_vector_parsed(prefix: Option<NonZeroU8>, scalar: ScalarType) -> Self {
+ match prefix {
+ Some(prefix) => Type::Vector(prefix.get(), scalar),
+ None => Type::Scalar(scalar),
+ }
+ }
+
+ pub(crate) fn maybe_array(
+ prefix: Option<NonZeroU8>,
+ scalar: ScalarType,
+ array: Option<Vec<u32>>,
+ ) -> Self {
+ match array {
+ Some(dimensions) => Type::Array(prefix, scalar, dimensions),
+ None => Self::maybe_vector_parsed(prefix, scalar),
+ }
+ }
+}
+
+impl ScalarType {
+ pub fn size_of(self) -> u8 {
+ match self {
+ ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1,
+ ScalarType::U16
+ | ScalarType::S16
+ | ScalarType::B16
+ | ScalarType::F16
+ | ScalarType::BF16 => 2,
+ ScalarType::U32
+ | ScalarType::S32
+ | ScalarType::B32
+ | ScalarType::F32
+ | ScalarType::U16x2
+ | ScalarType::S16x2
+ | ScalarType::F16x2
+ | ScalarType::BF16x2 => 4,
+ ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8,
+ ScalarType::B128 => 16,
+ ScalarType::Pred => 1,
+ }
+ }
+
+ pub fn kind(self) -> ScalarKind {
+ match self {
+ ScalarType::U8 => ScalarKind::Unsigned,
+ ScalarType::U16 => ScalarKind::Unsigned,
+ ScalarType::U16x2 => ScalarKind::Unsigned,
+ ScalarType::U32 => ScalarKind::Unsigned,
+ ScalarType::U64 => ScalarKind::Unsigned,
+ ScalarType::S8 => ScalarKind::Signed,
+ ScalarType::S16 => ScalarKind::Signed,
+ ScalarType::S16x2 => ScalarKind::Signed,
+ ScalarType::S32 => ScalarKind::Signed,
+ ScalarType::S64 => ScalarKind::Signed,
+ ScalarType::B8 => ScalarKind::Bit,
+ ScalarType::B16 => ScalarKind::Bit,
+ ScalarType::B32 => ScalarKind::Bit,
+ ScalarType::B64 => ScalarKind::Bit,
+ ScalarType::B128 => ScalarKind::Bit,
+ ScalarType::F16 => ScalarKind::Float,
+ ScalarType::F16x2 => ScalarKind::Float,
+ ScalarType::F32 => ScalarKind::Float,
+ ScalarType::F64 => ScalarKind::Float,
+ ScalarType::BF16 => ScalarKind::Float,
+ ScalarType::BF16x2 => ScalarKind::Float,
+ ScalarType::Pred => ScalarKind::Pred,
+ }
+ }
+}
+
+#[derive(Clone, Copy, PartialEq, Eq)]
+pub enum ScalarKind {
+ Bit,
+ Unsigned,
+ Signed,
+ Float,
+ Pred,
+}
+impl From<ScalarType> for Type {
+ fn from(value: ScalarType) -> Self {
+ Type::Scalar(value)
+ }
+}
+
+#[derive(Clone)]
+pub struct MovDetails {
+ pub typ: super::Type,
+ pub src_is_address: bool,
+ // two fields below are in use by member moves
+ pub dst_width: u8,
+ pub src_width: u8,
+ // This is in use by auto-generated movs
+ pub relaxed_src2_conv: bool,
+}
+
+impl MovDetails {
+ pub(crate) fn new(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
+ MovDetails {
+ typ: Type::maybe_vector(vector, scalar),
+ src_is_address: false,
+ dst_width: 0,
+ src_width: 0,
+ relaxed_src2_conv: false,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub enum ParsedOperand<Ident> {
+ Reg(Ident),
+ RegOffset(Ident, i32),
+ Imm(ImmediateValue),
+ VecMember(Ident, u8),
+ VecPack(Vec<Ident>),
+}
+
+impl<Ident: Copy> Operand for ParsedOperand<Ident> {
+ type Ident = Ident;
+
+ fn from_ident(ident: Self::Ident) -> Self {
+ ParsedOperand::Reg(ident)
+ }
+}
+
+pub trait Operand: Sized {
+ type Ident: Copy;
+
+ fn from_ident(ident: Self::Ident) -> Self;
+}
+
+#[derive(Copy, Clone)]
+pub enum ImmediateValue {
+ U64(u64),
+ S64(i64),
+ F32(f32),
+ F64(f64),
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum StCacheOperator {
+ Writeback,
+ L2Only,
+ Streaming,
+ Writethrough,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum LdCacheOperator {
+ Cached,
+ L2Only,
+ Streaming,
+ LastUse,
+ Uncached,
+}
+
+#[derive(Copy, Clone)]
+pub enum ArithDetails {
+ Integer(ArithInteger),
+ Float(ArithFloat),
+}
+
+impl ArithDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ ArithDetails::Integer(t) => t.type_,
+ ArithDetails::Float(arith) => arith.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithInteger {
+ pub type_: ScalarType,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithFloat {
+ pub type_: ScalarType,
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: Option<bool>,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum LdStQualifier {
+ Weak,
+ Volatile,
+ Relaxed(MemScope),
+ Acquire(MemScope),
+ Release(MemScope),
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum RoundingMode {
+ NearestEven,
+ Zero,
+ NegativeInf,
+ PositiveInf,
+}
+
+pub struct LdDetails {
+ pub qualifier: LdStQualifier,
+ pub state_space: StateSpace,
+ pub caching: LdCacheOperator,
+ pub typ: Type,
+ pub non_coherent: bool,
+}
+
+pub struct StData {
+ pub qualifier: LdStQualifier,
+ pub state_space: StateSpace,
+ pub caching: StCacheOperator,
+ pub typ: Type,
+}
+
+#[derive(Copy, Clone)]
+pub struct RetData {
+ pub uniform: bool,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum TuningDirective {
+ MaxNReg(u32),
+ MaxNtid(u32, u32, u32),
+ ReqNtid(u32, u32, u32),
+ MinNCtaPerSm(u32),
+}
+
+pub struct MethodDeclaration<'input, ID> {
+ pub return_arguments: Vec<Variable<ID>>,
+ pub name: MethodName<'input, ID>,
+ pub input_arguments: Vec<Variable<ID>>,
+ pub shared_mem: Option<ID>,
+}
+
+impl<'input> MethodDeclaration<'input, &'input str> {
+ pub fn name(&self) -> &'input str {
+ match self.name {
+ MethodName::Kernel(n) => n,
+ MethodName::Func(n) => n,
+ }
+ }
+}
+
+#[derive(Hash, PartialEq, Eq, Copy, Clone)]
+pub enum MethodName<'input, ID> {
+ Kernel(&'input str),
+ Func(ID),
+}
+
+bitflags! {
+ pub struct LinkingDirective: u8 {
+ const NONE = 0b000;
+ const EXTERN = 0b001;
+ const VISIBLE = 0b10;
+ const WEAK = 0b100;
+ }
+}
+
+pub struct Function<'a, ID, S> {
+ pub func_directive: MethodDeclaration<'a, ID>,
+ pub tuning: Vec<TuningDirective>,
+ pub body: Option<Vec<S>>,
+}
+
+pub enum Directive<'input, O: Operand> {
+ Variable(LinkingDirective, Variable<O::Ident>),
+ Method(
+ LinkingDirective,
+ Function<'input, &'input str, Statement<O>>,
+ ),
+}
+
+pub struct Module<'input> {
+ pub version: (u8, u8),
+ pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
+}
+
+#[derive(Copy, Clone)]
+pub enum MulDetails {
+ Integer {
+ type_: ScalarType,
+ control: MulIntControl,
+ },
+ Float(ArithFloat),
+}
+
+impl MulDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ MulDetails::Integer { type_, .. } => *type_,
+ MulDetails::Float(arith) => arith.type_,
+ }
+ }
+
+ pub fn dst_type(&self) -> ScalarType {
+ match self {
+ MulDetails::Integer {
+ type_,
+ control: MulIntControl::Wide,
+ } => match type_ {
+ ScalarType::U16 => ScalarType::U32,
+ ScalarType::S16 => ScalarType::S32,
+ ScalarType::U32 => ScalarType::U64,
+ ScalarType::S32 => ScalarType::S64,
+ _ => unreachable!(),
+ },
+ _ => self.type_(),
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum MulIntControl {
+ Low,
+ High,
+ Wide,
+}
+
+pub struct SetpData {
+ pub type_: ScalarType,
+ pub flush_to_zero: Option<bool>,
+ pub cmp_op: SetpCompareOp,
+}
+
+impl SetpData {
+ pub(crate) fn try_parse(
+ state: &mut PtxParserState,
+ cmp_op: super::RawSetpCompareOp,
+ ftz: bool,
+ type_: ScalarType,
+ ) -> Self {
+ let flush_to_zero = match (ftz, type_) {
+ (_, ScalarType::F32) => Some(ftz),
+ (true, _) => {
+ state.errors.push(PtxError::NonF32Ftz);
+ None
+ }
+ _ => None
+ };
+ let type_kind = type_.kind();
+ let cmp_op = if type_kind == ScalarKind::Float {
+ SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
+ } else {
+ match SetpCompareInt::try_from((cmp_op, type_kind)) {
+ Ok(op) => SetpCompareOp::Integer(op),
+ Err(err) => {
+ state.errors.push(err);
+ SetpCompareOp::Integer(SetpCompareInt::Eq)
+ }
+ }
+ };
+ Self {
+ type_,
+ flush_to_zero,
+ cmp_op,
+ }
+ }
+}
+
+pub struct SetpBoolData {
+ pub base: SetpData,
+ pub bool_op: SetpBoolPostOp,
+ pub negate_src3: bool,
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareOp {
+ Integer(SetpCompareInt),
+ Float(SetpCompareFloat),
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareInt {
+ Eq,
+ NotEq,
+ UnsignedLess,
+ UnsignedLessOrEq,
+ UnsignedGreater,
+ UnsignedGreaterOrEq,
+ SignedLess,
+ SignedLessOrEq,
+ SignedGreater,
+ SignedGreaterOrEq,
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareFloat {
+ Eq,
+ NotEq,
+ Less,
+ LessOrEq,
+ Greater,
+ GreaterOrEq,
+ NanEq,
+ NanNotEq,
+ NanLess,
+ NanLessOrEq,
+ NanGreater,
+ NanGreaterOrEq,
+ IsNotNan,
+ IsAnyNan,
+}
+
+impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt {
+ type Error = PtxError;
+
+ fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result<Self, PtxError> {
+ match (value, kind) {
+ (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq),
+ (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq),
+ (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedLess)
+ }
+ (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess),
+ (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedLessOrEq)
+ }
+ (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => {
+ Ok(SetpCompareInt::UnsignedLessOrEq)
+ }
+ (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedGreater)
+ }
+ (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater),
+ (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedGreaterOrEq)
+ }
+ (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => {
+ Ok(SetpCompareInt::UnsignedGreaterOrEq)
+ }
+ (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType),
+ }
+ }
+}
+
+impl From<RawSetpCompareOp> for SetpCompareFloat {
+ fn from(value: RawSetpCompareOp) -> Self {
+ match value {
+ RawSetpCompareOp::Eq => SetpCompareFloat::Eq,
+ RawSetpCompareOp::Ne => SetpCompareFloat::NotEq,
+ RawSetpCompareOp::Lt => SetpCompareFloat::Less,
+ RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq,
+ RawSetpCompareOp::Gt => SetpCompareFloat::Greater,
+ RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq,
+ RawSetpCompareOp::Lo => SetpCompareFloat::Less,
+ RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq,
+ RawSetpCompareOp::Hi => SetpCompareFloat::Greater,
+ RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq,
+ RawSetpCompareOp::Equ => SetpCompareFloat::NanEq,
+ RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq,
+ RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess,
+ RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq,
+ RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater,
+ RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq,
+ RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan,
+ RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan,
+ }
+ }
+}
+
+pub struct CallDetails {
+ pub uniform: bool,
+ pub return_arguments: Vec<(Type, StateSpace)>,
+ pub input_arguments: Vec<(Type, StateSpace)>,
+}
+
+pub struct CallArgs<T: Operand> {
+ pub return_arguments: Vec<T::Ident>,
+ pub func: T::Ident,
+ pub input_arguments: Vec<T>,
+}
+
+impl<T: Operand> CallArgs<T> {
+ #[allow(dead_code)] // Used by generated code
+ fn visit<Err>(
+ &self,
+ details: &CallDetails,
+ visitor: &mut impl Visitor<T, Err>,
+ ) -> Result<(), Err> {
+ for (param, (type_, space)) in self
+ .return_arguments
+ .iter()
+ .zip(details.return_arguments.iter())
+ {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)?;
+ }
+ visitor.visit_ident(&self.func, None, false, false)?;
+ for (param, (type_, space)) in self
+ .input_arguments
+ .iter()
+ .zip(details.input_arguments.iter())
+ {
+ visitor.visit(param, Some((type_, *space)), false, false)?;
+ }
+ Ok(())
+ }
+
+ #[allow(dead_code)] // Used by generated code
+ fn visit_mut<Err>(
+ &mut self,
+ details: &CallDetails,
+ visitor: &mut impl VisitorMut<T, Err>,
+ ) -> Result<(), Err> {
+ for (param, (type_, space)) in self
+ .return_arguments
+ .iter_mut()
+ .zip(details.return_arguments.iter())
+ {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)?;
+ }
+ visitor.visit_ident(&mut self.func, None, false, false)?;
+ for (param, (type_, space)) in self
+ .input_arguments
+ .iter_mut()
+ .zip(details.input_arguments.iter())
+ {
+ visitor.visit(param, Some((type_, *space)), false, false)?;
+ }
+ Ok(())
+ }
+
+ #[allow(dead_code)] // Used by generated code
+ fn map<U: Operand, Err>(
+ self,
+ details: &CallDetails,
+ visitor: &mut impl VisitorMap<T, U, Err>,
+ ) -> Result<CallArgs<U>, Err> {
+ let return_arguments = self
+ .return_arguments
+ .into_iter()
+ .zip(details.return_arguments.iter())
+ .map(|(param, (type_, space))| {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ let func = visitor.visit_ident(self.func, None, false, false)?;
+ let input_arguments = self
+ .input_arguments
+ .into_iter()
+ .zip(details.input_arguments.iter())
+ .map(|(param, (type_, space))| {
+ visitor.visit(param, Some((type_, *space)), false, false)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(CallArgs {
+ return_arguments,
+ func,
+ input_arguments,
+ })
+ }
+}
+
+pub struct CvtDetails {
+ pub from: ScalarType,
+ pub to: ScalarType,
+ pub mode: CvtMode,
+}
+
+pub enum CvtMode {
+ // int from int
+ ZeroExtend,
+ SignExtend,
+ Truncate,
+ Bitcast,
+ SaturateUnsignedToSigned,
+ SaturateSignedToUnsigned,
+ // float from float
+ FPExtend {
+ flush_to_zero: Option<bool>,
+ },
+ FPTruncate {
+ // float rounding
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ },
+ FPRound {
+ integer_rounding: Option<RoundingMode>,
+ flush_to_zero: Option<bool>,
+ },
+ // int from float
+ SignedFromFP {
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ }, // integer rounding
+ UnsignedFromFP {
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ }, // integer rounding
+ // float from int, ftz is allowed in the grammar, but clearly nonsensical
+ FPFromSigned(RoundingMode), // float rounding
+ FPFromUnsigned(RoundingMode), // float rounding
+}
+
+impl CvtDetails {
+ pub(crate) fn new(
+ errors: &mut Vec<PtxError>,
+ rnd: Option<RawRoundingMode>,
+ ftz: bool,
+ saturate: bool,
+ dst: ScalarType,
+ src: ScalarType,
+ ) -> Self {
+ if saturate && dst.kind() == ScalarKind::Float {
+ errors.push(PtxError::SyntaxError);
+ }
+ // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results.
+ let flush_to_zero = match (dst, src) {
+ (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz),
+ _ => {
+ if ftz {
+ errors.push(PtxError::NonF32Ftz);
+ }
+ None
+ }
+ };
+ let rounding = rnd.map(Into::into);
+ let mut unwrap_rounding = || match rounding {
+ Some(rnd) => rnd,
+ None => {
+ errors.push(PtxError::SyntaxError);
+ RoundingMode::NearestEven
+ }
+ };
+ let mode = match (dst.kind(), src.kind()) {
+ (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
+ Ordering::Less => CvtMode::FPTruncate {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ Ordering::Equal => CvtMode::FPRound {
+ integer_rounding: rounding,
+ flush_to_zero,
+ },
+ Ordering::Greater => {
+ if rounding.is_some() {
+ errors.push(PtxError::SyntaxError);
+ }
+ CvtMode::FPExtend { flush_to_zero }
+ }
+ },
+ (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
+ (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
+ (ScalarKind::Signed, ScalarKind::Unsigned) if saturate => {
+ CvtMode::SaturateUnsignedToSigned
+ }
+ (ScalarKind::Unsigned, ScalarKind::Signed) if saturate => {
+ CvtMode::SaturateSignedToUnsigned
+ }
+ (ScalarKind::Unsigned, ScalarKind::Signed)
+ | (ScalarKind::Signed, ScalarKind::Unsigned)
+ if dst.size_of() == src.size_of() =>
+ {
+ CvtMode::Bitcast
+ }
+ (ScalarKind::Unsigned, ScalarKind::Unsigned)
+ | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) {
+ Ordering::Less => CvtMode::Truncate,
+ Ordering::Equal => CvtMode::Bitcast,
+ Ordering::Greater => {
+ if src.kind() == ScalarKind::Signed {
+ CvtMode::SignExtend
+ } else {
+ CvtMode::ZeroExtend
+ }
+ }
+ },
+ (ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned,
+ (_, _) => {
+ errors.push(PtxError::SyntaxError);
+ CvtMode::Bitcast
+ }
+ };
+ CvtDetails {
+ mode,
+ to: dst,
+ from: src,
+ }
+ }
+}
+
+pub struct CvtIntToIntDesc {
+ pub dst: ScalarType,
+ pub src: ScalarType,
+ pub saturate: bool,
+}
+
+pub struct CvtDesc {
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: Option<bool>,
+ pub saturate: bool,
+ pub dst: ScalarType,
+ pub src: ScalarType,
+}
+
+pub struct ShrData {
+ pub type_: ScalarType,
+ pub kind: RightShiftKind,
+}
+
+pub enum RightShiftKind {
+ Arithmetic,
+ Logical,
+}
+
+pub struct CvtaDetails {
+ pub state_space: StateSpace,
+ pub direction: CvtaDirection,
+}
+
+pub enum CvtaDirection {
+ GenericToExplicit,
+ ExplicitToGeneric,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub struct TypeFtz {
+ pub flush_to_zero: Option<bool>,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub enum MadDetails {
+ Integer {
+ control: MulIntControl,
+ saturate: bool,
+ type_: ScalarType,
+ },
+ Float(ArithFloat),
+}
+
+impl MadDetails {
+ pub fn dst_type(&self) -> ScalarType {
+ match self {
+ MadDetails::Integer {
+ type_,
+ control: MulIntControl::Wide,
+ ..
+ } => match type_ {
+ ScalarType::U16 => ScalarType::U32,
+ ScalarType::S16 => ScalarType::S32,
+ ScalarType::U32 => ScalarType::U64,
+ ScalarType::S32 => ScalarType::S64,
+ _ => unreachable!(),
+ },
+ _ => self.type_(),
+ }
+ }
+
+ fn type_(&self) -> ScalarType {
+ match self {
+ MadDetails::Integer { type_, .. } => *type_,
+ MadDetails::Float(arith) => arith.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub enum MinMaxDetails {
+ Signed(ScalarType),
+ Unsigned(ScalarType),
+ Float(MinMaxFloat),
+}
+
+impl MinMaxDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ MinMaxDetails::Signed(t) => *t,
+ MinMaxDetails::Unsigned(t) => *t,
+ MinMaxDetails::Float(float) => float.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct MinMaxFloat {
+ pub flush_to_zero: Option<bool>,
+ pub nan: bool,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub struct RcpData {
+ pub kind: RcpKind,
+ pub flush_to_zero: Option<bool>,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum RcpKind {
+ Approx,
+ Compliant(RoundingMode),
+}
+
+pub struct BarData {
+ pub aligned: bool,
+}
+
+pub struct AtomDetails {
+ pub type_: Type,
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: StateSpace,
+ pub op: AtomicOp,
+}
+
+#[derive(Copy, Clone)]
+pub enum AtomicOp {
+ And,
+ Or,
+ Xor,
+ Exchange,
+ Add,
+ IncrementWrap,
+ DecrementWrap,
+ SignedMin,
+ UnsignedMin,
+ SignedMax,
+ UnsignedMax,
+ FloatAdd,
+ FloatMin,
+ FloatMax,
+}
+
+impl AtomicOp {
+ pub(crate) fn new(op: super::RawAtomicOp, kind: ScalarKind) -> Self {
+ use super::RawAtomicOp;
+ match (op, kind) {
+ (RawAtomicOp::And, _) => Self::And,
+ (RawAtomicOp::Or, _) => Self::Or,
+ (RawAtomicOp::Xor, _) => Self::Xor,
+ (RawAtomicOp::Exch, _) => Self::Exchange,
+ (RawAtomicOp::Add, ScalarKind::Float) => Self::FloatAdd,
+ (RawAtomicOp::Add, _) => Self::Add,
+ (RawAtomicOp::Inc, _) => Self::IncrementWrap,
+ (RawAtomicOp::Dec, _) => Self::DecrementWrap,
+ (RawAtomicOp::Min, ScalarKind::Signed) => Self::SignedMin,
+ (RawAtomicOp::Min, ScalarKind::Float) => Self::FloatMin,
+ (RawAtomicOp::Min, _) => Self::UnsignedMin,
+ (RawAtomicOp::Max, ScalarKind::Signed) => Self::SignedMax,
+ (RawAtomicOp::Max, ScalarKind::Float) => Self::FloatMax,
+ (RawAtomicOp::Max, _) => Self::UnsignedMax,
+ }
+ }
+}
+
+pub struct AtomCasDetails {
+ pub type_: ScalarType,
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: StateSpace,
+}
+
+#[derive(Copy, Clone)]
+pub enum DivDetails {
+ Unsigned(ScalarType),
+ Signed(ScalarType),
+ Float(DivFloatDetails),
+}
+
+impl DivDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ DivDetails::Unsigned(t) => *t,
+ DivDetails::Signed(t) => *t,
+ DivDetails::Float(float) => float.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct DivFloatDetails {
+ pub type_: ScalarType,
+ pub flush_to_zero: Option<bool>,
+ pub kind: DivFloatKind,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum DivFloatKind {
+ Approx,
+ ApproxFull,
+ Rounding(RoundingMode),
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct FlushToZero {
+ pub flush_to_zero: bool,
+}
diff --git a/ptx_parser/src/check_args.py b/ptx_parser/src/check_args.py new file mode 100644 index 0000000..04ffdb9 --- /dev/null +++ b/ptx_parser/src/check_args.py @@ -0,0 +1,69 @@ +import os, sys, subprocess
+
+
+SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"]
+TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"]
+MULTIVAR = ["", "<1>" ]
+VECTOR = ["", ".v2" ]
+
+HEADER = """
+ .version 8.5
+ .target sm_90
+ .address_size 64
+"""
+
+
+def directive(space, variable, multivar, vector):
+ return """{3}
+ {0} {4} .b32 variable{2} {1};
+ """.format(space, variable, multivar, HEADER, vector)
+
+def entry_arg(space, variable, multivar, vector):
+ return """{3}
+ .entry foobar ( {0} {4} .b32 variable{2} {1})
+ {{
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def fn_arg(space, variable, multivar, vector):
+ return """{3}
+ .func foobar ( {0} {4} .b32 variable{2} {1})
+ {{
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def fn_body(space, variable, multivar, vector):
+ return """{3}
+ .func foobar ()
+ {{
+ {0} {4} .b32 variable{2} {1};
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def generate(generator):
+ legal = []
+ for space in SPACE:
+ for init in TYPE_AND_INIT:
+ for multi in MULTIVAR:
+ for vector in VECTOR:
+ ptx = generator(space, init, multi, vector)
+ if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): #
+ legal.append((space, vector, init, multi))
+ print(generator.__name__)
+ print(legal)
+
+
+def main():
+ generate(directive)
+ generate(entry_arg)
+ generate(fn_arg)
+ generate(fn_body)
+
+if __name__ == "__main__":
+ main()
diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs new file mode 100644 index 0000000..f842ace --- /dev/null +++ b/ptx_parser/src/lib.rs @@ -0,0 +1,3269 @@ +use derive_more::Display; +use logos::Logos; +use ptx_parser_macros::derive_parser; +use rustc_hash::FxHashMap; +use std::fmt::Debug; +use std::iter; +use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; +use winnow::ascii::dec_uint; +use winnow::combinator::*; +use winnow::error::{ErrMode, ErrorKind}; +use winnow::stream::Accumulate; +use winnow::token::any; +use winnow::{ + error::{ContextError, ParserError}, + stream::{Offset, Stream, StreamIsPartial}, + PResult, +}; +use winnow::{prelude::*, Stateful}; + +mod ast; +pub use ast::*; + +impl From<RawMulIntControl> for ast::MulIntControl { + fn from(value: RawMulIntControl) -> Self { + match value { + RawMulIntControl::Lo => ast::MulIntControl::Low, + RawMulIntControl::Hi => ast::MulIntControl::High, + RawMulIntControl::Wide => ast::MulIntControl::Wide, + } + } +} + +impl From<RawStCacheOperator> for ast::StCacheOperator { + fn from(value: RawStCacheOperator) -> Self { + match value { + RawStCacheOperator::Wb => ast::StCacheOperator::Writeback, + RawStCacheOperator::Cg => ast::StCacheOperator::L2Only, + RawStCacheOperator::Cs => ast::StCacheOperator::Streaming, + RawStCacheOperator::Wt => ast::StCacheOperator::Writethrough, + } + } +} + +impl From<RawLdCacheOperator> for ast::LdCacheOperator { + fn from(value: RawLdCacheOperator) -> Self { + match value { + RawLdCacheOperator::Ca => ast::LdCacheOperator::Cached, + RawLdCacheOperator::Cg => ast::LdCacheOperator::L2Only, + RawLdCacheOperator::Cs => ast::LdCacheOperator::Streaming, + RawLdCacheOperator::Lu => ast::LdCacheOperator::LastUse, + RawLdCacheOperator::Cv => ast::LdCacheOperator::Uncached, + } + } +} + +impl From<RawLdStQualifier> for ast::LdStQualifier { + fn from(value: RawLdStQualifier) -> Self { + match value { + RawLdStQualifier::Weak => ast::LdStQualifier::Weak, + RawLdStQualifier::Volatile => ast::LdStQualifier::Volatile, + } + } +} + +impl From<RawRoundingMode> for ast::RoundingMode { + fn from(value: RawRoundingMode) -> Self { + match value { + RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven, + RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero, + RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf, + RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf, + } + } +} + +impl VectorPrefix { + pub(crate) fn len(self) -> NonZeroU8 { + unsafe { + match self { + VectorPrefix::V2 => NonZeroU8::new_unchecked(2), + VectorPrefix::V4 => NonZeroU8::new_unchecked(4), + VectorPrefix::V8 => NonZeroU8::new_unchecked(8), + } + } + } +} + +struct PtxParserState<'a, 'input> { + errors: &'a mut Vec<PtxError>, + function_declarations: + FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, +} + +impl<'a, 'input> PtxParserState<'a, 'input> { + fn new(errors: &'a mut Vec<PtxError>) -> Self { + Self { + errors, + function_declarations: FxHashMap::default(), + } + } + + fn record_function(&mut self, function_decl: &MethodDeclaration<'input, &'input str>) { + let name = match function_decl.name { + MethodName::Kernel(name) => name, + MethodName::Func(name) => name, + }; + let return_arguments = Self::get_type_space(&*function_decl.return_arguments); + let input_arguments = Self::get_type_space(&*function_decl.input_arguments); + // TODO: check if declarations match + self.function_declarations + .insert(name, (return_arguments, input_arguments)); + } + + fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> { + input_arguments + .iter() + .map(|var| (var.v_type.clone(), var.state_space)) + .collect::<Vec<_>>() + } +} + +impl<'a, 'input> Debug for PtxParserState<'a, 'input> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PtxParserState") + .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */ + .finish() + } +} + +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; + +fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::Ident(text) = t { + Some(text) + } else if let Some(text) = t.opcode_text() { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + +fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::DotIdent(text) = t { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + +fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { + any.verify_map(|t| { + Some(match t { + Token::Hex(s) => { + if s.ends_with('U') { + (&s[2..s.len() - 1], 16, true) + } else { + (&s[2..], 16, false) + } + } + Token::Decimal(s) => { + let radix = if s.starts_with('0') { 8 } else { 10 }; + if s.ends_with('U') { + (&s[..s.len() - 1], radix, true) + } else { + (s, radix, false) + } + } + _ => return None, + }) + }) + .parse_next(stream) +} + +fn take_error<'a, 'input: 'a, O, E>( + mut parser: impl Parser<PtxParser<'a, 'input>, Result<O, (O, PtxError)>, E>, +) -> impl Parser<PtxParser<'a, 'input>, O, E> { + move |input: &mut PtxParser<'a, 'input>| { + Ok(match parser.parse_next(input)? { + Ok(x) => x, + Err((x, err)) => { + input.state.errors.push(err); + x + } + }) + } +} + +fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> { + take_error((opt(Token::Minus), num).map(|(neg, x)| { + let (num, radix, is_unsigned) = x; + if neg.is_some() { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(-x)), + Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))), + } + } else if is_unsigned { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + } + } else { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(x)), + Err(_) => match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + }, + } + } + })) + .parse_next(input) +} + +fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> { + take_error(any.verify_map(|t| match t { + Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f32::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f64> { + take_error(any.verify_map(|t| match t { + Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f64::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<i32> { + take_error((opt(Token::Minus), num).map(|(sign, x)| { + let (text, radix, _) = x; + match i32::from_str_radix(text, radix) { + Ok(x) => Ok(if sign.is_some() { -x } else { x }), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn u8<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u8> { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u8::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u32::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> { + alt(( + int_immediate, + f32.map(ast::ImmediateValue::F32), + f64.map(ast::ImmediateValue::F64), + )) + .parse_next(stream) +} + +pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> { + let lexer = Token::lexer(text); + let input = lexer.collect::<Result<Vec<_>, _>>().ok()?; + let mut errors = Vec::new(); + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &input[..], + }; + let parsing_result = module.parse(parser).ok(); + if !errors.is_empty() { + None + } else { + parsing_result + } +} + +pub fn parse_module_checked<'input>( + text: &'input str, +) -> Result<ast::Module<'input>, Vec<PtxError>> { + let mut lexer = Token::lexer(text); + let mut errors = Vec::new(); + let mut tokens = Vec::new(); + loop { + let maybe_token = match lexer.next() { + Some(maybe_token) => maybe_token, + None => break, + }; + match maybe_token { + Ok(token) => tokens.push(token), + Err(mut err) => { + err.0 = lexer.span(); + errors.push(PtxError::from(err)) + } + } + } + if !errors.is_empty() { + return Err(errors); + } + let parse_result = { + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &tokens[..], + }; + module + .parse(parser) + .map_err(|err| PtxError::Parser(err.into_inner())) + }; + match parse_result { + Ok(result) if errors.is_empty() => Ok(result), + Ok(_) => Err(errors), + Err(err) => { + errors.push(err); + Err(errors) + } + } +} + +fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> { + ( + version, + target, + opt(address_size), + repeat_without_none(directive), + eof, + ) + .map(|(version, _, _, directives, _)| ast::Module { + version, + directives, + }) + .parse_next(stream) +} + +fn address_size<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotAddressSize, u8_literal(64)) + .void() + .parse_next(stream) +} + +fn version<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u8, u8)> { + (Token::DotVersion, u8, Token::Dot, u8) + .map(|(_, major, _, minor)| (major, minor)) + .parse_next(stream) +} + +fn target<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, Option<char>)> { + preceded(Token::DotTarget, ident.and_then(shader_model)).parse_next(stream) +} + +fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option<char>)> { + ( + "sm_", + dec_uint, + opt(any.verify(|c: &char| c.is_ascii_lowercase())), + eof, + ) + .map(|(_, digits, arch_variant, _)| (digits, arch_variant)) + .parse_next(stream) +} + +fn directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Option<ast::Directive<'input, ast::ParsedOperand<&'input str>>>> { + alt(( + function.map(|(linking, func)| Some(ast::Directive::Method(linking, func))), + file.map(|_| None), + section.map(|_| None), + (module_variable, Token::Semicolon) + .map(|((linking, var), _)| Some(ast::Directive::Variable(linking, var))), + )) + .parse_next(stream) +} + +fn module_variable<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> { + let linking = linking_directives.parse_next(stream)?; + let var = global_space + .flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space)) + // TODO: support multi var in globals + .map(|multi_var| multi_var.var) + .parse_next(stream)?; + Ok((linking, var)) +} + +fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotFile, + u32, + Token::String, + opt((Token::Comma, u32, Token::Comma, u32)), + ) + .void() + .parse_next(stream) +} + +fn section<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotSection.void(), + dot_ident.void(), + Token::LBrace.void(), + repeat::<_, _, (), _, _>(0.., section_dwarf_line), + Token::RBrace.void(), + ) + .void() + .parse_next(stream) +} + +fn section_dwarf_line<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt(( + (section_label, Token::Colon).void(), + (Token::DotB32, section_label, opt((Token::Add, u32))).void(), + (Token::DotB64, section_label, opt((Token::Add, u32))).void(), + ( + any_bit_type, + separated::<_, _, (), _, _, _, _>(1.., u32, Token::Comma), + ) + .void(), + )) + .parse_next(stream) +} + +fn any_bit_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((Token::DotB8, Token::DotB16, Token::DotB32, Token::DotB64)) + .void() + .parse_next(stream) +} + +fn section_label<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((ident, dot_ident)).void().parse_next(stream) +} + +fn function<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<( + ast::LinkingDirective, + ast::Function<'input, &'input str, ast::Statement<ParsedOperand<&'input str>>>, +)> { + let (linking, function) = ( + linking_directives, + method_declaration, + repeat(0.., tuning_directive), + function_body, + ) + .map(|(linking, func_directive, tuning, body)| { + ( + linking, + ast::Function { + func_directive, + tuning, + body, + }, + ) + }) + .parse_next(stream)?; + stream.state.record_function(&function.func_directive); + Ok((linking, function)) +} + +fn linking_directives<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::LinkingDirective> { + repeat( + 0.., + dispatch! { any; + Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), + Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), + Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + _ => fail + }, + ) + .fold(|| ast::LinkingDirective::NONE, |x, y| x | y) + .parse_next(stream) +} + +fn tuning_directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::TuningDirective> { + dispatch! {any; + Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg), + Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), + Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), + Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm), + _ => fail + } + .parse_next(stream) +} + +fn method_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::MethodDeclaration<'input, &'input str>> { + dispatch! {any; + Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ + return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None + }), + Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } + }), + _ => fail + } + .parse_next(stream) +} + +fn fn_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Vec<ast::Variable<&'input str>>> { + delimited( + Token::LParen, + separated(0.., fn_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Vec<ast::Variable<&'input str>>> { + delimited( + Token::LParen, + separated(0.., kernel_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_input<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Variable<&'input str>> { + preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream) +} + +fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> { + dispatch! { any; + Token::DotParam => method_parameter(StateSpace::Param), + Token::DotReg => method_parameter(StateSpace::Reg), + _ => fail + } + .parse_next(stream) +} + +fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, u32, u32)> { + struct Tuple3AccumulateU32 { + index: usize, + value: (u32, u32, u32), + } + + impl Accumulate<u32> for Tuple3AccumulateU32 { + fn initial(_: Option<usize>) -> Self { + Self { + index: 0, + value: (1, 1, 1), + } + } + + fn accumulate(&mut self, value: u32) { + match self.index { + 0 => { + self.value = (value, self.value.1, self.value.2); + self.index = 1; + } + 1 => { + self.value = (self.value.0, value, self.value.2); + self.index = 2; + } + 2 => { + self.value = (self.value.0, self.value.1, value); + self.index = 3; + } + _ => unreachable!(), + } + } + } + + separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..=3, u32, Token::Comma) + .map(|acc| acc.value) + .parse_next(stream) +} + +fn function_body<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Option<Vec<ast::Statement<ParsedOperandStr<'input>>>>> { + dispatch! {any; + Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some), + Token::Semicolon => empty.map(|_| None), + _ => fail + } + .parse_next(stream) +} + +fn statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Option<Statement<ParsedOperandStr<'input>>>> { + alt(( + label.map(Some), + debug_directive.map(|_| None), + terminated( + method_space + .flat_map(|space| multi_variable(false, space)) + .map(|var| Some(Statement::Variable(var))), + Token::Semicolon, + ), + predicated_instruction.map(Some), + pragma.map(|_| None), + block_statement.map(Some), + )) + .parse_next(stream) +} + +fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotPragma, Token::String, Token::Semicolon) + .void() + .parse_next(stream) +} + +fn method_parameter<'a, 'input: 'a>( + state_space: StateSpace, +) -> impl Parser<PtxParser<'a, 'input>, Variable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let (align, vector, type_, name) = variable_declaration.parse_next(stream)?; + let array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + // TODO: push this check into array_dimensions(...) + if let Some(ref dims) = array_dimensions { + if dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: Vec::new(), + }) + } +} + +// TODO: split to a separate type +fn variable_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<(Option<u32>, Option<NonZeroU8>, ScalarType, &'input str)> { + ( + opt(align.verify(|x| x.count_ones() == 1)), + vector_prefix, + scalar_type, + ident, + ) + .parse_next(stream) +} + +fn multi_variable<'a, 'input: 'a>( + extern_: bool, + state_space: StateSpace, +) -> impl Parser<PtxParser<'a, 'input>, MultiVariable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let ((align, vector, type_, name), count) = ( + variable_declaration, + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names + opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)), + ) + .parse_next(stream)?; + if count.is_some() { + return Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_vector_parsed(vector, type_), + state_space, + name, + array_init: Vec::new(), + }, + count, + }); + } + let mut array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + let initializer = match state_space { + StateSpace::Global | StateSpace::Const => match array_dimensions { + Some(ref mut dimensions) => { + opt(array_initializer(vector, type_, dimensions)).parse_next(stream)? + } + None => opt(value_initializer(vector, type_)).parse_next(stream)?, + }, + _ => None, + }; + if let Some(ref dims) = array_dimensions { + if !extern_ && dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: initializer.unwrap_or(Vec::new()), + }, + count, + }) + } +} + +fn array_initializer<'a, 'input: 'a>( + vector: Option<NonZeroU8>, + type_: ScalarType, + array_dimensions: &mut Vec<u32>, +) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants and multi dim arrays + if vector.is_some() || array_dimensions[0] == 0 || array_dimensions.len() > 1 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + delimited( + Token::LBrace, + separated( + 0..=array_dimensions[0] as usize, + single_value_append(&mut result, type_), + Token::Comma, + ), + Token::RBrace, + ) + .parse_next(stream)?; + // pad with zeros + let result_size = type_.size_of() as usize * array_dimensions[0] as usize; + result.extend(iter::repeat(0u8).take(result_size - result.len())); + Ok(result) + } +} + +fn value_initializer<'a, 'input: 'a>( + vector: Option<NonZeroU8>, + type_: ScalarType, +) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants + if vector.is_some() { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + single_value_append(&mut result, type_).parse_next(stream)?; + Ok(result) + } +} + +fn single_value_append<'a, 'input: 'a>( + accumulator: &mut Vec<u8>, + type_: ScalarType, +) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + let value = immediate_value.parse_next(stream)?; + match (type_, value) { + (ScalarType::U8 | ScalarType::B8, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u8).to_le_bytes()) + } + (ScalarType::U8 | ScalarType::B8, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u8).to_le_bytes()) + } + (ScalarType::U16 | ScalarType::B16, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u16).to_le_bytes()) + } + (ScalarType::U16 | ScalarType::B16, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u16).to_le_bytes()) + } + (ScalarType::U32 | ScalarType::B32, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u32).to_le_bytes()) + } + (ScalarType::U32 | ScalarType::B32, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u32).to_le_bytes()) + } + (ScalarType::U64 | ScalarType::B64, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u64).to_le_bytes()) + } + (ScalarType::U64 | ScalarType::B64, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u64).to_le_bytes()) + } + (ScalarType::S8, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i8).to_le_bytes()) + } + (ScalarType::S8, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i8).to_le_bytes()) + } + (ScalarType::S16, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i16).to_le_bytes()) + } + (ScalarType::S16, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i16).to_le_bytes()) + } + (ScalarType::S32, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i32).to_le_bytes()) + } + (ScalarType::S32, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i32).to_le_bytes()) + } + (ScalarType::S64, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i64).to_le_bytes()) + } + (ScalarType::S64, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i64).to_le_bytes()) + } + (ScalarType::F32, ImmediateValue::F32(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + (ScalarType::F64, ImmediateValue::F64(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + _ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)), + } + Ok(()) + } +} + +fn array_dimensions<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Vec<u32>> { + let dimension = delimited( + Token::LBracket, + opt(u32).verify(|dim| *dim != Some(0)), + Token::RBracket, + ) + .parse_next(stream)?; + let result = vec![dimension.unwrap_or(0)]; + repeat_fold_0_or_more( + delimited( + Token::LBracket, + u32.verify(|dim| *dim != 0), + Token::RBracket, + ), + move || result, + |mut result: Vec<u32>, x| { + result.push(x); + result + }, + stream, + ) +} + +// Copied and fixed from Winnow sources (fold_repeat0_) +// Winnow Repeat::fold takes FnMut() -> Result to initalize accumulator, +// this really should be FnOnce() -> Result +fn repeat_fold_0_or_more<I, O, E, F, G, H, R>( + mut f: F, + init: H, + mut g: G, + input: &mut I, +) -> PResult<R, E> +where + I: Stream, + F: Parser<I, O, E>, + G: FnMut(R, O) -> R, + H: FnOnce() -> R, + E: ParserError<I>, +{ + use winnow::error::ErrMode; + let mut res = init(); + loop { + let start = input.checkpoint(); + match f.parse_next(input) { + Ok(o) => { + res = g(res, o); + } + Err(ErrMode::Backtrack(_)) => { + input.reset(&start); + return Ok(res); + } + Err(e) => { + return Err(e); + } + } + } +} + +fn global_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> { + alt(( + Token::DotGlobal.value(StateSpace::Global), + Token::DotConst.value(StateSpace::Const), + Token::DotShared.value(StateSpace::Shared), + )) + .parse_next(stream) +} + +fn method_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> { + alt(( + Token::DotReg.value(StateSpace::Reg), + Token::DotLocal.value(StateSpace::Local), + Token::DotParam.value(StateSpace::Param), + global_space, + )) + .parse_next(stream) +} + +fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> { + preceded(Token::DotAlign, u32).parse_next(stream) +} + +fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Option<NonZeroU8>> { + opt(alt(( + Token::DotV2.value(unsafe { NonZeroU8::new_unchecked(2) }), + Token::DotV4.value(unsafe { NonZeroU8::new_unchecked(4) }), + Token::DotV8.value(unsafe { NonZeroU8::new_unchecked(8) }), + ))) + .parse_next(stream) +} + +fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> { + any.verify_map(|t| { + Some(match t { + Token::DotS8 => ScalarType::S8, + Token::DotS16 => ScalarType::S16, + Token::DotS16x2 => ScalarType::S16x2, + Token::DotS32 => ScalarType::S32, + Token::DotS64 => ScalarType::S64, + Token::DotU8 => ScalarType::U8, + Token::DotU16 => ScalarType::U16, + Token::DotU16x2 => ScalarType::U16x2, + Token::DotU32 => ScalarType::U32, + Token::DotU64 => ScalarType::U64, + Token::DotB8 => ScalarType::B8, + Token::DotB16 => ScalarType::B16, + Token::DotB32 => ScalarType::B32, + Token::DotB64 => ScalarType::B64, + Token::DotB128 => ScalarType::B128, + Token::DotPred => ScalarType::Pred, + Token::DotF16 => ScalarType::F16, + Token::DotF16x2 => ScalarType::F16x2, + Token::DotF32 => ScalarType::F32, + Token::DotF64 => ScalarType::F64, + Token::DotBF16 => ScalarType::BF16, + Token::DotBF16x2 => ScalarType::BF16x2, + _ => return None, + }) + }) + .parse_next(stream) +} + +fn predicated_instruction<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Statement<ParsedOperandStr<'input>>> { + (opt(pred_at), parse_instruction, Token::Semicolon) + .map(|(p, i, _)| ast::Statement::Instruction(p, i)) + .parse_next(stream) +} + +fn pred_at<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::PredAt<&'input str>> { + (Token::At, opt(Token::Exclamation), ident) + .map(|(_, not, label)| ast::PredAt { + not: not.is_some(), + label, + }) + .parse_next(stream) +} + +fn label<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Statement<ParsedOperandStr<'input>>> { + terminated(ident, Token::Colon) + .map(|l| ast::Statement::Label(l)) + .parse_next(stream) +} + +fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotLoc, + u32, + u32, + u32, + opt(( + Token::Comma, + ident_literal("function_name"), + ident, + dispatch! { any; + Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(), + Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), + _ => fail + }, + )), + ) + .void() + .parse_next(stream) +} + +fn block_statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Statement<ParsedOperandStr<'input>>> { + delimited(Token::LBrace, repeat_without_none(statement), Token::RBrace) + .map(|s| ast::Statement::Block(s)) + .parse_next(stream) +} + +fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>( + parser: impl Parser<Input, Option<Output>, Error>, +) -> impl Parser<Input, Vec<Output>, Error> { + repeat(0.., parser).fold(Vec::new, |mut acc: Vec<_>, item| { + if let Some(item) = item { + acc.push(item); + } + acc + }) +} + +fn ident_literal< + 'a, + 'input, + I: Stream<Token = Token<'input>> + StreamIsPartial, + E: ParserError<I>, +>( + s: &'input str, +) -> impl Parser<I, (), E> + 'input { + move |stream: &mut I| { + any.verify(|t| matches!(t, Token::Ident(text) if *text == s)) + .void() + .parse_next(stream) + } +} + +fn u8_literal<'a, 'input>(x: u8) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> { + move |stream: &mut PtxParser| u8.verify(|t| *t == x).void().parse_next(stream) +} + +impl<Ident> ast::ParsedOperand<Ident> { + fn parse<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult<ast::ParsedOperand<&'input str>> { + use winnow::combinator::*; + use winnow::token::any; + fn vector_index<'input>(inp: &'input str) -> Result<u8, PtxError> { + match inp { + ".x" | ".r" => Ok(0), + ".y" | ".g" => Ok(1), + ".z" | ".b" => Ok(2), + ".w" | ".a" => Ok(3), + _ => Err(PtxError::WrongVectorElement), + } + } + fn ident_operands<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult<ast::ParsedOperand<&'input str>> { + let main_ident = ident.parse_next(stream)?; + alt(( + preceded(Token::Plus, s32) + .map(move |offset| ast::ParsedOperand::RegOffset(main_ident, offset)), + take_error(dot_ident.map(move |suffix| { + let vector_index = vector_index(suffix) + .map_err(move |e| (ast::ParsedOperand::VecMember(main_ident, 0), e))?; + Ok(ast::ParsedOperand::VecMember(main_ident, vector_index)) + })), + empty.value(ast::ParsedOperand::Reg(main_ident)), + )) + .parse_next(stream) + } + fn vector_operand<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult<Vec<&'input str>> { + let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; + // TODO: parse .v8 literals + dispatch! {any; + Token::RBrace => empty.map(|_| vec![r1, r2]), + Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + _ => fail + } + .parse_next(stream) + } + alt(( + ident_operands, + immediate_value.map(ast::ParsedOperand::Imm), + vector_operand.map(ast::ParsedOperand::VecPack), + )) + .parse_next(stream) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum PtxError { + #[error("{source}")] + ParseInt { + #[from] + source: ParseIntError, + }, + #[error("{source}")] + ParseFloat { + #[from] + source: ParseFloatError, + }, + #[error("{source}")] + Lexer { + #[from] + source: TokenError, + }, + #[error("")] + Parser(ContextError), + #[error("")] + Todo, + #[error("")] + SyntaxError, + #[error("")] + NonF32Ftz, + #[error("")] + Unsupported32Bit, + #[error("")] + WrongType, + #[error("")] + UnknownFunction, + #[error("")] + MalformedCall, + #[error("")] + WrongArrayType, + #[error("")] + WrongVectorElement, + #[error("")] + MultiArrayVariable, + #[error("")] + ZeroDimensionArray, + #[error("")] + ArrayInitalizer, + #[error("")] + NonExternPointer, + #[error("{start}:{end}")] + UnrecognizedStatement { start: usize, end: usize }, + #[error("{start}:{end}")] + UnrecognizedDirective { start: usize, end: usize }, +} + +#[derive(Debug)] +struct ReverseStream<'a, T>(pub &'a [T]); + +impl<'i, T> Stream for ReverseStream<'i, T> +where + T: Clone + ::std::fmt::Debug, +{ + type Token = T; + type Slice = &'i [T]; + + type IterOffsets = + std::iter::Enumerate<std::iter::Cloned<std::iter::Rev<std::slice::Iter<'i, T>>>>; + + type Checkpoint = &'i [T]; + + fn iter_offsets(&self) -> Self::IterOffsets { + self.0.iter().rev().cloned().enumerate() + } + + fn eof_offset(&self) -> usize { + self.0.len() + } + + fn next_token(&mut self) -> Option<Self::Token> { + let (token, next) = self.0.split_last()?; + self.0 = next; + Some(token.clone()) + } + + fn offset_for<P>(&self, predicate: P) -> Option<usize> + where + P: Fn(Self::Token) -> bool, + { + self.0.iter().rev().position(|b| predicate(b.clone())) + } + + fn offset_at(&self, tokens: usize) -> Result<usize, winnow::error::Needed> { + if let Some(needed) = tokens + .checked_sub(self.0.len()) + .and_then(std::num::NonZeroUsize::new) + { + Err(winnow::error::Needed::Size(needed)) + } else { + Ok(tokens) + } + } + + fn next_slice(&mut self, offset: usize) -> Self::Slice { + let offset = self.0.len() - offset; + let (next, slice) = self.0.split_at(offset); + self.0 = next; + slice + } + + fn checkpoint(&self) -> Self::Checkpoint { + self.0 + } + + fn reset(&mut self, checkpoint: &Self::Checkpoint) { + self.0 = checkpoint; + } + + fn raw(&self) -> &dyn std::fmt::Debug { + self + } +} + +impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> { + fn offset_from(&self, start: &&'a [T]) -> usize { + let fst = start.as_ptr(); + let snd = self.0.as_ptr(); + + debug_assert!( + snd <= fst, + "`Offset::offset_from({snd:?}, {fst:?})` only accepts slices of `self`" + ); + (fst as usize - snd as usize) / std::mem::size_of::<T>() + } +} + +impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { + type PartialState = (); + + fn complete(&mut self) -> Self::PartialState {} + + fn restore_partial(&mut self, _state: Self::PartialState) {} + + fn is_partial_supported() -> bool { + false + } +} + +impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parser<I, Self, E> + for Token<'input> +{ + fn parse_next(&mut self, input: &mut I) -> PResult<Self, E> { + any.verify(|t| t == self).parse_next(input) + } +} + +fn bra<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> { + preceded( + opt(Token::DotUni), + any.verify_map(|t| match t { + Token::Ident(ident) => Some(ast::Instruction::Bra { + arguments: BraArgs { src: ident }, + }), + _ => None, + }), + ) + .parse_next(stream) +} + +fn call<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> { + let (uni, return_arguments, name, input_arguments) = ( + opt(Token::DotUni), + opt(( + Token::LParen, + separated(1.., ident, Token::Comma).map(|x: Vec<_>| x), + Token::RParen, + Token::Comma, + ) + .map(|(_, arguments, _, _)| arguments)), + ident, + opt(( + Token::Comma.void(), + Token::LParen.void(), + separated(1.., ParsedOperand::<&'input str>::parse, Token::Comma).map(|x: Vec<_>| x), + Token::RParen.void(), + ) + .map(|(_, _, arguments, _)| arguments)), + ) + .parse_next(stream)?; + let uniform = uni.is_some(); + let recorded_fn = match stream.state.function_declarations.get(name) { + Some(decl) => decl, + None => { + stream.state.errors.push(PtxError::UnknownFunction); + return Ok(empty_call(uniform, name)); + } + }; + let return_arguments = return_arguments.unwrap_or(Vec::new()); + let input_arguments = input_arguments.unwrap_or(Vec::new()); + if recorded_fn.0.len() != return_arguments.len() || recorded_fn.1.len() != input_arguments.len() + { + stream.state.errors.push(PtxError::MalformedCall); + return Ok(empty_call(uniform, name)); + } + let data = CallDetails { + uniform, + return_arguments: recorded_fn.0.clone(), + input_arguments: recorded_fn.1.clone(), + }; + let arguments = CallArgs { + return_arguments, + func: name, + input_arguments, + }; + Ok(ast::Instruction::Call { data, arguments }) +} + +fn empty_call<'input>( + uniform: bool, + name: &'input str, +) -> ast::Instruction<ParsedOperandStr<'input>> { + ast::Instruction::Call { + data: CallDetails { + uniform, + return_arguments: Vec::new(), + input_arguments: Vec::new(), + }, + arguments: CallArgs { + return_arguments: Vec::new(), + func: name, + input_arguments: Vec::new(), + }, + } +} + +type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; + +#[derive(Clone, PartialEq, Default, Debug, Display)] +#[display("({}:{})", _0.start, _0.end)] +pub struct TokenError(std::ops::Range<usize>); + +impl std::error::Error for TokenError {} + +// This macro is responsible for generating parser code for instruction parser. +// Instruction parsing is by far the most complex part of parsing PTX code: +// * There are tens of instruction kinds, each with slightly different parsing rules +// * After parsing, each instruction needs to do some early validation and generate a specific, +// strongly-typed object. We want strong-typing because we have a single PTX parser frontend, but +// there can be multiple different code emitter backends +// * Most importantly, instruction modifiers can come in aby order, so e.g. both +// `ld.relaxed.global.u32 a, b` and `ld.global.relaxed.u32 a, b` are equally valid. This makes +// classic parsing generators fail: if we tried to generate parsing rules that cover every possible +// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang +// will always emit modifiers in the correct order, but people who write inline assembly usually +// get it wrong (even first party developers) +// +// This macro exists purely to generate repetitive code for parsing each instruction. It is +// _not_ self-contained and is _not_ general-purpose: it relies on certain types and functions from +// the enclosing module +// +// derive_parser!(...) input is split into three parts: +// * Token type definition +// * Partial enums +// * Parsing definitions +// +// Token type definition: +// This is the enum type that will be usesby the instruction parser. For every instruction and +// modifier, derive_parser!(...) will add appropriate variant into this type. So e.g. if there is a +// rule for for `bar.sync` then those two variants wil be appended to the Token enum: +// #[token("bar")] Bar, +// #[token(".sync")] DotSync, +// +// Partial enums: +// With proper annotations, derive_parser!(...) parsing definitions are able to interpret +// instruction modifiers as variants of a single enum type. So e.g. for definitions `ld.u32` and +// `ld.u64` the macro can generate `enum ScalarType { U32, U64 }`. The problem is that for some +// (but not all) of those generated enum types we want to add some attributes and additional +// variants. In order to do so, you need to define this enum and derive_parser!(...) will append to +// the type instead of creating a new type. This is sort of replacement for partial classes known +// from C# +// +// Parsing definitions: +// Parsing definitions consist of a list of patterns and rules: +// * Pattern consists of: +// * Opcode: `ld` +// * Modifiers, always start with a dot: `.global`, `.relaxed`. Optionals are enclosed in braces +// * Arguments: `a`, `b`. Optionals are enclosed in braces +// * Code block: => { <code expression> }. Code blocks implictly take all modifiers ansd arguments +// as parameters. All modifiers and arguments are passed to the code block: +// * If it is an alternative (as defined in rules list later): +// * If it is mandatory then its type is Foo (as defined by the relevant rule) +// * If it is optional then its type is Option<Foo> +// * Otherwise: +// * If it is mandatory then it is skipped +// * If it is optional then its type is `bool` +// * List of rules. They are associated with the preceding patterns (until different opcode or +// different rules). Rules are used to resolve modifiers. There are two types of rules: +// * Normal rule: `.foobar: FoobarEnum => { .a, .b, .c }`. This means that instead of `.foobar` we +// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB, +// FoobarEnum::DotC appropriately +// * Type-only rule: `FoobarEnum => { .a, .b, .c }` this means that all the occurences of `.a` will +// emit FoobarEnum::DotA to the code block. This helps to avoid copy-paste errors +// Additionally, you can opt out from the usual parsing rule generation with a special `<=` pattern. +// See `call` instruction to see it in action +derive_parser!( + #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] + #[logos(skip r"(?:\s+)|(?://[^\n\r]*[\n\r]*)|(?:/\*[^*]*\*+(?:[^/*][^*]*\*+)*/)")] + #[logos(error = TokenError)] + enum Token<'input> { + #[token(",")] + Comma, + #[token(".")] + Dot, + #[token(":")] + Colon, + #[token(";")] + Semicolon, + #[token("@")] + At, + #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + Ident(&'input str), + #[regex(r"\.[a-zA-Z][a-zA-Z0-9_$]*|\.[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + DotIdent(&'input str), + #[regex(r#""[^"]*""#)] + String, + #[token("|")] + Pipe, + #[token("!")] + Exclamation, + #[token("(")] + LParen, + #[token(")")] + RParen, + #[token("[")] + LBracket, + #[token("]")] + RBracket, + #[token("{")] + LBrace, + #[token("}")] + RBrace, + #[token("<")] + Lt, + #[token(">")] + Gt, + #[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())] + F32(&'input str), + #[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())] + F64(&'input str), + #[regex(r"0[xX][0-9a-zA-Z]+U?", |lex| lex.slice())] + Hex(&'input str), + #[regex(r"[0-9]+U?", |lex| lex.slice())] + Decimal(&'input str), + #[token("-")] + Minus, + #[token("+")] + Plus, + #[token("=")] + Eq, + #[token(".version")] + DotVersion, + #[token(".loc")] + DotLoc, + #[token(".reg")] + DotReg, + #[token(".align")] + DotAlign, + #[token(".pragma")] + DotPragma, + #[token(".maxnreg")] + DotMaxnreg, + #[token(".maxntid")] + DotMaxntid, + #[token(".reqntid")] + DotReqntid, + #[token(".minnctapersm")] + DotMinnctapersm, + #[token(".entry")] + DotEntry, + #[token(".func")] + DotFunc, + #[token(".extern")] + DotExtern, + #[token(".visible")] + DotVisible, + #[token(".target")] + DotTarget, + #[token(".address_size")] + DotAddressSize, + #[token(".action")] + DotSection, + #[token(".file")] + DotFile + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum StateSpace { + Reg, + Generic, + Sreg, + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum MemScope { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum ScalarType { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum SetpBoolPostOp { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum AtomSemantics { } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov + mov{.vec}.type d, a => { + Instruction::Mov { + data: ast::MovDetails::new(vec, type_), + arguments: MovArgs { dst: d, src: a }, + } + } + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .pred, + .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st + st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawStCacheOperator::Wb).into(), + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.volatile{.ss}{.vec}.type [a], b => { + Instruction::St { + data: StData { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Release(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.mmio.relaxed.sys{.global}.type [a], b => { + state.errors.push(PtxError::Todo); + Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: type_.into() + }, + arguments: ast::StArgs { src1:a, src2:b } + } + } + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .cop: RawStCacheOperator = { .wb, .cg, .cs, .wt }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld + ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { + let (a, unified) = a; + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { + if level_prefetch_size.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Acquire(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.mmio.relaxed.sys{.global}.type d, [a] => { + state.errors.push(PtxError::Todo); + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: type_.into(), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; + .cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld-global-nc + ld.global{.cop}.nc{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if cop.is_some() && level_eviction_priority.is_some() { + state.errors.push(PtxError::SyntaxError); + } + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: global, + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: Type::maybe_vector(vec, type_), + non_coherent: true + }, + arguments: LdArgs { dst:d, src:a } + } + } + .cop: RawLdCacheOperator = { .ca, .cg, .cs }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, + .L1::evict_first, .L1::evict_last, .L1::no_allocate}; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add + add.type d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.sat}.s32 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_: s32, + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s64, + .u16x2, .s16x2 }; + ScalarType = { .s32 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f32 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.f64 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f16 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.bf16 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.bf16x2 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul + mul.mode.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: mode.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .mode: RawMulIntControl = { .hi, .lo }; + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + // "The .wide suffix is supported only for 16- and 32-bit integer types" + mul.wide.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: wide.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.f64 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp + setp.CmpOp{.ftz}.type p[|q], a, b => { + let data = ast::SetpData::try_parse(state, cmpop, ftz, type_); + ast::Instruction::Setp { + data, + arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b } + } + } + setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => { + let (negate_src3, c) = c; + let base = ast::SetpData::try_parse(state, cmpop, ftz, type_); + let data = ast::SetpBoolData { + base, + bool_op: boolop, + negate_src3 + }; + ast::Instruction::SetpBool { + data, + arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c } + } + } + .CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge, + .lo, .ls, .hi, .hs, // signed + .equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only + .BoolOp: SetpBoolPostOp = { .and, .or, .xor }; + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64, + .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not + not.type d, a => { + ast::Instruction::Not { + data: type_, + arguments: NotArgs { dst: d, src: a } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or + or.type d, a, b => { + ast::Instruction::Or { + data: type_, + arguments: OrArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and + and.type d, a, b => { + ast::Instruction::And { + data: type_, + arguments: AndArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra + bra <= { bra(stream) } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call + call <= { call(stream) } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt + cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => { + let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype); + let arguments = ast::CvtArgs { dst: d, src: a }; + ast::Instruction::Cvt { + data, arguments + } + } + // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b; + // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; + // cvt.rna{.satfinite}.tf32.f32 d, a; + // cvt.frnd2{.relu}.tf32.f32 d, a; + // cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b; + // cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a; + // cvt.rn.{.relu}.f16x2.f8x2type d, a; + + .ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi }; + .frnd2: RawRoundingMode = { .rn, .rz }; + .dtype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + .atype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl + shl.type d, a, b => { + ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } } + } + .type: ScalarType = { .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr + shr.type d, a, b => { + let kind = if type_.kind() == ast::ScalarKind::Signed { RightShiftKind::Arithmetic} else { RightShiftKind::Logical }; + ast::Instruction::Shr { + data: ast::ShrData { type_, kind }, + arguments: ShrArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta + cvta.space.size p, a => { + if size != ScalarType::U64 { + state.errors.push(PtxError::Unsupported32Bit); + } + let data = ast::CvtaDetails { + state_space: space, + direction: ast::CvtaDirection::ExplicitToGeneric + }; + let arguments = ast::CvtaArgs { + dst: p, src: a + }; + ast::Instruction::Cvta { + data, arguments + } + } + cvta.to.space.size p, a => { + if size != ScalarType::U64 { + state.errors.push(PtxError::Unsupported32Bit); + } + let data = ast::CvtaDetails { + state_space: space, + direction: ast::CvtaDirection::GenericToExplicit + }; + let arguments = ast::CvtaArgs { + dst: p, src: a + }; + ast::Instruction::Cvta { + data, arguments + } + } + .space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} }; + .size: ScalarType = { .u32, .u64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs + abs.type d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_ + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f32 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.f64 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16x2 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: bf16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16x2 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: bf16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad + mad.mode.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: mode.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + .mode: RawMulIntControl = { .hi, .lo }; + + // The .wide suffix is supported only for 16-bit and 32-bit integer types. + mad.wide.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: wide.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + mad.hi.sat.s32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_: s32, + control: hi.into(), + saturate: true + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + RawMulIntControl = { .hi }; + ScalarType = { .s32 }; + + mad{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: None, + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd.f64 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma + fma.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + fma.rnd.f64 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + fma.rnd{.ftz}{.sat}.f16 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f16, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + //fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c; + //fma.rnd{.ftz}.relu.f16 d, a, b, c; + //fma.rnd{.ftz}.relu.f16x2 d, a, b, c; + //fma.rnd{.relu}.bf16 d, a, b, c; + //fma.rnd{.relu}.bf16x2 d, a, b, c; + //fma.rnd.oob.{relu}.type d, a, b, c; + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub + sub.type d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub.sat.s32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_: s32, + saturate: true + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + ScalarType = { .s32 }; + + sub{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.f64 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + sub{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min + min.atype d, a, b => { + ast::Instruction::Min { + data: if atype.kind() == ast::ScalarKind::Signed { + ast::MinMaxDetails::Signed(atype) + } else { + ast::MinMaxDetails::Unsigned(atype) + }, + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + //min{.relu}.btype d, a, b => { todo!() } + min.btype d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Signed(btype), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + .atype: ScalarType = { .u16, .u32, .u64, + .u16x2, .s16, .s64 }; + .btype: ScalarType = { .s16x2, .s32 }; + + //min{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b; + min{.ftz}{.NaN}.f32 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f32 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min.f64 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan: false, + type_: f64 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f32, .f64 }; + + //min{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b; + //min{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b; + //min{.NaN}{.xorsign.abs}.bf16 d, a, b; + //min{.NaN}{.xorsign.abs}.bf16x2 d, a, b; + min{.ftz}{.NaN}.f16 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.ftz}{.NaN}.f16x2 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16x2 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.NaN}.bf16 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.NaN}.bf16x2 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16x2 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max + max.atype d, a, b => { + ast::Instruction::Max { + data: if atype.kind() == ast::ScalarKind::Signed { + ast::MinMaxDetails::Signed(atype) + } else { + ast::MinMaxDetails::Unsigned(atype) + }, + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + //max{.relu}.btype d, a, b => { todo!() } + max.btype d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Signed(btype), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + .atype: ScalarType = { .u16, .u32, .u64, + .u16x2, .s16, .s64 }; + .btype: ScalarType = { .s16x2, .s32 }; + + //max{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b; + max{.ftz}{.NaN}.f32 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f32 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max.f64 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan: false, + type_: f64 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f32, .f64 }; + + //max{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b; + //max{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b; + //max{.NaN}{.xorsign.abs}.bf16 d, a, b; + //max{.NaN}{.xorsign.abs}.bf16x2 d, a, b; + max{.ftz}{.NaN}.f16 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.ftz}{.NaN}.f16x2 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16x2 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.NaN}.bf16 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.NaN}.bf16x2 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16x2 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp-approx-ftz-f64 + rcp.approx{.ftz}.type d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_ + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd{.ftz}.f32 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd.f64 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + .type: ScalarType = { .f32, .f64 }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt + sqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd.f64 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64 + rsqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.ftz.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp + selp.type d, a, b, c => { + ast::Instruction::Selp { + data: type_, + arguments: SelpArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar + barrier{.cta}.sync{.aligned} a{, b} => { + let _ = cta; + ast::Instruction::Bar { + data: ast::BarData { aligned }, + arguments: BarArgs { src1: a, src2: b } + } + } + //barrier{.cta}.arrive{.aligned} a, b; + //barrier{.cta}.red.popc{.aligned}.u32 d, a{, b}, {!}c; + //barrier{.cta}.red.op{.aligned}.pred p, a{, b}, {!}c; + bar{.cta}.sync a{, b} => { + let _ = cta; + ast::Instruction::Bar { + data: ast::BarData { aligned: true }, + arguments: BarArgs { src1: a, src2: b } + } + } + //bar{.cta}.arrive a, b; + //bar{.cta}.red.popc.u32 d, a{, b}, {!}c; + //bar{.cta}.red.op.pred p, a{, b}, {!}c; + //.op = { .and, .or }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom + atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(op, type_.kind()), + type_: type_.into() + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.space}.cas.cas_type d, [a], b, c => { + ast::Instruction::AtomCas { + data: AtomCasDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + type_: cas_type + }, + arguments: AtomCasArgs { dst: d, src1: a, src2: b, src3: c } + } + } + atom{.sem}{.scope}{.space}.exch{.level::cache_hint}.b128 d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(exch, b128.kind()), + type_: b128.into() + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op{.level::cache_hint}.vec_32_bit.f32 d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, f32.kind()), + type_: ast::Type::Vector(vec_32_bit.len().get(), f32) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_16_bit}.half_word_type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, half_word_type.kind()), + type_: ast::Type::maybe_vector(vec_16_bit, half_word_type) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_32_bit}.packed_type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, packed_type.kind()), + type_: ast::Type::maybe_vector(vec_32_bit, packed_type) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + .space: StateSpace = { .global, .shared{::cta, ::cluster} }; + .sem: AtomSemantics = { .relaxed, .acquire, .release, .acq_rel }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .op: RawAtomicOp = { .and, .or, .xor, + .exch, + .add, .inc, .dec, + .min, .max }; + .level::cache_hint = { .L2::cache_hint }; + .type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64 }; + .cas_type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64, .b16, .b128 }; + .half_word_type: ScalarType = { .f16, .bf16 }; + .packed_type: ScalarType = { .f16x2, .bf16x2 }; + .vec_16_bit: VectorPrefix = { .v2, .v4, .v8 }; + .vec_32_bit: VectorPrefix = { .v2, .v4 }; + .float_op: RawAtomicOp = { .add, .min, .max }; + ScalarType = { .b16, .b128, .f32 }; + StateSpace = { .global }; + RawAtomicOp = { .exch }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div + div.type d, a, b => { + ast::Instruction::Div { + data: if type_.kind() == ast::ScalarKind::Signed { + ast::DivDetails::Signed(type_) + } else { + ast::DivDetails::Unsigned(type_) + }, + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + + div.approx{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::Approx + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.full{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::ApproxFull + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.rnd{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::Rounding(rnd.into()) + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.rnd.f64 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f64, + flush_to_zero: None, + kind: ast::DivFloatKind::Rounding(rnd.into()) + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg + neg.type d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + + neg{.ftz}.f32 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f32, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.f64 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f64, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg{.ftz}.f16 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f16, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg{.ftz}.f16x2 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f16x2, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.bf16 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: bf16, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.bf16x2 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: bf16x2, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sin + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-cos + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-lg2 + sin.approx{.ftz}.f32 d, a => { + ast::Instruction::Sin { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: SinArgs { dst: d, src: a, }, + } + } + cos.approx{.ftz}.f32 d, a => { + ast::Instruction::Cos { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: CosArgs { dst: d, src: a, }, + } + } + lg2.approx{.ftz}.f32 d, a => { + ast::Instruction::Lg2 { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: Lg2Args { dst: d, src: a, }, + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-ex2 + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-ex2 + ex2.approx{.ftz}.f32 d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: f32, + flush_to_zero: Some(ftz) + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + ex2.approx.atype d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: atype, + flush_to_zero: None + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + ex2.approx.ftz.btype d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: btype, + flush_to_zero: Some(true) + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + .atype: ScalarType = { .f16, .f16x2 }; + .btype: ScalarType = { .bf16, .bf16x2 }; + ScalarType = { .f32 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz + clz.type d, a => { + ast::Instruction::Clz { + data: type_, + arguments: ClzArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev + brev.type d, a => { + ast::Instruction::Brev { + data: type_, + arguments: BrevArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc + popc.type d, a => { + ast::Instruction::Popc { + data: type_, + arguments: PopcArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-xor + xor.type d, a, b => { + ast::Instruction::Xor { + data: type_, + arguments: XorArgs { dst: d, src1: a, src2: b, }, + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem + rem.type d, a, b => { + ast::Instruction::Rem { + data: type_, + arguments: RemArgs { dst: d, src1: a, src2: b, }, + } + } + .type: ScalarType = { .u16, .u32, .u64, .s16, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfe + bfe.type d, a, b, c => { + ast::Instruction::Bfe { + data: type_, + arguments: BfeArgs { dst: d, src1: a, src2: b, src3: c }, + } + } + .type: ScalarType = { .u32, .u64, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfi + bfi.type f, a, b, c, d => { + ast::Instruction::Bfi { + data: type_, + arguments: BfiArgs { dst: f, src1: a, src2: b, src3: c, src4: d }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt + // prmt.b32{.mode} d, a, b, c; + // .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 }; + prmt.b32 d, a, b, c => { + match c { + ast::ParsedOperand::Imm(ImmediateValue::S64(control)) => ast::Instruction::Prmt { + data: control as u16, + arguments: PrmtArgs { + dst: d, src1: a, src2: b + } + }, + _ => ast::Instruction::PrmtSlow { + arguments: PrmtSlowArgs { + dst: d, src1: a, src2: b, src3: c + } + } + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask + activemask.b32 d => { + ast::Instruction::Activemask { + arguments: ActivemaskArgs { dst: d } + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar + // fence{.sem}.scope; + // fence.op_restrict.release.cluster; + // fence.proxy.proxykind; + // fence.proxy.to_proxykind::from_proxykind.release.scope; + // fence.proxy.to_proxykind::from_proxykind.acquire.scope [addr], size; + //membar.proxy.proxykind; + //.sem = { .sc, .acq_rel }; + //.scope = { .cta, .cluster, .gpu, .sys }; + //.proxykind = { .alias, .async, async.global, .async.shared::{cta, cluster} }; + //.op_restrict = { .mbarrier_init }; + //.to_proxykind::from_proxykind = {.tensormap::generic}; + + membar.level => { + ast::Instruction::Membar { data: level } + } + membar.gl => { + ast::Instruction::Membar { data: MemScope::Gpu } + } + .level: MemScope = { .cta, .sys }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret + ret{.uni} => { + Instruction::Ret { data: RetData { uniform: uni } } + } + +); + +#[cfg(test)] +mod tests { + use super::target; + use super::PtxParserState; + use super::Token; + use logos::Logos; + use winnow::prelude::*; + + #[test] + fn sm_11() { + let tokens = Token::lexer(".target sm_11") + .collect::<Result<Vec<_>, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (11, None)); + } + + #[test] + fn sm_90a() { + let tokens = Token::lexer(".target sm_90a") + .collect::<Result<Vec<_>, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); + } + + #[test] + fn sm_90ab() { + let tokens = Token::lexer(".target sm_90ab") + .collect::<Result<Vec<_>, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert!(target.parse(stream).is_err()); + } +} diff --git a/ptx_parser_macros/Cargo.toml b/ptx_parser_macros/Cargo.toml new file mode 100644 index 0000000..62a5081 --- /dev/null +++ b/ptx_parser_macros/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "ptx_parser_macros" +version = "0.0.0" +authors = ["Andrzej Janik <[email protected]>"] +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +ptx_parser_macros_impl = { path = "../ptx_parser_macros_impl" } +convert_case = "0.6.0" +rustc-hash = "2.0.0" +syn = "2.0.67" +quote = "1.0" +proc-macro2 = "1.0.86" +either = "1.13.0" diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs new file mode 100644 index 0000000..5f47fac --- /dev/null +++ b/ptx_parser_macros/src/lib.rs @@ -0,0 +1,1023 @@ +use either::Either;
+use ptx_parser_macros_impl::parser;
+use proc_macro2::{Span, TokenStream};
+use quote::{format_ident, quote, ToTokens};
+use rustc_hash::{FxHashMap, FxHashSet};
+use std::{collections::hash_map, hash::Hash, iter, rc::Rc};
+use syn::{
+ parse_macro_input, parse_quote, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath,
+ Variant,
+};
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-floating-point-data-types
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-integer-data-types
+#[rustfmt::skip]
+static POSTFIX_MODIFIERS: &[&str] = &[
+ ".v2", ".v4", ".v8",
+ ".s8", ".s16", ".s16x2", ".s32", ".s64",
+ ".u8", ".u16", ".u16x2", ".u32", ".u64",
+ ".f16", ".f16x2", ".f32", ".f64",
+ ".b8", ".b16", ".b32", ".b64", ".b128",
+ ".pred",
+ ".bf16", ".bf16x2", ".e4m3", ".e5m2", ".tf32",
+];
+
+static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"];
+
+struct OpcodeDefinitions {
+ definitions: Vec<SingleOpcodeDefinition>,
+ block_selection: Vec<Vec<(Option<Vec<parser::DotModifier>>, usize)>>,
+}
+
+impl OpcodeDefinitions {
+ fn new(opcode: &Ident, definitions: Vec<SingleOpcodeDefinition>) -> Self {
+ let mut selections = vec![None; definitions.len()];
+ let mut generation = 0usize;
+ loop {
+ let mut selected_something = false;
+ let unselected = selections
+ .iter()
+ .enumerate()
+ .filter_map(|(idx, s)| if s.is_none() { Some(idx) } else { None })
+ .collect::<Vec<_>>();
+ match &*unselected {
+ [] => break,
+ [remaining] => {
+ selections[*remaining] = Some((None, generation));
+ break;
+ }
+ _ => {}
+ }
+ 'check_definitions: for i in unselected.iter().copied() {
+ let mut candidates = definitions[i]
+ .unordered_modifiers
+ .iter()
+ .chain(definitions[i].ordered_modifiers.iter())
+ .filter(|modifier| match modifier {
+ DotModifierRef::Direct {
+ optional: false, ..
+ }
+ | DotModifierRef::Indirect {
+ optional: false, ..
+ } => true,
+ _ => false,
+ })
+ .collect::<Vec<_>>();
+ candidates.sort_by_key(|modifier| match modifier {
+ DotModifierRef::Direct { .. } => 1,
+ DotModifierRef::Indirect { value, .. } => value.alternatives.len(),
+ });
+ // Attempt every modifier
+ 'check_candidates: for candidate_modifier in candidates {
+ // check all other unselected patterns
+ for j in unselected.iter().copied() {
+ if i == j {
+ continue;
+ }
+ let candidate_set = match candidate_modifier {
+ DotModifierRef::Direct { value, .. } => Either::Left(iter::once(value)),
+ DotModifierRef::Indirect { value, .. } => {
+ Either::Right(value.alternatives.iter())
+ }
+ };
+ for candidate_value in candidate_set {
+ if definitions[j].possible_modifiers.contains(candidate_value) {
+ continue 'check_candidates;
+ }
+ }
+ }
+ // it's unique
+ let candidate_vec = match candidate_modifier {
+ DotModifierRef::Direct { value, .. } => vec![value.clone()],
+ DotModifierRef::Indirect { value, .. } => {
+ value.alternatives.iter().cloned().collect::<Vec<_>>()
+ }
+ };
+ selections[i] = Some((Some(candidate_vec), generation));
+ selected_something = true;
+ continue 'check_definitions;
+ }
+ }
+ if !selected_something {
+ panic!(
+ "Failed to generate pattern selection for `{}`. State: {:?}",
+ opcode,
+ selections.into_iter().rev().collect::<Vec<_>>()
+ );
+ }
+ generation += 1;
+ }
+ let mut block_selection = Vec::new();
+ for current_generation in 0usize.. {
+ let mut current_generation_definitions = Vec::new();
+ for (idx, selection) in selections.iter_mut().enumerate() {
+ match selection {
+ Some((modifier_set, generation)) => {
+ if *generation == current_generation {
+ current_generation_definitions.push((modifier_set.clone(), idx));
+ *selection = None;
+ }
+ }
+ None => {}
+ }
+ }
+ if current_generation_definitions.is_empty() {
+ break;
+ }
+ block_selection.push(current_generation_definitions);
+ }
+ #[cfg(debug_assertions)]
+ {
+ let selected = block_selection
+ .iter()
+ .map(|x| x.len())
+ .reduce(|x, y| x + y)
+ .unwrap();
+ if selected != definitions.len() {
+ panic!(
+ "Internal error when generating pattern selection for `{}`: {:?}",
+ opcode, &block_selection
+ );
+ }
+ }
+ Self {
+ definitions,
+ block_selection,
+ }
+ }
+
+ fn get_enum_types(
+ parse_definitions: &[parser::OpcodeDefinition],
+ ) -> FxHashMap<syn::Type, FxHashSet<parser::DotModifier>> {
+ let mut result = FxHashMap::default();
+ for parser::OpcodeDefinition(_, rules) in parse_definitions.iter() {
+ for rule in rules {
+ let type_ = match rule.type_ {
+ Some(ref type_) => type_.clone(),
+ None => continue,
+ };
+ let insert_values = |set: &mut FxHashSet<_>| {
+ for value in rule.alternatives.iter().cloned() {
+ set.insert(value);
+ }
+ };
+ match result.entry(type_) {
+ hash_map::Entry::Occupied(mut entry) => insert_values(entry.get_mut()),
+ hash_map::Entry::Vacant(entry) => {
+ insert_values(entry.insert(FxHashSet::default()))
+ }
+ };
+ }
+ }
+ result
+ }
+}
+
+struct SingleOpcodeDefinition {
+ possible_modifiers: FxHashSet<parser::DotModifier>,
+ unordered_modifiers: Vec<DotModifierRef>,
+ ordered_modifiers: Vec<DotModifierRef>,
+ arguments: parser::Arguments,
+ code_block: parser::CodeBlock,
+}
+
+impl SingleOpcodeDefinition {
+ fn function_arguments_declarations(&self) -> impl Iterator<Item = TokenStream> + '_ {
+ self.unordered_modifiers
+ .iter()
+ .chain(self.ordered_modifiers.iter())
+ .filter_map(|modf| {
+ let type_ = modf.type_of();
+ type_.map(|t| {
+ let name = modf.ident();
+ quote! { #name : #t }
+ })
+ })
+ .chain(self.arguments.0.iter().map(|arg| {
+ let name = &arg.ident;
+ let arg_type = if arg.unified {
+ quote! { (ParsedOperandStr<'input>, bool) }
+ } else if arg.can_be_negated {
+ quote! { (bool, ParsedOperandStr<'input>) }
+ } else {
+ quote! { ParsedOperandStr<'input> }
+ };
+ if arg.optional {
+ quote! { #name : Option<#arg_type> }
+ } else {
+ quote! { #name : #arg_type }
+ }
+ }))
+ }
+
+ fn function_arguments(&self) -> impl Iterator<Item = TokenStream> + '_ {
+ self.unordered_modifiers
+ .iter()
+ .chain(self.ordered_modifiers.iter())
+ .filter_map(|modf| {
+ let type_ = modf.type_of();
+ type_.map(|_| {
+ let name = modf.ident();
+ quote! { #name }
+ })
+ })
+ .chain(self.arguments.0.iter().map(|arg| {
+ let name = &arg.ident;
+ quote! { #name }
+ }))
+ }
+
+ fn extract_and_insert(
+ definitions: &mut FxHashMap<Ident, Vec<SingleOpcodeDefinition>>,
+ special_definitions: &mut FxHashMap<Ident, proc_macro2::Group>,
+ parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition,
+ ) {
+ let (mut named_rules, mut unnamed_rules) = gather_rules(rules);
+ let mut last_opcode = pattern_seq.0.last().unwrap().0 .0.name.clone();
+ for (opcode_decl, code_block) in pattern_seq.0.into_iter().rev() {
+ let current_opcode = opcode_decl.0.name.clone();
+ if last_opcode != current_opcode {
+ named_rules = FxHashMap::default();
+ unnamed_rules = FxHashMap::default();
+ }
+ let parser::OpcodeDecl(instruction, arguments) = opcode_decl;
+ if code_block.special {
+ if !instruction.modifiers.is_empty() || !arguments.0.is_empty() {
+ panic!(
+ "`{}`: no modifiers or arguments are allowed in parser definition.",
+ instruction.name
+ );
+ }
+ special_definitions.insert(instruction.name, code_block.code);
+ continue;
+ }
+ let mut possible_modifiers = FxHashSet::default();
+ let mut unordered_modifiers = instruction
+ .modifiers
+ .into_iter()
+ .map(|parser::MaybeDotModifier { optional, modifier }| {
+ match named_rules.get(&modifier) {
+ Some(alts) => {
+ possible_modifiers.extend(alts.alternatives.iter().cloned());
+ if alts.alternatives.len() == 1 && alts.type_.is_none() {
+ DotModifierRef::Direct {
+ optional,
+ value: alts.alternatives[0].clone(),
+ name: modifier,
+ type_: alts.type_.clone(),
+ }
+ } else {
+ DotModifierRef::Indirect {
+ optional,
+ value: alts.clone(),
+ name: modifier,
+ }
+ }
+ }
+ None => {
+ let type_ = unnamed_rules.get(&modifier).cloned();
+ possible_modifiers.insert(modifier.clone());
+ DotModifierRef::Direct {
+ optional,
+ value: modifier.clone(),
+ name: modifier,
+ type_,
+ }
+ }
+ }
+ })
+ .collect::<Vec<_>>();
+ let ordered_modifiers = Self::extract_ordered_modifiers(&mut unordered_modifiers);
+ let entry = Self {
+ possible_modifiers,
+ unordered_modifiers,
+ ordered_modifiers,
+ arguments,
+ code_block,
+ };
+ multihash_extend(definitions, current_opcode.clone(), entry);
+ last_opcode = current_opcode;
+ }
+ }
+
+ fn extract_ordered_modifiers(
+ unordered_modifiers: &mut Vec<DotModifierRef>,
+ ) -> Vec<DotModifierRef> {
+ let mut result = Vec::new();
+ loop {
+ let is_ordered = match unordered_modifiers.last() {
+ Some(DotModifierRef::Direct { value, .. }) => {
+ let name = value.to_string();
+ POSTFIX_MODIFIERS.contains(&&*name)
+ }
+ Some(DotModifierRef::Indirect { value, .. }) => {
+ let type_ = value.type_.to_token_stream().to_string();
+ //panic!("{} {}", type_, POSTFIX_TYPES.contains(&&*type_));
+ POSTFIX_TYPES.contains(&&*type_)
+ }
+ None => break,
+ };
+ if is_ordered {
+ result.push(unordered_modifiers.pop().unwrap());
+ } else {
+ break;
+ }
+ }
+ if unordered_modifiers.len() == 1 {
+ result.push(unordered_modifiers.pop().unwrap());
+ }
+ result.reverse();
+ result
+ }
+}
+
+fn gather_rules(
+ rules: Vec<parser::Rule>,
+) -> (
+ FxHashMap<parser::DotModifier, Rc<parser::Rule>>,
+ FxHashMap<parser::DotModifier, Type>,
+) {
+ let mut named = FxHashMap::default();
+ let mut unnamed = FxHashMap::default();
+ for rule in rules {
+ match rule.modifier {
+ Some(ref modifier) => {
+ named.insert(modifier.clone(), Rc::new(rule));
+ }
+ None => unnamed.extend(
+ rule.alternatives
+ .into_iter()
+ .map(|alt| (alt, rule.type_.as_ref().unwrap().clone())),
+ ),
+ }
+ }
+ (named, unnamed)
+}
+
+#[proc_macro]
+pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let parse_definitions = parse_macro_input!(tokens as ptx_parser_macros_impl::parser::ParseDefinitions);
+ let mut definitions = FxHashMap::default();
+ let mut special_definitions = FxHashMap::default();
+ let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions);
+ let enum_types_tokens = emit_enum_types(types, parse_definitions.additional_enums);
+ for definition in parse_definitions.definitions.into_iter() {
+ SingleOpcodeDefinition::extract_and_insert(
+ &mut definitions,
+ &mut special_definitions,
+ definition,
+ );
+ }
+ let definitions = definitions
+ .into_iter()
+ .map(|(k, v)| {
+ let v = OpcodeDefinitions::new(&k, v);
+ (k, v)
+ })
+ .collect::<FxHashMap<_, _>>();
+ let mut token_enum = parse_definitions.token_type;
+ let (all_opcode, all_modifier) = write_definitions_into_tokens(
+ &definitions,
+ special_definitions.keys(),
+ &mut token_enum.variants,
+ );
+ let token_impl = emit_parse_function(&token_enum.ident, &definitions, &special_definitions, all_opcode, all_modifier);
+ let tokens = quote! {
+ #enum_types_tokens
+
+ #token_enum
+
+ #token_impl
+ };
+ tokens.into()
+}
+
+fn emit_enum_types(
+ types: FxHashMap<syn::Type, FxHashSet<parser::DotModifier>>,
+ mut existing_enums: FxHashMap<Ident, ItemEnum>,
+) -> TokenStream {
+ let token_types = types.into_iter().filter_map(|(type_, variants)| {
+ match type_ {
+ syn::Type::Path(TypePath {
+ qself: None,
+ ref path,
+ }) => {
+ if let Some(ident) = path.get_ident() {
+ if let Some(enum_) = existing_enums.get_mut(ident) {
+ enum_.variants.extend(variants.into_iter().map(|modifier| {
+ let ident = modifier.variant_capitalized();
+ let variant: syn::Variant = syn::parse_quote! {
+ #ident
+ };
+ variant
+ }));
+ return None;
+ }
+ }
+ }
+ _ => {}
+ }
+ let variants = variants.iter().map(|v| v.variant_capitalized());
+ Some(quote! {
+ #[derive(Copy, Clone, PartialEq, Eq, Hash)]
+ enum #type_ {
+ #(#variants),*
+ }
+ })
+ });
+ let mut result = TokenStream::new();
+ for tokens in token_types {
+ tokens.to_tokens(&mut result);
+ }
+ for (_, enum_) in existing_enums {
+ quote! { #enum_ }.to_tokens(&mut result);
+ }
+ result
+}
+
+fn emit_parse_function(
+ type_name: &Ident,
+ defs: &FxHashMap<Ident, OpcodeDefinitions>,
+ special_defs: &FxHashMap<Ident, proc_macro2::Group>,
+ all_opcode: Vec<&Ident>,
+ all_modifier: FxHashSet<&parser::DotModifier>,
+) -> TokenStream {
+ use std::fmt::Write;
+ let fns_ = defs
+ .iter()
+ .map(|(opcode, defs)| {
+ defs.definitions.iter().enumerate().map(|(idx, def)| {
+ let mut fn_name = opcode.to_string();
+ write!(&mut fn_name, "_{}", idx).ok();
+ let fn_name = Ident::new(&fn_name, Span::call_site());
+ let code_block = &def.code_block.code;
+ let args = def.function_arguments_declarations();
+ quote! {
+ fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction<ParsedOperandStr<'input>> #code_block
+ }
+ })
+ })
+ .flatten();
+ let selectors = defs.iter().map(|(opcode, def)| {
+ let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span());
+ let mut result = TokenStream::new();
+ let mut selectors = TokenStream::new();
+ quote! {
+ if false {
+ unsafe { std::hint::unreachable_unchecked() }
+ }
+ }
+ .to_tokens(&mut selectors);
+ let mut has_default_selector = false;
+ for selection_layer in def.block_selection.iter() {
+ for (selection_key, selected_definition) in selection_layer {
+ let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]);
+ match selection_key {
+ Some(selection_keys) => {
+ let selection_keys = selection_keys.iter().map(|k| k.dot_capitalized());
+ quote! {
+ else if false #(|| modifiers.contains(& #type_name :: #selection_keys))* {
+ #def_parser
+ }
+ }
+ .to_tokens(&mut selectors);
+ }
+ None => {
+ has_default_selector = true;
+ quote! {
+ else {
+ #def_parser
+ }
+ }
+ .to_tokens(&mut selectors);
+ }
+ }
+ }
+ }
+ if !has_default_selector {
+ quote! {
+ else {
+ return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token))
+ }
+ }
+ .to_tokens(&mut selectors);
+ }
+ quote! {
+ #opcode_variant => {
+ let modifers_start = stream.checkpoint();
+ let modifiers = take_while(0.., Token::modifier).parse_next(stream)?;
+ #selectors
+ }
+ }
+ .to_tokens(&mut result);
+ result
+ }).chain(special_defs.iter().map(|(opcode, code)| {
+ let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span());
+ quote! {
+ #opcode_variant => { #code? }
+ }
+ }));
+ let opcodes = all_opcode.into_iter().map(|op_ident| {
+ let op = op_ident.to_string();
+ let variant = Ident::new(&capitalize(&op), op_ident.span());
+ let value = op;
+ quote! {
+ #type_name :: #variant => Some(#value),
+ }
+ });
+ let modifier_names = iter::once(Ident::new("DotUnified", Span::call_site()))
+ .chain(all_modifier.iter().map(|m| m.dot_capitalized()));
+ quote! {
+ impl<'input> #type_name<'input> {
+ fn opcode_text(self) -> Option<&'static str> {
+ match self {
+ #(#opcodes)*
+ _ => None
+ }
+ }
+
+ fn modifier(self) -> bool {
+ match self {
+ #(
+ #type_name :: #modifier_names => true,
+ )*
+ _ => false
+ }
+ }
+ }
+
+ #(#fns_)*
+
+ fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult<Instruction<ParsedOperandStr<'input>>>
+ {
+ use winnow::Parser;
+ use winnow::token::*;
+ use winnow::combinator::*;
+ let opcode = any.parse_next(stream)?;
+ let modifiers_start = stream.checkpoint();
+ Ok(match opcode {
+ #(
+ #type_name :: #selectors
+ )*
+ _ => return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token))
+ })
+ }
+ }
+}
+
+fn emit_definition_parser(
+ token_type: &Ident,
+ (opcode, fn_idx): (&Ident, usize),
+ definition: &SingleOpcodeDefinition,
+) -> TokenStream {
+ let return_error_ref = quote! {
+ return Err(winnow::error::ErrMode::from_error_kind(&stream, winnow::error::ErrorKind::Token))
+ };
+ let return_error = quote! {
+ return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token))
+ };
+ let ordered_parse_declarations = definition.ordered_modifiers.iter().map(|modifier| {
+ modifier.type_of().map(|type_| {
+ let name = modifier.ident();
+ quote! {
+ let #name : #type_;
+ }
+ })
+ });
+ let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| {
+ let arg_name = modifier.ident();
+ match modifier {
+ DotModifierRef::Direct { optional, value, type_: None, .. } => {
+ let variant = value.dot_capitalized();
+ if *optional {
+ quote! {
+ #arg_name = opt(any.verify(|t| *t == #token_type :: #variant)).parse_next(&mut stream)?.is_some();
+ }
+ } else {
+ quote! {
+ any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?;
+ }
+ }
+ }
+ DotModifierRef::Direct { optional: false, type_: Some(type_), name, value } => {
+ let variable = name.ident();
+ let variant = value.dot_capitalized();
+ let parsed_variant = value.variant_capitalized();
+ quote! {
+ any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?;
+ #variable = #type_ :: #parsed_variant;
+ }
+ }
+ DotModifierRef::Direct { optional: true, type_: Some(_), .. } => { todo!() }
+ DotModifierRef::Indirect { optional, value, .. } => {
+ let variants = value.alternatives.iter().map(|alt| {
+ let type_ = value.type_.as_ref().unwrap();
+ let token_variant = alt.dot_capitalized();
+ let parsed_variant = alt.variant_capitalized();
+ quote! {
+ #token_type :: #token_variant => #type_ :: #parsed_variant,
+ }
+ });
+ if *optional {
+ quote! {
+ #arg_name = opt(any.verify_map(|tok| {
+ Some(match tok {
+ #(#variants)*
+ _ => return None
+ })
+ })).parse_next(&mut stream)?;
+ }
+ } else {
+ quote! {
+ #arg_name = any.verify_map(|tok| {
+ Some(match tok {
+ #(#variants)*
+ _ => return None
+ })
+ }).parse_next(&mut stream)?;
+ }
+ }
+ }
+ }
+ });
+ let unordered_parse_declarations = definition.unordered_modifiers.iter().map(|modifier| {
+ let name = modifier.ident();
+ let type_ = modifier.type_of_check();
+ quote! {
+ let mut #name : #type_ = std::default::Default::default();
+ }
+ });
+ let unordered_parse = definition
+ .unordered_modifiers
+ .iter()
+ .map(|modifier| match modifier {
+ DotModifierRef::Direct {
+ name,
+ value,
+ type_: None,
+ ..
+ } => {
+ let name = name.ident();
+ let token_variant = value.dot_capitalized();
+ quote! {
+ #token_type :: #token_variant => {
+ if #name {
+ #return_error_ref;
+ }
+ #name = true;
+ }
+ }
+ }
+ DotModifierRef::Direct {
+ name,
+ value,
+ type_: Some(type_),
+ ..
+ } => {
+ let variable = name.ident();
+ let token_variant = value.dot_capitalized();
+ let enum_variant = value.variant_capitalized();
+ quote! {
+ #token_type :: #token_variant => {
+ if #variable.is_some() {
+ #return_error_ref;
+ }
+ #variable = Some(#type_ :: #enum_variant);
+ }
+ }
+ }
+ DotModifierRef::Indirect { value, name, .. } => {
+ let variable = name.ident();
+ let type_ = value.type_.as_ref().unwrap();
+ let alternatives = value.alternatives.iter().map(|alt| {
+ let token_variant = alt.dot_capitalized();
+ let enum_variant = alt.variant_capitalized();
+ quote! {
+ #token_type :: #token_variant => {
+ if #variable.is_some() {
+ #return_error_ref;
+ }
+ #variable = Some(#type_ :: #enum_variant);
+ }
+ }
+ });
+ quote! {
+ #(#alternatives)*
+ }
+ }
+ });
+ let unordered_parse_validations =
+ definition
+ .unordered_modifiers
+ .iter()
+ .map(|modifier| match modifier {
+ DotModifierRef::Direct {
+ optional: false,
+ name,
+ type_: None,
+ ..
+ } => {
+ let variable = name.ident();
+ quote! {
+ if !#variable {
+ #return_error;
+ }
+ }
+ }
+ DotModifierRef::Direct {
+ optional: false,
+ name,
+ type_: Some(_),
+ ..
+ } => {
+ let variable = name.ident();
+ quote! {
+ let #variable = match #variable {
+ Some(x) => x,
+ None => #return_error
+ };
+ }
+ }
+ DotModifierRef::Indirect {
+ optional: false,
+ name,
+ ..
+ } => {
+ let variable = name.ident();
+ quote! {
+ let #variable = match #variable {
+ Some(x) => x,
+ None => #return_error
+ };
+ }
+ }
+ DotModifierRef::Direct { optional: true, .. }
+ | DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(),
+ });
+ let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| {
+ let comma = if idx == 0 || arg.pre_pipe {
+ quote! { empty }
+ } else {
+ quote! { any.verify(|t| *t == #token_type::Comma).void() }
+ };
+ let pre_bracket = if arg.pre_bracket {
+ quote! {
+ any.verify(|t| *t == #token_type::LBracket).void()
+ }
+ } else {
+ quote! {
+ empty
+ }
+ };
+ let pre_pipe = if arg.pre_pipe {
+ quote! {
+ any.verify(|t| *t == #token_type::Pipe).void()
+ }
+ } else {
+ quote! {
+ empty
+ }
+ };
+ let can_be_negated = if arg.can_be_negated {
+ quote! {
+ opt(any.verify(|t| *t == #token_type::Not)).map(|o| o.is_some())
+ }
+ } else {
+ quote! {
+ empty
+ }
+ };
+ let operand = {
+ quote! {
+ ParsedOperandStr::parse
+ }
+ };
+ let post_bracket = if arg.post_bracket {
+ quote! {
+ any.verify(|t| *t == #token_type::RBracket).void()
+ }
+ } else {
+ quote! {
+ empty
+ }
+ };
+ let unified = if arg.unified {
+ quote! {
+ opt(any.verify(|t| *t == #token_type::DotUnified).void()).map(|u| u.is_some())
+ }
+ } else {
+ quote! {
+ empty
+ }
+ };
+ let pattern = quote! {
+ (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified)
+ };
+ let arg_name = &arg.ident;
+ if arg.unified && arg.can_be_negated {
+ panic!("TODO: argument can't be both prefixed by `!` and suffixed by `.unified`")
+ }
+ let inner_parser = if arg.unified {
+ quote! {
+ #pattern.map(|(_, _, _, _, name, _, unified)| (name, unified))
+ }
+ } else if arg.can_be_negated {
+ quote! {
+ #pattern.map(|(_, _, _, negated, name, _, _)| (negated, name))
+ }
+ } else {
+ quote! {
+ #pattern.map(|(_, _, _, _, name, _, _)| name)
+ }
+ };
+ if arg.optional {
+ quote! {
+ let #arg_name = opt(#inner_parser).parse_next(stream)?;
+ }
+ } else {
+ quote! {
+ let #arg_name = #inner_parser.parse_next(stream)?;
+ }
+ }
+ });
+ let fn_args = definition.function_arguments();
+ let fn_name = format_ident!("{}_{}", opcode, fn_idx);
+ let fn_call = quote! {
+ #fn_name(&mut stream.state, #(#fn_args),* )
+ };
+ quote! {
+ #(#unordered_parse_declarations)*
+ #(#ordered_parse_declarations)*
+ {
+ let mut stream = ReverseStream(modifiers);
+ #(#ordered_parse)*
+ let mut stream: &[#token_type] = stream.0;
+ for token in stream.iter().copied() {
+ match token {
+ #(#unordered_parse)*
+ _ => #return_error_ref
+ }
+ }
+ }
+ #(#unordered_parse_validations)*
+ #(#arguments_parse)*
+ #fn_call
+ }
+}
+
+fn write_definitions_into_tokens<'a>(
+ defs: &'a FxHashMap<Ident, OpcodeDefinitions>,
+ special_definitions: impl Iterator<Item = &'a Ident>,
+ variants: &mut Punctuated<Variant, Token![,]>,
+) -> (Vec<&'a Ident>, FxHashSet<&'a parser::DotModifier>) {
+ let mut all_opcodes = Vec::new();
+ let mut all_modifiers = FxHashSet::default();
+ for (opcode, definitions) in defs.iter() {
+ all_opcodes.push(opcode);
+ let opcode_as_string = opcode.to_string();
+ let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span());
+ let arg: Variant = syn::parse_quote! {
+ #[token(#opcode_as_string)]
+ #variant_name
+ };
+ variants.push(arg);
+ for definition in definitions.definitions.iter() {
+ for modifier in definition.possible_modifiers.iter() {
+ all_modifiers.insert(modifier);
+ }
+ }
+ }
+ for opcode in special_definitions {
+ all_opcodes.push(opcode);
+ let opcode_as_string = opcode.to_string();
+ let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span());
+ let arg: Variant = syn::parse_quote! {
+ #[token(#opcode_as_string)]
+ #variant_name
+ };
+ variants.push(arg);
+ }
+ for modifier in all_modifiers.iter() {
+ let modifier_as_string = modifier.to_string();
+ let variant_name = modifier.dot_capitalized();
+ let arg: Variant = syn::parse_quote! {
+ #[token(#modifier_as_string)]
+ #variant_name
+ };
+ variants.push(arg);
+ }
+ variants.push(parse_quote! {
+ #[token(".unified")]
+ DotUnified
+ });
+ (all_opcodes, all_modifiers)
+}
+
+fn capitalize(s: &str) -> String {
+ let mut c = s.chars();
+ match c.next() {
+ None => String::new(),
+ Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
+ }
+}
+
+fn multihash_extend<K: Eq + Hash, V>(multimap: &mut FxHashMap<K, Vec<V>>, k: K, v: V) {
+ match multimap.entry(k) {
+ hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(v),
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(vec![v]);
+ }
+ }
+}
+
+enum DotModifierRef {
+ Direct {
+ optional: bool,
+ value: parser::DotModifier,
+ name: parser::DotModifier,
+ type_: Option<Type>,
+ },
+ Indirect {
+ optional: bool,
+ name: parser::DotModifier,
+ value: Rc<parser::Rule>,
+ },
+}
+
+impl DotModifierRef {
+ fn ident(&self) -> Ident {
+ match self {
+ DotModifierRef::Direct { name, .. } => name.ident(),
+ DotModifierRef::Indirect { name, .. } => name.ident(),
+ }
+ }
+
+ fn type_of(&self) -> Option<syn::Type> {
+ Some(match self {
+ DotModifierRef::Direct {
+ optional: true,
+ type_: None,
+ ..
+ } => syn::parse_quote! { bool },
+ DotModifierRef::Direct {
+ optional: false,
+ type_: None,
+ ..
+ } => return None,
+ DotModifierRef::Direct {
+ optional: true,
+ type_: Some(type_),
+ ..
+ } => syn::parse_quote! { Option<#type_> },
+ DotModifierRef::Direct {
+ optional: false,
+ type_: Some(type_),
+ ..
+ } => type_.clone(),
+ DotModifierRef::Indirect {
+ optional, value, ..
+ } => {
+ let type_ = value
+ .type_
+ .as_ref()
+ .expect("Indirect modifer must have a type");
+ if *optional {
+ syn::parse_quote! { Option<#type_> }
+ } else {
+ type_.clone()
+ }
+ }
+ })
+ }
+
+ fn type_of_check(&self) -> syn::Type {
+ match self {
+ DotModifierRef::Direct { type_: None, .. } => syn::parse_quote! { bool },
+ DotModifierRef::Direct {
+ type_: Some(type_), ..
+ } => syn::parse_quote! { Option<#type_> },
+ DotModifierRef::Indirect { value, .. } => {
+ let type_ = value
+ .type_
+ .as_ref()
+ .expect("Indirect modifer must have a type");
+ syn::parse_quote! { Option<#type_> }
+ }
+ }
+ }
+}
+
+#[proc_macro]
+pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let input = parse_macro_input!(tokens as ptx_parser_macros_impl::GenerateInstructionType);
+ let mut result = proc_macro2::TokenStream::new();
+ input.emit_arg_types(&mut result);
+ input.emit_instruction_type(&mut result);
+ input.emit_visit(&mut result);
+ input.emit_visit_mut(&mut result);
+ input.emit_visit_map(&mut result);
+ result.into()
+}
diff --git a/ptx_parser_macros_impl/Cargo.toml b/ptx_parser_macros_impl/Cargo.toml new file mode 100644 index 0000000..96f3b74 --- /dev/null +++ b/ptx_parser_macros_impl/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "ptx_parser_macros_impl" +version = "0.0.0" +authors = ["Andrzej Janik <[email protected]>"] +edition = "2021" + +[lib] + +[dependencies] +syn = { version = "2.0.67", features = ["extra-traits", "full"] } +quote = "1.0" +proc-macro2 = "1.0.86" +rustc-hash = "2.0.0" diff --git a/ptx_parser_macros_impl/src/lib.rs b/ptx_parser_macros_impl/src/lib.rs new file mode 100644 index 0000000..2f2c87a --- /dev/null +++ b/ptx_parser_macros_impl/src/lib.rs @@ -0,0 +1,881 @@ +use proc_macro2::TokenStream;
+use quote::{format_ident, quote, ToTokens};
+use syn::{
+ braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, LitBool, PathSegment, Token,
+ Type, TypeParam, Visibility,
+};
+
+pub mod parser;
+
+pub struct GenerateInstructionType {
+ pub visibility: Option<Visibility>,
+ pub name: Ident,
+ pub type_parameters: Punctuated<TypeParam, Token![,]>,
+ pub short_parameters: Punctuated<Ident, Token![,]>,
+ pub variants: Punctuated<InstructionVariant, Token![,]>,
+}
+
+impl GenerateInstructionType {
+ pub fn emit_arg_types(&self, tokens: &mut TokenStream) {
+ for v in self.variants.iter() {
+ v.emit_type(&self.visibility, tokens);
+ }
+ }
+
+ pub fn emit_instruction_type(&self, tokens: &mut TokenStream) {
+ let vis = &self.visibility;
+ let type_name = &self.name;
+ let type_parameters = &self.type_parameters;
+ let variants = self.variants.iter().map(|v| v.emit_variant());
+ quote! {
+ #vis enum #type_name<#type_parameters> {
+ #(#variants),*
+ }
+ }
+ .to_tokens(tokens);
+ }
+
+ pub fn emit_visit(&self, tokens: &mut TokenStream) {
+ self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit)
+ }
+
+ pub fn emit_visit_mut(&self, tokens: &mut TokenStream) {
+ self.emit_visit_impl(
+ VisitKind::RefMut,
+ tokens,
+ InstructionVariant::emit_visit_mut,
+ )
+ }
+
+ pub fn emit_visit_map(&self, tokens: &mut TokenStream) {
+ self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map)
+ }
+
+ fn emit_visit_impl(
+ &self,
+ kind: VisitKind,
+ tokens: &mut TokenStream,
+ mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream),
+ ) {
+ let type_name = &self.name;
+ let type_parameters = &self.type_parameters;
+ let short_parameters = &self.short_parameters;
+ let mut inner_tokens = TokenStream::new();
+ for v in self.variants.iter() {
+ fn_(v, type_name, &mut inner_tokens);
+ }
+ let visit_ref = kind.reference();
+ let visitor_type = format_ident!("Visitor{}", kind.type_suffix());
+ let visit_fn = format_ident!("visit{}", kind.fn_suffix());
+ let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map {
+ (
+ quote! { <#type_parameters, To: Operand, Err> },
+ quote! { <#short_parameters, To, Err> },
+ quote! { std::result::Result<#type_name<To>, Err> },
+ )
+ } else {
+ (
+ quote! { <#type_parameters, Err> },
+ quote! { <#short_parameters, Err> },
+ quote! { std::result::Result<(), Err> },
+ )
+ };
+ quote! {
+ pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type {
+ Ok(match i {
+ #inner_tokens
+ })
+ }
+ }.to_tokens(tokens);
+ if kind == VisitKind::Map {
+ return;
+ }
+ }
+}
+
+#[derive(Clone, Copy, PartialEq, Eq)]
+enum VisitKind {
+ Ref,
+ RefMut,
+ Map,
+}
+
+impl VisitKind {
+ fn fn_suffix(self) -> &'static str {
+ match self {
+ VisitKind::Ref => "",
+ VisitKind::RefMut => "_mut",
+ VisitKind::Map => "_map",
+ }
+ }
+
+ fn type_suffix(self) -> &'static str {
+ match self {
+ VisitKind::Ref => "",
+ VisitKind::RefMut => "Mut",
+ VisitKind::Map => "Map",
+ }
+ }
+
+ fn reference(self) -> Option<proc_macro2::TokenStream> {
+ match self {
+ VisitKind::Ref => Some(quote! { & }),
+ VisitKind::RefMut => Some(quote! { &mut }),
+ VisitKind::Map => None,
+ }
+ }
+}
+
+impl Parse for GenerateInstructionType {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let visibility = if !input.peek(Token![enum]) {
+ Some(input.parse::<Visibility>()?)
+ } else {
+ None
+ };
+ input.parse::<Token![enum]>()?;
+ let name = input.parse::<Ident>()?;
+ input.parse::<Token![<]>()?;
+ let type_parameters = Punctuated::parse_separated_nonempty(input)?;
+ let short_parameters = type_parameters
+ .iter()
+ .map(|p: &TypeParam| p.ident.clone())
+ .collect();
+ input.parse::<Token![>]>()?;
+ let variants_buffer;
+ braced!(variants_buffer in input);
+ let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?;
+ Ok(Self {
+ visibility,
+ name,
+ type_parameters,
+ short_parameters,
+ variants,
+ })
+ }
+}
+
+pub struct InstructionVariant {
+ pub name: Ident,
+ pub type_: Option<Option<Expr>>,
+ pub space: Option<Expr>,
+ pub data: Option<Type>,
+ pub arguments: Option<Arguments>,
+ pub visit: Option<Expr>,
+ pub visit_mut: Option<Expr>,
+ pub map: Option<Expr>,
+}
+
+impl InstructionVariant {
+ fn args_name(&self) -> Ident {
+ format_ident!("{}Args", self.name)
+ }
+
+ fn emit_variant(&self) -> TokenStream {
+ let name = &self.name;
+ let data = match &self.data {
+ None => {
+ quote! {}
+ }
+ Some(data_type) => {
+ quote! {
+ data: #data_type,
+ }
+ }
+ };
+ let arguments = match &self.arguments {
+ None => {
+ quote! {}
+ }
+ Some(args) => {
+ let args_name = self.args_name();
+ match &args {
+ Arguments::Def(InstructionArguments { generic: None, .. }) => {
+ quote! {
+ arguments: #args_name,
+ }
+ }
+ Arguments::Def(InstructionArguments {
+ generic: Some(generics),
+ ..
+ }) => {
+ quote! {
+ arguments: #args_name <#generics>,
+ }
+ }
+ Arguments::Decl(type_) => quote! {
+ arguments: #type_,
+ },
+ }
+ }
+ };
+ quote! {
+ #name { #data #arguments }
+ }
+ }
+
+ fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) {
+ self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit)
+ }
+
+ fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) {
+ self.emit_visit_impl(
+ &self.visit_mut,
+ enum_,
+ tokens,
+ InstructionArguments::emit_visit_mut,
+ )
+ }
+
+ fn emit_visit_impl(
+ &self,
+ visit_fn: &Option<Expr>,
+ enum_: &Ident,
+ tokens: &mut TokenStream,
+ mut fn_: impl FnMut(&InstructionArguments, &Option<Option<Expr>>, &Option<Expr>) -> TokenStream,
+ ) {
+ let name = &self.name;
+ let arguments = match &self.arguments {
+ None => {
+ quote! {
+ #enum_ :: #name { .. } => { }
+ }
+ .to_tokens(tokens);
+ return;
+ }
+ Some(Arguments::Decl(_)) => {
+ quote! {
+ #enum_ :: #name { data, arguments } => { #visit_fn }
+ }
+ .to_tokens(tokens);
+ return;
+ }
+ Some(Arguments::Def(args)) => args,
+ };
+ let data = &self.data.as_ref().map(|_| quote! { data,});
+ let arg_calls = fn_(arguments, &self.type_, &self.space);
+ quote! {
+ #enum_ :: #name { #data arguments } => {
+ #arg_calls
+ }
+ }
+ .to_tokens(tokens);
+ }
+
+ fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) {
+ let name = &self.name;
+ let data = &self.data.as_ref().map(|_| quote! { data,});
+ let arguments = match self.arguments {
+ None => None,
+ Some(Arguments::Decl(_)) => {
+ let map = self.map.as_ref().unwrap();
+ quote! {
+ #enum_ :: #name { #data arguments } => {
+ #map
+ }
+ }
+ .to_tokens(tokens);
+ return;
+ }
+ Some(Arguments::Def(ref def)) => Some(def),
+ };
+ let arguments_ident = &self.arguments.as_ref().map(|_| quote! { arguments,});
+ let mut arg_calls = None;
+ let arguments_init = arguments.as_ref().map(|arguments| {
+ let arg_type = self.args_name();
+ arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space));
+ let arg_names = arguments.fields.iter().map(|arg| &arg.name);
+ quote! {
+ arguments: #arg_type { #(#arg_names),* }
+ }
+ });
+ quote! {
+ #enum_ :: #name { #data #arguments_ident } => {
+ #arg_calls
+ #enum_ :: #name { #data #arguments_init }
+ }
+ }
+ .to_tokens(tokens);
+ }
+
+ fn emit_type(&self, vis: &Option<Visibility>, tokens: &mut TokenStream) {
+ let arguments = match self.arguments {
+ Some(Arguments::Def(ref a)) => a,
+ Some(Arguments::Decl(_)) => return,
+ None => return,
+ };
+ let name = self.args_name();
+ let type_parameters = if arguments.generic.is_some() {
+ Some(quote! { <T> })
+ } else {
+ None
+ };
+ let fields = arguments.fields.iter().map(|f| f.emit_field(vis));
+ quote! {
+ #vis struct #name #type_parameters {
+ #(#fields),*
+ }
+ }
+ .to_tokens(tokens);
+ }
+}
+
+impl Parse for InstructionVariant {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let name = input.parse::<Ident>()?;
+ let properties_buffer;
+ braced!(properties_buffer in input);
+ let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?;
+ let mut type_ = None;
+ let mut space = None;
+ let mut data = None;
+ let mut arguments = None;
+ let mut visit = None;
+ let mut visit_mut = None;
+ let mut map = None;
+ for property in properties {
+ match property {
+ VariantProperty::Type(t) => type_ = Some(t),
+ VariantProperty::Space(s) => space = Some(s),
+ VariantProperty::Data(d) => data = Some(d),
+ VariantProperty::Arguments(a) => arguments = Some(a),
+ VariantProperty::Visit(e) => visit = Some(e),
+ VariantProperty::VisitMut(e) => visit_mut = Some(e),
+ VariantProperty::Map(e) => map = Some(e),
+ }
+ }
+ Ok(Self {
+ name,
+ type_,
+ space,
+ data,
+ arguments,
+ visit,
+ visit_mut,
+ map,
+ })
+ }
+}
+
+enum VariantProperty {
+ Type(Option<Expr>),
+ Space(Expr),
+ Data(Type),
+ Arguments(Arguments),
+ Visit(Expr),
+ VisitMut(Expr),
+ Map(Expr),
+}
+
+impl VariantProperty {
+ pub fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let lookahead = input.lookahead1();
+ Ok(if lookahead.peek(Token![type]) {
+ input.parse::<Token![type]>()?;
+ input.parse::<Token![:]>()?;
+ VariantProperty::Type(if input.peek(Token![!]) {
+ input.parse::<Token![!]>()?;
+ None
+ } else {
+ Some(input.parse::<Expr>()?)
+ })
+ } else if lookahead.peek(Ident) {
+ let key = input.parse::<Ident>()?;
+ match &*key.to_string() {
+ "data" => {
+ input.parse::<Token![:]>()?;
+ VariantProperty::Data(input.parse::<Type>()?)
+ }
+ "space" => {
+ input.parse::<Token![:]>()?;
+ VariantProperty::Space(input.parse::<Expr>()?)
+ }
+ "arguments" => {
+ let generics = if input.peek(Token![<]) {
+ input.parse::<Token![<]>()?;
+ let gen_params =
+ Punctuated::<PathSegment, syn::token::PathSep>::parse_separated_nonempty(input)?;
+ input.parse::<Token![>]>()?;
+ Some(gen_params)
+ } else {
+ None
+ };
+ input.parse::<Token![:]>()?;
+ if input.peek(token::Brace) {
+ let fields;
+ braced!(fields in input);
+ VariantProperty::Arguments(Arguments::Def(InstructionArguments::parse(
+ generics, &fields,
+ )?))
+ } else {
+ VariantProperty::Arguments(Arguments::Decl(input.parse::<Type>()?))
+ }
+ }
+ "visit" => {
+ input.parse::<Token![:]>()?;
+ VariantProperty::Visit(input.parse::<Expr>()?)
+ }
+ "visit_mut" => {
+ input.parse::<Token![:]>()?;
+ VariantProperty::VisitMut(input.parse::<Expr>()?)
+ }
+ "map" => {
+ input.parse::<Token![:]>()?;
+ VariantProperty::Map(input.parse::<Expr>()?)
+ }
+ x => {
+ return Err(syn::Error::new(
+ key.span(),
+ format!(
+ "Unexpected key `{}`. Expected `type`, `data`, `arguments`, `visit, `visit_mut` or `map`.",
+ x
+ ),
+ ))
+ }
+ }
+ } else {
+ return Err(lookahead.error());
+ })
+ }
+}
+
+pub enum Arguments {
+ Decl(Type),
+ Def(InstructionArguments),
+}
+
+pub struct InstructionArguments {
+ pub generic: Option<Punctuated<PathSegment, syn::token::PathSep>>,
+ pub fields: Punctuated<ArgumentField, Token![,]>,
+}
+
+impl InstructionArguments {
+ pub fn parse(
+ generic: Option<Punctuated<PathSegment, syn::token::PathSep>>,
+ input: syn::parse::ParseStream,
+ ) -> syn::Result<Self> {
+ let fields = Punctuated::<ArgumentField, Token![,]>::parse_terminated_with(
+ input,
+ ArgumentField::parse,
+ )?;
+ Ok(Self { generic, fields })
+ }
+
+ fn emit_visit(
+ &self,
+ parent_type: &Option<Option<Expr>>,
+ parent_space: &Option<Expr>,
+ ) -> TokenStream {
+ self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit)
+ }
+
+ fn emit_visit_mut(
+ &self,
+ parent_type: &Option<Option<Expr>>,
+ parent_space: &Option<Expr>,
+ ) -> TokenStream {
+ self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut)
+ }
+
+ fn emit_visit_map(
+ &self,
+ parent_type: &Option<Option<Expr>>,
+ parent_space: &Option<Expr>,
+ ) -> TokenStream {
+ self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map)
+ }
+
+ fn emit_visit_impl(
+ &self,
+ parent_type: &Option<Option<Expr>>,
+ parent_space: &Option<Expr>,
+ mut fn_: impl FnMut(&ArgumentField, &Option<Option<Expr>>, &Option<Expr>, bool) -> TokenStream,
+ ) -> TokenStream {
+ let is_ident = if let Some(ref generic) = self.generic {
+ generic.len() > 1
+ } else {
+ false
+ };
+ let field_calls = self
+ .fields
+ .iter()
+ .map(|f| fn_(f, parent_type, parent_space, is_ident));
+ quote! {
+ #(#field_calls)*
+ }
+ }
+}
+
+pub struct ArgumentField {
+ pub name: Ident,
+ pub is_dst: bool,
+ pub repr: Type,
+ pub space: Option<Expr>,
+ pub type_: Option<Expr>,
+ pub relaxed_type_check: bool,
+}
+
+impl ArgumentField {
+ fn parse_block(
+ input: syn::parse::ParseStream,
+ ) -> syn::Result<(Type, Option<Expr>, Option<Expr>, Option<bool>, bool)> {
+ let content;
+ braced!(content in input);
+ let all_fields =
+ Punctuated::<ExprOrPath, Token![,]>::parse_terminated_with(&content, |content| {
+ let lookahead = content.lookahead1();
+ Ok(if lookahead.peek(Token![type]) {
+ content.parse::<Token![type]>()?;
+ content.parse::<Token![:]>()?;
+ ExprOrPath::Type(content.parse::<Expr>()?)
+ } else if lookahead.peek(Ident) {
+ let name_ident = content.parse::<Ident>()?;
+ content.parse::<Token![:]>()?;
+ match &*name_ident.to_string() {
+ "relaxed_type_check" => {
+ ExprOrPath::RelaxedTypeCheck(content.parse::<LitBool>()?.value)
+ }
+ "repr" => ExprOrPath::Repr(content.parse::<Type>()?),
+ "space" => ExprOrPath::Space(content.parse::<Expr>()?),
+ "dst" => {
+ let ident = content.parse::<LitBool>()?;
+ ExprOrPath::Dst(ident.value)
+ }
+ name => {
+ return Err(syn::Error::new(
+ name_ident.span(),
+ format!("Unexpected key `{}`, expected `repr` or `space", name),
+ ))
+ }
+ }
+ } else {
+ return Err(lookahead.error());
+ })
+ })?;
+ let mut repr = None;
+ let mut type_ = None;
+ let mut space = None;
+ let mut is_dst = None;
+ let mut relaxed_type_check = false;
+ for exp_or_path in all_fields {
+ match exp_or_path {
+ ExprOrPath::Repr(r) => repr = Some(r),
+ ExprOrPath::Type(t) => type_ = Some(t),
+ ExprOrPath::Space(s) => space = Some(s),
+ ExprOrPath::Dst(x) => is_dst = Some(x),
+ ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed,
+ }
+ }
+ Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check))
+ }
+
+ fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result<Type> {
+ input.parse::<Type>()
+ }
+
+ fn emit_visit(
+ &self,
+ parent_type: &Option<Option<Expr>>,
+ parent_space: &Option<Expr>,
+ is_ident: bool,
+ ) -> TokenStream {
+ self.emit_visit_impl(parent_type, parent_space, is_ident, false)
+ }
+
+ fn emit_visit_mut(
+ &self,
+ parent_type: &Option<Option<Expr>>,
+ parent_space: &Option<Expr>,
+ is_ident: bool,
+ ) -> TokenStream {
+ self.emit_visit_impl(parent_type, parent_space, is_ident, true)
+ }
+
+ fn emit_visit_impl(
+ &self,
+ parent_type: &Option<Option<Expr>>,
+ parent_space: &Option<Expr>,
+ is_ident: bool,
+ is_mut: bool,
+ ) -> TokenStream {
+ let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) {
+ (Some(type_), _) => (false, Some(type_)),
+ (None, None) => panic!("No type set"),
+ (None, Some(None)) => (true, None),
+ (None, Some(Some(type_))) => (false, Some(type_)),
+ };
+ let space = self
+ .space
+ .as_ref()
+ .or(parent_space.as_ref())
+ .map(|space| quote! { #space })
+ .unwrap_or_else(|| quote! { StateSpace::Reg });
+ let is_dst = self.is_dst;
+ let relaxed_type_check = self.relaxed_type_check;
+ let name = &self.name;
+ let type_space = if is_typeless {
+ quote! {
+ let type_space = None;
+ }
+ } else {
+ quote! {
+ let type_ = #type_;
+ let space = #space;
+ let type_space = Some((std::borrow::Borrow::<Type>::borrow(&type_), space));
+ }
+ };
+ if is_ident {
+ if is_mut {
+ quote! {
+ {
+ #type_space
+ visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
+ }
+ }
+ } else {
+ quote! {
+ {
+ #type_space
+ visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?;
+ }
+ }
+ }
+ } else {
+ let (operand_fn, arguments_name) = if is_mut {
+ (
+ quote! {
+ VisitOperand::visit_mut
+ },
+ quote! {
+ &mut arguments.#name
+ },
+ )
+ } else {
+ (
+ quote! {
+ VisitOperand::visit
+ },
+ quote! {
+ & arguments.#name
+ },
+ )
+ };
+ quote! {{
+ #type_space
+ #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?;
+ }}
+ }
+ }
+
+ fn emit_visit_map(
+ &self,
+ parent_type: &Option<Option<Expr>>,
+ parent_space: &Option<Expr>,
+ is_ident: bool,
+ ) -> TokenStream {
+ let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) {
+ (Some(type_), _) => (false, Some(type_)),
+ (None, None) => panic!("No type set"),
+ (None, Some(None)) => (true, None),
+ (None, Some(Some(type_))) => (false, Some(type_)),
+ };
+ let space = self
+ .space
+ .as_ref()
+ .or(parent_space.as_ref())
+ .map(|space| quote! { #space })
+ .unwrap_or_else(|| quote! { StateSpace::Reg });
+ let is_dst = self.is_dst;
+ let relaxed_type_check = self.relaxed_type_check;
+ let name = &self.name;
+ let type_space = if is_typeless {
+ quote! {
+ let type_space = None;
+ }
+ } else {
+ quote! {
+ let type_ = #type_;
+ let space = #space;
+ let type_space = Some((std::borrow::Borrow::<Type>::borrow(&type_), space));
+ }
+ };
+ let map_call = if is_ident {
+ quote! {
+ visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)?
+ }
+ } else {
+ quote! {
+ MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?
+ }
+ };
+ quote! {
+ let #name = {
+ #type_space
+ #map_call
+ };
+ }
+ }
+
+ fn is_dst(name: &Ident) -> syn::Result<bool> {
+ if name.to_string().starts_with("dst") {
+ Ok(true)
+ } else if name.to_string().starts_with("src") {
+ Ok(false)
+ } else {
+ return Err(syn::Error::new(
+ name.span(),
+ format!(
+ "Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`",
+ name
+ ),
+ ));
+ }
+ }
+
+ fn emit_field(&self, vis: &Option<Visibility>) -> TokenStream {
+ let name = &self.name;
+ let type_ = &self.repr;
+ quote! {
+ #vis #name: #type_
+ }
+ }
+}
+
+impl Parse for ArgumentField {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let name = input.parse::<Ident>()?;
+
+ input.parse::<Token![:]>()?;
+ let lookahead = input.lookahead1();
+ let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) {
+ Self::parse_block(input)?
+ } else if lookahead.peek(syn::Ident) {
+ (Self::parse_basic(input)?, None, None, None, false)
+ } else {
+ return Err(lookahead.error());
+ };
+ let is_dst = match is_dst {
+ Some(x) => x,
+ None => Self::is_dst(&name)?,
+ };
+ Ok(Self {
+ name,
+ is_dst,
+ repr,
+ type_,
+ space,
+ relaxed_type_check
+ })
+ }
+}
+
+enum ExprOrPath {
+ Repr(Type),
+ Type(Expr),
+ Space(Expr),
+ Dst(bool),
+ RelaxedTypeCheck(bool),
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use proc_macro2::Span;
+ use quote::{quote, ToTokens};
+
+ fn to_string(x: impl ToTokens) -> String {
+ quote! { #x }.to_string()
+ }
+
+ #[test]
+ fn parse_argument_field_basic() {
+ let input = quote! {
+ dst: P::Operand
+ };
+ let arg = syn::parse2::<ArgumentField>(input).unwrap();
+ assert_eq!("dst", arg.name.to_string());
+ assert_eq!("P :: Operand", to_string(arg.repr));
+ assert!(matches!(arg.type_, None));
+ }
+
+ #[test]
+ fn parse_argument_field_block() {
+ let input = quote! {
+ dst: {
+ type: ScalarType::U32,
+ space: StateSpace::Global,
+ repr: P::Operand,
+ }
+ };
+ let arg = syn::parse2::<ArgumentField>(input).unwrap();
+ assert_eq!("dst", arg.name.to_string());
+ assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap()));
+ assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap()));
+ assert_eq!("P :: Operand", to_string(arg.repr));
+ }
+
+ #[test]
+ fn parse_argument_field_block_untyped() {
+ let input = quote! {
+ dst: {
+ repr: P::Operand,
+ }
+ };
+ let arg = syn::parse2::<ArgumentField>(input).unwrap();
+ assert_eq!("dst", arg.name.to_string());
+ assert_eq!("P :: Operand", to_string(arg.repr));
+ assert!(matches!(arg.type_, None));
+ }
+
+ #[test]
+ fn parse_variant_complex() {
+ let input = quote! {
+ Ld {
+ type: ScalarType::U32,
+ space: StateSpace::Global,
+ data: LdDetails,
+ arguments<P>: {
+ dst: {
+ repr: P::Operand,
+ type: ScalarType::U32,
+ space: StateSpace::Shared,
+ },
+ src: P::Operand,
+ },
+ }
+ };
+ let variant = syn::parse2::<InstructionVariant>(input).unwrap();
+ assert_eq!("Ld", variant.name.to_string());
+ assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap()));
+ assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap()));
+ assert_eq!("LdDetails", to_string(variant.data.unwrap()));
+ let arguments = if let Some(Arguments::Def(a)) = variant.arguments {
+ a
+ } else {
+ panic!()
+ };
+ assert_eq!("P", to_string(arguments.generic));
+ let mut fields = arguments.fields.into_iter();
+ let dst = fields.next().unwrap();
+ assert_eq!("P :: Operand", to_string(dst.repr));
+ assert_eq!("ScalarType :: U32", to_string(dst.type_));
+ assert_eq!("StateSpace :: Shared", to_string(dst.space));
+ let src = fields.next().unwrap();
+ assert_eq!("P :: Operand", to_string(src.repr));
+ assert!(matches!(src.type_, None));
+ assert!(matches!(src.space, None));
+ }
+
+ #[test]
+ fn visit_variant_empty() {
+ let input = quote! {
+ Ret {
+ data: RetData
+ }
+ };
+ let variant = syn::parse2::<InstructionVariant>(input).unwrap();
+ let mut output = TokenStream::new();
+ variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output);
+ assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }");
+ }
+}
diff --git a/ptx_parser_macros_impl/src/parser.rs b/ptx_parser_macros_impl/src/parser.rs new file mode 100644 index 0000000..f1cd738 --- /dev/null +++ b/ptx_parser_macros_impl/src/parser.rs @@ -0,0 +1,844 @@ +use proc_macro2::Span;
+use proc_macro2::TokenStream;
+use quote::quote;
+use quote::ToTokens;
+use rustc_hash::FxHashMap;
+use std::fmt::Write;
+use syn::bracketed;
+use syn::parse::Peek;
+use syn::punctuated::Punctuated;
+use syn::spanned::Spanned;
+use syn::LitInt;
+use syn::Type;
+use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token};
+
+pub struct ParseDefinitions {
+ pub token_type: ItemEnum,
+ pub additional_enums: FxHashMap<Ident, ItemEnum>,
+ pub definitions: Vec<OpcodeDefinition>,
+}
+
+impl Parse for ParseDefinitions {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let token_type = input.parse::<ItemEnum>()?;
+ let mut additional_enums = FxHashMap::default();
+ while input.peek(Token![#]) {
+ let enum_ = input.parse::<ItemEnum>()?;
+ additional_enums.insert(enum_.ident.clone(), enum_);
+ }
+ let mut definitions = Vec::new();
+ while !input.is_empty() {
+ definitions.push(input.parse::<OpcodeDefinition>()?);
+ }
+ Ok(Self {
+ token_type,
+ additional_enums,
+ definitions,
+ })
+ }
+}
+
+pub struct OpcodeDefinition(pub Patterns, pub Vec<Rule>);
+
+impl Parse for OpcodeDefinition {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let patterns = input.parse::<Patterns>()?;
+ let mut rules = Vec::new();
+ while Rule::peek(input) {
+ rules.push(input.parse::<Rule>()?);
+ input.parse::<Token![;]>()?;
+ }
+ Ok(Self(patterns, rules))
+ }
+}
+
+pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>);
+
+impl Parse for Patterns {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let mut result = Vec::new();
+ loop {
+ if !OpcodeDecl::peek(input) {
+ break;
+ }
+ let decl = input.parse::<OpcodeDecl>()?;
+ let code_block = input.parse::<CodeBlock>()?;
+ result.push((decl, code_block))
+ }
+ Ok(Self(result))
+ }
+}
+
+pub struct OpcodeDecl(pub Instruction, pub Arguments);
+
+impl OpcodeDecl {
+ fn peek(input: syn::parse::ParseStream) -> bool {
+ Instruction::peek(input) && !input.peek2(Token![=])
+ }
+}
+
+impl Parse for OpcodeDecl {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ Ok(Self(
+ input.parse::<Instruction>()?,
+ input.parse::<Arguments>()?,
+ ))
+ }
+}
+
+pub struct CodeBlock {
+ pub special: bool,
+ pub code: proc_macro2::Group,
+}
+
+impl Parse for CodeBlock {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let lookahead = input.lookahead1();
+ let (special, code) = if lookahead.peek(Token![<]) {
+ input.parse::<Token![<]>()?;
+ input.parse::<Token![=]>()?;
+ //input.parse::<Token![>]>()?;
+ (true, input.parse::<proc_macro2::Group>()?)
+ } else if lookahead.peek(Token![=]) {
+ input.parse::<Token![=]>()?;
+ input.parse::<Token![>]>()?;
+ (false, input.parse::<proc_macro2::Group>()?)
+ } else {
+ return Err(lookahead.error());
+ };
+ Ok(Self { special, code })
+ }
+}
+
+pub struct Rule {
+ pub modifier: Option<DotModifier>,
+ pub type_: Option<Type>,
+ pub alternatives: Vec<DotModifier>,
+}
+
+impl Rule {
+ fn peek(input: syn::parse::ParseStream) -> bool {
+ DotModifier::peek(input)
+ || (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>]))
+ }
+
+ fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result<Vec<DotModifier>> {
+ let mut result = Vec::new();
+ Self::parse_with_alternative(input, &mut result)?;
+ loop {
+ if !input.peek(Token![,]) {
+ break;
+ }
+ input.parse::<Token![,]>()?;
+ Self::parse_with_alternative(input, &mut result)?;
+ }
+ Ok(result)
+ }
+
+ fn parse_with_alternative(
+ input: &syn::parse::ParseBuffer,
+ result: &mut Vec<DotModifier>,
+ ) -> Result<(), syn::Error> {
+ input.parse::<Token![.]>()?;
+ let part1 = input.parse::<IdentLike>()?;
+ if input.peek(token::Brace) {
+ result.push(DotModifier {
+ part1: part1.clone(),
+ part2: None,
+ });
+ let suffix_content;
+ braced!(suffix_content in input);
+ let suffixes = Punctuated::<IdentOrTypeSuffix, Token![,]>::parse_separated_nonempty(
+ &suffix_content,
+ )?;
+ for part2 in suffixes {
+ result.push(DotModifier {
+ part1: part1.clone(),
+ part2: Some(part2),
+ });
+ }
+ } else if IdentOrTypeSuffix::peek(input) {
+ let part2 = Some(IdentOrTypeSuffix::parse(input)?);
+ result.push(DotModifier { part1, part2 });
+ } else {
+ result.push(DotModifier { part1, part2: None });
+ }
+ Ok(())
+ }
+}
+
+#[derive(PartialEq, Eq, Hash, Clone)]
+struct IdentOrTypeSuffix(IdentLike);
+
+impl IdentOrTypeSuffix {
+ fn span(&self) -> Span {
+ self.0.span()
+ }
+
+ fn peek(input: syn::parse::ParseStream) -> bool {
+ input.peek(Token![::])
+ }
+}
+
+impl ToTokens for IdentOrTypeSuffix {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ let ident = &self.0;
+ quote! { :: #ident }.to_tokens(tokens)
+ }
+}
+
+impl Parse for IdentOrTypeSuffix {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ input.parse::<Token![::]>()?;
+ Ok(Self(input.parse::<IdentLike>()?))
+ }
+}
+
+impl Parse for Rule {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let (modifier, type_) = if DotModifier::peek(input) {
+ let modifier = Some(input.parse::<DotModifier>()?);
+ if input.peek(Token![:]) {
+ input.parse::<Token![:]>()?;
+ (modifier, Some(input.parse::<Type>()?))
+ } else {
+ (modifier, None)
+ }
+ } else {
+ (None, Some(input.parse::<Type>()?))
+ };
+ input.parse::<Token![=]>()?;
+ let content;
+ braced!(content in input);
+ let alternatives = Self::parse_alternatives(&content)?;
+ Ok(Self {
+ modifier,
+ type_,
+ alternatives,
+ })
+ }
+}
+
+pub struct Instruction {
+ pub name: Ident,
+ pub modifiers: Vec<MaybeDotModifier>,
+}
+impl Instruction {
+ fn peek(input: syn::parse::ParseStream) -> bool {
+ input.peek(Ident)
+ }
+}
+
+impl Parse for Instruction {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let instruction = input.parse::<Ident>()?;
+ let mut modifiers = Vec::new();
+ loop {
+ if !MaybeDotModifier::peek(input) {
+ break;
+ }
+ modifiers.push(MaybeDotModifier::parse(input)?);
+ }
+ Ok(Self {
+ name: instruction,
+ modifiers,
+ })
+ }
+}
+
+pub struct MaybeDotModifier {
+ pub optional: bool,
+ pub modifier: DotModifier,
+}
+
+impl MaybeDotModifier {
+ fn peek(input: syn::parse::ParseStream) -> bool {
+ input.peek(token::Brace) || DotModifier::peek(input)
+ }
+}
+
+impl Parse for MaybeDotModifier {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ Ok(if input.peek(token::Brace) {
+ let content;
+ braced!(content in input);
+ let modifier = DotModifier::parse(&content)?;
+ Self {
+ modifier,
+ optional: true,
+ }
+ } else {
+ let modifier = DotModifier::parse(input)?;
+ Self {
+ modifier,
+ optional: false,
+ }
+ })
+ }
+}
+
+#[derive(PartialEq, Eq, Hash, Clone)]
+pub struct DotModifier {
+ part1: IdentLike,
+ part2: Option<IdentOrTypeSuffix>,
+}
+
+impl std::fmt::Display for DotModifier {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, ".")?;
+ self.part1.fmt(f)?;
+ if let Some(ref part2) = self.part2 {
+ write!(f, "::")?;
+ part2.0.fmt(f)?;
+ }
+ Ok(())
+ }
+}
+
+impl std::fmt::Debug for DotModifier {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ std::fmt::Display::fmt(&self, f)
+ }
+}
+
+impl DotModifier {
+ pub fn span(&self) -> Span {
+ let part1 = self.part1.span();
+ if let Some(ref part2) = self.part2 {
+ part1.join(part2.span()).unwrap_or(part1)
+ } else {
+ part1
+ }
+ }
+
+ pub fn ident(&self) -> Ident {
+ let mut result = String::new();
+ write!(&mut result, "{}", self.part1).unwrap();
+ if let Some(ref part2) = self.part2 {
+ write!(&mut result, "_{}", part2.0).unwrap();
+ } else {
+ match self.part1 {
+ IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'),
+ IdentLike::Ident(_) | IdentLike::Integer(_) => {}
+ }
+ }
+ Ident::new(&result.to_ascii_lowercase(), self.span())
+ }
+
+ pub fn variant_capitalized(&self) -> Ident {
+ self.capitalized_impl(String::new())
+ }
+
+ pub fn dot_capitalized(&self) -> Ident {
+ self.capitalized_impl("Dot".to_string())
+ }
+
+ fn capitalized_impl(&self, prefix: String) -> Ident {
+ let mut temp = String::new();
+ write!(&mut temp, "{}", &self.part1).unwrap();
+ if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 {
+ write!(&mut temp, "_{}", part2).unwrap();
+ }
+ let mut result = prefix;
+ let mut capitalize = true;
+ for c in temp.chars() {
+ if c == '_' {
+ capitalize = true;
+ continue;
+ }
+ // Special hack to emit `BF16`` instead of `Bf16``
+ let c = if capitalize || c == 'f' && result.ends_with('B') {
+ capitalize = false;
+ c.to_ascii_uppercase()
+ } else {
+ c
+ };
+ result.push(c);
+ }
+ Ident::new(&result, self.span())
+ }
+
+ pub fn tokens(&self) -> TokenStream {
+ let part1 = &self.part1;
+ let part2 = &self.part2;
+ match self.part2 {
+ None => quote! { . #part1 },
+ Some(_) => quote! { . #part1 #part2 },
+ }
+ }
+
+ fn peek(input: syn::parse::ParseStream) -> bool {
+ input.peek(Token![.])
+ }
+}
+
+impl Parse for DotModifier {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ input.parse::<Token![.]>()?;
+ let part1 = input.parse::<IdentLike>()?;
+ if IdentOrTypeSuffix::peek(input) {
+ let part2 = Some(IdentOrTypeSuffix::parse(input)?);
+ Ok(Self { part1, part2 })
+ } else {
+ Ok(Self { part1, part2: None })
+ }
+ }
+}
+
+#[derive(PartialEq, Eq, Hash, Clone)]
+enum IdentLike {
+ Type(Token![type]),
+ Const(Token![const]),
+ Ident(Ident),
+ Integer(LitInt),
+}
+
+impl IdentLike {
+ fn span(&self) -> Span {
+ match self {
+ IdentLike::Type(c) => c.span(),
+ IdentLike::Const(t) => t.span(),
+ IdentLike::Ident(i) => i.span(),
+ IdentLike::Integer(l) => l.span(),
+ }
+ }
+}
+
+impl std::fmt::Display for IdentLike {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ IdentLike::Type(_) => f.write_str("type"),
+ IdentLike::Const(_) => f.write_str("const"),
+ IdentLike::Ident(ident) => write!(f, "{}", ident),
+ IdentLike::Integer(integer) => write!(f, "{}", integer),
+ }
+ }
+}
+
+impl ToTokens for IdentLike {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ match self {
+ IdentLike::Type(_) => quote! { type }.to_tokens(tokens),
+ IdentLike::Const(_) => quote! { const }.to_tokens(tokens),
+ IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens),
+ IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens),
+ }
+ }
+}
+
+impl Parse for IdentLike {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let lookahead = input.lookahead1();
+ Ok(if lookahead.peek(Token![const]) {
+ IdentLike::Const(input.parse::<Token![const]>()?)
+ } else if lookahead.peek(Token![type]) {
+ IdentLike::Type(input.parse::<Token![type]>()?)
+ } else if lookahead.peek(Ident) {
+ IdentLike::Ident(input.parse::<Ident>()?)
+ } else if lookahead.peek(LitInt) {
+ IdentLike::Integer(input.parse::<LitInt>()?)
+ } else {
+ return Err(lookahead.error());
+ })
+ }
+}
+
+// Arguments decalaration can loook like this:
+// a{, b}
+// That's why we don't parse Arguments as Punctuated<Argument, Token![,]>
+#[derive(PartialEq, Eq)]
+pub struct Arguments(pub Vec<Argument>);
+
+impl Parse for Arguments {
+ fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
+ let mut result = Vec::new();
+ loop {
+ if input.peek(Token![,]) {
+ input.parse::<Token![,]>()?;
+ }
+ let mut optional = false;
+ let mut can_be_negated = false;
+ let mut pre_pipe = false;
+ let ident;
+ let lookahead = input.lookahead1();
+ if lookahead.peek(token::Brace) {
+ let content;
+ braced!(content in input);
+ let lookahead = content.lookahead1();
+ if lookahead.peek(Token![!]) {
+ content.parse::<Token![!]>()?;
+ can_be_negated = true;
+ ident = input.parse::<Ident>()?;
+ } else if lookahead.peek(Token![,]) {
+ optional = true;
+ content.parse::<Token![,]>()?;
+ ident = content.parse::<Ident>()?;
+ } else {
+ return Err(lookahead.error());
+ }
+ } else if lookahead.peek(token::Bracket) {
+ let bracketed;
+ bracketed!(bracketed in input);
+ if bracketed.peek(Token![|]) {
+ optional = true;
+ bracketed.parse::<Token![|]>()?;
+ pre_pipe = true;
+ ident = bracketed.parse::<Ident>()?;
+ } else {
+ let mut sub_args = Self::parse(&bracketed)?;
+ sub_args.0.first_mut().unwrap().pre_bracket = true;
+ sub_args.0.last_mut().unwrap().post_bracket = true;
+ if peek_brace_token(input, Token![.]) {
+ let optional_suffix;
+ braced!(optional_suffix in input);
+ optional_suffix.parse::<Token![.]>()?;
+ let unified_ident = optional_suffix.parse::<Ident>()?;
+ if unified_ident.to_string() != "unified" {
+ return Err(syn::Error::new(
+ unified_ident.span(),
+ format!("Exptected `unified`, got `{}`", unified_ident),
+ ));
+ }
+ for a in sub_args.0.iter_mut() {
+ a.unified = true;
+ }
+ }
+ result.extend(sub_args.0);
+ continue;
+ }
+ } else if lookahead.peek(Ident) {
+ ident = input.parse::<Ident>()?;
+ } else if lookahead.peek(Token![|]) {
+ input.parse::<Token![|]>()?;
+ pre_pipe = true;
+ ident = input.parse::<Ident>()?;
+ } else {
+ break;
+ }
+ result.push(Argument {
+ optional,
+ pre_pipe,
+ can_be_negated,
+ pre_bracket: false,
+ ident,
+ post_bracket: false,
+ unified: false,
+ });
+ }
+ Ok(Self(result))
+ }
+}
+
+// This is effectively input.peek(token::Brace) && input.peek2(Token![.])
+// input.peek2 is supposed to skip over next token, but it skips over whole
+// braced token group. Not sure if it's a bug
+fn peek_brace_token<T: Peek>(input: syn::parse::ParseStream, _t: T) -> bool {
+ use syn::token::Token;
+ let cursor = input.cursor();
+ cursor
+ .group(proc_macro2::Delimiter::Brace)
+ .map_or(false, |(content, ..)| T::Token::peek(content))
+}
+
+#[derive(PartialEq, Eq)]
+pub struct Argument {
+ pub optional: bool,
+ pub pre_bracket: bool,
+ pub pre_pipe: bool,
+ pub can_be_negated: bool,
+ pub ident: Ident,
+ pub post_bracket: bool,
+ pub unified: bool,
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{Arguments, DotModifier, MaybeDotModifier};
+ use quote::{quote, ToTokens};
+
+ #[test]
+ fn parse_modifier_complex() {
+ let input = quote! {
+ .level::eviction_priority
+ };
+ let modifier = syn::parse2::<DotModifier>(input).unwrap();
+ assert_eq!(
+ ". level :: eviction_priority",
+ modifier.tokens().to_string()
+ );
+ }
+
+ #[test]
+ fn parse_modifier_optional() {
+ let input = quote! {
+ { .level::eviction_priority }
+ };
+ let maybe_modifider = syn::parse2::<MaybeDotModifier>(input).unwrap();
+ assert_eq!(
+ ". level :: eviction_priority",
+ maybe_modifider.modifier.tokens().to_string()
+ );
+ assert!(maybe_modifider.optional);
+ }
+
+ #[test]
+ fn parse_type_token() {
+ let input = quote! {
+ . type
+ };
+ let maybe_modifier = syn::parse2::<MaybeDotModifier>(input).unwrap();
+ assert_eq!(". type", maybe_modifier.modifier.tokens().to_string());
+ assert!(!maybe_modifier.optional);
+ }
+
+ #[test]
+ fn arguments_memory() {
+ let input = quote! {
+ [a], b
+ };
+ let arguments = syn::parse2::<Arguments>(input).unwrap();
+ let a = &arguments.0[0];
+ assert!(!a.optional);
+ assert_eq!("a", a.ident.to_string());
+ assert!(a.pre_bracket);
+ assert!(!a.pre_pipe);
+ assert!(a.post_bracket);
+ assert!(!a.can_be_negated);
+ let b = &arguments.0[1];
+ assert!(!b.optional);
+ assert_eq!("b", b.ident.to_string());
+ assert!(!b.pre_bracket);
+ assert!(!b.pre_pipe);
+ assert!(!b.post_bracket);
+ assert!(!b.can_be_negated);
+ }
+
+ #[test]
+ fn arguments_optional() {
+ let input = quote! {
+ b{, cache_policy}
+ };
+ let arguments = syn::parse2::<Arguments>(input).unwrap();
+ let b = &arguments.0[0];
+ assert!(!b.optional);
+ assert_eq!("b", b.ident.to_string());
+ assert!(!b.pre_bracket);
+ assert!(!b.pre_pipe);
+ assert!(!b.post_bracket);
+ assert!(!b.can_be_negated);
+ let cache_policy = &arguments.0[1];
+ assert!(cache_policy.optional);
+ assert_eq!("cache_policy", cache_policy.ident.to_string());
+ assert!(!cache_policy.pre_bracket);
+ assert!(!cache_policy.pre_pipe);
+ assert!(!cache_policy.post_bracket);
+ assert!(!cache_policy.can_be_negated);
+ }
+
+ #[test]
+ fn arguments_optional_pred() {
+ let input = quote! {
+ p[|q], a
+ };
+ let arguments = syn::parse2::<Arguments>(input).unwrap();
+ assert_eq!(arguments.0.len(), 3);
+ let p = &arguments.0[0];
+ assert!(!p.optional);
+ assert_eq!("p", p.ident.to_string());
+ assert!(!p.pre_bracket);
+ assert!(!p.pre_pipe);
+ assert!(!p.post_bracket);
+ assert!(!p.can_be_negated);
+ let q = &arguments.0[1];
+ assert!(q.optional);
+ assert_eq!("q", q.ident.to_string());
+ assert!(!q.pre_bracket);
+ assert!(q.pre_pipe);
+ assert!(!q.post_bracket);
+ assert!(!q.can_be_negated);
+ let a = &arguments.0[2];
+ assert!(!a.optional);
+ assert_eq!("a", a.ident.to_string());
+ assert!(!a.pre_bracket);
+ assert!(!a.pre_pipe);
+ assert!(!a.post_bracket);
+ assert!(!a.can_be_negated);
+ }
+
+ #[test]
+ fn arguments_optional_with_negate() {
+ let input = quote! {
+ b, {!}c
+ };
+ let arguments = syn::parse2::<Arguments>(input).unwrap();
+ assert_eq!(arguments.0.len(), 2);
+ let b = &arguments.0[0];
+ assert!(!b.optional);
+ assert_eq!("b", b.ident.to_string());
+ assert!(!b.pre_bracket);
+ assert!(!b.pre_pipe);
+ assert!(!b.post_bracket);
+ assert!(!b.can_be_negated);
+ let c = &arguments.0[1];
+ assert!(!c.optional);
+ assert_eq!("c", c.ident.to_string());
+ assert!(!c.pre_bracket);
+ assert!(!c.pre_pipe);
+ assert!(!c.post_bracket);
+ assert!(c.can_be_negated);
+ }
+
+ #[test]
+ fn arguments_tex() {
+ let input = quote! {
+ d[|p], [a{, b}, c], dpdx, dpdy {, e}
+ };
+ let arguments = syn::parse2::<Arguments>(input).unwrap();
+ assert_eq!(arguments.0.len(), 8);
+ {
+ let d = &arguments.0[0];
+ assert!(!d.optional);
+ assert_eq!("d", d.ident.to_string());
+ assert!(!d.pre_bracket);
+ assert!(!d.pre_pipe);
+ assert!(!d.post_bracket);
+ assert!(!d.can_be_negated);
+ }
+ {
+ let p = &arguments.0[1];
+ assert!(p.optional);
+ assert_eq!("p", p.ident.to_string());
+ assert!(!p.pre_bracket);
+ assert!(p.pre_pipe);
+ assert!(!p.post_bracket);
+ assert!(!p.can_be_negated);
+ }
+ {
+ let a = &arguments.0[2];
+ assert!(!a.optional);
+ assert_eq!("a", a.ident.to_string());
+ assert!(a.pre_bracket);
+ assert!(!a.pre_pipe);
+ assert!(!a.post_bracket);
+ assert!(!a.can_be_negated);
+ }
+ {
+ let b = &arguments.0[3];
+ assert!(b.optional);
+ assert_eq!("b", b.ident.to_string());
+ assert!(!b.pre_bracket);
+ assert!(!b.pre_pipe);
+ assert!(!b.post_bracket);
+ assert!(!b.can_be_negated);
+ }
+ {
+ let c = &arguments.0[4];
+ assert!(!c.optional);
+ assert_eq!("c", c.ident.to_string());
+ assert!(!c.pre_bracket);
+ assert!(!c.pre_pipe);
+ assert!(c.post_bracket);
+ assert!(!c.can_be_negated);
+ }
+ {
+ let dpdx = &arguments.0[5];
+ assert!(!dpdx.optional);
+ assert_eq!("dpdx", dpdx.ident.to_string());
+ assert!(!dpdx.pre_bracket);
+ assert!(!dpdx.pre_pipe);
+ assert!(!dpdx.post_bracket);
+ assert!(!dpdx.can_be_negated);
+ }
+ {
+ let dpdy = &arguments.0[6];
+ assert!(!dpdy.optional);
+ assert_eq!("dpdy", dpdy.ident.to_string());
+ assert!(!dpdy.pre_bracket);
+ assert!(!dpdy.pre_pipe);
+ assert!(!dpdy.post_bracket);
+ assert!(!dpdy.can_be_negated);
+ }
+ {
+ let e = &arguments.0[7];
+ assert!(e.optional);
+ assert_eq!("e", e.ident.to_string());
+ assert!(!e.pre_bracket);
+ assert!(!e.pre_pipe);
+ assert!(!e.post_bracket);
+ assert!(!e.can_be_negated);
+ }
+ }
+
+ #[test]
+ fn rule_multi() {
+ let input = quote! {
+ .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }
+ };
+ let rule = syn::parse2::<super::Rule>(input).unwrap();
+ assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string());
+ assert_eq!(
+ "StateSpace",
+ rule.type_.unwrap().to_token_stream().to_string()
+ );
+ let alts = rule
+ .alternatives
+ .iter()
+ .map(|m| m.tokens().to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(
+ vec![
+ ". global",
+ ". local",
+ ". param",
+ ". param :: func",
+ ". shared",
+ ". shared :: cta",
+ ". shared :: cluster"
+ ],
+ alts
+ );
+ }
+
+ #[test]
+ fn rule_multi2() {
+ let input = quote! {
+ .cop: StCacheOperator = { .wb, .cg, .cs, .wt }
+ };
+ let rule = syn::parse2::<super::Rule>(input).unwrap();
+ assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string());
+ assert_eq!(
+ "StCacheOperator",
+ rule.type_.unwrap().to_token_stream().to_string()
+ );
+ let alts = rule
+ .alternatives
+ .iter()
+ .map(|m| m.tokens().to_string())
+ .collect::<Vec<_>>();
+ assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts);
+ }
+
+ #[test]
+ fn args_unified() {
+ let input = quote! {
+ d, [a]{.unified}{, cache_policy}
+ };
+ let args = syn::parse2::<super::Arguments>(input).unwrap();
+ let a = &args.0[1];
+ assert!(!a.optional);
+ assert_eq!("a", a.ident.to_string());
+ assert!(a.pre_bracket);
+ assert!(!a.pre_pipe);
+ assert!(a.post_bracket);
+ assert!(!a.can_be_negated);
+ assert!(a.unified);
+ }
+
+ #[test]
+ fn special_block() {
+ let input = quote! {
+ bra <= { bra(stream) }
+ };
+ syn::parse2::<super::OpcodeDefinition>(input).unwrap();
+ }
+}
|