aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-10-31 21:28:15 +0100
committerAndrzej Janik <[email protected]>2020-10-31 21:28:15 +0100
commita82eb2081717c1fb48e140176fec0e5b5974a432 (patch)
treeb5ca6934333d1707ed43a1e21a8f02f630929dc4
parent861116f223081528cf1e32f5e1eddb733ac00241 (diff)
downloadZLUDA-a82eb2081717c1fb48e140176fec0e5b5974a432.tar.gz
ZLUDA-a82eb2081717c1fb48e140176fec0e5b5974a432.zip
Implement atomic instructions
-rw-r--r--level_zero/Cargo.toml6
-rw-r--r--level_zero/src/ze.rs73
-rw-r--r--notcuda/src/impl/module.rs4
-rw-r--r--ptx/lib/notcuda_ptx_impl.cl121
-rw-r--r--ptx/lib/notcuda_ptx_impl.spvbin0 -> 48348 bytes
-rw-r--r--ptx/src/ast.rs92
-rw-r--r--ptx/src/ptx.lalrpop215
-rw-r--r--ptx/src/test/spirv_build/bar_sync.ptx10
-rw-r--r--ptx/src/test/spirv_run/and.spvtxt4
-rw-r--r--ptx/src/test/spirv_run/atom_add.ptx28
-rw-r--r--ptx/src/test/spirv_run/atom_add.spvtxt84
-rw-r--r--ptx/src/test/spirv_run/atom_cas.ptx24
-rw-r--r--ptx/src/test/spirv_run/atom_cas.spvtxt77
-rw-r--r--ptx/src/test/spirv_run/atom_inc.ptx26
-rw-r--r--ptx/src/test/spirv_run/atom_inc.spvtxt89
-rw-r--r--ptx/src/test/spirv_run/constant_f32.spvtxt6
-rw-r--r--ptx/src/test/spirv_run/constant_negative.spvtxt4
-rw-r--r--ptx/src/test/spirv_run/fma.spvtxt6
-rw-r--r--ptx/src/test/spirv_run/mod.rs37
-rw-r--r--ptx/src/test/spirv_run/mul_ftz.spvtxt4
-rw-r--r--ptx/src/test/spirv_run/selp.spvtxt4
-rw-r--r--ptx/src/test/spirv_run/shared_variable.ptx26
-rw-r--r--ptx/src/test/spirv_run/shared_variable.spvtxt65
-rw-r--r--ptx/src/translate.rs755
24 files changed, 1672 insertions, 88 deletions
diff --git a/level_zero/Cargo.toml b/level_zero/Cargo.toml
index 97537b3..851159d 100644
--- a/level_zero/Cargo.toml
+++ b/level_zero/Cargo.toml
@@ -7,4 +7,8 @@ edition = "2018"
[lib]
[dependencies]
-level_zero-sys = { path = "../level_zero-sys" } \ No newline at end of file
+level_zero-sys = { path = "../level_zero-sys" }
+
+[dependencies.ocl-core]
+version = "0.11"
+features = ["opencl_version_1_2", "opencl_version_2_0", "opencl_version_2_1"] \ No newline at end of file
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs
index 5ced5d0..f8a2c3b 100644
--- a/level_zero/src/ze.rs
+++ b/level_zero/src/ze.rs
@@ -238,7 +238,76 @@ impl Drop for CommandQueue {
pub struct Module(sys::ze_module_handle_t);
impl Module {
- pub fn new_spirv(
+ // HACK ALERT
+ // We use OpenCL for now to do SPIR-V linking, because Level0
+ // does not allow linking. Don't let presence of zeModuleDynamicLink fool
+ // you, it's not currently possible to create non-compiled modules.
+ // zeModuleCreate always compiles (builds and links).
+ pub fn build_link_spirv<'a>(
+ ctx: &mut Context,
+ d: &Device,
+ binaries: &[&'a [u8]],
+ ) -> (Result<Self>, Option<BuildLog>) {
+ let ocl_program = match Self::build_link_spirv_impl(binaries) {
+ Err(_) => return (Err(sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None),
+ Ok(prog) => prog,
+ };
+ match ocl_core::get_program_info(&ocl_program, ocl_core::ProgramInfo::Binaries) {
+ Ok(ocl_core::ProgramInfoResult::Binaries(binaries)) => {
+ let (module, build_log) = Self::build_native(ctx, d, &binaries[0]);
+ (module, Some(build_log))
+ }
+ _ => return (Err(sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None),
+ }
+ }
+
+ fn build_link_spirv_impl<'a>(binaries: &[&'a [u8]]) -> ocl_core::Result<ocl_core::Program> {
+ let platforms = ocl_core::get_platform_ids()?;
+ let (platform, device) = platforms
+ .iter()
+ .find_map(|plat| {
+ let devices =
+ ocl_core::get_device_ids(plat, Some(ocl_core::DeviceType::GPU), None).ok()?;
+ for dev in devices {
+ let vendor =
+ ocl_core::get_device_info(dev, ocl_core::DeviceInfo::VendorId).ok()?;
+ if let ocl_core::DeviceInfoResult::VendorId(0x8086) = vendor {
+ let dev_type =
+ ocl_core::get_device_info(dev, ocl_core::DeviceInfo::Type).ok()?;
+ if let ocl_core::DeviceInfoResult::Type(ocl_core::DeviceType::GPU) =
+ dev_type
+ {
+ return Some((plat.clone(), dev));
+ }
+ }
+ }
+ None
+ })
+ .ok_or("")?;
+ let ctx_props = ocl_core::ContextProperties::new().platform(platform);
+ let ocl_ctx = ocl_core::create_context_from_type::<ocl_core::DeviceId>(
+ Some(&ctx_props),
+ ocl_core::DeviceType::GPU,
+ None,
+ None,
+ )?;
+ let mut programs = Vec::with_capacity(binaries.len());
+ for binary in binaries {
+ programs.push(ocl_core::create_program_with_il(&ocl_ctx, binary, None)?);
+ }
+ let options = CString::default();
+ ocl_core::link_program::<ocl_core::DeviceId, _>(
+ &ocl_ctx,
+ Some(&[device]),
+ &options,
+ &programs.iter().collect::<Vec<_>>(),
+ None,
+ None,
+ None,
+ )
+ }
+
+ pub fn build_spirv(
ctx: &mut Context,
d: &Device,
bin: &[u8],
@@ -247,7 +316,7 @@ impl Module {
Module::new(ctx, true, d, bin, opts)
}
- pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result<Self>, BuildLog) {
+ pub fn build_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result<Self>, BuildLog) {
Module::new(ctx, false, d, bin, None)
}
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index eea862b..35436c3 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -53,7 +53,7 @@ impl ModuleData {
Ok(_) if errors.len() > 0 => return Err(ModuleCompileError::Parse(errors, None)),
Ok(ast) => ast,
};
- let (spirv, all_arg_lens) = ptx::to_spirv(ast)?;
+ let (_, spirv, all_arg_lens) = ptx::to_spirv(ast)?;
let byte_il = unsafe {
slice::from_raw_parts::<u8>(
spirv.as_ptr() as *const _,
@@ -61,7 +61,7 @@ impl ModuleData {
)
};
let module = super::device::with_current_exclusive(|dev| {
- l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
+ l0::Module::build_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
});
match module {
Ok((Ok(module), _)) => Ok(Mutex::new(Self {
diff --git a/ptx/lib/notcuda_ptx_impl.cl b/ptx/lib/notcuda_ptx_impl.cl
new file mode 100644
index 0000000..a0d487b
--- /dev/null
+++ b/ptx/lib/notcuda_ptx_impl.cl
@@ -0,0 +1,121 @@
+// Every time this file changes it must te rebuilt:
+// ocloc -file notcuda_ptx_impl.cl -64 -options "-cl-std=CL2.0" -out_dir . -device kbl -output_no_suffix -spv_only
+// Additionally you should strip names:
+// spirv-opt --strip-debug notcuda_ptx_impl.spv -o notcuda_ptx_impl.spv
+
+#define FUNC(NAME) __notcuda_ptx_impl__ ## NAME
+
+#define atomic_inc(NAME, SUCCESS, FAILURE, SCOPE, SPACE) \
+ uint FUNC(NAME)(SPACE uint* ptr, uint threshold) { \
+ uint expected = *ptr; \
+ uint desired; \
+ do { \
+ desired = (expected >= threshold) ? 0 : expected + 1; \
+ } while (!atomic_compare_exchange_strong_explicit((volatile SPACE atomic_uint*)ptr, &expected, desired, SUCCESS, FAILURE, SCOPE)); \
+ return expected; \
+ }
+
+#define atomic_dec(NAME, SUCCESS, FAILURE, SCOPE, SPACE) \
+ uint FUNC(NAME)(SPACE uint* ptr, uint threshold) { \
+ uint expected = *ptr; \
+ uint desired; \
+ do { \
+ desired = (expected == 0 || expected > threshold) ? threshold : expected - 1; \
+ } while (!atomic_compare_exchange_strong_explicit((volatile SPACE atomic_uint*)ptr, &expected, desired, SUCCESS, FAILURE, SCOPE)); \
+ return expected; \
+ }
+
+// We are doing all this mess instead of accepting memory_order and memory_scope parameters
+// because ocloc emits broken (failing spirv-dis) SPIR-V when memory_order or memory_scope is a parameter
+
+// atom.inc
+atomic_inc(atom_relaxed_cta_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, );
+atomic_inc(atom_acquire_cta_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, );
+atomic_inc(atom_release_cta_generic_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, );
+atomic_inc(atom_acq_rel_cta_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, );
+
+atomic_inc(atom_relaxed_gpu_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, );
+atomic_inc(atom_acquire_gpu_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, );
+atomic_inc(atom_release_gpu_generic_inc, memory_order_release, memory_order_acquire, memory_scope_device, );
+atomic_inc(atom_acq_rel_gpu_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, );
+
+atomic_inc(atom_relaxed_sys_generic_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, );
+atomic_inc(atom_acquire_sys_generic_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, );
+atomic_inc(atom_release_sys_generic_inc, memory_order_release, memory_order_acquire, memory_scope_device, );
+atomic_inc(atom_acq_rel_sys_generic_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, );
+
+atomic_inc(atom_relaxed_cta_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global);
+atomic_inc(atom_acquire_cta_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global);
+atomic_inc(atom_release_cta_global_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, __global);
+atomic_inc(atom_acq_rel_cta_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global);
+
+atomic_inc(atom_relaxed_gpu_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global);
+atomic_inc(atom_acquire_gpu_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __global);
+atomic_inc(atom_release_gpu_global_inc, memory_order_release, memory_order_acquire, memory_scope_device, __global);
+atomic_inc(atom_acq_rel_gpu_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global);
+
+atomic_inc(atom_relaxed_sys_global_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global);
+atomic_inc(atom_acquire_sys_global_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __global);
+atomic_inc(atom_release_sys_global_inc, memory_order_release, memory_order_acquire, memory_scope_device, __global);
+atomic_inc(atom_acq_rel_sys_global_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global);
+
+atomic_inc(atom_relaxed_cta_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local);
+atomic_inc(atom_acquire_cta_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local);
+atomic_inc(atom_release_cta_shared_inc, memory_order_release, memory_order_acquire, memory_scope_work_group, __local);
+atomic_inc(atom_acq_rel_cta_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local);
+
+atomic_inc(atom_relaxed_gpu_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local);
+atomic_inc(atom_acquire_gpu_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
+atomic_inc(atom_release_gpu_shared_inc, memory_order_release, memory_order_acquire, memory_scope_device, __local);
+atomic_inc(atom_acq_rel_gpu_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
+
+atomic_inc(atom_relaxed_sys_shared_inc, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local);
+atomic_inc(atom_acquire_sys_shared_inc, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
+atomic_inc(atom_release_sys_shared_inc, memory_order_release, memory_order_acquire, memory_scope_device, __local);
+atomic_inc(atom_acq_rel_sys_shared_inc, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
+
+// atom.dec
+atomic_dec(atom_relaxed_cta_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, );
+atomic_dec(atom_acquire_cta_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, );
+atomic_dec(atom_release_cta_generic_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, );
+atomic_dec(atom_acq_rel_cta_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, );
+
+atomic_dec(atom_relaxed_gpu_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, );
+atomic_dec(atom_acquire_gpu_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, );
+atomic_dec(atom_release_gpu_generic_dec, memory_order_release, memory_order_acquire, memory_scope_device, );
+atomic_dec(atom_acq_rel_gpu_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, );
+
+atomic_dec(atom_relaxed_sys_generic_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, );
+atomic_dec(atom_acquire_sys_generic_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, );
+atomic_dec(atom_release_sys_generic_dec, memory_order_release, memory_order_acquire, memory_scope_device, );
+atomic_dec(atom_acq_rel_sys_generic_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, );
+
+atomic_dec(atom_relaxed_cta_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __global);
+atomic_dec(atom_acquire_cta_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __global);
+atomic_dec(atom_release_cta_global_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, __global);
+atomic_dec(atom_acq_rel_cta_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __global);
+
+atomic_dec(atom_relaxed_gpu_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global);
+atomic_dec(atom_acquire_gpu_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __global);
+atomic_dec(atom_release_gpu_global_dec, memory_order_release, memory_order_acquire, memory_scope_device, __global);
+atomic_dec(atom_acq_rel_gpu_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global);
+
+atomic_dec(atom_relaxed_sys_global_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __global);
+atomic_dec(atom_acquire_sys_global_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __global);
+atomic_dec(atom_release_sys_global_dec, memory_order_release, memory_order_acquire, memory_scope_device, __global);
+atomic_dec(atom_acq_rel_sys_global_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __global);
+
+atomic_dec(atom_relaxed_cta_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_work_group, __local);
+atomic_dec(atom_acquire_cta_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_work_group, __local);
+atomic_dec(atom_release_cta_shared_dec, memory_order_release, memory_order_acquire, memory_scope_work_group, __local);
+atomic_dec(atom_acq_rel_cta_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_work_group, __local);
+
+atomic_dec(atom_relaxed_gpu_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local);
+atomic_dec(atom_acquire_gpu_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
+atomic_dec(atom_release_gpu_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local);
+atomic_dec(atom_acq_rel_gpu_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
+
+atomic_dec(atom_relaxed_sys_shared_dec, memory_order_relaxed, memory_order_relaxed, memory_scope_device, __local);
+atomic_dec(atom_acquire_sys_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
+atomic_dec(atom_release_sys_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local);
+atomic_dec(atom_acq_rel_sys_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
diff --git a/ptx/lib/notcuda_ptx_impl.spv b/ptx/lib/notcuda_ptx_impl.spv
new file mode 100644
index 0000000..36f37bb
--- /dev/null
+++ b/ptx/lib/notcuda_ptx_impl.spv
Binary files differ
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 1266ea4..ad8e87d 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -109,11 +109,12 @@ macro_rules! sub_type {
};
}
-// Pointer is used when doing SLM converison to SPIRV
sub_type! {
VariableRegType {
Scalar(ScalarType),
Vector(SizedScalarType, u8),
+ // Pointer variant is used when passing around SLM pointer between
+ // function calls for dynamic SLM
Pointer(SizedScalarType, PointerStateSpace)
}
}
@@ -215,6 +216,11 @@ sub_enum!(SelpType {
F64,
});
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum BarDetails {
+ SyncAligned,
+}
+
pub trait UnwrapWithVec<E, To> {
fn unwrap_with(self, errs: &mut Vec<E>) -> To;
}
@@ -301,6 +307,7 @@ impl From<FnArgumentType> for Type {
sub_enum!(
PointerStateSpace : LdStateSpace {
+ Generic,
Global,
Const,
Shared,
@@ -372,6 +379,8 @@ sub_enum!(IntType {
S64
});
+sub_enum!(BitType { B8, B16, B32, B64 });
+
sub_enum!(UIntType { U8, U16, U32, U64 });
sub_enum!(SIntType { S8, S16, S32, S64 });
@@ -527,6 +536,9 @@ pub enum Instruction<P: ArgParams> {
Rcp(RcpDetails, Arg2<P>),
And(OrAndType, Arg3<P>),
Selp(SelpType, Arg4<P>),
+ Bar(BarDetails, Arg1Bar<P>),
+ Atom(AtomDetails, Arg3<P>),
+ AtomCas(AtomCasDetails, Arg4<P>),
}
#[derive(Copy, Clone)]
@@ -577,6 +589,10 @@ pub struct Arg1<P: ArgParams> {
pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand
}
+pub struct Arg1Bar<P: ArgParams> {
+ pub src: P::Operand,
+}
+
pub struct Arg2<P: ArgParams> {
pub dst: P::Id,
pub src: P::Operand,
@@ -712,12 +728,12 @@ impl From<LdStType> for PointerType {
pub enum LdStQualifier {
Weak,
Volatile,
- Relaxed(LdScope),
- Acquire(LdScope),
+ Relaxed(MemScope),
+ Acquire(MemScope),
}
#[derive(Copy, Clone, PartialEq, Eq)]
-pub enum LdScope {
+pub enum MemScope {
Cta,
Gpu,
Sys,
@@ -1051,6 +1067,74 @@ pub struct MinMaxFloat {
pub typ: FloatType,
}
+#[derive(Copy, Clone)]
+pub struct AtomDetails {
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: AtomSpace,
+ pub inner: AtomInnerDetails,
+}
+
+#[derive(Copy, Clone)]
+pub enum AtomSemantics {
+ Relaxed,
+ Acquire,
+ Release,
+ AcquireRelease,
+}
+
+#[derive(Copy, Clone)]
+pub enum AtomSpace {
+ Generic,
+ Global,
+ Shared,
+}
+
+#[derive(Copy, Clone)]
+pub enum AtomInnerDetails {
+ Bit { op: AtomBitOp, typ: BitType },
+ Unsigned { op: AtomUIntOp, typ: UIntType },
+ Signed { op: AtomSIntOp, typ: SIntType },
+ Float { op: AtomFloatOp, typ: FloatType },
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum AtomBitOp {
+ And,
+ Or,
+ Xor,
+ Exchange,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum AtomUIntOp {
+ Add,
+ Inc,
+ Dec,
+ Min,
+ Max,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum AtomSIntOp {
+ Add,
+ Min,
+ Max,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum AtomFloatOp {
+ Add,
+}
+
+#[derive(Copy, Clone)]
+pub struct AtomCasDetails {
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: AtomSpace,
+ pub typ: BitType
+}
+
pub enum NumsOrArrays<'a> {
Nums(Vec<(&'a str, u32)>),
Arrays(Vec<NumsOrArrays<'a>>),
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index dfe5a5f..806a3fc 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -35,9 +35,12 @@ match {
"<", ">",
"|",
"=",
+ ".acq_rel",
".acquire",
+ ".add",
".address_size",
".align",
+ ".aligned",
".and",
".approx",
".b16",
@@ -45,14 +48,17 @@ match {
".b64",
".b8",
".ca",
+ ".cas",
".cg",
".const",
".cs",
".cta",
".cv",
+ ".dec",
".entry",
".eq",
".equ",
+ ".exch",
".extern",
".f16",
".f16x2",
@@ -69,6 +75,7 @@ match {
".gtu",
".hi",
".hs",
+ ".inc",
".le",
".leu",
".lo",
@@ -78,6 +85,8 @@ match {
".lt",
".ltu",
".lu",
+ ".max",
+ ".min",
".nan",
".NaN",
".ne",
@@ -88,6 +97,7 @@ match {
".pred",
".reg",
".relaxed",
+ ".release",
".rm",
".rmi",
".rn",
@@ -103,6 +113,7 @@ match {
".sat",
".section",
".shared",
+ ".sync",
".sys",
".target",
".to",
@@ -126,6 +137,9 @@ match {
"abs",
"add",
"and",
+ "atom",
+ "bar",
+ "barrier",
"bra",
"call",
"cvt",
@@ -162,6 +176,9 @@ ExtendedID : &'input str = {
"abs",
"add",
"and",
+ "atom",
+ "bar",
+ "barrier",
"bra",
"call",
"cvt",
@@ -372,6 +389,7 @@ StateSpaceSpecifier: ast::StateSpace = {
".param" => ast::StateSpace::Param, // used to prepare function call
};
+#[inline]
ScalarType: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
@@ -438,6 +456,7 @@ Variable: ast::Variable<ast::VariableType, &'input str> = {
let v_type = ast::VariableType::Param(v_type);
ast::Variable {align, v_type, name, array_init}
},
+ SharedVariable,
};
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
@@ -478,6 +497,32 @@ LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
}
}
+SharedVariable: ast::Variable<ast::VariableType, &'input str> = {
+ ".shared" <var:VariableScalar<SizedScalarType>> => {
+ let (align, t, name) = var;
+ let v_type = ast::VariableGlobalType::Scalar(t);
+ ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ },
+ ".shared" <var:VariableVector<SizedScalarType>> => {
+ let (align, v_len, t, name) = var;
+ let v_type = ast::VariableGlobalType::Vector(t, v_len);
+ ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ },
+ ".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
+ let (align, t, name, arr_or_ptr) = var;
+ let (v_type, array_init) = match arr_or_ptr {
+ ast::ArrayOrPointer::Array { dimensions, init } => {
+ (ast::VariableGlobalType::Array(t, dimensions), init)
+ }
+ ast::ArrayOrPointer::Pointer => {
+ return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
+ }
+ };
+ Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init })
+ }
+}
+
+
ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
@@ -619,7 +664,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstMin,
InstMax,
InstRcp,
- InstSelp
+ InstSelp,
+ InstBar,
+ InstAtom,
+ InstAtomCas
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -655,14 +703,14 @@ LdStType: ast::LdStType = {
LdStQualifier: ast::LdStQualifier = {
".weak" => ast::LdStQualifier::Weak,
".volatile" => ast::LdStQualifier::Volatile,
- ".relaxed" <s:LdScope> => ast::LdStQualifier::Relaxed(s),
- ".acquire" <s:LdScope> => ast::LdStQualifier::Acquire(s),
+ ".relaxed" <s:MemScope> => ast::LdStQualifier::Relaxed(s),
+ ".acquire" <s:MemScope> => ast::LdStQualifier::Acquire(s),
};
-LdScope: ast::LdScope = {
- ".cta" => ast::LdScope::Cta,
- ".gpu" => ast::LdScope::Gpu,
- ".sys" => ast::LdScope::Sys
+MemScope: ast::MemScope = {
+ ".cta" => ast::MemScope::Cta,
+ ".gpu" => ast::MemScope::Gpu,
+ ".sys" => ast::MemScope::Sys
};
LdStateSpace: ast::LdStateSpace = {
@@ -798,6 +846,13 @@ SIntType: ast::SIntType = {
".s64" => ast::SIntType::S64,
};
+FloatType: ast::FloatType = {
+ ".f16" => ast::FloatType::F16,
+ ".f16x2" => ast::FloatType::F16x2,
+ ".f32" => ast::FloatType::F32,
+ ".f64" => ast::FloatType::F64,
+};
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
@@ -1296,6 +1351,140 @@ SelpType: ast::SelpType = {
".f64" => ast::SelpType::F64,
};
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar
+InstBar: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "barrier" ".sync" ".aligned" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a),
+ "bar" ".sync" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a)
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
+// The documentation does not mention all spported operations:
+// * Operation .add requires .u32 or .s32 or .u64 or .f64 or f16 or f16x2 or .f32
+// * Operation .inc requires .u32 type for instuction
+// * Operation .dec requires .u32 type for instuction
+// Otherwise as documented
+InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op:AtomBitOp> <typ:AtomBitType> <a:Arg3Atom> => {
+ let details = ast::AtomDetails {
+ semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(ast::MemScope::Gpu),
+ space: space.unwrap_or(ast::AtomSpace::Generic),
+ inner: ast::AtomInnerDetails::Bit { op, typ }
+ };
+ ast::Instruction::Atom(details,a)
+ },
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".inc" ".u32" <a:Arg3Atom> => {
+ let details = ast::AtomDetails {
+ semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(ast::MemScope::Gpu),
+ space: space.unwrap_or(ast::AtomSpace::Generic),
+ inner: ast::AtomInnerDetails::Unsigned {
+ op: ast::AtomUIntOp::Inc,
+ typ: ast::UIntType::U32
+ }
+ };
+ ast::Instruction::Atom(details,a)
+ },
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".dec" ".u32" <a:Arg3Atom> => {
+ let details = ast::AtomDetails {
+ semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(ast::MemScope::Gpu),
+ space: space.unwrap_or(ast::AtomSpace::Generic),
+ inner: ast::AtomInnerDetails::Unsigned {
+ op: ast::AtomUIntOp::Dec,
+ typ: ast::UIntType::U32
+ }
+ };
+ ast::Instruction::Atom(details,a)
+ },
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".add" <typ:FloatType> <a:Arg3Atom> => {
+ let op = ast::AtomFloatOp::Add;
+ let details = ast::AtomDetails {
+ semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(ast::MemScope::Gpu),
+ space: space.unwrap_or(ast::AtomSpace::Generic),
+ inner: ast::AtomInnerDetails::Float { op, typ }
+ };
+ ast::Instruction::Atom(details,a)
+ },
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op: AtomUIntOp> <typ:AtomUIntType> <a:Arg3Atom> => {
+ let details = ast::AtomDetails {
+ semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(ast::MemScope::Gpu),
+ space: space.unwrap_or(ast::AtomSpace::Generic),
+ inner: ast::AtomInnerDetails::Unsigned { op, typ }
+ };
+ ast::Instruction::Atom(details,a)
+ },
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op: AtomSIntOp> <typ:AtomSIntType> <a:Arg3Atom> => {
+ let details = ast::AtomDetails {
+ semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(ast::MemScope::Gpu),
+ space: space.unwrap_or(ast::AtomSpace::Generic),
+ inner: ast::AtomInnerDetails::Signed { op, typ }
+ };
+ ast::Instruction::Atom(details,a)
+ }
+}
+
+InstAtomCas: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".cas" <typ:AtomBitType> <a:Arg4Atom> => {
+ let details = ast::AtomCasDetails {
+ semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(ast::MemScope::Gpu),
+ space: space.unwrap_or(ast::AtomSpace::Generic),
+ typ,
+ };
+ ast::Instruction::AtomCas(details,a)
+ },
+}
+
+AtomSemantics: ast::AtomSemantics = {
+ ".relaxed" => ast::AtomSemantics::Relaxed,
+ ".acquire" => ast::AtomSemantics::Acquire,
+ ".release" => ast::AtomSemantics::Release,
+ ".acq_rel" => ast::AtomSemantics::AcquireRelease
+}
+
+AtomSpace: ast::AtomSpace = {
+ ".global" => ast::AtomSpace::Global,
+ ".shared" => ast::AtomSpace::Shared
+}
+
+AtomBitOp: ast::AtomBitOp = {
+ ".and" => ast::AtomBitOp::And,
+ ".or" => ast::AtomBitOp::Or,
+ ".xor" => ast::AtomBitOp::Xor,
+ ".exch" => ast::AtomBitOp::Exchange,
+}
+
+AtomUIntOp: ast::AtomUIntOp = {
+ ".add" => ast::AtomUIntOp::Add,
+ ".min" => ast::AtomUIntOp::Min,
+ ".max" => ast::AtomUIntOp::Max,
+}
+
+AtomSIntOp: ast::AtomSIntOp = {
+ ".add" => ast::AtomSIntOp::Add,
+ ".min" => ast::AtomSIntOp::Min,
+ ".max" => ast::AtomSIntOp::Max,
+}
+
+AtomBitType: ast::BitType = {
+ ".b32" => ast::BitType::B32,
+ ".b64" => ast::BitType::B64,
+}
+
+AtomUIntType: ast::UIntType = {
+ ".u32" => ast::UIntType::U32,
+ ".u64" => ast::UIntType::U64,
+}
+
+AtomSIntType: ast::SIntType = {
+ ".s32" => ast::SIntType::S32,
+ ".s64" => ast::SIntType::S64,
+}
+
ArithDetails: ast::ArithDetails = {
<t:UIntType> => ast::ArithDetails::Unsigned(t),
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
@@ -1414,6 +1603,10 @@ Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = {
<src:ExtendedID> => ast::Arg1{<>}
};
+Arg1Bar: ast::Arg1Bar<ast::ParsedArgParams<'input>> = {
+ <src:Operand> => ast::Arg1Bar{<>}
+};
+
Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
};
@@ -1448,10 +1641,18 @@ Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
};
+Arg3Atom: ast::Arg3<ast::ParsedArgParams<'input>> = {
+ <dst:ExtendedID> "," "[" <src1:Operand> "]" "," <src2:Operand> => ast::Arg3{<>}
+};
+
Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>}
};
+Arg4Atom: ast::Arg4<ast::ParsedArgParams<'input>> = {
+ <dst:ExtendedID> "," "[" <src1:Operand> "]" "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>}
+};
+
Arg4Setp: ast::Arg4Setp<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4Setp{<>}
};
diff --git a/ptx/src/test/spirv_build/bar_sync.ptx b/ptx/src/test/spirv_build/bar_sync.ptx
new file mode 100644
index 0000000..54c6663
--- /dev/null
+++ b/ptx/src/test/spirv_build/bar_sync.ptx
@@ -0,0 +1,10 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry bar_sync()
+{
+ .reg .u32 temp_32;
+ bar.sync temp_32;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt
index 9b72477..8358c28 100644
--- a/ptx/src/test/spirv_run/and.spvtxt
+++ b/ptx/src/test/spirv_run/and.spvtxt
@@ -11,8 +11,8 @@ OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
-OpCapability FunctionFloatControlINTEL
-OpExtension "SPV_INTEL_float_controls2"
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
%33 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "and"
diff --git a/ptx/src/test/spirv_run/atom_add.ptx b/ptx/src/test/spirv_run/atom_add.ptx
new file mode 100644
index 0000000..5d1f667
--- /dev/null
+++ b/ptx/src/test/spirv_run/atom_add.ptx
@@ -0,0 +1,28 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry atom_add(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .shared .align 4 .b8 shared_mem[1024];
+
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp1;
+ .reg .u32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u32 temp1, [in_addr];
+ ld.u32 temp2, [in_addr+4];
+ st.shared.u32 [shared_mem], temp1;
+ atom.shared.add.u32 temp1, [shared_mem], temp2;
+ ld.shared.u32 temp2, [shared_mem];
+ st.u32 [out_addr], temp1;
+ st.u32 [out_addr+4], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt
new file mode 100644
index 0000000..2c83fe9
--- /dev/null
+++ b/ptx/src/test/spirv_run/atom_add.spvtxt
@@ -0,0 +1,84 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 55
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
+%40 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "atom_add" %4
+OpDecorate %4 Alignment 4
+%41 = OpTypeVoid
+%42 = OpTypeInt 32 0
+%43 = OpTypeInt 8 0
+%44 = OpConstant %42 1024
+%45 = OpTypeArray %43 %44
+%46 = OpTypePointer Workgroup %45
+%4 = OpVariable %46 Workgroup
+%47 = OpTypeInt 64 0
+%48 = OpTypeFunction %41 %47 %47
+%49 = OpTypePointer Function %47
+%50 = OpTypePointer Function %42
+%51 = OpTypePointer Generic %42
+%27 = OpConstant %47 4
+%52 = OpTypePointer Workgroup %42
+%53 = OpConstant %42 1
+%54 = OpConstant %42 0
+%29 = OpConstant %47 4
+%1 = OpFunction %41 None %48
+%9 = OpFunctionParameter %47
+%10 = OpFunctionParameter %47
+%38 = OpLabel
+%2 = OpVariable %49 Function
+%3 = OpVariable %49 Function
+%5 = OpVariable %49 Function
+%6 = OpVariable %49 Function
+%7 = OpVariable %50 Function
+%8 = OpVariable %50 Function
+OpStore %2 %9
+OpStore %3 %10
+%12 = OpLoad %47 %2
+%11 = OpCopyObject %47 %12
+OpStore %5 %11
+%14 = OpLoad %47 %3
+%13 = OpCopyObject %47 %14
+OpStore %6 %13
+%16 = OpLoad %47 %5
+%31 = OpConvertUToPtr %51 %16
+%15 = OpLoad %42 %31
+OpStore %7 %15
+%18 = OpLoad %47 %5
+%28 = OpIAdd %47 %18 %27
+%32 = OpConvertUToPtr %51 %28
+%17 = OpLoad %42 %32
+OpStore %8 %17
+%19 = OpLoad %42 %7
+%33 = OpBitcast %52 %4
+OpStore %33 %19
+%21 = OpLoad %42 %8
+%34 = OpBitcast %52 %4
+%20 = OpAtomicIAdd %42 %34 %53 %54 %21
+OpStore %7 %20
+%35 = OpBitcast %52 %4
+%22 = OpLoad %42 %35
+OpStore %8 %22
+%23 = OpLoad %47 %6
+%24 = OpLoad %42 %7
+%36 = OpConvertUToPtr %51 %23
+OpStore %36 %24
+%25 = OpLoad %47 %6
+%26 = OpLoad %42 %8
+%30 = OpIAdd %47 %25 %29
+%37 = OpConvertUToPtr %51 %30
+OpStore %37 %26
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/atom_cas.ptx b/ptx/src/test/spirv_run/atom_cas.ptx
new file mode 100644
index 0000000..440a1cb
--- /dev/null
+++ b/ptx/src/test/spirv_run/atom_cas.ptx
@@ -0,0 +1,24 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry atom_cas(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp1;
+ .reg .u32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u32 temp1, [in_addr];
+ atom.cas.b32 temp1, [in_addr+4], temp1, 100;
+ ld.u32 temp2, [in_addr+4];
+ st.u32 [out_addr], temp1;
+ st.u32 [out_addr+4], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/atom_cas.spvtxt b/ptx/src/test/spirv_run/atom_cas.spvtxt
new file mode 100644
index 0000000..c5fb922
--- /dev/null
+++ b/ptx/src/test/spirv_run/atom_cas.spvtxt
@@ -0,0 +1,77 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 51
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
+%41 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "atom_cas"
+%42 = OpTypeVoid
+%43 = OpTypeInt 64 0
+%44 = OpTypeFunction %42 %43 %43
+%45 = OpTypePointer Function %43
+%46 = OpTypeInt 32 0
+%47 = OpTypePointer Function %46
+%48 = OpTypePointer Generic %46
+%25 = OpConstant %43 4
+%27 = OpConstant %46 100
+%49 = OpConstant %46 1
+%50 = OpConstant %46 0
+%28 = OpConstant %43 4
+%30 = OpConstant %43 4
+%1 = OpFunction %42 None %44
+%8 = OpFunctionParameter %43
+%9 = OpFunctionParameter %43
+%39 = OpLabel
+%2 = OpVariable %45 Function
+%3 = OpVariable %45 Function
+%4 = OpVariable %45 Function
+%5 = OpVariable %45 Function
+%6 = OpVariable %47 Function
+%7 = OpVariable %47 Function
+OpStore %2 %8
+OpStore %3 %9
+%11 = OpLoad %43 %2
+%10 = OpCopyObject %43 %11
+OpStore %4 %10
+%13 = OpLoad %43 %3
+%12 = OpCopyObject %43 %13
+OpStore %5 %12
+%15 = OpLoad %43 %4
+%32 = OpConvertUToPtr %48 %15
+%14 = OpLoad %46 %32
+OpStore %6 %14
+%17 = OpLoad %43 %4
+%18 = OpLoad %46 %6
+%26 = OpIAdd %43 %17 %25
+%34 = OpConvertUToPtr %48 %26
+%35 = OpCopyObject %46 %18
+%33 = OpAtomicCompareExchange %46 %34 %49 %50 %50 %27 %35
+%16 = OpCopyObject %46 %33
+OpStore %6 %16
+%20 = OpLoad %43 %4
+%29 = OpIAdd %43 %20 %28
+%36 = OpConvertUToPtr %48 %29
+%19 = OpLoad %46 %36
+OpStore %7 %19
+%21 = OpLoad %43 %5
+%22 = OpLoad %46 %6
+%37 = OpConvertUToPtr %48 %21
+OpStore %37 %22
+%23 = OpLoad %43 %5
+%24 = OpLoad %46 %7
+%31 = OpIAdd %43 %23 %30
+%38 = OpConvertUToPtr %48 %31
+OpStore %38 %24
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/atom_inc.ptx b/ptx/src/test/spirv_run/atom_inc.ptx
new file mode 100644
index 0000000..ed3df08
--- /dev/null
+++ b/ptx/src/test/spirv_run/atom_inc.ptx
@@ -0,0 +1,26 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry atom_inc(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp1;
+ .reg .u32 temp2;
+ .reg .u32 temp3;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ atom.inc.u32 temp1, [in_addr], 101;
+ atom.global.inc.u32 temp2, [in_addr], 101;
+ ld.u32 temp3, [in_addr];
+ st.u32 [out_addr], temp1;
+ st.u32 [out_addr+4], temp2;
+ st.u32 [out_addr+8], temp3;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt
new file mode 100644
index 0000000..6948cd9
--- /dev/null
+++ b/ptx/src/test/spirv_run/atom_inc.spvtxt
@@ -0,0 +1,89 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 60
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
+%49 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "atom_inc"
+OpDecorate %40 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_generic_inc" Import
+OpDecorate %44 LinkageAttributes "__notcuda_ptx_impl__atom_relaxed_gpu_global_inc" Import
+%50 = OpTypeVoid
+%51 = OpTypeInt 32 0
+%52 = OpTypePointer Generic %51
+%53 = OpTypeFunction %51 %52 %51
+%54 = OpTypePointer CrossWorkgroup %51
+%55 = OpTypeFunction %51 %54 %51
+%56 = OpTypeInt 64 0
+%57 = OpTypeFunction %50 %56 %56
+%58 = OpTypePointer Function %56
+%59 = OpTypePointer Function %51
+%27 = OpConstant %51 101
+%28 = OpConstant %51 101
+%29 = OpConstant %56 4
+%31 = OpConstant %56 8
+%40 = OpFunction %51 None %53
+%42 = OpFunctionParameter %52
+%43 = OpFunctionParameter %51
+OpFunctionEnd
+%44 = OpFunction %51 None %55
+%46 = OpFunctionParameter %54
+%47 = OpFunctionParameter %51
+OpFunctionEnd
+%1 = OpFunction %50 None %57
+%9 = OpFunctionParameter %56
+%10 = OpFunctionParameter %56
+%39 = OpLabel
+%2 = OpVariable %58 Function
+%3 = OpVariable %58 Function
+%4 = OpVariable %58 Function
+%5 = OpVariable %58 Function
+%6 = OpVariable %59 Function
+%7 = OpVariable %59 Function
+%8 = OpVariable %59 Function
+OpStore %2 %9
+OpStore %3 %10
+%12 = OpLoad %56 %2
+%11 = OpCopyObject %56 %12
+OpStore %4 %11
+%14 = OpLoad %56 %3
+%13 = OpCopyObject %56 %14
+OpStore %5 %13
+%16 = OpLoad %56 %4
+%33 = OpConvertUToPtr %52 %16
+%15 = OpFunctionCall %51 %40 %33 %27
+OpStore %6 %15
+%18 = OpLoad %56 %4
+%34 = OpConvertUToPtr %54 %18
+%17 = OpFunctionCall %51 %44 %34 %28
+OpStore %7 %17
+%20 = OpLoad %56 %4
+%35 = OpConvertUToPtr %52 %20
+%19 = OpLoad %51 %35
+OpStore %8 %19
+%21 = OpLoad %56 %5
+%22 = OpLoad %51 %6
+%36 = OpConvertUToPtr %52 %21
+OpStore %36 %22
+%23 = OpLoad %56 %5
+%24 = OpLoad %51 %7
+%30 = OpIAdd %56 %23 %29
+%37 = OpConvertUToPtr %52 %30
+OpStore %37 %24
+%25 = OpLoad %56 %5
+%26 = OpLoad %51 %8
+%32 = OpIAdd %56 %25 %31
+%38 = OpConvertUToPtr %52 %32
+OpStore %38 %26
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/constant_f32.spvtxt b/ptx/src/test/spirv_run/constant_f32.spvtxt
index 905bec4..27c5f4e 100644
--- a/ptx/src/test/spirv_run/constant_f32.spvtxt
+++ b/ptx/src/test/spirv_run/constant_f32.spvtxt
@@ -11,12 +11,12 @@ OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
-OpCapability FunctionFloatControlINTEL
-OpExtension "SPV_INTEL_float_controls2"
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "constant_f32"
-OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
+; OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
%25 = OpTypeVoid
%26 = OpTypeInt 64 0
%27 = OpTypeFunction %25 %26 %26
diff --git a/ptx/src/test/spirv_run/constant_negative.spvtxt b/ptx/src/test/spirv_run/constant_negative.spvtxt
index 39e5d19..ec2ff72 100644
--- a/ptx/src/test/spirv_run/constant_negative.spvtxt
+++ b/ptx/src/test/spirv_run/constant_negative.spvtxt
@@ -11,8 +11,8 @@ OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
-OpCapability FunctionFloatControlINTEL
-OpExtension "SPV_INTEL_float_controls2"
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
%24 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "constant_negative"
diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt
index 734bf0f..4a90d09 100644
--- a/ptx/src/test/spirv_run/fma.spvtxt
+++ b/ptx/src/test/spirv_run/fma.spvtxt
@@ -11,12 +11,12 @@ OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
-OpCapability FunctionFloatControlINTEL
-OpExtension "SPV_INTEL_float_controls2"
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
%37 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "fma"
-OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
+; OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
%38 = OpTypeVoid
%39 = OpTypeInt 64 0
%40 = OpTypeFunction %38 %39 %39
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 98b9630..40a9d64 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -86,12 +86,20 @@ test_ptx!(rcp, [2f32], [0.5f32]);
// 0x3f000000 is 0.5
// TODO: mul_ftz fails because IGC does not yet handle SPV_INTEL_float_controls2
// test_ptx!(mul_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0u32]);
-test_ptx!(mul_non_ftz, [0b1_00000000_10000000000000000000000u32, 0x3f000000u32], [0b1_00000000_01000000000000000000000u32]);
+test_ptx!(
+ mul_non_ftz,
+ [0b1_00000000_10000000000000000000000u32, 0x3f000000u32],
+ [0b1_00000000_01000000000000000000000u32]
+);
test_ptx!(constant_f32, [10f32], [5f32]);
test_ptx!(constant_negative, [-101i32], [101i32]);
test_ptx!(and, [6u32, 3u32], [2u32]);
test_ptx!(selp, [100u16, 200u16], [200u16]);
-test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
+test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
+test_ptx!(shared_variable, [513u64], [513u64]);
+test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]);
+test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]);
+test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
struct DisplayError<T: Debug> {
err: T,
@@ -124,7 +132,7 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>(
let name = CString::new(name)?;
let result = run_spirv(name.as_c_str(), notcuda_module, input, output)
.map_err(|err| DisplayError { err })?;
- assert_eq!(output, result.as_slice());
+ assert_eq!(result.as_slice(), output);
Ok(())
}
@@ -145,8 +153,8 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
let use_shared_mem = module
.kernel_info
.get(name.to_str().unwrap())
- .unwrap()
- .uses_shared_mem;
+ .map(|info| info.uses_shared_mem)
+ .unwrap_or(false);
let mut result = vec![0u8.into(); output.len()];
{
let mut drivers = ze::Driver::get()?;
@@ -155,11 +163,20 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
let mut devices = drv.devices()?;
let dev = devices.drain(0..1).next().unwrap();
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
- let (module, log) = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None);
+ let (module, maybe_log) = match module.should_link_ptx_impl {
+ Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]),
+ None => {
+ let (module, log) = ze::Module::build_spirv(&mut ctx, &dev, byte_il, None);
+ (module, Some(log))
+ }
+ };
let module = match module {
Ok(m) => m,
Err(err) => {
- let raw_err_string = log.get_cstring()?;
+ let raw_err_string = maybe_log
+ .map(|log| log.get_cstring())
+ .transpose()?
+ .unwrap_or(CString::default());
let err_string = raw_err_string.to_string_lossy();
panic!("{:?}\n{}", err, err_string);
}
@@ -215,7 +232,11 @@ fn test_spvtxt_assert<'a>(
ptr::null_mut(),
)
};
- assert!(result == spv_result_t::SPV_SUCCESS);
+ if result != spv_result_t::SPV_SUCCESS {
+ panic!("{:?}\n{}", result, unsafe {
+ str::from_utf8_unchecked(spirv_txt)
+ });
+ }
let mut parsed_spirv = Vec::<u32>::new();
let result = unsafe {
spirv_tools::spvBinaryParse(
diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt
index da6a12a..56cec5a 100644
--- a/ptx/src/test/spirv_run/mul_ftz.spvtxt
+++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt
@@ -11,8 +11,8 @@ OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
-OpCapability FunctionFloatControlINTEL
-OpExtension "SPV_INTEL_float_controls2"
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
%30 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "mul_ftz"
diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt
index dffd9af..6f73bc2 100644
--- a/ptx/src/test/spirv_run/selp.spvtxt
+++ b/ptx/src/test/spirv_run/selp.spvtxt
@@ -11,8 +11,8 @@ OpCapability Int16
OpCapability Int64
OpCapability Float16
OpCapability Float64
-OpCapability FunctionFloatControlINTEL
-OpExtension "SPV_INTEL_float_controls2"
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
%31 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "selp"
diff --git a/ptx/src/test/spirv_run/shared_variable.ptx b/ptx/src/test/spirv_run/shared_variable.ptx
new file mode 100644
index 0000000..4f7eff3
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_variable.ptx
@@ -0,0 +1,26 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+
+.visible .entry shared_variable(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .shared .align 4 .b8 shared_mem1[128];
+
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u64 temp1, [in_addr];
+ st.shared.u64 [shared_mem1], temp1;
+ ld.shared.u64 temp2, [shared_mem1];
+ st.global.u64 [out_addr], temp2;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/shared_variable.spvtxt b/ptx/src/test/spirv_run/shared_variable.spvtxt
new file mode 100644
index 0000000..1af2bd1
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_variable.spvtxt
@@ -0,0 +1,65 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 39
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
+%27 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "shared_variable" %4
+OpDecorate %4 Alignment 4
+%28 = OpTypeVoid
+%29 = OpTypeInt 32 0
+%30 = OpTypeInt 8 0
+%31 = OpConstant %29 128
+%32 = OpTypeArray %30 %31
+%33 = OpTypePointer Workgroup %32
+%4 = OpVariable %33 Workgroup
+%34 = OpTypeInt 64 0
+%35 = OpTypeFunction %28 %34 %34
+%36 = OpTypePointer Function %34
+%37 = OpTypePointer CrossWorkgroup %34
+%38 = OpTypePointer Workgroup %34
+%1 = OpFunction %28 None %35
+%9 = OpFunctionParameter %34
+%10 = OpFunctionParameter %34
+%25 = OpLabel
+%2 = OpVariable %36 Function
+%3 = OpVariable %36 Function
+%5 = OpVariable %36 Function
+%6 = OpVariable %36 Function
+%7 = OpVariable %36 Function
+%8 = OpVariable %36 Function
+OpStore %2 %9
+OpStore %3 %10
+%12 = OpLoad %34 %2
+%11 = OpCopyObject %34 %12
+OpStore %5 %11
+%14 = OpLoad %34 %3
+%13 = OpCopyObject %34 %14
+OpStore %6 %13
+%16 = OpLoad %34 %5
+%21 = OpConvertUToPtr %37 %16
+%15 = OpLoad %34 %21
+OpStore %7 %15
+%17 = OpLoad %34 %7
+%22 = OpBitcast %38 %4
+OpStore %22 %17
+%23 = OpBitcast %38 %4
+%18 = OpLoad %34 %23
+OpStore %8 %18
+%19 = OpLoad %34 %6
+%20 = OpLoad %34 %8
+%24 = OpConvertUToPtr %37 %19
+OpStore %24 %20
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index a7025b1..6b07c0f 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,14 +1,13 @@
use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
+use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, hash::Hash, iter, mem};
-use std::{
- collections::{hash_map, HashMap, HashSet},
- convert::TryFrom,
-};
use rspirv::binary::Assemble;
+static NOTCUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/notcuda_ptx_impl.spv");
+
quick_error! {
#[derive(Debug)]
pub enum TranslateError {
@@ -69,6 +68,7 @@ impl Into<spirv::StorageClass> for ast::PointerStateSpace {
ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
ast::PointerStateSpace::Param => spirv::StorageClass::Function,
+ ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
}
}
}
@@ -419,6 +419,7 @@ impl TypeWordMap {
pub struct Module {
pub spirv: dr::Module,
pub kernel_info: HashMap<String, KernelInfo>,
+ pub should_link_ptx_impl: Option<&'static [u8]>,
}
pub struct KernelInfo {
@@ -428,15 +429,22 @@ pub struct KernelInfo {
pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateError> {
let mut id_defs = GlobalStringIdResolver::new(1);
+ let mut ptx_impl_imports = HashMap::new();
let directives = ast
.directives
.into_iter()
- .map(|f| translate_directive(&mut id_defs, f))
+ .map(|directive| translate_directive(&mut id_defs, &mut ptx_impl_imports, directive))
.collect::<Result<Vec<_>, _>>()?;
+ let must_link_ptx_impl = ptx_impl_imports.len() > 0;
+ let directives = ptx_impl_imports
+ .into_iter()
+ .map(|(_, v)| v)
+ .chain(directives.into_iter())
+ .collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
- let mut directives =
- convert_dynamic_shared_memory_usage(&mut id_defs, directives, &mut || builder.id());
+ let call_map = get_call_map(&directives);
+ let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@@ -448,32 +456,142 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
- for d in directives {
+ emit_directives(
+ &mut builder,
+ &mut map,
+ &id_defs,
+ opencl_id,
+ &denorm_information,
+ &call_map,
+ directives,
+ &mut kernel_info,
+ )?;
+ let spirv = builder.module();
+ Ok(Module {
+ spirv,
+ kernel_info,
+ should_link_ptx_impl: if must_link_ptx_impl {
+ Some(NOTCUDA_PTX_IMPL)
+ } else {
+ None
+ },
+ })
+}
+
+fn emit_directives<'input>(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver<'input>,
+ opencl_id: spirv::Word,
+ denorm_information: &HashMap<CallgraphKey<'input>, HashMap<u8, spirv::FPDenormMode>>,
+ call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
+ directives: Vec<Directive>,
+ kernel_info: &mut HashMap<String, KernelInfo>,
+) -> Result<(), TranslateError> {
+ let empty_body = Vec::new();
+ for d in directives.iter() {
match d {
Directive::Variable(var) => {
- emit_variable(&mut builder, &mut map, &var)?;
+ emit_variable(builder, map, &var)?;
}
Directive::Method(f) => {
- let f_body = match f.body {
+ let f_body = match &f.body {
Some(f) => f,
- None => continue,
+ None => {
+ if f.import_as.is_some() {
+ &empty_body
+ } else {
+ continue;
+ }
+ }
};
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
+ for var in f.globals.iter() {
+ emit_variable(builder, map, var)?;
+ }
emit_function_header(
- &mut builder,
- &mut map,
+ builder,
+ map,
&id_defs,
- f.func_decl,
+ &f.globals,
+ &f.func_decl,
&denorm_information,
- &mut kernel_info,
+ call_map,
+ &directives,
+ kernel_info,
)?;
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
+ emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
+ if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
+ (&f.func_decl, &f.import_as)
+ {
+ builder.decorate(
+ *fn_id,
+ spirv::Decoration::LinkageAttributes,
+ &[
+ dr::Operand::LiteralString(name.clone()),
+ dr::Operand::LinkageType(spirv::LinkageType::Import),
+ ],
+ );
+ }
}
}
}
- let spirv = builder.module();
- Ok(Module { spirv, kernel_info })
+ Ok(())
+}
+
+fn get_call_map<'input>(
+ module: &[Directive<'input>],
+) -> HashMap<&'input str, HashSet<spirv::Word>> {
+ let mut directly_called_by = HashMap::new();
+ for directive in module {
+ match directive {
+ Directive::Method(Function {
+ func_decl,
+ body: Some(statements),
+ ..
+ }) => {
+ let call_key = CallgraphKey::new(&func_decl);
+ for statement in statements {
+ match statement {
+ Statement::Call(call) => {
+ multi_hash_map_append(&mut directly_called_by, call_key, call.func);
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ let mut result = HashMap::new();
+ for (method_key, children) in directly_called_by.iter() {
+ match method_key {
+ CallgraphKey::Kernel(name) => {
+ let mut visited = HashSet::new();
+ for child in children {
+ add_call_map_single(&directly_called_by, &mut visited, *child);
+ }
+ result.insert(*name, visited);
+ }
+ CallgraphKey::Func(_) => {}
+ }
+ }
+ result
+}
+
+fn add_call_map_single<'input>(
+ directly_called_by: &MultiHashMap<CallgraphKey<'input>, spirv::Word>,
+ visited: &mut HashSet<spirv::Word>,
+ current: spirv::Word,
+) {
+ if !visited.insert(current) {
+ return;
+ }
+ if let Some(children) = directly_called_by.get(&CallgraphKey::Func(current)) {
+ for child in children {
+ add_call_map_single(directly_called_by, visited, *child);
+ }
+ }
}
type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
@@ -495,7 +613,6 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
// This pass looks for all uses of .extern .shared and converts them to
// an additional method argument
fn convert_dynamic_shared_memory_usage<'input>(
- id_defs: &mut GlobalStringIdResolver<'input>,
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
) -> Vec<Directive<'input>> {
@@ -524,6 +641,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
func_decl,
globals,
body: Some(statements),
+ import_as,
}) => {
let call_key = CallgraphKey::new(&func_decl);
let statements = statements
@@ -545,6 +663,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
func_decl,
globals,
body: Some(statements),
+ import_as,
})
}
directive => directive,
@@ -561,6 +680,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
mut func_decl,
globals,
body: Some(statements),
+ import_as,
}) => {
let call_key = CallgraphKey::new(&func_decl);
if !methods_using_extern_shared.contains(&call_key) {
@@ -568,6 +688,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
func_decl,
globals,
body: Some(statements),
+ import_as,
});
}
let shared_id_param = new_id();
@@ -625,6 +746,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
func_decl,
globals,
body: Some(new_statements),
+ import_as,
})
}
directive => directive,
@@ -744,15 +866,6 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
}
}
-fn denorm_count_map_merge<T: Eq + Hash + Copy>(
- dst: &mut DenormCountMap<T>,
- src: &DenormCountMap<T>,
-) {
- for (k, count) in src {
- denorm_count_map_update_impl(dst, *k, *count);
- }
-}
-
// HACK ALERT!
// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
// in the kernel as flushing denorms to zero or preserving them
@@ -763,7 +876,7 @@ fn compute_denorm_information<'input>(
module: &[Directive<'input>],
) -> HashMap<CallgraphKey<'input>, HashMap<u8, spirv::FPDenormMode>> {
let mut denorm_methods = HashMap::new();
- for directive in module.iter() {
+ for directive in module {
match directive {
Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
@@ -861,9 +974,12 @@ fn emit_builtins(
fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- global: &GlobalStringIdResolver<'a>,
- func_directive: ast::MethodDecl<spirv::Word>,
+ defined_globals: &GlobalStringIdResolver<'a>,
+ synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
+ func_directive: &ast::MethodDecl<spirv::Word>,
denorm_information: &HashMap<CallgraphKey<'a>, HashMap<u8, spirv::FPDenormMode>>,
+ call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
+ direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
if let ast::MethodDecl::Kernel {
@@ -884,22 +1000,49 @@ fn emit_function_header<'a>(
let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
let fn_id = match func_directive {
ast::MethodDecl::Kernel { name, .. } => {
- let fn_id = global.get_id(name)?;
- let mut global_variables = global
+ let fn_id = defined_globals.get_id(name)?;
+ let mut global_variables = defined_globals
.variables_type_check
.iter()
.filter_map(|(k, t)| t.as_ref().map(|_| *k))
.collect::<Vec<_>>();
- let mut interface = global
+ let mut interface = defined_globals
.special_registers
.iter()
.map(|(_, id)| *id)
.collect::<Vec<_>>();
+ for ast::Variable { name, .. } in synthetic_globals {
+ interface.push(*name);
+ }
+ let empty_hash_set = HashSet::new();
+ let child_fns = call_map.get(name).unwrap_or(&empty_hash_set);
+ for directive in direcitves {
+ match directive {
+ Directive::Method(Function {
+ func_decl: ast::MethodDecl::Func(_, name, _),
+ globals,
+ ..
+ }) => {
+ if child_fns.contains(name) {
+ for var in globals {
+ interface.push(var.name);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+
global_variables.append(&mut interface);
- builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
+ builder.entry_point(
+ spirv::ExecutionModel::Kernel,
+ fn_id,
+ *name,
+ global_variables,
+ );
fn_id
}
- ast::MethodDecl::Func(_, name, _) => name,
+ ast::MethodDecl::Func(_, name, _) => *name,
};
builder.begin_function(
ret_type,
@@ -934,9 +1077,10 @@ fn emit_function_header<'a>(
pub fn to_spirv<'a>(
ast: ast::Module<'a>,
-) -> Result<(Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
+) -> Result<(Option<&'static [u8]>, Vec<u32>, HashMap<String, Vec<usize>>), TranslateError> {
let module = to_spirv_module(ast)?;
Ok((
+ module.should_link_ptx_impl,
module.spirv.assemble(),
module
.kernel_info
@@ -977,11 +1121,14 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn translate_directive<'input>(
id_defs: &mut GlobalStringIdResolver<'input>,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Directive<'input>, TranslateError> {
Ok(match d {
ast::Directive::Variable(v) => Directive::Variable(translate_variable(id_defs, v)?),
- ast::Directive::Method(f) => Directive::Method(translate_function(id_defs, f)?),
+ ast::Directive::Method(f) => {
+ Directive::Method(translate_function(id_defs, ptx_impl_imports, f)?)
+ }
})
}
@@ -1000,10 +1147,11 @@ fn translate_variable<'a>(
fn translate_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
f: ast::ParsedFunction<'a>,
) -> Result<Function<'a>, TranslateError> {
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive);
- to_ssa(str_resolver, fn_resolver, fn_decl, f.body)
+ to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)
}
fn expand_kernel_params<'a, 'b>(
@@ -1043,6 +1191,7 @@ fn expand_fn_params<'a, 'b>(
}
fn to_ssa<'input, 'b>(
+ ptx_impl_imports: &mut HashMap<String, Directive>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
f_args: ast::MethodDecl<'input, spirv::Word>,
@@ -1055,6 +1204,7 @@ fn to_ssa<'input, 'b>(
func_decl: f_args,
body: None,
globals: Vec::new(),
+ import_as: None,
})
}
};
@@ -1071,19 +1221,90 @@ fn to_ssa<'input, 'b>(
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
- let (f_body, globals) = extract_globals(labeled_statements);
+ let (f_body, globals) =
+ extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
Ok(Function {
func_decl: f_args,
globals: globals,
body: Some(f_body),
+ import_as: None,
})
}
-fn extract_globals(
+fn extract_globals<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
-) -> (Vec<ExpandedStatement>, Vec<ExpandedStatement>) {
- // This fn will be used for SLM
- (sorted_statements, Vec::new())
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ id_def: &mut NumericIdResolver,
+) -> (
+ Vec<ExpandedStatement>,
+ Vec<ast::Variable<ast::VariableType, spirv::Word>>,
+) {
+ let mut local = Vec::with_capacity(sorted_statements.len());
+ let mut global = Vec::new();
+ for statement in sorted_statements {
+ match statement {
+ Statement::Variable(
+ var
+ @
+ ast::Variable {
+ v_type: ast::VariableType::Shared(_),
+ ..
+ },
+ )
+ | Statement::Variable(
+ var
+ @
+ ast::Variable {
+ v_type: ast::VariableType::Global(_),
+ ..
+ },
+ ) => global.push(var),
+ Statement::Instruction(ast::Instruction::Atom(
+ d
+ @
+ ast::AtomDetails {
+ inner:
+ ast::AtomInnerDetails::Unsigned {
+ op: ast::AtomUIntOp::Inc,
+ ..
+ },
+ ..
+ },
+ a,
+ )) => {
+ local.push(to_ptx_impl_atomic_call(
+ id_def,
+ ptx_impl_imports,
+ d,
+ a,
+ "inc",
+ ));
+ }
+ Statement::Instruction(ast::Instruction::Atom(
+ d
+ @
+ ast::AtomDetails {
+ inner:
+ ast::AtomInnerDetails::Unsigned {
+ op: ast::AtomUIntOp::Dec,
+ ..
+ },
+ ..
+ },
+ a,
+ )) => {
+ local.push(to_ptx_impl_atomic_call(
+ id_def,
+ ptx_impl_imports,
+ d,
+ a,
+ "dec",
+ ));
+ }
+ s => local.push(s),
+ }
+ }
+ (local, global)
}
fn normalize_variable_decls(directives: &mut Vec<Directive>) {
@@ -1269,6 +1490,15 @@ fn convert_to_typed_statements(
ast::Instruction::Selp(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast())))
}
+ ast::Instruction::Bar(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast())))
+ }
+ ast::Instruction::Atom(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast())))
+ }
+ ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
+ ast::Instruction::AtomCas(d, a.cast()),
+ )),
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1286,6 +1516,99 @@ fn convert_to_typed_statements(
Ok(result)
}
+fn to_ptx_impl_atomic_call(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ details: ast::AtomDetails,
+ arg: ast::Arg3<ExpandedArgParams>,
+ op: &'static str,
+) -> ExpandedStatement {
+ let semantics = ptx_semantics_name(details.semantics);
+ let scope = ptx_scope_name(details.scope);
+ let space = ptx_space_name(details.space);
+ let fn_name = format!(
+ "__notcuda_ptx_impl__atom_{}_{}_{}_{}",
+ semantics, scope, space, op
+ );
+ // TODO: extract to a function
+ let ptr_space = match details.space {
+ ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
+ ast::AtomSpace::Global => ast::PointerStateSpace::Global,
+ ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
+ };
+ let fn_id = match ptx_impl_imports.entry(fn_name) {
+ hash_map::Entry::Vacant(entry) => {
+ let fn_id = id_defs.new_id(None);
+ let func_decl = ast::MethodDecl::Func::<spirv::Word>(
+ vec![ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
+ ast::ScalarType::U32,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ }],
+ fn_id,
+ vec![
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
+ ast::SizedScalarType::U32,
+ ptr_space,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
+ ast::ScalarType::U32,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ],
+ );
+ let func = Function {
+ func_decl,
+ globals: Vec::new(),
+ body: None,
+ import_as: Some(entry.key().clone()),
+ };
+ entry.insert(Directive::Method(func));
+ fn_id
+ }
+ hash_map::Entry::Occupied(entry) => match entry.get() {
+ Directive::Method(Function {
+ func_decl: ast::MethodDecl::Func(_, name, _),
+ ..
+ }) => *name,
+ _ => unreachable!(),
+ },
+ };
+ Statement::Call(ResolvedCall {
+ uniform: false,
+ func: fn_id,
+ ret_params: vec![(
+ arg.dst,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ )],
+ param_list: vec![
+ (
+ arg.src1,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
+ ast::SizedScalarType::U32,
+ ptr_space,
+ )),
+ ),
+ (
+ arg.src2,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ),
+ ],
+ })
+}
+
fn to_resolved_fn_args<T>(
params: Vec<T>,
params_decl: &[ast::FnArgumentType],
@@ -1529,6 +1852,9 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
| (t, ArgumentSemantics::DefaultRelaxed)
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
+ if let ast::Type::Array(_, _) = id_type {
+ return Ok(desc.op);
+ }
let generated_id = id_def.new_id(id_type.clone());
if !desc.is_dst {
result.push(Statement::LoadVar(
@@ -1916,6 +2242,12 @@ fn insert_implicit_conversions(
if let ast::Instruction::St(d, _) = &inst {
state_space = Some(d.state_space.to_ld_ss());
}
+ if let ast::Instruction::Atom(d, _) = &inst {
+ state_space = Some(d.space.to_ld_ss());
+ }
+ if let ast::Instruction::AtomCas(d, _) = &inst {
+ state_space = Some(d.space.to_ld_ss());
+ }
if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
default_conversion_fn = should_bitcast_packed;
}
@@ -2387,6 +2719,52 @@ fn emit_function_body_ops(
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
builder.select(result_type, Some(a.dst), a.src3, a.src2, a.src2)?;
}
+ // TODO: implement named barriers
+ ast::Instruction::Bar(d, _) => {
+ let workgroup_scope = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(spirv::Scope::Workgroup as u32),
+ )?;
+ let barrier_semantics = match d {
+ ast::BarDetails::SyncAligned => map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ )?,
+ };
+ builder.control_barrier(workgroup_scope, workgroup_scope, barrier_semantics)?;
+ }
+ ast::Instruction::Atom(details, arg) => {
+ emit_atom(builder, map, details, arg)?;
+ }
+ ast::Instruction::AtomCas(details, arg) => {
+ let result_type = map.get_or_add_scalar(builder, details.typ.into());
+ let memory_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(details.scope.to_spirv() as u32),
+ )?;
+ let semantics_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(details.semantics.to_spirv().bits()),
+ )?;
+ builder.atomic_compare_exchange(
+ result_type,
+ Some(arg.dst),
+ arg.src1,
+ memory_const,
+ semantics_const,
+ semantics_const,
+ arg.src3,
+ arg.src2,
+ )?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -2417,6 +2795,99 @@ fn emit_function_body_ops(
Ok(())
}
+fn emit_atom(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ details: &ast::AtomDetails,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), TranslateError> {
+ let (spirv_op, typ) = match details.inner {
+ ast::AtomInnerDetails::Bit { op, typ } => {
+ let spirv_op = match op {
+ ast::AtomBitOp::And => dr::Builder::atomic_and,
+ ast::AtomBitOp::Or => dr::Builder::atomic_or,
+ ast::AtomBitOp::Xor => dr::Builder::atomic_xor,
+ ast::AtomBitOp::Exchange => dr::Builder::atomic_exchange,
+ };
+ (spirv_op, ast::ScalarType::from(typ))
+ }
+ ast::AtomInnerDetails::Unsigned { op, typ } => {
+ let spirv_op = match op {
+ ast::AtomUIntOp::Add => dr::Builder::atomic_i_add,
+ ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => {
+ return Err(TranslateError::Unreachable);
+ }
+ ast::AtomUIntOp::Min => dr::Builder::atomic_u_min,
+ ast::AtomUIntOp::Max => dr::Builder::atomic_u_max,
+ };
+ (spirv_op, typ.into())
+ }
+ ast::AtomInnerDetails::Signed { op, typ } => {
+ let spirv_op = match op {
+ ast::AtomSIntOp::Add => dr::Builder::atomic_i_add,
+ ast::AtomSIntOp::Min => dr::Builder::atomic_s_min,
+ ast::AtomSIntOp::Max => dr::Builder::atomic_s_max,
+ };
+ (spirv_op, typ.into())
+ }
+ // TODO: Hardware is capable of this, implement it through builtin
+ ast::AtomInnerDetails::Float { .. } => todo!(),
+ };
+ let result_type = map.get_or_add_scalar(builder, typ);
+ let memory_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(details.scope.to_spirv() as u32),
+ )?;
+ let semantics_const = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(details.semantics.to_spirv().bits()),
+ )?;
+ spirv_op(
+ builder,
+ result_type,
+ Some(arg.dst),
+ arg.src1,
+ memory_const,
+ semantics_const,
+ arg.src2,
+ )?;
+ Ok(())
+}
+
+#[derive(Clone)]
+struct PtxImplImport {
+ out_arg: ast::Type,
+ fn_id: u32,
+ in_args: Vec<ast::Type>,
+}
+
+fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str {
+ match sema {
+ ast::AtomSemantics::Relaxed => "relaxed",
+ ast::AtomSemantics::Acquire => "acquire",
+ ast::AtomSemantics::Release => "release",
+ ast::AtomSemantics::AcquireRelease => "acq_rel",
+ }
+}
+
+fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
+ match scope {
+ ast::MemScope::Cta => "cta",
+ ast::MemScope::Gpu => "gpu",
+ ast::MemScope::Sys => "sys",
+ }
+}
+
+fn ptx_space_name(space: ast::AtomSpace) -> &'static str {
+ match space {
+ ast::AtomSpace::Generic => "generic",
+ ast::AtomSpace::Global => "global",
+ ast::AtomSpace::Shared => "shared",
+ }
+}
+
fn emit_mul_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -2652,7 +3123,7 @@ fn emit_cvt(
map: &mut TypeWordMap,
dets: &ast::CvtDetails,
arg: &ast::Arg2<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
+) -> Result<(), TranslateError> {
match dets {
ast::CvtDetails::FloatFromFloat(desc) => {
if desc.dst == desc.src {
@@ -3011,7 +3482,7 @@ fn emit_implicit_conversion(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
cv: &ImplicitConversion,
-) -> Result<(), dr::Error> {
+) -> Result<(), TranslateError> {
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
@@ -3019,7 +3490,7 @@ fn emit_implicit_conversion(
let dst_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
}
- (_, _, ConversionKind::BitToPtr(space)) => {
+ (_, _, ConversionKind::BitToPtr(_)) => {
let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
@@ -3782,8 +4253,9 @@ enum Directive<'input> {
struct Function<'input> {
pub func_decl: ast::MethodDecl<'input, spirv::Word>,
- pub globals: Vec<ExpandedStatement>,
+ pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
+ import_as: Option<String>,
}
pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
@@ -4091,6 +4563,13 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?),
+ ast::Instruction::Bar(d, a) => ast::Instruction::Bar(d, a.map(visitor)?),
+ ast::Instruction::Atom(d, a) => {
+ ast::Instruction::Atom(d, a.map_atom(visitor, d.inner.get_type(), d.space)?)
+ }
+ ast::Instruction::AtomCas(d, a) => {
+ ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
+ }
})
}
}
@@ -4337,6 +4816,9 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Rcp(_, _)
| ast::Instruction::And(_, _)
| ast::Instruction::Selp(_, _)
+ | ast::Instruction::Bar(_, _)
+ | ast::Instruction::Atom(_, _)
+ | ast::Instruction::AtomCas(_, _)
| ast::Instruction::Mad(_, _) => None,
}
}
@@ -4358,6 +4840,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::And(_, _) => None,
ast::Instruction::Cvta(_, _) => None,
ast::Instruction::Selp(_, _) => None,
+ ast::Instruction::Bar(_, _) => None,
+ ast::Instruction::Atom(_, _) => None,
+ ast::Instruction::AtomCas(_, _) => None,
ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None,
ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None,
@@ -4612,6 +5097,27 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
}
}
+impl<T: ArgParamsEx> ast::Arg1Bar<T> {
+ fn cast<U: ArgParamsEx<Operand = T::Operand>>(self) -> ast::Arg1Bar<U> {
+ ast::Arg1Bar { src: self.src }
+ }
+
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<ast::Arg1Bar<U>, TranslateError> {
+ let new_src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ )?;
+ Ok(ast::Arg1Bar { src: new_src })
+ }
+}
+
impl<T: ArgParamsEx> ast::Arg2<T> {
fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg2<U> {
ast::Arg2 {
@@ -5022,6 +5528,43 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
+
+ fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ t: ast::ScalarType,
+ state_space: ast::AtomSpace,
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let scalar_type = ast::ScalarType::from(t);
+ let dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&ast::Type::Scalar(scalar_type)),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::PhysicalPointer,
+ },
+ &ast::Type::Pointer(
+ ast::PointerType::Scalar(scalar_type),
+ state_space.to_ld_ss(),
+ ),
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(scalar_type),
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
+ }
}
impl<T: ArgParamsEx> ast::Arg4<T> {
@@ -5129,6 +5672,56 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
src3,
})
}
+
+ fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ t: ast::BitType,
+ state_space: ast::AtomSpace,
+ ) -> Result<ast::Arg4<U>, TranslateError> {
+ let scalar_type = ast::ScalarType::from(t);
+ let dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&ast::Type::Scalar(scalar_type)),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::PhysicalPointer,
+ },
+ &ast::Type::Pointer(
+ ast::PointerType::Scalar(scalar_type),
+ state_space.to_ld_ss(),
+ ),
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(scalar_type),
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(scalar_type),
+ )?;
+ Ok(ast::Arg4 {
+ dst,
+ src1,
+ src2,
+ src3,
+ })
+ }
}
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
@@ -5434,6 +6027,17 @@ impl ast::MinMaxDetails {
}
}
+impl ast::AtomInnerDetails {
+ fn get_type(&self) -> ast::ScalarType {
+ match self {
+ ast::AtomInnerDetails::Bit { typ, .. } => (*typ).into(),
+ ast::AtomInnerDetails::Unsigned { typ, .. } => (*typ).into(),
+ ast::AtomInnerDetails::Signed { typ, .. } => (*typ).into(),
+ ast::AtomInnerDetails::Float { typ, .. } => (*typ).into(),
+ }
+ }
+}
+
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
@@ -5509,6 +6113,37 @@ impl ast::MulDetails {
}
}
+impl ast::AtomSpace {
+ fn to_ld_ss(self) -> ast::LdStateSpace {
+ match self {
+ ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
+ ast::AtomSpace::Global => ast::LdStateSpace::Global,
+ ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
+ }
+ }
+}
+
+impl ast::MemScope {
+ fn to_spirv(self) -> spirv::Scope {
+ match self {
+ ast::MemScope::Cta => spirv::Scope::Workgroup,
+ ast::MemScope::Gpu => spirv::Scope::Device,
+ ast::MemScope::Sys => spirv::Scope::CrossDevice,
+ }
+ }
+}
+
+impl ast::AtomSemantics {
+ fn to_spirv(self) -> spirv::MemorySemantics {
+ match self {
+ ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED,
+ ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE,
+ ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE,
+ ast::AtomSemantics::AcquireRelease => spirv::MemorySemantics::ACQUIRE_RELEASE,
+ }
+ }
+}
+
fn bitcast_logical_pointer(
operand: &ast::Type,
instr: &ast::Type,
@@ -5528,7 +6163,27 @@ fn bitcast_physical_pointer(
) -> Result<Option<ConversionKind>, TranslateError> {
match operand_type {
// array decays to a pointer
- ast::Type::Array(_, _) => todo!(),
+ ast::Type::Array(op_scalar_t, _) => {
+ if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
+ if ss == Some(*instr_space) {
+ if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) {
+ Ok(None)
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
+ }
+ } else {
+ if ss == Some(ast::LdStateSpace::Generic)
+ || *instr_space == ast::LdStateSpace::Generic
+ {
+ Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ }
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ }
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => {