diff options
Diffstat (limited to 'ptx/src/pass')
21 files changed, 2423 insertions, 6758 deletions
diff --git a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs deleted file mode 100644 index 1dac7fd..0000000 --- a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs +++ /dev/null @@ -1,299 +0,0 @@ -use std::collections::{BTreeMap, BTreeSet};
-
-use super::*;
-
-/*
- PTX represents dynamically allocated shared local memory as
- .extern .shared .b32 shared_mem[];
- In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
- And in AMD compilation
- This pass looks for all uses of .extern .shared and converts them to
- an additional method argument
- The question is how this artificial argument should be expressed. There are
- several options:
- * Straight conversion:
- .shared .b32 shared_mem[]
- * Introduce .param_shared statespace:
- .param_shared .b32 shared_mem
- or
- .param_shared .b32 shared_mem[]
- * Introduce .shared_ptr <SCALAR> type:
- .param .shared_ptr .b32 shared_mem
- * Reuse .ptr hint:
- .param .u64 .ptr shared_mem
- This is the most tempting, but also the most nonsensical, .ptr is just a
- hint, which has no semantical meaning (and the output of our
- transformation has a semantical meaning - we emit additional
- "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
-*/
-pub(super) fn run<'input>(
- module: Vec<Directive<'input>>,
- kernels_methods_call_map: &MethodsCallMap<'input>,
- new_id: &mut impl FnMut() -> SpirvWord,
-) -> Result<Vec<Directive<'input>>, TranslateError> {
- let mut globals_shared = HashMap::new();
- for dir in module.iter() {
- match dir {
- Directive::Variable(
- _,
- ast::Variable {
- state_space: ast::StateSpace::Shared,
- name,
- v_type,
- ..
- },
- ) => {
- globals_shared.insert(*name, v_type.clone());
- }
- _ => {}
- }
- }
- if globals_shared.len() == 0 {
- return Ok(module);
- }
- let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet<SpirvWord>>::new();
- let module = module
- .into_iter()
- .map(|directive| match directive {
- Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- tuning,
- linkage,
- }) => {
- let call_key = (*func_decl).borrow().name;
- let statements = statements
- .into_iter()
- .map(|statement| {
- statement.visit_map(
- &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
- if let Some(_) = globals_shared.get(&id) {
- methods_to_directly_used_shared_globals
- .entry(call_key)
- .or_insert_with(HashSet::new)
- .insert(id);
- }
- Ok::<_, TranslateError>(id)
- },
- )
- })
- .collect::<Result<Vec<_>, _>>()?;
- Ok::<_, TranslateError>(Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- tuning,
- linkage,
- }))
- }
- directive => Ok(directive),
- })
- .collect::<Result<Vec<_>, _>>()?;
- // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
- // make sure it gets propagated to `fn1` and `kernel`
- let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
- methods_to_directly_used_shared_globals,
- kernels_methods_call_map,
- );
- // now visit every method declaration and inject those additional arguments
- let mut directives = Vec::with_capacity(module.len());
- for directive in module.into_iter() {
- match directive {
- Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- tuning,
- linkage,
- }) => {
- let statements = {
- let func_decl_ref = &mut (*func_decl).borrow_mut();
- let method_name = func_decl_ref.name;
- insert_arguments_remap_statements(
- new_id,
- kernels_methods_call_map,
- &globals_shared,
- &methods_to_indirectly_used_shared_globals,
- method_name,
- &mut directives,
- func_decl_ref,
- statements,
- )?
- };
- directives.push(Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- tuning,
- linkage,
- }));
- }
- directive => directives.push(directive),
- }
- }
- Ok(directives)
-}
-
-// We need to compute two kinds of information:
-// * If it's a kernel -> size of .shared globals in use (direct or indirect)
-// * If it's a function -> does it use .shared global (directly or indirectly)
-fn resolve_indirect_uses_of_globals_shared<'input>(
- methods_use_of_globals_shared: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
- kernels_methods_call_map: &MethodsCallMap<'input>,
-) -> HashMap<ast::MethodName<'input, SpirvWord>, BTreeSet<SpirvWord>> {
- let mut result = HashMap::new();
- for (method, callees) in kernels_methods_call_map.methods() {
- let mut indirect_globals = methods_use_of_globals_shared
- .get(&method)
- .into_iter()
- .flatten()
- .copied()
- .collect::<BTreeSet<_>>();
- for &callee in callees {
- indirect_globals.extend(
- methods_use_of_globals_shared
- .get(&ast::MethodName::Func(callee))
- .into_iter()
- .flatten()
- .copied(),
- );
- }
- result.insert(method, indirect_globals);
- }
- result
-}
-
-fn insert_arguments_remap_statements<'input>(
- new_id: &mut impl FnMut() -> SpirvWord,
- kernels_methods_call_map: &MethodsCallMap<'input>,
- globals_shared: &HashMap<SpirvWord, ast::Type>,
- methods_to_indirectly_used_shared_globals: &HashMap<
- ast::MethodName<'input, SpirvWord>,
- BTreeSet<SpirvWord>,
- >,
- method_name: ast::MethodName<SpirvWord>,
- result: &mut Vec<Directive>,
- func_decl_ref: &mut std::cell::RefMut<ast::MethodDeclaration<SpirvWord>>,
- statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
-) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
- let remapped_globals_in_method =
- if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) {
- match method_name {
- ast::MethodName::Func(..) => {
- let remapped_globals = method_globals
- .iter()
- .map(|global| {
- (
- *global,
- (
- new_id(),
- globals_shared
- .get(&global)
- .unwrap_or_else(|| todo!())
- .clone(),
- ),
- )
- })
- .collect::<BTreeMap<_, _>>();
- for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() {
- func_decl_ref.input_arguments.push(ast::Variable {
- align: None,
- v_type: shared_global_type.clone(),
- state_space: ast::StateSpace::Shared,
- name: *new_shared_global_id,
- array_init: Vec::new(),
- });
- }
- remapped_globals
- }
- ast::MethodName::Kernel(..) => method_globals
- .iter()
- .map(|global| {
- (
- *global,
- (
- *global,
- globals_shared
- .get(&global)
- .unwrap_or_else(|| todo!())
- .clone(),
- ),
- )
- })
- .collect::<BTreeMap<_, _>>(),
- }
- } else {
- return Ok(statements);
- };
- replace_uses_of_shared_memory(
- new_id,
- methods_to_indirectly_used_shared_globals,
- statements,
- remapped_globals_in_method,
- )
-}
-
-fn replace_uses_of_shared_memory<'input>(
- new_id: &mut impl FnMut() -> SpirvWord,
- methods_to_indirectly_used_shared_globals: &HashMap<
- ast::MethodName<'input, SpirvWord>,
- BTreeSet<SpirvWord>,
- >,
- statements: Vec<ExpandedStatement>,
- remapped_globals_in_method: BTreeMap<SpirvWord, (SpirvWord, ast::Type)>,
-) -> Result<Vec<ExpandedStatement>, TranslateError> {
- let mut result = Vec::with_capacity(statements.len());
- for statement in statements {
- match statement {
- Statement::Instruction(ast::Instruction::Call {
- mut data,
- mut arguments,
- }) => {
- // We can safely skip checking call arguments,
- // because there's simply no way to pass shared ptr
- // without converting it to .b64 first
- if let Some(shared_globals_used_by_callee) =
- methods_to_indirectly_used_shared_globals
- .get(&ast::MethodName::Func(arguments.func))
- {
- for &shared_global_used_by_callee in shared_globals_used_by_callee {
- let (remapped_shared_id, type_) = remapped_globals_in_method
- .get(&shared_global_used_by_callee)
- .unwrap_or_else(|| todo!());
- data.input_arguments
- .push((type_.clone(), ast::StateSpace::Shared));
- arguments.input_arguments.push(*remapped_shared_id);
- }
- }
- result.push(Statement::Instruction(ast::Instruction::Call {
- data,
- arguments,
- }))
- }
- statement => {
- let new_statement =
- statement.visit_map(&mut |id,
- _: Option<(&ast::Type, ast::StateSpace)>,
- _,
- _| {
- Ok::<_, TranslateError>(
- if let Some((remapped_shared_id, _)) =
- remapped_globals_in_method.get(&id)
- {
- *remapped_shared_id
- } else {
- id
- },
- )
- })?;
- result.push(new_statement);
- }
- }
- }
- Ok(result)
-}
diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs deleted file mode 100644 index 3b8fa93..0000000 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ /dev/null @@ -1,524 +0,0 @@ -use super::*;
-use ptx_parser as ast;
-use std::{
- collections::{BTreeSet, HashSet},
- iter,
- rc::Rc,
-};
-
-/*
- Our goal here is to transform
- .visible .entry foobar(.param .u64 input) {
- .reg .b64 in_addr;
- .reg .b64 in_addr2;
- ld.param.u64 in_addr, [input];
- cvta.to.global.u64 in_addr2, in_addr;
- }
- into:
- .visible .entry foobar(.param .u8 input[]) {
- .reg .u8 in_addr[];
- .reg .u8 in_addr2[];
- ld.param.u8[] in_addr, [input];
- mov.u8[] in_addr2, in_addr;
- }
- or:
- .visible .entry foobar(.reg .u8 input[]) {
- .reg .u8 in_addr[];
- .reg .u8 in_addr2[];
- mov.u8[] in_addr, input;
- mov.u8[] in_addr2, in_addr;
- }
- or:
- .visible .entry foobar(.param ptr<u8, global> input) {
- .reg ptr<u8, global> in_addr;
- .reg ptr<u8, global> in_addr2;
- ld.param.ptr<u8, global> in_addr, [input];
- mov.ptr<u8, global> in_addr2, in_addr;
- }
-*/
-// TODO: detect more patterns (mov, call via reg, call via param)
-// TODO: don't convert to ptr if the register is not ultimately used for ld/st
-// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
-// argument expansion
-// TODO: propagate out of calls and into calls
-pub(super) fn run<'a, 'input>(
- func_args: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
- func_body: Vec<TypedStatement>,
- id_defs: &mut NumericIdResolver<'a>,
-) -> Result<
- (
- Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
- Vec<TypedStatement>,
- ),
- TranslateError,
-> {
- let mut method_decl = func_args.borrow_mut();
- if !matches!(method_decl.name, ast::MethodName::Kernel(..)) {
- drop(method_decl);
- return Ok((func_args, func_body));
- }
- if Rc::strong_count(&func_args) != 1 {
- return Err(error_unreachable());
- }
- let func_args_64bit = (*method_decl)
- .input_arguments
- .iter()
- .filter_map(|arg| match arg.v_type {
- ast::Type::Scalar(ast::ScalarType::U64)
- | ast::Type::Scalar(ast::ScalarType::B64)
- | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
- _ => None,
- })
- .collect::<HashSet<_>>();
- let mut stateful_markers = Vec::new();
- let mut stateful_init_reg = HashMap::<_, Vec<_>>::new();
- for statement in func_body.iter() {
- match statement {
- Statement::Instruction(ast::Instruction::Cvta {
- data:
- ast::CvtaDetails {
- state_space: ast::StateSpace::Global,
- direction: ast::CvtaDirection::GenericToExplicit,
- },
- arguments,
- }) => {
- if let (TypedOperand::Reg(dst), Some(src)) =
- (arguments.dst, arguments.src.underlying_register())
- {
- if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) {
- stateful_markers.push((dst, src));
- }
- }
- }
- Statement::Instruction(ast::Instruction::Ld {
- data:
- ast::LdDetails {
- state_space: ast::StateSpace::Param,
- typ: ast::Type::Scalar(ast::ScalarType::U64),
- ..
- },
- arguments,
- })
- | Statement::Instruction(ast::Instruction::Ld {
- data:
- ast::LdDetails {
- state_space: ast::StateSpace::Param,
- typ: ast::Type::Scalar(ast::ScalarType::S64),
- ..
- },
- arguments,
- })
- | Statement::Instruction(ast::Instruction::Ld {
- data:
- ast::LdDetails {
- state_space: ast::StateSpace::Param,
- typ: ast::Type::Scalar(ast::ScalarType::B64),
- ..
- },
- arguments,
- }) => {
- if let (TypedOperand::Reg(dst), Some(src)) =
- (arguments.dst, arguments.src.underlying_register())
- {
- if func_args_64bit.contains(&src) {
- multi_hash_map_append(&mut stateful_init_reg, dst, src);
- }
- }
- }
- _ => {}
- }
- }
- if stateful_markers.len() == 0 {
- drop(method_decl);
- return Ok((func_args, func_body));
- }
- let mut func_args_ptr = HashSet::new();
- let mut regs_ptr_current = HashSet::new();
- for (dst, src) in stateful_markers {
- if let Some(func_args) = stateful_init_reg.get(&src) {
- for a in func_args {
- func_args_ptr.insert(*a);
- regs_ptr_current.insert(src);
- regs_ptr_current.insert(dst);
- }
- }
- }
- // BTreeSet here to have a stable order of iteration,
- // unfortunately our tests rely on it
- let mut regs_ptr_seen = BTreeSet::new();
- while regs_ptr_current.len() > 0 {
- let mut regs_ptr_new = HashSet::new();
- for statement in func_body.iter() {
- match statement {
- Statement::Instruction(ast::Instruction::Add {
- data:
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: ast::ScalarType::U64,
- saturate: false,
- }),
- arguments,
- })
- | Statement::Instruction(ast::Instruction::Add {
- data:
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: ast::ScalarType::S64,
- saturate: false,
- }),
- arguments,
- }) => {
- // TODO: don't mark result of double pointer sub or double
- // pointer add as ptr result
- if let (TypedOperand::Reg(dst), Some(src1)) =
- (arguments.dst, arguments.src1.underlying_register())
- {
- if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
- regs_ptr_new.insert(dst);
- }
- } else if let (TypedOperand::Reg(dst), Some(src2)) =
- (arguments.dst, arguments.src2.underlying_register())
- {
- if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
- regs_ptr_new.insert(dst);
- }
- }
- }
-
- Statement::Instruction(ast::Instruction::Sub {
- data:
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: ast::ScalarType::U64,
- saturate: false,
- }),
- arguments,
- })
- | Statement::Instruction(ast::Instruction::Sub {
- data:
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: ast::ScalarType::S64,
- saturate: false,
- }),
- arguments,
- }) => {
- // TODO: don't mark result of double pointer sub or double
- // pointer add as ptr result
- if let (TypedOperand::Reg(dst), Some(src1)) =
- (arguments.dst, arguments.src1.underlying_register())
- {
- if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) {
- regs_ptr_new.insert(dst);
- }
- } else if let (TypedOperand::Reg(dst), Some(src2)) =
- (arguments.dst, arguments.src2.underlying_register())
- {
- if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) {
- regs_ptr_new.insert(dst);
- }
- }
- }
- _ => {}
- }
- }
- for id in regs_ptr_current {
- regs_ptr_seen.insert(id);
- }
- regs_ptr_current = regs_ptr_new;
- }
- drop(regs_ptr_current);
- let mut remapped_ids = HashMap::new();
- let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
- for reg in regs_ptr_seen {
- let new_id = id_defs.register_variable(
- ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
- ast::StateSpace::Reg,
- );
- result.push(Statement::Variable(ast::Variable {
- align: None,
- name: new_id,
- array_init: Vec::new(),
- v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
- state_space: ast::StateSpace::Reg,
- }));
- remapped_ids.insert(reg, new_id);
- }
- for arg in (*method_decl).input_arguments.iter_mut() {
- if !func_args_ptr.contains(&arg.name) {
- continue;
- }
- let new_id = id_defs.register_variable(
- ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
- ast::StateSpace::Param,
- );
- let old_name = arg.name;
- arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
- arg.name = new_id;
- remapped_ids.insert(old_name, new_id);
- }
- for statement in func_body {
- match statement {
- l @ Statement::Label(_) => result.push(l),
- c @ Statement::Conditional(_) => result.push(c),
- c @ Statement::Constant(..) => result.push(c),
- Statement::Variable(var) => {
- if !remapped_ids.contains_key(&var.name) {
- result.push(Statement::Variable(var));
- }
- }
- Statement::Instruction(ast::Instruction::Add {
- data:
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: ast::ScalarType::U64,
- saturate: false,
- }),
- arguments,
- })
- | Statement::Instruction(ast::Instruction::Add {
- data:
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: ast::ScalarType::S64,
- saturate: false,
- }),
- arguments,
- }) if is_add_ptr_direct(&remapped_ids, &arguments) => {
- let (ptr, offset) = match arguments.src1.underlying_register() {
- Some(src1) if remapped_ids.contains_key(&src1) => {
- (remapped_ids.get(&src1).unwrap(), arguments.src2)
- }
- Some(src2) if remapped_ids.contains_key(&src2) => {
- (remapped_ids.get(&src2).unwrap(), arguments.src1)
- }
- _ => return Err(error_unreachable()),
- };
- let dst = arguments.dst.unwrap_reg()?;
- result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
- state_space: ast::StateSpace::Global,
- dst: *remapped_ids.get(&dst).unwrap(),
- ptr_src: *ptr,
- offset_src: offset,
- }))
- }
- Statement::Instruction(ast::Instruction::Sub {
- data:
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: ast::ScalarType::U64,
- saturate: false,
- }),
- arguments,
- })
- | Statement::Instruction(ast::Instruction::Sub {
- data:
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: ast::ScalarType::S64,
- saturate: false,
- }),
- arguments,
- }) if is_sub_ptr_direct(&remapped_ids, &arguments) => {
- let (ptr, offset) = match arguments.src1.underlying_register() {
- Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2),
- _ => return Err(error_unreachable()),
- };
- let offset_neg = id_defs.register_intermediate(Some((
- ast::Type::Scalar(ast::ScalarType::S64),
- ast::StateSpace::Reg,
- )));
- result.push(Statement::Instruction(ast::Instruction::Neg {
- data: ast::TypeFtz {
- type_: ast::ScalarType::S64,
- flush_to_zero: None,
- },
- arguments: ast::NegArgs {
- src: offset,
- dst: TypedOperand::Reg(offset_neg),
- },
- }));
- let dst = arguments.dst.unwrap_reg()?;
- result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
- state_space: ast::StateSpace::Global,
- dst: *remapped_ids.get(&dst).unwrap(),
- ptr_src: *ptr,
- offset_src: TypedOperand::Reg(offset_neg),
- }))
- }
- inst @ Statement::Instruction(_) => {
- let mut post_statements = Vec::new();
- let new_statement = inst.visit_map(&mut FnVisitor::new(
- |operand, type_space, is_dst, relaxed_conversion| {
- convert_to_stateful_memory_access_postprocess(
- id_defs,
- &remapped_ids,
- &mut result,
- &mut post_statements,
- operand,
- type_space,
- is_dst,
- relaxed_conversion,
- )
- },
- ))?;
- result.push(new_statement);
- result.extend(post_statements);
- }
- repack @ Statement::RepackVector(_) => {
- let mut post_statements = Vec::new();
- let new_statement = repack.visit_map(&mut FnVisitor::new(
- |operand, type_space, is_dst, relaxed_conversion| {
- convert_to_stateful_memory_access_postprocess(
- id_defs,
- &remapped_ids,
- &mut result,
- &mut post_statements,
- operand,
- type_space,
- is_dst,
- relaxed_conversion,
- )
- },
- ))?;
- result.push(new_statement);
- result.extend(post_statements);
- }
- _ => return Err(error_unreachable()),
- }
- }
- drop(method_decl);
- Ok((func_args, result))
-}
-
-fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool {
- match id_defs.get_typed(id) {
- Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
- | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
- | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
- _ => false,
- }
-}
-
-fn is_add_ptr_direct(
- remapped_ids: &HashMap<SpirvWord, SpirvWord>,
- arg: &ast::AddArgs<TypedOperand>,
-) -> bool {
- match arg.dst {
- TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
- return false
- }
- TypedOperand::Reg(dst) => {
- if !remapped_ids.contains_key(&dst) {
- return false;
- }
- if let Some(ref src1_reg) = arg.src1.underlying_register() {
- if remapped_ids.contains_key(src1_reg) {
- // don't trigger optimization when adding two pointers
- if let Some(ref src2_reg) = arg.src2.underlying_register() {
- return !remapped_ids.contains_key(src2_reg);
- }
- }
- }
- if let Some(ref src2_reg) = arg.src2.underlying_register() {
- remapped_ids.contains_key(src2_reg)
- } else {
- false
- }
- }
- }
-}
-
-fn is_sub_ptr_direct(
- remapped_ids: &HashMap<SpirvWord, SpirvWord>,
- arg: &ast::SubArgs<TypedOperand>,
-) -> bool {
- match arg.dst {
- TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
- return false
- }
- TypedOperand::Reg(dst) => {
- if !remapped_ids.contains_key(&dst) {
- return false;
- }
- match arg.src1.underlying_register() {
- Some(ref src1_reg) => {
- if remapped_ids.contains_key(src1_reg) {
- // don't trigger optimization when subtracting two pointers
- arg.src2
- .underlying_register()
- .map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg))
- } else {
- false
- }
- }
- None => false,
- }
- }
- }
-}
-
-fn convert_to_stateful_memory_access_postprocess(
- id_defs: &mut NumericIdResolver,
- remapped_ids: &HashMap<SpirvWord, SpirvWord>,
- result: &mut Vec<TypedStatement>,
- post_statements: &mut Vec<TypedStatement>,
- operand: TypedOperand,
- type_space: Option<(&ast::Type, ast::StateSpace)>,
- is_dst: bool,
- relaxed_conversion: bool,
-) -> Result<TypedOperand, TranslateError> {
- operand.map(|operand, _| {
- Ok(match remapped_ids.get(&operand) {
- Some(new_id) => {
- let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
- // TODO: readd if required
- if let Some((expected_type, expected_space)) = type_space {
- let implicit_conversion = if relaxed_conversion {
- if is_dst {
- super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper
- } else {
- super::insert_implicit_conversions::should_convert_relaxed_src_wrapper
- }
- } else {
- super::insert_implicit_conversions::default_implicit_conversion
- };
- if implicit_conversion(
- (new_operand_space, &new_operand_type),
- (expected_space, expected_type),
- )
- .is_ok()
- {
- return Ok(*new_id);
- }
- }
- let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?;
- let converting_id = id_defs
- .register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
- let kind = if new_operand_space == ast::StateSpace::Reg {
- ConversionKind::Default
- } else {
- ConversionKind::PtrToPtr
- };
- if is_dst {
- post_statements.push(Statement::Conversion(ImplicitConversion {
- src: converting_id,
- dst: *new_id,
- from_type: old_operand_type,
- from_space: old_operand_space,
- to_type: new_operand_type,
- to_space: new_operand_space,
- kind,
- }));
- converting_id
- } else {
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from_type: new_operand_type,
- from_space: new_operand_space,
- to_type: old_operand_type,
- to_space: old_operand_space,
- kind,
- }));
- converting_id
- }
- }
- None => operand,
- })
- })
-}
diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs deleted file mode 100644 index 550c662..0000000 --- a/ptx/src/pass/convert_to_typed.rs +++ /dev/null @@ -1,138 +0,0 @@ -use super::*;
-use ptx_parser as ast;
-
-pub(crate) fn run(
- func: Vec<UnconditionalStatement>,
- fn_defs: &GlobalFnDeclResolver,
- id_defs: &mut NumericIdResolver,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let mut result = Vec::<TypedStatement>::with_capacity(func.len());
- for s in func {
- match s {
- Statement::Instruction(inst) => match inst {
- ast::Instruction::Mov {
- data,
- arguments:
- ast::MovArgs {
- dst: ast::ParsedOperand::Reg(dst_reg),
- src: ast::ParsedOperand::Reg(src_reg),
- },
- } if fn_defs.fns.contains_key(&src_reg) => {
- if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
- return Err(error_mismatched_type());
- }
- result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
- dst: dst_reg,
- src: src_reg,
- }));
- }
- ast::Instruction::Call { data, arguments } => {
- let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?;
- let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?;
- let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
- let reresolved_call =
- Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?);
- visitor.func.push(reresolved_call);
- visitor.func.extend(visitor.post_stmts);
- }
- inst => {
- let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
- let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?);
- visitor.func.push(instruction);
- visitor.func.extend(visitor.post_stmts);
- }
- },
- Statement::Label(i) => result.push(Statement::Label(i)),
- Statement::Variable(v) => result.push(Statement::Variable(v)),
- Statement::Conditional(c) => result.push(Statement::Conditional(c)),
- _ => return Err(error_unreachable()),
- }
- }
- Ok(result)
-}
-
-struct VectorRepackVisitor<'a, 'b> {
- func: &'b mut Vec<TypedStatement>,
- id_def: &'b mut NumericIdResolver<'a>,
- post_stmts: Option<TypedStatement>,
-}
-
-impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
- fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
- VectorRepackVisitor {
- func,
- id_def,
- post_stmts: None,
- }
- }
-
- fn convert_vector(
- &mut self,
- is_dst: bool,
- relaxed_type_check: bool,
- typ: &ast::Type,
- state_space: ast::StateSpace,
- idx: Vec<SpirvWord>,
- ) -> Result<SpirvWord, TranslateError> {
- // mov.u32 foobar, {a,b};
- let scalar_t = match typ {
- ast::Type::Vector(_, scalar_t) => *scalar_t,
- _ => return Err(error_mismatched_type()),
- };
- let temp_vec = self
- .id_def
- .register_intermediate(Some((typ.clone(), state_space)));
- let statement = Statement::RepackVector(RepackVectorDetails {
- is_extract: is_dst,
- typ: scalar_t,
- packed: temp_vec,
- unpacked: idx,
- relaxed_type_check,
- });
- if is_dst {
- self.post_stmts = Some(statement);
- } else {
- self.func.push(statement);
- }
- Ok(temp_vec)
- }
-}
-
-impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
- for VectorRepackVisitor<'a, 'b>
-{
- fn visit_ident(
- &mut self,
- ident: SpirvWord,
- _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- _: bool,
- _: bool,
- ) -> Result<SpirvWord, TranslateError> {
- Ok(ident)
- }
-
- fn visit(
- &mut self,
- op: ast::ParsedOperand<SpirvWord>,
- type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- is_dst: bool,
- relaxed_type_check: bool,
- ) -> Result<TypedOperand, TranslateError> {
- Ok(match op {
- ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
- ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
- ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
- ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
- ast::ParsedOperand::VecPack(vec) => {
- let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?;
- TypedOperand::Reg(self.convert_vector(
- is_dst,
- relaxed_type_check,
- type_,
- space,
- vec,
- )?)
- }
- })
- }
-}
diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs index 04c8831..15125b0 100644 --- a/ptx/src/pass/deparamize_functions.rs +++ b/ptx/src/pass/deparamize_functions.rs @@ -1,5 +1,3 @@ -use std::collections::BTreeMap;
-
use super::*;
pub(super) fn run<'a, 'input>(
@@ -26,75 +24,73 @@ fn run_method<'input>( resolver: &mut GlobalStringIdentResolver2,
mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
- if method.func_decl.name.is_kernel() {
- return Ok(method);
- }
let is_declaration = method.body.is_none();
let mut body = Vec::new();
let mut remap_returns = Vec::new();
- for arg in method.func_decl.return_arguments.iter_mut() {
- match arg.state_space {
- ptx_parser::StateSpace::Param => {
- arg.state_space = ptx_parser::StateSpace::Reg;
- let old_name = arg.name;
- arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
- if is_declaration {
- continue;
+ if !method.func_decl.name.is_kernel() {
+ for arg in method.func_decl.return_arguments.iter_mut() {
+ match arg.state_space {
+ ptx_parser::StateSpace::Param => {
+ arg.state_space = ptx_parser::StateSpace::Reg;
+ let old_name = arg.name;
+ arg.name =
+ resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
+ if is_declaration {
+ continue;
+ }
+ remap_returns.push((old_name, arg.name, arg.v_type.clone()));
+ body.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: old_name,
+ v_type: arg.v_type.clone(),
+ state_space: ptx_parser::StateSpace::Param,
+ array_init: Vec::new(),
+ }));
}
- remap_returns.push((old_name, arg.name, arg.v_type.clone()));
- body.push(Statement::Variable(ast::Variable {
- align: None,
- name: old_name,
- v_type: arg.v_type.clone(),
- state_space: ptx_parser::StateSpace::Param,
- array_init: Vec::new(),
- }));
+ ptx_parser::StateSpace::Reg => {}
+ _ => return Err(error_unreachable()),
}
- ptx_parser::StateSpace::Reg => {}
- _ => return Err(error_unreachable()),
}
- }
- for arg in method.func_decl.input_arguments.iter_mut() {
- match arg.state_space {
- ptx_parser::StateSpace::Param => {
- arg.state_space = ptx_parser::StateSpace::Reg;
- let old_name = arg.name;
- arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
- if is_declaration {
- continue;
+ for arg in method.func_decl.input_arguments.iter_mut() {
+ match arg.state_space {
+ ptx_parser::StateSpace::Param => {
+ arg.state_space = ptx_parser::StateSpace::Reg;
+ let old_name = arg.name;
+ arg.name =
+ resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
+ if is_declaration {
+ continue;
+ }
+ body.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: old_name,
+ v_type: arg.v_type.clone(),
+ state_space: ptx_parser::StateSpace::Param,
+ array_init: Vec::new(),
+ }));
+ body.push(Statement::Instruction(ast::Instruction::St {
+ data: ast::StData {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: arg.v_type.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: old_name,
+ src2: arg.name,
+ },
+ }));
}
- body.push(Statement::Variable(ast::Variable {
- align: None,
- name: old_name,
- v_type: arg.v_type.clone(),
- state_space: ptx_parser::StateSpace::Param,
- array_init: Vec::new(),
- }));
- body.push(Statement::Instruction(ast::Instruction::St {
- data: ast::StData {
- qualifier: ast::LdStQualifier::Weak,
- state_space: ast::StateSpace::Param,
- caching: ast::StCacheOperator::Writethrough,
- typ: arg.v_type.clone(),
- },
- arguments: ast::StArgs {
- src1: old_name,
- src2: arg.name,
- },
- }));
+ ptx_parser::StateSpace::Reg => {}
+ _ => return Err(error_unreachable()),
}
- ptx_parser::StateSpace::Reg => {}
- _ => return Err(error_unreachable()),
}
}
- if remap_returns.is_empty() {
- return Ok(method);
- }
let body = method
.body
.map(|statements| {
for statement in statements {
- run_statement(&remap_returns, &mut body, statement)?;
+ run_statement(resolver, &remap_returns, &mut body, statement)?;
}
Ok::<_, TranslateError>(body)
})
@@ -110,28 +106,89 @@ fn run_method<'input>( }
fn run_statement<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<(), TranslateError> {
match statement {
- Statement::Instruction(ast::Instruction::Ret { .. }) => {
- for (old_name, new_name, type_) in remap_returns.iter().cloned() {
+ Statement::Instruction(ast::Instruction::Call {
+ mut data,
+ mut arguments,
+ }) => {
+ let mut post_st = Vec::new();
+ for ((type_, space), ident) in data
+ .input_arguments
+ .iter_mut()
+ .zip(arguments.input_arguments.iter_mut())
+ {
+ if *space == ptx_parser::StateSpace::Param {
+ *space = ptx_parser::StateSpace::Reg;
+ let old_name = *ident;
+ *ident = resolver
+ .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
+ result.push(Statement::Instruction(ast::Instruction::Ld {
+ data: ast::LdDetails {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_.clone(),
+ non_coherent: false,
+ },
+ arguments: ast::LdArgs {
+ dst: *ident,
+ src: old_name,
+ },
+ }));
+ }
+ }
+ for ((type_, space), ident) in data
+ .return_arguments
+ .iter_mut()
+ .zip(arguments.return_arguments.iter_mut())
+ {
+ if *space == ptx_parser::StateSpace::Param {
+ *space = ptx_parser::StateSpace::Reg;
+ let old_name = *ident;
+ *ident = resolver
+ .register_unnamed(Some((type_.clone(), ptx_parser::StateSpace::Reg)));
+ post_st.push(Statement::Instruction(ast::Instruction::St {
+ data: ast::StData {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: type_.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: old_name,
+ src2: *ident,
+ },
+ }));
+ }
+ }
+ result.push(Statement::Instruction(ast::Instruction::Call {
+ data,
+ arguments,
+ }));
+ result.extend(post_st.into_iter());
+ }
+ Statement::Instruction(ast::Instruction::Ret { data }) => {
+ for (old_name, new_name, type_) in remap_returns.iter() {
result.push(Statement::Instruction(ast::Instruction::Ld {
data: ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
- state_space: ast::StateSpace::Reg,
+ state_space: ast::StateSpace::Param,
caching: ast::LdCacheOperator::Cached,
- typ: type_,
+ typ: type_.clone(),
non_coherent: false,
},
arguments: ast::LdArgs {
- dst: new_name,
- src: old_name,
+ dst: *new_name,
+ src: *old_name,
},
}));
}
- result.push(statement);
+ result.push(Statement::Instruction(ast::Instruction::Ret { data }));
}
statement => {
result.push(statement);
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 235ad7d..fa011a3 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -18,16 +18,23 @@ // while with plain LLVM-C it's just:
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
-use std::convert::{TryFrom, TryInto};
-use std::ffi::CStr;
+// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete.
+// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with
+// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all"
+// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel",
+// but it will too fail similarly, but with "unable to legalize instruction"
+
+use std::array::TryFromSliceError;
+use std::convert::TryInto;
+use std::ffi::{CStr, NulError};
use std::ops::Deref;
-use std::ptr;
+use std::{i8, ptr};
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
-use llvm_zluda::core::*;
-use llvm_zluda::prelude::*;
+use llvm_zluda::{core::*, *};
+use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
const LLVM_UNNAMED: &CStr = c"";
@@ -172,7 +179,7 @@ pub(super) fn run<'input>( let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
- Directive2::Variable(..) => todo!(),
+ Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?,
Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
@@ -228,15 +235,18 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { })
.ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
- let fn_type = get_function_type(
- self.context,
- func_decl.return_arguments.iter().map(|v| &v.v_type),
- func_decl
- .input_arguments
- .iter()
- .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
- )?;
- let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
+ if fn_ == ptr::null_mut() {
+ let fn_type = get_function_type(
+ self.context,
+ func_decl.return_arguments.iter().map(|v| &v.v_type),
+ func_decl
+ .input_arguments
+ .iter()
+ .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
+ )?;
+ fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ }
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
}
@@ -274,6 +284,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
+ for var in func_decl.return_arguments {
+ method_emitter.emit_variable(var)?;
+ }
+ for statement in statements.iter() {
+ if let Statement::Label(label) = statement {
+ method_emitter.emit_label_initial(*label);
+ }
+ }
for statement in statements {
method_emitter.emit_statement(statement)?;
}
@@ -281,43 +299,146 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { }
Ok(())
}
+
+ fn emit_global(
+ &mut self,
+ _linking: ast::LinkingDirective,
+ var: ast::Variable<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let name = self
+ .id_defs
+ .ident_map
+ .get(&var.name)
+ .map(|entry| {
+ entry
+ .name
+ .as_ref()
+ .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?)))
+ })
+ .flatten()
+ .transpose()
+ .map_err(|_| error_unreachable())?
+ .unwrap_or(Cow::Borrowed(LLVM_UNNAMED));
+ let global = unsafe {
+ LLVMAddGlobalInAddressSpace(
+ self.module,
+ get_type(self.context, &var.v_type)?,
+ name.as_ptr(),
+ get_state_space(var.state_space)?,
+ )
+ };
+ self.resolver.register(var.name, global);
+ if let Some(align) = var.align {
+ unsafe { LLVMSetAlignment(global, align) };
+ }
+ if !var.array_init.is_empty() {
+ self.emit_array_init(&var.v_type, &*var.array_init, global)?;
+ }
+ Ok(())
+ }
+
+ // TODO: instead of Vec<u8> we should emit a typed initializer
+ fn emit_array_init(
+ &mut self,
+ type_: &ast::Type,
+ array_init: &[u8],
+ global: *mut llvm_zluda::LLVMValue,
+ ) -> Result<(), TranslateError> {
+ match type_ {
+ ast::Type::Array(None, scalar, dimensions) => {
+ if dimensions.len() != 1 {
+ todo!()
+ }
+ if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() {
+ return Err(error_unreachable());
+ }
+ let type_ = get_scalar_type(self.context, *scalar);
+ let mut elements = array_init
+ .chunks(scalar.size_of() as usize)
+ .map(|chunk| self.constant_from_bytes(*scalar, chunk, type_))
+ .collect::<Result<Vec<_>, _>>()
+ .map_err(|_| error_unreachable())?;
+ let initializer =
+ unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) };
+ unsafe { LLVMSetInitializer(global, initializer) };
+ }
+ _ => todo!(),
+ }
+ Ok(())
+ }
+
+ fn constant_from_bytes(
+ &self,
+ scalar: ast::ScalarType,
+ bytes: &[u8],
+ llvm_type: LLVMTypeRef,
+ ) -> Result<LLVMValueRef, TryFromSliceError> {
+ Ok(match scalar {
+ ptx_parser::ScalarType::Pred
+ | ptx_parser::ScalarType::S8
+ | ptx_parser::ScalarType::B8
+ | ptx_parser::ScalarType::U8 => unsafe {
+ LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0)
+ },
+ ptx_parser::ScalarType::S16
+ | ptx_parser::ScalarType::B16
+ | ptx_parser::ScalarType::U16 => unsafe {
+ LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
+ },
+ ptx_parser::ScalarType::S32
+ | ptx_parser::ScalarType::B32
+ | ptx_parser::ScalarType::U32 => unsafe {
+ LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0)
+ },
+ ptx_parser::ScalarType::F16 => todo!(),
+ ptx_parser::ScalarType::BF16 => todo!(),
+ ptx_parser::ScalarType::U64 => todo!(),
+ ptx_parser::ScalarType::S64 => todo!(),
+ ptx_parser::ScalarType::S16x2 => todo!(),
+ ptx_parser::ScalarType::F32 => todo!(),
+ ptx_parser::ScalarType::B64 => todo!(),
+ ptx_parser::ScalarType::F64 => todo!(),
+ ptx_parser::ScalarType::B128 => todo!(),
+ ptx_parser::ScalarType::U16x2 => todo!(),
+ ptx_parser::ScalarType::F16x2 => todo!(),
+ ptx_parser::ScalarType::BF16x2 => todo!(),
+ })
+ }
}
fn get_input_argument_type(
context: LLVMContextRef,
- v_type: &ptx_parser::Type,
- state_space: ptx_parser::StateSpace,
+ v_type: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<LLVMTypeRef, TranslateError> {
match state_space {
- ptx_parser::StateSpace::ParamEntry => {
+ ast::StateSpace::ParamEntry => {
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
}
- ptx_parser::StateSpace::Reg => get_type(context, v_type),
+ ast::StateSpace::Reg => get_type(context, v_type),
_ => return Err(error_unreachable()),
}
}
-struct MethodEmitContext<'a, 'input> {
+struct MethodEmitContext<'a> {
context: LLVMContextRef,
module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
- id_defs: &'a GlobalStringIdentResolver2<'input>,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
-impl<'a, 'input> MethodEmitContext<'a, 'input> {
- fn new<'x>(
- parent: &'a mut ModuleEmitContext<'x, 'input>,
+impl<'a> MethodEmitContext<'a> {
+ fn new(
+ parent: &'a mut ModuleEmitContext,
method: LLVMValueRef,
variables_builder: Builder,
- ) -> MethodEmitContext<'a, 'input> {
+ ) -> MethodEmitContext<'a> {
MethodEmitContext {
context: parent.context,
module: parent.module,
builder: parent.builder.get(),
- id_defs: parent.id_defs,
variables_builder,
resolver: &mut parent.resolver,
method,
@@ -330,18 +451,17 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ) -> Result<(), TranslateError> {
Ok(match statement {
Statement::Variable(var) => self.emit_variable(var)?,
- Statement::Label(label) => self.emit_label(label),
+ Statement::Label(label) => self.emit_label_delayed(label)?,
Statement::Instruction(inst) => self.emit_instruction(inst)?,
- Statement::Conditional(_) => todo!(),
- Statement::LoadVar(var) => self.emit_load_variable(var)?,
- Statement::StoreVar(store) => self.emit_store_var(store)?,
+ Statement::Conditional(cond) => self.emit_conditional(cond)?,
Statement::Conversion(conversion) => self.emit_conversion(conversion)?,
Statement::Constant(constant) => self.emit_constant(constant)?,
- Statement::RetValue(_, _) => todo!(),
- Statement::PtrAccess(_) => todo!(),
- Statement::RepackVector(_) => todo!(),
+ Statement::RetValue(_, values) => self.emit_ret_value(values)?,
+ Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
+ Statement::RepackVector(repack) => self.emit_vector_repack(repack)?,
Statement::FunctionPointer(_) => todo!(),
- Statement::VectorAccess(_) => todo!(),
+ Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
+ Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
})
}
@@ -364,7 +484,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(())
}
- fn emit_label(&mut self, label: SpirvWord) {
+ fn emit_label_initial(&mut self, label: SpirvWord) {
let block = unsafe {
LLVMAppendBasicBlockInContext(
self.context,
@@ -372,17 +492,18 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { self.resolver.get_or_add_raw(label),
)
};
+ self.resolver
+ .register(label, unsafe { LLVMBasicBlockAsValue(block) });
+ }
+
+ fn emit_label_delayed(&mut self, label: SpirvWord) -> Result<(), TranslateError> {
+ let block = self.resolver.value(label)?;
+ let block = unsafe { LLVMValueAsBasicBlock(block) };
let last_block = unsafe { LLVMGetInsertBlock(self.builder) };
if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() {
unsafe { LLVMBuildBr(self.builder, block) };
}
unsafe { LLVMPositionBuilderAtEnd(self.builder, block) };
- }
-
- fn emit_store_var(&mut self, store: StoreVarDetails) -> Result<(), TranslateError> {
- let ptr = self.resolver.value(store.arg.src1)?;
- let value = self.resolver.value(store.arg.src2)?;
- unsafe { LLVMBuildStore(self.builder, value, ptr) };
Ok(())
}
@@ -395,50 +516,51 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
- ast::Instruction::Mul { data, arguments } => todo!(),
- ast::Instruction::Setp { data, arguments } => todo!(),
- ast::Instruction::SetpBool { data, arguments } => todo!(),
- ast::Instruction::Not { data, arguments } => todo!(),
- ast::Instruction::Or { data, arguments } => todo!(),
- ast::Instruction::And { data, arguments } => todo!(),
- ast::Instruction::Bra { arguments } => todo!(),
+ ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
+ ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
+ ast::Instruction::SetpBool { .. } => todo!(),
+ ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
+ ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments),
+ ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
+ ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
- ast::Instruction::Cvt { data, arguments } => todo!(),
- ast::Instruction::Shr { data, arguments } => todo!(),
- ast::Instruction::Shl { data, arguments } => todo!(),
+ ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments),
+ ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments),
+ ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
- ast::Instruction::Cvta { data, arguments } => todo!(),
- ast::Instruction::Abs { data, arguments } => todo!(),
- ast::Instruction::Mad { data, arguments } => todo!(),
- ast::Instruction::Fma { data, arguments } => todo!(),
- ast::Instruction::Sub { data, arguments } => todo!(),
- ast::Instruction::Min { data, arguments } => todo!(),
- ast::Instruction::Max { data, arguments } => todo!(),
- ast::Instruction::Rcp { data, arguments } => todo!(),
- ast::Instruction::Sqrt { data, arguments } => todo!(),
- ast::Instruction::Rsqrt { data, arguments } => todo!(),
- ast::Instruction::Selp { data, arguments } => todo!(),
- ast::Instruction::Bar { data, arguments } => todo!(),
- ast::Instruction::Atom { data, arguments } => todo!(),
- ast::Instruction::AtomCas { data, arguments } => todo!(),
- ast::Instruction::Div { data, arguments } => todo!(),
- ast::Instruction::Neg { data, arguments } => todo!(),
- ast::Instruction::Sin { data, arguments } => todo!(),
- ast::Instruction::Cos { data, arguments } => todo!(),
- ast::Instruction::Lg2 { data, arguments } => todo!(),
- ast::Instruction::Ex2 { data, arguments } => todo!(),
- ast::Instruction::Clz { data, arguments } => todo!(),
- ast::Instruction::Brev { data, arguments } => todo!(),
- ast::Instruction::Popc { data, arguments } => todo!(),
- ast::Instruction::Xor { data, arguments } => todo!(),
- ast::Instruction::Rem { data, arguments } => todo!(),
- ast::Instruction::Bfe { data, arguments } => todo!(),
- ast::Instruction::Bfi { data, arguments } => todo!(),
- ast::Instruction::PrmtSlow { arguments } => todo!(),
- ast::Instruction::Prmt { data, arguments } => todo!(),
- ast::Instruction::Activemask { arguments } => todo!(),
- ast::Instruction::Membar { data } => todo!(),
+ ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
+ ast::Instruction::Abs { .. } => todo!(),
+ ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments),
+ ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments),
+ ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
+ ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments),
+ ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments),
+ ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
+ ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
+ ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
+ ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments),
+ ast::Instruction::Bar { .. } => todo!(),
+ ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
+ ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
+ ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
+ ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
+ ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
+ ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
+ ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments),
+ ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments),
+ ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments),
+ ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
+ ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments),
+ ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
+ ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
+ ast::Instruction::PrmtSlow { .. } => todo!(),
+ ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
+ ast::Instruction::Membar { data } => self.emit_membar(data),
ast::Instruction::Trap {} => todo!(),
+ // replaced by a function call
+ ast::Instruction::Bfe { .. }
+ | ast::Instruction::Bfi { .. }
+ | ast::Instruction::Activemask { .. } => return Err(error_unreachable()),
}
}
@@ -447,9 +569,6 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { data: ast::LdDetails,
arguments: ast::LdArgs<SpirvWord>,
) -> Result<(), TranslateError> {
- if data.non_coherent {
- todo!()
- }
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
@@ -462,24 +581,25 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(())
}
- fn emit_load_variable(&mut self, var: LoadVarDetails) -> Result<(), TranslateError> {
- if var.member_index.is_some() {
- todo!()
- }
- let builder = self.builder;
- let type_ = get_type(self.context, &var.typ)?;
- let ptr = self.resolver.value(var.arg.src)?;
- self.resolver.with_result(var.arg.dst, |dst| unsafe {
- LLVMBuildLoad2(builder, type_, ptr, dst)
- });
- Ok(())
- }
-
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
let builder = self.builder;
match conversion.kind {
- ConversionKind::Default => todo!(),
- ConversionKind::SignExtend => todo!(),
+ ConversionKind::Default => self.emit_conversion_default(
+ self.resolver.value(conversion.src)?,
+ conversion.dst,
+ &conversion.from_type,
+ conversion.from_space,
+ &conversion.to_type,
+ conversion.to_space,
+ ),
+ ConversionKind::SignExtend => {
+ let src = self.resolver.value(conversion.src)?;
+ let type_ = get_type(self.context, &conversion.to_type)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildSExt(builder, src, type_, dst)
+ });
+ Ok(())
+ }
ConversionKind::BitToPtr => {
let src = self.resolver.value(conversion.src)?;
let type_ = get_pointer_type(self.context, conversion.to_space)?;
@@ -488,8 +608,131 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { });
Ok(())
}
- ConversionKind::PtrToPtr => todo!(),
- ConversionKind::AddressOf => todo!(),
+ ConversionKind::PtrToPtr => {
+ let src = self.resolver.value(conversion.src)?;
+ let dst_type = get_pointer_type(self.context, conversion.to_space)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildAddrSpaceCast(builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+ ConversionKind::AddressOf => {
+ let src = self.resolver.value(conversion.src)?;
+ let dst_type = get_type(self.context, &conversion.to_type)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildPtrToInt(self.builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+ }
+ }
+
+ fn emit_conversion_default(
+ &mut self,
+ src: LLVMValueRef,
+ dst: SpirvWord,
+ from_type: &ast::Type,
+ from_space: ast::StateSpace,
+ to_type: &ast::Type,
+ to_space: ast::StateSpace,
+ ) -> Result<(), TranslateError> {
+ match (from_type, to_type) {
+ (ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => {
+ let from_layout = from_type.layout();
+ let to_layout = to_type.layout();
+ if from_layout.size() == to_layout.size() {
+ let dst_type = get_type(self.context, &to_type)?;
+ if from_type.kind() != ast::ScalarKind::Float
+ && to_type_scalar.kind() != ast::ScalarKind::Float
+ {
+ // It is noop, but another instruction expects result of this conversion
+ self.resolver.register(dst, src);
+ } else {
+ self.resolver.with_result(dst, |dst| unsafe {
+ LLVMBuildBitCast(self.builder, src, dst_type, dst)
+ });
+ }
+ Ok(())
+ } else {
+ // This block is safe because it's illegal to implictly convert between floating point values
+ let same_width_bit_type = unsafe {
+ LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
+ };
+ let same_width_bit_value = unsafe {
+ LLVMBuildBitCast(
+ self.builder,
+ src,
+ same_width_bit_type,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let wide_bit_type = match to_type_scalar.layout().size() {
+ 1 => ast::ScalarType::B8,
+ 2 => ast::ScalarType::B16,
+ 4 => ast::ScalarType::B32,
+ 8 => ast::ScalarType::B64,
+ _ => return Err(error_unreachable()),
+ };
+ let wide_bit_type_llvm = unsafe {
+ LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
+ };
+ if to_type_scalar.kind() == ast::ScalarKind::Unsigned
+ || to_type_scalar.kind() == ast::ScalarKind::Bit
+ {
+ let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() {
+ LLVMBuildZExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ };
+ self.resolver.with_result(dst, |dst| unsafe {
+ llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst)
+ });
+ Ok(())
+ } else {
+ let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
+ && to_type_scalar.kind() == ast::ScalarKind::Signed
+ {
+ if to_type_scalar.size_of() >= from_type.size_of() {
+ LLVMBuildSExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ }
+ } else {
+ if to_type_scalar.size_of() >= from_type.size_of() {
+ LLVMBuildZExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ }
+ };
+ let wide_bit_value = unsafe {
+ conversion_fn(
+ self.builder,
+ same_width_bit_value,
+ wide_bit_type_llvm,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ self.emit_conversion_default(
+ wide_bit_value,
+ dst,
+ &wide_bit_type.into(),
+ from_space,
+ to_type,
+ to_space,
+ )
+ }
+ }
+ }
+ (ast::Type::Vector(..), ast::Type::Scalar(..))
+ | (ast::Type::Scalar(..), ast::Type::Array(..))
+ | (ast::Type::Array(..), ast::Type::Scalar(..)) => {
+ let dst_type = get_type(self.context, to_type)?;
+ self.resolver.with_result(dst, |dst| unsafe {
+ LLVMBuildBitCast(self.builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+ _ => todo!(),
}
}
@@ -514,8 +757,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let fn_ = match data {
- ast::ArithDetails::Integer(integer) => LLVMBuildAdd,
- ast::ArithDetails::Float(float) => LLVMBuildFAdd,
+ ast::ArithDetails::Integer(..) => LLVMBuildAdd,
+ ast::ArithDetails::Float(..) => LLVMBuildFAdd,
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
fn_(builder, src1, src2, dst)
@@ -525,8 +768,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_st(
&self,
- data: ptx_parser::StData,
- arguments: ptx_parser::StArgs<SpirvWord>,
+ data: ast::StData,
+ arguments: ast::StArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let ptr = self.resolver.value(arguments.src1)?;
let value = self.resolver.value(arguments.src2)?;
@@ -537,14 +780,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(())
}
- fn emit_ret(&self, _data: ptx_parser::RetData) {
+ fn emit_ret(&self, _data: ast::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
fn emit_call(
&mut self,
- data: ptx_parser::CallDetails,
- arguments: ptx_parser::CallArgs<SpirvWord>,
+ data: ast::CallDetails,
+ arguments: ast::CallArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if cfg!(debug_assertions) {
for (_, space) in data.return_arguments.iter() {
@@ -558,14 +801,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }
}
}
- let name = match (&*data.return_arguments, &*arguments.return_arguments) {
- ([], []) => LLVM_UNNAMED.as_ptr(),
- ([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
+ let name = match &*arguments.return_arguments {
+ [] => LLVM_UNNAMED.as_ptr(),
+ [dst] => self.resolver.get_or_add_raw(*dst),
_ => todo!(),
};
let type_ = get_function_type(
self.context,
- data.return_arguments.iter().map(|(type_, space)| type_),
+ data.return_arguments.iter().map(|(type_, ..)| type_),
data.input_arguments
.iter()
.map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
@@ -597,13 +840,1380 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_mov(
&mut self,
- _data: ptx_parser::MovDetails,
- arguments: ptx_parser::MovArgs<SpirvWord>,
+ _data: ast::MovDetails,
+ arguments: ast::MovArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.resolver
.register(arguments.dst, self.resolver.value(arguments.src)?);
Ok(())
}
+
+ fn emit_ptr_access(&mut self, ptr_access: PtrAccess<SpirvWord>) -> Result<(), TranslateError> {
+ let ptr_src = self.resolver.value(ptr_access.ptr_src)?;
+ let mut offset_src = self.resolver.value(ptr_access.offset_src)?;
+ let pointee_type = get_scalar_type(self.context, ast::ScalarType::B8);
+ self.resolver.with_result(ptr_access.dst, |dst| unsafe {
+ LLVMBuildInBoundsGEP2(self.builder, pointee_type, ptr_src, &mut offset_src, 1, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_and(&mut self, arguments: ast::AndArgs<SpirvWord>) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildAnd(builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_atom(
+ &mut self,
+ data: ast::AtomDetails,
+ arguments: ast::AtomArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let op = match data.op {
+ ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
+ ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
+ ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
+ ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
+ ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
+ ast::AtomicOp::IncrementWrap => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
+ }
+ ast::AtomicOp::DecrementWrap => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
+ }
+ ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
+ ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin,
+ ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
+ ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax,
+ ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
+ ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
+ ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
+ };
+ self.resolver.register(arguments.dst, unsafe {
+ LLVMZludaBuildAtomicRMW(
+ builder,
+ op,
+ src1,
+ src2,
+ get_scope(data.scope)?,
+ get_ordering(data.semantics),
+ )
+ });
+ Ok(())
+ }
+
+ fn emit_atom_cas(
+ &mut self,
+ data: ast::AtomCasDetails,
+ arguments: ast::AtomCasArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let src3 = self.resolver.value(arguments.src3)?;
+ let success_ordering = get_ordering(data.semantics);
+ let failure_ordering = get_ordering_failure(data.semantics);
+ let temp = unsafe {
+ LLVMZludaBuildAtomicCmpXchg(
+ self.builder,
+ src1,
+ src2,
+ src3,
+ get_scope(data.scope)?,
+ success_ordering,
+ failure_ordering,
+ )
+ };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildExtractValue(self.builder, temp, 0, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_bra(&self, arguments: ast::BraArgs<SpirvWord>) -> Result<(), TranslateError> {
+ let src = self.resolver.value(arguments.src)?;
+ let src = unsafe { LLVMValueAsBasicBlock(src) };
+ unsafe { LLVMBuildBr(self.builder, src) };
+ Ok(())
+ }
+
+ fn emit_brev(
+ &mut self,
+ data: ast::ScalarType,
+ arguments: ast::BrevArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_fn = match data.size_of() {
+ 4 => c"llvm.bitreverse.i32",
+ 8 => c"llvm.bitreverse.i64",
+ _ => return Err(error_unreachable()),
+ };
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
+ let type_ = get_scalar_type(self.context, data);
+ let fn_type = get_function_type(
+ self.context,
+ iter::once(&data.into()),
+ iter::once(Ok(type_)),
+ )?;
+ if fn_ == ptr::null_mut() {
+ fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
+ }
+ let mut src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_ret_value(
+ &mut self,
+ values: Vec<(SpirvWord, ptx_parser::Type)>,
+ ) -> Result<(), TranslateError> {
+ match &*values {
+ [] => unsafe { LLVMBuildRetVoid(self.builder) },
+ [(value, type_)] => {
+ let value = self.resolver.value(*value)?;
+ let type_ = get_type(self.context, type_)?;
+ let value =
+ unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) };
+ unsafe { LLVMBuildRet(self.builder, value) }
+ }
+ _ => todo!(),
+ };
+ Ok(())
+ }
+
+ fn emit_clz(
+ &mut self,
+ data: ptx_parser::ScalarType,
+ arguments: ptx_parser::ClzArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_fn = match data.size_of() {
+ 4 => c"llvm.ctlz.i32",
+ 8 => c"llvm.ctlz.i64",
+ _ => return Err(error_unreachable()),
+ };
+ let type_ = get_scalar_type(self.context, data.into());
+ let pred = get_scalar_type(self.context, ast::ScalarType::Pred);
+ let fn_type = get_function_type(
+ self.context,
+ iter::once(&ast::ScalarType::U32.into()),
+ [Ok(type_), Ok(pred)].into_iter(),
+ )?;
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
+ if fn_ == ptr::null_mut() {
+ fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
+ }
+ let src = self.resolver.value(arguments.src)?;
+ let false_ = unsafe { LLVMConstInt(pred, 0, 0) };
+ let mut args = [src, false_];
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ fn_type,
+ fn_,
+ args.as_mut_ptr(),
+ args.len() as u32,
+ dst,
+ )
+ });
+ Ok(())
+ }
+
+ fn emit_mul(
+ &mut self,
+ data: ast::MulDetails,
+ arguments: ast::MulArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?;
+ Ok(())
+ }
+
+ fn emit_mul_impl(
+ &mut self,
+ data: ast::MulDetails,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<LLVMValueRef, TranslateError> {
+ let mul_fn = match data {
+ ast::MulDetails::Integer { control, type_ } => match control {
+ ast::MulIntControl::Low => LLVMBuildMul,
+ ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2),
+ ast::MulIntControl::Wide => {
+ return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1)
+ }
+ },
+ ast::MulDetails::Float(..) => LLVMBuildFMul,
+ };
+ let src1 = self.resolver.value(src1)?;
+ let src2 = self.resolver.value(src2)?;
+ Ok(self
+ .resolver
+ .with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) }))
+ }
+
+ fn emit_mul_high(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<LLVMValueRef, TranslateError> {
+ let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?;
+ let shift_constant =
+ unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) };
+ let shifted = unsafe {
+ LLVMBuildLShr(
+ self.builder,
+ wide_value,
+ shift_constant,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let narrow_type = get_scalar_type(self.context, type_);
+ Ok(self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildTrunc(self.builder, shifted, narrow_type, dst)
+ }))
+ }
+
+ fn emit_mul_wide_impl(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> {
+ let src1 = self.resolver.value(src1)?;
+ let src2 = self.resolver.value(src2)?;
+ let wide_type =
+ unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) };
+ let llvm_cast = match type_.kind() {
+ ptx_parser::ScalarKind::Signed => LLVMBuildSExt,
+ ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt,
+ _ => return Err(error_unreachable()),
+ };
+ let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) };
+ let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) };
+ Ok((
+ wide_type,
+ self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildMul(self.builder, src1, src2, dst)
+ }),
+ ))
+ }
+
+ fn emit_cos(
+ &mut self,
+ _data: ast::FlushToZero,
+ arguments: ast::CosArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
+ let cos = self.emit_intrinsic(
+ c"llvm.cos.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(self.resolver.value(arguments.src)?, llvm_f32)],
+ )?;
+ unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
+ Ok(())
+ }
+
+ fn emit_or(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::OrArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildOr(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_xor(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::XorArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildXor(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> {
+ let src = self.resolver.value(vec_acccess.vector_src)?;
+ let index = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::B8),
+ vec_acccess.member as _,
+ 0,
+ )
+ };
+ self.resolver
+ .with_result(vec_acccess.scalar_dst, |dst| unsafe {
+ LLVMBuildExtractElement(self.builder, src, index, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> {
+ let vector_src = self.resolver.value(vector_write.vector_src)?;
+ let scalar_src = self.resolver.value(vector_write.scalar_src)?;
+ let index = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::B8),
+ vector_write.member as _,
+ 0,
+ )
+ };
+ self.resolver
+ .with_result(vector_write.vector_dst, |dst| unsafe {
+ LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> {
+ let i8_type = get_scalar_type(self.context, ast::ScalarType::B8);
+ if repack.is_extract {
+ let src = self.resolver.value(repack.packed)?;
+ for (index, dst) in repack.unpacked.iter().enumerate() {
+ let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) };
+ self.resolver.with_result(*dst, |dst| unsafe {
+ LLVMBuildExtractElement(self.builder, src, index, dst)
+ });
+ }
+ } else {
+ let vector_type = get_type(
+ self.context,
+ &ast::Type::Vector(repack.unpacked.len() as u8, repack.typ),
+ )?;
+ let mut temp_vec = unsafe { LLVMGetUndef(vector_type) };
+ for (index, src_id) in repack.unpacked.iter().enumerate() {
+ let dst = if index == repack.unpacked.len() - 1 {
+ Some(repack.packed)
+ } else {
+ None
+ };
+ let scalar_src = self.resolver.value(*src_id)?;
+ let index = unsafe { LLVMConstInt(i8_type, index as _, 0) };
+ temp_vec = self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst)
+ });
+ }
+ }
+ Ok(())
+ }
+
+ fn emit_div(
+ &mut self,
+ data: ptx_parser::DivDetails,
+ arguments: ptx_parser::DivArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let integer_div = match data {
+ ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv,
+ ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv,
+ ptx_parser::DivDetails::Float(float_div) => {
+ return self.emit_div_float(float_div, arguments)
+ }
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ integer_div(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_div_float(
+ &mut self,
+ float_div: ptx_parser::DivFloatDetails,
+ arguments: ptx_parser::DivArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let _rnd = match float_div.kind {
+ ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven,
+ ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven,
+ ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode,
+ };
+ let approx = match float_div.kind {
+ ptx_parser::DivFloatKind::Approx => {
+ LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc
+ }
+ ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone,
+ ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone,
+ };
+ let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFDiv(builder, src1, src2, dst)
+ });
+ unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) };
+ if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind {
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div:
+ // div.full.f32 implements a relatively fast, full-range approximation that scales
+ // operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not
+ // support rounding modifiers. The maximum ulp error is 2 across the full range of
+ // inputs.
+ // https://llvm.org/docs/LangRef.html#fpmath-metadata
+ let fpmath_value =
+ unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) };
+ let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) };
+ let mut md_node_content = [fpmath_value];
+ let md_node = unsafe {
+ LLVMMDNodeInContext2(
+ self.context,
+ md_node_content.as_mut_ptr(),
+ md_node_content.len(),
+ )
+ };
+ let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) };
+ let kind = unsafe {
+ LLVMGetMDKindIDInContext(
+ self.context,
+ "fpmath".as_ptr().cast(),
+ "fpmath".len() as u32,
+ )
+ };
+ unsafe { LLVMSetMetadata(fdiv, kind, md_node) };
+ }
+ Ok(())
+ }
+
+ fn emit_cvta(
+ &mut self,
+ data: ptx_parser::CvtaDetails,
+ arguments: ptx_parser::CvtaArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let (from_space, to_space) = match data.direction {
+ ptx_parser::CvtaDirection::GenericToExplicit => {
+ (ast::StateSpace::Generic, data.state_space)
+ }
+ ptx_parser::CvtaDirection::ExplicitToGeneric => {
+ (data.state_space, ast::StateSpace::Generic)
+ }
+ };
+ let from_type = get_pointer_type(self.context, from_space)?;
+ let dest_type = get_pointer_type(self.context, to_space)?;
+ let src = self.resolver.value(arguments.src)?;
+ let temp_ptr =
+ unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sub(
+ &mut self,
+ data: ptx_parser::ArithDetails,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ match data {
+ ptx_parser::ArithDetails::Integer(arith_integer) => {
+ self.emit_sub_integer(arith_integer, arguments)
+ }
+ ptx_parser::ArithDetails::Float(arith_float) => {
+ self.emit_sub_float(arith_float, arguments)
+ }
+ }
+ }
+
+ fn emit_sub_integer(
+ &mut self,
+ arith_integer: ptx_parser::ArithInteger,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arith_integer.saturate {
+ todo!()
+ }
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildSub(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sub_float(
+ &mut self,
+ arith_float: ptx_parser::ArithFloat,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arith_float.saturate {
+ todo!()
+ }
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFSub(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sin(
+ &mut self,
+ _data: ptx_parser::FlushToZero,
+ arguments: ptx_parser::SinArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
+ let sin = self.emit_intrinsic(
+ c"llvm.sin.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(self.resolver.value(arguments.src)?, llvm_f32)],
+ )?;
+ unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
+ Ok(())
+ }
+
+ fn emit_intrinsic(
+ &mut self,
+ name: &CStr,
+ dst: Option<SpirvWord>,
+ return_type: &ast::Type,
+ arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
+ ) -> Result<LLVMValueRef, TranslateError> {
+ let fn_type = get_function_type(
+ self.context,
+ iter::once(return_type),
+ arguments.iter().map(|(_, type_)| Ok(*type_)),
+ )?;
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
+ if fn_ == ptr::null_mut() {
+ fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ }
+ let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::<Vec<_>>();
+ Ok(self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ fn_type,
+ fn_,
+ arguments.as_mut_ptr(),
+ arguments.len() as u32,
+ dst,
+ )
+ }))
+ }
+
+ fn emit_neg(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::NegArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src = self.resolver.value(arguments.src)?;
+ let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float {
+ LLVMBuildFNeg
+ } else {
+ LLVMBuildNeg
+ };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ llvm_fn(self.builder, src, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_not(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::NotArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildNot(self.builder, src, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_setp(
+ &mut self,
+ data: ptx_parser::SetpData,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arguments.dst2.is_some() {
+ todo!()
+ }
+ match data.cmp_op {
+ ptx_parser::SetpCompareOp::Integer(setp_compare_int) => {
+ self.emit_setp_int(setp_compare_int, arguments)
+ }
+ ptx_parser::SetpCompareOp::Float(setp_compare_float) => {
+ self.emit_setp_float(setp_compare_float, arguments)
+ }
+ }
+ }
+
+ fn emit_setp_int(
+ &mut self,
+ setp: ptx_parser::SetpCompareInt,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let op = match setp {
+ ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ,
+ ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE,
+ ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT,
+ ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE,
+ ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT,
+ ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE,
+ ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT,
+ ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE,
+ ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT,
+ ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE,
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst1, |dst1| unsafe {
+ LLVMBuildICmp(self.builder, op, src1, src2, dst1)
+ });
+ Ok(())
+ }
+
+ fn emit_setp_float(
+ &mut self,
+ setp: ptx_parser::SetpCompareFloat,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let op = match setp {
+ ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ,
+ ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE,
+ ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT,
+ ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE,
+ ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT,
+ ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE,
+ ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ,
+ ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE,
+ ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT,
+ ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE,
+ ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT,
+ ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE,
+ ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD,
+ ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO,
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst1, |dst1| unsafe {
+ LLVMBuildFCmp(self.builder, op, src1, src2, dst1)
+ });
+ Ok(())
+ }
+
+ fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> {
+ let predicate = self.resolver.value(cond.predicate)?;
+ let if_true = self.resolver.value(cond.if_true)?;
+ let if_false = self.resolver.value(cond.if_false)?;
+ unsafe {
+ LLVMBuildCondBr(
+ self.builder,
+ predicate,
+ LLVMValueAsBasicBlock(if_true),
+ LLVMValueAsBasicBlock(if_false),
+ )
+ };
+ Ok(())
+ }
+
+ fn emit_cvt(
+ &mut self,
+ data: ptx_parser::CvtDetails,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let dst_type = get_scalar_type(self.context, data.to);
+ let llvm_fn = match data.mode {
+ ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
+ ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
+ ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
+ ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
+ ptx_parser::CvtMode::SaturateUnsignedToSigned => {
+ return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments)
+ }
+ ptx_parser::CvtMode::SaturateSignedToUnsigned => {
+ return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments)
+ }
+ ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt,
+ ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc,
+ ptx_parser::CvtMode::FPRound {
+ integer_rounding, ..
+ } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
+ arguments,
+ Some(LLVMBuildFPToSI),
+ )
+ }
+ ptx_parser::CvtMode::SignedFromFP { rounding, .. } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ rounding,
+ arguments,
+ Some(LLVMBuildFPToSI),
+ )
+ }
+ ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ rounding,
+ arguments,
+ Some(LLVMBuildFPToUI),
+ )
+ }
+ ptx_parser::CvtMode::FPFromSigned(_) => todo!(),
+ ptx_parser::CvtMode::FPFromUnsigned(_) => todo!(),
+ };
+ let src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ llvm_fn(self.builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_cvt_unsigned_to_signed_sat(
+ &mut self,
+ from: ptx_parser::ScalarType,
+ to: ptx_parser::ScalarType,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ // This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1,
+ // so if it's downcast to a smaller type, it will be the maximum value
+ // of the smaller type
+ let max_value = match to {
+ ptx_parser::ScalarType::S8 => i8::MAX as u64,
+ ptx_parser::ScalarType::S16 => i16::MAX as u64,
+ ptx_parser::ScalarType::S32 => i32::MAX as u64,
+ ptx_parser::ScalarType::S64 => i64::MAX as u64,
+ _ => return Err(error_unreachable()),
+ };
+ let from_llvm = get_scalar_type(self.context, from);
+ let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
+ let clamped = self.emit_intrinsic(
+ c"llvm.umin",
+ None,
+ &from.into(),
+ vec![
+ (self.resolver.value(arguments.src)?, from_llvm),
+ (max, from_llvm),
+ ],
+ )?;
+ let resize_fn = if to.layout().size() >= from.layout().size() {
+ LLVMBuildSExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ };
+ let to_llvm = get_scalar_type(self.context, to);
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ resize_fn(self.builder, clamped, to_llvm, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_cvt_signed_to_unsigned_sat(
+ &mut self,
+ from: ptx_parser::ScalarType,
+ to: ptx_parser::ScalarType,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let from_llvm = get_scalar_type(self.context, from);
+ let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) };
+ let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
+ let zero_clamped = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![
+ (self.resolver.value(arguments.src)?, from_llvm),
+ (zero, from_llvm),
+ ],
+ )?;
+ // zero_clamped is now unsigned
+ let max_value = match to {
+ ptx_parser::ScalarType::U8 => u8::MAX as u64,
+ ptx_parser::ScalarType::U16 => u16::MAX as u64,
+ ptx_parser::ScalarType::U32 => u32::MAX as u64,
+ ptx_parser::ScalarType::U64 => u64::MAX as u64,
+ _ => return Err(error_unreachable()),
+ };
+ let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
+ let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
+ let fully_clamped = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![(zero_clamped, from_llvm), (max, from_llvm)],
+ )?;
+ let resize_fn = if to.layout().size() >= from.layout().size() {
+ LLVMBuildZExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ };
+ let to_llvm = get_scalar_type(self.context, to);
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ resize_fn(self.builder, fully_clamped, to_llvm, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_cvt_float_to_int(
+ &mut self,
+ from: ast::ScalarType,
+ to: ast::ScalarType,
+ rounding: ast::RoundingMode,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ llvm_cast: Option<
+ unsafe extern "C" fn(
+ arg1: LLVMBuilderRef,
+ Val: LLVMValueRef,
+ DestTy: LLVMTypeRef,
+ Name: *const i8,
+ ) -> LLVMValueRef,
+ >,
+ ) -> Result<(), TranslateError> {
+ let prefix = match rounding {
+ ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
+ ptx_parser::RoundingMode::Zero => "llvm.trunc",
+ ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
+ ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
+ };
+ let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from));
+ let rounded_float = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![(
+ self.resolver.value(arguments.src)?,
+ get_scalar_type(self.context, from),
+ )],
+ )?;
+ if let Some(llvm_cast) = llvm_cast {
+ let to = get_scalar_type(self.context, to);
+ let poisoned_dst =
+ unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFreeze(self.builder, poisoned_dst, dst)
+ });
+ } else {
+ self.resolver.register(arguments.dst, rounded_float);
+ }
+ // Using explicit saturation gives us worse codegen: it explicitly checks for out of bound
+ // values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt_<TO>_<FROM> which
+ // saturates by default and we don't care about NaNs anyway
+ /*
+ let cast_intrinsic = format!(
+ "{}.{}.{}\0",
+ llvm_cast,
+ LLVMTypeDisplay(to),
+ LLVMTypeDisplay(from)
+ );
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &to.into(),
+ vec![(rounded_float, get_scalar_type(self.context, from))],
+ )?;
+ */
+ Ok(())
+ }
+
+ fn emit_rsqrt(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::RsqrtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match data.type_ {
+ ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32",
+ ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(self.resolver.value(arguments.src)?, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_sqrt(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::SqrtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match (data.type_, data.kind) {
+ (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32",
+ (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32",
+ (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(self.resolver.value(arguments.src)?, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_rcp(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::RcpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match (data.type_, data.kind) {
+ (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32",
+ (_, ast::RcpKind::Compliant(rnd)) => {
+ return self.emit_rcp_compliant(data, arguments, rnd)
+ }
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(self.resolver.value(arguments.src)?, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_rcp_compliant(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::RcpArgs<SpirvWord>,
+ _rnd: ast::RoundingMode,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let one = unsafe { LLVMConstReal(type_, 1.0) };
+ let src = self.resolver.value(arguments.src)?;
+ let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFDiv(self.builder, one, src, dst)
+ });
+ unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) };
+ Ok(())
+ }
+
+ fn emit_shr(
+ &mut self,
+ data: ptx_parser::ShrData,
+ arguments: ptx_parser::ShrArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let shift_fn = match data.kind {
+ ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr,
+ ptx_parser::RightShiftKind::Logical => LLVMBuildLShr,
+ };
+ self.emit_shift(
+ data.type_,
+ arguments.dst,
+ arguments.src1,
+ arguments.src2,
+ shift_fn,
+ )
+ }
+
+ fn emit_shl(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ arguments: ptx_parser::ShlArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.emit_shift(
+ type_,
+ arguments.dst,
+ arguments.src1,
+ arguments.src2,
+ LLVMBuildShl,
+ )
+ }
+
+ fn emit_shift(
+ &mut self,
+ type_: ast::ScalarType,
+ dst: SpirvWord,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ llvm_fn: unsafe extern "C" fn(
+ LLVMBuilderRef,
+ LLVMValueRef,
+ LLVMValueRef,
+ *const i8,
+ ) -> LLVMValueRef,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(src1)?;
+ let shift_size = self.resolver.value(src2)?;
+ let integer_bits = type_.layout().size() * 8;
+ let integer_bits_constant = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::U32),
+ integer_bits as u64,
+ 0,
+ )
+ };
+ let should_clamp = unsafe {
+ LLVMBuildICmp(
+ self.builder,
+ LLVMIntPredicate::LLVMIntUGE,
+ shift_size,
+ integer_bits_constant,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let llvm_type = get_scalar_type(self.context, type_);
+ let zero = unsafe { LLVMConstNull(llvm_type) };
+ let normalized_shift_size = if type_.layout().size() >= 4 {
+ unsafe {
+ LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr())
+ }
+ } else {
+ unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) }
+ };
+ let shifted = unsafe {
+ llvm_fn(
+ self.builder,
+ src1,
+ normalized_shift_size,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ self.resolver.with_result(dst, |dst| unsafe {
+ LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_ex2(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::Ex2Args<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = match data.type_ {
+ ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16",
+ ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(
+ self.resolver.value(arguments.src)?,
+ get_scalar_type(self.context, data.type_),
+ )],
+ )?;
+ Ok(())
+ }
+
+ fn emit_lg2(
+ &mut self,
+ _data: ptx_parser::FlushToZero,
+ arguments: ptx_parser::Lg2Args<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.emit_intrinsic(
+ c"llvm.amdgcn.log.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(
+ self.resolver.value(arguments.src)?,
+ get_scalar_type(self.context, ast::ScalarType::F32.into()),
+ )],
+ )?;
+ Ok(())
+ }
+
+ fn emit_selp(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::SelpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let src3 = self.resolver.value(arguments.src3)?;
+ self.resolver.with_result(arguments.dst, |dst_name| unsafe {
+ LLVMBuildSelect(self.builder, src3, src1, src2, dst_name)
+ });
+ Ok(())
+ }
+
+ fn emit_rem(
+ &mut self,
+ data: ptx_parser::ScalarType,
+ arguments: ptx_parser::RemArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_fn = match data.kind() {
+ ptx_parser::ScalarKind::Unsigned => LLVMBuildURem,
+ ptx_parser::ScalarKind::Signed => LLVMBuildSRem,
+ _ => return Err(error_unreachable()),
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst_name| unsafe {
+ llvm_fn(self.builder, src1, src2, dst_name)
+ });
+ Ok(())
+ }
+
+ fn emit_popc(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ arguments: ptx_parser::PopcArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = match type_ {
+ ast::ScalarType::B32 => c"llvm.ctpop.i32",
+ ast::ScalarType::B64 => c"llvm.ctpop.i64",
+ _ => return Err(error_unreachable()),
+ };
+ let llvm_type = get_scalar_type(self.context, type_);
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &type_.into(),
+ vec![(self.resolver.value(arguments.src)?, llvm_type)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_min(
+ &mut self,
+ data: ptx_parser::MinMaxDetails,
+ arguments: ptx_parser::MinArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_prefix = match data {
+ ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin",
+ ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
+ return Err(error_todo())
+ }
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
+ };
+ let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
+ let llvm_type = get_scalar_type(self.context, data.type_());
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_().into(),
+ vec![
+ (self.resolver.value(arguments.src1)?, llvm_type),
+ (self.resolver.value(arguments.src2)?, llvm_type),
+ ],
+ )?;
+ Ok(())
+ }
+
+ fn emit_max(
+ &mut self,
+ data: ptx_parser::MinMaxDetails,
+ arguments: ptx_parser::MaxArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_prefix = match data {
+ ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax",
+ ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
+ return Err(error_todo())
+ }
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
+ };
+ let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
+ let llvm_type = get_scalar_type(self.context, data.type_());
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_().into(),
+ vec![
+ (self.resolver.value(arguments.src1)?, llvm_type),
+ (self.resolver.value(arguments.src2)?, llvm_type),
+ ],
+ )?;
+ Ok(())
+ }
+
+ fn emit_fma(
+ &mut self,
+ data: ptx_parser::ArithFloat,
+ arguments: ptx_parser::FmaArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_));
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![
+ (
+ self.resolver.value(arguments.src1)?,
+ get_scalar_type(self.context, data.type_),
+ ),
+ (
+ self.resolver.value(arguments.src2)?,
+ get_scalar_type(self.context, data.type_),
+ ),
+ (
+ self.resolver.value(arguments.src3)?,
+ get_scalar_type(self.context, data.type_),
+ ),
+ ],
+ )?;
+ Ok(())
+ }
+
+ fn emit_mad(
+ &mut self,
+ data: ptx_parser::MadDetails,
+ arguments: ptx_parser::MadArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let mul_control = match data {
+ ptx_parser::MadDetails::Float(mad_float) => {
+ return self.emit_fma(
+ mad_float,
+ ast::FmaArgs {
+ dst: arguments.dst,
+ src1: arguments.src1,
+ src2: arguments.src2,
+ src3: arguments.src3,
+ },
+ )
+ }
+ ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()),
+ ptx_parser::MadDetails::Integer { type_, control, .. } => {
+ ast::MulDetails::Integer { control, type_ }
+ }
+ };
+ let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?;
+ let src3 = self.resolver.value(arguments.src3)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildAdd(self.builder, temp, src3, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> {
+ unsafe {
+ LLVMZludaBuildFence(
+ self.builder,
+ LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent,
+ get_scope_membar(data)?,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ Ok(())
+ }
+
+ fn emit_prmt(
+ &mut self,
+ control: u16,
+ arguments: ptx_parser::PrmtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let components = [
+ (control >> 0) & 0b1111,
+ (control >> 4) & 0b1111,
+ (control >> 8) & 0b1111,
+ (control >> 12) & 0b1111,
+ ];
+ if components.iter().any(|&c| c > 7) {
+ return Err(TranslateError::Todo);
+ }
+ let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
+ let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;
+ let mut components = [
+ unsafe { LLVMConstInt(u32_type, components[0] as _, 0) },
+ unsafe { LLVMConstInt(u32_type, components[1] as _, 0) },
+ unsafe { LLVMConstInt(u32_type, components[2] as _, 0) },
+ unsafe { LLVMConstInt(u32_type, components[3] as _, 0) },
+ ];
+ let components_indices =
+ unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src1_vector =
+ unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) };
+ let src2 = self.resolver.value(arguments.src2)?;
+ let src2_vector =
+ unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildShuffleVector(
+ self.builder,
+ src1_vector,
+ src2_vector,
+ components_indices,
+ dst,
+ )
+ });
+ Ok(())
+ }
+
+ /*
+ // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
+ // Should be available in LLVM 19
+ fn with_rounding<T>(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T {
+ let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
+ let void_type = unsafe { LLVMVoidTypeInContext(self.context) };
+ let get_rounding = c"llvm.get.rounding";
+ let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) };
+ let mut get_rounding_fn =
+ unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) };
+ if get_rounding_fn == ptr::null_mut() {
+ get_rounding_fn = unsafe {
+ LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type)
+ };
+ }
+ let set_rounding = c"llvm.set.rounding";
+ let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) };
+ let mut set_rounding_fn =
+ unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) };
+ if set_rounding_fn == ptr::null_mut() {
+ set_rounding_fn = unsafe {
+ LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type)
+ };
+ }
+ let mut preserved_rounding_mode = unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ get_rounding_fn_type,
+ get_rounding_fn,
+ ptr::null_mut(),
+ 0,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let mut requested_rounding = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::B32),
+ rounding_to_llvm(rnd) as u64,
+ 0,
+ )
+ };
+ unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ set_rounding_fn_type,
+ set_rounding_fn,
+ &mut requested_rounding,
+ 1,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let result = fn_(self);
+ unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ set_rounding_fn_type,
+ set_rounding_fn,
+ &mut preserved_rounding_mode,
+ 1,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ result
+ }
+ */
}
fn get_pointer_type<'ctx>(
@@ -613,6 +2223,45 @@ fn get_pointer_type<'ctx>( Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
}
+// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
+fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
+ Ok(match scope {
+ ast::MemScope::Cta => c"workgroup-one-as",
+ ast::MemScope::Gpu => c"agent-one-as",
+ ast::MemScope::Sys => c"one-as",
+ ast::MemScope::Cluster => todo!(),
+ }
+ .as_ptr())
+}
+
+fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
+ Ok(match scope {
+ ast::MemScope::Cta => c"workgroup",
+ ast::MemScope::Gpu => c"agent",
+ ast::MemScope::Sys => c"",
+ ast::MemScope::Cluster => todo!(),
+ }
+ .as_ptr())
+}
+
+fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
+ match semantics {
+ ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
+ ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease,
+ ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease,
+ }
+}
+
+fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
+ match semantics {
+ ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
+ ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ }
+}
+
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
@@ -670,8 +2319,7 @@ fn get_function_type<'a>( mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
- let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
- input_args.collect::<Result<Vec<_>, _>>()?;
+ let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,
@@ -747,8 +2395,57 @@ impl ResolveIdent { .ok_or_else(|| error_unreachable())
}
- fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) {
+ fn with_result(
+ &mut self,
+ word: SpirvWord,
+ fn_: impl FnOnce(*const i8) -> LLVMValueRef,
+ ) -> LLVMValueRef {
let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast()));
self.register(word, t);
+ t
+ }
+
+ fn with_result_option(
+ &mut self,
+ word: Option<SpirvWord>,
+ fn_: impl FnOnce(*const i8) -> LLVMValueRef,
+ ) -> LLVMValueRef {
+ match word {
+ Some(word) => self.with_result(word, fn_),
+ None => fn_(LLVM_UNNAMED.as_ptr()),
+ }
+ }
+}
+
+struct LLVMTypeDisplay(ast::ScalarType);
+
+impl std::fmt::Display for LLVMTypeDisplay {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self.0 {
+ ast::ScalarType::Pred => write!(f, "i1"),
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
+ ptx_parser::ScalarType::B128 => write!(f, "i128"),
+ ast::ScalarType::F16 => write!(f, "f16"),
+ ptx_parser::ScalarType::BF16 => write!(f, "bfloat"),
+ ast::ScalarType::F32 => write!(f, "f32"),
+ ast::ScalarType::F64 => write!(f, "f64"),
+ ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"),
+ ast::ScalarType::F16x2 => write!(f, "v2f16"),
+ ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"),
+ }
+ }
+}
+
+/*
+fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
+ match this {
+ ptx_parser::RoundingMode::Zero => 0,
+ ptx_parser::RoundingMode::NearestEven => 1,
+ ptx_parser::RoundingMode::PositiveInf => 2,
+ ptx_parser::RoundingMode::NegativeInf => 3,
}
}
+*/
diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs deleted file mode 100644 index 120a477..0000000 --- a/ptx/src/pass/emit_spirv.rs +++ /dev/null @@ -1,2762 +0,0 @@ -use super::*;
-use half::f16;
-use ptx_parser as ast;
-use rspirv::{binary::Assemble, dr};
-use std::{
- collections::{HashMap, HashSet},
- ffi::CString,
- mem,
-};
-
-pub(super) fn run<'input>(
- mut builder: dr::Builder,
- id_defs: &GlobalStringIdResolver<'input>,
- call_map: MethodsCallMap<'input>,
- denorm_information: HashMap<
- ptx_parser::MethodName<SpirvWord>,
- HashMap<u8, (spirv::FPDenormMode, isize)>,
- >,
- directives: Vec<Directive<'input>>,
-) -> Result<(dr::Module, HashMap<String, KernelInfo>, CString), TranslateError> {
- builder.set_version(1, 3);
- emit_capabilities(&mut builder);
- emit_extensions(&mut builder);
- let opencl_id = emit_opencl_import(&mut builder);
- emit_memory_model(&mut builder);
- let mut map = TypeWordMap::new(&mut builder);
- //emit_builtins(&mut builder, &mut map, &id_defs);
- let mut kernel_info = HashMap::new();
- let (build_options, should_flush_denorms) =
- emit_denorm_build_string(&call_map, &denorm_information);
- let (directives, globals_use_map) = get_globals_use_map(directives);
- emit_directives(
- &mut builder,
- &mut map,
- &id_defs,
- opencl_id,
- should_flush_denorms,
- &call_map,
- globals_use_map,
- directives,
- &mut kernel_info,
- )?;
- Ok((builder.module(), kernel_info, build_options))
-}
-
-fn emit_capabilities(builder: &mut dr::Builder) {
- builder.capability(spirv::Capability::GenericPointer);
- builder.capability(spirv::Capability::Linkage);
- builder.capability(spirv::Capability::Addresses);
- builder.capability(spirv::Capability::Kernel);
- builder.capability(spirv::Capability::Int8);
- builder.capability(spirv::Capability::Int16);
- builder.capability(spirv::Capability::Int64);
- builder.capability(spirv::Capability::Float16);
- builder.capability(spirv::Capability::Float64);
- builder.capability(spirv::Capability::DenormFlushToZero);
- // TODO: re-enable when Intel float control extension works
- //builder.capability(spirv::Capability::FunctionFloatControlINTEL);
-}
-
-// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html
-fn emit_extensions(builder: &mut dr::Builder) {
- // TODO: re-enable when Intel float control extension works
- //builder.extension("SPV_INTEL_float_controls2");
- builder.extension("SPV_KHR_float_controls");
- builder.extension("SPV_KHR_no_integer_wrap_decoration");
-}
-
-fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {
- builder.ext_inst_import("OpenCL.std")
-}
-
-fn emit_memory_model(builder: &mut dr::Builder) {
- builder.memory_model(
- spirv::AddressingModel::Physical64,
- spirv::MemoryModel::OpenCL,
- );
-}
-
-struct TypeWordMap {
- void: spirv::Word,
- complex: HashMap<SpirvType, SpirvWord>,
- constants: HashMap<(SpirvType, u64), SpirvWord>,
-}
-
-impl TypeWordMap {
- fn new(b: &mut dr::Builder) -> TypeWordMap {
- let void = b.type_void(None);
- TypeWordMap {
- void: void,
- complex: HashMap::<SpirvType, SpirvWord>::new(),
- constants: HashMap::new(),
- }
- }
-
- fn void(&self) -> spirv::Word {
- self.void
- }
-
- fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord {
- let key: SpirvScalarKey = t.into();
- self.get_or_add_spirv_scalar(b, key)
- }
-
- fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord {
- *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| {
- SpirvWord(match key {
- SpirvScalarKey::B8 => b.type_int(None, 8, 0),
- SpirvScalarKey::B16 => b.type_int(None, 16, 0),
- SpirvScalarKey::B32 => b.type_int(None, 32, 0),
- SpirvScalarKey::B64 => b.type_int(None, 64, 0),
- SpirvScalarKey::F16 => b.type_float(None, 16),
- SpirvScalarKey::F32 => b.type_float(None, 32),
- SpirvScalarKey::F64 => b.type_float(None, 64),
- SpirvScalarKey::Pred => b.type_bool(None),
- SpirvScalarKey::F16x2 => todo!(),
- })
- })
- }
-
- fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord {
- match t {
- SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
- SpirvType::Pointer(ref typ, storage) => {
- let base = self.get_or_add(b, *typ.clone());
- *self
- .complex
- .entry(t)
- .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0)))
- }
- SpirvType::Vector(typ, len) => {
- let base = self.get_or_add_spirv_scalar(b, typ);
- *self
- .complex
- .entry(t)
- .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32)))
- }
- SpirvType::Array(typ, array_dimensions) => {
- let (base_type, length) = match &*array_dimensions {
- &[] => {
- return self.get_or_add(b, SpirvType::Base(typ));
- }
- &[len] => {
- let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
- let base = self.get_or_add_spirv_scalar(b, typ);
- let len_const = b.constant_u32(u32_type.0, None, len);
- (base, len_const)
- }
- array_dimensions => {
- let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
- let base = self
- .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
- let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]);
- (base, len_const)
- }
- };
- *self
- .complex
- .entry(SpirvType::Array(typ, array_dimensions))
- .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length)))
- }
- SpirvType::Func(ref out_params, ref in_params) => {
- let out_t = match out_params {
- Some(p) => self.get_or_add(b, *p.clone()),
- None => SpirvWord(self.void()),
- };
- let in_t = in_params
- .iter()
- .map(|t| self.get_or_add(b, t.clone()).0)
- .collect::<Vec<_>>();
- *self
- .complex
- .entry(t)
- .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t)))
- }
- SpirvType::Struct(ref underlying) => {
- let underlying_ids = underlying
- .iter()
- .map(|t| self.get_or_add_spirv_scalar(b, *t).0)
- .collect::<Vec<_>>();
- *self
- .complex
- .entry(t)
- .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids)))
- }
- }
- }
-
- fn get_or_add_fn(
- &mut self,
- b: &mut dr::Builder,
- in_params: impl Iterator<Item = SpirvType>,
- mut out_params: impl ExactSizeIterator<Item = SpirvType>,
- ) -> (SpirvWord, SpirvWord) {
- let (out_args, out_spirv_type) = if out_params.len() == 0 {
- (None, SpirvWord(self.void()))
- } else if out_params.len() == 1 {
- let arg_as_key = out_params.next().unwrap();
- (
- Some(Box::new(arg_as_key.clone())),
- self.get_or_add(b, arg_as_key),
- )
- } else {
- // TODO: support multiple return values
- todo!()
- };
- (
- out_spirv_type,
- self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::<Vec<_>>())),
- )
- }
-
- fn get_or_add_constant(
- &mut self,
- b: &mut dr::Builder,
- typ: &ast::Type,
- init: &[u8],
- ) -> Result<SpirvWord, TranslateError> {
- Ok(match typ {
- ast::Type::Scalar(t) => match t {
- ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self
- .get_or_add_constant_single::<u8, _, _>(
- b,
- *t,
- init,
- |v| v as u64,
- |b, result_type, v| b.constant_u32(result_type, None, v as u32),
- ),
- ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self
- .get_or_add_constant_single::<u16, _, _>(
- b,
- *t,
- init,
- |v| v as u64,
- |b, result_type, v| b.constant_u32(result_type, None, v as u32),
- ),
- ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self
- .get_or_add_constant_single::<u32, _, _>(
- b,
- *t,
- init,
- |v| v as u64,
- |b, result_type, v| b.constant_u32(result_type, None, v),
- ),
- ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self
- .get_or_add_constant_single::<u64, _, _>(
- b,
- *t,
- init,
- |v| v,
- |b, result_type, v| b.constant_u64(result_type, None, v),
- ),
- ast::ScalarType::F16 => self.get_or_add_constant_single::<f16, _, _>(
- b,
- *t,
- init,
- |v| unsafe { mem::transmute::<_, u16>(v) } as u64,
- |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()),
- ),
- ast::ScalarType::F32 => self.get_or_add_constant_single::<f32, _, _>(
- b,
- *t,
- init,
- |v| unsafe { mem::transmute::<_, u32>(v) } as u64,
- |b, result_type, v| b.constant_f32(result_type, None, v),
- ),
- ast::ScalarType::F64 => self.get_or_add_constant_single::<f64, _, _>(
- b,
- *t,
- init,
- |v| unsafe { mem::transmute::<_, u64>(v) },
- |b, result_type, v| b.constant_f64(result_type, None, v),
- ),
- ast::ScalarType::F16x2 => return Err(TranslateError::Todo),
- ast::ScalarType::Pred => self.get_or_add_constant_single::<u8, _, _>(
- b,
- *t,
- init,
- |v| v as u64,
- |b, result_type, v| {
- if v == 0 {
- b.constant_false(result_type, None)
- } else {
- b.constant_true(result_type, None)
- }
- },
- ),
- ast::ScalarType::S16x2
- | ast::ScalarType::U16x2
- | ast::ScalarType::BF16
- | ast::ScalarType::BF16x2
- | ast::ScalarType::B128 => todo!(),
- },
- ast::Type::Vector(len, typ) => {
- let result_type =
- self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
- let size_of_t = typ.size_of();
- let components = (0..*len)
- .map(|x| {
- Ok::<_, TranslateError>(
- self.get_or_add_constant(
- b,
- &ast::Type::Scalar(*typ),
- &init[((size_of_t as usize) * (x as usize))..],
- )?
- .0,
- )
- })
- .collect::<Result<Vec<_>, _>>()?;
- SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
- }
- ast::Type::Array(_, typ, dims) => match dims.as_slice() {
- [] => return Err(error_unreachable()),
- [dim] => {
- let result_type = self
- .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
- let size_of_t = typ.size_of();
- let components = (0..*dim)
- .map(|x| {
- Ok::<_, TranslateError>(
- self.get_or_add_constant(
- b,
- &ast::Type::Scalar(*typ),
- &init[((size_of_t as usize) * (x as usize))..],
- )?
- .0,
- )
- })
- .collect::<Result<Vec<_>, _>>()?;
- SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
- }
- [first_dim, rest @ ..] => {
- let result_type = self.get_or_add(
- b,
- SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()),
- );
- let size_of_t = rest
- .iter()
- .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y));
- let components = (0..*first_dim)
- .map(|x| {
- Ok::<_, TranslateError>(
- self.get_or_add_constant(
- b,
- &ast::Type::Array(None, *typ, rest.to_vec()),
- &init[((size_of_t as usize) * (x as usize))..],
- )?
- .0,
- )
- })
- .collect::<Result<Vec<_>, _>>()?;
- SpirvWord(b.constant_composite(result_type.0, None, components.into_iter()))
- }
- },
- ast::Type::Pointer(..) => return Err(error_unreachable()),
- })
- }
-
- fn get_or_add_constant_single<
- T: Copy,
- CastAsU64: FnOnce(T) -> u64,
- InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word,
- >(
- &mut self,
- b: &mut dr::Builder,
- key: ast::ScalarType,
- init: &[u8],
- cast: CastAsU64,
- f: InsertConstant,
- ) -> SpirvWord {
- let value = unsafe { *(init.as_ptr() as *const T) };
- let value_64 = cast(value);
- let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64);
- match self.constants.get(&ht_key) {
- Some(value) => *value,
- None => {
- let spirv_type = self.get_or_add_scalar(b, key);
- let result = SpirvWord(f(b, spirv_type.0, value));
- self.constants.insert(ht_key, result);
- result
- }
- }
- }
-}
-
-#[derive(PartialEq, Eq, Hash, Clone)]
-enum SpirvType {
- Base(SpirvScalarKey),
- Vector(SpirvScalarKey, u8),
- Array(SpirvScalarKey, Vec<u32>),
- Pointer(Box<SpirvType>, spirv::StorageClass),
- Func(Option<Box<SpirvType>>, Vec<SpirvType>),
- Struct(Vec<SpirvScalarKey>),
-}
-
-impl SpirvType {
- fn new(t: ast::Type) -> Self {
- match t {
- ast::Type::Scalar(t) => SpirvType::Base(t.into()),
- ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len),
- ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len),
- ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
- Box::new(SpirvType::Base(pointer_t.into())),
- space_to_spirv(space),
- ),
- }
- }
-
- fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self {
- let key = Self::new(t);
- SpirvType::Pointer(Box::new(key), outer_space)
- }
-}
-
-impl From<ast::ScalarType> for SpirvType {
- fn from(t: ast::ScalarType) -> Self {
- SpirvType::Base(t.into())
- }
-}
-// SPIR-V integer type definitions are signless, more below:
-// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
-// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a
-#[derive(PartialEq, Eq, Hash, Clone, Copy)]
-enum SpirvScalarKey {
- B8,
- B16,
- B32,
- B64,
- F16,
- F32,
- F64,
- Pred,
- F16x2,
-}
-
-impl From<ast::ScalarType> for SpirvScalarKey {
- fn from(t: ast::ScalarType) -> Self {
- match t {
- ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8,
- ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
- SpirvScalarKey::B16
- }
- ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => {
- SpirvScalarKey::B32
- }
- ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => {
- SpirvScalarKey::B64
- }
- ast::ScalarType::F16 => SpirvScalarKey::F16,
- ast::ScalarType::F32 => SpirvScalarKey::F32,
- ast::ScalarType::F64 => SpirvScalarKey::F64,
- ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
- ast::ScalarType::Pred => SpirvScalarKey::Pred,
- ast::ScalarType::S16x2
- | ast::ScalarType::U16x2
- | ast::ScalarType::BF16
- | ast::ScalarType::BF16x2
- | ast::ScalarType::B128 => todo!(),
- }
- }
-}
-
-fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass {
- match this {
- ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::StateSpace::Generic => spirv::StorageClass::Generic,
- ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::StateSpace::Local => spirv::StorageClass::Function,
- ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::StateSpace::Param => spirv::StorageClass::Function,
- ast::StateSpace::Reg => spirv::StorageClass::Function,
- ast::StateSpace::ParamEntry
- | ast::StateSpace::ParamFunc
- | ast::StateSpace::SharedCluster
- | ast::StateSpace::SharedCta => todo!(),
- }
-}
-
-// TODO: remove this once we have pef-function support for denorms
-fn emit_denorm_build_string<'input>(
- call_map: &MethodsCallMap,
- denorm_information: &HashMap<
- ast::MethodName<'input, SpirvWord>,
- HashMap<u8, (spirv::FPDenormMode, isize)>,
- >,
-) -> (CString, bool) {
- let denorm_counts = denorm_information
- .iter()
- .map(|(method, meth_denorm)| {
- let f16_count = meth_denorm
- .get(&(mem::size_of::<f16>() as u8))
- .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
- .1;
- let f32_count = meth_denorm
- .get(&(mem::size_of::<f32>() as u8))
- .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
- .1;
- (method, (f16_count + f32_count))
- })
- .collect::<HashMap<_, _>>();
- let mut flush_over_preserve = 0;
- for (kernel, children) in call_map.kernels() {
- flush_over_preserve += *denorm_counts
- .get(&ast::MethodName::Kernel(kernel))
- .unwrap_or(&0);
- for child_fn in children {
- flush_over_preserve += *denorm_counts
- .get(&ast::MethodName::Func(*child_fn))
- .unwrap_or(&0);
- }
- }
- if flush_over_preserve > 0 {
- (
- CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(),
- true,
- )
- } else {
- (CString::new("-ze-take-global-address").unwrap(), false)
- }
-}
-
-fn get_globals_use_map<'input>(
- directives: Vec<Directive<'input>>,
-) -> (
- Vec<Directive<'input>>,
- HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
-) {
- let mut known_globals = HashSet::new();
- for directive in directives.iter() {
- match directive {
- Directive::Variable(_, ast::Variable { name, .. }) => {
- known_globals.insert(*name);
- }
- Directive::Method(..) => {}
- }
- }
- let mut symbol_uses_map = HashMap::new();
- let directives = directives
- .into_iter()
- .map(|directive| match directive {
- Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive,
- Directive::Method(Function {
- func_decl,
- body: Some(mut statements),
- globals,
- import_as,
- tuning,
- linkage,
- }) => {
- let method_name = func_decl.borrow().name;
- statements = statements
- .into_iter()
- .map(|statement| {
- statement.visit_map(
- &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| {
- if known_globals.contains(&symbol) {
- multi_hash_map_append(
- &mut symbol_uses_map,
- method_name,
- symbol,
- );
- }
- Ok::<_, TranslateError>(symbol)
- },
- )
- })
- .collect::<Result<Vec<_>, _>>()
- .unwrap();
- Directive::Method(Function {
- func_decl,
- body: Some(statements),
- globals,
- import_as,
- tuning,
- linkage,
- })
- }
- })
- .collect::<Vec<_>>();
- (directives, symbol_uses_map)
-}
-
-fn emit_directives<'input>(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- id_defs: &GlobalStringIdResolver<'input>,
- opencl_id: spirv::Word,
- should_flush_denorms: bool,
- call_map: &MethodsCallMap<'input>,
- globals_use_map: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
- directives: Vec<Directive<'input>>,
- kernel_info: &mut HashMap<String, KernelInfo>,
-) -> Result<(), TranslateError> {
- let empty_body = Vec::new();
- for d in directives.iter() {
- match d {
- Directive::Variable(linking, var) => {
- emit_variable(builder, map, id_defs, *linking, &var)?;
- }
- Directive::Method(f) => {
- let f_body = match &f.body {
- Some(f) => f,
- None => {
- if f.linkage.contains(ast::LinkingDirective::EXTERN) {
- &empty_body
- } else {
- continue;
- }
- }
- };
- for var in f.globals.iter() {
- emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?;
- }
- let func_decl = (*f.func_decl).borrow();
- let fn_id = emit_function_header(
- builder,
- map,
- &id_defs,
- &*func_decl,
- call_map,
- &globals_use_map,
- kernel_info,
- )?;
- if matches!(func_decl.name, ast::MethodName::Kernel(_)) {
- if should_flush_denorms {
- builder.execution_mode(
- fn_id.0,
- spirv_headers::ExecutionMode::DenormFlushToZero,
- [16],
- );
- builder.execution_mode(
- fn_id.0,
- spirv_headers::ExecutionMode::DenormFlushToZero,
- [32],
- );
- builder.execution_mode(
- fn_id.0,
- spirv_headers::ExecutionMode::DenormFlushToZero,
- [64],
- );
- }
- // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx)
- builder.execution_mode(
- fn_id.0,
- spirv_headers::ExecutionMode::ContractionOff,
- [],
- );
- for t in f.tuning.iter() {
- match *t {
- ast::TuningDirective::MaxNtid(nx, ny, nz) => {
- builder.execution_mode(
- fn_id.0,
- spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
- [nx, ny, nz],
- );
- }
- ast::TuningDirective::ReqNtid(nx, ny, nz) => {
- builder.execution_mode(
- fn_id.0,
- spirv_headers::ExecutionMode::LocalSize,
- [nx, ny, nz],
- );
- }
- // Too architecture specific
- ast::TuningDirective::MaxNReg(..)
- | ast::TuningDirective::MinNCtaPerSm(..) => {}
- }
- }
- }
- emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
- emit_function_linkage(builder, id_defs, f, fn_id)?;
- builder.select_block(None)?;
- builder.end_function()?;
- }
- }
- }
- Ok(())
-}
-
-fn emit_variable<'input>(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- id_defs: &GlobalStringIdResolver<'input>,
- linking: ast::LinkingDirective,
- var: &ast::Variable<SpirvWord>,
-) -> Result<(), TranslateError> {
- let (must_init, st_class) = match var.state_space {
- ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
- (false, spirv::StorageClass::Function)
- }
- ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
- ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
- ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant),
- ast::StateSpace::Generic => todo!(),
- ast::StateSpace::ParamEntry
- | ast::StateSpace::ParamFunc
- | ast::StateSpace::SharedCluster
- | ast::StateSpace::SharedCta => todo!(),
- };
- let initalizer = if var.array_init.len() > 0 {
- Some(
- map.get_or_add_constant(
- builder,
- &ast::Type::from(var.v_type.clone()),
- &*var.array_init,
- )?
- .0,
- )
- } else if must_init {
- let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone()));
- Some(builder.constant_null(type_id.0, None))
- } else {
- None
- };
- let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class));
- builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer);
- if let Some(align) = var.align {
- builder.decorate(
- var.name.0,
- spirv::Decoration::Alignment,
- [dr::Operand::LiteralInt32(align)].iter().cloned(),
- );
- }
- if var.state_space != ast::StateSpace::Shared
- || !linking.contains(ast::LinkingDirective::EXTERN)
- {
- emit_linking_decoration(builder, id_defs, None, var.name, linking);
- }
- Ok(())
-}
-
-fn emit_function_header<'input>(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- defined_globals: &GlobalStringIdResolver<'input>,
- func_decl: &ast::MethodDeclaration<'input, SpirvWord>,
- call_map: &MethodsCallMap<'input>,
- globals_use_map: &HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
- kernel_info: &mut HashMap<String, KernelInfo>,
-) -> Result<SpirvWord, TranslateError> {
- if let ast::MethodName::Kernel(name) = func_decl.name {
- let args_lens = func_decl
- .input_arguments
- .iter()
- .map(|param| {
- (
- type_size_of(¶m.v_type),
- matches!(param.v_type, ast::Type::Pointer(..)),
- )
- })
- .collect();
- kernel_info.insert(
- name.to_string(),
- KernelInfo {
- arguments_sizes: args_lens,
- uses_shared_mem: func_decl.shared_mem.is_some(),
- },
- );
- }
- let (ret_type, func_type) = get_function_type(
- builder,
- map,
- effective_input_arguments(func_decl).map(|(_, typ)| typ),
- &func_decl.return_arguments,
- );
- let fn_id = match func_decl.name {
- ast::MethodName::Kernel(name) => {
- let fn_id = defined_globals.get_id(name)?;
- let interface = globals_use_map
- .get(&ast::MethodName::Kernel(name))
- .into_iter()
- .flatten()
- .copied()
- .chain({
- call_map
- .get_kernel_children(name)
- .copied()
- .flat_map(|subfunction| {
- globals_use_map
- .get(&ast::MethodName::Func(subfunction))
- .into_iter()
- .flatten()
- .copied()
- })
- .into_iter()
- })
- .map(|word| word.0)
- .collect::<Vec<spirv::Word>>();
- builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface);
- fn_id
- }
- ast::MethodName::Func(name) => name,
- };
- builder.begin_function(
- ret_type.0,
- Some(fn_id.0),
- spirv::FunctionControl::NONE,
- func_type.0,
- )?;
- for (name, typ) in effective_input_arguments(func_decl) {
- let result_type = map.get_or_add(builder, typ);
- builder.function_parameter(Some(name.0), result_type.0)?;
- }
- Ok(fn_id)
-}
-
-pub fn type_size_of(this: &ast::Type) -> usize {
- match this {
- ast::Type::Scalar(typ) => typ.size_of() as usize,
- ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize),
- ast::Type::Array(_, typ, len) => len
- .iter()
- .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
- ast::Type::Pointer(..) => mem::size_of::<usize>(),
- }
-}
-fn emit_function_body_ops<'input>(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- id_defs: &GlobalStringIdResolver<'input>,
- opencl: spirv::Word,
- func: &[ExpandedStatement],
-) -> Result<(), TranslateError> {
- for s in func {
- match s {
- Statement::Label(id) => {
- if builder.selected_block().is_some() {
- builder.branch(id.0)?;
- }
- builder.begin_block(Some(id.0))?;
- }
- _ => {
- if builder.selected_block().is_none() && builder.selected_function().is_some() {
- builder.begin_block(None)?;
- }
- }
- }
- match s {
- Statement::Label(_) => (),
- Statement::Variable(var) => {
- emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?;
- }
- Statement::Constant(cnst) => {
- let typ_id = map.get_or_add_scalar(builder, cnst.typ);
- match (cnst.typ, cnst.value) {
- (ast::ScalarType::B8, ast::ImmediateValue::U64(value))
- | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32);
- }
- (ast::ScalarType::B16, ast::ImmediateValue::U64(value))
- | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32);
- }
- (ast::ScalarType::B32, ast::ImmediateValue::U64(value))
- | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32);
- }
- (ast::ScalarType::B64, ast::ImmediateValue::U64(value))
- | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => {
- builder.constant_u64(typ_id.0, Some(cnst.dst.0), value);
- }
- (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32);
- }
- (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32);
- }
- (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32);
- }
- (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => {
- builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64);
- }
- (ast::ScalarType::B8, ast::ImmediateValue::S64(value))
- | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32);
- }
- (ast::ScalarType::B16, ast::ImmediateValue::S64(value))
- | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32);
- }
- (ast::ScalarType::B32, ast::ImmediateValue::S64(value))
- | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32);
- }
- (ast::ScalarType::B64, ast::ImmediateValue::S64(value))
- | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => {
- builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64);
- }
- (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32);
- }
- (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32);
- }
- (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32);
- }
- (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => {
- builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64);
- }
- (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => {
- builder.constant_f32(
- typ_id.0,
- Some(cnst.dst.0),
- f16::from_f32(value).to_f32(),
- );
- }
- (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => {
- builder.constant_f32(typ_id.0, Some(cnst.dst.0), value);
- }
- (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => {
- builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64);
- }
- (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => {
- builder.constant_f32(
- typ_id.0,
- Some(cnst.dst.0),
- f16::from_f64(value).to_f32(),
- );
- }
- (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => {
- builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32);
- }
- (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => {
- builder.constant_f64(typ_id.0, Some(cnst.dst.0), value);
- }
- (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => {
- let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0;
- if value == 0 {
- builder.constant_false(bool_type, Some(cnst.dst.0));
- } else {
- builder.constant_true(bool_type, Some(cnst.dst.0));
- }
- }
- (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => {
- let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0;
- if value == 0 {
- builder.constant_false(bool_type, Some(cnst.dst.0));
- } else {
- builder.constant_true(bool_type, Some(cnst.dst.0));
- }
- }
- _ => return Err(error_mismatched_type()),
- }
- }
- Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
- Statement::Conditional(bra) => {
- builder.branch_conditional(
- bra.predicate.0,
- bra.if_true.0,
- bra.if_false.0,
- iter::empty(),
- )?;
- }
- Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
- // TODO: implement properly
- let zero = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U64),
- &vec_repr(0u64),
- )?;
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64);
- builder.copy_object(result_type.0, Some(dst.0), zero.0)?;
- }
- Statement::Instruction(inst) => match inst {
- ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(),
- ast::Instruction::Call { data, arguments } => {
- let (result_type, result_id) =
- match (&*data.return_arguments, &*arguments.return_arguments) {
- ([(type_, space)], [id]) => {
- if *space != ast::StateSpace::Reg {
- return Err(error_unreachable());
- }
- (
- map.get_or_add(builder, SpirvType::new(type_.clone())).0,
- Some(id.0),
- )
- }
- ([], []) => (map.void(), None),
- _ => todo!(),
- };
- let arg_list = arguments
- .input_arguments
- .iter()
- .map(|id| id.0)
- .collect::<Vec<_>>();
- builder.function_call(result_type, result_id, arguments.func.0, arg_list)?;
- }
- ast::Instruction::Abs { data, arguments } => {
- emit_abs(builder, map, opencl, data, arguments)?
- }
- // SPIR-V does not support marking jumps as guaranteed-converged
- ast::Instruction::Bra { arguments, .. } => {
- builder.branch(arguments.src.0)?;
- }
- ast::Instruction::Ld { data, arguments } => {
- let mem_access = match data.qualifier {
- ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE,
- // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad
- ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE,
- _ => return Err(TranslateError::Todo),
- };
- let result_type =
- map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
- builder.load(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src.0,
- Some(mem_access | spirv::MemoryAccess::ALIGNED),
- [dr::Operand::LiteralInt32(
- type_size_of(&ast::Type::from(data.typ.clone())) as u32,
- )]
- .iter()
- .cloned(),
- )?;
- }
- ast::Instruction::St { data, arguments } => {
- let mem_access = match data.qualifier {
- ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE,
- // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore
- ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE,
- _ => return Err(TranslateError::Todo),
- };
- builder.store(
- arguments.src1.0,
- arguments.src2.0,
- Some(mem_access | spirv::MemoryAccess::ALIGNED),
- [dr::Operand::LiteralInt32(
- type_size_of(&ast::Type::from(data.typ.clone())) as u32,
- )]
- .iter()
- .cloned(),
- )?;
- }
- // SPIR-V does not support ret as guaranteed-converged
- ast::Instruction::Ret { .. } => builder.ret()?,
- ast::Instruction::Mov { data, arguments } => {
- let result_type =
- map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
- builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
- }
- ast::Instruction::Mul { data, arguments } => match data {
- ast::MulDetails::Integer { type_, control } => {
- emit_mul_int(builder, map, opencl, *type_, *control, arguments)?
- }
- ast::MulDetails::Float(ref ctr) => {
- emit_mul_float(builder, map, ctr, arguments)?
- }
- },
- ast::Instruction::Add { data, arguments } => match data {
- ast::ArithDetails::Integer(desc) => {
- emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)?
- }
- ast::ArithDetails::Float(desc) => {
- emit_add_float(builder, map, desc, arguments)?
- }
- },
- ast::Instruction::Setp { data, arguments } => {
- if arguments.dst2.is_some() {
- todo!()
- }
- emit_setp(builder, map, data, arguments)?;
- }
- ast::Instruction::Not { data, arguments } => {
- let result_type = map.get_or_add(builder, SpirvType::from(*data));
- let result_id = Some(arguments.dst.0);
- let operand = arguments.src;
- match data {
- ast::ScalarType::Pred => {
- logical_not(builder, result_type.0, result_id, operand.0)
- }
- _ => builder.not(result_type.0, result_id, operand.0),
- }?;
- }
- ast::Instruction::Shl { data, arguments } => {
- let full_type = ast::Type::Scalar(*data);
- let size_of = type_size_of(&full_type);
- let result_type = map.get_or_add(builder, SpirvType::new(full_type));
- let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?;
- builder.shift_left_logical(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- offset_src,
- )?;
- }
- ast::Instruction::Shr { data, arguments } => {
- let full_type = ast::ScalarType::from(data.type_);
- let size_of = full_type.size_of();
- let result_type = map.get_or_add_scalar(builder, full_type).0;
- let offset_src =
- insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?;
- match data.kind {
- ptx_parser::RightShiftKind::Arithmetic => {
- builder.shift_right_arithmetic(
- result_type,
- Some(arguments.dst.0),
- arguments.src1.0,
- offset_src,
- )?;
- }
- ptx_parser::RightShiftKind::Logical => {
- builder.shift_right_logical(
- result_type,
- Some(arguments.dst.0),
- arguments.src1.0,
- offset_src,
- )?;
- }
- }
- }
- ast::Instruction::Cvt { data, arguments } => {
- emit_cvt(builder, map, opencl, data, arguments)?;
- }
- ast::Instruction::Cvta { data, arguments } => {
- // This would be only meaningful if const/slm/global pointers
- // had a different format than generic pointers, but they don't pretty much by ptx definition
- // Honestly, I have no idea why this instruction exists and is emitted by the compiler
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
- builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
- }
- ast::Instruction::SetpBool { .. } => todo!(),
- ast::Instruction::Mad { data, arguments } => match data {
- ast::MadDetails::Integer {
- type_,
- control,
- saturate,
- } => {
- if *saturate {
- todo!()
- }
- if type_.kind() == ast::ScalarKind::Signed {
- emit_mad_sint(builder, map, opencl, *type_, *control, arguments)?
- } else {
- emit_mad_uint(builder, map, opencl, *type_, *control, arguments)?
- }
- }
- ast::MadDetails::Float(desc) => {
- emit_mad_float(builder, map, opencl, desc, arguments)?
- }
- },
- ast::Instruction::Fma { data, arguments } => {
- emit_fma_float(builder, map, opencl, data, arguments)?
- }
- ast::Instruction::Or { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, *data).0;
- if *data == ast::ScalarType::Pred {
- builder.logical_or(
- result_type,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- } else {
- builder.bitwise_or(
- result_type,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- }
- }
- ast::Instruction::Sub { data, arguments } => match data {
- ast::ArithDetails::Integer(desc) => {
- emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?;
- }
- ast::ArithDetails::Float(desc) => {
- emit_sub_float(builder, map, desc, arguments)?;
- }
- },
- ast::Instruction::Min { data, arguments } => {
- emit_min(builder, map, opencl, data, arguments)?;
- }
- ast::Instruction::Max { data, arguments } => {
- emit_max(builder, map, opencl, data, arguments)?;
- }
- ast::Instruction::Rcp { data, arguments } => {
- emit_rcp(builder, map, opencl, data, arguments)?;
- }
- ast::Instruction::And { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, *data);
- if *data == ast::ScalarType::Pred {
- builder.logical_and(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- } else {
- builder.bitwise_and(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- }
- }
- ast::Instruction::Selp { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, *data);
- builder.select(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src3.0,
- arguments.src1.0,
- arguments.src2.0,
- )?;
- }
- // TODO: implement named barriers
- ast::Instruction::Bar { data, arguments } => {
- let workgroup_scope = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(spirv::Scope::Workgroup as u32),
- )?;
- let barrier_semantics = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(
- spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
- | spirv::MemorySemantics::WORKGROUP_MEMORY
- | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
- ),
- )?;
- builder.control_barrier(
- workgroup_scope.0,
- workgroup_scope.0,
- barrier_semantics.0,
- )?;
- }
- ast::Instruction::Atom { data, arguments } => {
- emit_atom(builder, map, data, arguments)?;
- }
- ast::Instruction::AtomCas { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, data.type_);
- let memory_const = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(scope_to_spirv(data.scope) as u32),
- )?;
- let semantics_const = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(semantics_to_spirv(data.semantics).bits()),
- )?;
- builder.atomic_compare_exchange(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- memory_const.0,
- semantics_const.0,
- semantics_const.0,
- arguments.src3.0,
- arguments.src2.0,
- )?;
- }
- ast::Instruction::Div { data, arguments } => match data {
- ast::DivDetails::Unsigned(t) => {
- let result_type = map.get_or_add_scalar(builder, (*t).into());
- builder.u_div(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- }
- ast::DivDetails::Signed(t) => {
- let result_type = map.get_or_add_scalar(builder, (*t).into());
- builder.s_div(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- }
- ast::DivDetails::Float(t) => {
- let result_type = map.get_or_add_scalar(builder, t.type_.into());
- builder.f_div(
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- emit_float_div_decoration(builder, arguments.dst, t.kind);
- }
- },
- ast::Instruction::Sqrt { data, arguments } => {
- emit_sqrt(builder, map, opencl, data, arguments)?;
- }
- ast::Instruction::Rsqrt { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, data.type_.into());
- builder.ext_inst(
- result_type.0,
- Some(arguments.dst.0),
- opencl,
- spirv::CLOp::rsqrt as spirv::Word,
- [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
- )?;
- }
- ast::Instruction::Neg { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, data.type_);
- let negate_func = if data.type_.kind() == ast::ScalarKind::Float {
- dr::Builder::f_negate
- } else {
- dr::Builder::s_negate
- };
- negate_func(
- builder,
- result_type.0,
- Some(arguments.dst.0),
- arguments.src.0,
- )?;
- }
- ast::Instruction::Sin { arguments, .. } => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
- builder.ext_inst(
- result_type.0,
- Some(arguments.dst.0),
- opencl,
- spirv::CLOp::sin as u32,
- [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
- )?;
- }
- ast::Instruction::Cos { arguments, .. } => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
- builder.ext_inst(
- result_type.0,
- Some(arguments.dst.0),
- opencl,
- spirv::CLOp::cos as u32,
- [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
- )?;
- }
- ast::Instruction::Lg2 { arguments, .. } => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
- builder.ext_inst(
- result_type.0,
- Some(arguments.dst.0),
- opencl,
- spirv::CLOp::log2 as u32,
- [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
- )?;
- }
- ast::Instruction::Ex2 { arguments, .. } => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
- builder.ext_inst(
- result_type.0,
- Some(arguments.dst.0),
- opencl,
- spirv::CLOp::exp2 as u32,
- [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
- )?;
- }
- ast::Instruction::Clz { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, (*data).into());
- builder.ext_inst(
- result_type.0,
- Some(arguments.dst.0),
- opencl,
- spirv::CLOp::clz as u32,
- [dr::Operand::IdRef(arguments.src.0)].iter().cloned(),
- )?;
- }
- ast::Instruction::Brev { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, (*data).into());
- builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
- }
- ast::Instruction::Popc { data, arguments } => {
- let result_type = map.get_or_add_scalar(builder, (*data).into());
- builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?;
- }
- ast::Instruction::Xor { data, arguments } => {
- let builder_fn: fn(
- &mut dr::Builder,
- u32,
- Option<u32>,
- u32,
- u32,
- ) -> Result<u32, dr::Error> = match data {
- ast::ScalarType::Pred => emit_logical_xor_spirv,
- _ => dr::Builder::bitwise_xor,
- };
- let result_type = map.get_or_add_scalar(builder, (*data).into());
- builder_fn(
- builder,
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- }
- ast::Instruction::Bfe { .. }
- | ast::Instruction::Bfi { .. }
- | ast::Instruction::Activemask { .. } => {
- // Should have beeen replaced with a funciton call earlier
- return Err(error_unreachable());
- }
-
- ast::Instruction::Rem { data, arguments } => {
- let builder_fn = if data.kind() == ast::ScalarKind::Signed {
- dr::Builder::s_mod
- } else {
- dr::Builder::u_mod
- };
- let result_type = map.get_or_add_scalar(builder, (*data).into());
- builder_fn(
- builder,
- result_type.0,
- Some(arguments.dst.0),
- arguments.src1.0,
- arguments.src2.0,
- )?;
- }
- ast::Instruction::Prmt { data, arguments } => {
- let control = *data as u32;
- let components = [
- (control >> 0) & 0b1111,
- (control >> 4) & 0b1111,
- (control >> 8) & 0b1111,
- (control >> 12) & 0b1111,
- ];
- if components.iter().any(|&c| c > 7) {
- return Err(TranslateError::Todo);
- }
- let vec4_b8_type =
- map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4));
- let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
- let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?;
- let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?;
- let dst_vector = builder.vector_shuffle(
- vec4_b8_type.0,
- None,
- src1_vector,
- src2_vector,
- components,
- )?;
- builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?;
- }
- ast::Instruction::Membar { data } => {
- let (scope, semantics) = match data {
- ast::MemScope::Cta => (
- spirv::Scope::Workgroup,
- spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
- | spirv::MemorySemantics::WORKGROUP_MEMORY
- | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
- ),
- ast::MemScope::Gpu => (
- spirv::Scope::Device,
- spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
- | spirv::MemorySemantics::WORKGROUP_MEMORY
- | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
- ),
- ast::MemScope::Sys => (
- spirv::Scope::CrossDevice,
- spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
- | spirv::MemorySemantics::WORKGROUP_MEMORY
- | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
- ),
-
- ast::MemScope::Cluster => todo!(),
- };
- let spirv_scope = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(scope as u32),
- )?;
- let spirv_semantics = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(semantics),
- )?;
- builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?;
- }
- },
- Statement::LoadVar(details) => {
- emit_load_var(builder, map, details)?;
- }
- Statement::StoreVar(details) => {
- let dst_ptr = match details.member_index {
- Some(index) => {
- let result_ptr_type = map.get_or_add(
- builder,
- SpirvType::pointer_to(
- details.typ.clone(),
- spirv::StorageClass::Function,
- ),
- );
- let index_spirv = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(index as u32),
- )?;
- builder.in_bounds_access_chain(
- result_ptr_type.0,
- None,
- details.arg.src1.0,
- [index_spirv.0].iter().copied(),
- )?
- }
- None => details.arg.src1.0,
- };
- builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?;
- }
- Statement::RetValue(_, id) => {
- builder.ret_value(id.0)?;
- }
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- offset_src,
- }) => {
- let u8_pointer = map.get_or_add(
- builder,
- SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)),
- );
- let result_type = map.get_or_add(
- builder,
- SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)),
- );
- let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?;
- let temp = builder.in_bounds_ptr_access_chain(
- u8_pointer.0,
- None,
- ptr_src_u8,
- offset_src.0,
- iter::empty(),
- )?;
- builder.bitcast(result_type.0, Some(dst.0), temp)?;
- }
- Statement::RepackVector(repack) => {
- if repack.is_extract {
- let scalar_type = map.get_or_add_scalar(builder, repack.typ);
- for (index, dst_id) in repack.unpacked.iter().enumerate() {
- builder.composite_extract(
- scalar_type.0,
- Some(dst_id.0),
- repack.packed.0,
- [index as u32].iter().copied(),
- )?;
- }
- } else {
- let vector_type = map.get_or_add(
- builder,
- SpirvType::Vector(
- SpirvScalarKey::from(repack.typ),
- repack.unpacked.len() as u8,
- ),
- );
- let mut temp_vec = builder.undef(vector_type.0, None);
- for (index, src_id) in repack.unpacked.iter().enumerate() {
- temp_vec = builder.composite_insert(
- vector_type.0,
- None,
- src_id.0,
- temp_vec,
- [index as u32].iter().copied(),
- )?;
- }
- builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
- }
- }
- Statement::VectorAccess(vector_access) => todo!(),
- }
- }
- Ok(())
-}
-
-fn emit_function_linkage<'input>(
- builder: &mut dr::Builder,
- id_defs: &GlobalStringIdResolver<'input>,
- f: &Function,
- fn_name: SpirvWord,
-) -> Result<(), TranslateError> {
- if f.linkage == ast::LinkingDirective::NONE {
- return Ok(());
- };
- let linking_name = match f.func_decl.borrow().name {
- // According to SPIR-V rules linkage attributes are invalid on kernels
- ast::MethodName::Kernel(..) => return Ok(()),
- ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else(
- || match id_defs.reverse_variables.get(&fn_id) {
- Some(fn_name) => Ok(fn_name),
- None => Err(error_unknown_symbol()),
- },
- Result::Ok,
- )?,
- };
- emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage);
- Ok(())
-}
-
-fn get_function_type(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- spirv_input: impl Iterator<Item = SpirvType>,
- spirv_output: &[ast::Variable<SpirvWord>],
-) -> (SpirvWord, SpirvWord) {
- map.get_or_add_fn(
- builder,
- spirv_input,
- spirv_output
- .iter()
- .map(|var| SpirvType::new(var.v_type.clone())),
- )
-}
-
-fn emit_linking_decoration<'input>(
- builder: &mut dr::Builder,
- id_defs: &GlobalStringIdResolver<'input>,
- name_override: Option<&str>,
- name: SpirvWord,
- linking: ast::LinkingDirective,
-) {
- if linking == ast::LinkingDirective::NONE {
- return;
- }
- if linking.contains(ast::LinkingDirective::VISIBLE) {
- let string_name =
- name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
- builder.decorate(
- name.0,
- spirv::Decoration::LinkageAttributes,
- [
- dr::Operand::LiteralString(string_name.to_string()),
- dr::Operand::LinkageType(spirv::LinkageType::Export),
- ]
- .iter()
- .cloned(),
- );
- } else if linking.contains(ast::LinkingDirective::EXTERN) {
- let string_name =
- name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
- builder.decorate(
- name.0,
- spirv::Decoration::LinkageAttributes,
- [
- dr::Operand::LiteralString(string_name.to_string()),
- dr::Operand::LinkageType(spirv::LinkageType::Import),
- ]
- .iter()
- .cloned(),
- );
- }
- // TODO: handle LinkingDirective::WEAK
-}
-
-fn effective_input_arguments<'a>(
- this: &'a ast::MethodDeclaration<'a, SpirvWord>,
-) -> impl Iterator<Item = (SpirvWord, SpirvType)> + 'a {
- let is_kernel = matches!(this.name, ast::MethodName::Kernel(_));
- this.input_arguments.iter().map(move |arg| {
- if !is_kernel && arg.state_space != ast::StateSpace::Reg {
- let spirv_type =
- SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space));
- (arg.name, spirv_type)
- } else {
- (arg.name, SpirvType::new(arg.v_type.clone()))
- }
- })
-}
-
-fn emit_implicit_conversion(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- cv: &ImplicitConversion,
-) -> Result<(), TranslateError> {
- let from_parts = to_parts(&cv.from_type);
- let to_parts = to_parts(&cv.to_type);
- match (from_parts.kind, to_parts.kind, &cv.kind) {
- (_, _, &ConversionKind::BitToPtr) => {
- let dst_type = map.get_or_add(
- builder,
- SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)),
- );
- builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?;
- }
- (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => {
- if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
- if from_parts.scalar_kind != ast::ScalarKind::Float
- && to_parts.scalar_kind != ast::ScalarKind::Float
- {
- // It is noop, but another instruction expects result of this conversion
- builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?;
- } else {
- builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?;
- }
- } else {
- // This block is safe because it's illegal to implictly convert between floating point values
- let same_width_bit_type = map.get_or_add(
- builder,
- SpirvType::new(type_from_parts(TypeParts {
- scalar_kind: ast::ScalarKind::Bit,
- ..from_parts
- })),
- );
- let same_width_bit_value =
- builder.bitcast(same_width_bit_type.0, None, cv.src.0)?;
- let wide_bit_type = type_from_parts(TypeParts {
- scalar_kind: ast::ScalarKind::Bit,
- ..to_parts
- });
- let wide_bit_type_spirv =
- map.get_or_add(builder, SpirvType::new(wide_bit_type.clone()));
- if to_parts.scalar_kind == ast::ScalarKind::Unsigned
- || to_parts.scalar_kind == ast::ScalarKind::Bit
- {
- builder.u_convert(
- wide_bit_type_spirv.0,
- Some(cv.dst.0),
- same_width_bit_value,
- )?;
- } else {
- let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
- && to_parts.scalar_kind == ast::ScalarKind::Signed
- {
- dr::Builder::s_convert
- } else {
- dr::Builder::u_convert
- };
- let wide_bit_value =
- conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?;
- emit_implicit_conversion(
- builder,
- map,
- &ImplicitConversion {
- src: SpirvWord(wide_bit_value),
- dst: cv.dst,
- from_type: wide_bit_type,
- from_space: cv.from_space,
- to_type: cv.to_type.clone(),
- to_space: cv.to_space,
- kind: ConversionKind::Default,
- },
- )?;
- }
- }
- }
- (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => {
- let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
- builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?;
- }
- (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default)
- | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default)
- | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
- builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?;
- }
- (_, _, &ConversionKind::PtrToPtr) => {
- let result_type = map.get_or_add(
- builder,
- SpirvType::Pointer(
- Box::new(SpirvType::new(cv.to_type.clone())),
- space_to_spirv(cv.to_space),
- ),
- );
- if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic
- {
- let src = if cv.from_type != cv.to_type {
- let temp_type = map.get_or_add(
- builder,
- SpirvType::Pointer(
- Box::new(SpirvType::new(cv.to_type.clone())),
- space_to_spirv(cv.from_space),
- ),
- );
- builder.bitcast(temp_type.0, None, cv.src.0)?
- } else {
- cv.src.0
- };
- builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?;
- } else if cv.from_space == ast::StateSpace::Generic
- && cv.to_space != ast::StateSpace::Generic
- {
- let src = if cv.from_type != cv.to_type {
- let temp_type = map.get_or_add(
- builder,
- SpirvType::Pointer(
- Box::new(SpirvType::new(cv.to_type.clone())),
- space_to_spirv(cv.from_space),
- ),
- );
- builder.bitcast(temp_type.0, None, cv.src.0)?
- } else {
- cv.src.0
- };
- builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?;
- } else {
- builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?;
- }
- }
- (_, _, &ConversionKind::AddressOf) => {
- let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
- builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?;
- }
- (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => {
- let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
- builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?;
- }
- (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => {
- let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
- builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?;
- }
- _ => unreachable!(),
- }
- Ok(())
-}
-
-fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
- let mut result = vec![0; mem::size_of::<T>()];
- unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) };
- result
-}
-
-fn emit_abs(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- d: &ast::TypeFtz,
- arg: &ast::AbsArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let scalar_t = ast::ScalarType::from(d.type_);
- let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
- let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
- spirv::CLOp::s_abs
- } else {
- spirv::CLOp::fabs
- };
- builder.ext_inst(
- result_type.0,
- Some(arg.dst.0),
- opencl,
- cl_abs as spirv::Word,
- [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
- )?;
- Ok(())
-}
-
-fn emit_mul_int(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- type_: ast::ScalarType,
- control: ast::MulIntControl,
- arg: &ast::MulArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(type_));
- match control {
- ast::MulIntControl::Low => {
- builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
- }
- ast::MulIntControl::High => {
- let opencl_inst = if type_.kind() == ast::ScalarKind::Signed {
- spirv::CLOp::s_mul_hi
- } else {
- spirv::CLOp::u_mul_hi
- };
- builder.ext_inst(
- inst_type.0,
- Some(arg.dst.0),
- opencl,
- opencl_inst as spirv::Word,
- [
- dr::Operand::IdRef(arg.src1.0),
- dr::Operand::IdRef(arg.src2.0),
- ]
- .iter()
- .cloned(),
- )?;
- }
- ast::MulIntControl::Wide => {
- let instr_width = type_.size_of();
- let instr_kind = type_.kind();
- let dst_type = scalar_from_parts(instr_width * 2, instr_kind);
- let dst_type_id = map.get_or_add_scalar(builder, dst_type);
- let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed {
- let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?;
- let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?;
- (src1, src2)
- } else {
- let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?;
- let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?;
- (src1, src2)
- };
- builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?;
- builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty());
- }
- }
- Ok(())
-}
-
-fn emit_mul_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- ctr: &ast::ArithFloat,
- arg: &ast::MulArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- if ctr.saturate {
- todo!()
- }
- let result_type = map.get_or_add_scalar(builder, ctr.type_.into());
- builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
- emit_rounding_decoration(builder, arg.dst, ctr.rounding);
- Ok(())
-}
-
-fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType {
- match kind {
- ast::ScalarKind::Float => match width {
- 2 => ast::ScalarType::F16,
- 4 => ast::ScalarType::F32,
- 8 => ast::ScalarType::F64,
- _ => unreachable!(),
- },
- ast::ScalarKind::Bit => match width {
- 1 => ast::ScalarType::B8,
- 2 => ast::ScalarType::B16,
- 4 => ast::ScalarType::B32,
- 8 => ast::ScalarType::B64,
- _ => unreachable!(),
- },
- ast::ScalarKind::Signed => match width {
- 1 => ast::ScalarType::S8,
- 2 => ast::ScalarType::S16,
- 4 => ast::ScalarType::S32,
- 8 => ast::ScalarType::S64,
- _ => unreachable!(),
- },
- ast::ScalarKind::Unsigned => match width {
- 1 => ast::ScalarType::U8,
- 2 => ast::ScalarType::U16,
- 4 => ast::ScalarType::U32,
- 8 => ast::ScalarType::U64,
- _ => unreachable!(),
- },
- ast::ScalarKind::Pred => ast::ScalarType::Pred,
- }
-}
-
-fn emit_rounding_decoration(
- builder: &mut dr::Builder,
- dst: SpirvWord,
- rounding: Option<ast::RoundingMode>,
-) {
- if let Some(rounding) = rounding {
- builder.decorate(
- dst.0,
- spirv::Decoration::FPRoundingMode,
- [rounding_to_spirv(rounding)].iter().cloned(),
- );
- }
-}
-
-fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand {
- let mode = match this {
- ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE,
- ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ,
- ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP,
- ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN,
- };
- rspirv::dr::Operand::FPRoundingMode(mode)
-}
-
-fn emit_add_int(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- typ: ast::ScalarType,
- saturate: bool,
- arg: &ast::AddArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- if saturate {
- todo!()
- }
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
- builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
- Ok(())
-}
-
-fn emit_add_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- desc: &ast::ArithFloat,
- arg: &ast::AddArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)));
- builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
- emit_rounding_decoration(builder, arg.dst, desc.rounding);
- Ok(())
-}
-
-fn emit_setp(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- setp: &ast::SetpData,
- arg: &ast::SetpArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let result_type = map
- .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred))
- .0;
- let result_id = Some(arg.dst1.0);
- let operand_1 = arg.src1.0;
- let operand_2 = arg.src2.0;
- match setp.cmp_op {
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => {
- builder.i_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => {
- builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => {
- builder.i_not_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => {
- builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => {
- builder.u_less_than(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => {
- builder.s_less_than(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => {
- builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => {
- builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => {
- builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => {
- builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => {
- builder.u_greater_than(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => {
- builder.s_greater_than(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => {
- builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => {
- builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => {
- builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => {
- builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => {
- builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => {
- builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => {
- builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => {
- builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => {
- builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => {
- builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => {
- let temp1 = builder.is_nan(result_type, None, operand_1)?;
- let temp2 = builder.is_nan(result_type, None, operand_2)?;
- builder.logical_or(result_type, result_id, temp1, temp2)
- }
- ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => {
- let temp1 = builder.is_nan(result_type, None, operand_1)?;
- let temp2 = builder.is_nan(result_type, None, operand_2)?;
- let any_nan = builder.logical_or(result_type, None, temp1, temp2)?;
- logical_not(builder, result_type, result_id, any_nan)
- }
- _ => todo!(),
- }?;
- Ok(())
-}
-
-// HACK ALERT
-// Temporary workaround until IGC gets its shit together
-// Currently IGC carries two copies of SPIRV-LLVM translator
-// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/.
-// Obviously, old and buggy one is used for compiling L0 SPIRV
-// https://github.com/intel/intel-graphics-compiler/issues/148
-fn logical_not(
- builder: &mut dr::Builder,
- result_type: spirv::Word,
- result_id: Option<spirv::Word>,
- operand: spirv::Word,
-) -> Result<spirv::Word, dr::Error> {
- let const_true = builder.constant_true(result_type, None);
- let const_false = builder.constant_false(result_type, None);
- builder.select(result_type, result_id, operand, const_false, const_true)
-}
-
-// HACK ALERT
-// For some reason IGC fails linking if the value and shift size are of different type
-fn insert_shift_hack(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- offset_var: spirv::Word,
- size_of: usize,
-) -> Result<spirv::Word, TranslateError> {
- let result_type = match size_of {
- 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16),
- 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64),
- 4 => return Ok(offset_var),
- _ => return Err(error_unreachable()),
- };
- Ok(builder.u_convert(result_type.0, None, offset_var)?)
-}
-
-fn emit_cvt(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- dets: &ast::CvtDetails,
- arg: &ast::CvtArgs<SpirvWord>,
-) -> Result<(), TranslateError> {
- match dets.mode {
- ptx_parser::CvtMode::SignExtend => {
- let cv = ImplicitConversion {
- src: arg.src,
- dst: arg.dst,
- from_type: dets.from.into(),
- from_space: ast::StateSpace::Reg,
- to_type: dets.to.into(),
- to_space: ast::StateSpace::Reg,
- kind: ConversionKind::SignExtend,
- };
- emit_implicit_conversion(builder, map, &cv)?;
- }
- ptx_parser::CvtMode::ZeroExtend
- | ptx_parser::CvtMode::Truncate
- | ptx_parser::CvtMode::Bitcast => {
- let cv = ImplicitConversion {
- src: arg.src,
- dst: arg.dst,
- from_type: dets.from.into(),
- from_space: ast::StateSpace::Reg,
- to_type: dets.to.into(),
- to_space: ast::StateSpace::Reg,
- kind: ConversionKind::Default,
- };
- emit_implicit_conversion(builder, map, &cv)?;
- }
- ptx_parser::CvtMode::SaturateUnsignedToSigned => {
- let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
- builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?;
- }
- ptx_parser::CvtMode::SaturateSignedToUnsigned => {
- let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
- builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?;
- }
- ptx_parser::CvtMode::FPExtend { flush_to_zero } => {
- let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
- builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?;
- }
- ptx_parser::CvtMode::FPTruncate {
- rounding,
- flush_to_zero,
- } => {
- let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
- builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?;
- emit_rounding_decoration(builder, arg.dst, Some(rounding));
- }
- ptx_parser::CvtMode::FPRound {
- integer_rounding,
- flush_to_zero,
- } => {
- if flush_to_zero == Some(true) {
- todo!()
- }
- let result_type = map.get_or_add(builder, SpirvType::from(dets.to));
- match integer_rounding {
- Some(ast::RoundingMode::NearestEven) => {
- builder.ext_inst(
- result_type.0,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::rint as u32,
- [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
- )?;
- }
- Some(ast::RoundingMode::Zero) => {
- builder.ext_inst(
- result_type.0,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::trunc as u32,
- [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
- )?;
- }
- Some(ast::RoundingMode::NegativeInf) => {
- builder.ext_inst(
- result_type.0,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::floor as u32,
- [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
- )?;
- }
- Some(ast::RoundingMode::PositiveInf) => {
- builder.ext_inst(
- result_type.0,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::ceil as u32,
- [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
- )?;
- }
- None => {
- builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?;
- }
- }
- }
- ptx_parser::CvtMode::SignedFromFP {
- rounding,
- flush_to_zero,
- } => {
- let dest_t: ast::ScalarType = dets.to.into();
- let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?;
- emit_rounding_decoration(builder, arg.dst, Some(rounding));
- }
- ptx_parser::CvtMode::UnsignedFromFP {
- rounding,
- flush_to_zero,
- } => {
- let dest_t: ast::ScalarType = dets.to.into();
- let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?;
- emit_rounding_decoration(builder, arg.dst, Some(rounding));
- }
- ptx_parser::CvtMode::FPFromSigned(rounding) => {
- let dest_t: ast::ScalarType = dets.to.into();
- let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?;
- emit_rounding_decoration(builder, arg.dst, Some(rounding));
- }
- ptx_parser::CvtMode::FPFromUnsigned(rounding) => {
- let dest_t: ast::ScalarType = dets.to.into();
- let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?;
- emit_rounding_decoration(builder, arg.dst, Some(rounding));
- }
- }
- Ok(())
-}
-
-fn emit_mad_uint(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- type_: ast::ScalarType,
- control: ast::MulIntControl,
- arg: &ast::MadArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let inst_type = map
- .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_)))
- .0;
- match control {
- ast::MulIntControl::Low => {
- let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?;
- builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?;
- }
- ast::MulIntControl::High => {
- builder.ext_inst(
- inst_type,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::u_mad_hi as spirv::Word,
- [
- dr::Operand::IdRef(arg.src1.0),
- dr::Operand::IdRef(arg.src2.0),
- dr::Operand::IdRef(arg.src3.0),
- ]
- .iter()
- .cloned(),
- )?;
- }
- ast::MulIntControl::Wide => todo!(),
- };
- Ok(())
-}
-
-fn emit_mad_sint(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- type_: ast::ScalarType,
- control: ast::MulIntControl,
- arg: &ast::MadArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0;
- match control {
- ast::MulIntControl::Low => {
- let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?;
- builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?;
- }
- ast::MulIntControl::High => {
- builder.ext_inst(
- inst_type,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::s_mad_hi as spirv::Word,
- [
- dr::Operand::IdRef(arg.src1.0),
- dr::Operand::IdRef(arg.src2.0),
- dr::Operand::IdRef(arg.src3.0),
- ]
- .iter()
- .cloned(),
- )?;
- }
- ast::MulIntControl::Wide => todo!(),
- };
- Ok(())
-}
-
-fn emit_mad_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::ArithFloat,
- arg: &ast::MadArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let inst_type = map
- .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
- .0;
- builder.ext_inst(
- inst_type,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::mad as spirv::Word,
- [
- dr::Operand::IdRef(arg.src1.0),
- dr::Operand::IdRef(arg.src2.0),
- dr::Operand::IdRef(arg.src3.0),
- ]
- .iter()
- .cloned(),
- )?;
- Ok(())
-}
-
-fn emit_fma_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::ArithFloat,
- arg: &ast::FmaArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let inst_type = map
- .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
- .0;
- builder.ext_inst(
- inst_type,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::fma as spirv::Word,
- [
- dr::Operand::IdRef(arg.src1.0),
- dr::Operand::IdRef(arg.src2.0),
- dr::Operand::IdRef(arg.src3.0),
- ]
- .iter()
- .cloned(),
- )?;
- Ok(())
-}
-
-fn emit_sub_int(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- typ: ast::ScalarType,
- saturate: bool,
- arg: &ast::SubArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- if saturate {
- todo!()
- }
- let inst_type = map
- .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)))
- .0;
- builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
- Ok(())
-}
-
-fn emit_sub_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- desc: &ast::ArithFloat,
- arg: &ast::SubArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let inst_type = map
- .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_)))
- .0;
- builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?;
- emit_rounding_decoration(builder, arg.dst, desc.rounding);
- Ok(())
-}
-
-fn emit_min(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::MinMaxDetails,
- arg: &ast::MinArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let cl_op = match desc {
- ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
- ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
- ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
- };
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_()));
- builder.ext_inst(
- inst_type.0,
- Some(arg.dst.0),
- opencl,
- cl_op as spirv::Word,
- [
- dr::Operand::IdRef(arg.src1.0),
- dr::Operand::IdRef(arg.src2.0),
- ]
- .iter()
- .cloned(),
- )?;
- Ok(())
-}
-
-fn emit_max(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::MinMaxDetails,
- arg: &ast::MaxArgs<SpirvWord>,
-) -> Result<(), dr::Error> {
- let cl_op = match desc {
- ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
- ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
- ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
- };
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_()));
- builder.ext_inst(
- inst_type.0,
- Some(arg.dst.0),
- opencl,
- cl_op as spirv::Word,
- [
- dr::Operand::IdRef(arg.src1.0),
- dr::Operand::IdRef(arg.src2.0),
- ]
- .iter()
- .cloned(),
- )?;
- Ok(())
-}
-
-fn emit_rcp(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::RcpData,
- arg: &ast::RcpArgs<SpirvWord>,
-) -> Result<(), TranslateError> {
- let is_f64 = desc.type_ == ast::ScalarType::F64;
- let (instr_type, constant) = if is_f64 {
- (ast::ScalarType::F64, vec_repr(1.0f64))
- } else {
- (ast::ScalarType::F32, vec_repr(1.0f32))
- };
- let result_type = map.get_or_add_scalar(builder, instr_type);
- let rounding = match desc.kind {
- ptx_parser::RcpKind::Approx => {
- builder.ext_inst(
- result_type.0,
- Some(arg.dst.0),
- opencl,
- spirv::CLOp::native_recip as u32,
- [dr::Operand::IdRef(arg.src.0)].iter().cloned(),
- )?;
- return Ok(());
- }
- ptx_parser::RcpKind::Compliant(rounding) => rounding,
- };
- let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
- builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?;
- emit_rounding_decoration(builder, arg.dst, Some(rounding));
- builder.decorate(
- arg.dst.0,
- spirv::Decoration::FPFastMathMode,
- [dr::Operand::FPFastMathMode(
- spirv::FPFastMathMode::ALLOW_RECIP,
- )]
- .iter()
- .cloned(),
- );
- Ok(())
-}
-
-fn emit_atom(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- details: &ast::AtomDetails,
- arg: &ast::AtomArgs<SpirvWord>,
-) -> Result<(), TranslateError> {
- let spirv_op = match details.op {
- ptx_parser::AtomicOp::And => dr::Builder::atomic_and,
- ptx_parser::AtomicOp::Or => dr::Builder::atomic_or,
- ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor,
- ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange,
- ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add,
- ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => {
- return Err(error_unreachable())
- }
- ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min,
- ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min,
- ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max,
- ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max,
- ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext,
- ptx_parser::AtomicOp::FloatMin => todo!(),
- ptx_parser::AtomicOp::FloatMax => todo!(),
- };
- let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone()));
- let memory_const = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(scope_to_spirv(details.scope) as u32),
- )?;
- let semantics_const = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(semantics_to_spirv(details.semantics).bits()),
- )?;
- spirv_op(
- builder,
- result_type.0,
- Some(arg.dst.0),
- arg.src1.0,
- memory_const.0,
- semantics_const.0,
- arg.src2.0,
- )?;
- Ok(())
-}
-
-fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope {
- match this {
- ast::MemScope::Cta => spirv::Scope::Workgroup,
- ast::MemScope::Gpu => spirv::Scope::Device,
- ast::MemScope::Sys => spirv::Scope::CrossDevice,
- ptx_parser::MemScope::Cluster => todo!(),
- }
-}
-
-fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics {
- match this {
- ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED,
- ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE,
- ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE,
- ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE,
- }
-}
-
-fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) {
- match kind {
- ast::DivFloatKind::Approx => {
- builder.decorate(
- dst.0,
- spirv::Decoration::FPFastMathMode,
- [dr::Operand::FPFastMathMode(
- spirv::FPFastMathMode::ALLOW_RECIP,
- )]
- .iter()
- .cloned(),
- );
- }
- ast::DivFloatKind::Rounding(rnd) => {
- emit_rounding_decoration(builder, dst, Some(rnd));
- }
- ast::DivFloatKind::ApproxFull => {}
- }
-}
-
-fn emit_sqrt(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- details: &ast::RcpData,
- a: &ast::SqrtArgs<SpirvWord>,
-) -> Result<(), TranslateError> {
- let result_type = map.get_or_add_scalar(builder, details.type_.into());
- let (ocl_op, rounding) = match details.kind {
- ast::RcpKind::Approx => (spirv::CLOp::sqrt, None),
- ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
- };
- builder.ext_inst(
- result_type.0,
- Some(a.dst.0),
- opencl,
- ocl_op as spirv::Word,
- [dr::Operand::IdRef(a.src.0)].iter().cloned(),
- )?;
- emit_rounding_decoration(builder, a.dst, rounding);
- Ok(())
-}
-
-// TODO: check what kind of assembly do we emit
-fn emit_logical_xor_spirv(
- builder: &mut dr::Builder,
- result_type: spirv::Word,
- result_id: Option<spirv::Word>,
- op1: spirv::Word,
- op2: spirv::Word,
-) -> Result<spirv::Word, dr::Error> {
- let temp_or = builder.logical_or(result_type, None, op1, op2)?;
- let temp_and = builder.logical_and(result_type, None, op1, op2)?;
- let temp_neg = logical_not(builder, result_type, None, temp_and)?;
- builder.logical_and(result_type, result_id, temp_or, temp_neg)
-}
-
-fn emit_load_var(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- details: &LoadVarDetails,
-) -> Result<(), TranslateError> {
- let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone()));
- match details.member_index {
- Some((index, Some(width))) => {
- let vector_type = match details.typ {
- ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t),
- _ => return Err(error_mismatched_type()),
- };
- let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
- let vector_temp = builder.load(
- vector_type_spirv.0,
- None,
- details.arg.src.0,
- None,
- iter::empty(),
- )?;
- builder.composite_extract(
- result_type.0,
- Some(details.arg.dst.0),
- vector_temp,
- [index as u32].iter().copied(),
- )?;
- }
- Some((index, None)) => {
- let result_ptr_type = map.get_or_add(
- builder,
- SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function),
- );
- let index_spirv = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(index as u32),
- )?;
- let src = builder.in_bounds_access_chain(
- result_ptr_type.0,
- None,
- details.arg.src.0,
- [index_spirv.0].iter().copied(),
- )?;
- builder.load(
- result_type.0,
- Some(details.arg.dst.0),
- src,
- None,
- iter::empty(),
- )?;
- }
- None => {
- builder.load(
- result_type.0,
- Some(details.arg.dst.0),
- details.arg.src.0,
- None,
- iter::empty(),
- )?;
- }
- };
- Ok(())
-}
-
-fn to_parts(this: &ast::Type) -> TypeParts {
- match this {
- ast::Type::Scalar(scalar) => TypeParts {
- kind: TypeKind::Scalar,
- state_space: ast::StateSpace::Reg,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: Vec::new(),
- },
- ast::Type::Vector(components, scalar) => TypeParts {
- kind: TypeKind::Vector,
- state_space: ast::StateSpace::Reg,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*components as u32],
- },
- ast::Type::Array(_, scalar, components) => TypeParts {
- kind: TypeKind::Array,
- state_space: ast::StateSpace::Reg,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: components.clone(),
- },
- ast::Type::Pointer(scalar, space) => TypeParts {
- kind: TypeKind::Pointer,
- state_space: *space,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: Vec::new(),
- },
- }
-}
-
-fn type_from_parts(t: TypeParts) -> ast::Type {
- match t.kind {
- TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)),
- TypeKind::Vector => ast::Type::Vector(
- t.components[0] as u8,
- scalar_from_parts(t.width, t.scalar_kind),
- ),
- TypeKind::Array => ast::Type::Array(
- None,
- scalar_from_parts(t.width, t.scalar_kind),
- t.components,
- ),
- TypeKind::Pointer => {
- ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space)
- }
- }
-}
-
-#[derive(Eq, PartialEq, Clone)]
-struct TypeParts {
- kind: TypeKind,
- scalar_kind: ast::ScalarKind,
- width: u8,
- state_space: ast::StateSpace,
- components: Vec<u32>,
-}
-
-#[derive(Eq, PartialEq, Copy, Clone)]
-enum TypeKind {
- Scalar,
- Vector,
- Array,
- Pointer,
-}
diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs deleted file mode 100644 index e496c75..0000000 --- a/ptx/src/pass/expand_arguments.rs +++ /dev/null @@ -1,181 +0,0 @@ -use super::*;
-use ptx_parser as ast;
-
-pub(super) fn run<'a, 'b>(
- func: Vec<TypedStatement>,
- id_def: &'b mut MutableNumericIdResolver<'a>,
-) -> Result<Vec<ExpandedStatement>, TranslateError> {
- let mut result = Vec::with_capacity(func.len());
- for s in func {
- match s {
- Statement::Label(id) => result.push(Statement::Label(id)),
- Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
- Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
- Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
- Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
- Statement::Constant(c) => result.push(Statement::Constant(c)),
- Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
- s => {
- let (new_statement, post_stmts) = {
- let mut visitor = FlattenArguments::new(&mut result, id_def);
- (s.visit_map(&mut visitor)?, visitor.post_stmts)
- };
- result.push(new_statement);
- result.extend(post_stmts);
- }
- }
- }
- Ok(result)
-}
-
-struct FlattenArguments<'a, 'b> {
- func: &'b mut Vec<ExpandedStatement>,
- id_def: &'b mut MutableNumericIdResolver<'a>,
- post_stmts: Vec<ExpandedStatement>,
-}
-
-impl<'a, 'b> FlattenArguments<'a, 'b> {
- fn new(
- func: &'b mut Vec<ExpandedStatement>,
- id_def: &'b mut MutableNumericIdResolver<'a>,
- ) -> Self {
- FlattenArguments {
- func,
- id_def,
- post_stmts: Vec::new(),
- }
- }
-
- fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
- Ok(name)
- }
-
- fn reg_offset(
- &mut self,
- reg: SpirvWord,
- offset: i32,
- type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- _is_dst: bool,
- ) -> Result<SpirvWord, TranslateError> {
- let (type_, state_space) = if let Some((type_, state_space)) = type_space {
- (type_, state_space)
- } else {
- return Err(TranslateError::UntypedSymbol);
- };
- if state_space == ast::StateSpace::Reg {
- let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
- if reg_space != ast::StateSpace::Reg {
- return Err(error_mismatched_type());
- }
- let reg_scalar_type = match reg_type {
- ast::Type::Scalar(underlying_type) => underlying_type,
- _ => return Err(error_mismatched_type()),
- };
- let id_constant_stmt = self
- .id_def
- .register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: reg_scalar_type,
- value: ast::ImmediateValue::S64(offset as i64),
- }));
- let arith_details = match reg_scalar_type.kind() {
- ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
- type_: reg_scalar_type,
- saturate: false,
- }),
- ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
- ast::ArithDetails::Integer(ast::ArithInteger {
- type_: reg_scalar_type,
- saturate: false,
- })
- }
- _ => return Err(error_unreachable()),
- };
- let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
- self.func
- .push(Statement::Instruction(ast::Instruction::Add {
- data: arith_details,
- arguments: ast::AddArgs {
- dst: id_add_result,
- src1: reg,
- src2: id_constant_stmt,
- },
- }));
- Ok(id_add_result)
- } else {
- let id_constant_stmt = self.id_def.register_intermediate(
- ast::Type::Scalar(ast::ScalarType::S64),
- ast::StateSpace::Reg,
- );
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: ast::ScalarType::S64,
- value: ast::ImmediateValue::S64(offset as i64),
- }));
- let dst = self
- .id_def
- .register_intermediate(type_.clone(), state_space);
- self.func.push(Statement::PtrAccess(PtrAccess {
- underlying_type: type_.clone(),
- state_space: state_space,
- dst,
- ptr_src: reg,
- offset_src: id_constant_stmt,
- }));
- Ok(dst)
- }
- }
-
- fn immediate(
- &mut self,
- value: ast::ImmediateValue,
- type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- ) -> Result<SpirvWord, TranslateError> {
- let (scalar_t, state_space) =
- if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
- (*scalar, state_space)
- } else {
- return Err(TranslateError::UntypedSymbol);
- };
- let id = self
- .id_def
- .register_intermediate(ast::Type::Scalar(scalar_t), state_space);
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id,
- typ: scalar_t,
- value,
- }));
- Ok(id)
- }
-}
-
-impl<'a, 'b> ast::VisitorMap<TypedOperand, SpirvWord, TranslateError> for FlattenArguments<'a, 'b> {
- fn visit(
- &mut self,
- args: TypedOperand,
- type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- is_dst: bool,
- _relaxed_type_check: bool,
- ) -> Result<SpirvWord, TranslateError> {
- match args {
- TypedOperand::Reg(r) => self.reg(r),
- TypedOperand::Imm(x) => self.immediate(x, type_space),
- TypedOperand::RegOffset(reg, offset) => {
- self.reg_offset(reg, offset, type_space, is_dst)
- }
- TypedOperand::VecMember(..) => Err(error_unreachable()),
- }
- }
-
- fn visit_ident(
- &mut self,
- name: <TypedOperand as ptx_parser::Operand>::Ident,
- _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- _is_dst: bool,
- _relaxed_type_check: bool,
- ) -> Result<<SpirvWord as ptx_parser::Operand>::Ident, TranslateError> {
- self.reg(name)
- }
-}
diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs index 3dabf40..f2de786 100644 --- a/ptx/src/pass/expand_operands.rs +++ b/ptx/src/pass/expand_operands.rs @@ -189,15 +189,12 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { fn vec_member(
&mut self,
- vector_src: SpirvWord,
+ vector_ident: SpirvWord,
member: u8,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
- if is_dst {
- return Err(error_mismatched_type());
- }
- let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
+ let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_ident)? {
(ast::Type::Vector(vector_width, scalar_t), space) => {
(*vector_width, *scalar_t, *space)
}
@@ -206,35 +203,46 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { let temporary = self
.resolver
.register_unnamed(Some((scalar_type.into(), space)));
- self.result.push(Statement::VectorAccess(VectorAccess {
- scalar_type,
- vector_width,
- dst: temporary,
- src: vector_src,
- member: member,
- }));
+ if is_dst {
+ self.post_stmts.push(Statement::VectorWrite(VectorWrite {
+ scalar_type,
+ vector_width,
+ vector_dst: vector_ident,
+ vector_src: vector_ident,
+ scalar_src: temporary,
+ member,
+ }));
+ } else {
+ self.result.push(Statement::VectorRead(VectorRead {
+ scalar_type,
+ vector_width,
+ scalar_dst: temporary,
+ vector_src: vector_ident,
+ member,
+ }));
+ }
Ok(temporary)
}
fn vec_pack(
&mut self,
- vecs: Vec<SpirvWord>,
+ vector_elements: Vec<SpirvWord>,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
- let (scalar_t, state_space) = match type_space {
- Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
+ let (width, scalar_t, state_space) = match type_space {
+ Some((ast::Type::Vector(width, scalar_t), space)) => (*width, *scalar_t, space),
_ => return Err(error_mismatched_type()),
};
- let temp_vec = self
+ let temporary_vector = self
.resolver
- .register_unnamed(Some((scalar_t.into(), state_space)));
+ .register_unnamed(Some((ast::Type::Vector(width, scalar_t), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
- packed: temp_vec,
- unpacked: vecs,
+ packed: temporary_vector,
+ unpacked: vector_elements,
relaxed_type_check,
});
if is_dst {
@@ -242,7 +250,7 @@ impl<'a, 'input> FlattenArguments<'a, 'input> { } else {
self.result.push(statement);
}
- Ok(temp_vec)
+ Ok(temporary_vector)
}
}
@@ -273,7 +281,7 @@ impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, Translate fn visit_ident(
&mut self,
- name: <TypedOperand as ast::Operand>::Ident,
+ name: SpirvWord,
_type_space: Option<(&ast::Type, ast::StateSpace)>,
_is_dst: bool,
_relaxed_type_check: bool,
diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs deleted file mode 100644 index 2912366..0000000 --- a/ptx/src/pass/extract_globals.rs +++ /dev/null @@ -1,281 +0,0 @@ -use super::*;
-
-pub(super) fn run<'input, 'b>(
- sorted_statements: Vec<ExpandedStatement>,
- ptx_impl_imports: &mut HashMap<String, Directive>,
- id_def: &mut NumericIdResolver,
-) -> Result<(Vec<ExpandedStatement>, Vec<ast::Variable<SpirvWord>>), TranslateError> {
- let mut local = Vec::with_capacity(sorted_statements.len());
- let mut global = Vec::new();
- for statement in sorted_statements {
- match statement {
- Statement::Variable(
- var @ ast::Variable {
- state_space: ast::StateSpace::Shared,
- ..
- },
- )
- | Statement::Variable(
- var @ ast::Variable {
- state_space: ast::StateSpace::Global,
- ..
- },
- ) => global.push(var),
- Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => {
- let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat();
- local.push(instruction_to_fn_call(
- id_def,
- ptx_impl_imports,
- ast::Instruction::Bfe { data, arguments },
- fn_name,
- )?);
- }
- Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => {
- let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat();
- local.push(instruction_to_fn_call(
- id_def,
- ptx_impl_imports,
- ast::Instruction::Bfi { data, arguments },
- fn_name,
- )?);
- }
- Statement::Instruction(ast::Instruction::Brev { data, arguments }) => {
- let fn_name: String =
- [ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat();
- local.push(instruction_to_fn_call(
- id_def,
- ptx_impl_imports,
- ast::Instruction::Brev { data, arguments },
- fn_name,
- )?);
- }
- Statement::Instruction(ast::Instruction::Activemask { arguments }) => {
- let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
- local.push(instruction_to_fn_call(
- id_def,
- ptx_impl_imports,
- ast::Instruction::Activemask { arguments },
- fn_name,
- )?);
- }
- Statement::Instruction(ast::Instruction::Atom {
- data:
- data @ ast::AtomDetails {
- op: ast::AtomicOp::IncrementWrap,
- semantics,
- scope,
- space,
- ..
- },
- arguments,
- }) => {
- let fn_name = [
- ZLUDA_PTX_PREFIX,
- "atom_",
- semantics_to_ptx_name(semantics),
- "_",
- scope_to_ptx_name(scope),
- "_",
- space_to_ptx_name(space),
- "_inc",
- ]
- .concat();
- local.push(instruction_to_fn_call(
- id_def,
- ptx_impl_imports,
- ast::Instruction::Atom { data, arguments },
- fn_name,
- )?);
- }
- Statement::Instruction(ast::Instruction::Atom {
- data:
- data @ ast::AtomDetails {
- op: ast::AtomicOp::DecrementWrap,
- semantics,
- scope,
- space,
- ..
- },
- arguments,
- }) => {
- let fn_name = [
- ZLUDA_PTX_PREFIX,
- "atom_",
- semantics_to_ptx_name(semantics),
- "_",
- scope_to_ptx_name(scope),
- "_",
- space_to_ptx_name(space),
- "_dec",
- ]
- .concat();
- local.push(instruction_to_fn_call(
- id_def,
- ptx_impl_imports,
- ast::Instruction::Atom { data, arguments },
- fn_name,
- )?);
- }
- Statement::Instruction(ast::Instruction::Atom {
- data:
- data @ ast::AtomDetails {
- op: ast::AtomicOp::FloatAdd,
- semantics,
- scope,
- space,
- ..
- },
- arguments,
- }) => {
- let scalar_type = match data.type_ {
- ptx_parser::Type::Scalar(scalar) => scalar,
- _ => return Err(error_unreachable()),
- };
- let fn_name = [
- ZLUDA_PTX_PREFIX,
- "atom_",
- semantics_to_ptx_name(semantics),
- "_",
- scope_to_ptx_name(scope),
- "_",
- space_to_ptx_name(space),
- "_add_",
- scalar_to_ptx_name(scalar_type),
- ]
- .concat();
- local.push(instruction_to_fn_call(
- id_def,
- ptx_impl_imports,
- ast::Instruction::Atom { data, arguments },
- fn_name,
- )?);
- }
- s => local.push(s),
- }
- }
- Ok((local, global))
-}
-
-fn instruction_to_fn_call(
- id_defs: &mut NumericIdResolver,
- ptx_impl_imports: &mut HashMap<String, Directive>,
- inst: ast::Instruction<SpirvWord>,
- fn_name: String,
-) -> Result<ExpandedStatement, TranslateError> {
- let mut arguments = Vec::new();
- ast::visit_map(inst, &mut |operand,
- type_space: Option<(
- &ast::Type,
- ast::StateSpace,
- )>,
- is_dst,
- _| {
- let (typ, space) = match type_space {
- Some((typ, space)) => (typ.clone(), space),
- None => return Err(error_unreachable()),
- };
- arguments.push((operand, is_dst, typ, space));
- Ok(SpirvWord(0))
- })?;
- let return_arguments_count = arguments
- .iter()
- .position(|(desc, is_dst, _, _)| !is_dst)
- .unwrap_or(arguments.len());
- let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
- let fn_id = register_external_fn_call(
- id_defs,
- ptx_impl_imports,
- fn_name,
- return_arguments
- .iter()
- .map(|(_, _, typ, state)| (typ, *state)),
- input_arguments
- .iter()
- .map(|(_, _, typ, state)| (typ, *state)),
- )?;
- Ok(Statement::Instruction(ast::Instruction::Call {
- data: ast::CallDetails {
- uniform: false,
- return_arguments: return_arguments
- .iter()
- .map(|(_, _, typ, state)| (typ.clone(), *state))
- .collect::<Vec<_>>(),
- input_arguments: input_arguments
- .iter()
- .map(|(_, _, typ, state)| (typ.clone(), *state))
- .collect::<Vec<_>>(),
- },
- arguments: ast::CallArgs {
- return_arguments: return_arguments
- .iter()
- .map(|(name, _, _, _)| *name)
- .collect::<Vec<_>>(),
- func: fn_id,
- input_arguments: input_arguments
- .iter()
- .map(|(name, _, _, _)| *name)
- .collect::<Vec<_>>(),
- },
- }))
-}
-
-fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
- match this {
- ast::ScalarType::B8 => "b8",
- ast::ScalarType::B16 => "b16",
- ast::ScalarType::B32 => "b32",
- ast::ScalarType::B64 => "b64",
- ast::ScalarType::B128 => "b128",
- ast::ScalarType::U8 => "u8",
- ast::ScalarType::U16 => "u16",
- ast::ScalarType::U16x2 => "u16x2",
- ast::ScalarType::U32 => "u32",
- ast::ScalarType::U64 => "u64",
- ast::ScalarType::S8 => "s8",
- ast::ScalarType::S16 => "s16",
- ast::ScalarType::S16x2 => "s16x2",
- ast::ScalarType::S32 => "s32",
- ast::ScalarType::S64 => "s64",
- ast::ScalarType::F16 => "f16",
- ast::ScalarType::F16x2 => "f16x2",
- ast::ScalarType::F32 => "f32",
- ast::ScalarType::F64 => "f64",
- ast::ScalarType::BF16 => "bf16",
- ast::ScalarType::BF16x2 => "bf16x2",
- ast::ScalarType::Pred => "pred",
- }
-}
-
-fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str {
- match this {
- ast::AtomSemantics::Relaxed => "relaxed",
- ast::AtomSemantics::Acquire => "acquire",
- ast::AtomSemantics::Release => "release",
- ast::AtomSemantics::AcqRel => "acq_rel",
- }
-}
-
-fn scope_to_ptx_name(this: ast::MemScope) -> &'static str {
- match this {
- ast::MemScope::Cta => "cta",
- ast::MemScope::Gpu => "gpu",
- ast::MemScope::Sys => "sys",
- ast::MemScope::Cluster => "cluster",
- }
-}
-
-fn space_to_ptx_name(this: ast::StateSpace) -> &'static str {
- match this {
- ast::StateSpace::Generic => "generic",
- ast::StateSpace::Global => "global",
- ast::StateSpace::Shared => "shared",
- ast::StateSpace::Reg => "reg",
- ast::StateSpace::Const => "const",
- ast::StateSpace::Local => "local",
- ast::StateSpace::Param => "param",
- ast::StateSpace::SharedCluster => "shared_cluster",
- ast::StateSpace::ParamEntry => "param_entry",
- ast::StateSpace::SharedCta => "shared_cta",
- ast::StateSpace::ParamFunc => "param_func",
- }
-}
diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs deleted file mode 100644 index c029016..0000000 --- a/ptx/src/pass/fix_special_registers.rs +++ /dev/null @@ -1,130 +0,0 @@ -use super::*;
-use std::collections::HashMap;
-
-pub(super) fn run<'a, 'b, 'input>(
- ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
- typed_statements: Vec<TypedStatement>,
- numeric_id_defs: &'a mut NumericIdResolver<'b>,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let result = Vec::with_capacity(typed_statements.len());
- let mut sreg_sresolver = SpecialRegisterResolver {
- ptx_impl_imports,
- numeric_id_defs,
- result,
- };
- for statement in typed_statements {
- let statement = statement.visit_map(&mut sreg_sresolver)?;
- sreg_sresolver.result.push(statement);
- }
- Ok(sreg_sresolver.result)
-}
-
-struct SpecialRegisterResolver<'a, 'b, 'input> {
- ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
- numeric_id_defs: &'a mut NumericIdResolver<'b>,
- result: Vec<TypedStatement>,
-}
-
-impl<'a, 'b, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
- for SpecialRegisterResolver<'a, 'b, 'input>
-{
- fn visit(
- &mut self,
- operand: TypedOperand,
- _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- is_dst: bool,
- _relaxed_type_check: bool,
- ) -> Result<TypedOperand, TranslateError> {
- operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index))
- }
-
- fn visit_ident(
- &mut self,
- args: SpirvWord,
- _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- is_dst: bool,
- _relaxed_type_check: bool,
- ) -> Result<SpirvWord, TranslateError> {
- self.replace_sreg(args, is_dst, None)
- }
-}
-
-impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> {
- fn replace_sreg(
- &mut self,
- name: SpirvWord,
- is_dst: bool,
- vector_index: Option<u8>,
- ) -> Result<SpirvWord, TranslateError> {
- if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) {
- if is_dst {
- return Err(error_mismatched_type());
- }
- let input_arguments = match (vector_index, sreg.get_function_input_type()) {
- (Some(idx), Some(inp_type)) => {
- if inp_type != ast::ScalarType::U8 {
- return Err(TranslateError::Unreachable);
- }
- let constant = self.numeric_id_defs.register_intermediate(Some((
- ast::Type::Scalar(inp_type),
- ast::StateSpace::Reg,
- )));
- self.result.push(Statement::Constant(ConstantDefinition {
- dst: constant,
- typ: inp_type,
- value: ast::ImmediateValue::U64(idx as u64),
- }));
- vec![(
- TypedOperand::Reg(constant),
- ast::Type::Scalar(inp_type),
- ast::StateSpace::Reg,
- )]
- }
- (None, None) => Vec::new(),
- _ => return Err(error_mismatched_type()),
- };
- let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
- let return_type = sreg.get_function_return_type();
- let fn_result = self.numeric_id_defs.register_intermediate(Some((
- ast::Type::Scalar(return_type),
- ast::StateSpace::Reg,
- )));
- let return_arguments = vec![(
- fn_result,
- ast::Type::Scalar(return_type),
- ast::StateSpace::Reg,
- )];
- let fn_call = register_external_fn_call(
- self.numeric_id_defs,
- self.ptx_impl_imports,
- ocl_fn_name.to_string(),
- return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
- input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
- )?;
- let data = ast::CallDetails {
- uniform: false,
- return_arguments: return_arguments
- .iter()
- .map(|(_, typ, space)| (typ.clone(), *space))
- .collect(),
- input_arguments: input_arguments
- .iter()
- .map(|(_, typ, space)| (typ.clone(), *space))
- .collect(),
- };
- let arguments = ast::CallArgs {
- return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
- func: fn_call,
- input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(),
- };
- self.result
- .push(Statement::Instruction(ast::Instruction::Call {
- data,
- arguments,
- }));
- Ok(fn_result)
- } else {
- Ok(name)
- }
- }
-}
diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs index 97f6356..8c3b794 100644 --- a/ptx/src/pass/fix_special_registers2.rs +++ b/ptx/src/pass/fix_special_registers2.rs @@ -31,10 +31,10 @@ pub(super) fn run<'a, 'input>( sreg_to_function,
result: Vec::new(),
};
- directives
- .into_iter()
- .map(|directive| run_directive(&mut visitor, directive))
- .collect::<Result<Vec<_>, _>>()
+ for directive in directives.into_iter() {
+ result.push(run_directive(&mut visitor, directive)?);
+ }
+ Ok(result)
}
fn run_directive<'a, 'input>(
@@ -112,7 +112,7 @@ impl<'a, 'b, 'input> is_dst: bool,
_relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
- self.replace_sreg(args, None, is_dst)
+ Ok(self.replace_sreg(args, None, is_dst)?.unwrap_or(args))
}
}
@@ -122,7 +122,7 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { name: SpirvWord,
vector_index: Option<u8>,
is_dst: bool,
- ) -> Result<SpirvWord, TranslateError> {
+ ) -> Result<Option<SpirvWord>, TranslateError> {
if let Some(sreg) = self.special_registers.get(name) {
if is_dst {
return Err(error_mismatched_type());
@@ -179,30 +179,33 @@ impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { data,
arguments,
}));
- Ok(fn_result)
+ Ok(Some(fn_result))
} else {
- Ok(name)
+ Ok(None)
}
}
}
-pub fn map_operand<T, U, Err>(
+pub fn map_operand<T: Copy, Err>(
this: ast::ParsedOperand<T>,
- fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
-) -> Result<ast::ParsedOperand<U>, Err> {
+ fn_: &mut impl FnMut(T, Option<u8>) -> Result<Option<T>, Err>,
+) -> Result<ast::ParsedOperand<T>, Err> {
Ok(match this {
- ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
+ ast::ParsedOperand::Reg(ident) => {
+ ast::ParsedOperand::Reg(fn_(ident, None)?.unwrap_or(ident))
+ }
ast::ParsedOperand::RegOffset(ident, offset) => {
- ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
+ ast::ParsedOperand::RegOffset(fn_(ident, None)?.unwrap_or(ident), offset)
}
ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
- ast::ParsedOperand::VecMember(ident, member) => {
- ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
- }
+ ast::ParsedOperand::VecMember(ident, member) => match fn_(ident, Some(member))? {
+ Some(ident) => ast::ParsedOperand::Reg(ident),
+ None => ast::ParsedOperand::VecMember(ident, member),
+ },
ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
idents
.into_iter()
- .map(|ident| fn_(ident, None))
+ .map(|ident| Ok(fn_(ident, None)?.unwrap_or(ident)))
.collect::<Result<Vec<_>, _>>()?,
),
})
diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs index 753172a..718c052 100644 --- a/ptx/src/pass/hoist_globals.rs +++ b/ptx/src/pass/hoist_globals.rs @@ -5,7 +5,7 @@ pub(super) fn run<'input>( ) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
let mut result = Vec::with_capacity(directives.len());
for mut directive in directives.into_iter() {
- run_directive(&mut result, &mut directive);
+ run_directive(&mut result, &mut directive)?;
result.push(directive);
}
Ok(result)
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index ec6498c..60c4a14 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -1,7 +1,4 @@ use super::*;
-use ptx_parser::VisitorMap;
-use rustc_hash::FxHashSet;
-
// This pass:
// * Turns all .local, .param and .reg in-body variables into .local variables
// (if _not_ an input method argument)
@@ -40,9 +37,6 @@ fn run_method<'a, 'input>( method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
let mut func_decl = method.func_decl;
- for arg in func_decl.return_arguments.iter_mut() {
- visitor.visit_variable(arg)?;
- }
let is_kernel = func_decl.name.is_kernel();
if is_kernel {
for arg in func_decl.input_arguments.iter_mut() {
@@ -52,17 +46,21 @@ fn run_method<'a, 'input>( let new_name = visitor
.resolver
.register_unnamed(Some((arg.v_type.clone(), new_space)));
- visitor.input_argument(old_name, new_name, old_space);
+ visitor.input_argument(old_name, new_name, old_space)?;
arg.name = new_name;
arg.state_space = new_space;
}
};
+ for arg in func_decl.return_arguments.iter_mut() {
+ visitor.visit_variable(arg)?;
+ }
+ let return_arguments = &func_decl.return_arguments[..];
let body = method
.body
.map(move |statements| {
let mut result = Vec::with_capacity(statements.len());
for statement in statements {
- run_statement(&mut visitor, &mut result, statement)?;
+ run_statement(&mut visitor, return_arguments, &mut result, statement)?;
}
Ok::<_, TranslateError>(result)
})
@@ -79,10 +77,33 @@ fn run_method<'a, 'input>( fn run_statement<'a, 'input>(
visitor: &mut InsertMemSSAVisitor<'a, 'input>,
+ return_arguments: &[ast::Variable<SpirvWord>],
result: &mut Vec<ExpandedStatement>,
statement: ExpandedStatement,
) -> Result<(), TranslateError> {
match statement {
+ Statement::Instruction(ast::Instruction::Ret { data }) => {
+ let statement = if return_arguments.is_empty() {
+ Statement::Instruction(ast::Instruction::Ret { data })
+ } else {
+ Statement::RetValue(
+ data,
+ return_arguments
+ .iter()
+ .map(|arg| {
+ if arg.state_space != ast::StateSpace::Local {
+ return Err(error_unreachable());
+ }
+ Ok((arg.name, arg.v_type.clone()))
+ })
+ .collect::<Result<Vec<_>, _>>()?,
+ )
+ };
+ let new_statement = statement.visit_map(visitor)?;
+ result.extend(visitor.pre.drain(..).map(Statement::Instruction));
+ result.push(new_statement);
+ result.extend(visitor.post.drain(..).map(Statement::Instruction));
+ }
Statement::Variable(mut var) => {
visitor.visit_variable(&mut var)?;
result.push(Statement::Variable(var));
@@ -154,7 +175,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
- ) -> Result<(), TranslateError> {
+ ) -> Result<bool, TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables.insert(
@@ -164,6 +185,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { type_: type_.clone(),
},
);
+ true
}
ast::StateSpace::Param => {
self.variables.insert(
@@ -174,19 +196,18 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { name: new_name,
},
);
+ true
}
// Good as-is
- ast::StateSpace::Local => {}
- // Will be pulled into global scope later
- ast::StateSpace::Generic
+ ast::StateSpace::Local
+ | ast::StateSpace::Generic
| ast::StateSpace::SharedCluster
| ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::SharedCta
- | ast::StateSpace::Shared => {}
- ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
- return Err(error_unreachable())
- }
+ | ast::StateSpace::Shared
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc => return Err(error_unreachable()),
})
}
@@ -239,17 +260,28 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
- if var.state_space != ast::StateSpace::Local {
- let old_name = var.name;
- let old_space = var.state_space;
- let new_space = ast::StateSpace::Local;
- let new_name = self
- .resolver
- .register_unnamed(Some((var.v_type.clone(), new_space)));
- self.variable(&var.v_type, old_name, new_name, old_space)?;
- var.name = new_name;
- var.state_space = new_space;
- }
+ let old_space = match var.state_space {
+ space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
+ // Do nothing
+ ptx_parser::StateSpace::Local => return Ok(()),
+ // Handled by another pass
+ ptx_parser::StateSpace::Generic
+ | ptx_parser::StateSpace::SharedCluster
+ | ptx_parser::StateSpace::ParamEntry
+ | ptx_parser::StateSpace::Global
+ | ptx_parser::StateSpace::SharedCta
+ | ptx_parser::StateSpace::Const
+ | ptx_parser::StateSpace::Shared
+ | ptx_parser::StateSpace::ParamFunc => return Ok(()),
+ };
+ let old_name = var.name;
+ let new_space = ast::StateSpace::Local;
+ let new_name = self
+ .resolver
+ .register_unnamed(Some((var.v_type.clone(), new_space)));
+ self.variable(&var.v_type, old_name, new_name, old_space)?;
+ var.name = new_name;
+ var.state_space = new_space;
Ok(())
}
}
@@ -260,9 +292,9 @@ impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError> fn visit(
&mut self,
ident: SpirvWord,
- type_space: Option<(&ast::Type, ast::StateSpace)>,
+ _type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
- relaxed_type_check: bool,
+ _relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
if let Some(remap) = self.variables.get(&ident) {
match remap {
diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs deleted file mode 100644 index c04fa09..0000000 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ /dev/null @@ -1,438 +0,0 @@ -use std::mem;
-
-use super::*;
-use ptx_parser as ast;
-
-/*
- There are several kinds of implicit conversions in PTX:
- * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
- * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
- semantics are to first zext/chop/bitcast `y` as needed and then do
- documented special ld/st/cvt conversion rules for destination operands
- - st.param [x] y (used as function return arguments) same rule as above applies
- - generic/global ld: for instruction `ld x, [y]`, y must be of type
- b64/u64/s64, which is bitcast to a pointer, dereferenced and then
- documented special ld/st/cvt conversion rules are applied to dst
- - generic/global st: for instruction `st [x], y`, x must be of type
- b64/u64/s64, which is bitcast to a pointer
-*/
-pub(super) fn run(
- func: Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
-) -> Result<Vec<ExpandedStatement>, TranslateError> {
- let mut result = Vec::with_capacity(func.len());
- for s in func.into_iter() {
- match s {
- Statement::Instruction(inst) => {
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- Statement::Instruction(inst),
- )?;
- }
- Statement::PtrAccess(access) => {
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- Statement::PtrAccess(access),
- )?;
- }
- Statement::RepackVector(repack) => {
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- Statement::RepackVector(repack),
- )?;
- }
- Statement::VectorAccess(vector_access) => {
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- Statement::VectorAccess(vector_access),
- )?;
- }
- s @ Statement::Conditional(_)
- | s @ Statement::Conversion(_)
- | s @ Statement::Label(_)
- | s @ Statement::Constant(_)
- | s @ Statement::Variable(_)
- | s @ Statement::LoadVar(..)
- | s @ Statement::StoreVar(..)
- | s @ Statement::RetValue(..)
- | s @ Statement::FunctionPointer(..) => result.push(s),
- }
- }
- Ok(result)
-}
-
-fn insert_implicit_conversions_impl(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- stmt: ExpandedStatement,
-) -> Result<(), TranslateError> {
- let mut post_conv = Vec::new();
- let statement = stmt.visit_map::<SpirvWord, TranslateError>(
- &mut |operand,
- type_state: Option<(&ast::Type, ast::StateSpace)>,
- is_dst,
- relaxed_type_check| {
- let (instr_type, instruction_space) = match type_state {
- None => return Ok(operand),
- Some(t) => t,
- };
- let (operand_type, operand_space) = id_def.get_typed(operand)?;
- let conversion_fn = if relaxed_type_check {
- if is_dst {
- should_convert_relaxed_dst_wrapper
- } else {
- should_convert_relaxed_src_wrapper
- }
- } else {
- default_implicit_conversion
- };
- match conversion_fn(
- (operand_space, &operand_type),
- (instruction_space, instr_type),
- )? {
- Some(conv_kind) => {
- let conv_output = if is_dst { &mut post_conv } else { &mut *func };
- let mut from_type = instr_type.clone();
- let mut from_space = instruction_space;
- let mut to_type = operand_type;
- let mut to_space = operand_space;
- let mut src =
- id_def.register_intermediate(instr_type.clone(), instruction_space);
- let mut dst = operand;
- let result = Ok::<_, TranslateError>(src);
- if !is_dst {
- mem::swap(&mut src, &mut dst);
- mem::swap(&mut from_type, &mut to_type);
- mem::swap(&mut from_space, &mut to_space);
- }
- conv_output.push(Statement::Conversion(ImplicitConversion {
- src,
- dst,
- from_type,
- from_space,
- to_type,
- to_space,
- kind: conv_kind,
- }));
- result
- }
- None => Ok(operand),
- }
- },
- )?;
- func.push(statement);
- func.append(&mut post_conv);
- Ok(())
-}
-
-pub(crate) fn default_implicit_conversion(
- (operand_space, operand_type): (ast::StateSpace, &ast::Type),
- (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
-) -> Result<Option<ConversionKind>, TranslateError> {
- if instruction_space == ast::StateSpace::Reg {
- if operand_space == ast::StateSpace::Reg {
- if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
- (operand_type, instruction_type)
- {
- if scalar.kind() == ast::ScalarKind::Bit
- && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
- {
- return Ok(Some(ConversionKind::Default));
- }
- }
- } else if is_addressable(operand_space) {
- return Ok(Some(ConversionKind::AddressOf));
- }
- }
- if instruction_space != operand_space {
- default_implicit_conversion_space(
- (operand_space, operand_type),
- (instruction_space, instruction_type),
- )
- } else if instruction_type != operand_type {
- default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
- } else {
- Ok(None)
- }
-}
-
-fn is_addressable(this: ast::StateSpace) -> bool {
- match this {
- ast::StateSpace::Const
- | ast::StateSpace::Generic
- | ast::StateSpace::Global
- | ast::StateSpace::Local
- | ast::StateSpace::Shared => true,
- ast::StateSpace::Param | ast::StateSpace::Reg => false,
- ast::StateSpace::SharedCluster
- | ast::StateSpace::SharedCta
- | ast::StateSpace::ParamEntry
- | ast::StateSpace::ParamFunc => todo!(),
- }
-}
-
-// Space is different
-fn default_implicit_conversion_space(
- (operand_space, operand_type): (ast::StateSpace, &ast::Type),
- (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
-) -> Result<Option<ConversionKind>, TranslateError> {
- if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
- || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
- {
- Ok(Some(ConversionKind::PtrToPtr))
- } else if operand_space == ast::StateSpace::Reg {
- match operand_type {
- ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
- if *operand_ptr_space == instruction_space =>
- {
- if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
- Ok(Some(ConversionKind::PtrToPtr))
- } else {
- Ok(None)
- }
- }
- // TODO: 32 bit
- ast::Type::Scalar(ast::ScalarType::B64)
- | ast::Type::Scalar(ast::ScalarType::U64)
- | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
- ast::StateSpace::Global
- | ast::StateSpace::Generic
- | ast::StateSpace::Const
- | ast::StateSpace::Local
- | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
- _ => Err(error_mismatched_type()),
- },
- ast::Type::Scalar(ast::ScalarType::B32)
- | ast::Type::Scalar(ast::ScalarType::U32)
- | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
- ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
- Ok(Some(ConversionKind::BitToPtr))
- }
- _ => Err(error_mismatched_type()),
- },
- _ => Err(error_mismatched_type()),
- }
- } else if instruction_space == ast::StateSpace::Reg {
- match instruction_type {
- ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
- if operand_space == *instruction_ptr_space =>
- {
- if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
- Ok(Some(ConversionKind::PtrToPtr))
- } else {
- Ok(None)
- }
- }
- _ => Err(error_mismatched_type()),
- }
- } else {
- Err(error_mismatched_type())
- }
-}
-
-// Space is same, but type is different
-fn default_implicit_conversion_type(
- space: ast::StateSpace,
- operand_type: &ast::Type,
- instruction_type: &ast::Type,
-) -> Result<Option<ConversionKind>, TranslateError> {
- if space == ast::StateSpace::Reg {
- if should_bitcast(instruction_type, operand_type) {
- Ok(Some(ConversionKind::Default))
- } else {
- Err(TranslateError::MismatchedType)
- }
- } else {
- Ok(Some(ConversionKind::PtrToPtr))
- }
-}
-
-fn coerces_to_generic(this: ast::StateSpace) -> bool {
- match this {
- ast::StateSpace::Global
- | ast::StateSpace::Const
- | ast::StateSpace::Local
- | ptx_parser::StateSpace::SharedCta
- | ast::StateSpace::SharedCluster
- | ast::StateSpace::Shared => true,
- ast::StateSpace::Reg
- | ast::StateSpace::Param
- | ast::StateSpace::ParamEntry
- | ast::StateSpace::ParamFunc
- | ast::StateSpace::Generic => false,
- }
-}
-
-fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
- match (instr, operand) {
- (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
- if inst.size_of() != operand.size_of() {
- return false;
- }
- match inst.kind() {
- ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
- ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
- ast::ScalarKind::Signed => {
- operand.kind() == ast::ScalarKind::Bit
- || operand.kind() == ast::ScalarKind::Unsigned
- }
- ast::ScalarKind::Unsigned => {
- operand.kind() == ast::ScalarKind::Bit
- || operand.kind() == ast::ScalarKind::Signed
- }
- ast::ScalarKind::Pred => false,
- }
- }
- (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
- | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
- should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
- }
- _ => false,
- }
-}
-
-pub(crate) fn should_convert_relaxed_dst_wrapper(
- (operand_space, operand_type): (ast::StateSpace, &ast::Type),
- (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
-) -> Result<Option<ConversionKind>, TranslateError> {
- if operand_space != instruction_space {
- return Err(TranslateError::MismatchedType);
- }
- if operand_type == instruction_type {
- return Ok(None);
- }
- match should_convert_relaxed_dst(operand_type, instruction_type) {
- conv @ Some(_) => Ok(conv),
- None => Err(TranslateError::MismatchedType),
- }
-}
-
-// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
-fn should_convert_relaxed_dst(
- dst_type: &ast::Type,
- instr_type: &ast::Type,
-) -> Option<ConversionKind> {
- if dst_type == instr_type {
- return None;
- }
- match (dst_type, instr_type) {
- (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ast::ScalarKind::Bit => {
- if instr_type.size_of() <= dst_type.size_of() {
- Some(ConversionKind::Default)
- } else {
- None
- }
- }
- ast::ScalarKind::Signed => {
- if dst_type.kind() != ast::ScalarKind::Float {
- if instr_type.size_of() == dst_type.size_of() {
- Some(ConversionKind::Default)
- } else if instr_type.size_of() < dst_type.size_of() {
- Some(ConversionKind::SignExtend)
- } else {
- None
- }
- } else {
- None
- }
- }
- ast::ScalarKind::Unsigned => {
- if instr_type.size_of() <= dst_type.size_of()
- && dst_type.kind() != ast::ScalarKind::Float
- {
- Some(ConversionKind::Default)
- } else {
- None
- }
- }
- ast::ScalarKind::Float => {
- if instr_type.size_of() <= dst_type.size_of()
- && dst_type.kind() == ast::ScalarKind::Bit
- {
- Some(ConversionKind::Default)
- } else {
- None
- }
- }
- ast::ScalarKind::Pred => None,
- },
- (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
- | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
- should_convert_relaxed_dst(
- &ast::Type::Scalar(*dst_type),
- &ast::Type::Scalar(*instr_type),
- )
- }
- _ => None,
- }
-}
-
-pub(crate) fn should_convert_relaxed_src_wrapper(
- (operand_space, operand_type): (ast::StateSpace, &ast::Type),
- (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
-) -> Result<Option<ConversionKind>, TranslateError> {
- if operand_space != instruction_space {
- return Err(error_mismatched_type());
- }
- if operand_type == instruction_type {
- return Ok(None);
- }
- match should_convert_relaxed_src(operand_type, instruction_type) {
- conv @ Some(_) => Ok(conv),
- None => Err(error_mismatched_type()),
- }
-}
-
-// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
-fn should_convert_relaxed_src(
- src_type: &ast::Type,
- instr_type: &ast::Type,
-) -> Option<ConversionKind> {
- if src_type == instr_type {
- return None;
- }
- match (src_type, instr_type) {
- (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ast::ScalarKind::Bit => {
- if instr_type.size_of() <= src_type.size_of() {
- Some(ConversionKind::Default)
- } else {
- None
- }
- }
- ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
- if instr_type.size_of() <= src_type.size_of()
- && src_type.kind() != ast::ScalarKind::Float
- {
- Some(ConversionKind::Default)
- } else {
- None
- }
- }
- ast::ScalarKind::Float => {
- if instr_type.size_of() <= src_type.size_of()
- && src_type.kind() == ast::ScalarKind::Bit
- {
- Some(ConversionKind::Default)
- } else {
- None
- }
- }
- ast::ScalarKind::Pred => None,
- },
- (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
- | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
- should_convert_relaxed_src(
- &ast::Type::Scalar(*dst_type),
- &ast::Type::Scalar(*instr_type),
- )
- }
- _ => None,
- }
-}
diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs deleted file mode 100644 index 150109b..0000000 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ /dev/null @@ -1,275 +0,0 @@ -use super::*;
-use ptx_parser as ast;
-
-/*
- How do we handle arguments:
- - input .params in kernels
- .param .b64 in_arg
- get turned into this SPIR-V:
- %1 = OpFunctionParameter %ulong
- %2 = OpVariable %_ptr_Function_ulong Function
- OpStore %2 %1
- We do this for two reasons. One, common treatment for argument-declared
- .param variables and .param variables inside function (we assume that
- at SPIR-V level every .param is a pointer in Function storage class)
- - input .params in functions
- .param .b64 in_arg
- get turned into this SPIR-V:
- %1 = OpFunctionParameter %_ptr_Function_ulong
- - input .regs
- .reg .b64 in_arg
- get turned into the same SPIR-V as kernel .params:
- %1 = OpFunctionParameter %ulong
- %2 = OpVariable %_ptr_Function_ulong Function
- OpStore %2 %1
- - output .regs
- .reg .b64 out_arg
- get just a variable declaration:
- %2 = OpVariable %%_ptr_Function_ulong Function
- - output .params don't exist, they have been moved to input positions
- by an earlier pass
- Distinguishing betweem kernel .params and function .params is not the
- cleanest solution. Alternatively, we could "deparamize" all kernel .param
- arguments by turning them into .reg arguments like this:
- .param .b64 arg -> .reg ptr<.b64,.param> arg
- This has the massive downside that this transformation would have to run
- very early and would muddy up already difficult code. It's simpler to just
- have an if here
-*/
-pub(super) fn run<'a, 'b>(
- func: Vec<TypedStatement>,
- id_def: &mut NumericIdResolver,
- fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let mut result = Vec::with_capacity(func.len());
- for arg in fn_decl.input_arguments.iter_mut() {
- insert_mem_ssa_argument(
- id_def,
- &mut result,
- arg,
- matches!(fn_decl.name, ast::MethodName::Kernel(_)),
- );
- }
- for arg in fn_decl.return_arguments.iter() {
- insert_mem_ssa_argument_reg_return(&mut result, arg);
- }
- for s in func {
- match s {
- Statement::Instruction(inst) => match inst {
- ast::Instruction::Ret { data } => {
- // TODO: handle multiple output args
- match &fn_decl.return_arguments[..] {
- [return_reg] => {
- let new_id = id_def.register_intermediate(Some((
- return_reg.v_type.clone(),
- ast::StateSpace::Reg,
- )));
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: ast::LdArgs {
- dst: new_id,
- src: return_reg.name,
- },
- typ: return_reg.v_type.clone(),
- member_index: None,
- }));
- result.push(Statement::RetValue(data, new_id));
- }
- [] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
- _ => unimplemented!(),
- }
- }
- inst => insert_mem_ssa_statement_default(
- id_def,
- &mut result,
- Statement::Instruction(inst),
- )?,
- },
- Statement::Conditional(bra) => {
- insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
- }
- Statement::Conversion(conv) => {
- insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
- }
- Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
- id_def,
- &mut result,
- Statement::PtrAccess(ptr_access),
- )?,
- Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
- id_def,
- &mut result,
- Statement::RepackVector(repack),
- )?,
- Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
- id_def,
- &mut result,
- Statement::FunctionPointer(func_ptr),
- )?,
- s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
- result.push(s)
- }
- _ => return Err(error_unreachable()),
- }
- }
- Ok(result)
-}
-
-fn insert_mem_ssa_argument(
- id_def: &mut NumericIdResolver,
- func: &mut Vec<TypedStatement>,
- arg: &mut ast::Variable<SpirvWord>,
- is_kernel: bool,
-) {
- if !is_kernel && arg.state_space == ast::StateSpace::Param {
- return;
- }
- let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
- func.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: arg.v_type.clone(),
- state_space: ast::StateSpace::Reg,
- name: arg.name,
- array_init: Vec::new(),
- }));
- func.push(Statement::StoreVar(StoreVarDetails {
- arg: ast::StArgs {
- src1: arg.name,
- src2: new_id,
- },
- typ: arg.v_type.clone(),
- member_index: None,
- }));
- arg.name = new_id;
-}
-
-fn insert_mem_ssa_argument_reg_return(
- func: &mut Vec<TypedStatement>,
- arg: &ast::Variable<SpirvWord>,
-) {
- func.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: arg.v_type.clone(),
- state_space: arg.state_space,
- name: arg.name,
- array_init: arg.array_init.clone(),
- }));
-}
-
-fn insert_mem_ssa_statement_default<'a, 'input>(
- id_def: &'a mut NumericIdResolver<'input>,
- func: &'a mut Vec<TypedStatement>,
- stmt: TypedStatement,
-) -> Result<(), TranslateError> {
- let mut visitor = InsertMemSSAVisitor {
- id_def,
- func,
- post_statements: Vec::new(),
- };
- let new_stmt = stmt.visit_map(&mut visitor)?;
- visitor.func.push(new_stmt);
- visitor.func.extend(visitor.post_statements);
- Ok(())
-}
-
-struct InsertMemSSAVisitor<'a, 'input> {
- id_def: &'a mut NumericIdResolver<'input>,
- func: &'a mut Vec<TypedStatement>,
- post_statements: Vec<TypedStatement>,
-}
-
-impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
- fn symbol(
- &mut self,
- symbol: SpirvWord,
- member_index: Option<u8>,
- expected: Option<(&ast::Type, ast::StateSpace)>,
- is_dst: bool,
- ) -> Result<SpirvWord, TranslateError> {
- if expected.is_none() {
- return Ok(symbol);
- };
- let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
- if var_space != ast::StateSpace::Reg || !is_variable {
- return Ok(symbol);
- };
- let member_index = match member_index {
- Some(idx) => {
- let vector_width = match var_type {
- ast::Type::Vector(width, scalar_t) => {
- var_type = ast::Type::Scalar(scalar_t);
- width
- }
- _ => return Err(error_mismatched_type()),
- };
- Some((
- idx,
- if self.id_def.special_registers.get(symbol).is_some() {
- Some(vector_width)
- } else {
- None
- },
- ))
- }
- None => None,
- };
- let generated_id = self
- .id_def
- .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
- if !is_dst {
- self.func.push(Statement::LoadVar(LoadVarDetails {
- arg: ast::LdArgs {
- dst: generated_id,
- src: symbol,
- },
- typ: var_type,
- member_index,
- }));
- } else {
- self.post_statements
- .push(Statement::StoreVar(StoreVarDetails {
- arg: ast::StArgs {
- src1: symbol,
- src2: generated_id,
- },
- typ: var_type,
- member_index: member_index.map(|(idx, _)| idx),
- }));
- }
- Ok(generated_id)
- }
-}
-
-impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
- for InsertMemSSAVisitor<'a, 'input>
-{
- fn visit(
- &mut self,
- operand: TypedOperand,
- type_space: Option<(&ast::Type, ast::StateSpace)>,
- is_dst: bool,
- _relaxed_type_check: bool,
- ) -> Result<TypedOperand, TranslateError> {
- Ok(match operand {
- TypedOperand::Reg(reg) => {
- TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
- }
- TypedOperand::RegOffset(reg, offset) => {
- TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
- }
- op @ TypedOperand::Imm(..) => op,
- TypedOperand::VecMember(symbol, index) => {
- TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?)
- }
- })
- }
-
- fn visit_ident(
- &mut self,
- args: SpirvWord,
- type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
- is_dst: bool,
- relaxed_type_check: bool,
- ) -> Result<SpirvWord, TranslateError> {
- self.symbol(args, None, type_space, is_dst)
- }
-}
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 0e233ed..ef131b4 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,84 +1,43 @@ use ptx_parser as ast;
-use rspirv::{binary::Assemble, dr};
+use quick_error::quick_error;
use rustc_hash::FxHashMap;
use std::hash::Hash;
-use std::num::NonZeroU8;
use std::{
borrow::Cow,
- cell::RefCell,
- collections::{hash_map, HashMap, HashSet},
+ collections::{hash_map, HashMap},
ffi::CString,
iter,
- marker::PhantomData,
- mem,
- rc::Rc,
};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
-mod convert_dynamic_shared_memory_usage;
-mod convert_to_stateful_memory_access;
-mod convert_to_typed;
mod deparamize_functions;
pub(crate) mod emit_llvm;
-mod emit_spirv;
-mod expand_arguments;
mod expand_operands;
-mod extract_globals;
-mod fix_special_registers;
mod fix_special_registers2;
mod hoist_globals;
mod insert_explicit_load_store;
-mod insert_implicit_conversions;
mod insert_implicit_conversions2;
-mod insert_mem_ssa_statements;
-mod normalize_identifiers;
mod normalize_identifiers2;
-mod normalize_labels;
-mod normalize_predicates;
mod normalize_predicates2;
+mod replace_instructions_with_function_calls;
mod resolve_function_pointers;
-static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
-static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
-const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
+static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
+const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl_";
-pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
- let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1));
- let mut ptx_impl_imports = HashMap::new();
- let directives = ast
- .directives
- .into_iter()
- .filter_map(|directive| {
- translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
- })
- .collect::<Result<Vec<_>, _>>()?;
- let directives = hoist_function_globals(directives);
- let must_link_ptx_impl = ptx_impl_imports.len() > 0;
- let mut directives = ptx_impl_imports
- .into_iter()
- .map(|(_, v)| v)
- .chain(directives.into_iter())
- .collect::<Vec<_>>();
- let mut builder = dr::Builder::new();
- builder.reserve_ids(id_defs.current_id().0);
- let call_map = MethodsCallMap::new(&directives);
- let mut directives =
- convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || {
- SpirvWord(builder.id())
- })?;
- normalize_variable_decls(&mut directives);
- let denorm_information = compute_denorm_information(&directives);
- todo!()
- /*
- let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?;
- Ok(Module {
- llvm_ir,
- kernel_info: HashMap::new(),
- }) */
+quick_error! {
+ #[derive(Debug)]
+ pub enum TranslateError {
+ UnknownSymbol {}
+ UntypedSymbol {}
+ MismatchedType {}
+ Unreachable {}
+ Todo {}
+ }
}
-pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
+pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
@@ -86,11 +45,11 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, Trans let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
let directives = resolve_function_pointers::run(directives)?;
let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
- let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
- expand_operands::run(&mut flat_resolver, directives)?;
+ let directives = expand_operands::run(&mut flat_resolver, directives)?;
let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
+ let directives = replace_instructions_with_function_calls::run(&mut flat_resolver, directives)?;
let directives = hoist_globals::run(directives)?;
let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
Ok(Module {
@@ -99,254 +58,15 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, Trans })
}
-fn translate_directive<'input, 'a>(
- id_defs: &'a mut GlobalStringIdResolver<'input>,
- ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
- d: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
-) -> Result<Option<Directive<'input>>, TranslateError> {
- Ok(match d {
- ast::Directive::Variable(linking, var) => Some(Directive::Variable(
- linking,
- ast::Variable {
- align: var.align,
- v_type: var.v_type.clone(),
- state_space: var.state_space,
- name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
- array_init: var.array_init,
- },
- )),
- ast::Directive::Method(linkage, f) => {
- translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method)
- }
- })
-}
-
-type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement<ast::ParsedOperand<&'a str>>>;
-
-fn translate_function<'input, 'a>(
- id_defs: &'a mut GlobalStringIdResolver<'input>,
- ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
- linkage: ast::LinkingDirective,
- f: ParsedFunction<'input>,
-) -> Result<Option<Function<'input>>, TranslateError> {
- let import_as = match &f.func_directive {
- ast::MethodDeclaration {
- name: ast::MethodName::Func(func_name),
- ..
- } if *func_name == "__assertfail" || *func_name == "vprintf" => {
- Some([ZLUDA_PTX_PREFIX, func_name].concat())
- }
- _ => None,
- };
- let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
- let mut func = to_ssa(
- ptx_impl_imports,
- str_resolver,
- fn_resolver,
- fn_decl,
- f.body,
- f.tuning,
- linkage,
- )?;
- func.import_as = import_as;
- if func.import_as.is_some() {
- ptx_impl_imports.insert(
- func.import_as.as_ref().unwrap().clone(),
- Directive::Method(func),
- );
- Ok(None)
- } else {
- Ok(Some(func))
- }
-}
-
-fn to_ssa<'input, 'b>(
- ptx_impl_imports: &'b mut HashMap<String, Directive<'input>>,
- mut id_defs: FnStringIdResolver<'input, 'b>,
- fn_defs: GlobalFnDeclResolver<'input, 'b>,
- func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
- f_body: Option<Vec<ast::Statement<ast::ParsedOperand<&'input str>>>>,
- tuning: Vec<ast::TuningDirective>,
- linkage: ast::LinkingDirective,
-) -> Result<Function<'input>, TranslateError> {
- //deparamize_function_decl(&func_decl)?;
- let f_body = match f_body {
- Some(vec) => vec,
- None => {
- return Ok(Function {
- func_decl: func_decl,
- body: None,
- globals: Vec::new(),
- import_as: None,
- tuning,
- linkage,
- })
- }
- };
- let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?;
- let mut numeric_id_defs = id_defs.finish();
- let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?;
- let typed_statements =
- convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
- let typed_statements =
- fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
- let (func_decl, typed_statements) =
- convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
- let ssa_statements = insert_mem_ssa_statements::run(
- typed_statements,
- &mut numeric_id_defs,
- &mut (*func_decl).borrow_mut(),
- )?;
- let mut numeric_id_defs = numeric_id_defs.finish();
- let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?;
- let expanded_statements =
- insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?;
- let mut numeric_id_defs = numeric_id_defs.unmut();
- let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs);
- let (f_body, globals) =
- extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?;
- Ok(Function {
- func_decl: func_decl,
- globals: globals,
- body: Some(f_body),
- import_as: None,
- tuning,
- linkage,
- })
-}
-
pub struct Module {
pub llvm_ir: emit_llvm::MemoryBuffer,
pub kernel_info: HashMap<String, KernelInfo>,
}
-struct GlobalStringIdResolver<'input> {
- current_id: SpirvWord,
- variables: HashMap<Cow<'input, str>, SpirvWord>,
- pub(crate) reverse_variables: HashMap<SpirvWord, &'input str>,
- variables_type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
- special_registers: SpecialRegistersMap,
- fns: HashMap<SpirvWord, FnSigMapper<'input>>,
-}
-
-impl<'input> GlobalStringIdResolver<'input> {
- fn new(start_id: SpirvWord) -> Self {
- Self {
- current_id: start_id,
- variables: HashMap::new(),
- reverse_variables: HashMap::new(),
- variables_type_check: HashMap::new(),
- special_registers: SpecialRegistersMap::new(),
- fns: HashMap::new(),
- }
- }
-
- fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord {
- self.get_or_add_impl(id, None)
- }
-
- fn get_or_add_def_typed(
- &mut self,
- id: &'input str,
- typ: ast::Type,
- state_space: ast::StateSpace,
- is_variable: bool,
- ) -> SpirvWord {
- self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
- }
-
- fn get_or_add_impl(
- &mut self,
- id: &'input str,
- typ: Option<(ast::Type, ast::StateSpace, bool)>,
- ) -> SpirvWord {
- let id = match self.variables.entry(Cow::Borrowed(id)) {
- hash_map::Entry::Occupied(e) => *(e.get()),
- hash_map::Entry::Vacant(e) => {
- let numeric_id = self.current_id;
- e.insert(numeric_id);
- self.reverse_variables.insert(numeric_id, id);
- self.current_id.0 += 1;
- numeric_id
- }
- };
- self.variables_type_check.insert(id, typ);
- id
- }
-
- fn get_id(&self, id: &str) -> Result<SpirvWord, TranslateError> {
- self.variables
- .get(id)
- .copied()
- .ok_or_else(error_unknown_symbol)
+impl Module {
+ pub fn linked_bitcode(&self) -> &[u8] {
+ ZLUDA_PTX_IMPL
}
-
- fn current_id(&self) -> SpirvWord {
- self.current_id
- }
-
- fn start_fn<'b>(
- &'b mut self,
- header: &'b ast::MethodDeclaration<'input, &'input str>,
- ) -> Result<
- (
- FnStringIdResolver<'input, 'b>,
- GlobalFnDeclResolver<'input, 'b>,
- Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
- ),
- TranslateError,
- > {
- // In case a function decl was inserted earlier we want to use its id
- let name_id = self.get_or_add_def(header.name());
- let mut fn_resolver = FnStringIdResolver {
- current_id: &mut self.current_id,
- global_variables: &self.variables,
- global_type_check: &self.variables_type_check,
- special_registers: &mut self.special_registers,
- variables: vec![HashMap::new(); 1],
- type_check: HashMap::new(),
- };
- let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
- let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
- let name = match header.name {
- ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
- ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
- };
- let fn_decl = ast::MethodDeclaration {
- return_arguments,
- name,
- input_arguments,
- shared_mem: None,
- };
- let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) {
- let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
- let new_fn_decl = resolver.func_decl.clone();
- self.fns.insert(name_id, resolver);
- new_fn_decl
- } else {
- Rc::new(RefCell::new(fn_decl))
- };
- Ok((
- fn_resolver,
- GlobalFnDeclResolver { fns: &self.fns },
- new_fn_decl,
- ))
- }
-}
-
-fn rename_fn_params<'a, 'b>(
- fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: &'b [ast::Variable<&'a str>],
-) -> Vec<ast::Variable<SpirvWord>> {
- args.iter()
- .map(|a| ast::Variable {
- name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
- v_type: a.v_type.clone(),
- state_space: a.state_space,
- align: a.align,
- array_init: a.array_init.clone(),
- })
- .collect()
}
pub struct KernelInfo {
@@ -365,18 +85,6 @@ enum PtxSpecialRegister { }
impl PtxSpecialRegister {
- fn try_parse(s: &str) -> Option<Self> {
- match s {
- "%tid" => Some(Self::Tid),
- "%ntid" => Some(Self::Ntid),
- "%ctaid" => Some(Self::Ctaid),
- "%nctaid" => Some(Self::Nctaid),
- "%clock" => Some(Self::Clock),
- "%lanemask_lt" => Some(Self::LanemaskLt),
- _ => None,
- }
- }
-
fn as_str(self) -> &'static str {
match self {
Self::Tid => "%tid",
@@ -431,216 +139,24 @@ impl PtxSpecialRegister { }
}
-struct SpecialRegistersMap {
- reg_to_id: HashMap<PtxSpecialRegister, SpirvWord>,
- id_to_reg: HashMap<SpirvWord, PtxSpecialRegister>,
-}
-
-impl SpecialRegistersMap {
- fn new() -> Self {
- SpecialRegistersMap {
- reg_to_id: HashMap::new(),
- id_to_reg: HashMap::new(),
- }
- }
-
- fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
- self.id_to_reg.get(&id).copied()
- }
-
- fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
- match self.reg_to_id.entry(reg) {
- hash_map::Entry::Occupied(e) => *e.get(),
- hash_map::Entry::Vacant(e) => {
- let numeric_id = SpirvWord(current_id.0);
- current_id.0 += 1;
- e.insert(numeric_id);
- self.id_to_reg.insert(numeric_id, reg);
- numeric_id
- }
- }
- }
-}
-
-struct FnStringIdResolver<'input, 'b> {
- current_id: &'b mut SpirvWord,
- global_variables: &'b HashMap<Cow<'input, str>, SpirvWord>,
- global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
- special_registers: &'b mut SpecialRegistersMap,
- variables: Vec<HashMap<Cow<'input, str>, SpirvWord>>,
- type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
-}
-
-impl<'a, 'b> FnStringIdResolver<'a, 'b> {
- fn finish(self) -> NumericIdResolver<'b> {
- NumericIdResolver {
- current_id: self.current_id,
- global_type_check: self.global_type_check,
- type_check: self.type_check,
- special_registers: self.special_registers,
- }
- }
-
- fn start_block(&mut self) {
- self.variables.push(HashMap::new())
- }
-
- fn end_block(&mut self) {
- self.variables.pop();
- }
-
- fn get_id(&mut self, id: &str) -> Result<SpirvWord, TranslateError> {
- for scope in self.variables.iter().rev() {
- match scope.get(id) {
- Some(id) => return Ok(*id),
- None => continue,
- }
- }
- match self.global_variables.get(id) {
- Some(id) => Ok(*id),
- None => {
- let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?;
- Ok(self.special_registers.get_or_add(self.current_id, sreg))
- }
- }
- }
-
- fn add_def(
- &mut self,
- id: &'a str,
- typ: Option<(ast::Type, ast::StateSpace)>,
- is_variable: bool,
- ) -> SpirvWord {
- let numeric_id = *self.current_id;
- self.variables
- .last_mut()
- .unwrap()
- .insert(Cow::Borrowed(id), numeric_id);
- self.type_check.insert(
- numeric_id,
- typ.map(|(typ, space)| (typ, space, is_variable)),
- );
- self.current_id.0 += 1;
- numeric_id
- }
-
- #[must_use]
- fn add_defs(
- &mut self,
- base_id: &'a str,
- count: u32,
- typ: ast::Type,
- state_space: ast::StateSpace,
- is_variable: bool,
- ) -> impl Iterator<Item = SpirvWord> {
- let numeric_id = *self.current_id;
- for i in 0..count {
- self.variables.last_mut().unwrap().insert(
- Cow::Owned(format!("{}{}", base_id, i)),
- SpirvWord(numeric_id.0 + i),
- );
- self.type_check.insert(
- SpirvWord(numeric_id.0 + i),
- Some((typ.clone(), state_space, is_variable)),
- );
- }
- self.current_id.0 += count;
- (0..count)
- .into_iter()
- .map(move |i| SpirvWord(i + numeric_id.0))
- }
-}
-
-struct NumericIdResolver<'b> {
- current_id: &'b mut SpirvWord,
- global_type_check: &'b HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
- type_check: HashMap<SpirvWord, Option<(ast::Type, ast::StateSpace, bool)>>,
- special_registers: &'b mut SpecialRegistersMap,
-}
-
-impl<'b> NumericIdResolver<'b> {
- fn finish(self) -> MutableNumericIdResolver<'b> {
- MutableNumericIdResolver { base: self }
- }
-
- fn get_typed(
- &self,
- id: SpirvWord,
- ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
- match self.type_check.get(&id) {
- Some(Some(x)) => Ok(x.clone()),
- Some(None) => Err(TranslateError::UntypedSymbol),
- None => match self.special_registers.get(id) {
- Some(x) => Ok((x.get_type(), ast::StateSpace::Reg, true)),
- None => match self.global_type_check.get(&id) {
- Some(Some(result)) => Ok(result.clone()),
- Some(None) | None => Err(TranslateError::UntypedSymbol),
- },
- },
- }
- }
-
- // This is for identifiers which will be emitted later as OpVariable
- // They are candidates for insertion of LoadVar/StoreVar
- fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
- let new_id = *self.current_id;
- self.type_check
- .insert(new_id, Some((typ, state_space, true)));
- self.current_id.0 += 1;
- new_id
- }
-
- fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
- let new_id = *self.current_id;
- self.type_check
- .insert(new_id, typ.map(|(t, space)| (t, space, false)));
- self.current_id.0 += 1;
- new_id
- }
-}
-
-struct MutableNumericIdResolver<'b> {
- base: NumericIdResolver<'b>,
-}
-
-impl<'b> MutableNumericIdResolver<'b> {
- fn unmut(self) -> NumericIdResolver<'b> {
- self.base
- }
-
- fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
- self.base.get_typed(id).map(|(t, space, _)| (t, space))
- }
-
- fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord {
- self.base.register_intermediate(Some((typ, state_space)))
- }
+#[cfg(debug_assertions)]
+fn error_unreachable() -> TranslateError {
+ unreachable!()
}
-quick_error! {
- #[derive(Debug)]
- pub enum TranslateError {
- UnknownSymbol {}
- UntypedSymbol {}
- MismatchedType {}
- Spirv(err: rspirv::dr::Error) {
- from()
- display("{}", err)
- cause(err)
- }
- Unreachable {}
- Todo {}
- }
+#[cfg(not(debug_assertions))]
+fn error_unreachable() -> TranslateError {
+ TranslateError::Unreachable
}
#[cfg(debug_assertions)]
-fn error_unreachable() -> TranslateError {
+fn error_todo() -> TranslateError {
unreachable!()
}
#[cfg(not(debug_assertions))]
-fn error_unreachable() -> TranslateError {
- TranslateError::Unreachable
+fn error_todo() -> TranslateError {
+ TranslateError::Todo
}
#[cfg(debug_assertions)]
@@ -663,112 +179,20 @@ fn error_mismatched_type() -> TranslateError { TranslateError::MismatchedType
}
-pub struct GlobalFnDeclResolver<'input, 'a> {
- fns: &'a HashMap<SpirvWord, FnSigMapper<'input>>,
-}
-
-impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
- fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> {
- self.fns.get(&id).ok_or_else(error_unknown_symbol)
- }
-}
-
-struct FnSigMapper<'input> {
- // true - stays as return argument
- // false - is moved to input argument
- return_param_args: Vec<bool>,
- func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
-}
-
-impl<'input> FnSigMapper<'input> {
- fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self {
- let return_param_args = method
- .return_arguments
- .iter()
- .map(|a| a.state_space != ast::StateSpace::Param)
- .collect::<Vec<_>>();
- let mut new_return_arguments = Vec::new();
- for arg in method.return_arguments.into_iter() {
- if arg.state_space == ast::StateSpace::Param {
- method.input_arguments.push(arg);
- } else {
- new_return_arguments.push(arg);
- }
- }
- method.return_arguments = new_return_arguments;
- FnSigMapper {
- return_param_args,
- func_decl: Rc::new(RefCell::new(method)),
- }
- }
-
- fn resolve_in_spirv_repr(
- &self,
- data: ast::CallDetails,
- arguments: ast::CallArgs<ast::ParsedOperand<SpirvWord>>,
- ) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
- let func_decl = (*self.func_decl).borrow();
- let mut data_return = Vec::new();
- let mut arguments_return = Vec::new();
- let mut data_input = data.input_arguments;
- let mut arguments_input = arguments.input_arguments;
- let mut func_decl_return_iter = func_decl.return_arguments.iter();
- let mut func_decl_input_iter = func_decl.input_arguments[arguments_input.len()..].iter();
- for (idx, id) in arguments.return_arguments.iter().enumerate() {
- let stays_as_return = match self.return_param_args.get(idx) {
- Some(x) => *x,
- None => return Err(TranslateError::MismatchedType),
- };
- if stays_as_return {
- if let Some(var) = func_decl_return_iter.next() {
- data_return.push((var.v_type.clone(), var.state_space));
- arguments_return.push(*id);
- } else {
- return Err(TranslateError::MismatchedType);
- }
- } else {
- if let Some(var) = func_decl_input_iter.next() {
- data_input.push((var.v_type.clone(), var.state_space));
- arguments_input.push(ast::ParsedOperand::Reg(*id));
- } else {
- return Err(TranslateError::MismatchedType);
- }
- }
- }
- if arguments_return.len() != func_decl.return_arguments.len()
- || arguments_input.len() != func_decl.input_arguments.len()
- {
- return Err(TranslateError::MismatchedType);
- }
- let data = ast::CallDetails {
- uniform: data.uniform,
- return_arguments: data_return,
- input_arguments: data_input,
- };
- let arguments = ast::CallArgs {
- func: arguments.func,
- return_arguments: arguments_return,
- input_arguments: arguments_input,
- };
- Ok(ast::Instruction::Call { data, arguments })
- }
-}
-
enum Statement<I, P: ast::Operand> {
Label(SpirvWord),
Variable(ast::Variable<P::Ident>),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
- LoadVar(LoadVarDetails),
- StoreVar(StoreVarDetails),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
- RetValue(ast::RetData, SpirvWord),
+ RetValue(ast::RetData, Vec<(SpirvWord, ast::Type)>),
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
- VectorAccess(VectorAccess),
+ VectorRead(VectorRead),
+ VectorWrite(VectorWrite),
}
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@@ -813,52 +237,6 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> { if_false,
})
}
- Statement::LoadVar(LoadVarDetails {
- arg,
- typ,
- member_index,
- }) => {
- let dst = visitor.visit_ident(
- arg.dst,
- Some((&typ, ast::StateSpace::Reg)),
- true,
- false,
- )?;
- let src = visitor.visit_ident(
- arg.src,
- Some((&typ, ast::StateSpace::Local)),
- false,
- false,
- )?;
- Statement::LoadVar(LoadVarDetails {
- arg: ast::LdArgs { dst, src },
- typ,
- member_index,
- })
- }
- Statement::StoreVar(StoreVarDetails {
- arg,
- typ,
- member_index,
- }) => {
- let src1 = visitor.visit_ident(
- arg.src1,
- Some((&typ, ast::StateSpace::Local)),
- false,
- false,
- )?;
- let src2 = visitor.visit_ident(
- arg.src2,
- Some((&typ, ast::StateSpace::Reg)),
- false,
- false,
- )?;
- Statement::StoreVar(StoreVarDetails {
- arg: ast::StArgs { src1, src2 },
- typ,
- member_index,
- })
- }
Statement::Conversion(ImplicitConversion {
src,
dst,
@@ -900,9 +278,20 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> { Statement::Constant(ConstantDefinition { dst, typ, value })
}
Statement::RetValue(data, value) => {
- // TODO:
- // We should report type here
- let value = visitor.visit_ident(value, None, false, false)?;
+ let value = value
+ .into_iter()
+ .map(|(ident, type_)| {
+ Ok((
+ visitor.visit_ident(
+ ident,
+ Some((&type_, ast::StateSpace::Local)),
+ false,
+ false,
+ )?,
+ type_,
+ ))
+ })
+ .collect::<Result<Vec<_>, _>>()?;
Statement::RetValue(data, value)
}
Statement::PtrAccess(PtrAccess {
@@ -937,33 +326,69 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> { offset_src,
})
}
- Statement::VectorAccess(VectorAccess {
+ Statement::VectorRead(VectorRead {
scalar_type,
vector_width,
- dst,
- src: vector_src,
+ scalar_dst: dst,
+ vector_src,
member,
}) => {
+ let scalar_t = scalar_type.into();
+ let vector_t = ast::Type::Vector(vector_width, scalar_type);
let dst: SpirvWord = visitor.visit_ident(
dst,
- Some((&scalar_type.into(), ast::StateSpace::Reg)),
+ Some((&scalar_t, ast::StateSpace::Reg)),
true,
false,
)?;
let src = visitor.visit_ident(
vector_src,
- Some((
- &ast::Type::Vector(vector_width, scalar_type),
- ast::StateSpace::Reg,
- )),
+ Some((&vector_t, ast::StateSpace::Reg)),
false,
false,
)?;
- Statement::VectorAccess(VectorAccess {
+ Statement::VectorRead(VectorRead {
+ scalar_type,
+ vector_width,
+ scalar_dst: dst,
+ vector_src: src,
+ member,
+ })
+ }
+ Statement::VectorWrite(VectorWrite {
+ scalar_type,
+ vector_width,
+ vector_dst,
+ vector_src,
+ scalar_src,
+ member,
+ }) => {
+ let scalar_t = scalar_type.into();
+ let vector_t = ast::Type::Vector(vector_width, scalar_type);
+ let vector_dst = visitor.visit_ident(
+ vector_dst,
+ Some((&vector_t, ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let vector_src = visitor.visit_ident(
+ vector_src,
+ Some((&vector_t, ast::StateSpace::Reg)),
+ false,
+ false,
+ )?;
+ let scalar_src = visitor.visit_ident(
+ scalar_src,
+ Some((&scalar_t, ast::StateSpace::Reg)),
+ false,
+ false,
+ )?;
+ Statement::VectorWrite(VectorWrite {
+ vector_dst,
+ vector_src,
+ scalar_src,
scalar_type,
vector_width,
- dst,
- src,
member,
})
}
@@ -1049,22 +474,6 @@ struct BrachCondition { if_true: SpirvWord,
if_false: SpirvWord,
}
-struct LoadVarDetails {
- arg: ast::LdArgs<SpirvWord>,
- typ: ast::Type,
- // (index, vector_width)
- // HACK ALERT
- // For some reason IGC explodes when you try to load from builtin vectors
- // using OpInBoundsAccessChain, the one true way to do it is to
- // OpLoad+OpCompositeExtract
- member_index: Option<(u8, Option<u8>)>,
-}
-
-struct StoreVarDetails {
- arg: ast::StArgs<SpirvWord>,
- typ: ast::Type,
- member_index: Option<u8>,
-}
#[derive(Clone)]
struct ImplicitConversion {
@@ -1115,14 +524,14 @@ struct FunctionPointerDetails { }
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
-struct SpirvWord(spirv::Word);
+pub struct SpirvWord(u32);
-impl From<spirv::Word> for SpirvWord {
- fn from(value: spirv::Word) -> Self {
+impl From<u32> for SpirvWord {
+ fn from(value: u32) -> Self {
Self(value)
}
}
-impl From<SpirvWord> for spirv::Word {
+impl From<SpirvWord> for u32 {
fn from(value: SpirvWord) -> Self {
value.0
}
@@ -1136,31 +545,6 @@ impl ast::Operand for SpirvWord { }
}
-fn pred_map_variable<U, T, F: FnMut(T) -> Result<U, TranslateError>>(
- this: ast::PredAt<T>,
- f: &mut F,
-) -> Result<ast::PredAt<U>, TranslateError> {
- let new_label = f(this.label)?;
- Ok(ast::PredAt {
- not: this.not,
- label: new_label,
- })
-}
-
-pub(crate) enum Directive<'input> {
- Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
- Method(Function<'input>),
-}
-
-pub(crate) struct Function<'input> {
- pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
- pub globals: Vec<ast::Variable<SpirvWord>>,
- pub body: Option<Vec<ExpandedStatement>>,
- import_as: Option<String>,
- tuning: Vec<ast::TuningDirective>,
- linkage: ast::LinkingDirective,
-}
-
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
type NormalizedStatement = Statement<
@@ -1171,577 +555,12 @@ type NormalizedStatement = Statement< ast::ParsedOperand<SpirvWord>,
>;
-type UnconditionalStatement =
- Statement<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
-
-type TypedStatement = Statement<ast::Instruction<TypedOperand>, TypedOperand>;
-
-#[derive(Copy, Clone)]
-enum TypedOperand {
- Reg(SpirvWord),
- RegOffset(SpirvWord, i32),
- Imm(ast::ImmediateValue),
- VecMember(SpirvWord, u8),
-}
-
-impl TypedOperand {
- fn map<Err>(
- self,
- fn_: impl FnOnce(SpirvWord, Option<u8>) -> Result<SpirvWord, Err>,
- ) -> Result<Self, Err> {
- Ok(match self {
- TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?),
- TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off),
- TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
- TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx),
- })
- }
-
- fn underlying_register(&self) -> Option<SpirvWord> {
- match self {
- Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r),
- Self::Imm(_) => None,
- }
- }
-
- fn unwrap_reg(&self) -> Result<SpirvWord, TranslateError> {
- match self {
- TypedOperand::Reg(reg) => Ok(*reg),
- _ => Err(error_unreachable()),
- }
- }
-}
-
-impl ast::Operand for TypedOperand {
- type Ident = SpirvWord;
-
- fn from_ident(ident: Self::Ident) -> Self {
- TypedOperand::Reg(ident)
- }
-}
-
-impl<Fn> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
- for FnVisitor<TypedOperand, TypedOperand, TranslateError, Fn>
-where
- Fn: FnMut(
- TypedOperand,
- Option<(&ast::Type, ast::StateSpace)>,
- bool,
- bool,
- ) -> Result<TypedOperand, TranslateError>,
-{
- fn visit(
- &mut self,
- args: TypedOperand,
- type_space: Option<(&ast::Type, ast::StateSpace)>,
- is_dst: bool,
- relaxed_type_check: bool,
- ) -> Result<TypedOperand, TranslateError> {
- (self.fn_)(args, type_space, is_dst, relaxed_type_check)
- }
-
- fn visit_ident(
- &mut self,
- args: SpirvWord,
- type_space: Option<(&ast::Type, ast::StateSpace)>,
- is_dst: bool,
- relaxed_type_check: bool,
- ) -> Result<SpirvWord, TranslateError> {
- match (self.fn_)(
- TypedOperand::Reg(args),
- type_space,
- is_dst,
- relaxed_type_check,
- )? {
- TypedOperand::Reg(reg) => Ok(reg),
- _ => Err(TranslateError::Unreachable),
- }
- }
-}
-
-struct FnVisitor<
- T,
- U,
- Err,
- Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
-> {
- fn_: Fn,
- _marker: PhantomData<fn(T) -> Result<U, Err>>,
-}
-
-impl<
- T,
- U,
- Err,
- Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result<U, Err>,
- > FnVisitor<T, U, Err, Fn>
-{
- fn new(fn_: Fn) -> Self {
- Self {
- fn_,
- _marker: PhantomData,
- }
- }
-}
-
-fn register_external_fn_call<'a>(
- id_defs: &mut NumericIdResolver,
- ptx_impl_imports: &mut HashMap<String, Directive>,
- name: String,
- return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
- input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
-) -> Result<SpirvWord, TranslateError> {
- match ptx_impl_imports.entry(name) {
- hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.register_intermediate(None);
- let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
- let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
- let func_decl = ast::MethodDeclaration::<SpirvWord> {
- return_arguments,
- name: ast::MethodName::Func(fn_id),
- input_arguments,
- shared_mem: None,
- };
- let func = Function {
- func_decl: Rc::new(RefCell::new(func_decl)),
- globals: Vec::new(),
- body: None,
- import_as: Some(entry.key().clone()),
- tuning: Vec::new(),
- linkage: ast::LinkingDirective::EXTERN,
- };
- entry.insert(Directive::Method(func));
- Ok(fn_id)
- }
- hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
- ast::MethodName::Func(fn_id) => Ok(fn_id),
- ast::MethodName::Kernel(_) => Err(error_unreachable()),
- },
- _ => Err(error_unreachable()),
- },
- }
-}
-
-fn fn_arguments_to_variables<'a>(
- id_defs: &mut NumericIdResolver,
- args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
-) -> Vec<ast::Variable<SpirvWord>> {
- args.map(|(typ, space)| ast::Variable {
- align: None,
- v_type: typ.clone(),
- state_space: space,
- name: id_defs.register_intermediate(None),
- array_init: Vec::new(),
- })
- .collect::<Vec<_>>()
-}
-
-fn hoist_function_globals(directives: Vec<Directive>) -> Vec<Directive> {
- let mut result = Vec::with_capacity(directives.len());
- for directive in directives {
- match directive {
- Directive::Method(method) => {
- for variable in method.globals {
- result.push(Directive::Variable(ast::LinkingDirective::NONE, variable));
- }
- result.push(Directive::Method(Function {
- globals: Vec::new(),
- ..method
- }))
- }
- _ => result.push(directive),
- }
- }
- result
-}
-
-struct MethodsCallMap<'input> {
- map: HashMap<ast::MethodName<'input, SpirvWord>, HashSet<SpirvWord>>,
-}
-
-impl<'input> MethodsCallMap<'input> {
- fn new(module: &[Directive<'input>]) -> Self {
- let mut directly_called_by = HashMap::new();
- for directive in module {
- match directive {
- Directive::Method(Function {
- func_decl,
- body: Some(statements),
- ..
- }) => {
- let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
- if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
- entry.insert(Vec::new());
- }
- for statement in statements {
- match statement {
- Statement::Instruction(ast::Instruction::Call { data, arguments }) => {
- multi_hash_map_append(
- &mut directly_called_by,
- call_key,
- arguments.func,
- );
- }
- _ => {}
- }
- }
- }
- _ => {}
- }
- }
- let mut result = HashMap::new();
- for (&method_key, children) in directly_called_by.iter() {
- let mut visited = HashSet::new();
- for child in children {
- Self::add_call_map_single(&directly_called_by, &mut visited, *child);
- }
- result.insert(method_key, visited);
- }
- MethodsCallMap { map: result }
- }
-
- fn add_call_map_single(
- directly_called_by: &HashMap<ast::MethodName<'input, SpirvWord>, Vec<SpirvWord>>,
- visited: &mut HashSet<SpirvWord>,
- current: SpirvWord,
- ) {
- if !visited.insert(current) {
- return;
- }
- if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
- for child in children {
- Self::add_call_map_single(directly_called_by, visited, *child);
- }
- }
- }
-
- fn get_kernel_children(&self, name: &'input str) -> impl Iterator<Item = &SpirvWord> {
- self.map
- .get(&ast::MethodName::Kernel(name))
- .into_iter()
- .flatten()
- }
-
- fn kernels(&self) -> impl Iterator<Item = (&'input str, &HashSet<SpirvWord>)> {
- self.map
- .iter()
- .filter_map(|(method, children)| match method {
- ast::MethodName::Kernel(kernel) => Some((*kernel, children)),
- ast::MethodName::Func(..) => None,
- })
- }
-
- fn methods(
- &self,
- ) -> impl Iterator<Item = (ast::MethodName<'input, SpirvWord>, &HashSet<SpirvWord>)> {
- self.map
- .iter()
- .map(|(method, children)| (*method, children))
- }
-
- fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) {
- self.map
- .get(&method)
- .into_iter()
- .flatten()
- .copied()
- .for_each(f);
- }
-}
-
-fn multi_hash_map_append<
- K: Eq + std::hash::Hash,
- V,
- Collection: std::iter::Extend<V> + std::default::Default,
->(
- m: &mut HashMap<K, Collection>,
- key: K,
- value: V,
-) {
- match m.entry(key) {
- hash_map::Entry::Occupied(mut entry) => {
- entry.get_mut().extend(iter::once(value));
- }
- hash_map::Entry::Vacant(entry) => {
- entry.insert(Default::default()).extend(iter::once(value));
- }
- }
-}
-
-fn normalize_variable_decls(directives: &mut Vec<Directive>) {
- for directive in directives {
- match directive {
- Directive::Method(Function {
- body: Some(func), ..
- }) => {
- func[1..].sort_by_key(|s| match s {
- Statement::Variable(_) => 0,
- _ => 1,
- });
- }
- _ => (),
- }
- }
-}
-
-// HACK ALERT!
-// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
-// in the kernel as flushing denorms to zero or preserving them
-// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
-// such capability, so instead we guesstimate which use is more common in the kernel
-// and emit suitable execution mode
-fn compute_denorm_information<'input>(
- module: &[Directive<'input>],
-) -> HashMap<ast::MethodName<'input, SpirvWord>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
- let mut denorm_methods = HashMap::new();
- for directive in module {
- match directive {
- Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
- Directive::Method(Function {
- func_decl,
- body: Some(statements),
- ..
- }) => {
- let mut flush_counter = DenormCountMap::new();
- let method_key = (**func_decl).borrow().name;
- for statement in statements {
- match statement {
- Statement::Instruction(inst) => {
- if let Some((flush, width)) = flush_to_zero(inst) {
- denorm_count_map_update(&mut flush_counter, width, flush);
- }
- }
- Statement::LoadVar(..) => {}
- Statement::StoreVar(..) => {}
- Statement::Conditional(_) => {}
- Statement::Conversion(_) => {}
- Statement::Constant(_) => {}
- Statement::RetValue(_, _) => {}
- Statement::Label(_) => {}
- Statement::Variable(_) => {}
- Statement::PtrAccess { .. } => {}
- Statement::VectorAccess { .. } => {}
- Statement::RepackVector(_) => {}
- Statement::FunctionPointer(_) => {}
- }
- }
- denorm_methods.insert(method_key, flush_counter);
- }
- }
- }
- denorm_methods
- .into_iter()
- .map(|(name, v)| {
- let width_to_denorm = v
- .into_iter()
- .map(|(k, flush_over_preserve)| {
- let mode = if flush_over_preserve > 0 {
- spirv::FPDenormMode::FlushToZero
- } else {
- spirv::FPDenormMode::Preserve
- };
- (k, (mode, flush_over_preserve))
- })
- .collect();
- (name, width_to_denorm)
- })
- .collect()
-}
-
-fn flush_to_zero(this: &ast::Instruction<SpirvWord>) -> Option<(bool, u8)> {
- match this {
- ast::Instruction::Ld { .. } => None,
- ast::Instruction::St { .. } => None,
- ast::Instruction::Mov { .. } => None,
- ast::Instruction::Not { .. } => None,
- ast::Instruction::Bra { .. } => None,
- ast::Instruction::Shl { .. } => None,
- ast::Instruction::Shr { .. } => None,
- ast::Instruction::Ret { .. } => None,
- ast::Instruction::Call { .. } => None,
- ast::Instruction::Or { .. } => None,
- ast::Instruction::And { .. } => None,
- ast::Instruction::Cvta { .. } => None,
- ast::Instruction::Selp { .. } => None,
- ast::Instruction::Bar { .. } => None,
- ast::Instruction::Atom { .. } => None,
- ast::Instruction::AtomCas { .. } => None,
- ast::Instruction::Sub {
- data: ast::ArithDetails::Integer(_),
- ..
- } => None,
- ast::Instruction::Add {
- data: ast::ArithDetails::Integer(_),
- ..
- } => None,
- ast::Instruction::Mul {
- data: ast::MulDetails::Integer { .. },
- ..
- } => None,
- ast::Instruction::Mad {
- data: ast::MadDetails::Integer { .. },
- ..
- } => None,
- ast::Instruction::Min {
- data: ast::MinMaxDetails::Signed(_),
- ..
- } => None,
- ast::Instruction::Min {
- data: ast::MinMaxDetails::Unsigned(_),
- ..
- } => None,
- ast::Instruction::Max {
- data: ast::MinMaxDetails::Signed(_),
- ..
- } => None,
- ast::Instruction::Max {
- data: ast::MinMaxDetails::Unsigned(_),
- ..
- } => None,
- ast::Instruction::Cvt {
- data:
- ast::CvtDetails {
- mode:
- ast::CvtMode::ZeroExtend
- | ast::CvtMode::SignExtend
- | ast::CvtMode::Truncate
- | ast::CvtMode::Bitcast
- | ast::CvtMode::SaturateUnsignedToSigned
- | ast::CvtMode::SaturateSignedToUnsigned
- | ast::CvtMode::FPFromSigned(_)
- | ast::CvtMode::FPFromUnsigned(_),
- ..
- },
- ..
- } => None,
- ast::Instruction::Div {
- data: ast::DivDetails::Unsigned(_),
- ..
- } => None,
- ast::Instruction::Div {
- data: ast::DivDetails::Signed(_),
- ..
- } => None,
- ast::Instruction::Clz { .. } => None,
- ast::Instruction::Brev { .. } => None,
- ast::Instruction::Popc { .. } => None,
- ast::Instruction::Xor { .. } => None,
- ast::Instruction::Bfe { .. } => None,
- ast::Instruction::Bfi { .. } => None,
- ast::Instruction::Rem { .. } => None,
- ast::Instruction::Prmt { .. } => None,
- ast::Instruction::Activemask { .. } => None,
- ast::Instruction::Membar { .. } => None,
- ast::Instruction::Sub {
- data: ast::ArithDetails::Float(float_control),
- ..
- }
- | ast::Instruction::Add {
- data: ast::ArithDetails::Float(float_control),
- ..
- }
- | ast::Instruction::Mul {
- data: ast::MulDetails::Float(float_control),
- ..
- }
- | ast::Instruction::Mad {
- data: ast::MadDetails::Float(float_control),
- ..
- } => float_control
- .flush_to_zero
- .map(|ftz| (ftz, float_control.type_.size_of())),
- ast::Instruction::Fma { data, .. } => {
- data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
- }
- ast::Instruction::Setp { data, .. } => {
- data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
- }
- ast::Instruction::SetpBool { data, .. } => data
- .base
- .flush_to_zero
- .map(|ftz| (ftz, data.base.type_.size_of())),
- ast::Instruction::Abs { data, .. }
- | ast::Instruction::Rsqrt { data, .. }
- | ast::Instruction::Neg { data, .. }
- | ast::Instruction::Ex2 { data, .. } => {
- data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
- }
- ast::Instruction::Min {
- data: ast::MinMaxDetails::Float(float_control),
- ..
- }
- | ast::Instruction::Max {
- data: ast::MinMaxDetails::Float(float_control),
- ..
- } => float_control
- .flush_to_zero
- .map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())),
- ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => {
- data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of()))
- }
- // Modifier .ftz can only be specified when either .dtype or .atype
- // is .f32 and applies only to single precision (.f32) inputs and results.
- ast::Instruction::Cvt {
- data:
- ast::CvtDetails {
- mode:
- ast::CvtMode::FPExtend { flush_to_zero }
- | ast::CvtMode::FPTruncate { flush_to_zero, .. }
- | ast::CvtMode::FPRound { flush_to_zero, .. }
- | ast::CvtMode::SignedFromFP { flush_to_zero, .. }
- | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. },
- ..
- },
- ..
- } => flush_to_zero.map(|ftz| (ftz, 4)),
- ast::Instruction::Div {
- data:
- ast::DivDetails::Float(ast::DivFloatDetails {
- type_,
- flush_to_zero,
- ..
- }),
- ..
- } => flush_to_zero.map(|ftz| (ftz, type_.size_of())),
- ast::Instruction::Sin { data, .. }
- | ast::Instruction::Cos { data, .. }
- | ast::Instruction::Lg2 { data, .. } => {
- Some((data.flush_to_zero, mem::size_of::<f32>() as u8))
- }
- ptx_parser::Instruction::PrmtSlow { .. } => None,
- ptx_parser::Instruction::Trap {} => None,
- }
-}
-
-type DenormCountMap<T> = HashMap<T, isize>;
-
-fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
- let num_value = if value { 1 } else { -1 };
- denorm_count_map_update_impl(map, key, num_value);
-}
-
-fn denorm_count_map_update_impl<T: Eq + Hash>(
- map: &mut DenormCountMap<T>,
- key: T,
- num_value: isize,
-) {
- match map.entry(key) {
- hash_map::Entry::Occupied(mut counter) => {
- *(counter.get_mut()) += num_value;
- }
- hash_map::Entry::Vacant(entry) => {
- entry.insert(num_value);
- }
- }
-}
-
-pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
+enum Directive2<'input, Instruction, Operand: ast::Operand> {
Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
Method(Function2<'input, Instruction, Operand>),
}
-pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
+struct Function2<'input, Instruction, Operand: ast::Operand> {
pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
pub globals: Vec<ast::Variable<SpirvWord>>,
pub body: Option<Vec<Statement<Instruction, Operand>>>,
@@ -1861,6 +680,41 @@ impl<'input, 'b> ScopedResolver<'input, 'b> { scope.flush(self.flat_resolver);
}
+ fn add_or_get_in_current_scope_untyped(
+ &mut self,
+ name: &'input str,
+ ) -> Result<SpirvWord, TranslateError> {
+ let current_scope = self.scopes.last_mut().unwrap();
+ Ok(
+ match current_scope.name_to_ident.entry(Cow::Borrowed(name)) {
+ hash_map::Entry::Occupied(occupied_entry) => {
+ let ident = *occupied_entry.get();
+ let entry = current_scope
+ .ident_map
+ .get(&ident)
+ .ok_or_else(|| error_unreachable())?;
+ if entry.type_space.is_some() {
+ return Err(error_unknown_symbol());
+ }
+ ident
+ }
+ hash_map::Entry::Vacant(vacant_entry) => {
+ let new_id = self.flat_resolver.current_id;
+ self.flat_resolver.current_id.0 += 1;
+ vacant_entry.insert(new_id);
+ current_scope.ident_map.insert(
+ new_id,
+ IdentEntry {
+ name: Some(Cow::Borrowed(name)),
+ type_space: None,
+ },
+ );
+ new_id
+ }
+ },
+ )
+ }
+
fn add(
&mut self,
name: Cow<'input, str>,
@@ -1949,19 +803,6 @@ impl SpecialRegistersMap2 { self.id_to_reg.get(&id).copied()
}
- fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
- match self.reg_to_id.entry(reg) {
- hash_map::Entry::Occupied(e) => *e.get(),
- hash_map::Entry::Vacant(e) => {
- let numeric_id = SpirvWord(current_id.0);
- current_id.0 += 1;
- e.insert(numeric_id);
- self.id_to_reg.insert(numeric_id, reg);
- numeric_id
- }
- }
- }
-
fn generate_declarations<'a, 'input>(
resolver: &'a mut GlobalStringIdentResolver2<'input>,
) -> impl ExactSizeIterator<
@@ -1975,7 +816,7 @@ impl SpecialRegistersMap2 { let name =
ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
let return_type = sreg.get_function_return_type();
- let input_type = sreg.get_function_return_type();
+ let input_type = sreg.get_function_input_type();
(
sreg,
ast::MethodDeclaration {
@@ -1988,14 +829,17 @@ impl SpecialRegistersMap2 { array_init: Vec::new(),
}],
name: name,
- input_arguments: vec![ast::Variable {
- align: None,
- v_type: input_type.into(),
- state_space: ast::StateSpace::Reg,
- name: resolver
- .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
- array_init: Vec::new(),
- }],
+ input_arguments: input_type
+ .into_iter()
+ .map(|type_| ast::Variable {
+ align: None,
+ v_type: type_.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((type_.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ })
+ .collect::<Vec<_>>(),
shared_mem: None,
},
)
@@ -2003,10 +847,49 @@ impl SpecialRegistersMap2 { }
}
-pub struct VectorAccess {
+pub struct VectorRead {
scalar_type: ast::ScalarType,
vector_width: u8,
- dst: SpirvWord,
- src: SpirvWord,
+ scalar_dst: SpirvWord,
+ vector_src: SpirvWord,
member: u8,
}
+
+pub struct VectorWrite {
+ scalar_type: ast::ScalarType,
+ vector_width: u8,
+ vector_dst: SpirvWord,
+ vector_src: SpirvWord,
+ scalar_src: SpirvWord,
+ member: u8,
+}
+
+fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str {
+ match this {
+ ast::ScalarType::B8 => "b8",
+ ast::ScalarType::B16 => "b16",
+ ast::ScalarType::B32 => "b32",
+ ast::ScalarType::B64 => "b64",
+ ast::ScalarType::B128 => "b128",
+ ast::ScalarType::U8 => "u8",
+ ast::ScalarType::U16 => "u16",
+ ast::ScalarType::U16x2 => "u16x2",
+ ast::ScalarType::U32 => "u32",
+ ast::ScalarType::U64 => "u64",
+ ast::ScalarType::S8 => "s8",
+ ast::ScalarType::S16 => "s16",
+ ast::ScalarType::S16x2 => "s16x2",
+ ast::ScalarType::S32 => "s32",
+ ast::ScalarType::S64 => "s64",
+ ast::ScalarType::F16 => "f16",
+ ast::ScalarType::F16x2 => "f16x2",
+ ast::ScalarType::F32 => "f32",
+ ast::ScalarType::F64 => "f64",
+ ast::ScalarType::BF16 => "bf16",
+ ast::ScalarType::BF16x2 => "bf16x2",
+ ast::ScalarType::Pred => "pred",
+ }
+}
+
+type UnconditionalStatement =
+ Statement<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;
diff --git a/ptx/src/pass/normalize_identifiers.rs b/ptx/src/pass/normalize_identifiers.rs deleted file mode 100644 index b598345..0000000 --- a/ptx/src/pass/normalize_identifiers.rs +++ /dev/null @@ -1,80 +0,0 @@ -use super::*;
-use ptx_parser as ast;
-
-pub(crate) fn run<'input, 'b>(
- id_defs: &mut FnStringIdResolver<'input, 'b>,
- fn_defs: &GlobalFnDeclResolver<'input, 'b>,
- func: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
-) -> Result<Vec<NormalizedStatement>, TranslateError> {
- for s in func.iter() {
- match s {
- ast::Statement::Label(id) => {
- id_defs.add_def(*id, None, false);
- }
- _ => (),
- }
- }
- let mut result = Vec::new();
- for s in func {
- expand_map_variables(id_defs, fn_defs, &mut result, s)?;
- }
- Ok(result)
-}
-
-fn expand_map_variables<'a, 'b>(
- id_defs: &mut FnStringIdResolver<'a, 'b>,
- fn_defs: &GlobalFnDeclResolver<'a, 'b>,
- result: &mut Vec<NormalizedStatement>,
- s: ast::Statement<ast::ParsedOperand<&'a str>>,
-) -> Result<(), TranslateError> {
- match s {
- ast::Statement::Block(block) => {
- id_defs.start_block();
- for s in block {
- expand_map_variables(id_defs, fn_defs, result, s)?;
- }
- id_defs.end_block();
- }
- ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
- ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
- p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id)))
- .transpose()?,
- ast::visit_map(i, &mut |id,
- _: Option<(&ast::Type, ast::StateSpace)>,
- _: bool,
- _: bool| {
- id_defs.get_id(id)
- })?,
- ))),
- ast::Statement::Variable(var) => {
- let var_type = var.var.v_type.clone();
- match var.count {
- Some(count) => {
- for new_id in
- id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
- {
- result.push(Statement::Variable(ast::Variable {
- align: var.var.align,
- v_type: var.var.v_type.clone(),
- state_space: var.var.state_space,
- name: new_id,
- array_init: var.var.array_init.clone(),
- }))
- }
- }
- None => {
- let new_id =
- id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
- result.push(Statement::Variable(ast::Variable {
- align: var.var.align,
- v_type: var.var.v_type.clone(),
- state_space: var.var.state_space,
- name: new_id,
- array_init: var.var.array_init,
- }));
- }
- }
- }
- };
- Ok(())
-}
diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index beaf08b..5155886 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -1,6 +1,5 @@ use super::*;
use ptx_parser as ast;
-use rustc_hash::FxHashMap;
pub(crate) fn run<'input, 'b>(
resolver: &mut ScopedResolver<'input, 'b>,
@@ -37,7 +36,7 @@ fn run_method<'input, 'b>( let name = match method.func_directive.name {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
ast::MethodName::Func(text) => {
- ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
+ ast::MethodName::Func(resolver.add_or_get_in_current_scope_untyped(text)?)
}
};
resolver.start_scope();
diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs deleted file mode 100644 index 037e918..0000000 --- a/ptx/src/pass/normalize_labels.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::{collections::HashSet, iter};
-
-use super::*;
-
-pub(super) fn run(
- func: Vec<ExpandedStatement>,
- id_def: &mut NumericIdResolver,
-) -> Vec<ExpandedStatement> {
- let mut labels_in_use = HashSet::new();
- for s in func.iter() {
- match s {
- Statement::Instruction(i) => {
- if let Some(target) = jump_target(i) {
- labels_in_use.insert(target);
- }
- }
- Statement::Conditional(cond) => {
- labels_in_use.insert(cond.if_true);
- labels_in_use.insert(cond.if_false);
- }
- Statement::Variable(..)
- | Statement::LoadVar(..)
- | Statement::StoreVar(..)
- | Statement::RetValue(..)
- | Statement::Conversion(..)
- | Statement::Constant(..)
- | Statement::Label(..)
- | Statement::PtrAccess { .. }
- | Statement::VectorAccess { .. }
- | Statement::RepackVector(..)
- | Statement::FunctionPointer(..) => {}
- }
- }
- iter::once(Statement::Label(id_def.register_intermediate(None)))
- .chain(func.into_iter().filter(|s| match s {
- Statement::Label(i) => labels_in_use.contains(i),
- _ => true,
- }))
- .collect::<Vec<_>>()
-}
-
-fn jump_target<T: ast::Operand<Ident = SpirvWord>>(
- this: &ast::Instruction<T>,
-) -> Option<SpirvWord> {
- match this {
- ast::Instruction::Bra { arguments } => Some(arguments.src),
- _ => None,
- }
-}
diff --git a/ptx/src/pass/normalize_predicates.rs b/ptx/src/pass/normalize_predicates.rs deleted file mode 100644 index c971cfa..0000000 --- a/ptx/src/pass/normalize_predicates.rs +++ /dev/null @@ -1,44 +0,0 @@ -use super::*;
-use ptx_parser as ast;
-
-pub(crate) fn run(
- func: Vec<NormalizedStatement>,
- id_def: &mut NumericIdResolver,
-) -> Result<Vec<UnconditionalStatement>, TranslateError> {
- let mut result = Vec::with_capacity(func.len());
- for s in func {
- match s {
- Statement::Label(id) => result.push(Statement::Label(id)),
- Statement::Instruction((pred, inst)) => {
- if let Some(pred) = pred {
- let if_true = id_def.register_intermediate(None);
- let if_false = id_def.register_intermediate(None);
- let folded_bra = match &inst {
- ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
- _ => None,
- };
- let mut branch = BrachCondition {
- predicate: pred.label,
- if_true: folded_bra.unwrap_or(if_true),
- if_false,
- };
- if pred.not {
- std::mem::swap(&mut branch.if_true, &mut branch.if_false);
- }
- result.push(Statement::Conditional(branch));
- if folded_bra.is_none() {
- result.push(Statement::Label(if_true));
- result.push(Statement::Instruction(inst));
- }
- result.push(Statement::Label(if_false));
- } else {
- result.push(Statement::Instruction(inst));
- }
- }
- Statement::Variable(var) => result.push(Statement::Variable(var)),
- // Blocks are flattened when resolving ids
- _ => return Err(error_unreachable()),
- }
- }
- Ok(result)
-}
diff --git a/ptx/src/pass/replace_instructions_with_function_calls.rs b/ptx/src/pass/replace_instructions_with_function_calls.rs new file mode 100644 index 0000000..70d77d3 --- /dev/null +++ b/ptx/src/pass/replace_instructions_with_function_calls.rs @@ -0,0 +1,187 @@ +use super::*;
+
+pub(super) fn run<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ let mut fn_declarations = FxHashMap::default();
+ let remapped_directives = directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, &mut fn_declarations, directive))
+ .collect::<Result<Vec<_>, _>>()?;
+ let mut result = fn_declarations
+ .into_iter()
+ .map(|(_, (return_arguments, name, input_arguments))| {
+ Directive2::Method(Function2 {
+ func_decl: ast::MethodDeclaration {
+ return_arguments,
+ name: ast::MethodName::Func(name),
+ input_arguments,
+ shared_mem: None,
+ },
+ globals: Vec::new(),
+ body: None,
+ import_as: None,
+ tuning: Vec::new(),
+ linkage: ast::LinkingDirective::EXTERN,
+ })
+ })
+ .collect::<Vec<_>>();
+ result.extend(remapped_directives);
+ Ok(result)
+}
+
+fn run_directive<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ fn_declarations: &mut FxHashMap<
+ Cow<'input, str>,
+ (
+ Vec<ast::Variable<SpirvWord>>,
+ SpirvWord,
+ Vec<ast::Variable<SpirvWord>>,
+ ),
+ >,
+ directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(mut method) => {
+ method.body = method
+ .body
+ .map(|statements| run_statements(resolver, fn_declarations, statements))
+ .transpose()?;
+ Directive2::Method(method)
+ }
+ })
+}
+
+fn run_statements<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ fn_declarations: &mut FxHashMap<
+ Cow<'input, str>,
+ (
+ Vec<ast::Variable<SpirvWord>>,
+ SpirvWord,
+ Vec<ast::Variable<SpirvWord>>,
+ ),
+ >,
+ statements: Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ statements
+ .into_iter()
+ .map(|statement| {
+ Ok(match statement {
+ Statement::Instruction(instruction) => {
+ Statement::Instruction(run_instruction(resolver, fn_declarations, instruction)?)
+ }
+ s => s,
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_instruction<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ fn_declarations: &mut FxHashMap<
+ Cow<'input, str>,
+ (
+ Vec<ast::Variable<SpirvWord>>,
+ SpirvWord,
+ Vec<ast::Variable<SpirvWord>>,
+ ),
+ >,
+ instruction: ptx_parser::Instruction<SpirvWord>,
+) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
+ Ok(match instruction {
+ i @ ptx_parser::Instruction::Activemask { .. } => {
+ to_call(resolver, fn_declarations, "activemask".into(), i)?
+ }
+ i @ ptx_parser::Instruction::Bfe { data, .. } => {
+ let name = ["bfe_", scalar_to_ptx_name(data)].concat();
+ to_call(resolver, fn_declarations, name.into(), i)?
+ }
+ i @ ptx_parser::Instruction::Bfi { data, .. } => {
+ let name = ["bfi_", scalar_to_ptx_name(data)].concat();
+ to_call(resolver, fn_declarations, name.into(), i)?
+ }
+ i => i,
+ })
+}
+
+fn to_call<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ fn_declarations: &mut FxHashMap<
+ Cow<'input, str>,
+ (
+ Vec<ast::Variable<SpirvWord>>,
+ SpirvWord,
+ Vec<ast::Variable<SpirvWord>>,
+ ),
+ >,
+ name: Cow<'input, str>,
+ i: ast::Instruction<SpirvWord>,
+) -> Result<ptx_parser::Instruction<SpirvWord>, TranslateError> {
+ let mut data_return = Vec::new();
+ let mut data_input = Vec::new();
+ let mut arguments_return = Vec::new();
+ let mut arguments_input = Vec::new();
+ ast::visit(&i, &mut |name: &SpirvWord,
+ type_space: Option<(
+ &ptx_parser::Type,
+ ptx_parser::StateSpace,
+ )>,
+ is_dst: bool,
+ _: bool| {
+ let (type_, space) = type_space.ok_or_else(error_mismatched_type)?;
+ if is_dst {
+ data_return.push((type_.clone(), space));
+ arguments_return.push(*name);
+ } else {
+ data_input.push((type_.clone(), space));
+ arguments_input.push(*name);
+ };
+ Ok::<_, TranslateError>(())
+ })?;
+ let fn_name = match fn_declarations.entry(name) {
+ hash_map::Entry::Occupied(occupied_entry) => occupied_entry.get().1,
+ hash_map::Entry::Vacant(vacant_entry) => {
+ let name = vacant_entry.key().clone();
+ let full_name = [ZLUDA_PTX_PREFIX, &*name].concat();
+ let name = resolver.register_named(Cow::Owned(full_name.clone()), None);
+ vacant_entry.insert((
+ to_variables(resolver, &data_return),
+ name,
+ to_variables(resolver, &data_input),
+ ));
+ name
+ }
+ };
+ Ok(ast::Instruction::Call {
+ data: ptx_parser::CallDetails {
+ uniform: false,
+ return_arguments: data_return,
+ input_arguments: data_input,
+ },
+ arguments: ptx_parser::CallArgs {
+ return_arguments: arguments_return,
+ func: fn_name,
+ input_arguments: arguments_input,
+ },
+ })
+}
+
+fn to_variables<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ arguments: &Vec<(ptx_parser::Type, ptx_parser::StateSpace)>,
+) -> Vec<ptx_parser::Variable<SpirvWord>> {
+ arguments
+ .iter()
+ .map(|(type_, space)| ast::Variable {
+ align: None,
+ v_type: type_.clone(),
+ state_space: *space,
+ name: resolver.register_unnamed(Some((type_.clone(), *space))),
+ array_init: Vec::new(),
+ })
+ .collect::<Vec<_>>()
+}
|