diff options
author | Andrzej Janik <[email protected]> | 2020-10-31 21:28:15 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-10-31 21:28:15 +0100 |
commit | a82eb2081717c1fb48e140176fec0e5b5974a432 (patch) | |
tree | b5ca6934333d1707ed43a1e21a8f02f630929dc4 | |
parent | 861116f223081528cf1e32f5e1eddb733ac00241 (diff) | |
download | ZLUDA-a82eb2081717c1fb48e140176fec0e5b5974a432.tar.gz ZLUDA-a82eb2081717c1fb48e140176fec0e5b5974a432.zip |
Implement atomic instructions
24 files changed, 1672 insertions, 88 deletions
diff --git a/level_zero/Cargo.toml b/level_zero/Cargo.toml index 97537b3..851159d 100644 --- a/level_zero/Cargo.toml +++ b/level_zero/Cargo.toml @@ -7,4 +7,8 @@ edition = "2018" [lib] [dependencies] -level_zero-sys = { path = "../level_zero-sys" }
\ No newline at end of file +level_zero-sys = { path = "../level_zero-sys" } + +[dependencies.ocl-core] +version = "0.11" +features = ["opencl_version_1_2", "opencl_version_2_0", "opencl_version_2_1"]
\ No newline at end of file diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 5ced5d0..f8a2c3b 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -238,7 +238,76 @@ impl Drop for CommandQueue { pub struct Module(sys::ze_module_handle_t);
impl Module {
- pub fn new_spirv(
+ // HACK ALERT
+ // We use OpenCL for now to do SPIR-V linking, because Level0
+ // does not allow linking. Don't let presence of zeModuleDynamicLink fool
+ // you, it's not currently possible to create non-compiled modules.
+ // zeModuleCreate always compiles (builds and links).
+ pub fn build_link_spirv<'a>(
+ ctx: &mut Context,
+ d: &Device,
+ binaries: &[&'a [u8]],
+ ) -> (Result<Self>, Option<BuildLog>) {
+ let ocl_program = match Self::build_link_spirv_impl(binaries) {
+ Err(_) => return (Err(sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None),
+ Ok(prog) => prog,
+ };
+ match ocl_core::get_program_info(&ocl_program, ocl_core::ProgramInfo::Binaries) {
+ Ok(ocl_core::ProgramInfoResult::Binaries(binaries)) => {
+ let (module, build_log) = Self::build_native(ctx, d, &binaries[0]);
+ (module, Some(build_log))
+ }
+ _ => return (Err(sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None),
+ }
+ }
+
+ fn build_link_spirv_impl<'a>(binaries: &[&'a [u8]]) -> ocl_core::Result<ocl_core::Program> {
+ let platforms = ocl_core::get_platform_ids()?;
+ let (platform, device) = platforms
+ .iter()
+ .find_map(|plat| {
+ let devices =
+ ocl_core::get_device_ids(plat, Some(ocl_core::DeviceType::GPU), None).ok()?;
+ for dev in devices {
+ let vendor =
+ ocl_core::get_device_info(dev, ocl_core::DeviceInfo::VendorId).ok()?;
+ if let ocl_core::DeviceInfoResult::VendorId(0x8086) = vendor {
+ let dev_type =
+ ocl_core::get_device_info(dev, ocl_core::DeviceInfo::Type).ok()?;
+ if let ocl_core::DeviceInfoResult::Type(ocl_core::DeviceType::GPU) =
+ dev_type
+ {
+ return Some((plat.clone(), dev));
+ }
+ }
+ }
+ None
+ })
+ .ok_or("")?;
+ let ctx_props = ocl_core::ContextProperties::new().platform(platform);
+ let ocl_ctx = ocl_core::create_context_from_type::<ocl_core::DeviceId>(
+ Some(&ctx_props),
+ ocl_core::DeviceType::GPU,
+ None,
+ None,
+ )?;
+ let mut programs = Vec::with_capacity(binaries.len());
+ for binary in binaries {
+ programs.push(ocl_core::create_program_with_il(&ocl_ctx, binary, None)?);
+ }
+ let options = CString::default();
+ ocl_core::link_program::<ocl_core::DeviceId, _>(
+ &ocl_ctx,
+ Some(&[device]),
+ &options,
+ &programs.iter().collect::<Vec<_>>(),
+ None,
+ None,
+ None,
+ )
+ }
+
+ pub fn build_spirv(
ctx: &mut Context,
d: &Device,
bin: &[u8],
@@ -247,7 +316,7 @@ impl Module { Module::new(ctx, true, d, bin, opts)
}
- pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result<Self>, BuildLog) {
+ pub fn build_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result<Self>, BuildLog) {
Module::new(ctx, false, d, bin, None)
}
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index eea862b..35436c3 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -53,7 +53,7 @@ impl ModuleData { Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)), Ok(ast) => ast, }; - let (spirv, all_arg_lens) = ptx::to_spirv(ast)?; + let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?; let byte_il = unsafe { slice::from_raw_parts::<u8>( spirv.as_ptr() as *const _, @@ -61,7 +61,7 @@ impl ModuleData { ) }; let module = super::device::with_current_exclusive(|dev| { - l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None) + l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None) }); match module { Ok((Ok(module), _)) => Ok(Mutex::new(Self { diff --git a/ptx/lib/notcuda_ptx_impl.cl b/ptx/lib/notcuda_ptx_impl.cl new file mode 100644 index 0000000..a0d487b --- /dev/null +++ b/ptx/lib/notcuda_ptx_impl.cl @@ -0,0 +1,121 @@ +// Every time this file changes it must te rebuilt:
+// ocloc -file notcuda_ptx_impl.cl -64 -options "-cl-std=CL2.0" -out_dir . -device kbl -output_no_suffix -spv_only
+// Additionally you should strip names:
+// spirv-opt --strip-debug notcuda_ptx_impl.spv -o notcuda_ptx_impl.spv
+
+#define FUNC(NAME) __notcuda_ptx_impl__ ## NAME
+
+#define atomic_inc(NAME, SUCCESS, FAILURE, SCOPE, SPACE) \
+ uint FUNC(NAME)(SPACE uint* ptr, uint threshold) { \
+ uint expected = *ptr; \
+ uint desired; \
+ do { \
+ desired = (expected >= threshold) ? 0 : expected + 1; \
+ } while (!atomic_compare_exchange_strong_explicit((volatile SPACE atomic_uint*)ptr, &expected, desired, SUCCESS, FAILURE, SCOPE)); \
+ return expected; \
+ }
+
+#define atomic_dec(NAME, SUCCESS, FAILURE, SCOPE, SPACE) \
+ uint FUNC(NAME)(SPACE uint* ptr, uint threshold) { \
+ uint expected = *ptr; \
+ uint desired; \
+ do { \
+ desired = (expected == 0 || expected > threshold) ? threshold : expected - 1; \
+ } while (!atomic_compare_exchange_strong_explicit((volatile SPACE atomic_uint*)ptr, &expected, desired, SUCCESS, FAILURE, SCOPE)); \
+ return expected; \
+ }
+
+// We are doing all this mess instead of accepting memory_order and memory_scope parameters
+// because ocloc emits broken (failing spirv-dis) SPIR-V when memory_order or memory_scope is a parameter
+
+// atom.inc
+atomic_inc(atom_relaxed_cta_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, );
+atomic_inc(atom_acquire_cta_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, );
+atomic_inc(atom_release_cta_generic_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, );
+atomic_inc(atom_acq_rel_cta_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, );
+
+atomic_inc(atom_relaxed_gpu_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, );
+atomic_inc(atom_acquire_gpu_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, );
+atomic_inc(atom_release_gpu_generic_inc, memory_order_release, memory_order_acquire, memory_scope_device, );
+atomic_inc(atom_acq_rel_gpu_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, );
+
+atomic_inc(atom_relaxed_sys_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, );
+atomic_inc(atom_acquire_sys_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, );
+atomic_inc(atom_release_sys_generic_inc, memory_order_release, memory_order_acquire, memory_scope_device, );
+atomic_inc(atom_acq_rel_sys_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, );
+
+atomic_inc(atom_relaxed_cta_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global);
+atomic_inc(atom_acquire_cta_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global);
+atomic_inc(atom_release_cta_global_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, __global);
+atomic_inc(atom_acq_rel_cta_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global);
+
+atomic_inc(atom_relaxed_gpu_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global);
+atomic_inc(atom_acquire_gpu_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __global);
+atomic_inc(atom_release_gpu_global_inc, memory_order_release, memory_order_acquire, memory_scope_device, __global);
+atomic_inc(atom_acq_rel_gpu_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global);
+
+atomic_inc(atom_relaxed_sys_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global);
+atomic_inc(atom_acquire_sys_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __global);
+atomic_inc(atom_release_sys_global_inc, memory_order_release, memory_order_acquire, memory_scope_device, __global);
+atomic_inc(atom_acq_rel_sys_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global);
+
+atomic_inc(atom_relaxed_cta_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local);
+atomic_inc(atom_acquire_cta_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local);
+atomic_inc(atom_release_cta_shared_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, __local);
+atomic_inc(atom_acq_rel_cta_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local);
+
+atomic_inc(atom_relaxed_gpu_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local);
+atomic_inc(atom_acquire_gpu_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
+atomic_inc(atom_release_gpu_shared_inc, memory_order_release, memory_order_acquire, memory_scope_device, __local);
+atomic_inc(atom_acq_rel_gpu_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
+
+atomic_inc(atom_relaxed_sys_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local);
+atomic_inc(atom_acquire_sys_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
+atomic_inc(atom_release_sys_shared_inc, memory_order_release, memory_order_acquire, memory_scope_device, __local);
+atomic_inc(atom_acq_rel_sys_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
+
+// atom.dec
+atomic_dec(atom_relaxed_cta_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, );
+atomic_dec(atom_acquire_cta_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, );
+atomic_dec(atom_release_cta_generic_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, );
+atomic_dec(atom_acq_rel_cta_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, );
+
+atomic_dec(atom_relaxed_gpu_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, );
+atomic_dec(atom_acquire_gpu_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, );
+atomic_dec(atom_release_gpu_generic_dec, memory_order_release, memory_order_acquire, memory_scope_device, );
+atomic_dec(atom_acq_rel_gpu_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, );
+
+atomic_dec(atom_relaxed_sys_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, );
+atomic_dec(atom_acquire_sys_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, );
+atomic_dec(atom_release_sys_generic_dec, memory_order_release, memory_order_acquire, memory_scope_device, );
+atomic_dec(atom_acq_rel_sys_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, );
+
+atomic_dec(atom_relaxed_cta_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global);
+atomic_dec(atom_acquire_cta_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global);
+atomic_dec(atom_release_cta_global_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, __global);
+atomic_dec(atom_acq_rel_cta_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global);
+
+atomic_dec(atom_relaxed_gpu_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global);
+atomic_dec(atom_acquire_gpu_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __global);
+atomic_dec(atom_release_gpu_global_dec, memory_order_release, memory_order_acquire, memory_scope_device, __global);
+atomic_dec(atom_acq_rel_gpu_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global);
+
+atomic_dec(atom_relaxed_sys_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global);
+atomic_dec(atom_acquire_sys_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __global);
+atomic_dec(atom_release_sys_global_dec, memory_order_release, memory_order_acquire, memory_scope_device, __global);
+atomic_dec(atom_acq_rel_sys_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global);
+
+atomic_dec(atom_relaxed_cta_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local);
+atomic_dec(atom_acquire_cta_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local);
+atomic_dec(atom_release_cta_shared_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, __local);
+atomic_dec(atom_acq_rel_cta_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local);
+
+atomic_dec(atom_relaxed_gpu_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local);
+atomic_dec(atom_acquire_gpu_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
+atomic_dec(atom_release_gpu_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local);
+atomic_dec(atom_acq_rel_gpu_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
+
+atomic_dec(atom_relaxed_sys_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local);
+atomic_dec(atom_acquire_sys_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
+atomic_dec(atom_release_sys_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local);
+atomic_dec(atom_acq_rel_sys_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
diff --git a/ptx/lib/notcuda_ptx_impl.spv b/ptx/lib/notcuda_ptx_impl.spv Binary files differnew file mode 100644 index 0000000..36f37bb --- /dev/null +++ b/ptx/lib/notcuda_ptx_impl.spv diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1266ea4..ad8e87d 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -109,11 +109,12 @@ macro_rules! sub_type { }; } -// Pointer is used when doing SLM converison to SPIRV sub_type! { VariableRegType { Scalar(ScalarType), Vector(SizedScalarType, u8), + // Pointer variant is used when passing around SLM pointer between + // function calls for dynamic SLM Pointer(SizedScalarType, PointerStateSpace) } } @@ -215,6 +216,11 @@ sub_enum!(SelpType { F64, }); +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum BarDetails { + SyncAligned, +} + pub trait UnwrapWithVec<E, To> { fn unwrap_with(self, errs: &mut Vec<E>) -> To; } @@ -301,6 +307,7 @@ impl From<FnArgumentType> for Type { sub_enum!( PointerStateSpace : LdStateSpace { + Generic, Global, Const, Shared, @@ -372,6 +379,8 @@ sub_enum!(IntType { S64 }); +sub_enum!(BitType { B8, B16, B32, B64 }); + sub_enum!(UIntType { U8, U16, U32, U64 }); sub_enum!(SIntType { S8, S16, S32, S64 }); @@ -527,6 +536,9 @@ pub enum Instruction<P: ArgParams> { Rcp(RcpDetails, Arg2<P>), And(OrAndType, Arg3<P>), Selp(SelpType, Arg4<P>), + Bar(BarDetails, Arg1Bar<P>), + Atom(AtomDetails, Arg3<P>), + AtomCas(AtomCasDetails, Arg4<P>), } #[derive(Copy, Clone)] @@ -577,6 +589,10 @@ pub struct Arg1<P: ArgParams> { pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand } +pub struct Arg1Bar<P: ArgParams> { + pub src: P::Operand, +} + pub struct Arg2<P: ArgParams> { pub dst: P::Id, pub src: P::Operand, @@ -712,12 +728,12 @@ impl From<LdStType> for PointerType { pub enum LdStQualifier { Weak, Volatile, - Relaxed(LdScope), - Acquire(LdScope), + Relaxed(MemScope), + Acquire(MemScope), } #[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdScope { +pub enum MemScope { Cta, Gpu, Sys, @@ -1051,6 +1067,74 @@ pub struct MinMaxFloat { pub typ: FloatType, } +#[derive(Copy, Clone)] +pub struct AtomDetails { + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: AtomSpace, + pub inner: AtomInnerDetails, +} + +#[derive(Copy, Clone)] +pub enum AtomSemantics { + Relaxed, + Acquire, + Release, + AcquireRelease, +} + +#[derive(Copy, Clone)] +pub enum AtomSpace { + Generic, + Global, + Shared, +} + +#[derive(Copy, Clone)] +pub enum AtomInnerDetails { + Bit { op: AtomBitOp, typ: BitType }, + Unsigned { op: AtomUIntOp, typ: UIntType }, + Signed { op: AtomSIntOp, typ: SIntType }, + Float { op: AtomFloatOp, typ: FloatType }, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AtomBitOp { + And, + Or, + Xor, + Exchange, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AtomUIntOp { + Add, + Inc, + Dec, + Min, + Max, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AtomSIntOp { + Add, + Min, + Max, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum AtomFloatOp { + Add, +} + +#[derive(Copy, Clone)] +pub struct AtomCasDetails { + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: AtomSpace, + pub typ: BitType +} + pub enum NumsOrArrays<'a> { Nums(Vec<(&'a str, u32)>), Arrays(Vec<NumsOrArrays<'a>>), diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index dfe5a5f..806a3fc 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -35,9 +35,12 @@ match { "<", ">", "|", "=", + ".acq_rel", ".acquire", + ".add", ".address_size", ".align", + ".aligned", ".and", ".approx", ".b16", @@ -45,14 +48,17 @@ match { ".b64", ".b8", ".ca", + ".cas", ".cg", ".const", ".cs", ".cta", ".cv", + ".dec", ".entry", ".eq", ".equ", + ".exch", ".extern", ".f16", ".f16x2", @@ -69,6 +75,7 @@ match { ".gtu", ".hi", ".hs", + ".inc", ".le", ".leu", ".lo", @@ -78,6 +85,8 @@ match { ".lt", ".ltu", ".lu", + ".max", + ".min", ".nan", ".NaN", ".ne", @@ -88,6 +97,7 @@ match { ".pred", ".reg", ".relaxed", + ".release", ".rm", ".rmi", ".rn", @@ -103,6 +113,7 @@ match { ".sat", ".section", ".shared", + ".sync", ".sys", ".target", ".to", @@ -126,6 +137,9 @@ match { "abs", "add", "and", + "atom", + "bar", + "barrier", "bra", "call", "cvt", @@ -162,6 +176,9 @@ ExtendedID : &'input str = { "abs", "add", "and", + "atom", + "bar", + "barrier", "bra", "call", "cvt", @@ -372,6 +389,7 @@ StateSpaceSpecifier: ast::StateSpace = { ".param" => ast::StateSpace::Param, // used to prepare function call }; +#[inline] ScalarType: ast::ScalarType = { ".f16" => ast::ScalarType::F16, ".f16x2" => ast::ScalarType::F16x2, @@ -438,6 +456,7 @@ Variable: ast::Variable<ast::VariableType, &'input str> = { let v_type = ast::VariableType::Param(v_type); ast::Variable {align, v_type, name, array_init} }, + SharedVariable, }; RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = { @@ -478,6 +497,32 @@ LocalVariable: ast::Variable<ast::VariableType, &'input str> = { } } +SharedVariable: ast::Variable<ast::VariableType, &'input str> = { + ".shared" <var:VariableScalar<SizedScalarType>> => { + let (align, t, name) = var; + let v_type = ast::VariableGlobalType::Scalar(t); + ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + }, + ".shared" <var:VariableVector<SizedScalarType>> => { + let (align, v_len, t, name) = var; + let v_type = ast::VariableGlobalType::Vector(t, v_len); + ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + }, + ".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? { + let (align, t, name, arr_or_ptr) = var; + let (v_type, array_init) = match arr_or_ptr { + ast::ArrayOrPointer::Array { dimensions, init } => { + (ast::VariableGlobalType::Array(t, dimensions), init) + } + ast::ArrayOrPointer::Pointer => { + return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); + } + }; + Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init }) + } +} + + ModuleVariable: ast::Variable<ast::VariableType, &'input str> = { LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => { let (align, v_type, name, array_init) = def; @@ -619,7 +664,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstMin, InstMax, InstRcp, - InstSelp + InstSelp, + InstBar, + InstAtom, + InstAtomCas }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -655,14 +703,14 @@ LdStType: ast::LdStType = { LdStQualifier: ast::LdStQualifier = { ".weak" => ast::LdStQualifier::Weak, ".volatile" => ast::LdStQualifier::Volatile, - ".relaxed" <s:LdScope> => ast::LdStQualifier::Relaxed(s), - ".acquire" <s:LdScope> => ast::LdStQualifier::Acquire(s), + ".relaxed" <s:MemScope> => ast::LdStQualifier::Relaxed(s), + ".acquire" <s:MemScope> => ast::LdStQualifier::Acquire(s), }; -LdScope: ast::LdScope = { - ".cta" => ast::LdScope::Cta, - ".gpu" => ast::LdScope::Gpu, - ".sys" => ast::LdScope::Sys +MemScope: ast::MemScope = { + ".cta" => ast::MemScope::Cta, + ".gpu" => ast::MemScope::Gpu, + ".sys" => ast::MemScope::Sys }; LdStateSpace: ast::LdStateSpace = { @@ -798,6 +846,13 @@ SIntType: ast::SIntType = { ".s64" => ast::SIntType::S64, }; +FloatType: ast::FloatType = { + ".f16" => ast::FloatType::F16, + ".f16x2" => ast::FloatType::F16x2, + ".f32" => ast::FloatType::F32, + ".f64" => ast::FloatType::F64, +}; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add @@ -1296,6 +1351,140 @@ SelpType: ast::SelpType = { ".f64" => ast::SelpType::F64, }; +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar +InstBar: ast::Instruction<ast::ParsedArgParams<'input>> = { + "barrier" ".sync" ".aligned" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a), + "bar" ".sync" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a) +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom +// The documentation does not mention all spported operations: +// * Operation .add requires .u32 or .s32 or .u64 or .f64 or f16 or f16x2 or .f32 +// * Operation .inc requires .u32 type for instuction +// * Operation .dec requires .u32 type for instuction +// Otherwise as documented +InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op:AtomBitOp> <typ:AtomBitType> <a:Arg3Atom> => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Bit { op, typ } + }; + ast::Instruction::Atom(details,a) + }, + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".inc" ".u32" <a:Arg3Atom> => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Unsigned { + op: ast::AtomUIntOp::Inc, + typ: ast::UIntType::U32 + } + }; + ast::Instruction::Atom(details,a) + }, + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".dec" ".u32" <a:Arg3Atom> => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Unsigned { + op: ast::AtomUIntOp::Dec, + typ: ast::UIntType::U32 + } + }; + ast::Instruction::Atom(details,a) + }, + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".add" <typ:FloatType> <a:Arg3Atom> => { + let op = ast::AtomFloatOp::Add; + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Float { op, typ } + }; + ast::Instruction::Atom(details,a) + }, + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op: AtomUIntOp> <typ:AtomUIntType> <a:Arg3Atom> => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Unsigned { op, typ } + }; + ast::Instruction::Atom(details,a) + }, + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op: AtomSIntOp> <typ:AtomSIntType> <a:Arg3Atom> => { + let details = ast::AtomDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + inner: ast::AtomInnerDetails::Signed { op, typ } + }; + ast::Instruction::Atom(details,a) + } +} + +InstAtomCas: ast::Instruction<ast::ParsedArgParams<'input>> = { + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".cas" <typ:AtomBitType> <a:Arg4Atom> => { + let details = ast::AtomCasDetails { + semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), + scope: scope.unwrap_or(ast::MemScope::Gpu), + space: space.unwrap_or(ast::AtomSpace::Generic), + typ, + }; + ast::Instruction::AtomCas(details,a) + }, +} + +AtomSemantics: ast::AtomSemantics = { + ".relaxed" => ast::AtomSemantics::Relaxed, + ".acquire" => ast::AtomSemantics::Acquire, + ".release" => ast::AtomSemantics::Release, + ".acq_rel" => ast::AtomSemantics::AcquireRelease +} + +AtomSpace: ast::AtomSpace = { + ".global" => ast::AtomSpace::Global, + ".shared" => ast::AtomSpace::Shared +} + +AtomBitOp: ast::AtomBitOp = { + ".and" => ast::AtomBitOp::And, + ".or" => ast::AtomBitOp::Or, + ".xor" => ast::AtomBitOp::Xor, + ".exch" => ast::AtomBitOp::Exchange, +} + +AtomUIntOp: ast::AtomUIntOp = { + ".add" => ast::AtomUIntOp::Add, + ".min" => ast::AtomUIntOp::Min, + ".max" => ast::AtomUIntOp::Max, +} + +AtomSIntOp: ast::AtomSIntOp = { + ".add" => ast::AtomSIntOp::Add, + ".min" => ast::AtomSIntOp::Min, + ".max" => ast::AtomSIntOp::Max, +} + +AtomBitType: ast::BitType = { + ".b32" => ast::BitType::B32, + ".b64" => ast::BitType::B64, +} + +AtomUIntType: ast::UIntType = { + ".u32" => ast::UIntType::U32, + ".u64" => ast::UIntType::U64, +} + +AtomSIntType: ast::SIntType = { + ".s32" => ast::SIntType::S32, + ".s64" => ast::SIntType::S64, +} + ArithDetails: ast::ArithDetails = { <t:UIntType> => ast::ArithDetails::Unsigned(t), <t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt { @@ -1414,6 +1603,10 @@ Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = { <src:ExtendedID> => ast::Arg1{<>} }; +Arg1Bar: ast::Arg1Bar<ast::ParsedArgParams<'input>> = { + <src:Operand> => ast::Arg1Bar{<>} +}; + Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = { <dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>} }; @@ -1448,10 +1641,18 @@ Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = { <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>} }; +Arg3Atom: ast::Arg3<ast::ParsedArgParams<'input>> = { + <dst:ExtendedID> "," "[" <src1:Operand> "]" "," <src2:Operand> => ast::Arg3{<>} +}; + Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = { <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>} }; +Arg4Atom: ast::Arg4<ast::ParsedArgParams<'input>> = { + <dst:ExtendedID> "," "[" <src1:Operand> "]" "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>} +}; + Arg4Setp: ast::Arg4Setp<ast::ParsedArgParams<'input>> = { <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4Setp{<>} }; diff --git a/ptx/src/test/spirv_build/bar_sync.ptx b/ptx/src/test/spirv_build/bar_sync.ptx new file mode 100644 index 0000000..54c6663 --- /dev/null +++ b/ptx/src/test/spirv_build/bar_sync.ptx @@ -0,0 +1,10 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry bar_sync()
+{
+ .reg .u32 temp_32;
+ bar.sync temp_32;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt index 9b72477..8358c28 100644 --- a/ptx/src/test/spirv_run/and.spvtxt +++ b/ptx/src/test/spirv_run/and.spvtxt @@ -11,8 +11,8 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %33 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "and" diff --git a/ptx/src/test/spirv_run/atom_add.ptx b/ptx/src/test/spirv_run/atom_add.ptx new file mode 100644 index 0000000..5d1f667 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_add.ptx @@ -0,0 +1,28 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry atom_add(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .shared .align 4 .b8 shared_mem[1024];
+
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp1;
+ .reg .u32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u32 temp1, [in_addr];
+ ld.u32 temp2, [in_addr+4];
+ st.shared.u32 [shared_mem], temp1;
+ atom.shared.add.u32 temp1, [shared_mem], temp2;
+ ld.shared.u32 temp2, [shared_mem];
+ st.u32 [out_addr], temp1;
+ st.u32 [out_addr+4], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt new file mode 100644 index 0000000..2c83fe9 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_add.spvtxt @@ -0,0 +1,84 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 55 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%40 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "atom_add" %4 +OpDecorate %4 Alignment 4 +%41 = OpTypeVoid +%42 = OpTypeInt 32 0 +%43 = OpTypeInt 8 0 +%44 = OpConstant %42 1024 +%45 = OpTypeArray %43 %44 +%46 = OpTypePointer Workgroup %45 +%4 = OpVariable %46 Workgroup +%47 = OpTypeInt 64 0 +%48 = OpTypeFunction %41 %47 %47 +%49 = OpTypePointer Function %47 +%50 = OpTypePointer Function %42 +%51 = OpTypePointer Generic %42 +%27 = OpConstant %47 4 +%52 = OpTypePointer Workgroup %42 +%53 = OpConstant %42 1 +%54 = OpConstant %42 0 +%29 = OpConstant %47 4 +%1 = OpFunction %41 None %48 +%9 = OpFunctionParameter %47 +%10 = OpFunctionParameter %47 +%38 = OpLabel +%2 = OpVariable %49 Function +%3 = OpVariable %49 Function +%5 = OpVariable %49 Function +%6 = OpVariable %49 Function +%7 = OpVariable %50 Function +%8 = OpVariable %50 Function +OpStore %2 %9 +OpStore %3 %10 +%12 = OpLoad %47 %2 +%11 = OpCopyObject %47 %12 +OpStore %5 %11 +%14 = OpLoad %47 %3 +%13 = OpCopyObject %47 %14 +OpStore %6 %13 +%16 = OpLoad %47 %5 +%31 = OpConvertUToPtr %51 %16 +%15 = OpLoad %42 %31 +OpStore %7 %15 +%18 = OpLoad %47 %5 +%28 = OpIAdd %47 %18 %27 +%32 = OpConvertUToPtr %51 %28 +%17 = OpLoad %42 %32 +OpStore %8 %17 +%19 = OpLoad %42 %7 +%33 = OpBitcast %52 %4 +OpStore %33 %19 +%21 = OpLoad %42 %8 +%34 = OpBitcast %52 %4 +%20 = OpAtomicIAdd %42 %34 %53 %54 %21 +OpStore %7 %20 +%35 = OpBitcast %52 %4 +%22 = OpLoad %42 %35 +OpStore %8 %22 +%23 = OpLoad %47 %6 +%24 = OpLoad %42 %7 +%36 = OpConvertUToPtr %51 %23 +OpStore %36 %24 +%25 = OpLoad %47 %6 +%26 = OpLoad %42 %8 +%30 = OpIAdd %47 %25 %29 +%37 = OpConvertUToPtr %51 %30 +OpStore %37 %26 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/atom_cas.ptx b/ptx/src/test/spirv_run/atom_cas.ptx new file mode 100644 index 0000000..440a1cb --- /dev/null +++ b/ptx/src/test/spirv_run/atom_cas.ptx @@ -0,0 +1,24 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry atom_cas(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp1;
+ .reg .u32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u32 temp1, [in_addr];
+ atom.cas.b32 temp1, [in_addr+4], temp1, 100;
+ ld.u32 temp2, [in_addr+4];
+ st.u32 [out_addr], temp1;
+ st.u32 [out_addr+4], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/atom_cas.spvtxt b/ptx/src/test/spirv_run/atom_cas.spvtxt new file mode 100644 index 0000000..c5fb922 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_cas.spvtxt @@ -0,0 +1,77 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 51 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%41 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "atom_cas" +%42 = OpTypeVoid +%43 = OpTypeInt 64 0 +%44 = OpTypeFunction %42 %43 %43 +%45 = OpTypePointer Function %43 +%46 = OpTypeInt 32 0 +%47 = OpTypePointer Function %46 +%48 = OpTypePointer Generic %46 +%25 = OpConstant %43 4 +%27 = OpConstant %46 100 +%49 = OpConstant %46 1 +%50 = OpConstant %46 0 +%28 = OpConstant %43 4 +%30 = OpConstant %43 4 +%1 = OpFunction %42 None %44 +%8 = OpFunctionParameter %43 +%9 = OpFunctionParameter %43 +%39 = OpLabel +%2 = OpVariable %45 Function +%3 = OpVariable %45 Function +%4 = OpVariable %45 Function +%5 = OpVariable %45 Function +%6 = OpVariable %47 Function +%7 = OpVariable %47 Function +OpStore %2 %8 +OpStore %3 %9 +%11 = OpLoad %43 %2 +%10 = OpCopyObject %43 %11 +OpStore %4 %10 +%13 = OpLoad %43 %3 +%12 = OpCopyObject %43 %13 +OpStore %5 %12 +%15 = OpLoad %43 %4 +%32 = OpConvertUToPtr %48 %15 +%14 = OpLoad %46 %32 +OpStore %6 %14 +%17 = OpLoad %43 %4 +%18 = OpLoad %46 %6 +%26 = OpIAdd %43 %17 %25 +%34 = OpConvertUToPtr %48 %26 +%35 = OpCopyObject %46 %18 +%33 = OpAtomicCompareExchange %46 %34 %49 %50 %50 %27 %35 +%16 = OpCopyObject %46 %33 +OpStore %6 %16 +%20 = OpLoad %43 %4 +%29 = OpIAdd %43 %20 %28 +%36 = OpConvertUToPtr %48 %29 +%19 = OpLoad %46 %36 +OpStore %7 %19 +%21 = OpLoad %43 %5 +%22 = OpLoad %46 %6 +%37 = OpConvertUToPtr %48 %21 +OpStore %37 %22 +%23 = OpLoad %43 %5 +%24 = OpLoad %46 %7 +%31 = OpIAdd %43 %23 %30 +%38 = OpConvertUToPtr %48 %31 +OpStore %38 %24 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/atom_inc.ptx b/ptx/src/test/spirv_run/atom_inc.ptx new file mode 100644 index 0000000..ed3df08 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_inc.ptx @@ -0,0 +1,26 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry atom_inc(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp1;
+ .reg .u32 temp2;
+ .reg .u32 temp3;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ atom.inc.u32 temp1, [in_addr], 101;
+ atom.global.inc.u32 temp2, [in_addr], 101;
+ ld.u32 temp3, [in_addr];
+ st.u32 [out_addr], temp1;
+ st.u32 [out_addr+4], temp2;
+ st.u32 [out_addr+8], temp3;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt new file mode 100644 index 0000000..6948cd9 --- /dev/null +++ b/ptx/src/test/spirv_run/atom_inc.spvtxt @@ -0,0 +1,89 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 60 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%49 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "atom_inc" +OpDecorate %40 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import +OpDecorate %44 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import +%50 = OpTypeVoid +%51 = OpTypeInt 32 0 +%52 = OpTypePointer Generic %51 +%53 = OpTypeFunction %51 %52 %51 +%54 = OpTypePointer CrossWorkgroup %51 +%55 = OpTypeFunction %51 %54 %51 +%56 = OpTypeInt 64 0 +%57 = OpTypeFunction %50 %56 %56 +%58 = OpTypePointer Function %56 +%59 = OpTypePointer Function %51 +%27 = OpConstant %51 101 +%28 = OpConstant %51 101 +%29 = OpConstant %56 4 +%31 = OpConstant %56 8 +%40 = OpFunction %51 None %53 +%42 = OpFunctionParameter %52 +%43 = OpFunctionParameter %51 +OpFunctionEnd +%44 = OpFunction %51 None %55 +%46 = OpFunctionParameter %54 +%47 = OpFunctionParameter %51 +OpFunctionEnd +%1 = OpFunction %50 None %57 +%9 = OpFunctionParameter %56 +%10 = OpFunctionParameter %56 +%39 = OpLabel +%2 = OpVariable %58 Function +%3 = OpVariable %58 Function +%4 = OpVariable %58 Function +%5 = OpVariable %58 Function +%6 = OpVariable %59 Function +%7 = OpVariable %59 Function +%8 = OpVariable %59 Function +OpStore %2 %9 +OpStore %3 %10 +%12 = OpLoad %56 %2 +%11 = OpCopyObject %56 %12 +OpStore %4 %11 +%14 = OpLoad %56 %3 +%13 = OpCopyObject %56 %14 +OpStore %5 %13 +%16 = OpLoad %56 %4 +%33 = OpConvertUToPtr %52 %16 +%15 = OpFunctionCall %51 %40 %33 %27 +OpStore %6 %15 +%18 = OpLoad %56 %4 +%34 = OpConvertUToPtr %54 %18 +%17 = OpFunctionCall %51 %44 %34 %28 +OpStore %7 %17 +%20 = OpLoad %56 %4 +%35 = OpConvertUToPtr %52 %20 +%19 = OpLoad %51 %35 +OpStore %8 %19 +%21 = OpLoad %56 %5 +%22 = OpLoad %51 %6 +%36 = OpConvertUToPtr %52 %21 +OpStore %36 %22 +%23 = OpLoad %56 %5 +%24 = OpLoad %51 %7 +%30 = OpIAdd %56 %23 %29 +%37 = OpConvertUToPtr %52 %30 +OpStore %37 %24 +%25 = OpLoad %56 %5 +%26 = OpLoad %51 %8 +%32 = OpIAdd %56 %25 %31 +%38 = OpConvertUToPtr %52 %32 +OpStore %38 %26 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/constant_f32.spvtxt b/ptx/src/test/spirv_run/constant_f32.spvtxt index 905bec4..27c5f4e 100644 --- a/ptx/src/test/spirv_run/constant_f32.spvtxt +++ b/ptx/src/test/spirv_run/constant_f32.spvtxt @@ -11,12 +11,12 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "constant_f32" -OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +; OpDecorate %1 FunctionDenormModeINTEL 32 Preserve %25 = OpTypeVoid %26 = OpTypeInt 64 0 %27 = OpTypeFunction %25 %26 %26 diff --git a/ptx/src/test/spirv_run/constant_negative.spvtxt b/ptx/src/test/spirv_run/constant_negative.spvtxt index 39e5d19..ec2ff72 100644 --- a/ptx/src/test/spirv_run/constant_negative.spvtxt +++ b/ptx/src/test/spirv_run/constant_negative.spvtxt @@ -11,8 +11,8 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "constant_negative" diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt index 734bf0f..4a90d09 100644 --- a/ptx/src/test/spirv_run/fma.spvtxt +++ b/ptx/src/test/spirv_run/fma.spvtxt @@ -11,12 +11,12 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %37 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "fma" -OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +; OpDecorate %1 FunctionDenormModeINTEL 32 Preserve %38 = OpTypeVoid %39 = OpTypeInt 64 0 %40 = OpTypeFunction %38 %39 %39 diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 98b9630..40a9d64 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -86,12 +86,20 @@ test_ptx!(rcp, [2f32], [0.5f32]); // 0x3f000000 is 0.5
// TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2
// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]);
-test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]);
+test_ptx!(
+ mul_non_ftz,
+ [0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
+ [0b1_00000000_01000000000000000000000u32]
+);
test_ptx!(constant_f32, [10f32], [5f32]);
test_ptx!(constant_negative, [-101i32], [101i32]);
test_ptx!(and, [6u32, 3u32], [2u32]);
test_ptx!(selp, [100u16, 200u16], [200u16]);
-test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
+test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
+test_ptx!(shared_variable, [513u64], [513u64]);
+test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]);
+test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]);
+test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
struct DisplayError<T: Debug> {
err: T,
@@ -124,7 +132,7 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>( let name = CString::new(name)?;
let result = run_spirv(name.as_c_str(), notcuda_module, input, output)
.map_err(|err| DisplayError { err })?;
- assert_eq!(output, result.as_slice());
+ assert_eq!(result.as_slice(), output);
Ok(())
}
@@ -145,8 +153,8 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>( let use_shared_mem = module
.kernel_info
.get(name.to_str().unwrap())
- .unwrap()
- .uses_shared_mem;
+ .map(|info| info.uses_shared_mem)
+ .unwrap_or(false);
let mut result = vec![0u8.into(); output.len()];
{
let mut drivers = ze::Driver::get()?;
@@ -155,11 +163,20 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>( let mut devices = drv.devices()?;
let dev = devices.drain(0..1).next().unwrap();
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
- let (module, log) = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None);
+ let (module, maybe_log) = match module.should_link_ptx_impl {
+ Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]),
+ None => {
+ let (module, log) = ze::Module::build_spirv(&mut ctx, &dev, byte_il, None);
+ (module, Some(log))
+ }
+ };
let module = match module {
Ok(m) => m,
Err(err) => {
- let raw_err_string = log.get_cstring()?;
+ let raw_err_string = maybe_log
+ .map(|log| log.get_cstring())
+ .transpose()?
+ .unwrap_or(CString::default());
let err_string = raw_err_string.to_string_lossy();
panic!("{:?}\n{}", err, err_string);
}
@@ -215,7 +232,11 @@ fn test_spvtxt_assert<'a>( ptr::null_mut(),
)
};
- assert!(result == spv_result_t::SPV_SUCCESS);
+ if result != spv_result_t::SPV_SUCCESS {
+ panic!("{:?}\n{}", result, unsafe {
+ str::from_utf8_unchecked(spirv_txt)
+ });
+ }
let mut parsed_spirv = Vec::<u32>::new();
let result = unsafe {
spirv_tools::spvBinaryParse(
diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt index da6a12a..56cec5a 100644 --- a/ptx/src/test/spirv_run/mul_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt @@ -11,8 +11,8 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %30 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "mul_ftz" diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt index dffd9af..6f73bc2 100644 --- a/ptx/src/test/spirv_run/selp.spvtxt +++ b/ptx/src/test/spirv_run/selp.spvtxt @@ -11,8 +11,8 @@ OpCapability Int16 OpCapability Int64 OpCapability Float16 OpCapability Float64 -OpCapability FunctionFloatControlINTEL -OpExtension "SPV_INTEL_float_controls2" +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" %31 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "selp" diff --git a/ptx/src/test/spirv_run/shared_variable.ptx b/ptx/src/test/spirv_run/shared_variable.ptx new file mode 100644 index 0000000..4f7eff3 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_variable.ptx @@ -0,0 +1,26 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+
+.visible .entry shared_variable(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .shared .align 4 .b8 shared_mem1[128];
+
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u64 temp1, [in_addr];
+ st.shared.u64 [shared_mem1], temp1;
+ ld.shared.u64 temp2, [shared_mem1];
+ st.global.u64 [out_addr], temp2;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/shared_variable.spvtxt b/ptx/src/test/spirv_run/shared_variable.spvtxt new file mode 100644 index 0000000..1af2bd1 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_variable.spvtxt @@ -0,0 +1,65 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 39 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%27 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "shared_variable" %4 +OpDecorate %4 Alignment 4 +%28 = OpTypeVoid +%29 = OpTypeInt 32 0 +%30 = OpTypeInt 8 0 +%31 = OpConstant %29 128 +%32 = OpTypeArray %30 %31 +%33 = OpTypePointer Workgroup %32 +%4 = OpVariable %33 Workgroup +%34 = OpTypeInt 64 0 +%35 = OpTypeFunction %28 %34 %34 +%36 = OpTypePointer Function %34 +%37 = OpTypePointer CrossWorkgroup %34 +%38 = OpTypePointer Workgroup %34 +%1 = OpFunction %28 None %35 +%9 = OpFunctionParameter %34 +%10 = OpFunctionParameter %34 +%25 = OpLabel +%2 = OpVariable %36 Function +%3 = OpVariable %36 Function +%5 = OpVariable %36 Function +%6 = OpVariable %36 Function +%7 = OpVariable %36 Function +%8 = OpVariable %36 Function +OpStore %2 %9 +OpStore %3 %10 +%12 = OpLoad %34 %2 +%11 = OpCopyObject %34 %12 +OpStore %5 %11 +%14 = OpLoad %34 %3 +%13 = OpCopyObject %34 %14 +OpStore %6 %13 +%16 = OpLoad %34 %5 +%21 = OpConvertUToPtr %37 %16 +%15 = OpLoad %34 %21 +OpStore %7 %15 +%17 = OpLoad %34 %7 +%22 = OpBitcast %38 %4 +OpStore %22 %17 +%23 = OpBitcast %38 %4 +%18 = OpLoad %34 %23 +OpStore %8 %18 +%19 = OpLoad %34 %6 +%20 = OpLoad %34 %8 +%24 = OpConvertUToPtr %37 %19 +OpStore %24 %20 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a7025b1..6b07c0f 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,14 +1,13 @@ use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
+use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, hash::Hash, iter, mem};
-use std::{
- collections::{hash_map, HashMap, HashSet},
- convert::TryFrom,
-};
use rspirv::binary::Assemble;
+static NOTCUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/notcuda_ptx_impl.spv");
+
quick_error! {
#[derive(Debug)]
pub enum TranslateError {
@@ -69,6 +68,7 @@ impl Into<spirv::StorageClass> for ast::PointerStateSpace { ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
ast::PointerStateSpace::Param => spirv::StorageClass::Function,
+ ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
}
}
}
@@ -419,6 +419,7 @@ impl TypeWordMap { pub struct Module {
pub spirv: dr::Module,
pub kernel_info: HashMap<String, KernelInfo>,
+ pub should_link_ptx_impl: Option<&'static [u8]>,
}
pub struct KernelInfo {
@@ -428,15 +429,22 @@ pub struct KernelInfo { pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateError> {
let mut id_defs = GlobalStringIdResolver::new(1);
+ let mut ptx_impl_imports = HashMap::new();
let directives = ast
.directives
.into_iter()
- .map(|f| translate_directive(&mut id_defs, f))
+ .map(|directive| translate_directive(&mut id_defs, &mut ptx_impl_imports, directive))
.collect::<Result<Vec<_>, _>>()?;
+ let must_link_ptx_impl = ptx_impl_imports.len() > 0;
+ let directives = ptx_impl_imports
+ .into_iter()
+ .map(|(_, v)| v)
+ .chain(directives.into_iter())
+ .collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
- let mut directives =
- convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id());
+ let call_map = get_call_map(&directives);
+ let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@@ -448,32 +456,142 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
- for d in directives {
+ emit_directives(
+ &mut builder,
+ &mut map,
+ &id_defs,
+ opencl_id,
+ &denorm_information,
+ &call_map,
+ directives,
+ &mut kernel_info,
+ )?;
+ let spirv = builder.module();
+ Ok(Module {
+ spirv,
+ kernel_info,
+ should_link_ptx_impl: if must_link_ptx_impl {
+ Some(NOTCUDA_PTX_IMPL)
+ } else {
+ None
+ },
+ })
+}
+
+fn emit_directives<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver<'input>,
+ opencl_id: spirv::Word,
+ denorm_information: &HashMap<CallgraphKey<'input>, HashMap<u8, spirv::FPDenormMode>>,
+ call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
+ directives: Vec<Directive>,
+ kernel_info: &mut HashMap<String, KernelInfo>,
+) -> Result<(), TranslateError> {
+ let empty_body = Vec::new();
+ for d in directives.iter() {
match d {
Directive::Variable(var) => {
- emit_variable(&mut builder, &mut map, &var)?;
+ emit_variable(builder, map, &var)?;
}
Directive::Method(f) => {
- let f_body = match f.body {
+ let f_body = match &f.body {
Some(f) => f,
- None => continue,
+ None => {
+ if f.import_as.is_some() {
+ &empty_body
+ } else {
+ continue;
+ }
+ }
};
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
+ for var in f.globals.iter() {
+ emit_variable(builder, map, var)?;
+ }
emit_function_header(
- &mut builder,
- &mut map,
+ builder,
+ map,
&id_defs,
- f.func_decl,
+ &f.globals,
+ &f.func_decl,
&denorm_information,
- &mut kernel_info,
+ call_map,
+ &directives,
+ kernel_info,
)?;
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
+ emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
+ if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
+ (&f.func_decl, &f.import_as)
+ {
+ builder.decorate(
+ *fn_id,
+ spirv::Decoration::LinkageAttributes,
+ &[
+ dr::Operand::LiteralString(name.clone()),
+ dr::Operand::LinkageType(spirv::LinkageType::Import),
+ ],
+ );
+ }
}
}
}
- let spirv = builder.module();
- Ok(Module { spirv, kernel_info })
+ Ok(())
+}
+
+fn get_call_map<'input>(
+ module: &[Directive<'input>],
+) -> HashMap<&'input str, HashSet<spirv::Word>> {
+ let mut directly_called_by = HashMap::new();
+ for directive in module {
+ match directive {
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ ..
+ }) => {
+ let call_key = CallgraphKey::new(&func_decl);
+ for statement in statements {
+ match statement {
+ Statement::Call(call) => {
+ multi_hash_map_append(&mut directly_called_by, call_key, call.func);
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ let mut result = HashMap::new();
+ for (method_key, children) in directly_called_by.iter() {
+ match method_key {
+ CallgraphKey::Kernel(name) => {
+ let mut visited = HashSet::new();
+ for child in children {
+ add_call_map_single(&directly_called_by, &mut visited, *child);
+ }
+ result.insert(*name, visited);
+ }
+ CallgraphKey::Func(_) => {}
+ }
+ }
+ result
+}
+
+fn add_call_map_single<'input>(
+ directly_called_by: &MultiHashMap<CallgraphKey<'input>, spirv::Word>,
+ visited: &mut HashSet<spirv::Word>,
+ current: spirv::Word,
+) {
+ if !visited.insert(current) {
+ return;
+ }
+ if let Some(children) = directly_called_by.get(&CallgraphKey::Func(current)) {
+ for child in children {
+ add_call_map_single(directly_called_by, visited, *child);
+ }
+ }
}
type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
@@ -495,7 +613,6 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, // This pass looks for all uses of .extern .shared and converts them to
// an additional method argument
fn convert_dynamic_shared_memory_usage<'input>(
- id_defs: &mut GlobalStringIdResolver<'input>,
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
) -> Vec<Directive<'input>> {
@@ -524,6 +641,7 @@ fn convert_dynamic_shared_memory_usage<'input>( func_decl,
globals,
body: Some(statements),
+ import_as,
}) => {
let call_key = CallgraphKey::new(&func_decl);
let statements = statements
@@ -545,6 +663,7 @@ fn convert_dynamic_shared_memory_usage<'input>( func_decl,
globals,
body: Some(statements),
+ import_as,
})
}
directive => directive,
@@ -561,6 +680,7 @@ fn convert_dynamic_shared_memory_usage<'input>( mut func_decl,
globals,
body: Some(statements),
+ import_as,
}) => {
let call_key = CallgraphKey::new(&func_decl);
if !methods_using_extern_shared.contains(&call_key) {
@@ -568,6 +688,7 @@ fn convert_dynamic_shared_memory_usage<'input>( func_decl,
globals,
body: Some(statements),
+ import_as,
});
}
let shared_id_param = new_id();
@@ -625,6 +746,7 @@ fn convert_dynamic_shared_memory_usage<'input>( func_decl,
globals,
body: Some(new_statements),
+ import_as,
})
}
directive => directive,
@@ -744,15 +866,6 @@ fn denorm_count_map_update_impl<T: Eq + Hash>( }
}
-fn denorm_count_map_merge<T: Eq + Hash + Copy>(
- dst: &mut DenormCountMap<T>,
- src: &DenormCountMap<T>,
-) {
- for (k, count) in src {
- denorm_count_map_update_impl(dst, *k, *count);
- }
-}
-
// HACK ALERT!
// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
// in the kernel as flushing denorms to zero or preserving them
@@ -763,7 +876,7 @@ fn compute_denorm_information<'input>( module: &[Directive<'input>],
) -> HashMap<CallgraphKey<'input>, HashMap<u8, spirv::FPDenormMode>> {
let mut denorm_methods = HashMap::new();
- for directive in module.iter() {
+ for directive in module {
match directive {
Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
@@ -861,9 +974,12 @@ fn emit_builtins( fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- global: &GlobalStringIdResolver<'a>,
- func_directive: ast::MethodDecl<spirv::Word>,
+ defined_globals: &GlobalStringIdResolver<'a>,
+ synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
+ func_directive: &ast::MethodDecl<spirv::Word>,
denorm_information: &HashMap<CallgraphKey<'a>, HashMap<u8, spirv::FPDenormMode>>,
+ call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
+ direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
if let ast::MethodDecl::Kernel {
@@ -884,22 +1000,49 @@ fn emit_function_header<'a>( let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
let fn_id = match func_directive {
ast::MethodDecl::Kernel { name, .. } => {
- let fn_id = global.get_id(name)?;
- let mut global_variables = global
+ let fn_id = defined_globals.get_id(name)?;
+ let mut global_variables = defined_globals
.variables_type_check
.iter()
.filter_map(|(k, t)| t.as_ref().map(|_| *k))
.collect::<Vec<_>>();
- let mut interface = global
+ let mut interface = defined_globals
.special_registers
.iter()
.map(|(_, id)| *id)
.collect::<Vec<_>>();
+ for ast::Variable { name, .. } in synthetic_globals {
+ interface.push(*name);
+ }
+ let empty_hash_set = HashSet::new();
+ let child_fns = call_map.get(name).unwrap_or(&empty_hash_set);
+ for directive in direcitves {
+ match directive {
+ Directive::Method(Function {
+ func_decl: ast::MethodDecl::Func(_, name, _),
+ globals,
+ ..
+ }) => {
+ if child_fns.contains(name) {
+ for var in globals {
+ interface.push(var.name);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+
global_variables.append(&mut interface);
- builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
+ builder.entry_point(
+ spirv::ExecutionModel::Kernel,
+ fn_id,
+ *name,
+ global_variables,
+ );
fn_id
}
- ast::MethodDecl::Func(_, name, _) => name,
+ ast::MethodDecl::Func(_, name, _) => *name,
};
builder.begin_function(
ret_type,
@@ -934,9 +1077,10 @@ fn emit_function_header<'a>( pub fn to_spirv<'a>(
ast: ast::Module<'a>,
-) -> Result<(Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
+) -> Result<(Option<&'static [u8]>, Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
let module = to_spirv_module(ast)?;
Ok((
+ module.should_link_ptx_impl,
module.spirv.assemble(),
module
.kernel_info
@@ -977,11 +1121,14 @@ fn emit_memory_model(builder: &mut dr::Builder) { fn translate_directive<'input>(
id_defs: &mut GlobalStringIdResolver<'input>,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Directive<'input>, TranslateError> {
Ok(match d {
ast::Directive::Variable(v) => Directive::Variable(translate_variable(id_defs, v)?),
- ast::Directive::Method(f) => Directive::Method(translate_function(id_defs, f)?),
+ ast::Directive::Method(f) => {
+ Directive::Method(translate_function(id_defs, ptx_impl_imports, f)?)
+ }
})
}
@@ -1000,10 +1147,11 @@ fn translate_variable<'a>( fn translate_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
f: ast::ParsedFunction<'a>,
) -> Result<Function<'a>, TranslateError> {
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive);
- to_ssa(str_resolver, fn_resolver, fn_decl, f.body)
+ to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)
}
fn expand_kernel_params<'a, 'b>(
@@ -1043,6 +1191,7 @@ fn expand_fn_params<'a, 'b>( }
fn to_ssa<'input, 'b>(
+ ptx_impl_imports: &mut HashMap<String, Directive>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
f_args: ast::MethodDecl<'input, spirv::Word>,
@@ -1055,6 +1204,7 @@ fn to_ssa<'input, 'b>( func_decl: f_args,
body: None,
globals: Vec::new(),
+ import_as: None,
})
}
};
@@ -1071,19 +1221,90 @@ fn to_ssa<'input, 'b>( insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
- let (f_body, globals) = extract_globals(labeled_statements);
+ let (f_body, globals) =
+ extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
Ok(Function {
func_decl: f_args,
globals: globals,
body: Some(f_body),
+ import_as: None,
})
}
-fn extract_globals(
+fn extract_globals<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
-) -> (Vec<ExpandedStatement>, Vec<ExpandedStatement>) {
- // This fn will be used for SLM
- (sorted_statements, Vec::new())
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ id_def: &mut NumericIdResolver,
+) -> (
+ Vec<ExpandedStatement>,
+ Vec<ast::Variable<ast::VariableType, spirv::Word>>,
+) {
+ let mut local = Vec::with_capacity(sorted_statements.len());
+ let mut global = Vec::new();
+ for statement in sorted_statements {
+ match statement {
+ Statement::Variable(
+ var
+ @
+ ast::Variable {
+ v_type: ast::VariableType::Shared(_),
+ ..
+ },
+ )
+ | Statement::Variable(
+ var
+ @
+ ast::Variable {
+ v_type: ast::VariableType::Global(_),
+ ..
+ },
+ ) => global.push(var),
+ Statement::Instruction(ast::Instruction::Atom(
+ d
+ @
+ ast::AtomDetails {
+ inner:
+ ast::AtomInnerDetails::Unsigned {
+ op: ast::AtomUIntOp::Inc,
+ ..
+ },
+ ..
+ },
+ a,
+ )) => {
+ local.push(to_ptx_impl_atomic_call(
+ id_def,
+ ptx_impl_imports,
+ d,
+ a,
+ "inc",
+ ));
+ }
+ Statement::Instruction(ast::Instruction::Atom(
+ d
+ @
+ ast::AtomDetails {
+ inner:
+ ast::AtomInnerDetails::Unsigned {
+ op: ast::AtomUIntOp::Dec,
+ ..
+ },
+ ..
+ },
+ a,
+ )) => {
+ local.push(to_ptx_impl_atomic_call(
+ id_def,
+ ptx_impl_imports,
+ d,
+ a,
+ "dec",
+ ));
+ }
+ s => local.push(s),
+ }
+ }
+ (local, global)
}
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
@@ -1269,6 +1490,15 @@ fn convert_to_typed_statements( ast::Instruction::Selp(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast())))
}
+ ast::Instruction::Bar(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast())))
+ }
+ ast::Instruction::Atom(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast())))
+ }
+ ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
+ ast::Instruction::AtomCas(d, a.cast()),
+ )),
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1286,6 +1516,99 @@ fn convert_to_typed_statements( Ok(result)
}
+fn to_ptx_impl_atomic_call(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ details: ast::AtomDetails,
+ arg: ast::Arg3<ExpandedArgParams>,
+ op: &'static str,
+) -> ExpandedStatement {
+ let semantics = ptx_semantics_name(details.semantics);
+ let scope = ptx_scope_name(details.scope);
+ let space = ptx_space_name(details.space);
+ let fn_name = format!(
+ "__notcuda_ptx_impl__atom_{}_{}_{}_{}",
+ semantics, scope, space, op
+ );
+ // TODO: extract to a function
+ let ptr_space = match details.space {
+ ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
+ ast::AtomSpace::Global => ast::PointerStateSpace::Global,
+ ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
+ };
+ let fn_id = match ptx_impl_imports.entry(fn_name) {
+ hash_map::Entry::Vacant(entry) => {
+ let fn_id = id_defs.new_id(None);
+ let func_decl = ast::MethodDecl::Func::<spirv::Word>(
+ vec![ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
+ ast::ScalarType::U32,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ }],
+ fn_id,
+ vec![
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
+ ast::SizedScalarType::U32,
+ ptr_space,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
+ ast::ScalarType::U32,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ],
+ );
+ let func = Function {
+ func_decl,
+ globals: Vec::new(),
+ body: None,
+ import_as: Some(entry.key().clone()),
+ };
+ entry.insert(Directive::Method(func));
+ fn_id
+ }
+ hash_map::Entry::Occupied(entry) => match entry.get() {
+ Directive::Method(Function {
+ func_decl: ast::MethodDecl::Func(_, name, _),
+ ..
+ }) => *name,
+ _ => unreachable!(),
+ },
+ };
+ Statement::Call(ResolvedCall {
+ uniform: false,
+ func: fn_id,
+ ret_params: vec![(
+ arg.dst,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ )],
+ param_list: vec![
+ (
+ arg.src1,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
+ ast::SizedScalarType::U32,
+ ptr_space,
+ )),
+ ),
+ (
+ arg.src2,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ),
+ ],
+ })
+}
+
fn to_resolved_fn_args<T>(
params: Vec<T>,
params_decl: &[ast::FnArgumentType],
@@ -1529,6 +1852,9 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( | (t, ArgumentSemantics::DefaultRelaxed)
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
+ if let ast::Type::Array(_, _) = id_type {
+ return Ok(desc.op);
+ }
let generated_id = id_def.new_id(id_type.clone());
if !desc.is_dst {
result.push(Statement::LoadVar(
@@ -1916,6 +2242,12 @@ fn insert_implicit_conversions( if let ast::Instruction::St(d, _) = &inst {
state_space = Some(d.state_space.to_ld_ss());
}
+ if let ast::Instruction::Atom(d, _) = &inst {
+ state_space = Some(d.space.to_ld_ss());
+ }
+ if let ast::Instruction::AtomCas(d, _) = &inst {
+ state_space = Some(d.space.to_ld_ss());
+ }
if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
default_conversion_fn = should_bitcast_packed;
}
@@ -2387,6 +2719,52 @@ fn emit_function_body_ops( let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
builder.select(result_type, Some(a.dst), a.src3, a.src2, a.src2)?;
}
+ // TODO: implement named barriers
+ ast::Instruction::Bar(d, _) => {
+ let workgroup_scope = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(spirv::Scope::Workgroup as u32),
+ )?;
+ let barrier_semantics = match d {
+ ast::BarDetails::SyncAligned => map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ )?,
+ };
+ builder.control_barrier(workgroup_scope, workgroup_scope, barrier_semantics)?;
+ }
+ ast::Instruction::Atom(details, arg) => {
+ emit_atom(builder, map, details, arg)?;
+ }
+ ast::Instruction::AtomCas(details, arg) => {
+ let result_type = map.get_or_add_scalar(builder, details.typ.into());
+ let memory_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(details.scope.to_spirv() as u32),
+ )?;
+ let semantics_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(details.semantics.to_spirv().bits()),
+ )?;
+ builder.atomic_compare_exchange(
+ result_type,
+ Some(arg.dst),
+ arg.src1,
+ memory_const,
+ semantics_const,
+ semantics_const,
+ arg.src3,
+ arg.src2,
+ )?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -2417,6 +2795,99 @@ fn emit_function_body_ops( Ok(())
}
+fn emit_atom(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ details: &ast::AtomDetails,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), TranslateError> {
+ let (spirv_op, typ) = match details.inner {
+ ast::AtomInnerDetails::Bit { op, typ } => {
+ let spirv_op = match op {
+ ast::AtomBitOp::And => dr::Builder::atomic_and,
+ ast::AtomBitOp::Or => dr::Builder::atomic_or,
+ ast::AtomBitOp::Xor => dr::Builder::atomic_xor,
+ ast::AtomBitOp::Exchange => dr::Builder::atomic_exchange,
+ };
+ (spirv_op, ast::ScalarType::from(typ))
+ }
+ ast::AtomInnerDetails::Unsigned { op, typ } => {
+ let spirv_op = match op {
+ ast::AtomUIntOp::Add => dr::Builder::atomic_i_add,
+ ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => {
+ return Err(TranslateError::Unreachable);
+ }
+ ast::AtomUIntOp::Min => dr::Builder::atomic_u_min,
+ ast::AtomUIntOp::Max => dr::Builder::atomic_u_max,
+ };
+ (spirv_op, typ.into())
+ }
+ ast::AtomInnerDetails::Signed { op, typ } => {
+ let spirv_op = match op {
+ ast::AtomSIntOp::Add => dr::Builder::atomic_i_add,
+ ast::AtomSIntOp::Min => dr::Builder::atomic_s_min,
+ ast::AtomSIntOp::Max => dr::Builder::atomic_s_max,
+ };
+ (spirv_op, typ.into())
+ }
+ // TODO: Hardware is capable of this, implement it through builtin
+ ast::AtomInnerDetails::Float { .. } => todo!(),
+ };
+ let result_type = map.get_or_add_scalar(builder, typ);
+ let memory_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(details.scope.to_spirv() as u32),
+ )?;
+ let semantics_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(details.semantics.to_spirv().bits()),
+ )?;
+ spirv_op(
+ builder,
+ result_type,
+ Some(arg.dst),
+ arg.src1,
+ memory_const,
+ semantics_const,
+ arg.src2,
+ )?;
+ Ok(())
+}
+
+#[derive(Clone)]
+struct PtxImplImport {
+ out_arg: ast::Type,
+ fn_id: u32,
+ in_args: Vec<ast::Type>,
+}
+
+fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str {
+ match sema {
+ ast::AtomSemantics::Relaxed => "relaxed",
+ ast::AtomSemantics::Acquire => "acquire",
+ ast::AtomSemantics::Release => "release",
+ ast::AtomSemantics::AcquireRelease => "acq_rel",
+ }
+}
+
+fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
+ match scope {
+ ast::MemScope::Cta => "cta",
+ ast::MemScope::Gpu => "gpu",
+ ast::MemScope::Sys => "sys",
+ }
+}
+
+fn ptx_space_name(space: ast::AtomSpace) -> &'static str {
+ match space {
+ ast::AtomSpace::Generic => "generic",
+ ast::AtomSpace::Global => "global",
+ ast::AtomSpace::Shared => "shared",
+ }
+}
+
fn emit_mul_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -2652,7 +3123,7 @@ fn emit_cvt( map: &mut TypeWordMap,
dets: &ast::CvtDetails,
arg: &ast::Arg2<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
+) -> Result<(), TranslateError> {
match dets {
ast::CvtDetails::FloatFromFloat(desc) => {
if desc.dst == desc.src {
@@ -3011,7 +3482,7 @@ fn emit_implicit_conversion( builder: &mut dr::Builder,
map: &mut TypeWordMap,
cv: &ImplicitConversion,
-) -> Result<(), dr::Error> {
+) -> Result<(), TranslateError> {
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
@@ -3019,7 +3490,7 @@ fn emit_implicit_conversion( let dst_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
}
- (_, _, ConversionKind::BitToPtr(space)) => {
+ (_, _, ConversionKind::BitToPtr(_)) => {
let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
@@ -3782,8 +4253,9 @@ enum Directive<'input> { struct Function<'input> {
pub func_decl: ast::MethodDecl<'input, spirv::Word>,
- pub globals: Vec<ExpandedStatement>,
+ pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
+ import_as: Option<String>,
}
pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
@@ -4091,6 +4563,13 @@ impl<T: ArgParamsEx> ast::Instruction<T> { a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?),
+ ast::Instruction::Bar(d, a) => ast::Instruction::Bar(d, a.map(visitor)?),
+ ast::Instruction::Atom(d, a) => {
+ ast::Instruction::Atom(d, a.map_atom(visitor, d.inner.get_type(), d.space)?)
+ }
+ ast::Instruction::AtomCas(d, a) => {
+ ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
+ }
})
}
}
@@ -4337,6 +4816,9 @@ impl ast::Instruction<ExpandedArgParams> { | ast::Instruction::Rcp(_, _)
| ast::Instruction::And(_, _)
| ast::Instruction::Selp(_, _)
+ | ast::Instruction::Bar(_, _)
+ | ast::Instruction::Atom(_, _)
+ | ast::Instruction::AtomCas(_, _)
| ast::Instruction::Mad(_, _) => None,
}
}
@@ -4358,6 +4840,9 @@ impl ast::Instruction<ExpandedArgParams> { ast::Instruction::And(_, _) => None,
ast::Instruction::Cvta(_, _) => None,
ast::Instruction::Selp(_, _) => None,
+ ast::Instruction::Bar(_, _) => None,
+ ast::Instruction::Atom(_, _) => None,
+ ast::Instruction::AtomCas(_, _) => None,
ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None,
ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None,
@@ -4612,6 +5097,27 @@ impl<T: ArgParamsEx> ast::Arg1<T> { }
}
+impl<T: ArgParamsEx> ast::Arg1Bar<T> {
+ fn cast<U: ArgParamsEx<Operand = T::Operand>>(self) -> ast::Arg1Bar<U> {
+ ast::Arg1Bar { src: self.src }
+ }
+
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<ast::Arg1Bar<U>, TranslateError> {
+ let new_src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ )?;
+ Ok(ast::Arg1Bar { src: new_src })
+ }
+}
+
impl<T: ArgParamsEx> ast::Arg2<T> {
fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg2<U> {
ast::Arg2 {
@@ -5022,6 +5528,43 @@ impl<T: ArgParamsEx> ast::Arg3<T> { )?;
Ok(ast::Arg3 { dst, src1, src2 })
}
+
+ fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ t: ast::ScalarType,
+ state_space: ast::AtomSpace,
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let scalar_type = ast::ScalarType::from(t);
+ let dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&ast::Type::Scalar(scalar_type)),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::PhysicalPointer,
+ },
+ &ast::Type::Pointer(
+ ast::PointerType::Scalar(scalar_type),
+ state_space.to_ld_ss(),
+ ),
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(scalar_type),
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
+ }
}
impl<T: ArgParamsEx> ast::Arg4<T> {
@@ -5129,6 +5672,56 @@ impl<T: ArgParamsEx> ast::Arg4<T> { src3,
})
}
+
+ fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ t: ast::BitType,
+ state_space: ast::AtomSpace,
+ ) -> Result<ast::Arg4<U>, TranslateError> {
+ let scalar_type = ast::ScalarType::from(t);
+ let dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&ast::Type::Scalar(scalar_type)),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::PhysicalPointer,
+ },
+ &ast::Type::Pointer(
+ ast::PointerType::Scalar(scalar_type),
+ state_space.to_ld_ss(),
+ ),
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(scalar_type),
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(scalar_type),
+ )?;
+ Ok(ast::Arg4 {
+ dst,
+ src1,
+ src2,
+ src3,
+ })
+ }
}
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
@@ -5434,6 +6027,17 @@ impl ast::MinMaxDetails { }
}
+impl ast::AtomInnerDetails {
+ fn get_type(&self) -> ast::ScalarType {
+ match self {
+ ast::AtomInnerDetails::Bit { typ, .. } => (*typ).into(),
+ ast::AtomInnerDetails::Unsigned { typ, .. } => (*typ).into(),
+ ast::AtomInnerDetails::Signed { typ, .. } => (*typ).into(),
+ ast::AtomInnerDetails::Float { typ, .. } => (*typ).into(),
+ }
+ }
+}
+
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
@@ -5509,6 +6113,37 @@ impl ast::MulDetails { }
}
+impl ast::AtomSpace {
+ fn to_ld_ss(self) -> ast::LdStateSpace {
+ match self {
+ ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
+ ast::AtomSpace::Global => ast::LdStateSpace::Global,
+ ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
+ }
+ }
+}
+
+impl ast::MemScope {
+ fn to_spirv(self) -> spirv::Scope {
+ match self {
+ ast::MemScope::Cta => spirv::Scope::Workgroup,
+ ast::MemScope::Gpu => spirv::Scope::Device,
+ ast::MemScope::Sys => spirv::Scope::CrossDevice,
+ }
+ }
+}
+
+impl ast::AtomSemantics {
+ fn to_spirv(self) -> spirv::MemorySemantics {
+ match self {
+ ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED,
+ ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE,
+ ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE,
+ ast::AtomSemantics::AcquireRelease => spirv::MemorySemantics::ACQUIRE_RELEASE,
+ }
+ }
+}
+
fn bitcast_logical_pointer(
operand: &ast::Type,
instr: &ast::Type,
@@ -5528,7 +6163,27 @@ fn bitcast_physical_pointer( ) -> Result<Option<ConversionKind>, TranslateError> {
match operand_type {
// array decays to a pointer
- ast::Type::Array(_, _) => todo!(),
+ ast::Type::Array(op_scalar_t, _) => {
+ if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
+ if ss == Some(*instr_space) {
+ if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) {
+ Ok(None)
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
+ }
+ } else {
+ if ss == Some(ast::LdStateSpace::Generic)
+ || *instr_space == ast::LdStateSpace::Generic
+ {
+ Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ }
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ }
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => {
|