aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-09-17 16:24:25 +0000
committerAndrzej Janik <[email protected]>2021-09-17 16:24:25 +0000
commit5b2352723fb251b64317737167b609a0a11651a6 (patch)
tree35562f2fdf7b39b56ebe0bb74113b5f1ff2d8d88
parentc37223fe673f2f45a533e338b74ae9325748588a (diff)
downloadZLUDA-5b2352723fb251b64317737167b609a0a11651a6.tar.gz
ZLUDA-5b2352723fb251b64317737167b609a0a11651a6.zip
Implement function pointers and activemask
-rw-r--r--ptx/lib/zluda_ptx_impl.bcbin30788 -> 31224 bytes
-rw-r--r--ptx/lib/zluda_ptx_impl.cl5
-rw-r--r--ptx/src/test/spirv_run/activemask.spvtxt18
-rw-r--r--ptx/src/test/spirv_run/func_ptr.ptx31
-rw-r--r--ptx/src/test/spirv_run/func_ptr.spvtxt73
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/translate.rs116
7 files changed, 214 insertions, 30 deletions
diff --git a/ptx/lib/zluda_ptx_impl.bc b/ptx/lib/zluda_ptx_impl.bc
index 6a2a51c..175f4df 100644
--- a/ptx/lib/zluda_ptx_impl.bc
+++ b/ptx/lib/zluda_ptx_impl.bc
Binary files differ
diff --git a/ptx/lib/zluda_ptx_impl.cl b/ptx/lib/zluda_ptx_impl.cl
index 9171ef9..aca9327 100644
--- a/ptx/lib/zluda_ptx_impl.cl
+++ b/ptx/lib/zluda_ptx_impl.cl
@@ -291,6 +291,11 @@ atomic_add(atom_acq_rel_sys_shared_add_f64, memory_order_acq_rel, memory_order_a
ulong FUNC(brev_b64)(ulong base) {
return __llvm_bitreverse_i64(base);
}
+
+ // Taken from __ballot definition in hipamd/include/hip/amd_detail/amd_device_functions.h
+ uint FUNC(activemask)() {
+ return (uint)__builtin_amdgcn_uicmp(1, 0, 33);
+ }
#endif
void FUNC(__assertfail)(
diff --git a/ptx/src/test/spirv_run/activemask.spvtxt b/ptx/src/test/spirv_run/activemask.spvtxt
index c4ad55d..0753c95 100644
--- a/ptx/src/test/spirv_run/activemask.spvtxt
+++ b/ptx/src/test/spirv_run/activemask.spvtxt
@@ -7,21 +7,22 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %16 = OpExtInstImport "OpenCL.std"
+ %18 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "activemask"
OpExecutionMode %1 ContractionOff
+ OpDecorate %15 LinkageAttributes "__zluda_ptx_impl__activemask" Import
%void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %21 = OpTypeFunction %uint
%ulong = OpTypeInt 64 0
- %19 = OpTypeFunction %void %ulong %ulong
+ %23 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
- %uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
- %v4uint = OpTypeVector %uint 4
- %bool = OpTypeBool
- %true = OpConstantTrue %bool
%_ptr_Generic_uint = OpTypePointer Generic %uint
- %1 = OpFunction %void None %19
+ %15 = OpFunction %uint None %21
+ OpFunctionEnd
+ %1 = OpFunction %void None %23
%6 = OpFunctionParameter %ulong
%7 = OpFunctionParameter %ulong
%14 = OpLabel
@@ -33,8 +34,7 @@
OpStore %3 %7
%8 = OpLoad %ulong %3 Aligned 8
OpStore %4 %8
- %26 = OpSubgroupBallotKHR %v4uint %true
- %9 = OpCompositeExtract %uint %26 0
+ %9 = OpFunctionCall %uint %15
OpStore %5 %9
%10 = OpLoad %ulong %4
%11 = OpLoad %uint %5
diff --git a/ptx/src/test/spirv_run/func_ptr.ptx b/ptx/src/test/spirv_run/func_ptr.ptx
new file mode 100644
index 0000000..aa94f2b
--- /dev/null
+++ b/ptx/src/test/spirv_run/func_ptr.ptx
@@ -0,0 +1,31 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.func (.reg .f32 out) foobar(.reg .f32 x, .reg .f32 y)
+{
+ add.f32 out, x, y;
+ ret;
+}
+
+.visible .entry func_ptr(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+ .reg .u64 temp2;
+ .reg .u64 f_addr;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ add.u64 temp2, temp, 1;
+ mov.u64 f_addr, foobar;
+ add.u64 temp2, temp2, f_addr;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/func_ptr.spvtxt b/ptx/src/test/spirv_run/func_ptr.spvtxt
new file mode 100644
index 0000000..adc71eb
--- /dev/null
+++ b/ptx/src/test/spirv_run/func_ptr.spvtxt
@@ -0,0 +1,73 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %38 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %11 "func_ptr"
+ OpExecutionMode %11 ContractionOff
+ %void = OpTypeVoid
+ %float = OpTypeFloat 32
+ %41 = OpTypeFunction %float %float %float
+%_ptr_Function_float = OpTypePointer Function %float
+ %ulong = OpTypeInt 64 0
+ %44 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_1 = OpConstant %ulong 1
+ %ulong_0 = OpConstant %ulong 0
+ %1 = OpFunction %float None %41
+ %5 = OpFunctionParameter %float
+ %6 = OpFunctionParameter %float
+ %10 = OpLabel
+ %3 = OpVariable %_ptr_Function_float Function
+ %4 = OpVariable %_ptr_Function_float Function
+ %2 = OpVariable %_ptr_Function_float Function
+ OpStore %3 %5
+ OpStore %4 %6
+ %8 = OpLoad %float %3
+ %9 = OpLoad %float %4
+ %7 = OpFAdd %float %8 %9
+ OpStore %2 %7
+ OpFunctionEnd
+ %11 = OpFunction %void None %44
+ %19 = OpFunctionParameter %ulong
+ %20 = OpFunctionParameter %ulong
+ %36 = OpLabel
+ %12 = OpVariable %_ptr_Function_ulong Function
+ %13 = OpVariable %_ptr_Function_ulong Function
+ %14 = OpVariable %_ptr_Function_ulong Function
+ %15 = OpVariable %_ptr_Function_ulong Function
+ %16 = OpVariable %_ptr_Function_ulong Function
+ %17 = OpVariable %_ptr_Function_ulong Function
+ %18 = OpVariable %_ptr_Function_ulong Function
+ OpStore %12 %19
+ OpStore %13 %20
+ %21 = OpLoad %ulong %12 Aligned 8
+ OpStore %14 %21
+ %22 = OpLoad %ulong %13 Aligned 8
+ OpStore %15 %22
+ %24 = OpLoad %ulong %14
+ %34 = OpConvertUToPtr %_ptr_Generic_ulong %24
+ %23 = OpLoad %ulong %34 Aligned 8
+ OpStore %16 %23
+ %26 = OpLoad %ulong %16
+ %25 = OpIAdd %ulong %26 %ulong_1
+ OpStore %17 %25
+ %27 = OpCopyObject %ulong %ulong_0
+ OpStore %18 %27
+ %29 = OpLoad %ulong %17
+ %30 = OpLoad %ulong %18
+ %28 = OpIAdd %ulong %29 %30
+ OpStore %17 %28
+ %31 = OpLoad %ulong %15
+ %32 = OpLoad %ulong %17
+ %35 = OpConvertUToPtr %_ptr_Generic_ulong %31
+ OpStore %35 %32 Aligned 8
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index f6b556e..0dcd0bb 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -209,6 +209,7 @@ test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
test_ptx!(activemask, [0u32], [1u32]);
test_ptx!(membar, [152731u32], [152731u32]);
+test_ptx!(func_ptr, [152731u64], [152732u64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index e015062..39bd07e 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -5,7 +5,7 @@ use std::cell::RefCell;
use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc};
-use rspirv::binary::Assemble;
+use rspirv::binary::{Assemble, Disassemble};
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.bc");
@@ -607,6 +607,7 @@ fn emit_directives<'input>(
}
}
emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
+ builder.select_block(None)?;
builder.end_function()?;
if let (
ast::MethodDeclaration {
@@ -988,6 +989,7 @@ fn compute_denorm_information<'input>(
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
Statement::RepackVector(_) => {}
+ Statement::FunctionPointer(_) => {}
}
}
denorm_methods.insert(method_key, flush_counter);
@@ -1411,6 +1413,15 @@ fn extract_globals<'input, 'b>(
fn_name,
)?);
}
+ Statement::Instruction(ast::Instruction::Activemask { arg }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Activemask { arg },
+ fn_name,
+ )?);
+ }
Statement::Instruction(ast::Instruction::Atom(
details
@
@@ -1596,6 +1607,21 @@ fn convert_to_typed_statements(
for s in func {
match s {
Statement::Instruction(inst) => match inst {
+ ast::Instruction::Mov(
+ mov,
+ ast::Arg2Mov {
+ dst: ast::Operand::Reg(dst_reg),
+ src: ast::Operand::Reg(src_reg),
+ },
+ ) if fn_defs.fns.contains_key(&src_reg) => {
+ if mov.typ != ast::Type::Scalar(ast::ScalarType::U64) {
+ return Err(TranslateError::MismatchedType);
+ }
+ result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
+ dst: dst_reg,
+ src: src_reg,
+ }));
+ }
ast::Instruction::Call(call) => {
let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
let resolved_call = resolver.resolve_in_spirv_repr(call)?;
@@ -1724,7 +1750,7 @@ fn instruction_to_fn_call(
let return_arguments_count = arguments
.iter()
.position(|(desc, _, _)| !desc.is_dst)
- .unwrap_or(0);
+ .unwrap_or(arguments.len());
let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
let fn_id = register_external_fn_call(
id_defs,
@@ -1826,7 +1852,8 @@ fn normalize_labels(
| Statement::Constant(..)
| Statement::Label(..)
| Statement::PtrAccess { .. }
- | Statement::RepackVector(..) => {}
+ | Statement::RepackVector(..)
+ | Statement::FunctionPointer(..) => {}
}
}
iter::once(Statement::Label(id_def.register_intermediate(None)))
@@ -1984,6 +2011,9 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::RepackVector(repack) => {
insert_mem_ssa_statement_default(id_def, &mut result, repack)?
}
+ Statement::FunctionPointer(func_ptr) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, func_ptr)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
_ => return Err(error_unreachable()),
}
@@ -2235,6 +2265,7 @@ fn expand_arguments<'a, 'b>(
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
Statement::Constant(c) => result.push(Statement::Constant(c)),
+ Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
}
}
Ok(result)
@@ -2421,7 +2452,8 @@ fn insert_implicit_conversions(
| s @ Statement::Variable(_)
| s @ Statement::LoadVar(..)
| s @ Statement::StoreVar(..)
- | s @ Statement::RetValue(_, _) => result.push(s),
+ | s @ Statement::RetValue(..)
+ | s @ Statement::FunctionPointer(..) => result.push(s),
}
}
Ok(result)
@@ -2653,6 +2685,16 @@ fn emit_function_body_ops<'input>(
iter::empty(),
)?;
}
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
+ // TODO: implement properly
+ let zero = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U64),
+ &vec_repr(0u64),
+ )?;
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64);
+ builder.copy_object(result_type, Some(*dst), zero)?;
+ }
Statement::Instruction(inst) => match inst {
ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?,
ast::Instruction::Call(_) => unreachable!(),
@@ -2975,14 +3017,13 @@ fn emit_function_body_ops<'input>(
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
- ast::Instruction::Bfe { .. } => {
- // Should have beeen replaced with a funciton call earlier
- return Err(error_unreachable());
- }
- ast::Instruction::Bfi { .. } => {
+ ast::Instruction::Bfe { .. }
+ | ast::Instruction::Bfi { .. }
+ | ast::Instruction::Activemask { .. } => {
// Should have beeen replaced with a funciton call earlier
return Err(error_unreachable());
}
+
ast::Instruction::Rem { typ, arg } => {
let builder_fn = if typ.kind() == ast::ScalarKind::Signed {
dr::Builder::s_mod
@@ -3017,18 +3058,6 @@ fn emit_function_body_ops<'input>(
)?;
builder.bitcast(b32_type, Some(arg.dst), dst_vector)?;
}
- ast::Instruction::Activemask { arg } => {
- let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
- let vec4_b32_type =
- map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B32, 4));
- let pred_true = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::Pred),
- &[1],
- )?;
- let dst_vector = builder.subgroup_ballot_khr(vec4_b32_type, None, pred_true)?;
- builder.composite_extract(b32_type, Some(arg.src), dst_vector, [0])?;
- }
ast::Instruction::Membar { level } => {
let (scope, semantics) = match level {
ast::MemScope::Cta => (
@@ -5293,6 +5322,44 @@ impl<'b> MutableNumericIdResolver<'b> {
}
}
+struct FunctionPointerDetails {
+ dst: spirv::Word,
+ src: spirv::Word,
+}
+
+impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
+ for FunctionPointerDetails
+{
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::FunctionPointer(FunctionPointerDetails {
+ dst: visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::U64),
+ ast::StateSpace::Reg,
+ )),
+ )?,
+ src: visitor.id(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ None,
+ )?,
+ }))
+ }
+}
+
enum Statement<I, P: ast::ArgParams> {
Label(u32),
Variable(ast::Variable<P::Id>),
@@ -5307,6 +5374,7 @@ enum Statement<I, P: ast::ArgParams> {
RetValue(ast::RetData, spirv::Word),
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
+ FunctionPointer(FunctionPointerDetails),
}
impl ExpandedStatement {
@@ -5399,6 +5467,12 @@ impl ExpandedStatement {
..repack
})
}
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
+ Statement::FunctionPointer(FunctionPointerDetails {
+ dst: f(dst, true),
+ src: f(src, false),
+ })
+ }
}
}
}