diff options
Diffstat (limited to 'ptx/src/pass/emit_llvm.rs')
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 1933 |
1 files changed, 1815 insertions, 118 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 235ad7d..fa011a3 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -18,16 +18,23 @@ // while with plain LLVM-C it's just:
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
-use std::convert::{TryFrom, TryInto};
-use std::ffi::CStr;
+// AMDGPU LLVM backend support for llvm.experimental.constrained.* is incomplete.
+// Emitting @llvm.experimental.constrained.fdiv.f32(...) makes LLVm fail with
+// "LLVM ERROR: unsupported libcall legalization". Running with "-mllvm -print-before-all"
+// shows it fails inside amdgpu-isel. You can get a little bit furthr with "-mllvm -global-isel",
+// but it will too fail similarly, but with "unable to legalize instruction"
+
+use std::array::TryFromSliceError;
+use std::convert::TryInto;
+use std::ffi::{CStr, NulError};
use std::ops::Deref;
-use std::ptr;
+use std::{i8, ptr};
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
-use llvm_zluda::core::*;
-use llvm_zluda::prelude::*;
+use llvm_zluda::{core::*, *};
+use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
const LLVM_UNNAMED: &CStr = c"";
@@ -172,7 +179,7 @@ pub(super) fn run<'input>( let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
- Directive2::Variable(..) => todo!(),
+ Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?,
Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
@@ -228,15 +235,18 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { })
.ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
- let fn_type = get_function_type(
- self.context,
- func_decl.return_arguments.iter().map(|v| &v.v_type),
- func_decl
- .input_arguments
- .iter()
- .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
- )?;
- let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
+ if fn_ == ptr::null_mut() {
+ let fn_type = get_function_type(
+ self.context,
+ func_decl.return_arguments.iter().map(|v| &v.v_type),
+ func_decl
+ .input_arguments
+ .iter()
+ .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
+ )?;
+ fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ }
if let ast::MethodName::Func(name) = func_decl.name {
self.resolver.register(name, fn_);
}
@@ -274,6 +284,14 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { unsafe { LLVMAppendBasicBlockInContext(self.context, fn_, LLVM_UNNAMED.as_ptr()) };
unsafe { LLVMPositionBuilderAtEnd(self.builder.get(), real_bb) };
let mut method_emitter = MethodEmitContext::new(self, fn_, variables_builder);
+ for var in func_decl.return_arguments {
+ method_emitter.emit_variable(var)?;
+ }
+ for statement in statements.iter() {
+ if let Statement::Label(label) = statement {
+ method_emitter.emit_label_initial(*label);
+ }
+ }
for statement in statements {
method_emitter.emit_statement(statement)?;
}
@@ -281,43 +299,146 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { }
Ok(())
}
+
+ fn emit_global(
+ &mut self,
+ _linking: ast::LinkingDirective,
+ var: ast::Variable<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let name = self
+ .id_defs
+ .ident_map
+ .get(&var.name)
+ .map(|entry| {
+ entry
+ .name
+ .as_ref()
+ .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?)))
+ })
+ .flatten()
+ .transpose()
+ .map_err(|_| error_unreachable())?
+ .unwrap_or(Cow::Borrowed(LLVM_UNNAMED));
+ let global = unsafe {
+ LLVMAddGlobalInAddressSpace(
+ self.module,
+ get_type(self.context, &var.v_type)?,
+ name.as_ptr(),
+ get_state_space(var.state_space)?,
+ )
+ };
+ self.resolver.register(var.name, global);
+ if let Some(align) = var.align {
+ unsafe { LLVMSetAlignment(global, align) };
+ }
+ if !var.array_init.is_empty() {
+ self.emit_array_init(&var.v_type, &*var.array_init, global)?;
+ }
+ Ok(())
+ }
+
+ // TODO: instead of Vec<u8> we should emit a typed initializer
+ fn emit_array_init(
+ &mut self,
+ type_: &ast::Type,
+ array_init: &[u8],
+ global: *mut llvm_zluda::LLVMValue,
+ ) -> Result<(), TranslateError> {
+ match type_ {
+ ast::Type::Array(None, scalar, dimensions) => {
+ if dimensions.len() != 1 {
+ todo!()
+ }
+ if dimensions[0] as usize * scalar.size_of() as usize != array_init.len() {
+ return Err(error_unreachable());
+ }
+ let type_ = get_scalar_type(self.context, *scalar);
+ let mut elements = array_init
+ .chunks(scalar.size_of() as usize)
+ .map(|chunk| self.constant_from_bytes(*scalar, chunk, type_))
+ .collect::<Result<Vec<_>, _>>()
+ .map_err(|_| error_unreachable())?;
+ let initializer =
+ unsafe { LLVMConstArray2(type_, elements.as_mut_ptr(), elements.len() as u64) };
+ unsafe { LLVMSetInitializer(global, initializer) };
+ }
+ _ => todo!(),
+ }
+ Ok(())
+ }
+
+ fn constant_from_bytes(
+ &self,
+ scalar: ast::ScalarType,
+ bytes: &[u8],
+ llvm_type: LLVMTypeRef,
+ ) -> Result<LLVMValueRef, TryFromSliceError> {
+ Ok(match scalar {
+ ptx_parser::ScalarType::Pred
+ | ptx_parser::ScalarType::S8
+ | ptx_parser::ScalarType::B8
+ | ptx_parser::ScalarType::U8 => unsafe {
+ LLVMConstInt(llvm_type, u8::from_le_bytes(bytes.try_into()?) as u64, 0)
+ },
+ ptx_parser::ScalarType::S16
+ | ptx_parser::ScalarType::B16
+ | ptx_parser::ScalarType::U16 => unsafe {
+ LLVMConstInt(llvm_type, u16::from_le_bytes(bytes.try_into()?) as u64, 0)
+ },
+ ptx_parser::ScalarType::S32
+ | ptx_parser::ScalarType::B32
+ | ptx_parser::ScalarType::U32 => unsafe {
+ LLVMConstInt(llvm_type, u32::from_le_bytes(bytes.try_into()?) as u64, 0)
+ },
+ ptx_parser::ScalarType::F16 => todo!(),
+ ptx_parser::ScalarType::BF16 => todo!(),
+ ptx_parser::ScalarType::U64 => todo!(),
+ ptx_parser::ScalarType::S64 => todo!(),
+ ptx_parser::ScalarType::S16x2 => todo!(),
+ ptx_parser::ScalarType::F32 => todo!(),
+ ptx_parser::ScalarType::B64 => todo!(),
+ ptx_parser::ScalarType::F64 => todo!(),
+ ptx_parser::ScalarType::B128 => todo!(),
+ ptx_parser::ScalarType::U16x2 => todo!(),
+ ptx_parser::ScalarType::F16x2 => todo!(),
+ ptx_parser::ScalarType::BF16x2 => todo!(),
+ })
+ }
}
fn get_input_argument_type(
context: LLVMContextRef,
- v_type: &ptx_parser::Type,
- state_space: ptx_parser::StateSpace,
+ v_type: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<LLVMTypeRef, TranslateError> {
match state_space {
- ptx_parser::StateSpace::ParamEntry => {
+ ast::StateSpace::ParamEntry => {
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
}
- ptx_parser::StateSpace::Reg => get_type(context, v_type),
+ ast::StateSpace::Reg => get_type(context, v_type),
_ => return Err(error_unreachable()),
}
}
-struct MethodEmitContext<'a, 'input> {
+struct MethodEmitContext<'a> {
context: LLVMContextRef,
module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
- id_defs: &'a GlobalStringIdentResolver2<'input>,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
-impl<'a, 'input> MethodEmitContext<'a, 'input> {
- fn new<'x>(
- parent: &'a mut ModuleEmitContext<'x, 'input>,
+impl<'a> MethodEmitContext<'a> {
+ fn new(
+ parent: &'a mut ModuleEmitContext,
method: LLVMValueRef,
variables_builder: Builder,
- ) -> MethodEmitContext<'a, 'input> {
+ ) -> MethodEmitContext<'a> {
MethodEmitContext {
context: parent.context,
module: parent.module,
builder: parent.builder.get(),
- id_defs: parent.id_defs,
variables_builder,
resolver: &mut parent.resolver,
method,
@@ -330,18 +451,17 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ) -> Result<(), TranslateError> {
Ok(match statement {
Statement::Variable(var) => self.emit_variable(var)?,
- Statement::Label(label) => self.emit_label(label),
+ Statement::Label(label) => self.emit_label_delayed(label)?,
Statement::Instruction(inst) => self.emit_instruction(inst)?,
- Statement::Conditional(_) => todo!(),
- Statement::LoadVar(var) => self.emit_load_variable(var)?,
- Statement::StoreVar(store) => self.emit_store_var(store)?,
+ Statement::Conditional(cond) => self.emit_conditional(cond)?,
Statement::Conversion(conversion) => self.emit_conversion(conversion)?,
Statement::Constant(constant) => self.emit_constant(constant)?,
- Statement::RetValue(_, _) => todo!(),
- Statement::PtrAccess(_) => todo!(),
- Statement::RepackVector(_) => todo!(),
+ Statement::RetValue(_, values) => self.emit_ret_value(values)?,
+ Statement::PtrAccess(ptr_access) => self.emit_ptr_access(ptr_access)?,
+ Statement::RepackVector(repack) => self.emit_vector_repack(repack)?,
Statement::FunctionPointer(_) => todo!(),
- Statement::VectorAccess(_) => todo!(),
+ Statement::VectorRead(vector_read) => self.emit_vector_read(vector_read)?,
+ Statement::VectorWrite(vector_write) => self.emit_vector_write(vector_write)?,
})
}
@@ -364,7 +484,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(())
}
- fn emit_label(&mut self, label: SpirvWord) {
+ fn emit_label_initial(&mut self, label: SpirvWord) {
let block = unsafe {
LLVMAppendBasicBlockInContext(
self.context,
@@ -372,17 +492,18 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { self.resolver.get_or_add_raw(label),
)
};
+ self.resolver
+ .register(label, unsafe { LLVMBasicBlockAsValue(block) });
+ }
+
+ fn emit_label_delayed(&mut self, label: SpirvWord) -> Result<(), TranslateError> {
+ let block = self.resolver.value(label)?;
+ let block = unsafe { LLVMValueAsBasicBlock(block) };
let last_block = unsafe { LLVMGetInsertBlock(self.builder) };
if unsafe { LLVMGetBasicBlockTerminator(last_block) } == ptr::null_mut() {
unsafe { LLVMBuildBr(self.builder, block) };
}
unsafe { LLVMPositionBuilderAtEnd(self.builder, block) };
- }
-
- fn emit_store_var(&mut self, store: StoreVarDetails) -> Result<(), TranslateError> {
- let ptr = self.resolver.value(store.arg.src1)?;
- let value = self.resolver.value(store.arg.src2)?;
- unsafe { LLVMBuildStore(self.builder, value, ptr) };
Ok(())
}
@@ -395,50 +516,51 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
- ast::Instruction::Mul { data, arguments } => todo!(),
- ast::Instruction::Setp { data, arguments } => todo!(),
- ast::Instruction::SetpBool { data, arguments } => todo!(),
- ast::Instruction::Not { data, arguments } => todo!(),
- ast::Instruction::Or { data, arguments } => todo!(),
- ast::Instruction::And { data, arguments } => todo!(),
- ast::Instruction::Bra { arguments } => todo!(),
+ ast::Instruction::Mul { data, arguments } => self.emit_mul(data, arguments),
+ ast::Instruction::Setp { data, arguments } => self.emit_setp(data, arguments),
+ ast::Instruction::SetpBool { .. } => todo!(),
+ ast::Instruction::Not { data, arguments } => self.emit_not(data, arguments),
+ ast::Instruction::Or { data, arguments } => self.emit_or(data, arguments),
+ ast::Instruction::And { arguments, .. } => self.emit_and(arguments),
+ ast::Instruction::Bra { arguments } => self.emit_bra(arguments),
ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
- ast::Instruction::Cvt { data, arguments } => todo!(),
- ast::Instruction::Shr { data, arguments } => todo!(),
- ast::Instruction::Shl { data, arguments } => todo!(),
+ ast::Instruction::Cvt { data, arguments } => self.emit_cvt(data, arguments),
+ ast::Instruction::Shr { data, arguments } => self.emit_shr(data, arguments),
+ ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
- ast::Instruction::Cvta { data, arguments } => todo!(),
- ast::Instruction::Abs { data, arguments } => todo!(),
- ast::Instruction::Mad { data, arguments } => todo!(),
- ast::Instruction::Fma { data, arguments } => todo!(),
- ast::Instruction::Sub { data, arguments } => todo!(),
- ast::Instruction::Min { data, arguments } => todo!(),
- ast::Instruction::Max { data, arguments } => todo!(),
- ast::Instruction::Rcp { data, arguments } => todo!(),
- ast::Instruction::Sqrt { data, arguments } => todo!(),
- ast::Instruction::Rsqrt { data, arguments } => todo!(),
- ast::Instruction::Selp { data, arguments } => todo!(),
- ast::Instruction::Bar { data, arguments } => todo!(),
- ast::Instruction::Atom { data, arguments } => todo!(),
- ast::Instruction::AtomCas { data, arguments } => todo!(),
- ast::Instruction::Div { data, arguments } => todo!(),
- ast::Instruction::Neg { data, arguments } => todo!(),
- ast::Instruction::Sin { data, arguments } => todo!(),
- ast::Instruction::Cos { data, arguments } => todo!(),
- ast::Instruction::Lg2 { data, arguments } => todo!(),
- ast::Instruction::Ex2 { data, arguments } => todo!(),
- ast::Instruction::Clz { data, arguments } => todo!(),
- ast::Instruction::Brev { data, arguments } => todo!(),
- ast::Instruction::Popc { data, arguments } => todo!(),
- ast::Instruction::Xor { data, arguments } => todo!(),
- ast::Instruction::Rem { data, arguments } => todo!(),
- ast::Instruction::Bfe { data, arguments } => todo!(),
- ast::Instruction::Bfi { data, arguments } => todo!(),
- ast::Instruction::PrmtSlow { arguments } => todo!(),
- ast::Instruction::Prmt { data, arguments } => todo!(),
- ast::Instruction::Activemask { arguments } => todo!(),
- ast::Instruction::Membar { data } => todo!(),
+ ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
+ ast::Instruction::Abs { .. } => todo!(),
+ ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments),
+ ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments),
+ ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
+ ast::Instruction::Min { data, arguments } => self.emit_min(data, arguments),
+ ast::Instruction::Max { data, arguments } => self.emit_max(data, arguments),
+ ast::Instruction::Rcp { data, arguments } => self.emit_rcp(data, arguments),
+ ast::Instruction::Sqrt { data, arguments } => self.emit_sqrt(data, arguments),
+ ast::Instruction::Rsqrt { data, arguments } => self.emit_rsqrt(data, arguments),
+ ast::Instruction::Selp { data, arguments } => self.emit_selp(data, arguments),
+ ast::Instruction::Bar { .. } => todo!(),
+ ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
+ ast::Instruction::AtomCas { data, arguments } => self.emit_atom_cas(data, arguments),
+ ast::Instruction::Div { data, arguments } => self.emit_div(data, arguments),
+ ast::Instruction::Neg { data, arguments } => self.emit_neg(data, arguments),
+ ast::Instruction::Sin { data, arguments } => self.emit_sin(data, arguments),
+ ast::Instruction::Cos { data, arguments } => self.emit_cos(data, arguments),
+ ast::Instruction::Lg2 { data, arguments } => self.emit_lg2(data, arguments),
+ ast::Instruction::Ex2 { data, arguments } => self.emit_ex2(data, arguments),
+ ast::Instruction::Clz { data, arguments } => self.emit_clz(data, arguments),
+ ast::Instruction::Brev { data, arguments } => self.emit_brev(data, arguments),
+ ast::Instruction::Popc { data, arguments } => self.emit_popc(data, arguments),
+ ast::Instruction::Xor { data, arguments } => self.emit_xor(data, arguments),
+ ast::Instruction::Rem { data, arguments } => self.emit_rem(data, arguments),
+ ast::Instruction::PrmtSlow { .. } => todo!(),
+ ast::Instruction::Prmt { data, arguments } => self.emit_prmt(data, arguments),
+ ast::Instruction::Membar { data } => self.emit_membar(data),
ast::Instruction::Trap {} => todo!(),
+ // replaced by a function call
+ ast::Instruction::Bfe { .. }
+ | ast::Instruction::Bfi { .. }
+ | ast::Instruction::Activemask { .. } => return Err(error_unreachable()),
}
}
@@ -447,9 +569,6 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { data: ast::LdDetails,
arguments: ast::LdArgs<SpirvWord>,
) -> Result<(), TranslateError> {
- if data.non_coherent {
- todo!()
- }
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
@@ -462,24 +581,25 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(())
}
- fn emit_load_variable(&mut self, var: LoadVarDetails) -> Result<(), TranslateError> {
- if var.member_index.is_some() {
- todo!()
- }
- let builder = self.builder;
- let type_ = get_type(self.context, &var.typ)?;
- let ptr = self.resolver.value(var.arg.src)?;
- self.resolver.with_result(var.arg.dst, |dst| unsafe {
- LLVMBuildLoad2(builder, type_, ptr, dst)
- });
- Ok(())
- }
-
fn emit_conversion(&mut self, conversion: ImplicitConversion) -> Result<(), TranslateError> {
let builder = self.builder;
match conversion.kind {
- ConversionKind::Default => todo!(),
- ConversionKind::SignExtend => todo!(),
+ ConversionKind::Default => self.emit_conversion_default(
+ self.resolver.value(conversion.src)?,
+ conversion.dst,
+ &conversion.from_type,
+ conversion.from_space,
+ &conversion.to_type,
+ conversion.to_space,
+ ),
+ ConversionKind::SignExtend => {
+ let src = self.resolver.value(conversion.src)?;
+ let type_ = get_type(self.context, &conversion.to_type)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildSExt(builder, src, type_, dst)
+ });
+ Ok(())
+ }
ConversionKind::BitToPtr => {
let src = self.resolver.value(conversion.src)?;
let type_ = get_pointer_type(self.context, conversion.to_space)?;
@@ -488,8 +608,131 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { });
Ok(())
}
- ConversionKind::PtrToPtr => todo!(),
- ConversionKind::AddressOf => todo!(),
+ ConversionKind::PtrToPtr => {
+ let src = self.resolver.value(conversion.src)?;
+ let dst_type = get_pointer_type(self.context, conversion.to_space)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildAddrSpaceCast(builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+ ConversionKind::AddressOf => {
+ let src = self.resolver.value(conversion.src)?;
+ let dst_type = get_type(self.context, &conversion.to_type)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildPtrToInt(self.builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+ }
+ }
+
+ fn emit_conversion_default(
+ &mut self,
+ src: LLVMValueRef,
+ dst: SpirvWord,
+ from_type: &ast::Type,
+ from_space: ast::StateSpace,
+ to_type: &ast::Type,
+ to_space: ast::StateSpace,
+ ) -> Result<(), TranslateError> {
+ match (from_type, to_type) {
+ (ast::Type::Scalar(from_type), ast::Type::Scalar(to_type_scalar)) => {
+ let from_layout = from_type.layout();
+ let to_layout = to_type.layout();
+ if from_layout.size() == to_layout.size() {
+ let dst_type = get_type(self.context, &to_type)?;
+ if from_type.kind() != ast::ScalarKind::Float
+ && to_type_scalar.kind() != ast::ScalarKind::Float
+ {
+ // It is noop, but another instruction expects result of this conversion
+ self.resolver.register(dst, src);
+ } else {
+ self.resolver.with_result(dst, |dst| unsafe {
+ LLVMBuildBitCast(self.builder, src, dst_type, dst)
+ });
+ }
+ Ok(())
+ } else {
+ // This block is safe because it's illegal to implictly convert between floating point values
+ let same_width_bit_type = unsafe {
+ LLVMIntTypeInContext(self.context, (from_layout.size() * 8) as u32)
+ };
+ let same_width_bit_value = unsafe {
+ LLVMBuildBitCast(
+ self.builder,
+ src,
+ same_width_bit_type,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let wide_bit_type = match to_type_scalar.layout().size() {
+ 1 => ast::ScalarType::B8,
+ 2 => ast::ScalarType::B16,
+ 4 => ast::ScalarType::B32,
+ 8 => ast::ScalarType::B64,
+ _ => return Err(error_unreachable()),
+ };
+ let wide_bit_type_llvm = unsafe {
+ LLVMIntTypeInContext(self.context, (to_layout.size() * 8) as u32)
+ };
+ if to_type_scalar.kind() == ast::ScalarKind::Unsigned
+ || to_type_scalar.kind() == ast::ScalarKind::Bit
+ {
+ let llvm_fn = if to_type_scalar.size_of() >= from_type.size_of() {
+ LLVMBuildZExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ };
+ self.resolver.with_result(dst, |dst| unsafe {
+ llvm_fn(self.builder, same_width_bit_value, wide_bit_type_llvm, dst)
+ });
+ Ok(())
+ } else {
+ let conversion_fn = if from_type.kind() == ast::ScalarKind::Signed
+ && to_type_scalar.kind() == ast::ScalarKind::Signed
+ {
+ if to_type_scalar.size_of() >= from_type.size_of() {
+ LLVMBuildSExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ }
+ } else {
+ if to_type_scalar.size_of() >= from_type.size_of() {
+ LLVMBuildZExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ }
+ };
+ let wide_bit_value = unsafe {
+ conversion_fn(
+ self.builder,
+ same_width_bit_value,
+ wide_bit_type_llvm,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ self.emit_conversion_default(
+ wide_bit_value,
+ dst,
+ &wide_bit_type.into(),
+ from_space,
+ to_type,
+ to_space,
+ )
+ }
+ }
+ }
+ (ast::Type::Vector(..), ast::Type::Scalar(..))
+ | (ast::Type::Scalar(..), ast::Type::Array(..))
+ | (ast::Type::Array(..), ast::Type::Scalar(..)) => {
+ let dst_type = get_type(self.context, to_type)?;
+ self.resolver.with_result(dst, |dst| unsafe {
+ LLVMBuildBitCast(self.builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+ _ => todo!(),
}
}
@@ -514,8 +757,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { let src1 = self.resolver.value(arguments.src1)?;
let src2 = self.resolver.value(arguments.src2)?;
let fn_ = match data {
- ast::ArithDetails::Integer(integer) => LLVMBuildAdd,
- ast::ArithDetails::Float(float) => LLVMBuildFAdd,
+ ast::ArithDetails::Integer(..) => LLVMBuildAdd,
+ ast::ArithDetails::Float(..) => LLVMBuildFAdd,
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
fn_(builder, src1, src2, dst)
@@ -525,8 +768,8 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_st(
&self,
- data: ptx_parser::StData,
- arguments: ptx_parser::StArgs<SpirvWord>,
+ data: ast::StData,
+ arguments: ast::StArgs<SpirvWord>,
) -> Result<(), TranslateError> {
let ptr = self.resolver.value(arguments.src1)?;
let value = self.resolver.value(arguments.src2)?;
@@ -537,14 +780,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Ok(())
}
- fn emit_ret(&self, _data: ptx_parser::RetData) {
+ fn emit_ret(&self, _data: ast::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
fn emit_call(
&mut self,
- data: ptx_parser::CallDetails,
- arguments: ptx_parser::CallArgs<SpirvWord>,
+ data: ast::CallDetails,
+ arguments: ast::CallArgs<SpirvWord>,
) -> Result<(), TranslateError> {
if cfg!(debug_assertions) {
for (_, space) in data.return_arguments.iter() {
@@ -558,14 +801,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { }
}
}
- let name = match (&*data.return_arguments, &*arguments.return_arguments) {
- ([], []) => LLVM_UNNAMED.as_ptr(),
- ([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
+ let name = match &*arguments.return_arguments {
+ [] => LLVM_UNNAMED.as_ptr(),
+ [dst] => self.resolver.get_or_add_raw(*dst),
_ => todo!(),
};
let type_ = get_function_type(
self.context,
- data.return_arguments.iter().map(|(type_, space)| type_),
+ data.return_arguments.iter().map(|(type_, ..)| type_),
data.input_arguments
.iter()
.map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
@@ -597,13 +840,1380 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_mov(
&mut self,
- _data: ptx_parser::MovDetails,
- arguments: ptx_parser::MovArgs<SpirvWord>,
+ _data: ast::MovDetails,
+ arguments: ast::MovArgs<SpirvWord>,
) -> Result<(), TranslateError> {
self.resolver
.register(arguments.dst, self.resolver.value(arguments.src)?);
Ok(())
}
+
+ fn emit_ptr_access(&mut self, ptr_access: PtrAccess<SpirvWord>) -> Result<(), TranslateError> {
+ let ptr_src = self.resolver.value(ptr_access.ptr_src)?;
+ let mut offset_src = self.resolver.value(ptr_access.offset_src)?;
+ let pointee_type = get_scalar_type(self.context, ast::ScalarType::B8);
+ self.resolver.with_result(ptr_access.dst, |dst| unsafe {
+ LLVMBuildInBoundsGEP2(self.builder, pointee_type, ptr_src, &mut offset_src, 1, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_and(&mut self, arguments: ast::AndArgs<SpirvWord>) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildAnd(builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_atom(
+ &mut self,
+ data: ast::AtomDetails,
+ arguments: ast::AtomArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let op = match data.op {
+ ast::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
+ ast::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
+ ast::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
+ ast::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
+ ast::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
+ ast::AtomicOp::IncrementWrap => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
+ }
+ ast::AtomicOp::DecrementWrap => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
+ }
+ ast::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
+ ast::AtomicOp::UnsignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin,
+ ast::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
+ ast::AtomicOp::UnsignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax,
+ ast::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
+ ast::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
+ ast::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
+ };
+ self.resolver.register(arguments.dst, unsafe {
+ LLVMZludaBuildAtomicRMW(
+ builder,
+ op,
+ src1,
+ src2,
+ get_scope(data.scope)?,
+ get_ordering(data.semantics),
+ )
+ });
+ Ok(())
+ }
+
+ fn emit_atom_cas(
+ &mut self,
+ data: ast::AtomCasDetails,
+ arguments: ast::AtomCasArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let src3 = self.resolver.value(arguments.src3)?;
+ let success_ordering = get_ordering(data.semantics);
+ let failure_ordering = get_ordering_failure(data.semantics);
+ let temp = unsafe {
+ LLVMZludaBuildAtomicCmpXchg(
+ self.builder,
+ src1,
+ src2,
+ src3,
+ get_scope(data.scope)?,
+ success_ordering,
+ failure_ordering,
+ )
+ };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildExtractValue(self.builder, temp, 0, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_bra(&self, arguments: ast::BraArgs<SpirvWord>) -> Result<(), TranslateError> {
+ let src = self.resolver.value(arguments.src)?;
+ let src = unsafe { LLVMValueAsBasicBlock(src) };
+ unsafe { LLVMBuildBr(self.builder, src) };
+ Ok(())
+ }
+
+ fn emit_brev(
+ &mut self,
+ data: ast::ScalarType,
+ arguments: ast::BrevArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_fn = match data.size_of() {
+ 4 => c"llvm.bitreverse.i32",
+ 8 => c"llvm.bitreverse.i64",
+ _ => return Err(error_unreachable()),
+ };
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
+ let type_ = get_scalar_type(self.context, data);
+ let fn_type = get_function_type(
+ self.context,
+ iter::once(&data.into()),
+ iter::once(Ok(type_)),
+ )?;
+ if fn_ == ptr::null_mut() {
+ fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
+ }
+ let mut src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildCall2(self.builder, fn_type, fn_, &mut src, 1, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_ret_value(
+ &mut self,
+ values: Vec<(SpirvWord, ptx_parser::Type)>,
+ ) -> Result<(), TranslateError> {
+ match &*values {
+ [] => unsafe { LLVMBuildRetVoid(self.builder) },
+ [(value, type_)] => {
+ let value = self.resolver.value(*value)?;
+ let type_ = get_type(self.context, type_)?;
+ let value =
+ unsafe { LLVMBuildLoad2(self.builder, type_, value, LLVM_UNNAMED.as_ptr()) };
+ unsafe { LLVMBuildRet(self.builder, value) }
+ }
+ _ => todo!(),
+ };
+ Ok(())
+ }
+
+ fn emit_clz(
+ &mut self,
+ data: ptx_parser::ScalarType,
+ arguments: ptx_parser::ClzArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_fn = match data.size_of() {
+ 4 => c"llvm.ctlz.i32",
+ 8 => c"llvm.ctlz.i64",
+ _ => return Err(error_unreachable()),
+ };
+ let type_ = get_scalar_type(self.context, data.into());
+ let pred = get_scalar_type(self.context, ast::ScalarType::Pred);
+ let fn_type = get_function_type(
+ self.context,
+ iter::once(&ast::ScalarType::U32.into()),
+ [Ok(type_), Ok(pred)].into_iter(),
+ )?;
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, llvm_fn.as_ptr()) };
+ if fn_ == ptr::null_mut() {
+ fn_ = unsafe { LLVMAddFunction(self.module, llvm_fn.as_ptr(), fn_type) };
+ }
+ let src = self.resolver.value(arguments.src)?;
+ let false_ = unsafe { LLVMConstInt(pred, 0, 0) };
+ let mut args = [src, false_];
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ fn_type,
+ fn_,
+ args.as_mut_ptr(),
+ args.len() as u32,
+ dst,
+ )
+ });
+ Ok(())
+ }
+
+ fn emit_mul(
+ &mut self,
+ data: ast::MulDetails,
+ arguments: ast::MulArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.emit_mul_impl(data, Some(arguments.dst), arguments.src1, arguments.src2)?;
+ Ok(())
+ }
+
+ fn emit_mul_impl(
+ &mut self,
+ data: ast::MulDetails,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<LLVMValueRef, TranslateError> {
+ let mul_fn = match data {
+ ast::MulDetails::Integer { control, type_ } => match control {
+ ast::MulIntControl::Low => LLVMBuildMul,
+ ast::MulIntControl::High => return self.emit_mul_high(type_, dst, src1, src2),
+ ast::MulIntControl::Wide => {
+ return Ok(self.emit_mul_wide_impl(type_, dst, src1, src2)?.1)
+ }
+ },
+ ast::MulDetails::Float(..) => LLVMBuildFMul,
+ };
+ let src1 = self.resolver.value(src1)?;
+ let src2 = self.resolver.value(src2)?;
+ Ok(self
+ .resolver
+ .with_result_option(dst, |dst| unsafe { mul_fn(self.builder, src1, src2, dst) }))
+ }
+
+ fn emit_mul_high(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<LLVMValueRef, TranslateError> {
+ let (wide_type, wide_value) = self.emit_mul_wide_impl(type_, None, src1, src2)?;
+ let shift_constant =
+ unsafe { LLVMConstInt(wide_type, (type_.layout().size() * 8) as u64, 0) };
+ let shifted = unsafe {
+ LLVMBuildLShr(
+ self.builder,
+ wide_value,
+ shift_constant,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let narrow_type = get_scalar_type(self.context, type_);
+ Ok(self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildTrunc(self.builder, shifted, narrow_type, dst)
+ }))
+ }
+
+ fn emit_mul_wide_impl(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ dst: Option<SpirvWord>,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ ) -> Result<(LLVMTypeRef, LLVMValueRef), TranslateError> {
+ let src1 = self.resolver.value(src1)?;
+ let src2 = self.resolver.value(src2)?;
+ let wide_type =
+ unsafe { LLVMIntTypeInContext(self.context, (type_.layout().size() * 8 * 2) as u32) };
+ let llvm_cast = match type_.kind() {
+ ptx_parser::ScalarKind::Signed => LLVMBuildSExt,
+ ptx_parser::ScalarKind::Unsigned => LLVMBuildZExt,
+ _ => return Err(error_unreachable()),
+ };
+ let src1 = unsafe { llvm_cast(self.builder, src1, wide_type, LLVM_UNNAMED.as_ptr()) };
+ let src2 = unsafe { llvm_cast(self.builder, src2, wide_type, LLVM_UNNAMED.as_ptr()) };
+ Ok((
+ wide_type,
+ self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildMul(self.builder, src1, src2, dst)
+ }),
+ ))
+ }
+
+ fn emit_cos(
+ &mut self,
+ _data: ast::FlushToZero,
+ arguments: ast::CosArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
+ let cos = self.emit_intrinsic(
+ c"llvm.cos.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(self.resolver.value(arguments.src)?, llvm_f32)],
+ )?;
+ unsafe { LLVMZludaSetFastMathFlags(cos, LLVMZludaFastMathApproxFunc) }
+ Ok(())
+ }
+
+ fn emit_or(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::OrArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildOr(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_xor(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::XorArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildXor(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_vector_read(&mut self, vec_acccess: VectorRead) -> Result<(), TranslateError> {
+ let src = self.resolver.value(vec_acccess.vector_src)?;
+ let index = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::B8),
+ vec_acccess.member as _,
+ 0,
+ )
+ };
+ self.resolver
+ .with_result(vec_acccess.scalar_dst, |dst| unsafe {
+ LLVMBuildExtractElement(self.builder, src, index, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_vector_write(&mut self, vector_write: VectorWrite) -> Result<(), TranslateError> {
+ let vector_src = self.resolver.value(vector_write.vector_src)?;
+ let scalar_src = self.resolver.value(vector_write.scalar_src)?;
+ let index = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::B8),
+ vector_write.member as _,
+ 0,
+ )
+ };
+ self.resolver
+ .with_result(vector_write.vector_dst, |dst| unsafe {
+ LLVMBuildInsertElement(self.builder, vector_src, scalar_src, index, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_vector_repack(&mut self, repack: RepackVectorDetails) -> Result<(), TranslateError> {
+ let i8_type = get_scalar_type(self.context, ast::ScalarType::B8);
+ if repack.is_extract {
+ let src = self.resolver.value(repack.packed)?;
+ for (index, dst) in repack.unpacked.iter().enumerate() {
+ let index: *mut LLVMValue = unsafe { LLVMConstInt(i8_type, index as _, 0) };
+ self.resolver.with_result(*dst, |dst| unsafe {
+ LLVMBuildExtractElement(self.builder, src, index, dst)
+ });
+ }
+ } else {
+ let vector_type = get_type(
+ self.context,
+ &ast::Type::Vector(repack.unpacked.len() as u8, repack.typ),
+ )?;
+ let mut temp_vec = unsafe { LLVMGetUndef(vector_type) };
+ for (index, src_id) in repack.unpacked.iter().enumerate() {
+ let dst = if index == repack.unpacked.len() - 1 {
+ Some(repack.packed)
+ } else {
+ None
+ };
+ let scalar_src = self.resolver.value(*src_id)?;
+ let index = unsafe { LLVMConstInt(i8_type, index as _, 0) };
+ temp_vec = self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildInsertElement(self.builder, temp_vec, scalar_src, index, dst)
+ });
+ }
+ }
+ Ok(())
+ }
+
+ fn emit_div(
+ &mut self,
+ data: ptx_parser::DivDetails,
+ arguments: ptx_parser::DivArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let integer_div = match data {
+ ptx_parser::DivDetails::Unsigned(_) => LLVMBuildUDiv,
+ ptx_parser::DivDetails::Signed(_) => LLVMBuildSDiv,
+ ptx_parser::DivDetails::Float(float_div) => {
+ return self.emit_div_float(float_div, arguments)
+ }
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ integer_div(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_div_float(
+ &mut self,
+ float_div: ptx_parser::DivFloatDetails,
+ arguments: ptx_parser::DivArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let _rnd = match float_div.kind {
+ ptx_parser::DivFloatKind::Approx => ast::RoundingMode::NearestEven,
+ ptx_parser::DivFloatKind::ApproxFull => ast::RoundingMode::NearestEven,
+ ptx_parser::DivFloatKind::Rounding(rounding_mode) => rounding_mode,
+ };
+ let approx = match float_div.kind {
+ ptx_parser::DivFloatKind::Approx => {
+ LLVMZludaFastMathAllowReciprocal | LLVMZludaFastMathApproxFunc
+ }
+ ptx_parser::DivFloatKind::ApproxFull => LLVMZludaFastMathNone,
+ ptx_parser::DivFloatKind::Rounding(_) => LLVMZludaFastMathNone,
+ };
+ let fdiv = self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFDiv(builder, src1, src2, dst)
+ });
+ unsafe { LLVMZludaSetFastMathFlags(fdiv, approx) };
+ if let ptx_parser::DivFloatKind::ApproxFull = float_div.kind {
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-div:
+ // div.full.f32 implements a relatively fast, full-range approximation that scales
+ // operands to achieve better accuracy, but is not fully IEEE 754 compliant and does not
+ // support rounding modifiers. The maximum ulp error is 2 across the full range of
+ // inputs.
+ // https://llvm.org/docs/LangRef.html#fpmath-metadata
+ let fpmath_value =
+ unsafe { LLVMConstReal(get_scalar_type(self.context, ast::ScalarType::F32), 2.0) };
+ let fpmath_value = unsafe { LLVMValueAsMetadata(fpmath_value) };
+ let mut md_node_content = [fpmath_value];
+ let md_node = unsafe {
+ LLVMMDNodeInContext2(
+ self.context,
+ md_node_content.as_mut_ptr(),
+ md_node_content.len(),
+ )
+ };
+ let md_node = unsafe { LLVMMetadataAsValue(self.context, md_node) };
+ let kind = unsafe {
+ LLVMGetMDKindIDInContext(
+ self.context,
+ "fpmath".as_ptr().cast(),
+ "fpmath".len() as u32,
+ )
+ };
+ unsafe { LLVMSetMetadata(fdiv, kind, md_node) };
+ }
+ Ok(())
+ }
+
+ fn emit_cvta(
+ &mut self,
+ data: ptx_parser::CvtaDetails,
+ arguments: ptx_parser::CvtaArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let (from_space, to_space) = match data.direction {
+ ptx_parser::CvtaDirection::GenericToExplicit => {
+ (ast::StateSpace::Generic, data.state_space)
+ }
+ ptx_parser::CvtaDirection::ExplicitToGeneric => {
+ (data.state_space, ast::StateSpace::Generic)
+ }
+ };
+ let from_type = get_pointer_type(self.context, from_space)?;
+ let dest_type = get_pointer_type(self.context, to_space)?;
+ let src = self.resolver.value(arguments.src)?;
+ let temp_ptr =
+ unsafe { LLVMBuildIntToPtr(self.builder, src, from_type, LLVM_UNNAMED.as_ptr()) };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildAddrSpaceCast(self.builder, temp_ptr, dest_type, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sub(
+ &mut self,
+ data: ptx_parser::ArithDetails,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ match data {
+ ptx_parser::ArithDetails::Integer(arith_integer) => {
+ self.emit_sub_integer(arith_integer, arguments)
+ }
+ ptx_parser::ArithDetails::Float(arith_float) => {
+ self.emit_sub_float(arith_float, arguments)
+ }
+ }
+ }
+
+ fn emit_sub_integer(
+ &mut self,
+ arith_integer: ptx_parser::ArithInteger,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arith_integer.saturate {
+ todo!()
+ }
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildSub(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sub_float(
+ &mut self,
+ arith_float: ptx_parser::ArithFloat,
+ arguments: ptx_parser::SubArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arith_float.saturate {
+ todo!()
+ }
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFSub(self.builder, src1, src2, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_sin(
+ &mut self,
+ _data: ptx_parser::FlushToZero,
+ arguments: ptx_parser::SinArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_f32 = get_scalar_type(self.context, ast::ScalarType::F32);
+ let sin = self.emit_intrinsic(
+ c"llvm.sin.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(self.resolver.value(arguments.src)?, llvm_f32)],
+ )?;
+ unsafe { LLVMZludaSetFastMathFlags(sin, LLVMZludaFastMathApproxFunc) }
+ Ok(())
+ }
+
+ fn emit_intrinsic(
+ &mut self,
+ name: &CStr,
+ dst: Option<SpirvWord>,
+ return_type: &ast::Type,
+ arguments: Vec<(LLVMValueRef, LLVMTypeRef)>,
+ ) -> Result<LLVMValueRef, TranslateError> {
+ let fn_type = get_function_type(
+ self.context,
+ iter::once(return_type),
+ arguments.iter().map(|(_, type_)| Ok(*type_)),
+ )?;
+ let mut fn_ = unsafe { LLVMGetNamedFunction(self.module, name.as_ptr()) };
+ if fn_ == ptr::null_mut() {
+ fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ }
+ let mut arguments = arguments.iter().map(|(arg, _)| *arg).collect::<Vec<_>>();
+ Ok(self.resolver.with_result_option(dst, |dst| unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ fn_type,
+ fn_,
+ arguments.as_mut_ptr(),
+ arguments.len() as u32,
+ dst,
+ )
+ }))
+ }
+
+ fn emit_neg(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::NegArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src = self.resolver.value(arguments.src)?;
+ let llvm_fn = if data.type_.kind() == ptx_parser::ScalarKind::Float {
+ LLVMBuildFNeg
+ } else {
+ LLVMBuildNeg
+ };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ llvm_fn(self.builder, src, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_not(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::NotArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildNot(self.builder, src, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_setp(
+ &mut self,
+ data: ptx_parser::SetpData,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if arguments.dst2.is_some() {
+ todo!()
+ }
+ match data.cmp_op {
+ ptx_parser::SetpCompareOp::Integer(setp_compare_int) => {
+ self.emit_setp_int(setp_compare_int, arguments)
+ }
+ ptx_parser::SetpCompareOp::Float(setp_compare_float) => {
+ self.emit_setp_float(setp_compare_float, arguments)
+ }
+ }
+ }
+
+ fn emit_setp_int(
+ &mut self,
+ setp: ptx_parser::SetpCompareInt,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let op = match setp {
+ ptx_parser::SetpCompareInt::Eq => LLVMIntPredicate::LLVMIntEQ,
+ ptx_parser::SetpCompareInt::NotEq => LLVMIntPredicate::LLVMIntNE,
+ ptx_parser::SetpCompareInt::UnsignedLess => LLVMIntPredicate::LLVMIntULT,
+ ptx_parser::SetpCompareInt::UnsignedLessOrEq => LLVMIntPredicate::LLVMIntULE,
+ ptx_parser::SetpCompareInt::UnsignedGreater => LLVMIntPredicate::LLVMIntUGT,
+ ptx_parser::SetpCompareInt::UnsignedGreaterOrEq => LLVMIntPredicate::LLVMIntUGE,
+ ptx_parser::SetpCompareInt::SignedLess => LLVMIntPredicate::LLVMIntSLT,
+ ptx_parser::SetpCompareInt::SignedLessOrEq => LLVMIntPredicate::LLVMIntSLE,
+ ptx_parser::SetpCompareInt::SignedGreater => LLVMIntPredicate::LLVMIntSGT,
+ ptx_parser::SetpCompareInt::SignedGreaterOrEq => LLVMIntPredicate::LLVMIntSGE,
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst1, |dst1| unsafe {
+ LLVMBuildICmp(self.builder, op, src1, src2, dst1)
+ });
+ Ok(())
+ }
+
+ fn emit_setp_float(
+ &mut self,
+ setp: ptx_parser::SetpCompareFloat,
+ arguments: ptx_parser::SetpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let op = match setp {
+ ptx_parser::SetpCompareFloat::Eq => LLVMRealPredicate::LLVMRealOEQ,
+ ptx_parser::SetpCompareFloat::NotEq => LLVMRealPredicate::LLVMRealONE,
+ ptx_parser::SetpCompareFloat::Less => LLVMRealPredicate::LLVMRealOLT,
+ ptx_parser::SetpCompareFloat::LessOrEq => LLVMRealPredicate::LLVMRealOLE,
+ ptx_parser::SetpCompareFloat::Greater => LLVMRealPredicate::LLVMRealOGT,
+ ptx_parser::SetpCompareFloat::GreaterOrEq => LLVMRealPredicate::LLVMRealOGE,
+ ptx_parser::SetpCompareFloat::NanEq => LLVMRealPredicate::LLVMRealUEQ,
+ ptx_parser::SetpCompareFloat::NanNotEq => LLVMRealPredicate::LLVMRealUNE,
+ ptx_parser::SetpCompareFloat::NanLess => LLVMRealPredicate::LLVMRealULT,
+ ptx_parser::SetpCompareFloat::NanLessOrEq => LLVMRealPredicate::LLVMRealULE,
+ ptx_parser::SetpCompareFloat::NanGreater => LLVMRealPredicate::LLVMRealUGT,
+ ptx_parser::SetpCompareFloat::NanGreaterOrEq => LLVMRealPredicate::LLVMRealUGE,
+ ptx_parser::SetpCompareFloat::IsNotNan => LLVMRealPredicate::LLVMRealORD,
+ ptx_parser::SetpCompareFloat::IsAnyNan => LLVMRealPredicate::LLVMRealUNO,
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst1, |dst1| unsafe {
+ LLVMBuildFCmp(self.builder, op, src1, src2, dst1)
+ });
+ Ok(())
+ }
+
+ fn emit_conditional(&mut self, cond: BrachCondition) -> Result<(), TranslateError> {
+ let predicate = self.resolver.value(cond.predicate)?;
+ let if_true = self.resolver.value(cond.if_true)?;
+ let if_false = self.resolver.value(cond.if_false)?;
+ unsafe {
+ LLVMBuildCondBr(
+ self.builder,
+ predicate,
+ LLVMValueAsBasicBlock(if_true),
+ LLVMValueAsBasicBlock(if_false),
+ )
+ };
+ Ok(())
+ }
+
+ fn emit_cvt(
+ &mut self,
+ data: ptx_parser::CvtDetails,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let dst_type = get_scalar_type(self.context, data.to);
+ let llvm_fn = match data.mode {
+ ptx_parser::CvtMode::ZeroExtend => LLVMBuildZExt,
+ ptx_parser::CvtMode::SignExtend => LLVMBuildSExt,
+ ptx_parser::CvtMode::Truncate => LLVMBuildTrunc,
+ ptx_parser::CvtMode::Bitcast => LLVMBuildBitCast,
+ ptx_parser::CvtMode::SaturateUnsignedToSigned => {
+ return self.emit_cvt_unsigned_to_signed_sat(data.from, data.to, arguments)
+ }
+ ptx_parser::CvtMode::SaturateSignedToUnsigned => {
+ return self.emit_cvt_signed_to_unsigned_sat(data.from, data.to, arguments)
+ }
+ ptx_parser::CvtMode::FPExtend { .. } => LLVMBuildFPExt,
+ ptx_parser::CvtMode::FPTruncate { .. } => LLVMBuildFPTrunc,
+ ptx_parser::CvtMode::FPRound {
+ integer_rounding, ..
+ } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ integer_rounding.unwrap_or(ast::RoundingMode::NearestEven),
+ arguments,
+ Some(LLVMBuildFPToSI),
+ )
+ }
+ ptx_parser::CvtMode::SignedFromFP { rounding, .. } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ rounding,
+ arguments,
+ Some(LLVMBuildFPToSI),
+ )
+ }
+ ptx_parser::CvtMode::UnsignedFromFP { rounding, .. } => {
+ return self.emit_cvt_float_to_int(
+ data.from,
+ data.to,
+ rounding,
+ arguments,
+ Some(LLVMBuildFPToUI),
+ )
+ }
+ ptx_parser::CvtMode::FPFromSigned(_) => todo!(),
+ ptx_parser::CvtMode::FPFromUnsigned(_) => todo!(),
+ };
+ let src = self.resolver.value(arguments.src)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ llvm_fn(self.builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_cvt_unsigned_to_signed_sat(
+ &mut self,
+ from: ptx_parser::ScalarType,
+ to: ptx_parser::ScalarType,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ // This looks dodgy, but it's fine. MAX bit pattern is always 0b11..1,
+ // so if it's downcast to a smaller type, it will be the maximum value
+ // of the smaller type
+ let max_value = match to {
+ ptx_parser::ScalarType::S8 => i8::MAX as u64,
+ ptx_parser::ScalarType::S16 => i16::MAX as u64,
+ ptx_parser::ScalarType::S32 => i32::MAX as u64,
+ ptx_parser::ScalarType::S64 => i64::MAX as u64,
+ _ => return Err(error_unreachable()),
+ };
+ let from_llvm = get_scalar_type(self.context, from);
+ let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
+ let clamped = self.emit_intrinsic(
+ c"llvm.umin",
+ None,
+ &from.into(),
+ vec![
+ (self.resolver.value(arguments.src)?, from_llvm),
+ (max, from_llvm),
+ ],
+ )?;
+ let resize_fn = if to.layout().size() >= from.layout().size() {
+ LLVMBuildSExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ };
+ let to_llvm = get_scalar_type(self.context, to);
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ resize_fn(self.builder, clamped, to_llvm, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_cvt_signed_to_unsigned_sat(
+ &mut self,
+ from: ptx_parser::ScalarType,
+ to: ptx_parser::ScalarType,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let from_llvm = get_scalar_type(self.context, from);
+ let zero = unsafe { LLVMConstInt(from_llvm, 0, 0) };
+ let zero_clamp_intrinsic = format!("llvm.smax.{}\0", LLVMTypeDisplay(from));
+ let zero_clamped = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(zero_clamp_intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![
+ (self.resolver.value(arguments.src)?, from_llvm),
+ (zero, from_llvm),
+ ],
+ )?;
+ // zero_clamped is now unsigned
+ let max_value = match to {
+ ptx_parser::ScalarType::U8 => u8::MAX as u64,
+ ptx_parser::ScalarType::U16 => u16::MAX as u64,
+ ptx_parser::ScalarType::U32 => u32::MAX as u64,
+ ptx_parser::ScalarType::U64 => u64::MAX as u64,
+ _ => return Err(error_unreachable()),
+ };
+ let max = unsafe { LLVMConstInt(from_llvm, max_value, 0) };
+ let max_clamp_intrinsic = format!("llvm.umin.{}\0", LLVMTypeDisplay(from));
+ let fully_clamped = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(max_clamp_intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![(zero_clamped, from_llvm), (max, from_llvm)],
+ )?;
+ let resize_fn = if to.layout().size() >= from.layout().size() {
+ LLVMBuildZExtOrBitCast
+ } else {
+ LLVMBuildTrunc
+ };
+ let to_llvm = get_scalar_type(self.context, to);
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ resize_fn(self.builder, fully_clamped, to_llvm, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_cvt_float_to_int(
+ &mut self,
+ from: ast::ScalarType,
+ to: ast::ScalarType,
+ rounding: ast::RoundingMode,
+ arguments: ptx_parser::CvtArgs<SpirvWord>,
+ llvm_cast: Option<
+ unsafe extern "C" fn(
+ arg1: LLVMBuilderRef,
+ Val: LLVMValueRef,
+ DestTy: LLVMTypeRef,
+ Name: *const i8,
+ ) -> LLVMValueRef,
+ >,
+ ) -> Result<(), TranslateError> {
+ let prefix = match rounding {
+ ptx_parser::RoundingMode::NearestEven => "llvm.roundeven",
+ ptx_parser::RoundingMode::Zero => "llvm.trunc",
+ ptx_parser::RoundingMode::NegativeInf => "llvm.floor",
+ ptx_parser::RoundingMode::PositiveInf => "llvm.ceil",
+ };
+ let intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(from));
+ let rounded_float = self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ None,
+ &from.into(),
+ vec![(
+ self.resolver.value(arguments.src)?,
+ get_scalar_type(self.context, from),
+ )],
+ )?;
+ if let Some(llvm_cast) = llvm_cast {
+ let to = get_scalar_type(self.context, to);
+ let poisoned_dst =
+ unsafe { llvm_cast(self.builder, rounded_float, to, LLVM_UNNAMED.as_ptr()) };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFreeze(self.builder, poisoned_dst, dst)
+ });
+ } else {
+ self.resolver.register(arguments.dst, rounded_float);
+ }
+ // Using explicit saturation gives us worse codegen: it explicitly checks for out of bound
+ // values and NaNs. Using non-saturated fptosi/fptoui emits v_cvt_<TO>_<FROM> which
+ // saturates by default and we don't care about NaNs anyway
+ /*
+ let cast_intrinsic = format!(
+ "{}.{}.{}\0",
+ llvm_cast,
+ LLVMTypeDisplay(to),
+ LLVMTypeDisplay(from)
+ );
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(cast_intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &to.into(),
+ vec![(rounded_float, get_scalar_type(self.context, from))],
+ )?;
+ */
+ Ok(())
+ }
+
+ fn emit_rsqrt(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::RsqrtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match data.type_ {
+ ast::ScalarType::F32 => c"llvm.amdgcn.rsq.f32",
+ ast::ScalarType::F64 => c"llvm.amdgcn.rsq.f64",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(self.resolver.value(arguments.src)?, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_sqrt(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::SqrtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match (data.type_, data.kind) {
+ (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.sqrt.f32",
+ (ast::ScalarType::F32, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f32",
+ (ast::ScalarType::F64, ast::RcpKind::Compliant(..)) => c"llvm.sqrt.f64",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(self.resolver.value(arguments.src)?, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_rcp(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::RcpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let intrinsic = match (data.type_, data.kind) {
+ (ast::ScalarType::F32, ast::RcpKind::Approx) => c"llvm.amdgcn.rcp.f32",
+ (_, ast::RcpKind::Compliant(rnd)) => {
+ return self.emit_rcp_compliant(data, arguments, rnd)
+ }
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(self.resolver.value(arguments.src)?, type_)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_rcp_compliant(
+ &mut self,
+ data: ptx_parser::RcpData,
+ arguments: ptx_parser::RcpArgs<SpirvWord>,
+ _rnd: ast::RoundingMode,
+ ) -> Result<(), TranslateError> {
+ let type_ = get_scalar_type(self.context, data.type_);
+ let one = unsafe { LLVMConstReal(type_, 1.0) };
+ let src = self.resolver.value(arguments.src)?;
+ let rcp = self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildFDiv(self.builder, one, src, dst)
+ });
+ unsafe { LLVMZludaSetFastMathFlags(rcp, LLVMZludaFastMathAllowReciprocal) };
+ Ok(())
+ }
+
+ fn emit_shr(
+ &mut self,
+ data: ptx_parser::ShrData,
+ arguments: ptx_parser::ShrArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let shift_fn = match data.kind {
+ ptx_parser::RightShiftKind::Arithmetic => LLVMBuildAShr,
+ ptx_parser::RightShiftKind::Logical => LLVMBuildLShr,
+ };
+ self.emit_shift(
+ data.type_,
+ arguments.dst,
+ arguments.src1,
+ arguments.src2,
+ shift_fn,
+ )
+ }
+
+ fn emit_shl(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ arguments: ptx_parser::ShlArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.emit_shift(
+ type_,
+ arguments.dst,
+ arguments.src1,
+ arguments.src2,
+ LLVMBuildShl,
+ )
+ }
+
+ fn emit_shift(
+ &mut self,
+ type_: ast::ScalarType,
+ dst: SpirvWord,
+ src1: SpirvWord,
+ src2: SpirvWord,
+ llvm_fn: unsafe extern "C" fn(
+ LLVMBuilderRef,
+ LLVMValueRef,
+ LLVMValueRef,
+ *const i8,
+ ) -> LLVMValueRef,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(src1)?;
+ let shift_size = self.resolver.value(src2)?;
+ let integer_bits = type_.layout().size() * 8;
+ let integer_bits_constant = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::U32),
+ integer_bits as u64,
+ 0,
+ )
+ };
+ let should_clamp = unsafe {
+ LLVMBuildICmp(
+ self.builder,
+ LLVMIntPredicate::LLVMIntUGE,
+ shift_size,
+ integer_bits_constant,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let llvm_type = get_scalar_type(self.context, type_);
+ let zero = unsafe { LLVMConstNull(llvm_type) };
+ let normalized_shift_size = if type_.layout().size() >= 4 {
+ unsafe {
+ LLVMBuildZExtOrBitCast(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr())
+ }
+ } else {
+ unsafe { LLVMBuildTrunc(self.builder, shift_size, llvm_type, LLVM_UNNAMED.as_ptr()) }
+ };
+ let shifted = unsafe {
+ llvm_fn(
+ self.builder,
+ src1,
+ normalized_shift_size,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ self.resolver.with_result(dst, |dst| unsafe {
+ LLVMBuildSelect(self.builder, should_clamp, zero, shifted, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_ex2(
+ &mut self,
+ data: ptx_parser::TypeFtz,
+ arguments: ptx_parser::Ex2Args<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = match data.type_ {
+ ast::ScalarType::F16 => c"llvm.amdgcn.exp2.f16",
+ ast::ScalarType::F32 => c"llvm.amdgcn.exp2.f32",
+ _ => return Err(error_unreachable()),
+ };
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![(
+ self.resolver.value(arguments.src)?,
+ get_scalar_type(self.context, data.type_),
+ )],
+ )?;
+ Ok(())
+ }
+
+ fn emit_lg2(
+ &mut self,
+ _data: ptx_parser::FlushToZero,
+ arguments: ptx_parser::Lg2Args<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.emit_intrinsic(
+ c"llvm.amdgcn.log.f32",
+ Some(arguments.dst),
+ &ast::ScalarType::F32.into(),
+ vec![(
+ self.resolver.value(arguments.src)?,
+ get_scalar_type(self.context, ast::ScalarType::F32.into()),
+ )],
+ )?;
+ Ok(())
+ }
+
+ fn emit_selp(
+ &mut self,
+ _data: ptx_parser::ScalarType,
+ arguments: ptx_parser::SelpArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let src3 = self.resolver.value(arguments.src3)?;
+ self.resolver.with_result(arguments.dst, |dst_name| unsafe {
+ LLVMBuildSelect(self.builder, src3, src1, src2, dst_name)
+ });
+ Ok(())
+ }
+
+ fn emit_rem(
+ &mut self,
+ data: ptx_parser::ScalarType,
+ arguments: ptx_parser::RemArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_fn = match data.kind() {
+ ptx_parser::ScalarKind::Unsigned => LLVMBuildURem,
+ ptx_parser::ScalarKind::Signed => LLVMBuildSRem,
+ _ => return Err(error_unreachable()),
+ };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ self.resolver.with_result(arguments.dst, |dst_name| unsafe {
+ llvm_fn(self.builder, src1, src2, dst_name)
+ });
+ Ok(())
+ }
+
+ fn emit_popc(
+ &mut self,
+ type_: ptx_parser::ScalarType,
+ arguments: ptx_parser::PopcArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = match type_ {
+ ast::ScalarType::B32 => c"llvm.ctpop.i32",
+ ast::ScalarType::B64 => c"llvm.ctpop.i64",
+ _ => return Err(error_unreachable()),
+ };
+ let llvm_type = get_scalar_type(self.context, type_);
+ self.emit_intrinsic(
+ intrinsic,
+ Some(arguments.dst),
+ &type_.into(),
+ vec![(self.resolver.value(arguments.src)?, llvm_type)],
+ )?;
+ Ok(())
+ }
+
+ fn emit_min(
+ &mut self,
+ data: ptx_parser::MinMaxDetails,
+ arguments: ptx_parser::MinArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_prefix = match data {
+ ptx_parser::MinMaxDetails::Signed(..) => "llvm.smin",
+ ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umin",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
+ return Err(error_todo())
+ }
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.maxnum",
+ };
+ let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
+ let llvm_type = get_scalar_type(self.context, data.type_());
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_().into(),
+ vec![
+ (self.resolver.value(arguments.src1)?, llvm_type),
+ (self.resolver.value(arguments.src2)?, llvm_type),
+ ],
+ )?;
+ Ok(())
+ }
+
+ fn emit_max(
+ &mut self,
+ data: ptx_parser::MinMaxDetails,
+ arguments: ptx_parser::MaxArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_prefix = match data {
+ ptx_parser::MinMaxDetails::Signed(..) => "llvm.smax",
+ ptx_parser::MinMaxDetails::Unsigned(..) => "llvm.umax",
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { nan: true, .. }) => {
+ return Err(error_todo())
+ }
+ ptx_parser::MinMaxDetails::Float(ptx_parser::MinMaxFloat { .. }) => "llvm.minnum",
+ };
+ let intrinsic = format!("{}.{}\0", llvm_prefix, LLVMTypeDisplay(data.type_()));
+ let llvm_type = get_scalar_type(self.context, data.type_());
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_().into(),
+ vec![
+ (self.resolver.value(arguments.src1)?, llvm_type),
+ (self.resolver.value(arguments.src2)?, llvm_type),
+ ],
+ )?;
+ Ok(())
+ }
+
+ fn emit_fma(
+ &mut self,
+ data: ptx_parser::ArithFloat,
+ arguments: ptx_parser::FmaArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let intrinsic = format!("llvm.fma.{}\0", LLVMTypeDisplay(data.type_));
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_.into(),
+ vec![
+ (
+ self.resolver.value(arguments.src1)?,
+ get_scalar_type(self.context, data.type_),
+ ),
+ (
+ self.resolver.value(arguments.src2)?,
+ get_scalar_type(self.context, data.type_),
+ ),
+ (
+ self.resolver.value(arguments.src3)?,
+ get_scalar_type(self.context, data.type_),
+ ),
+ ],
+ )?;
+ Ok(())
+ }
+
+ fn emit_mad(
+ &mut self,
+ data: ptx_parser::MadDetails,
+ arguments: ptx_parser::MadArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let mul_control = match data {
+ ptx_parser::MadDetails::Float(mad_float) => {
+ return self.emit_fma(
+ mad_float,
+ ast::FmaArgs {
+ dst: arguments.dst,
+ src1: arguments.src1,
+ src2: arguments.src2,
+ src3: arguments.src3,
+ },
+ )
+ }
+ ptx_parser::MadDetails::Integer { saturate: true, .. } => return Err(error_todo()),
+ ptx_parser::MadDetails::Integer { type_, control, .. } => {
+ ast::MulDetails::Integer { control, type_ }
+ }
+ };
+ let temp = self.emit_mul_impl(mul_control, None, arguments.src1, arguments.src2)?;
+ let src3 = self.resolver.value(arguments.src3)?;
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildAdd(self.builder, temp, src3, dst)
+ });
+ Ok(())
+ }
+
+ fn emit_membar(&self, data: ptx_parser::MemScope) -> Result<(), TranslateError> {
+ unsafe {
+ LLVMZludaBuildFence(
+ self.builder,
+ LLVMAtomicOrdering::LLVMAtomicOrderingSequentiallyConsistent,
+ get_scope_membar(data)?,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ Ok(())
+ }
+
+ fn emit_prmt(
+ &mut self,
+ control: u16,
+ arguments: ptx_parser::PrmtArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let components = [
+ (control >> 0) & 0b1111,
+ (control >> 4) & 0b1111,
+ (control >> 8) & 0b1111,
+ (control >> 12) & 0b1111,
+ ];
+ if components.iter().any(|&c| c > 7) {
+ return Err(TranslateError::Todo);
+ }
+ let u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
+ let v4u8_type = get_type(self.context, &ast::Type::Vector(4, ast::ScalarType::U8))?;
+ let mut components = [
+ unsafe { LLVMConstInt(u32_type, components[0] as _, 0) },
+ unsafe { LLVMConstInt(u32_type, components[1] as _, 0) },
+ unsafe { LLVMConstInt(u32_type, components[2] as _, 0) },
+ unsafe { LLVMConstInt(u32_type, components[3] as _, 0) },
+ ];
+ let components_indices =
+ unsafe { LLVMConstVector(components.as_mut_ptr(), components.len() as u32) };
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src1_vector =
+ unsafe { LLVMBuildBitCast(self.builder, src1, v4u8_type, LLVM_UNNAMED.as_ptr()) };
+ let src2 = self.resolver.value(arguments.src2)?;
+ let src2_vector =
+ unsafe { LLVMBuildBitCast(self.builder, src2, v4u8_type, LLVM_UNNAMED.as_ptr()) };
+ self.resolver.with_result(arguments.dst, |dst| unsafe {
+ LLVMBuildShuffleVector(
+ self.builder,
+ src1_vector,
+ src2_vector,
+ components_indices,
+ dst,
+ )
+ });
+ Ok(())
+ }
+
+ /*
+ // Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
+ // Should be available in LLVM 19
+ fn with_rounding<T>(&mut self, rnd: ast::RoundingMode, fn_: impl FnOnce(&mut Self) -> T) -> T {
+ let mut u32_type = get_scalar_type(self.context, ast::ScalarType::U32);
+ let void_type = unsafe { LLVMVoidTypeInContext(self.context) };
+ let get_rounding = c"llvm.get.rounding";
+ let get_rounding_fn_type = unsafe { LLVMFunctionType(u32_type, ptr::null_mut(), 0, 0) };
+ let mut get_rounding_fn =
+ unsafe { LLVMGetNamedFunction(self.module, get_rounding.as_ptr()) };
+ if get_rounding_fn == ptr::null_mut() {
+ get_rounding_fn = unsafe {
+ LLVMAddFunction(self.module, get_rounding.as_ptr(), get_rounding_fn_type)
+ };
+ }
+ let set_rounding = c"llvm.set.rounding";
+ let set_rounding_fn_type = unsafe { LLVMFunctionType(void_type, &mut u32_type, 1, 0) };
+ let mut set_rounding_fn =
+ unsafe { LLVMGetNamedFunction(self.module, set_rounding.as_ptr()) };
+ if set_rounding_fn == ptr::null_mut() {
+ set_rounding_fn = unsafe {
+ LLVMAddFunction(self.module, set_rounding.as_ptr(), set_rounding_fn_type)
+ };
+ }
+ let mut preserved_rounding_mode = unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ get_rounding_fn_type,
+ get_rounding_fn,
+ ptr::null_mut(),
+ 0,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let mut requested_rounding = unsafe {
+ LLVMConstInt(
+ get_scalar_type(self.context, ast::ScalarType::B32),
+ rounding_to_llvm(rnd) as u64,
+ 0,
+ )
+ };
+ unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ set_rounding_fn_type,
+ set_rounding_fn,
+ &mut requested_rounding,
+ 1,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ let result = fn_(self);
+ unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ set_rounding_fn_type,
+ set_rounding_fn,
+ &mut preserved_rounding_mode,
+ 1,
+ LLVM_UNNAMED.as_ptr(),
+ )
+ };
+ result
+ }
+ */
}
fn get_pointer_type<'ctx>(
@@ -613,6 +2223,45 @@ fn get_pointer_type<'ctx>( Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
}
+// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
+fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
+ Ok(match scope {
+ ast::MemScope::Cta => c"workgroup-one-as",
+ ast::MemScope::Gpu => c"agent-one-as",
+ ast::MemScope::Sys => c"one-as",
+ ast::MemScope::Cluster => todo!(),
+ }
+ .as_ptr())
+}
+
+fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
+ Ok(match scope {
+ ast::MemScope::Cta => c"workgroup",
+ ast::MemScope::Gpu => c"agent",
+ ast::MemScope::Sys => c"",
+ ast::MemScope::Cluster => todo!(),
+ }
+ .as_ptr())
+}
+
+fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
+ match semantics {
+ ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
+ ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease,
+ ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease,
+ }
+}
+
+fn get_ordering_failure(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
+ match semantics {
+ ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
+ ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ }
+}
+
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
@@ -670,8 +2319,7 @@ fn get_function_type<'a>( mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
- let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
- input_args.collect::<Result<Vec<_>, _>>()?;
+ let mut input_args = input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,
@@ -747,8 +2395,57 @@ impl ResolveIdent { .ok_or_else(|| error_unreachable())
}
- fn with_result(&mut self, word: SpirvWord, fn_: impl FnOnce(*const i8) -> LLVMValueRef) {
+ fn with_result(
+ &mut self,
+ word: SpirvWord,
+ fn_: impl FnOnce(*const i8) -> LLVMValueRef,
+ ) -> LLVMValueRef {
let t = self.get_or_ad_impl(word, |dst| fn_(dst.as_ptr().cast()));
self.register(word, t);
+ t
+ }
+
+ fn with_result_option(
+ &mut self,
+ word: Option<SpirvWord>,
+ fn_: impl FnOnce(*const i8) -> LLVMValueRef,
+ ) -> LLVMValueRef {
+ match word {
+ Some(word) => self.with_result(word, fn_),
+ None => fn_(LLVM_UNNAMED.as_ptr()),
+ }
+ }
+}
+
+struct LLVMTypeDisplay(ast::ScalarType);
+
+impl std::fmt::Display for LLVMTypeDisplay {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self.0 {
+ ast::ScalarType::Pred => write!(f, "i1"),
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => write!(f, "i8"),
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => write!(f, "i16"),
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => write!(f, "i32"),
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => write!(f, "i64"),
+ ptx_parser::ScalarType::B128 => write!(f, "i128"),
+ ast::ScalarType::F16 => write!(f, "f16"),
+ ptx_parser::ScalarType::BF16 => write!(f, "bfloat"),
+ ast::ScalarType::F32 => write!(f, "f32"),
+ ast::ScalarType::F64 => write!(f, "f64"),
+ ptx_parser::ScalarType::S16x2 | ptx_parser::ScalarType::U16x2 => write!(f, "v2i16"),
+ ast::ScalarType::F16x2 => write!(f, "v2f16"),
+ ptx_parser::ScalarType::BF16x2 => write!(f, "v2bfloat"),
+ }
+ }
+}
+
+/*
+fn rounding_to_llvm(this: ast::RoundingMode) -> u32 {
+ match this {
+ ptx_parser::RoundingMode::Zero => 0,
+ ptx_parser::RoundingMode::NearestEven => 1,
+ ptx_parser::RoundingMode::PositiveInf => 2,
+ ptx_parser::RoundingMode::NegativeInf => 3,
}
}
+*/
|