aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/emit_llvm.rs
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2024-09-26 18:54:15 +0200
committerAndrzej Janik <[email protected]>2024-09-26 18:54:15 +0200
commit820eaf8ada98ca480695ced2d93109401cce9c9f (patch)
treed14aa19e3d9979ad596bcad41e1c34006f7907fc /ptx/src/pass/emit_llvm.rs
parentc4e131519413fbb070e5df60c9c6716d894c09be (diff)
downloadZLUDA-820eaf8ada98ca480695ced2d93109401cce9c9f.tar.gz
ZLUDA-820eaf8ada98ca480695ced2d93109401cce9c9f.zip
Implement atomics
Diffstat (limited to 'ptx/src/pass/emit_llvm.rs')
-rw-r--r--ptx/src/pass/emit_llvm.rs121
1 files changed, 115 insertions, 6 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index 7f74d1a..bc5f745 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -19,15 +19,15 @@
// unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
use std::convert::{TryFrom, TryInto};
-use std::ffi::CStr;
+use std::ffi::{CStr, NulError};
use std::ops::Deref;
use std::ptr;
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
-use llvm_zluda::core::*;
-use llvm_zluda::prelude::*;
+use llvm_zluda::{core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp};
+use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
const LLVM_UNNAMED: &CStr = c"";
@@ -172,7 +172,7 @@ pub(super) fn run<'input>(
let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
- Directive2::Variable(..) => todo!(),
+ Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?,
Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
@@ -281,6 +281,43 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
Ok(())
}
+
+ fn emit_global(
+ &mut self,
+ linking: ast::LinkingDirective,
+ var: ptx_parser::Variable<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let name = self
+ .id_defs
+ .ident_map
+ .get(&var.name)
+ .map(|entry| {
+ entry
+ .name
+ .as_ref()
+ .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?)))
+ })
+ .flatten()
+ .transpose()
+ .map_err(|_| error_unreachable())?
+ .unwrap_or(Cow::Borrowed(LLVM_UNNAMED));
+ let global = unsafe {
+ LLVMAddGlobalInAddressSpace(
+ self.module,
+ get_type(self.context, &var.v_type)?,
+ name.as_ptr(),
+ get_state_space(var.state_space)?,
+ )
+ };
+ self.resolver.register(var.name, global);
+ if let Some(align) = var.align {
+ unsafe { LLVMSetAlignment(global, align) };
+ }
+ if !var.array_init.is_empty() {
+ todo!()
+ }
+ Ok(())
+ }
}
fn get_input_argument_type(
@@ -419,7 +456,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
ast::Instruction::Rsqrt { data, arguments } => todo!(),
ast::Instruction::Selp { data, arguments } => todo!(),
ast::Instruction::Bar { data, arguments } => todo!(),
- ast::Instruction::Atom { data, arguments } => todo!(),
+ ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => todo!(),
ast::Instruction::Div { data, arguments } => todo!(),
ast::Instruction::Neg { data, arguments } => todo!(),
@@ -499,7 +536,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
- ConversionKind::PtrToPtr => todo!(),
+ ConversionKind::PtrToPtr => {
+ let src = self.resolver.value(conversion.src)?;
+ let dst_type = get_pointer_type(self.context, conversion.to_space)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildAddrSpaceCast(builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
ConversionKind::AddressOf => todo!(),
}
}
@@ -635,6 +679,51 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
});
Ok(())
}
+
+ fn emit_atom(
+ &mut self,
+ data: ptx_parser::AtomDetails,
+ arguments: ptx_parser::AtomArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let op = match data.op {
+ ptx_parser::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
+ ptx_parser::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
+ ptx_parser::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
+ ptx_parser::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
+ ptx_parser::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
+ ptx_parser::AtomicOp::IncrementWrap => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
+ }
+ ptx_parser::AtomicOp::DecrementWrap => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
+ }
+ ptx_parser::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
+ ptx_parser::AtomicOp::UnsignedMin => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin
+ }
+ ptx_parser::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
+ ptx_parser::AtomicOp::UnsignedMax => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax
+ }
+ ptx_parser::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
+ ptx_parser::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
+ ptx_parser::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
+ };
+ self.resolver.register(arguments.dst, unsafe {
+ LLVMZludaBuildAtomicRMW(
+ builder,
+ op,
+ src1,
+ src2,
+ get_scope(data.scope)?,
+ get_ordering(data.semantics),
+ )
+ });
+ Ok(())
+ }
}
fn get_pointer_type<'ctx>(
@@ -644,6 +733,26 @@ fn get_pointer_type<'ctx>(
Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
}
+// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
+fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
+ Ok(match scope {
+ ast::MemScope::Cta => c"workgroup-one-as",
+ ast::MemScope::Gpu => c"agent-one-as",
+ ast::MemScope::Sys => c"one-as",
+ ast::MemScope::Cluster => todo!(),
+ }
+ .as_ptr())
+}
+
+fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
+ match semantics {
+ ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
+ ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease,
+ ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease,
+ }
+}
+
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),