aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-02 22:47:27 +0200
committerAndrzej Janik <[email protected]>2020-09-02 22:47:27 +0200
commit0f4a4c634b3dd9e1117cb843fcde59498ac2ae07 (patch)
tree36a16a1c75c0989a215e7c59b4682868b9fbf433 /ptx
parentefd83981b8d4d26f25389db933bf70756f060f37 (diff)
downloadZLUDA-0f4a4c634b3dd9e1117cb843fcde59498ac2ae07.tar.gz
ZLUDA-0f4a4c634b3dd9e1117cb843fcde59498ac2ae07.zip
Add support for declaring __local variables and their alignment
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/test/spirv_run/local_align.ptx21
-rw-r--r--ptx/src/test/spirv_run/local_align.spvtxt38
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/translate.rs68
4 files changed, 105 insertions, 23 deletions
diff --git a/ptx/src/test/spirv_run/local_align.ptx b/ptx/src/test/spirv_run/local_align.ptx
new file mode 100644
index 0000000..6e10de3
--- /dev/null
+++ b/ptx/src/test/spirv_run/local_align.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry local_align(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .local .align 8 .b8 __local_depot0[8];
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ st.u64 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/local_align.spvtxt b/ptx/src/test/spirv_run/local_align.spvtxt
new file mode 100644
index 0000000..beefb76
--- /dev/null
+++ b/ptx/src/test/spirv_run/local_align.spvtxt
@@ -0,0 +1,38 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %1 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %5 "local_align"
+ OpDecorate %8 Alignment 8
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %4 = OpTypeFunction %void %ulong %ulong
+ %uchar = OpTypeInt 8 0
+%_arr_uchar_8 = OpTypeArray %uchar %8
+%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %5 = OpFunction %void None %4
+ %6 = OpFunctionParameter %ulong
+ %7 = OpFunctionParameter %ulong
+ %18 = OpLabel
+ %8 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup
+ %9 = OpVariable %_ptr_Function_ulong Function
+ %10 = OpVariable %_ptr_Function_ulong Function
+ %11 = OpVariable %_ptr_Function_ulong Function
+ OpStore %9 %6
+ OpStore %10 %7
+ %13 = OpLoad %ulong %9
+ %16 = OpConvertUToPtr %_ptr_Generic_ulong %13
+ %12 = OpLoad %ulong %16
+ OpStore %11 %12
+ %14 = OpLoad %ulong %10
+ %15 = OpLoad %ulong %11
+ %17 = OpConvertUToPtr %_ptr_Generic_ulong %14
+ OpStore %17 %15
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 23852a1..8883669 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -51,6 +51,7 @@ test_ptx!(shl, [11u64], [44u64]);
test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]);
+test_ptx!(local_align, [1u64], [1u64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 3fe01cf..642e6ec 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -5,20 +5,21 @@ use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
-#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
- Pointer(SpirvScalarKey, spirv::StorageClass),
+ Array(SpirvScalarKey, u32),
+ Pointer(Box<SpirvType>, spirv::StorageClass),
}
impl SpirvType {
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
let key = match t {
- ast::Type::Scalar(typ) => SpirvScalarKey::from(typ),
- ast::Type::ExtendedScalar(typ) => SpirvScalarKey::from(typ),
- ast::Type::Array(_, _) => todo!(),
+ ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
+ ast::Type::ExtendedScalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
+ ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
};
- SpirvType::Pointer(key, sc)
+ SpirvType::Pointer(Box::new(key), sc)
}
}
@@ -27,7 +28,7 @@ impl From<ast::Type> for SpirvType {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::ExtendedScalar(t) => SpirvType::Base(t.into()),
- ast::Type::Array(_, _) => todo!(),
+ ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
}
}
}
@@ -126,13 +127,20 @@ impl TypeWordMap {
fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
match t {
SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
- SpirvType::Pointer(typ, storage) => {
- let base = self.get_or_add_spirv_scalar(b, typ);
+ SpirvType::Pointer(ref typ, storage) => {
+ let base = self.get_or_add(b, *typ.clone());
*self
.complex
.entry(t)
.or_insert_with(|| b.type_pointer(None, storage, base))
}
+ SpirvType::Array(typ, len) => {
+ let base = self.get_or_add_spirv_scalar(b, typ);
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| b.type_array(base, len))
+ }
}
}
@@ -248,7 +256,7 @@ fn normalize_labels(
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
- Statement::Variable(_, _, _)
+ Statement::Variable(_, _, _, _)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Conversion(_)
@@ -298,9 +306,9 @@ fn normalize_predicates(
result.push(Statement::Instruction(inst));
}
}
- ast::Statement::Variable(var) => {
- result.push(Statement::Variable(var.name, var.v_type, var.space))
- }
+ ast::Statement::Variable(var) => result.push(Statement::Variable(
+ var.name, var.v_type, var.space, var.align,
+ )),
// Blocks are flattened when resolving ids
ast::Statement::Block(_) => unreachable!(),
}
@@ -373,7 +381,7 @@ fn insert_mem_ssa_statements(
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
- s @ Statement::Variable(_, _, _) | s @ Statement::Label(_) => result.push(s),
+ s @ Statement::Variable(_, _, _, _) | s @ Statement::Label(_) => result.push(s),
Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Conversion(_)
@@ -395,7 +403,9 @@ fn expand_arguments(
let new_inst = inst.map(&mut visitor);
result.push(Statement::Instruction(new_inst));
}
- Statement::Variable(id, typ, ss) => result.push(Statement::Variable(id, typ, ss)),
+ Statement::Variable(id, typ, ss, align) => {
+ result.push(Statement::Variable(id, typ, ss, align))
+ }
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
@@ -555,7 +565,7 @@ fn insert_implicit_conversions(
s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
- | s @ Statement::Variable(_, _, _)
+ | s @ Statement::Variable(_, _, _, _)
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _) => result.push(s),
Statement::Conversion(_) => unreachable!(),
@@ -614,15 +624,24 @@ fn emit_function_body_ops(
}
match s {
Statement::Label(_) => (),
- Statement::Variable(id, typ, ss) => {
+ Statement::Variable(id, typ, ss, align) => {
let type_id = map.get_or_add(
builder,
SpirvType::new_pointer(*typ, spirv::StorageClass::Function),
);
- if *ss != ast::StateSpace::Reg {
- todo!()
+ let st_class = match ss {
+ ast::StateSpace::Reg => spirv::StorageClass::Function,
+ ast::StateSpace::Local => spirv::StorageClass::Workgroup,
+ _ => todo!(),
+ };
+ builder.variable(type_id, Some(*id), st_class, None);
+ if let Some(align) = align {
+ builder.decorate(
+ *id,
+ spirv::Decoration::Alignment,
+ &[dr::Operand::LiteralInt32(*align)],
+ );
}
- builder.variable(type_id, Some(*id), spirv::StorageClass::Function, None);
}
Statement::Constant(cnst) => {
let typ_id = map.get_or_add_scalar(builder, cnst.typ);
@@ -1006,7 +1025,10 @@ fn emit_implicit_conversion(
ConversionKind::Ptr(space) => {
let dst_type = map.get_or_add(
builder,
- SpirvType::Pointer(SpirvScalarKey::from(to_type), space.to_spirv()),
+ SpirvType::Pointer(
+ Box::new(SpirvType::Base(SpirvScalarKey::from(to_type))),
+ space.to_spirv(),
+ ),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
@@ -1221,7 +1243,7 @@ impl NumericIdResolver {
}
enum Statement<I> {
- Variable(spirv::Word, ast::Type, ast::StateSpace),
+ Variable(spirv::Word, ast::Type, ast::StateSpace, Option<u32>),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Label(u32),
@@ -1235,7 +1257,7 @@ enum Statement<I> {
impl Statement<ast::Instruction<ExpandedArgParams>> {
fn visit_variable<F: FnMut(spirv::Word) -> spirv::Word>(self, f: &mut F) -> Self {
match self {
- Statement::Variable(id, t, ss) => Statement::Variable(f(id), t, ss),
+ Statement::Variable(id, t, ss, align) => Statement::Variable(f(id), t, ss, align),
Statement::LoadVar(a, t) => {
Statement::LoadVar(a.map(&mut reduced_visitor(f), Some(t)), t)
}