// We use Raw LLVM-C bindings here because using inkwell is just not worth it. // Specifically the issue is with builder functions. We maintain the mapping // between ZLUDA identifiers and LLVM values. When using inkwell, LLVM values // are kept as instances `AnyValueEnum`. Now look at the signature of // `Builder::build_int_add(...)`: // pub fn build_int_add>(&self, lhs: T, rhs: T, name: &str, ) -> Result // At this point both lhs and rhs are `AnyValueEnum`. To call // `build_int_add(...)` we would have to do something like this: // if let (Ok(lhs), Ok(rhs)) = (lhs.as_int(), rhs.as_int()) { // builder.build_int_add(lhs, rhs, dst)?; // } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_pointer(), rhs.as_pointer()) { // builder.build_int_add(lhs, rhs, dst)?; // } else if let (Ok(lhs), Ok(rhs)) = (lhs.as_vector(), rhs.as_vector()) { // builder.build_int_add(lhs, rhs, dst)?; // } else { // return Err(error_unrachable()); // } // while with plain LLVM-C it's just: // unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) }; // AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete. // Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with // "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all" // shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel", // but it will too fail similarly, but with "unable to legalize instruction" use std::array::TryFromSliceError; use std::convert::TryInto; use std::ffi::{CStr, NulError}; use std::ops::Deref; use std::{i8, ptr}; use super::*; use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule}; use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer; use llvm_zluda::{core::*, *}; use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW}; use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca}; const LLVM_UNNAMED: &CStr = c""; // https://llvm.org/docs/AMDGPUUsage.html#address-spaces const GENERIC_ADDRESS_SPACE: u32 = 0; const GLOBAL_ADDRESS_SPACE: u32 = 1; const SHARED_ADDRESS_SPACE: u32 = 3; const CONSTANT_ADDRESS_SPACE: u32 = 4; const PRIVATE_ADDRESS_SPACE: u32 = 5; struct Context(LLVMContextRef); impl Context { fn new() -> Self { Self(unsafe { LLVMContextCreate() }) } fn get(&self) -> LLVMContextRef { self.0 } } impl Drop for Context { fn drop(&mut self) { unsafe { LLVMContextDispose(self.0); } } } struct Module(LLVMModuleRef); impl Module { fn new(ctx: &Context, name: &CStr) -> Self { Self(unsafe { LLVMModuleCreateWithNameInContext(name.as_ptr(), ctx.get()) }) } fn get(&self) -> LLVMModuleRef { self.0 } fn verify(&self) -> Result<(), Message> { let mut err = ptr::null_mut(); let error = unsafe { LLVMVerifyModule( self.get(), LLVMVerifierFailureAction::LLVMReturnStatusAction, &mut err, ) }; if error == 1 && err != ptr::null_mut() { Err(Message(unsafe { CStr::from_ptr(err) })) } else { Ok(()) } } fn write_bitcode_to_memory(&self) -> MemoryBuffer { let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) }; MemoryBuffer(memory_buffer) } fn write_to_stderr(&self) { unsafe { LLVMDumpModule(self.get()) }; } } impl Drop for Module { fn drop(&mut self) { unsafe { LLVMDisposeModule(self.0); } } } struct Builder(LLVMBuilderRef); impl Builder { fn new(ctx: &Context) -> Self { Self::new_raw(ctx.get()) } fn new_raw(ctx: LLVMContextRef) -> Self { Self(unsafe { LLVMCreateBuilderInContext(ctx) }) } fn get(&self) -> LLVMBuilderRef { self.0 } } impl Drop for Builder { fn drop(&mut self) { unsafe { LLVMDisposeBuilder(self.0); } } } struct Message(&'static CStr); impl Drop for Message { fn drop(&mut self) { unsafe { LLVMDisposeMessage(self.0.as_ptr().cast_mut()); } } } impl std::fmt::Debug for Message { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Debug::fmt(&self.0, f) } } pub struct MemoryBuffer(LLVMMemoryBufferRef); impl Drop for MemoryBuffer { fn drop(&mut self) { unsafe { LLVMDisposeMemoryBuffer(self.0); } } } impl Deref for MemoryBuffer { type Target = [u8]; fn deref(&self) -> &Self::Target { let data = unsafe { LLVMGetBufferStart(self.0) }; let len = unsafe { LLVMGetBufferSize(self.0) }; unsafe { std::slice::from_raw_parts(data.cast(), len) } } } pub(super) fn run<'input>( id_defs: GlobalStringIdentResolver2<'input>, directives: Vec, SpirvWord>>, ) -> Result { let context = Context::new(); let module = Module::new(&context, LLVM_UNNAMED); let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs); for directive in directives { match directive { Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?, Directive2::Method(method) => emit_ctx.emit_method(method)?, } } module.write_to_stderr(); if let Err(err) = module.verify() { panic!("{:?}", err); } Ok(module.write_bitcode_to_memory()) } struct ModuleEmitContext<'a, 'input> { context: LLVMContextRef, module: LLVMModuleRef, builder: Builder, id_defs: &'a GlobalStringIdentResolver2<'input>, resolver: ResolveIdent, } impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn new( context: &Context, module: &Module, id_defs: &'a GlobalStringIdentResolver2<'input>, ) -> Self { ModuleEmitContext { context: context.get(), module: module.get(), builder: Builder::new(context), id_defs, resolver: ResolveIdent::new(&id_defs), } } fn kernel_call_convention() -> u32 { LLVMCallConv::LLVMAMDGPUKERNELCallConv as u32 } fn func_call_convention() -> u32 { LLVMCallConv::LLVMCCallConv as u32 } fn emit_method( &mut self, method: Function2<'input, ast::Instruction, SpirvWord>, ) -> Result<(), TranslateError> { let func_decl = method.func_decl; let name = method .import_as .as_deref() .or_else(|| match func_decl.name { ast::MethodName::Kernel(name) => Some(name), ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(), }) .ok_or_else(|| error_unreachable())?; let name = CString::new(name).map_err(|_| error_unreachable())?; let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; if fn_ == ptr::null_mut() { let fn_type = get_function_type( self.context, func_decl.return_arguments.iter().map(|v| &v.v_type), func_decl .input_arguments .iter() .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), )?; fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; } if let ast::MethodName::Func(name) = func_decl.name { self.resolver.register(name, fn_); } for (i, param) in func_decl.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; self.resolver.register(param.name, value); if func_decl.name.is_kernel() { let attr_kind = unsafe { LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len()) }; let attr = unsafe { LLVMCreateTypeAttribute( self.context, attr_kind, get_type(self.context, ¶m.v_type)?, ) }; unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; } } let call_conv = if func_decl.name.is_kernel() { Self::kernel_call_convention() } else { Self::func_call_convention() }; unsafe { LLVMSetFunctionCallConv(fn_, call_conv) }; if let Some(statements) = method.body { let variables_bb = unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; let variables_builder = Builder::new_raw(self.context); unsafe { LLVMPositionBuilderAtEnd(variables_builder.get(), variables_bb) }; let real_bb = unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) }; unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) }; let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder); for var in func_decl.return_arguments { method_emitter.emit_variable(var)?; } for statement in statements.iter() { if let Statement::Label(label) = statement { method_emitter.emit_label_initial(*label); } } for statement in statements { method_emitter.emit_statement(statement)?; } unsafe { LLVMBuildBr(method_emitter.variables_builder.get(), real_bb) }; } Ok(()) } fn emit_global( &mut self, _linking: ast::LinkingDirective, var: ast::Variable, ) -> Result<(), TranslateError> { let name = self .id_defs .ident_map .get(&var.name) .map(|entry| { entry .name .as_ref() .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?))) }) .flatten() .transpose() .map_err(|_| error_unreachable())? .unwrap_or(Cow::Borrowed(LLVM_UNNAMED)); let global = unsafe { LLVMAddGlobalInAddressSpace( self.module, get_type(self.context, &var.v_type)?, name.as_ptr(), get_state_space(var.state_space)?, ) }; self.resolver.register(var.name, global); if let Some(align) = var.align { unsafe { LLVMSetAlignment(global, align) }; } if !var.array_init.is_empty() { self.emit_array_init(&var.v_type, &*var.array_init, global)?; } Ok(()) } // TODO: instead of Vec we should emit a typed initializer fn emit_array_init( &mut self, type_: &ast::Type, array_init: &[u8], global: *mut llvm_zluda::LLVMValue, ) -> Result<(), TranslateError> { match type_ { ast::Type::Array(None, scalar, dimensions) => { if dimensions.len() != 1 { todo!() } if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() { return Err(error_unreachable()); } let type_ = get_scalar_type(self.context, *scalar); let mut elements = array_init .chunks(scalar.size_of() as usize) .map(|chunk| self.constant_from_bytes(*scalar, chunk, type_)) .collect::, _>>() .map_err(|_| error_unreachable())?; let initializer = unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) }; unsafe { LLVMSetInitializer(global, initializer) }; } _ => todo!(), } Ok(()) } fn constant_from_bytes( &self, scalar: ast::ScalarType, bytes: &[u8], llvm_type: LLVMTypeRef, ) -> Result { Ok(match scalar { ptx_parser::ScalarType::Pred | ptx_parser::ScalarType::S8 | ptx_parser::ScalarType::B8 | ptx_parser::ScalarType::U8 => unsafe { LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0) }, ptx_parser::ScalarType::S16 | ptx_parser::ScalarType::B16 | ptx_parser::ScalarType::U16 => unsafe { LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0) }, ptx_parser::ScalarType::S32 | ptx_parser::ScalarType::B32 | ptx_parser::ScalarType::U32 => unsafe { LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0) }, ptx_parser::ScalarType::F16 => todo!(), ptx_parser::ScalarType::BF16 => todo!(), ptx_parser::ScalarType::U64 => todo!(), ptx_parser::ScalarType::S64 => todo!(), ptx_parser::ScalarType::S16x2 => todo!(), ptx_parser::ScalarType::F32 => todo!(), ptx_parser::ScalarType::B64 => todo!(), ptx_parser::ScalarType::F64 => todo!(), ptx_parser::ScalarType::B128 => todo!(), ptx_parser::ScalarType::U16x2 => todo!(), ptx_parser::ScalarType::F16x2 => todo!(), ptx_parser::ScalarType::BF16x2 => todo!(), }) } } fn get_input_argument_type( context: LLVMContextRef, v_type: &ast::Type, state_space: ast::StateSpace, ) -> Result { match state_space { ast::StateSpace::ParamEntry => { Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) }) } ast::StateSpace::Reg => get_type(context, v_type), _ => return Err(error_unreachable()), } } struct MethodEmitContext<'a> { context: LLVMContextRef, module: LLVMModuleRef, method: LLVMValueRef, builder: LLVMBuilderRef, variables_builder: Builder, resolver: &'a mut ResolveIdent, } impl<'a> MethodEmitContext<'a> { fn new( parent: &'a mut ModuleEmitContext, method: LLVMValueRef, variables_builder: Builder, ) -> MethodEmitContext<'a> { MethodEmitContext { context: parent.context, module: parent.module, builder: parent.builder.get(), variables_builder, resolver: &mut parent.resolver, method, } } fn emit_statement( &mut self, statement: Statement, SpirvWord>, ) -> Result<(), TranslateError> { Ok(match statement { Statement::Variable(var) => self.emit_variable(var)?, Statement::Label(label) => self.emit_label_delayed(label)?, Statement::Instruction(inst) => self.emit_instruction(inst)?, Statement::Conditional(cond) => self.emit_conditional(cond)?, Statement::Conversion(conversion) => self.emit_conversion(conversion)?, Statement::Constant(constant) => self.emit_constant(constant)?, Statement::RetValue(_, values) => self.emit_ret_value(values)?, Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?, Statement::RepackVector(repack) => self.emit_vector_repack(repack)?, Statement::FunctionPointer(_) => todo!(), Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?, Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?, }) } fn emit_variable(&mut self, var: ast::Variable) -> Result<(), TranslateError> { let alloca = unsafe { LLVMZludaBuildAlloca( self.variables_builder.get(), get_type(self.context, &var.v_type)?, get_state_space(var.state_space)?, self.resolver.get_or_add_raw(var.name), ) }; self.resolver.register(var.name, alloca); if let Some(align) = var.align { unsafe { LLVMSetAlignment(alloca, align) }; } if !var.array_init.is_empty() { todo!() } Ok(()) } fn emit_label_initial(&mut self, label: SpirvWord) { let block = unsafe { LLVMAppendBasicBlockInContext( self.context, self.method, self.resolver.get_or_add_raw(label), ) }; self.resolver .register(label, unsafe { LLVMBasicBlockAsValue(block) }); } fn emit_label_delayed(&mut self, label: SpirvWord) -> Result<(), TranslateError> { let block = self.resolver.value(label)?; let block = unsafe { LLVMValueAsBasicBlock(block) }; let last_block = unsafe { LLVMGetInsertBlock(self.builder) }; if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() { unsafe { LLVMBuildBr(self.builder, block) }; } unsafe { LLVMPositionBuilderAtEnd(self.builder, block) }; Ok(()) } fn emit_instruction( &mut self, inst: ast::Instruction, ) -> Result<(), TranslateError> { match inst { ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments), ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments), ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments), ast::Instruction::SetpBool { .. } => todo!(), ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments), ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments), ast::Instruction::And { arguments, .. } => self.emit_and(arguments), ast::Instruction::Bra { arguments } => self.emit_bra(arguments), ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments), ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments), ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments), ast::Instruction::Ret { data } => Ok(self.emit_ret(data)), ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments), ast::Instruction::Abs { .. } => todo!(), ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments), ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments), ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments), ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments), ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments), ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments), ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments), ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments), ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments), ast::Instruction::Bar { .. } => todo!(), ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments), ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments), ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments), ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments), ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments), ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments), ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments), ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments), ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments), ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments), ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments), ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments), ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments), ast::Instruction::PrmtSlow { .. } => todo!(), ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments), ast::Instruction::Membar { data } => self.emit_membar(data), ast::Instruction::Trap {} => todo!(), // replaced by a function call ast::Instruction::Bfe { .. } | ast::Instruction::Bfi { .. } | ast::Instruction::Activemask { .. } => return Err(error_unreachable()), } } fn emit_ld( &mut self, data: ast::LdDetails, arguments: ast::LdArgs, ) -> Result<(), TranslateError> { if data.qualifier != ast::LdStQualifier::Weak { todo!() } let builder = self.builder; let type_ = get_type(self.context, &data.typ)?; let ptr = self.resolver.value(arguments.src)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildLoad2(builder, type_, ptr, dst) }); Ok(()) } fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> { let builder = self.builder; match conversion.kind { ConversionKind::Default => self.emit_conversion_default( self.resolver.value(conversion.src)?, conversion.dst, &conversion.from_type, conversion.from_space, &conversion.to_type, conversion.to_space, ), ConversionKind::SignExtend => { let src = self.resolver.value(conversion.src)?; let type_ = get_type(self.context, &conversion.to_type)?; self.resolver.with_result(conversion.dst, |dst| unsafe { LLVMBuildSExt(builder, src, type_, dst) }); Ok(()) } ConversionKind::BitToPtr => { let src = self.resolver.value(conversion.src)?; let type_ = get_pointer_type(self.context, conversion.to_space)?; self.resolver.with_result(conversion.dst, |dst| unsafe { LLVMBuildIntToPtr(builder, src, type_, dst) }); Ok(()) } ConversionKind::PtrToPtr => { let src = self.resolver.value(conversion.src)?; let dst_type = get_pointer_type(self.context, conversion.to_space)?; self.resolver.with_result(conversion.dst, |dst| unsafe { LLVMBuildAddrSpaceCast(builder, src, dst_type, dst) }); Ok(()) } ConversionKind::AddressOf => { let src = self.resolver.value(conversion.src)?; let dst_type = get_type(self.context, &conversion.to_type)?; self.resolver.with_result(conversion.dst, |dst| unsafe { LLVMBuildPtrToInt(self.builder, src, dst_type, dst) }); Ok(()) } } } fn emit_conversion_default( &mut self, src: LLVMValueRef, dst: SpirvWord, from_type: &ast::Type, from_space: ast::StateSpace, to_type: &ast::Type, to_space: ast::StateSpace, ) -> Result<(), TranslateError> { match (from_type, to_type) { (ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => { let from_layout = from_type.layout(); let to_layout = to_type.layout(); if from_layout.size() == to_layout.size() { let dst_type = get_type(self.context, &to_type)?; if from_type.kind() != ast::ScalarKind::Float && to_type_scalar.kind() != ast::ScalarKind::Float { // It is noop, but another instruction expects result of this conversion self.resolver.register(dst, src); } else { self.resolver.with_result(dst, |dst| unsafe { LLVMBuildBitCast(self.builder, src, dst_type, dst) }); } Ok(()) } else { // This block is safe because it's illegal to implictly convert between floating point values let same_width_bit_type = unsafe { LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32) }; let same_width_bit_value = unsafe { LLVMBuildBitCast( self.builder, src, same_width_bit_type, LLVM_UNNAMED.as_ptr(), ) }; let wide_bit_type = match to_type_scalar.layout().size() { 1 => ast::ScalarType::B8, 2 => ast::ScalarType::B16, 4 => ast::ScalarType::B32, 8 => ast::ScalarType::B64, _ => return Err(error_unreachable()), }; let wide_bit_type_llvm = unsafe { LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32) }; if to_type_scalar.kind() == ast::ScalarKind::Unsigned || to_type_scalar.kind() == ast::ScalarKind::Bit { let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() { LLVMBuildZExtOrBitCast } else { LLVMBuildTrunc }; self.resolver.with_result(dst, |dst| unsafe { llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst) }); Ok(()) } else { let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed && to_type_scalar.kind() == ast::ScalarKind::Signed { if to_type_scalar.size_of() >= from_type.size_of() { LLVMBuildSExtOrBitCast } else { LLVMBuildTrunc } } else { if to_type_scalar.size_of() >= from_type.size_of() { LLVMBuildZExtOrBitCast } else { LLVMBuildTrunc } }; let wide_bit_value = unsafe { conversion_fn( self.builder, same_width_bit_value, wide_bit_type_llvm, LLVM_UNNAMED.as_ptr(), ) }; self.emit_conversion_default( wide_bit_value, dst, &wide_bit_type.into(), from_space, to_type, to_space, ) } } } (ast::Type::Vector(..), ast::Type::Scalar(..)) | (ast::Type::Scalar(..), ast::Type::Array(..)) | (ast::Type::Array(..), ast::Type::Scalar(..)) => { let dst_type = get_type(self.context, to_type)?; self.resolver.with_result(dst, |dst| unsafe { LLVMBuildBitCast(self.builder, src, dst_type, dst) }); Ok(()) } _ => todo!(), } } fn emit_constant(&mut self, constant: ConstantDefinition) -> Result<(), TranslateError> { let type_ = get_scalar_type(self.context, constant.typ); let value = match constant.value { ast::ImmediateValue::U64(x) => unsafe { LLVMConstInt(type_, x, 0) }, ast::ImmediateValue::S64(x) => unsafe { LLVMConstInt(type_, x as u64, 0) }, ast::ImmediateValue::F32(x) => unsafe { LLVMConstReal(type_, x as f64) }, ast::ImmediateValue::F64(x) => unsafe { LLVMConstReal(type_, x) }, }; self.resolver.register(constant.dst, value); Ok(()) } fn emit_add( &mut self, data: ast::ArithDetails, arguments: ast::AddArgs, ) -> Result<(), TranslateError> { let builder = self.builder; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; let fn_ = match data { ast::ArithDetails::Integer(..) => LLVMBuildAdd, ast::ArithDetails::Float(..) => LLVMBuildFAdd, }; self.resolver.with_result(arguments.dst, |dst| unsafe { fn_(builder, src1, src2, dst) }); Ok(()) } fn emit_st( &self, data: ast::StData, arguments: ast::StArgs, ) -> Result<(), TranslateError> { let ptr = self.resolver.value(arguments.src1)?; let value = self.resolver.value(arguments.src2)?; if data.qualifier != ast::LdStQualifier::Weak { todo!() } unsafe { LLVMBuildStore(self.builder, value, ptr) }; Ok(()) } fn emit_ret(&self, _data: ast::RetData) { unsafe { LLVMBuildRetVoid(self.builder) }; } fn emit_call( &mut self, data: ast::CallDetails, arguments: ast::CallArgs, ) -> Result<(), TranslateError> { if cfg!(debug_assertions) { for (_, space) in data.return_arguments.iter() { if *space != ast::StateSpace::Reg { panic!() } } for (_, space) in data.input_arguments.iter() { if *space != ast::StateSpace::Reg { panic!() } } } let name = match &*arguments.return_arguments { [] => LLVM_UNNAMED.as_ptr(), [dst] => self.resolver.get_or_add_raw(*dst), _ => todo!(), }; let type_ = get_function_type( self.context, data.return_arguments.iter().map(|(type_, ..)| type_), data.input_arguments .iter() .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)), )?; let mut input_arguments = arguments .input_arguments .iter() .map(|arg| self.resolver.value(*arg)) .collect::, _>>()?; let llvm_fn = unsafe { LLVMBuildCall2( self.builder, type_, self.resolver.value(arguments.func)?, input_arguments.as_mut_ptr(), input_arguments.len() as u32, name, ) }; match &*arguments.return_arguments { [] => {} [name] => { self.resolver.register(*name, llvm_fn); } _ => todo!(), } Ok(()) } fn emit_mov( &mut self, _data: ast::MovDetails, arguments: ast::MovArgs, ) -> Result<(), TranslateError> { self.resolver .register(arguments.dst, self.resolver.value(arguments.src)?); Ok(()) } fn emit_ptr_access(&mut self, ptr_access: PtrAccess) -> Result<(), TranslateError> { let ptr_src = self.resolver.value(ptr_access.ptr_src)?; let mut offset_src = self.resolver.value(ptr_access.offset_src)?; let pointee_type = get_scalar_type(self.context, ast::ScalarType::B8); self.resolver.with_result(ptr_access.dst, |dst| unsafe { LLVMBuildInBoundsGEP2(self.builder, pointee_type, ptr_src, &mut offset_src, 1, dst) }); Ok(()) } fn emit_and(&mut self, arguments: ast::AndArgs) -> Result<(), TranslateError> { let builder = self.builder; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildAnd(builder, src1, src2, dst) }); Ok(()) } fn emit_atom( &mut self, data: ast::AtomDetails, arguments: ast::AtomArgs, ) -> Result<(), TranslateError> { let builder = self.builder; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; let op = match data.op { ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd, ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr, ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor, ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg, ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd, ast::AtomicOp::IncrementWrap => { LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap } ast::AtomicOp::DecrementWrap => { LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap } ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin, ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin, ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax, ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax, ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd, ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin, ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax, }; self.resolver.register(arguments.dst, unsafe { LLVMZludaBuildAtomicRMW( builder, op, src1, src2, get_scope(data.scope)?, get_ordering(data.semantics), ) }); Ok(()) } fn emit_atom_cas( &mut self, data: ast::AtomCasDetails, arguments: ast::AtomCasArgs, ) -> Result<(), TranslateError> { let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; let src3 = self.resolver.value(arguments.src3)?; let success_ordering = get_ordering(data.semantics); let failure_ordering = get_ordering_failure(data.semantics); let temp = unsafe { LLVMZludaBuildAtomicCmpXchg( self.builder, src1, src2, src3, get_scope(data.scope)?, success_ordering, failure_ordering, ) }; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildExtractValue(self.builder, temp, 0, dst) }); Ok(()) } fn emit_bra(&self, arguments: ast::BraArgs) -> Result<(), TranslateError> { let src = self.resolver.value(arguments.src)?; let src = unsafe { LLVMValueAsBasicBlock(src) }; unsafe { LLVMBuildBr(self.builder, src) }; Ok(()) } fn emit_brev( &mut self, data: ast::ScalarType, arguments: ast::BrevArgs, ) -> Result<(), TranslateError> { let llvm_fn = match data.size_of() { 4 => c"llvm.bitreverse.i32", 8 => c"llvm.bitreverse.i64", _ => return Err(error_unreachable()), }; let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; let type_ = get_scalar_type(self.context, data); let fn_type = get_function_type( self.context, iter::once(&data.into()), iter::once(Ok(type_)), )?; if fn_ == ptr::null_mut() { fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; } let mut src = self.resolver.value(arguments.src)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst) }); Ok(()) } fn emit_ret_value( &mut self, values: Vec<(SpirvWord, ptx_parser::Type)>, ) -> Result<(), TranslateError> { match &*values { [] => unsafe { LLVMBuildRetVoid(self.builder) }, [(value, type_)] => { let value = self.resolver.value(*value)?; let type_ = get_type(self.context, type_)?; let value = unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) }; unsafe { LLVMBuildRet(self.builder, value) } } _ => todo!(), }; Ok(()) } fn emit_clz( &mut self, data: ptx_parser::ScalarType, arguments: ptx_parser::ClzArgs, ) -> Result<(), TranslateError> { let llvm_fn = match data.size_of() { 4 => c"llvm.ctlz.i32", 8 => c"llvm.ctlz.i64", _ => return Err(error_unreachable()), }; let type_ = get_scalar_type(self.context, data.into()); let pred = get_scalar_type(self.context, ast::ScalarType::Pred); let fn_type = get_function_type( self.context, iter::once(&ast::ScalarType::U32.into()), [Ok(type_), Ok(pred)].into_iter(), )?; let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) }; if fn_ == ptr::null_mut() { fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) }; } let src = self.resolver.value(arguments.src)?; let false_ = unsafe { LLVMConstInt(pred, 0, 0) }; let mut args = [src, false_]; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildCall2( self.builder, fn_type, fn_, args.as_mut_ptr(), args.len() as u32, dst, ) }); Ok(()) } fn emit_mul( &mut self, data: ast::MulDetails, arguments: ast::MulArgs, ) -> Result<(), TranslateError> { self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?; Ok(()) } fn emit_mul_impl( &mut self, data: ast::MulDetails, dst: Option, src1: SpirvWord, src2: SpirvWord, ) -> Result { let mul_fn = match data { ast::MulDetails::Integer { control, type_ } => match control { ast::MulIntControl::Low => LLVMBuildMul, ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2), ast::MulIntControl::Wide => { return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1) } }, ast::MulDetails::Float(..) => LLVMBuildFMul, }; let src1 = self.resolver.value(src1)?; let src2 = self.resolver.value(src2)?; Ok(self .resolver .with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) })) } fn emit_mul_high( &mut self, type_: ptx_parser::ScalarType, dst: Option, src1: SpirvWord, src2: SpirvWord, ) -> Result { let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?; let shift_constant = unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) }; let shifted = unsafe { LLVMBuildLShr( self.builder, wide_value, shift_constant, LLVM_UNNAMED.as_ptr(), ) }; let narrow_type = get_scalar_type(self.context, type_); Ok(self.resolver.with_result_option(dst, |dst| unsafe { LLVMBuildTrunc(self.builder, shifted, narrow_type, dst) })) } fn emit_mul_wide_impl( &mut self, type_: ptx_parser::ScalarType, dst: Option, src1: SpirvWord, src2: SpirvWord, ) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> { let src1 = self.resolver.value(src1)?; let src2 = self.resolver.value(src2)?; let wide_type = unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) }; let llvm_cast = match type_.kind() { ptx_parser::ScalarKind::Signed => LLVMBuildSExt, ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt, _ => return Err(error_unreachable()), }; let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) }; let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) }; Ok(( wide_type, self.resolver.with_result_option(dst, |dst| unsafe { LLVMBuildMul(self.builder, src1, src2, dst) }), )) } fn emit_cos( &mut self, _data: ast::FlushToZero, arguments: ast::CosArgs, ) -> Result<(), TranslateError> { let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); let cos = self.emit_intrinsic( c"llvm.cos.f32", Some(arguments.dst), &ast::ScalarType::F32.into(), vec![(self.resolver.value(arguments.src)?, llvm_f32)], )?; unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) } Ok(()) } fn emit_or( &mut self, _data: ptx_parser::ScalarType, arguments: ptx_parser::OrArgs, ) -> Result<(), TranslateError> { let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildOr(self.builder, src1, src2, dst) }); Ok(()) } fn emit_xor( &mut self, _data: ptx_parser::ScalarType, arguments: ptx_parser::XorArgs, ) -> Result<(), TranslateError> { let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildXor(self.builder, src1, src2, dst) }); Ok(()) } fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> { let src = self.resolver.value(vec_acccess.vector_src)?; let index = unsafe { LLVMConstInt( get_scalar_type(self.context, ast::ScalarType::B8), vec_acccess.member as _, 0, ) }; self.resolver .with_result(vec_acccess.scalar_dst, |dst| unsafe { LLVMBuildExtractElement(self.builder, src, index, dst) }); Ok(()) } fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> { let vector_src = self.resolver.value(vector_write.vector_src)?; let scalar_src = self.resolver.value(vector_write.scalar_src)?; let index = unsafe { LLVMConstInt( get_scalar_type(self.context, ast::ScalarType::B8), vector_write.member as _, 0, ) }; self.resolver .with_result(vector_write.vector_dst, |dst| unsafe { LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst) }); Ok(()) } fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> { let i8_type = get_scalar_type(self.context, ast::ScalarType::B8); if repack.is_extract { let src = self.resolver.value(repack.packed)?; for (index, dst) in repack.unpacked.iter().enumerate() { let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) }; self.resolver.with_result(*dst, |dst| unsafe { LLVMBuildExtractElement(self.builder, src, index, dst) }); } } else { let vector_type = get_type( self.context, &ast::Type::Vector(repack.unpacked.len() as u8, repack.typ), )?; let mut temp_vec = unsafe { LLVMGetUndef(vector_type) }; for (index, src_id) in repack.unpacked.iter().enumerate() { let dst = if index == repack.unpacked.len() - 1 { Some(repack.packed) } else { None }; let scalar_src = self.resolver.value(*src_id)?; let index = unsafe { LLVMConstInt(i8_type, index as _, 0) }; temp_vec = self.resolver.with_result_option(dst, |dst| unsafe { LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst) }); } } Ok(()) } fn emit_div( &mut self, data: ptx_parser::DivDetails, arguments: ptx_parser::DivArgs, ) -> Result<(), TranslateError> { let integer_div = match data { ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv, ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv, ptx_parser::DivDetails::Float(float_div) => { return self.emit_div_float(float_div, arguments) } }; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst| unsafe { integer_div(self.builder, src1, src2, dst) }); Ok(()) } fn emit_div_float( &mut self, float_div: ptx_parser::DivFloatDetails, arguments: ptx_parser::DivArgs, ) -> Result<(), TranslateError> { let builder = self.builder; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; let _rnd = match float_div.kind { ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven, ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven, ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode, }; let approx = match float_div.kind { ptx_parser::DivFloatKind::Approx => { LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc } ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone, ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone, }; let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildFDiv(builder, src1, src2, dst) }); unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) }; if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind { // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div: // div.full.f32 implements a relatively fast, full-range approximation that scales // operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not // support rounding modifiers. The maximum ulp error is 2 across the full range of // inputs. // https://llvm.org/docs/LangRef.html#fpmath-metadata let fpmath_value = unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) }; let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) }; let mut md_node_content = [fpmath_value]; let md_node = unsafe { LLVMMDNodeInContext2( self.context, md_node_content.as_mut_ptr(), md_node_content.len(), ) }; let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) }; let kind = unsafe { LLVMGetMDKindIDInContext( self.context, "fpmath".as_ptr().cast(), "fpmath".len() as u32, ) }; unsafe { LLVMSetMetadata(fdiv, kind, md_node) }; } Ok(()) } fn emit_cvta( &mut self, data: ptx_parser::CvtaDetails, arguments: ptx_parser::CvtaArgs, ) -> Result<(), TranslateError> { let (from_space, to_space) = match data.direction { ptx_parser::CvtaDirection::GenericToExplicit => { (ast::StateSpace::Generic, data.state_space) } ptx_parser::CvtaDirection::ExplicitToGeneric => { (data.state_space, ast::StateSpace::Generic) } }; let from_type = get_pointer_type(self.context, from_space)?; let dest_type = get_pointer_type(self.context, to_space)?; let src = self.resolver.value(arguments.src)?; let temp_ptr = unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) }; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst) }); Ok(()) } fn emit_sub( &mut self, data: ptx_parser::ArithDetails, arguments: ptx_parser::SubArgs, ) -> Result<(), TranslateError> { match data { ptx_parser::ArithDetails::Integer(arith_integer) => { self.emit_sub_integer(arith_integer, arguments) } ptx_parser::ArithDetails::Float(arith_float) => { self.emit_sub_float(arith_float, arguments) } } } fn emit_sub_integer( &mut self, arith_integer: ptx_parser::ArithInteger, arguments: ptx_parser::SubArgs, ) -> Result<(), TranslateError> { if arith_integer.saturate { todo!() } let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildSub(self.builder, src1, src2, dst) }); Ok(()) } fn emit_sub_float( &mut self, arith_float: ptx_parser::ArithFloat, arguments: ptx_parser::SubArgs, ) -> Result<(), TranslateError> { if arith_float.saturate { todo!() } let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildFSub(self.builder, src1, src2, dst) }); Ok(()) } fn emit_sin( &mut self, _data: ptx_parser::FlushToZero, arguments: ptx_parser::SinArgs, ) -> Result<(), TranslateError> { let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32); let sin = self.emit_intrinsic( c"llvm.sin.f32", Some(arguments.dst), &ast::ScalarType::F32.into(), vec![(self.resolver.value(arguments.src)?, llvm_f32)], )?; unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) } Ok(()) } fn emit_intrinsic( &mut self, name: &CStr, dst: Option, return_type: &ast::Type, arguments: Vec<(LLVMValueRef, LLVMTypeRef)>, ) -> Result { let fn_type = get_function_type( self.context, iter::once(return_type), arguments.iter().map(|(_, type_)| Ok(*type_)), )?; let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) }; if fn_ == ptr::null_mut() { fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; } let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::>(); Ok(self.resolver.with_result_option(dst, |dst| unsafe { LLVMBuildCall2( self.builder, fn_type, fn_, arguments.as_mut_ptr(), arguments.len() as u32, dst, ) })) } fn emit_neg( &mut self, data: ptx_parser::TypeFtz, arguments: ptx_parser::NegArgs, ) -> Result<(), TranslateError> { let src = self.resolver.value(arguments.src)?; let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float { LLVMBuildFNeg } else { LLVMBuildNeg }; self.resolver.with_result(arguments.dst, |dst| unsafe { llvm_fn(self.builder, src, dst) }); Ok(()) } fn emit_not( &mut self, _data: ptx_parser::ScalarType, arguments: ptx_parser::NotArgs, ) -> Result<(), TranslateError> { let src = self.resolver.value(arguments.src)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildNot(self.builder, src, dst) }); Ok(()) } fn emit_setp( &mut self, data: ptx_parser::SetpData, arguments: ptx_parser::SetpArgs, ) -> Result<(), TranslateError> { if arguments.dst2.is_some() { todo!() } match data.cmp_op { ptx_parser::SetpCompareOp::Integer(setp_compare_int) => { self.emit_setp_int(setp_compare_int, arguments) } ptx_parser::SetpCompareOp::Float(setp_compare_float) => { self.emit_setp_float(setp_compare_float, arguments) } } } fn emit_setp_int( &mut self, setp: ptx_parser::SetpCompareInt, arguments: ptx_parser::SetpArgs, ) -> Result<(), TranslateError> { let op = match setp { ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ, ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE, ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT, ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE, ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT, ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE, ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT, ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE, ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT, ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE, }; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst1, |dst1| unsafe { LLVMBuildICmp(self.builder, op, src1, src2, dst1) }); Ok(()) } fn emit_setp_float( &mut self, setp: ptx_parser::SetpCompareFloat, arguments: ptx_parser::SetpArgs, ) -> Result<(), TranslateError> { let op = match setp { ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ, ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE, ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT, ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE, ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT, ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE, ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ, ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE, ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT, ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE, ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT, ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE, ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD, ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO, }; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst1, |dst1| unsafe { LLVMBuildFCmp(self.builder, op, src1, src2, dst1) }); Ok(()) } fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> { let predicate = self.resolver.value(cond.predicate)?; let if_true = self.resolver.value(cond.if_true)?; let if_false = self.resolver.value(cond.if_false)?; unsafe { LLVMBuildCondBr( self.builder, predicate, LLVMValueAsBasicBlock(if_true), LLVMValueAsBasicBlock(if_false), ) }; Ok(()) } fn emit_cvt( &mut self, data: ptx_parser::CvtDetails, arguments: ptx_parser::CvtArgs, ) -> Result<(), TranslateError> { let dst_type = get_scalar_type(self.context, data.to); let llvm_fn = match data.mode { ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt, ptx_parser::CvtMode::SignExtend => LLVMBuildSExt, ptx_parser::CvtMode::Truncate => LLVMBuildTrunc, ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast, ptx_parser::CvtMode::SaturateUnsignedToSigned => { return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments) } ptx_parser::CvtMode::SaturateSignedToUnsigned => { return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments) } ptx_parser::CvtMode::FPExtend { flush_to_zero } => LLVMBuildFPExt, ptx_parser::CvtMode::FPTruncate { rounding, flush_to_zero, } => LLVMBuildFPTrunc, ptx_parser::CvtMode::FPRound { integer_rounding, flush_to_zero, } => { return self.emit_cvt_float_to_int( data.from, data.to, integer_rounding.unwrap_or(ast::RoundingMode::NearestEven), arguments, Some(LLVMBuildFPToSI), ) } ptx_parser::CvtMode::SignedFromFP { rounding, flush_to_zero, } => { return self.emit_cvt_float_to_int( data.from, data.to, rounding, arguments, Some(LLVMBuildFPToSI), ) } ptx_parser::CvtMode::UnsignedFromFP { rounding, flush_to_zero, } => { return self.emit_cvt_float_to_int( data.from, data.to, rounding, arguments, Some(LLVMBuildFPToUI), ) } ptx_parser::CvtMode::FPFromSigned(rounding_mode) => todo!(), ptx_parser::CvtMode::FPFromUnsigned(rounding_mode) => todo!(), }; let src = self.resolver.value(arguments.src)?; self.resolver.with_result(arguments.dst, |dst| unsafe { llvm_fn(self.builder, src, dst_type, dst) }); Ok(()) } fn emit_cvt_unsigned_to_signed_sat( &mut self, from: ptx_parser::ScalarType, to: ptx_parser::ScalarType, arguments: ptx_parser::CvtArgs, ) -> Result<(), TranslateError> { // This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1, // so if it's downcast to a smaller type, it will be the maximum value // of the smaller type let max_value = match to { ptx_parser::ScalarType::S8 => i8::MAX as u64, ptx_parser::ScalarType::S16 => i16::MAX as u64, ptx_parser::ScalarType::S32 => i32::MAX as u64, ptx_parser::ScalarType::S64 => i64::MAX as u64, _ => return Err(error_unreachable()), }; let from_llvm = get_scalar_type(self.context, from); let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; let clamped = self.emit_intrinsic( c"llvm.umin", None, &from.into(), vec![ (self.resolver.value(arguments.src)?, from_llvm), (max, from_llvm), ], )?; let resize_fn = if to.layout().size() >= from.layout().size() { LLVMBuildSExtOrBitCast } else { LLVMBuildTrunc }; let to_llvm = get_scalar_type(self.context, to); self.resolver.with_result(arguments.dst, |dst| unsafe { resize_fn(self.builder, clamped, to_llvm, dst) }); Ok(()) } fn emit_cvt_signed_to_unsigned_sat( &mut self, from: ptx_parser::ScalarType, to: ptx_parser::ScalarType, arguments: ptx_parser::CvtArgs, ) -> Result<(), TranslateError> { let from_llvm = get_scalar_type(self.context, from); let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) }; let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from)); let zero_clamped = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) }, None, &from.into(), vec![ (self.resolver.value(arguments.src)?, from_llvm), (zero, from_llvm), ], )?; // zero_clamped is now unsigned let max_value = match to { ptx_parser::ScalarType::U8 => u8::MAX as u64, ptx_parser::ScalarType::U16 => u16::MAX as u64, ptx_parser::ScalarType::U32 => u32::MAX as u64, ptx_parser::ScalarType::U64 => u64::MAX as u64, _ => return Err(error_unreachable()), }; let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) }; let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from)); let fully_clamped = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) }, None, &from.into(), vec![(zero_clamped, from_llvm), (max, from_llvm)], )?; let resize_fn = if to.layout().size() >= from.layout().size() { LLVMBuildZExtOrBitCast } else { LLVMBuildTrunc }; let to_llvm = get_scalar_type(self.context, to); self.resolver.with_result(arguments.dst, |dst| unsafe { resize_fn(self.builder, fully_clamped, to_llvm, dst) }); Ok(()) } fn emit_cvt_float_to_int( &mut self, from: ast::ScalarType, to: ast::ScalarType, rounding: ast::RoundingMode, arguments: ptx_parser::CvtArgs, llvm_cast: Option< unsafe extern "C" fn( arg1: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: *const i8, ) -> LLVMValueRef, >, ) -> Result<(), TranslateError> { let prefix = match rounding { ptx_parser::RoundingMode::NearestEven => "llvm.roundeven", ptx_parser::RoundingMode::Zero => "llvm.trunc", ptx_parser::RoundingMode::NegativeInf => "llvm.floor", ptx_parser::RoundingMode::PositiveInf => "llvm.ceil", }; let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from)); let rounded_float = self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, None, &from.into(), vec![( self.resolver.value(arguments.src)?, get_scalar_type(self.context, from), )], )?; if let Some(llvm_cast) = llvm_cast { let to = get_scalar_type(self.context, to); let poisoned_dst = unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) }; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildFreeze(self.builder, poisoned_dst, dst) }); } else { self.resolver.register(arguments.dst, rounded_float); } // Using explicit saturation gives us worse codegen: it explicitly checks for out of bound // values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt__ which // saturates by default and we don't care about NaNs anyway /* let cast_intrinsic = format!( "{}.{}.{}\0", llvm_cast, LLVMTypeDisplay(to), LLVMTypeDisplay(from) ); self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) }, Some(arguments.dst), &to.into(), vec![(rounded_float, get_scalar_type(self.context, from))], )?; */ Ok(()) } fn emit_rsqrt( &mut self, data: ptx_parser::TypeFtz, arguments: ptx_parser::RsqrtArgs, ) -> Result<(), TranslateError> { let type_ = get_scalar_type(self.context, data.type_); let intrinsic = match data.type_ { ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32", ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64", _ => return Err(error_unreachable()), }; self.emit_intrinsic( intrinsic, Some(arguments.dst), &data.type_.into(), vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) } fn emit_sqrt( &mut self, data: ptx_parser::RcpData, arguments: ptx_parser::SqrtArgs, ) -> Result<(), TranslateError> { let type_ = get_scalar_type(self.context, data.type_); let intrinsic = match (data.type_, data.kind) { (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32", (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32", (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64", _ => return Err(error_unreachable()), }; self.emit_intrinsic( intrinsic, Some(arguments.dst), &data.type_.into(), vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) } fn emit_rcp( &mut self, data: ptx_parser::RcpData, arguments: ptx_parser::RcpArgs, ) -> Result<(), TranslateError> { let type_ = get_scalar_type(self.context, data.type_); let intrinsic = match (data.type_, data.kind) { (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32", (_, ast::RcpKind::Compliant(rnd)) => { return self.emit_rcp_compliant(data, arguments, rnd) } _ => return Err(error_unreachable()), }; self.emit_intrinsic( intrinsic, Some(arguments.dst), &data.type_.into(), vec![(self.resolver.value(arguments.src)?, type_)], )?; Ok(()) } fn emit_rcp_compliant( &mut self, data: ptx_parser::RcpData, arguments: ptx_parser::RcpArgs, _rnd: ast::RoundingMode, ) -> Result<(), TranslateError> { let type_ = get_scalar_type(self.context, data.type_); let one = unsafe { LLVMConstReal(type_, 1.0) }; let src = self.resolver.value(arguments.src)?; let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildFDiv(self.builder, one, src, dst) }); unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) }; Ok(()) } fn emit_shr( &mut self, data: ptx_parser::ShrData, arguments: ptx_parser::ShrArgs, ) -> Result<(), TranslateError> { let shift_fn = match data.kind { ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr, ptx_parser::RightShiftKind::Logical => LLVMBuildLShr, }; self.emit_shift( data.type_, arguments.dst, arguments.src1, arguments.src2, shift_fn, ) } fn emit_shl( &mut self, type_: ptx_parser::ScalarType, arguments: ptx_parser::ShlArgs, ) -> Result<(), TranslateError> { self.emit_shift( type_, arguments.dst, arguments.src1, arguments.src2, LLVMBuildShl, ) } fn emit_shift( &mut self, type_: ast::ScalarType, dst: SpirvWord, src1: SpirvWord, src2: SpirvWord, llvm_fn: unsafe extern "C" fn( LLVMBuilderRef, LLVMValueRef, LLVMValueRef, *const i8, ) -> LLVMValueRef, ) -> Result<(), TranslateError> { let src1 = self.resolver.value(src1)?; let shift_size = self.resolver.value(src2)?; let integer_bits = type_.layout().size() * 8; let integer_bits_constant = unsafe { LLVMConstInt( get_scalar_type(self.context, ast::ScalarType::U32), integer_bits as u64, 0, ) }; let should_clamp = unsafe { LLVMBuildICmp( self.builder, LLVMIntPredicate::LLVMIntUGE, shift_size, integer_bits_constant, LLVM_UNNAMED.as_ptr(), ) }; let llvm_type = get_scalar_type(self.context, type_); let zero = unsafe { LLVMConstNull(llvm_type) }; let normalized_shift_size = if type_.layout().size() >= 4 { unsafe { LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } } else { unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) } }; let shifted = unsafe { llvm_fn( self.builder, src1, normalized_shift_size, LLVM_UNNAMED.as_ptr(), ) }; self.resolver.with_result(dst, |dst| unsafe { LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst) }); Ok(()) } fn emit_ex2( &mut self, data: ptx_parser::TypeFtz, arguments: ptx_parser::Ex2Args, ) -> Result<(), TranslateError> { let intrinsic = match data.type_ { ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16", ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32", _ => return Err(error_unreachable()), }; self.emit_intrinsic( intrinsic, Some(arguments.dst), &data.type_.into(), vec![( self.resolver.value(arguments.src)?, get_scalar_type(self.context, data.type_), )], )?; Ok(()) } fn emit_lg2( &mut self, _data: ptx_parser::FlushToZero, arguments: ptx_parser::Lg2Args, ) -> Result<(), TranslateError> { self.emit_intrinsic( c"llvm.amdgcn.log.f32", Some(arguments.dst), &ast::ScalarType::F32.into(), vec![( self.resolver.value(arguments.src)?, get_scalar_type(self.context, ast::ScalarType::F32.into()), )], )?; Ok(()) } fn emit_selp( &mut self, _data: ptx_parser::ScalarType, arguments: ptx_parser::SelpArgs, ) -> Result<(), TranslateError> { let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; let src3 = self.resolver.value(arguments.src3)?; self.resolver.with_result(arguments.dst, |dst_name| unsafe { LLVMBuildSelect(self.builder, src3, src1, src2, dst_name) }); Ok(()) } fn emit_rem( &mut self, data: ptx_parser::ScalarType, arguments: ptx_parser::RemArgs, ) -> Result<(), TranslateError> { let llvm_fn = match data.kind() { ptx_parser::ScalarKind::Unsigned => LLVMBuildURem, ptx_parser::ScalarKind::Signed => LLVMBuildSRem, _ => return Err(error_unreachable()), }; let src1 = self.resolver.value(arguments.src1)?; let src2 = self.resolver.value(arguments.src2)?; self.resolver.with_result(arguments.dst, |dst_name| unsafe { llvm_fn(self.builder, src1, src2, dst_name) }); Ok(()) } fn emit_popc( &mut self, type_: ptx_parser::ScalarType, arguments: ptx_parser::PopcArgs, ) -> Result<(), TranslateError> { let intrinsic = match type_ { ast::ScalarType::B32 => c"llvm.ctpop.i32", ast::ScalarType::B64 => c"llvm.ctpop.i64", _ => return Err(error_unreachable()), }; let llvm_type = get_scalar_type(self.context, type_); self.emit_intrinsic( intrinsic, Some(arguments.dst), &type_.into(), vec![(self.resolver.value(arguments.src)?, llvm_type)], )?; Ok(()) } fn emit_min( &mut self, data: ptx_parser::MinMaxDetails, arguments: ptx_parser::MinArgs, ) -> Result<(), TranslateError> { let llvm_prefix = match data { ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin", ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin", ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { return Err(error_todo()) } ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum", }; let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); let llvm_type = get_scalar_type(self.context, data.type_()); self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), &data.type_().into(), vec![ (self.resolver.value(arguments.src1)?, llvm_type), (self.resolver.value(arguments.src2)?, llvm_type), ], )?; Ok(()) } fn emit_max( &mut self, data: ptx_parser::MinMaxDetails, arguments: ptx_parser::MaxArgs, ) -> Result<(), TranslateError> { let llvm_prefix = match data { ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax", ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax", ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => { return Err(error_todo()) } ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum", }; let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_())); let llvm_type = get_scalar_type(self.context, data.type_()); self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), &data.type_().into(), vec![ (self.resolver.value(arguments.src1)?, llvm_type), (self.resolver.value(arguments.src2)?, llvm_type), ], )?; Ok(()) } fn emit_fma( &mut self, data: ptx_parser::ArithFloat, arguments: ptx_parser::FmaArgs, ) -> Result<(), TranslateError> { let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_)); self.emit_intrinsic( unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) }, Some(arguments.dst), &data.type_.into(), vec![ ( self.resolver.value(arguments.src1)?, get_scalar_type(self.context, data.type_), ), ( self.resolver.value(arguments.src2)?, get_scalar_type(self.context, data.type_), ), ( self.resolver.value(arguments.src3)?, get_scalar_type(self.context, data.type_), ), ], )?; Ok(()) } fn emit_mad( &mut self, data: ptx_parser::MadDetails, arguments: ptx_parser::MadArgs, ) -> Result<(), TranslateError> { let mul_control = match data { ptx_parser::MadDetails::Float(mad_float) => { return self.emit_fma( mad_float, ast::FmaArgs { dst: arguments.dst, src1: arguments.src1, src2: arguments.src2, src3: arguments.src3, }, ) } ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()), ptx_parser::MadDetails::Integer { type_, control, .. } => { ast::MulDetails::Integer { control, type_ } } }; let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?; let src3 = self.resolver.value(arguments.src3)?; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildAdd(self.builder, temp, src3, dst) }); Ok(()) } fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> { unsafe { LLVMZludaBuildFence( self.builder, LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent, get_scope_membar(data)?, LLVM_UNNAMED.as_ptr(), ) }; Ok(()) } fn emit_prmt( &mut self, control: u16, arguments: ptx_parser::PrmtArgs, ) -> Result<(), TranslateError> { let components = [ (control >> 0) & 0b1111, (control >> 4) & 0b1111, (control >> 8) & 0b1111, (control >> 12) & 0b1111, ]; if components.iter().any(|&c| c > 7) { return Err(TranslateError::Todo); } let u32_type = get_scalar_type(self.context, ast::ScalarType::U32); let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?; let mut components = [ unsafe { LLVMConstInt(u32_type, components[0] as _, 0) }, unsafe { LLVMConstInt(u32_type, components[1] as _, 0) }, unsafe { LLVMConstInt(u32_type, components[2] as _, 0) }, unsafe { LLVMConstInt(u32_type, components[3] as _, 0) }, ]; let components_indices = unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) }; let src1 = self.resolver.value(arguments.src1)?; let src1_vector = unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) }; let src2 = self.resolver.value(arguments.src2)?; let src2_vector = unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) }; self.resolver.with_result(arguments.dst, |dst| unsafe { LLVMBuildShuffleVector( self.builder, src1_vector, src2_vector, components_indices, dst, ) }); Ok(()) } /* // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding` // Should be available in LLVM 19 fn with_rounding(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T { let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32); let void_type = unsafe { LLVMVoidTypeInContext(self.context) }; let get_rounding = c"llvm.get.rounding"; let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) }; let mut get_rounding_fn = unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) }; if get_rounding_fn == ptr::null_mut() { get_rounding_fn = unsafe { LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type) }; } let set_rounding = c"llvm.set.rounding"; let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) }; let mut set_rounding_fn = unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) }; if set_rounding_fn == ptr::null_mut() { set_rounding_fn = unsafe { LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type) }; } let mut preserved_rounding_mode = unsafe { LLVMBuildCall2( self.builder, get_rounding_fn_type, get_rounding_fn, ptr::null_mut(), 0, LLVM_UNNAMED.as_ptr(), ) }; let mut requested_rounding = unsafe { LLVMConstInt( get_scalar_type(self.context, ast::ScalarType::B32), rounding_to_llvm(rnd) as u64, 0, ) }; unsafe { LLVMBuildCall2( self.builder, set_rounding_fn_type, set_rounding_fn, &mut requested_rounding, 1, LLVM_UNNAMED.as_ptr(), ) }; let result = fn_(self); unsafe { LLVMBuildCall2( self.builder, set_rounding_fn_type, set_rounding_fn, &mut preserved_rounding_mode, 1, LLVM_UNNAMED.as_ptr(), ) }; result } */ } fn get_pointer_type<'ctx>( context: LLVMContextRef, to_space: ast::StateSpace, ) -> Result { Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) }) } // https://llvm.org/docs/AMDGPUUsage.html#memory-scopes fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> { Ok(match scope { ast::MemScope::Cta => c"workgroup-one-as", ast::MemScope::Gpu => c"agent-one-as", ast::MemScope::Sys => c"one-as", ast::MemScope::Cluster => todo!(), } .as_ptr()) } fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> { Ok(match scope { ast::MemScope::Cta => c"workgroup", ast::MemScope::Gpu => c"agent", ast::MemScope::Sys => c"", ast::MemScope::Cluster => todo!(), } .as_ptr()) } fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { match semantics { ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease, ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease, } } fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering { match semantics { ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic, ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire, } } fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result { Ok(match type_ { ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar), ast::Type::Vector(size, scalar) => { let base_type = get_scalar_type(context, *scalar); unsafe { LLVMVectorType(base_type, *size as u32) } } ast::Type::Array(vec, scalar, dimensions) => { let mut underlying_type = get_scalar_type(context, *scalar); if let Some(size) = vec { underlying_type = unsafe { LLVMVectorType(underlying_type, size.get() as u32) }; } if dimensions.is_empty() { return Ok(unsafe { LLVMArrayType2(underlying_type, 0) }); } dimensions .iter() .rfold(underlying_type, |result, dimension| unsafe { LLVMArrayType2(result, *dimension as u64) }) } ast::Type::Pointer(_, space) => get_pointer_type(context, *space)?, }) } fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeRef { match type_ { ast::ScalarType::Pred => unsafe { LLVMInt1TypeInContext(context) }, ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => unsafe { LLVMInt8TypeInContext(context) }, ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => unsafe { LLVMInt16TypeInContext(context) }, ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => unsafe { LLVMInt32TypeInContext(context) }, ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => unsafe { LLVMInt64TypeInContext(context) }, ast::ScalarType::B128 => unsafe { LLVMInt128TypeInContext(context) }, ast::ScalarType::F16 => unsafe { LLVMHalfTypeInContext(context) }, ast::ScalarType::F32 => unsafe { LLVMFloatTypeInContext(context) }, ast::ScalarType::F64 => unsafe { LLVMDoubleTypeInContext(context) }, ast::ScalarType::BF16 => unsafe { LLVMBFloatTypeInContext(context) }, ast::ScalarType::U16x2 => todo!(), ast::ScalarType::S16x2 => todo!(), ast::ScalarType::F16x2 => todo!(), ast::ScalarType::BF16x2 => todo!(), } } fn get_function_type<'a>( context: LLVMContextRef, mut return_args: impl ExactSizeIterator, input_args: impl ExactSizeIterator>, ) -> Result { let mut input_args = input_args.collect::, _>>()?; let return_type = match return_args.len() { 0 => unsafe { LLVMVoidTypeInContext(context) }, 1 => get_type(context, return_args.next().unwrap())?, _ => todo!(), }; Ok(unsafe { LLVMFunctionType( return_type, input_args.as_mut_ptr(), input_args.len() as u32, 0, ) }) } fn get_state_space(space: ast::StateSpace) -> Result { match space { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), ast::StateSpace::Param => Err(TranslateError::Todo), ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::ParamFunc => Err(TranslateError::Todo), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), ast::StateSpace::Const => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::Shared => Ok(SHARED_ADDRESS_SPACE), ast::StateSpace::SharedCta => Err(TranslateError::Todo), ast::StateSpace::SharedCluster => Err(TranslateError::Todo), } } struct ResolveIdent { words: HashMap, values: HashMap, } impl ResolveIdent { fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self { ResolveIdent { words: HashMap::new(), values: HashMap::new(), } } fn get_or_ad_impl<'a, T>(&'a mut self, word: SpirvWord, fn_: impl FnOnce(&'a str) -> T) -> T { let str = match self.words.entry(word) { hash_map::Entry::Occupied(entry) => entry.into_mut(), hash_map::Entry::Vacant(entry) => { let mut text = word.0.to_string(); text.push('\0'); entry.insert(text) } }; fn_(&str[..str.len() - 1]) } fn get_or_add(&mut self, word: SpirvWord) -> &str { self.get_or_ad_impl(word, |x| x) } fn get_or_add_raw(&mut self, word: SpirvWord) -> *const i8 { self.get_or_add(word).as_ptr().cast() } fn register(&mut self, word: SpirvWord, v: LLVMValueRef) { self.values.insert(word, v); } fn value(&self, word: SpirvWord) -> Result { self.values .get(&word) .copied() .ok_or_else(|| error_unreachable()) } fn with_result( &mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef, ) -> LLVMValueRef { let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast())); self.register(word, t); t } fn with_result_option( &mut self, word: Option, fn_: impl FnOnce(*const i8) -> LLVMValueRef, ) -> LLVMValueRef { match word { Some(word) => self.with_result(word, fn_), None => fn_(LLVM_UNNAMED.as_ptr()), } } } struct LLVMTypeDisplay(ast::ScalarType); impl std::fmt::Display for LLVMTypeDisplay { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self.0 { ast::ScalarType::Pred => write!(f, "i1"), ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"), ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"), ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"), ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"), ptx_parser::ScalarType::B128 => write!(f, "i128"), ast::ScalarType::F16 => write!(f, "f16"), ptx_parser::ScalarType::BF16 => write!(f, "bfloat"), ast::ScalarType::F32 => write!(f, "f32"), ast::ScalarType::F64 => write!(f, "f64"), ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"), ast::ScalarType::F16x2 => write!(f, "v2f16"), ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"), } } } /* fn rounding_to_llvm(this: ast::RoundingMode) -> u32 { match this { ptx_parser::RoundingMode::Zero => 0, ptx_parser::RoundingMode::NearestEven => 1, ptx_parser::RoundingMode::PositiveInf => 2, ptx_parser::RoundingMode::NegativeInf => 3, } } */