diff options
author | Andrzej Janik <[email protected]> | 2024-09-26 18:54:15 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-09-26 18:54:15 +0200 |
commit | 820eaf8ada98ca480695ced2d93109401cce9c9f (patch) | |
tree | d14aa19e3d9979ad596bcad41e1c34006f7907fc /ptx | |
parent | c4e131519413fbb070e5df60c9c6716d894c09be (diff) | |
download | ZLUDA-820eaf8ada98ca480695ced2d93109401cce9c9f.tar.gz ZLUDA-820eaf8ada98ca480695ced2d93109401cce9c9f.zip |
Implement atomics
Diffstat (limited to 'ptx')
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 121 | ||||
-rw-r--r-- | ptx/src/pass/insert_explicit_load_store.rs | 51 |
2 files changed, 146 insertions, 26 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 7f74d1a..bc5f745 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -19,15 +19,15 @@ // unsafe { LLVMBuildAdd(builder, lhs, rhs, dst) };
use std::convert::{TryFrom, TryInto};
-use std::ffi::CStr;
+use std::ffi::{CStr, NulError};
use std::ops::Deref;
use std::ptr;
use super::*;
use llvm_zluda::analysis::{LLVMVerifierFailureAction, LLVMVerifyModule};
use llvm_zluda::bit_writer::LLVMWriteBitcodeToMemoryBuffer;
-use llvm_zluda::core::*;
-use llvm_zluda::prelude::*;
+use llvm_zluda::{core::*, LLVMAtomicOrdering, LLVMAtomicRMWBinOp, LLVMZludaAtomicRMWBinOp};
+use llvm_zluda::{prelude::*, LLVMZludaBuildAtomicRMW};
use llvm_zluda::{LLVMCallConv, LLVMZludaBuildAlloca};
const LLVM_UNNAMED: &CStr = c"";
@@ -172,7 +172,7 @@ pub(super) fn run<'input>( let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
- Directive2::Variable(..) => todo!(),
+ Directive2::Variable(linking, variable) => emit_ctx.emit_global(linking, variable)?,
Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
@@ -281,6 +281,43 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { }
Ok(())
}
+
+ fn emit_global(
+ &mut self,
+ linking: ast::LinkingDirective,
+ var: ptx_parser::Variable<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let name = self
+ .id_defs
+ .ident_map
+ .get(&var.name)
+ .map(|entry| {
+ entry
+ .name
+ .as_ref()
+ .map(|text| Ok::<_, NulError>(Cow::Owned(CString::new(&**text)?)))
+ })
+ .flatten()
+ .transpose()
+ .map_err(|_| error_unreachable())?
+ .unwrap_or(Cow::Borrowed(LLVM_UNNAMED));
+ let global = unsafe {
+ LLVMAddGlobalInAddressSpace(
+ self.module,
+ get_type(self.context, &var.v_type)?,
+ name.as_ptr(),
+ get_state_space(var.state_space)?,
+ )
+ };
+ self.resolver.register(var.name, global);
+ if let Some(align) = var.align {
+ unsafe { LLVMSetAlignment(global, align) };
+ }
+ if !var.array_init.is_empty() {
+ todo!()
+ }
+ Ok(())
+ }
}
fn get_input_argument_type(
@@ -419,7 +456,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Rsqrt { data, arguments } => todo!(),
ast::Instruction::Selp { data, arguments } => todo!(),
ast::Instruction::Bar { data, arguments } => todo!(),
- ast::Instruction::Atom { data, arguments } => todo!(),
+ ast::Instruction::Atom { data, arguments } => self.emit_atom(data, arguments),
ast::Instruction::AtomCas { data, arguments } => todo!(),
ast::Instruction::Div { data, arguments } => todo!(),
ast::Instruction::Neg { data, arguments } => todo!(),
@@ -499,7 +536,14 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { });
Ok(())
}
- ConversionKind::PtrToPtr => todo!(),
+ ConversionKind::PtrToPtr => {
+ let src = self.resolver.value(conversion.src)?;
+ let dst_type = get_pointer_type(self.context, conversion.to_space)?;
+ self.resolver.with_result(conversion.dst, |dst| unsafe {
+ LLVMBuildAddrSpaceCast(builder, src, dst_type, dst)
+ });
+ Ok(())
+ }
ConversionKind::AddressOf => todo!(),
}
}
@@ -635,6 +679,51 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { });
Ok(())
}
+
+ fn emit_atom(
+ &mut self,
+ data: ptx_parser::AtomDetails,
+ arguments: ptx_parser::AtomArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let builder = self.builder;
+ let src1 = self.resolver.value(arguments.src1)?;
+ let src2 = self.resolver.value(arguments.src2)?;
+ let op = match data.op {
+ ptx_parser::AtomicOp::And => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAnd,
+ ptx_parser::AtomicOp::Or => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpOr,
+ ptx_parser::AtomicOp::Xor => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXor,
+ ptx_parser::AtomicOp::Exchange => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpXchg,
+ ptx_parser::AtomicOp::Add => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpAdd,
+ ptx_parser::AtomicOp::IncrementWrap => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUIncWrap
+ }
+ ptx_parser::AtomicOp::DecrementWrap => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUDecWrap
+ }
+ ptx_parser::AtomicOp::SignedMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMin,
+ ptx_parser::AtomicOp::UnsignedMin => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMin
+ }
+ ptx_parser::AtomicOp::SignedMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpMax,
+ ptx_parser::AtomicOp::UnsignedMax => {
+ LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpUMax
+ }
+ ptx_parser::AtomicOp::FloatAdd => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFAdd,
+ ptx_parser::AtomicOp::FloatMin => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMin,
+ ptx_parser::AtomicOp::FloatMax => LLVMZludaAtomicRMWBinOp::LLVMZludaAtomicRMWBinOpFMax,
+ };
+ self.resolver.register(arguments.dst, unsafe {
+ LLVMZludaBuildAtomicRMW(
+ builder,
+ op,
+ src1,
+ src2,
+ get_scope(data.scope)?,
+ get_ordering(data.semantics),
+ )
+ });
+ Ok(())
+ }
}
fn get_pointer_type<'ctx>(
@@ -644,6 +733,26 @@ fn get_pointer_type<'ctx>( Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(to_space)?) })
}
+// https://llvm.org/docs/AMDGPUUsage.html#memory-scopes
+fn get_scope(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
+ Ok(match scope {
+ ast::MemScope::Cta => c"workgroup-one-as",
+ ast::MemScope::Gpu => c"agent-one-as",
+ ast::MemScope::Sys => c"one-as",
+ ast::MemScope::Cluster => todo!(),
+ }
+ .as_ptr())
+}
+
+fn get_ordering(semantics: ast::AtomSemantics) -> LLVMAtomicOrdering {
+ match semantics {
+ ast::AtomSemantics::Relaxed => LLVMAtomicOrdering::LLVMAtomicOrderingMonotonic,
+ ast::AtomSemantics::Acquire => LLVMAtomicOrdering::LLVMAtomicOrderingAcquire,
+ ast::AtomSemantics::Release => LLVMAtomicOrdering::LLVMAtomicOrderingRelease,
+ ast::AtomSemantics::AcqRel => LLVMAtomicOrdering::LLVMAtomicOrderingAcquireRelease,
+ }
+}
+
fn get_type(context: LLVMContextRef, type_: &ast::Type) -> Result<LLVMTypeRef, TranslateError> {
Ok(match type_ {
ast::Type::Scalar(scalar) => get_scalar_type(context, *scalar),
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index ec6498c..42988ea 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -52,7 +52,7 @@ fn run_method<'a, 'input>( let new_name = visitor
.resolver
.register_unnamed(Some((arg.v_type.clone(), new_space)));
- visitor.input_argument(old_name, new_name, old_space);
+ visitor.input_argument(old_name, new_name, old_space)?;
arg.name = new_name;
arg.state_space = new_space;
}
@@ -154,7 +154,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { old_name: SpirvWord,
new_name: SpirvWord,
old_space: ast::StateSpace,
- ) -> Result<(), TranslateError> {
+ ) -> Result<bool, TranslateError> {
Ok(match old_space {
ast::StateSpace::Reg => {
self.variables.insert(
@@ -164,6 +164,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { type_: type_.clone(),
},
);
+ true
}
ast::StateSpace::Param => {
self.variables.insert(
@@ -174,19 +175,18 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { name: new_name,
},
);
+ true
}
// Good as-is
- ast::StateSpace::Local => {}
- // Will be pulled into global scope later
- ast::StateSpace::Generic
+ ast::StateSpace::Local
+ | ast::StateSpace::Generic
| ast::StateSpace::SharedCluster
| ast::StateSpace::Global
| ast::StateSpace::Const
| ast::StateSpace::SharedCta
- | ast::StateSpace::Shared => {}
- ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
- return Err(error_unreachable())
- }
+ | ast::StateSpace::Shared
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc => return Err(error_unreachable()),
})
}
@@ -239,17 +239,28 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }
fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) -> Result<(), TranslateError> {
- if var.state_space != ast::StateSpace::Local {
- let old_name = var.name;
- let old_space = var.state_space;
- let new_space = ast::StateSpace::Local;
- let new_name = self
- .resolver
- .register_unnamed(Some((var.v_type.clone(), new_space)));
- self.variable(&var.v_type, old_name, new_name, old_space)?;
- var.name = new_name;
- var.state_space = new_space;
- }
+ let old_space = match var.state_space {
+ space @ (ptx_parser::StateSpace::Reg | ptx_parser::StateSpace::Param) => space,
+ // Do nothing
+ ptx_parser::StateSpace::Local => return Ok(()),
+ // Handled by another pass
+ ptx_parser::StateSpace::Generic
+ | ptx_parser::StateSpace::SharedCluster
+ | ptx_parser::StateSpace::ParamEntry
+ | ptx_parser::StateSpace::Global
+ | ptx_parser::StateSpace::SharedCta
+ | ptx_parser::StateSpace::Const
+ | ptx_parser::StateSpace::Shared
+ | ptx_parser::StateSpace::ParamFunc => return Ok(()),
+ };
+ let old_name = var.name;
+ let new_space = ast::StateSpace::Local;
+ let new_name = self
+ .resolver
+ .register_unnamed(Some((var.v_type.clone(), new_space)));
+ self.variable(&var.v_type, old_name, new_name, old_space)?;
+ var.name = new_name;
+ var.state_space = new_space;
Ok(())
}
}
|