summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-18 23:01:15 +0100
committerAndrzej Janik <[email protected]>2020-11-18 23:01:15 +0100
commitc8078071fd1eedebd94c9fd0ad81d65607bb8da1 (patch)
treeb00ad66fbc5224c450e8a20889866cb5dbe891fc
parentb0b0c21a9b5eb973d366dda6899e4a94442f923d (diff)
downloadZLUDA-c8078071fd1eedebd94c9fd0ad81d65607bb8da1.tar.gz
ZLUDA-c8078071fd1eedebd94c9fd0ad81d65607bb8da1.zip
Implement PtrAdd transformations
-rw-r--r--ptx/src/test/spirv_run/cvta.spvtxt85
-rw-r--r--ptx/src/test/spirv_run/reg_local.spvtxt7
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt116
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt94
-rw-r--r--ptx/src/translate.rs485
5 files changed, 406 insertions, 381 deletions
diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt
index cf6ff8b..2e9f028 100644
--- a/ptx/src/test/spirv_run/cvta.spvtxt
+++ b/ptx/src/test/spirv_run/cvta.spvtxt
@@ -7,48 +7,61 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %27 = OpExtInstImport "OpenCL.std"
+ %39 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvta"
%void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %30 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %43 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%float = OpTypeFloat 32
%_ptr_Function_float = OpTypePointer Function %float
+ %ulong = OpTypeInt 64 0
+%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
- %1 = OpFunction %void None %30
- %7 = OpFunctionParameter %ulong
- %8 = OpFunctionParameter %ulong
- %25 = OpLabel
- %2 = OpVariable %_ptr_Function_ulong Function
- %3 = OpVariable %_ptr_Function_ulong Function
- %4 = OpVariable %_ptr_Function_ulong Function
- %5 = OpVariable %_ptr_Function_ulong Function
+ %1 = OpFunction %void None %43
+ %19 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %37 = OpLabel
+ %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %7 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %8 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%6 = OpVariable %_ptr_Function_float Function
- OpStore %2 %7
- OpStore %3 %8
- %9 = OpLoad %ulong %2
- OpStore %4 %9
- %10 = OpLoad %ulong %3
- OpStore %5 %10
- %12 = OpLoad %ulong %4
- %20 = OpCopyObject %ulong %12
- %19 = OpCopyObject %ulong %20
- %11 = OpCopyObject %ulong %19
- OpStore %4 %11
- %14 = OpLoad %ulong %5
- %22 = OpCopyObject %ulong %14
- %21 = OpCopyObject %ulong %22
- %13 = OpCopyObject %ulong %21
- OpStore %5 %13
- %16 = OpLoad %ulong %4
- %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16
- %15 = OpLoad %float %23
- OpStore %6 %15
- %17 = OpLoad %ulong %5
- %18 = OpLoad %float %6
- %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17
- OpStore %24 %18
+ OpStore %2 %19
+ OpStore %3 %20
+ %10 = OpBitcast %_ptr_Function_ulong %2
+ %9 = OpLoad %ulong %10
+ %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %9
+ OpStore %7 %21
+ %12 = OpBitcast %_ptr_Function_ulong %3
+ %11 = OpLoad %ulong %12
+ %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %11
+ OpStore %8 %22
+ %23 = OpLoad %_ptr_CrossWorkgroup_uchar %7
+ %14 = OpConvertPtrToU %ulong %23
+ %32 = OpCopyObject %ulong %14
+ %31 = OpCopyObject %ulong %32
+ %13 = OpCopyObject %ulong %31
+ %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13
+ OpStore %7 %24
+ %25 = OpLoad %_ptr_CrossWorkgroup_uchar %8
+ %16 = OpConvertPtrToU %ulong %25
+ %34 = OpCopyObject %ulong %16
+ %33 = OpCopyObject %ulong %34
+ %15 = OpCopyObject %ulong %33
+ %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15
+ OpStore %8 %26
+ %27 = OpLoad %_ptr_CrossWorkgroup_uchar %7
+ %17 = OpConvertPtrToU %ulong %27
+ %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17
+ %28 = OpLoad %float %35
+ OpStore %6 %28
+ %29 = OpLoad %_ptr_CrossWorkgroup_uchar %8
+ %18 = OpConvertPtrToU %ulong %29
+ %30 = OpLoad %float %6
+ %36 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
+ OpStore %36 %30
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt
index 10ff639..5ce3689 100644
--- a/ptx/src/test/spirv_run/reg_local.spvtxt
+++ b/ptx/src/test/spirv_run/reg_local.spvtxt
@@ -24,6 +24,7 @@
%ulong_1 = OpConstant %ulong 1
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_0 = OpConstant %ulong 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_0_0 = OpConstant %ulong 0
%1 = OpFunction %void None %37
%8 = OpFunctionParameter %ulong
@@ -52,9 +53,9 @@
%27 = OpBitcast %_ptr_Generic_ulong %4
OpStore %27 %19
%28 = OpBitcast %_ptr_Generic_ulong %4
- %46 = OpBitcast %ulong %28
- %47 = OpIAdd %ulong %46 %ulong_0
- %21 = OpBitcast %_ptr_Generic_ulong %47
+ %47 = OpBitcast %_ptr_Generic_uchar %28
+ %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0
+ %21 = OpBitcast %_ptr_Generic_ulong %48
%29 = OpLoad %ulong %21
%15 = OpCopyObject %ulong %29
OpStore %7 %15
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
index 963d88a..bf56a18 100644
--- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
@@ -7,51 +7,85 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %30 = OpExtInstImport "OpenCL.std"
+ %51 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "ld_st_offset"
+ OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID
+ OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
%void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %33 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
%uint = OpTypeInt 32 0
+ %v4uint = OpTypeVector %uint 4
+%_ptr_UniformConstant_v4uint = OpTypePointer UniformConstant %v4uint
+%gl_LocalInvocationID = OpVariable %_ptr_UniformConstant_v4uint UniformConstant
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %58 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%_ptr_Function_uint = OpTypePointer Function %uint
-%_ptr_Generic_uint = OpTypePointer Generic %uint
- %ulong_4 = OpConstant %ulong 4
- %ulong_4_0 = OpConstant %ulong 4
- %1 = OpFunction %void None %33
- %8 = OpFunctionParameter %ulong
- %9 = OpFunctionParameter %ulong
- %28 = OpLabel
- %2 = OpVariable %_ptr_Function_ulong Function
- %3 = OpVariable %_ptr_Function_ulong Function
- %4 = OpVariable %_ptr_Function_ulong Function
- %5 = OpVariable %_ptr_Function_ulong Function
+ %ulong = OpTypeInt 64 0
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %1 = OpFunction %void None %58
+ %22 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %23 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %49 = OpLabel
+ %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%6 = OpVariable %_ptr_Function_uint Function
- %7 = OpVariable %_ptr_Function_uint Function
- OpStore %2 %8
- OpStore %3 %9
- %10 = OpLoad %ulong %2
- OpStore %4 %10
- %11 = OpLoad %ulong %3
- OpStore %5 %11
- %13 = OpLoad %ulong %4
- %24 = OpConvertUToPtr %_ptr_Generic_uint %13
- %12 = OpLoad %uint %24
- OpStore %6 %12
- %15 = OpLoad %ulong %4
- %21 = OpIAdd %ulong %15 %ulong_4
- %25 = OpConvertUToPtr %_ptr_Generic_uint %21
- %14 = OpLoad %uint %25
- OpStore %7 %14
- %16 = OpLoad %ulong %5
- %17 = OpLoad %uint %7
- %26 = OpConvertUToPtr %_ptr_Generic_uint %16
- OpStore %26 %17
- %18 = OpLoad %ulong %5
- %19 = OpLoad %uint %6
- %23 = OpIAdd %ulong %18 %ulong_4_0
- %27 = OpConvertUToPtr %_ptr_Generic_uint %23
- OpStore %27 %19
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %22
+ OpStore %3 %23
+ %13 = OpBitcast %_ptr_Function_ulong %2
+ %45 = OpLoad %ulong %13
+ %12 = OpCopyObject %ulong %45
+ %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12
+ OpStore %10 %24
+ %15 = OpBitcast %_ptr_Function_ulong %3
+ %46 = OpLoad %ulong %15
+ %14 = OpCopyObject %ulong %46
+ %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14
+ OpStore %11 %25
+ %26 = OpLoad %_ptr_CrossWorkgroup_uchar %10
+ %17 = OpConvertPtrToU %ulong %26
+ %16 = OpCopyObject %ulong %17
+ %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %16
+ OpStore %10 %27
+ %28 = OpLoad %_ptr_CrossWorkgroup_uchar %11
+ %19 = OpConvertPtrToU %ulong %28
+ %18 = OpCopyObject %ulong %19
+ %29 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18
+ OpStore %11 %29
+ %31 = OpLoad %v4uint %gl_LocalInvocationID
+ %44 = OpCompositeExtract %uint %31 0
+ %30 = OpCopyObject %uint %44
+ OpStore %6 %30
+ %33 = OpLoad %uint %6
+ %63 = OpBitcast %uint %33
+ %32 = OpUConvert %ulong %63
+ OpStore %7 %32
+ %35 = OpLoad %_ptr_CrossWorkgroup_uchar %10
+ %36 = OpLoad %ulong %7
+ %64 = OpBitcast %_ptr_CrossWorkgroup_uchar %35
+ %65 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %64 %36
+ %34 = OpBitcast %_ptr_CrossWorkgroup_uchar %65
+ OpStore %10 %34
+ %38 = OpLoad %_ptr_CrossWorkgroup_uchar %11
+ %39 = OpLoad %ulong %7
+ %66 = OpBitcast %_ptr_CrossWorkgroup_uchar %38
+ %67 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %66 %39
+ %37 = OpBitcast %_ptr_CrossWorkgroup_uchar %67
+ OpStore %11 %37
+ %40 = OpLoad %_ptr_CrossWorkgroup_uchar %10
+ %20 = OpConvertPtrToU %ulong %40
+ %47 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %20
+ %41 = OpLoad %ulong %47
+ OpStore %8 %41
+ %42 = OpLoad %_ptr_CrossWorkgroup_uchar %11
+ %21 = OpConvertPtrToU %ulong %42
+ %43 = OpLoad %ulong %8
+ %48 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21
+ OpStore %48 %43
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt
index 963d88a..1d28996 100644
--- a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt
+++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt
@@ -7,51 +7,61 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %30 = OpExtInstImport "OpenCL.std"
+ %43 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "ld_st_offset"
+ OpEntryPoint Kernel %1 "stateful_ld_st_simple"
%void = OpTypeVoid
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %47 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%ulong = OpTypeInt 64 0
- %33 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
- %uint = OpTypeInt 32 0
-%_ptr_Function_uint = OpTypePointer Function %uint
-%_ptr_Generic_uint = OpTypePointer Generic %uint
- %ulong_4 = OpConstant %ulong 4
- %ulong_4_0 = OpConstant %ulong 4
- %1 = OpFunction %void None %33
- %8 = OpFunctionParameter %ulong
- %9 = OpFunctionParameter %ulong
- %28 = OpLabel
- %2 = OpVariable %_ptr_Function_ulong Function
- %3 = OpVariable %_ptr_Function_ulong Function
- %4 = OpVariable %_ptr_Function_ulong Function
- %5 = OpVariable %_ptr_Function_ulong Function
- %6 = OpVariable %_ptr_Function_uint Function
- %7 = OpVariable %_ptr_Function_uint Function
- OpStore %2 %8
- OpStore %3 %9
- %10 = OpLoad %ulong %2
- OpStore %4 %10
- %11 = OpLoad %ulong %3
- OpStore %5 %11
- %13 = OpLoad %ulong %4
- %24 = OpConvertUToPtr %_ptr_Generic_uint %13
- %12 = OpLoad %uint %24
- OpStore %6 %12
- %15 = OpLoad %ulong %4
- %21 = OpIAdd %ulong %15 %ulong_4
- %25 = OpConvertUToPtr %_ptr_Generic_uint %21
- %14 = OpLoad %uint %25
- OpStore %7 %14
- %16 = OpLoad %ulong %5
- %17 = OpLoad %uint %7
- %26 = OpConvertUToPtr %_ptr_Generic_uint %16
- OpStore %26 %17
- %18 = OpLoad %ulong %5
- %19 = OpLoad %uint %6
- %23 = OpIAdd %ulong %18 %ulong_4_0
- %27 = OpConvertUToPtr %_ptr_Generic_uint %23
- OpStore %27 %19
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %1 = OpFunction %void None %47
+ %23 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %24 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
+ %41 = OpLabel
+ %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %9 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %23
+ OpStore %3 %24
+ %14 = OpBitcast %_ptr_Function_ulong %2
+ %13 = OpLoad %ulong %14
+ %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13
+ OpStore %9 %25
+ %16 = OpBitcast %_ptr_Function_ulong %3
+ %15 = OpLoad %ulong %16
+ %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15
+ OpStore %11 %26
+ %27 = OpLoad %_ptr_CrossWorkgroup_uchar %9
+ %18 = OpConvertPtrToU %ulong %27
+ %36 = OpCopyObject %ulong %18
+ %35 = OpCopyObject %ulong %36
+ %17 = OpCopyObject %ulong %35
+ %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %17
+ OpStore %12 %28
+ %29 = OpLoad %_ptr_CrossWorkgroup_uchar %11
+ %20 = OpConvertPtrToU %ulong %29
+ %38 = OpCopyObject %ulong %20
+ %37 = OpCopyObject %ulong %38
+ %19 = OpCopyObject %ulong %37
+ %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19
+ OpStore %10 %30
+ %31 = OpLoad %_ptr_CrossWorkgroup_uchar %12
+ %21 = OpConvertPtrToU %ulong %31
+ %39 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21
+ %32 = OpLoad %ulong %39
+ OpStore %8 %32
+ %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10
+ %22 = OpConvertPtrToU %ulong %33
+ %34 = OpLoad %ulong %8
+ %40 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %22
+ OpStore %40 %34
OpReturn
OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 29fa93e..ef2c1c5 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -971,7 +971,7 @@ fn compute_denorm_information<'input>(
Statement::Undef(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
- Statement::PtrAdd { .. } => {}
+ Statement::PtrAccess { .. } => {}
}
}
denorm_methods.insert(method_key, flush_counter);
@@ -1635,6 +1635,7 @@ fn convert_to_typed_statements(
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
+ Statement::Conditional(c) => result.push(Statement::Conditional(c)),
_ => return Err(TranslateError::Unreachable),
}
}
@@ -1868,7 +1869,7 @@ fn normalize_labels(
| Statement::Constant(_)
| Statement::Label(_)
| Statement::Undef(_, _)
- | Statement::PtrAdd { .. } => {}
+ | Statement::PtrAccess { .. } => {}
}
}
iter::once(Statement::Label(id_def.new_non_variable(None)))
@@ -2004,6 +2005,9 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::Conversion(conv) => {
insert_mem_ssa_statement_default(id_def, &mut result, conv)?
}
+ Statement::PtrAccess(ptr_access) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
_ => return Err(TranslateError::Unreachable),
}
@@ -2159,55 +2163,11 @@ fn expand_arguments<'a, 'b>(
name,
array_init,
})),
- Statement::PtrAdd {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- offset_src: constant_src,
- } => {
+ Statement::PtrAccess(ptr_access) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let sema = match state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
- };
- let ptr_type = ast::Type::Pointer(underlying_type.clone(), state_space);
- let new_dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema,
- },
- Some(&ptr_type),
- )?;
- let new_ptr_src = visitor.id(
- ArgumentDescriptor {
- op: ptr_src,
- is_dst: false,
- sema,
- },
- Some(&ptr_type),
- )?;
- let new_constant_src = visitor.id(
- ArgumentDescriptor {
- op: constant_src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Scalar(ast::ScalarType::S64)),
- )?;
- result.push(Statement::PtrAdd {
- underlying_type,
- state_space,
- dst: new_dst,
- ptr_src: new_ptr_src,
- offset_src: new_constant_src,
- })
+ let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::PtrAccess(new_inst));
+ result.extend(post_stmts);
}
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
@@ -2288,13 +2248,13 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
value: ast::ImmediateValue::S64(offset as i64),
}));
let dst = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::PtrAdd {
+ self.func.push(Statement::PtrAccess(PtrAccess {
underlying_type: underlying_type.clone(),
state_space: *state_space,
dst,
ptr_src: reg,
offset_src: id_constant_stmt,
- });
+ }));
return Ok(dst);
} else {
add_type = self.id_def.get_typed(reg)?;
@@ -2577,13 +2537,13 @@ fn insert_implicit_conversions(
should_bitcast_wrapper,
None,
)?,
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src: constant_src,
- } => {
+ }) => {
let visit_desc = VisitArgumentDescriptor {
desc: ArgumentDescriptor {
op: ptr_src,
@@ -2591,12 +2551,14 @@ fn insert_implicit_conversions(
sema: ArgumentSemantics::PhysicalPointer,
},
typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
- stmt_ctor: |new_ptr_src| Statement::PtrAdd {
- underlying_type,
- state_space,
- dst,
- ptr_src: new_ptr_src,
- offset_src: constant_src,
+ stmt_ctor: |new_ptr_src| {
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src: new_ptr_src,
+ offset_src: constant_src,
+ })
},
};
insert_implicit_conversions_impl(
@@ -3224,13 +3186,13 @@ fn emit_function_body_ops(
let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src,
- } => {
+ }) => {
let u8_pointer = map.get_or_add(
builder,
SpirvType::from(ast::Type::Pointer(
@@ -4243,6 +4205,7 @@ fn expand_map_variables<'a, 'b>(
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
+// TODO: propagate through calls?
fn convert_to_stateful_memory_access<'a>(
func_args: &mut SpirvMethodDecl,
func_body: Vec<TypedStatement>,
@@ -4397,183 +4360,70 @@ fn convert_to_stateful_memory_access<'a>(
result.push(Statement::Variable(var));
}
}
- Statement::Instruction(ast::Instruction::Cvta(
- ast::CvtaDetails {
- to: ast::CvtaStateSpace::Global,
- size: ast::CvtaSize::U64,
- from: ast::CvtaStateSpace::Generic,
- },
+ Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
arg,
- )) if is_cvta_ptr_direct(&remapped_ids, &arg) => {
- let new_dst = *remapped_ids.get(&arg.dst).unwrap();
- let new_src = *remapped_ids.get(&arg.src.underlying().unwrap()).unwrap();
- result.push(Statement::Instruction(ast::Instruction::Mov(
- ast::MovDetails {
- typ: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- src_is_address: false,
- dst_width: 0,
- src_width: 0,
- relaxed_src2_conv: false,
- },
- ast::Arg2Mov::Normal(ast::Arg2MovNormal {
- dst: ast::IdOrVector::Reg(new_dst),
- src: ast::OperandOrVector::Reg(new_src),
- }),
- )));
- }
- Statement::Instruction(ast::Instruction::Ld(
- details
- @
- ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- ..
- },
+ ))
+ | Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
arg,
- )) if is_param_ld_ptr_direct(&remapped_ids, &func_args_ptr, &arg) => {
- let new_dst = if let ast::IdOrVector::Reg(dst) = arg.dst {
- *remapped_ids.get(&dst).unwrap()
- } else {
- return Err(TranslateError::Unreachable);
+ )) if is_add_ptr_direct(&remapped_ids, &arg) => {
+ let (ptr, offset) = match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => {
+ (remapped_ids.get(src1).unwrap(), arg.src2)
+ }
+ Some(src2) if remapped_ids.contains_key(src2) => {
+ (remapped_ids.get(src2).unwrap(), arg.src1)
+ }
+ _ => return Err(TranslateError::Unreachable),
};
- result.push(Statement::Instruction(ast::Instruction::Ld(
- ast::LdDetails {
- typ: ast::LdStType::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- ..details
- },
- ast::Arg2Ld {
- src: arg.src,
- dst: ast::IdOrVector::Reg(new_dst),
- },
- )));
- }
- Statement::Instruction(ast::Instruction::Ld(
- details
- @
- ast::LdDetails {
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
state_space: ast::LdStateSpace::Global,
- ..
- },
- arg,
- )) if is_ldst_global_ptr_direct(&remapped_ids, &arg.src) => {
- let new_src = if let ast::Operand::Reg(src) = arg.src {
- *remapped_ids.get(&src).unwrap()
- } else {
- return Err(TranslateError::Unreachable);
- };
- result.push(Statement::Instruction(ast::Instruction::Ld(
- details,
- ast::Arg2Ld {
- src: ast::Operand::Reg(new_src),
- ..arg
- },
- )));
- }
- Statement::Instruction(ast::Instruction::St(
- details
- @
- ast::StData {
- state_space: ast::StStateSpace::Global,
- ..
- },
- arg,
- )) if is_ldst_global_ptr_direct(&remapped_ids, &arg.src1) => {
- let new_src1 = if let ast::Operand::Reg(src1) = arg.src1 {
- *remapped_ids.get(&src1).unwrap()
- } else {
- return Err(TranslateError::Unreachable);
- };
- result.push(Statement::Instruction(ast::Instruction::St(
- details,
- ast::Arg2St {
- src1: ast::Operand::Reg(new_src1),
- ..arg
- },
- )));
+ dst: *remapped_ids.get(&arg.dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: offset,
+ }))
}
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
let new_statement =
- inst.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, typ| {
- Ok(match remapped_ids.get(&arg_desc.op) {
- Some(new_id) => {
- let (old_type_full, _) = id_defs.get_typed(arg_desc.op)?;
- let old_type = old_type_full.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type_full));
- if arg_desc.is_dst {
- post_statements.push(Statement::Conversion(
- ImplicitConversion {
- src: converting_id,
- dst: *new_id,
- from: old_type,
- to: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- kind: ConversionKind::BitToPtr(
- ast::LdStateSpace::Global,
- ),
- src_sema: ArgumentSemantics::Default,
- dst_sema: arg_desc.sema,
- },
- ));
- converting_id
- } else {
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- to: old_type,
- kind: ConversionKind::PtrToBit(ast::UIntType::U64),
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- }
- None => match func_args_ptr.get(&arg_desc.op) {
- Some(new_id) => {
- if arg_desc.is_dst {
- return Err(TranslateError::Unreachable);
- }
- let (old_type, _) = id_defs.get_typed(arg_desc.op)?;
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type));
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Pointer(
- ast::ScalarType::U8,
- ast::LdStateSpace::Global,
- ),
- ast::LdStateSpace::Param,
- ),
- to: old_type_clone,
- kind: ConversionKind::PtrToPtr { spirv_ptr: false },
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- None => arg_desc.op,
- },
- })
+ inst.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, _| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ )
+ })?;
+ result.push(new_statement);
+ for s in post_statements {
+ result.push(s);
+ }
+ }
+ Statement::Call(call) => {
+ let mut post_statements = Vec::new();
+ let new_statement =
+ call.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, _| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ )
})?;
result.push(new_statement);
for s in post_statements {
result.push(s);
}
}
- Statement::Call(call) => todo!(),
_ => return Err(TranslateError::Unreachable),
}
}
@@ -4588,33 +4438,84 @@ fn convert_to_stateful_memory_access<'a>(
Ok(result)
}
-fn is_param_ld_ptr_direct(
- remapped_ids: &HashMap<u32, u32>,
- func_args_ptr: &HashSet<u32>,
- arg: &ast::Arg2Ld<TypedArgParams>,
-) -> bool {
- match (arg.src.underlying(), &arg.dst) {
- (Some(src), ast::IdOrVector::Reg(dst)) => {
- func_args_ptr.contains(src) && remapped_ids.contains_key(dst)
+fn convert_to_stateful_memory_access_postprocess(
+ id_defs: &mut NumericIdResolver,
+ remapped_ids: &HashMap<spirv::Word, spirv::Word>,
+ func_args_ptr: &HashSet<spirv::Word>,
+ result: &mut Vec<TypedStatement>,
+ post_statements: &mut Vec<TypedStatement>,
+ arg_desc: ArgumentDescriptor<spirv::Word>,
+) -> Result<spirv::Word, TranslateError> {
+ Ok(match remapped_ids.get(&arg_desc.op) {
+ Some(new_id) => {
+ let (old_type_full, _) = id_defs.get_typed(arg_desc.op)?;
+ let old_type = old_type_full.clone();
+ let converting_id = id_defs.new_non_variable(Some(old_type_full));
+ if arg_desc.is_dst {
+ post_statements.push(Statement::Conversion(ImplicitConversion {
+ src: converting_id,
+ dst: *new_id,
+ from: old_type,
+ to: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: arg_desc.sema,
+ }));
+ converting_id
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ to: old_type,
+ kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ converting_id
+ }
}
- _ => false,
- }
+ None => match func_args_ptr.get(&arg_desc.op) {
+ Some(new_id) => {
+ if arg_desc.is_dst {
+ return Err(TranslateError::Unreachable);
+ }
+ let (old_type, _) = id_defs.get_typed(arg_desc.op)?;
+ let old_type_clone = old_type.clone();
+ let converting_id = id_defs.new_non_variable(Some(old_type));
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
+ ast::LdStateSpace::Param,
+ ),
+ to: old_type_clone,
+ kind: ConversionKind::PtrToPtr { spirv_ptr: false },
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ converting_id
+ }
+ None => arg_desc.op,
+ },
+ })
}
-fn is_cvta_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg2<TypedArgParams>) -> bool {
- match arg.src.underlying() {
- Some(src) => remapped_ids.contains_key(src) && remapped_ids.contains_key(&arg.dst),
- None => false,
+fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
+ if !remapped_ids.contains_key(&arg.dst) {
+ return false;
}
-}
-
-fn is_ldst_global_ptr_direct(
- remapped_ids: &HashMap<u32, u32>,
- src: &ast::Operand<spirv::Word>,
-) -> bool {
- match src.underlying() {
- Some(src) => remapped_ids.contains_key(src),
- None => false,
+ match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => true,
+ Some(src2) if remapped_ids.contains_key(src2) => true,
+ _ => false,
}
}
@@ -4960,13 +4861,7 @@ enum Statement<I, P: ast::ArgParams> {
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
Undef(ast::Type, spirv::Word),
- PtrAdd {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
- dst: spirv::Word,
- ptr_src: spirv::Word,
- offset_src: spirv::Word,
- },
+ PtrAccess(PtrAccess<P>),
}
impl ExpandedStatement {
@@ -5035,23 +4930,23 @@ impl ExpandedStatement {
let id = f(id, true);
Statement::Undef(typ, id)
}
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src: constant_src,
- } => {
+ }) => {
let dst = f(dst, true);
let ptr_src = f(ptr_src, false);
let constant_src = f(constant_src, false);
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src: constant_src,
- }
+ })
}
}
}
@@ -5156,6 +5051,70 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
}
}
+impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
+ fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<P, To>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<PtrAccess<To>, TranslateError> {
+ let sema = match self.state_space {
+ ast::LdStateSpace::Const
+ | ast::LdStateSpace::Global
+ | ast::LdStateSpace::Shared
+ | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
+ ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
+ ArgumentSemantics::RegisterPointer
+ }
+ };
+ let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
+ let new_dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema,
+ },
+ Some(&ptr_type),
+ )?;
+ let new_ptr_src = visitor.id(
+ ArgumentDescriptor {
+ op: self.ptr_src,
+ is_dst: false,
+ sema,
+ },
+ Some(&ptr_type),
+ )?;
+ let new_constant_src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.offset_src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(ast::ScalarType::S64),
+ )?;
+ Ok(PtrAccess {
+ underlying_type: self.underlying_type,
+ state_space: self.state_space,
+ dst: new_dst,
+ ptr_src: new_ptr_src,
+ offset_src: new_constant_src,
+ })
+ }
+}
+
+impl VisitVariable for PtrAccess<TypedArgParams> {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ Ok(Statement::PtrAccess(self.map(f)?))
+ }
+}
+
pub trait ArgParamsEx: ast::ArgParams + Sized {
fn get_fn_decl<'x, 'b>(
id: &Self::Id,
@@ -5451,6 +5410,14 @@ pub struct ArgumentDescriptor<Op> {
sema: ArgumentSemantics,
}
+pub struct PtrAccess<P: ast::ArgParams> {
+ underlying_type: ast::PointerType,
+ state_space: ast::LdStateSpace,
+ dst: spirv::Word,
+ ptr_src: spirv::Word,
+ offset_src: P::Operand,
+}
+
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArgumentSemantics {
// normal register access