aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-08-02 01:04:05 +0200
committerAndrzej Janik <[email protected]>2021-08-02 01:04:05 +0200
commitb4de21fbc5eaf33540f1121bfe7c6ba0acaff6c9 (patch)
tree87de1e064acd8889d3f494a4d8e9a1f071e2edaa /ptx
parent4a71fefb8a3886277dba23a4ae17247bb5e2f2e5 (diff)
downloadZLUDA-b4de21fbc5eaf33540f1121bfe7c6ba0acaff6c9.tar.gz
ZLUDA-b4de21fbc5eaf33540f1121bfe7c6ba0acaff6c9.zip
Use calls to OpenCL builtins when translating sregs, do SPIRV->LLVM conversion on every build
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/ntid.spvtxt40
-rw-r--r--ptx/src/test/spirv_run/vector4.ptx22
-rw-r--r--ptx/src/test/spirv_run/vector4.spvtxt99
-rw-r--r--ptx/src/translate.rs162
5 files changed, 243 insertions, 81 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 226043f..d5bc8dd 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -61,6 +61,7 @@ test_ptx!(block, [1u64], [2u64]);
test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
+test_ptx!(vector4, [1u32, 2u32, 3u32, 4u32], [4u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
test_ptx!(ntid, [3u32], [4u32]);
test_ptx!(reg_local, [12u64], [13u64]);
diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt
index 7b5a630..e5f343c 100644
--- a/ptx/src/test/spirv_run/ntid.spvtxt
+++ b/ptx/src/test/spirv_run/ntid.spvtxt
@@ -7,24 +7,27 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %28 = OpExtInstImport "OpenCL.std"
+ %31 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize
- OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
+ OpEntryPoint Kernel %1 "ntid"
+ OpExecutionMode %1 ContractionOff
+ OpDecorate %24 LinkageAttributes "get_local_size" Import
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
- %v3ulong = OpTypeVector %ulong 3
-%_ptr_Input_v3ulong = OpTypePointer Input %v3ulong
-%gl_WorkGroupSize = OpVariable %_ptr_Input_v3ulong Input
- %33 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
+ %35 = OpTypeFunction %ulong %uint
+ %36 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
- %1 = OpFunction %void None %33
+ %uint_0 = OpConstant %uint 0
+ %24 = OpFunction %ulong None %35
+ %26 = OpFunctionParameter %uint
+ OpFunctionEnd
+ %1 = OpFunction %void None %36
%9 = OpFunctionParameter %ulong
%10 = OpFunctionParameter %ulong
- %26 = OpLabel
+ %29 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@@ -38,13 +41,12 @@
%12 = OpLoad %ulong %3 Aligned 8
OpStore %5 %12
%14 = OpLoad %ulong %4
- %24 = OpConvertUToPtr %_ptr_Generic_uint %14
- %13 = OpLoad %uint %24 Aligned 4
+ %27 = OpConvertUToPtr %_ptr_Generic_uint %14
+ %13 = OpLoad %uint %27 Aligned 4
OpStore %6 %13
- %38 = OpLoad %v3ulong %gl_WorkGroupSize
- %23 = OpCompositeExtract %ulong %38 0
- %39 = OpBitcast %ulong %23
- %16 = OpUConvert %uint %39
+ %23 = OpFunctionCall %ulong %24 %uint_0
+ %40 = OpBitcast %ulong %23
+ %16 = OpUConvert %uint %40
%15 = OpCopyObject %uint %16
OpStore %7 %15
%18 = OpLoad %uint %6
@@ -53,7 +55,7 @@
OpStore %6 %17
%20 = OpLoad %ulong %5
%21 = OpLoad %uint %6
- %25 = OpConvertUToPtr %_ptr_Generic_uint %20
- OpStore %25 %21 Aligned 4
+ %28 = OpConvertUToPtr %_ptr_Generic_uint %20
+ OpStore %28 %21 Aligned 4
OpReturn
- OpFunctionEnd
+ OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/vector4.ptx b/ptx/src/test/spirv_run/vector4.ptx
new file mode 100644
index 0000000..d010b70
--- /dev/null
+++ b/ptx/src/test/spirv_run/vector4.ptx
@@ -0,0 +1,22 @@
+.version 6.5
+.target sm_60
+.address_size 64
+
+.visible .entry vector4(
+ .param .u64 input_p,
+ .param .u64 output_p
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .v4 .u32 temp;
+ .reg .u32 temp_scalar;
+
+ ld.param.u64 in_addr, [input_p];
+ ld.param.u64 out_addr, [output_p];
+
+ ld.v4.u32 temp, [in_addr];
+ mov.b32 temp_scalar, temp.w;
+ st.u32 [out_addr], temp_scalar;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/vector4.spvtxt b/ptx/src/test/spirv_run/vector4.spvtxt
new file mode 100644
index 0000000..8253bf9
--- /dev/null
+++ b/ptx/src/test/spirv_run/vector4.spvtxt
@@ -0,0 +1,99 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %51 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %25 "vector"
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %v2uint = OpTypeVector %uint 2
+ %55 = OpTypeFunction %v2uint %v2uint
+%_ptr_Function_v2uint = OpTypePointer Function %v2uint
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
+ %ulong = OpTypeInt 64 0
+ %67 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
+ %1 = OpFunction %v2uint None %55
+ %7 = OpFunctionParameter %v2uint
+ %24 = OpLabel
+ %3 = OpVariable %_ptr_Function_v2uint Function
+ %2 = OpVariable %_ptr_Function_v2uint Function
+ %4 = OpVariable %_ptr_Function_v2uint Function
+ %5 = OpVariable %_ptr_Function_uint Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ OpStore %3 %7
+ %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0
+ %9 = OpLoad %uint %59
+ %8 = OpCopyObject %uint %9
+ OpStore %5 %8
+ %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1
+ %11 = OpLoad %uint %61
+ %10 = OpCopyObject %uint %11
+ OpStore %6 %10
+ %13 = OpLoad %uint %5
+ %14 = OpLoad %uint %6
+ %12 = OpIAdd %uint %13 %14
+ OpStore %6 %12
+ %16 = OpLoad %uint %6
+ %15 = OpCopyObject %uint %16
+ %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
+ OpStore %62 %15
+ %18 = OpLoad %uint %6
+ %17 = OpCopyObject %uint %18
+ %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
+ OpStore %63 %17
+ %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
+ %20 = OpLoad %uint %64
+ %19 = OpCopyObject %uint %20
+ %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
+ OpStore %65 %19
+ %22 = OpLoad %v2uint %4
+ %21 = OpCopyObject %v2uint %22
+ OpStore %2 %21
+ %23 = OpLoad %v2uint %2
+ OpReturnValue %23
+ OpFunctionEnd
+ %25 = OpFunction %void None %67
+ %34 = OpFunctionParameter %ulong
+ %35 = OpFunctionParameter %ulong
+ %49 = OpLabel
+ %26 = OpVariable %_ptr_Function_ulong Function
+ %27 = OpVariable %_ptr_Function_ulong Function
+ %28 = OpVariable %_ptr_Function_ulong Function
+ %29 = OpVariable %_ptr_Function_ulong Function
+ %30 = OpVariable %_ptr_Function_v2uint Function
+ %31 = OpVariable %_ptr_Function_uint Function
+ %32 = OpVariable %_ptr_Function_uint Function
+ %33 = OpVariable %_ptr_Function_ulong Function
+ OpStore %26 %34
+ OpStore %27 %35
+ %36 = OpLoad %ulong %26 Aligned 8
+ OpStore %28 %36
+ %37 = OpLoad %ulong %27 Aligned 8
+ OpStore %29 %37
+ %39 = OpLoad %ulong %28
+ %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39
+ %38 = OpLoad %v2uint %46 Aligned 8
+ OpStore %30 %38
+ %41 = OpLoad %v2uint %30
+ %40 = OpFunctionCall %v2uint %1 %41
+ OpStore %30 %40
+ %43 = OpLoad %v2uint %30
+ %47 = OpBitcast %ulong %43
+ %42 = OpCopyObject %ulong %47
+ OpStore %33 %42
+ %44 = OpLoad %ulong %29
+ %45 = OpLoad %v2uint %30
+ %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44
+ OpStore %48 %45 Aligned 8
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 5fea075..6c2c594 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -448,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
- emit_builtins(&mut builder, &mut map, &id_defs);
+ //emit_builtins(&mut builder, &mut map, &id_defs);
let mut kernel_info = HashMap::new();
let build_options = emit_denorm_build_string(&call_map, &denorm_information);
emit_directives(
@@ -1250,7 +1250,8 @@ fn to_ssa<'input, 'b>(
&mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
- let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
+ let ssa_statements =
+ fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1269,6 +1270,7 @@ fn to_ssa<'input, 'b>(
}
fn fix_special_registers(
+ ptx_impl_imports: &mut HashMap<String, Directive>,
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
@@ -1276,7 +1278,6 @@ fn fix_special_registers(
for s in typed_statements {
match s {
Statement::LoadVar(
- mut
details
@
LoadVarDetails {
@@ -1285,48 +1286,53 @@ fn fix_special_registers(
},
) => {
let index = details.member_index.unwrap().0;
- if index == 3 {
- result.push(Statement::Constant(ConstantDefinition {
- dst: details.arg.dst,
- typ: ast::ScalarType::U32,
- value: ast::ImmediateValue::U64(0),
- }));
- } else {
- let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src)
- {
- Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg),
- None => None,
- };
- let (sreg_src, scalar_typ, vector_width) = match sreg_and_type {
- Some(sreg_and_type) => sreg_and_type,
- None => {
- result.push(Statement::LoadVar(details));
- continue;
- }
- };
- let temp_id = numeric_id_defs
- .register_intermediate(Some((details.typ.clone(), details.state_space)));
- let real_dst = details.arg.dst;
- details.arg.dst = temp_id;
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: Arg2 {
- src: sreg_src,
- dst: temp_id,
- },
- state_space: ast::StateSpace::Sreg,
- typ: ast::Type::Scalar(scalar_typ),
- member_index: Some((index, Some(vector_width))),
- }));
- result.push(Statement::Conversion(ImplicitConversion {
- src: temp_id,
- dst: real_dst,
- from_type: ast::Type::Scalar(scalar_typ),
- from_space: ast::StateSpace::Sreg,
- to_type: ast::Type::Scalar(ast::ScalarType::U32),
- to_space: ast::StateSpace::Sreg,
- kind: ConversionKind::Default,
- }));
- }
+ let sreg = numeric_id_defs
+ .special_registers
+ .get(details.arg.src)
+ .ok_or_else(|| error_unreachable())?;
+ let (ocl_name, ocl_type) = sreg.get_opencl_fn_type();
+ let index_constant = numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::Constant(ConstantDefinition {
+ dst: index_constant,
+ typ: ast::ScalarType::U32,
+ value: ast::ImmediateValue::U64(index as u64),
+ }));
+ let fn_result = numeric_id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ocl_type),
+ ast::StateSpace::Reg,
+ )));
+ let return_arguments =
+ vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)];
+ let input_arguments = vec![(
+ TypedOperand::Reg(index_constant),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )];
+ let fn_call = register_external_fn_call(
+ numeric_id_defs,
+ ptx_impl_imports,
+ ocl_name.to_string(),
+ return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ )?;
+ result.push(Statement::Call(ResolvedCall {
+ uniform: false,
+ return_arguments,
+ name: fn_call,
+ input_arguments,
+ }));
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: fn_result,
+ dst: details.arg.dst,
+ from_type: ast::Type::Scalar(ocl_type),
+ from_space: ast::StateSpace::Reg,
+ to_type: ast::Type::Scalar(ast::ScalarType::U32),
+ to_space: ast::StateSpace::Reg,
+ kind: ConversionKind::Default,
+ }));
}
s => result.push(s),
}
@@ -1721,8 +1727,8 @@ fn instruction_to_fn_call(
id_defs,
ptx_impl_imports,
fn_name,
- return_arguments,
- input_arguments,
+ return_arguments.iter().map(|(_, typ, state)| (typ, *state)),
+ input_arguments.iter().map(|(_, typ, state)| (typ, *state)),
)?;
Ok(Statement::Call(ResolvedCall {
uniform: false,
@@ -1732,12 +1738,12 @@ fn instruction_to_fn_call(
}))
}
-fn register_external_fn_call(
+fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
name: String,
- return_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
- input_arguments: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
+ return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+ input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
match ptx_impl_imports.entry(name) {
hash_map::Entry::Vacant(entry) => {
@@ -1770,19 +1776,18 @@ fn register_external_fn_call(
}
}
-fn fn_arguments_to_variables(
+fn fn_arguments_to_variables<'a>(
id_defs: &mut NumericIdResolver,
- args: &[(ArgumentDescriptor<spirv::Word>, ast::Type, ast::StateSpace)],
+ args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
) -> Vec<ast::Variable<spirv::Word>> {
- args.iter()
- .map(|(_, typ, space)| ast::Variable {
- align: None,
- v_type: typ.clone(),
- state_space: *space,
- name: id_defs.register_intermediate(None),
- array_init: Vec::new(),
- })
- .collect::<Vec<_>>()
+ args.map(|(typ, space)| ast::Variable {
+ align: None,
+ v_type: typ.clone(),
+ state_space: space,
+ name: id_defs.register_intermediate(None),
+ array_init: Vec::new(),
+ })
+ .collect::<Vec<_>>()
}
fn arguments_to_resolved_arguments(
@@ -2226,7 +2231,7 @@ fn expand_arguments<'a, 'b>(
Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
- Statement::Constant(_) => return Err(error_unreachable()),
+ Statement::Constant(c) => result.push(Statement::Constant(c)),
}
}
Ok(result)
@@ -4686,6 +4691,19 @@ impl PtxSpecialRegister {
}
}
+ fn get_scalar_type(self) -> ast::ScalarType {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => ast::ScalarType::U32,
+ PtxSpecialRegister::Tid64
+ | PtxSpecialRegister::Ntid64
+ | PtxSpecialRegister::Ctaid64
+ | PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64,
+ }
+ }
+
fn get_builtin(self) -> spirv::BuiltIn {
match self {
PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
@@ -4701,6 +4719,23 @@ impl PtxSpecialRegister {
}
}
+ fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) {
+ match self {
+ PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
+ ("_Z12get_local_idj", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => {
+ ("_Z14get_local_sizej", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => {
+ ("_Z12get_group_idj", ast::ScalarType::U64)
+ }
+ PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
+ ("_Z14get_num_groupsj", ast::ScalarType::U64)
+ }
+ }
+ }
+
fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
match self {
PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
@@ -4743,6 +4778,8 @@ impl SpecialRegistersMap {
}
fn interface(&self) -> Vec<spirv::Word> {
+ return Vec::new();
+ /*
self.reg_to_id
.iter()
.filter_map(|(sreg, id)| {
@@ -4753,6 +4790,7 @@ impl SpecialRegistersMap {
}
})
.collect::<Vec<_>>()
+ */
}
fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {