diff options
author | Andrzej Janik <[email protected]> | 2020-11-18 23:01:15 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-11-18 23:01:15 +0100 |
commit | c8078071fd1eedebd94c9fd0ad81d65607bb8da1 (patch) | |
tree | b00ad66fbc5224c450e8a20889866cb5dbe891fc | |
parent | b0b0c21a9b5eb973d366dda6899e4a94442f923d (diff) | |
download | ZLUDA-c8078071fd1eedebd94c9fd0ad81d65607bb8da1.tar.gz ZLUDA-c8078071fd1eedebd94c9fd0ad81d65607bb8da1.zip |
Implement PtrAdd transformations
-rw-r--r-- | ptx/src/test/spirv_run/cvta.spvtxt | 85 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/reg_local.spvtxt | 7 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt | 116 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt | 94 | ||||
-rw-r--r-- | ptx/src/translate.rs | 485 |
5 files changed, 406 insertions, 381 deletions
diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt index cf6ff8b..2e9f028 100644 --- a/ptx/src/test/spirv_run/cvta.spvtxt +++ b/ptx/src/test/spirv_run/cvta.spvtxt @@ -7,48 +7,61 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %27 = OpExtInstImport "OpenCL.std" + %39 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvta" %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %30 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %43 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %float = OpTypeFloat 32 %_ptr_Function_float = OpTypePointer Function %float + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %1 = OpFunction %void None %30 - %7 = OpFunctionParameter %ulong - %8 = OpFunctionParameter %ulong - %25 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function + %1 = OpFunction %void None %43 + %19 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %37 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %7 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %8 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %6 = OpVariable %_ptr_Function_float Function - OpStore %2 %7 - OpStore %3 %8 - %9 = OpLoad %ulong %2 - OpStore %4 %9 - %10 = OpLoad %ulong %3 - OpStore %5 %10 - %12 = OpLoad %ulong %4 - %20 = OpCopyObject %ulong %12 - %19 = OpCopyObject %ulong %20 - %11 = OpCopyObject %ulong %19 - OpStore %4 %11 - %14 = OpLoad %ulong %5 - %22 = OpCopyObject %ulong %14 - %21 = OpCopyObject %ulong %22 - %13 = OpCopyObject %ulong %21 - OpStore %5 %13 - %16 = OpLoad %ulong %4 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16 - %15 = OpLoad %float %23 - OpStore %6 %15 - %17 = OpLoad %ulong %5 - %18 = OpLoad %float %6 - %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17 - OpStore %24 %18 + OpStore %2 %19 + OpStore %3 %20 + %10 = OpBitcast %_ptr_Function_ulong %2 + %9 = OpLoad %ulong %10 + %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %9 + OpStore %7 %21 + %12 = OpBitcast %_ptr_Function_ulong %3 + %11 = OpLoad %ulong %12 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %11 + OpStore %8 %22 + %23 = OpLoad %_ptr_CrossWorkgroup_uchar %7 + %14 = OpConvertPtrToU %ulong %23 + %32 = OpCopyObject %ulong %14 + %31 = OpCopyObject %ulong %32 + %13 = OpCopyObject %ulong %31 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13 + OpStore %7 %24 + %25 = OpLoad %_ptr_CrossWorkgroup_uchar %8 + %16 = OpConvertPtrToU %ulong %25 + %34 = OpCopyObject %ulong %16 + %33 = OpCopyObject %ulong %34 + %15 = OpCopyObject %ulong %33 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 + OpStore %8 %26 + %27 = OpLoad %_ptr_CrossWorkgroup_uchar %7 + %17 = OpConvertPtrToU %ulong %27 + %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17 + %28 = OpLoad %float %35 + OpStore %6 %28 + %29 = OpLoad %_ptr_CrossWorkgroup_uchar %8 + %18 = OpConvertPtrToU %ulong %29 + %30 = OpLoad %float %6 + %36 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 + OpStore %36 %30 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt index 10ff639..5ce3689 100644 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -24,6 +24,7 @@ %ulong_1 = OpConstant %ulong 1 %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_0 = OpConstant %ulong 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_0_0 = OpConstant %ulong 0 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -52,9 +53,9 @@ %27 = OpBitcast %_ptr_Generic_ulong %4 OpStore %27 %19 %28 = OpBitcast %_ptr_Generic_ulong %4 - %46 = OpBitcast %ulong %28 - %47 = OpIAdd %ulong %46 %ulong_0 - %21 = OpBitcast %_ptr_Generic_ulong %47 + %47 = OpBitcast %_ptr_Generic_uchar %28 + %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0 + %21 = OpBitcast %_ptr_Generic_ulong %48 %29 = OpLoad %ulong %21 %15 = OpCopyObject %ulong %29 OpStore %7 %15 diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt index 963d88a..bf56a18 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt @@ -7,51 +7,85 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" + %51 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ld_st_offset" + OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID + OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %33 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 + %v4uint = OpTypeVector %uint 4 +%_ptr_UniformConstant_v4uint = OpTypePointer UniformConstant %v4uint +%gl_LocalInvocationID = OpVariable %_ptr_UniformConstant_v4uint UniformConstant + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %58 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %33 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %28 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function + %ulong = OpTypeInt 64 0 +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %1 = OpFunction %void None %58 + %22 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %23 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %49 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 - OpStore %4 %10 - %11 = OpLoad %ulong %3 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %24 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %21 = OpIAdd %ulong %15 %ulong_4 - %25 = OpConvertUToPtr %_ptr_Generic_uint %21 - %14 = OpLoad %uint %25 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %uint %7 - %26 = OpConvertUToPtr %_ptr_Generic_uint %16 - OpStore %26 %17 - %18 = OpLoad %ulong %5 - %19 = OpLoad %uint %6 - %23 = OpIAdd %ulong %18 %ulong_4_0 - %27 = OpConvertUToPtr %_ptr_Generic_uint %23 - OpStore %27 %19 + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %22 + OpStore %3 %23 + %13 = OpBitcast %_ptr_Function_ulong %2 + %45 = OpLoad %ulong %13 + %12 = OpCopyObject %ulong %45 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12 + OpStore %10 %24 + %15 = OpBitcast %_ptr_Function_ulong %3 + %46 = OpLoad %ulong %15 + %14 = OpCopyObject %ulong %46 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 + OpStore %11 %25 + %26 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %17 = OpConvertPtrToU %ulong %26 + %16 = OpCopyObject %ulong %17 + %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %16 + OpStore %10 %27 + %28 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %19 = OpConvertPtrToU %ulong %28 + %18 = OpCopyObject %ulong %19 + %29 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 + OpStore %11 %29 + %31 = OpLoad %v4uint %gl_LocalInvocationID + %44 = OpCompositeExtract %uint %31 0 + %30 = OpCopyObject %uint %44 + OpStore %6 %30 + %33 = OpLoad %uint %6 + %63 = OpBitcast %uint %33 + %32 = OpUConvert %ulong %63 + OpStore %7 %32 + %35 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %36 = OpLoad %ulong %7 + %64 = OpBitcast %_ptr_CrossWorkgroup_uchar %35 + %65 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %64 %36 + %34 = OpBitcast %_ptr_CrossWorkgroup_uchar %65 + OpStore %10 %34 + %38 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %39 = OpLoad %ulong %7 + %66 = OpBitcast %_ptr_CrossWorkgroup_uchar %38 + %67 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %66 %39 + %37 = OpBitcast %_ptr_CrossWorkgroup_uchar %67 + OpStore %11 %37 + %40 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %20 = OpConvertPtrToU %ulong %40 + %47 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %20 + %41 = OpLoad %ulong %47 + OpStore %8 %41 + %42 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %21 = OpConvertPtrToU %ulong %42 + %43 = OpLoad %ulong %8 + %48 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 + OpStore %48 %43 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt index 963d88a..1d28996 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt @@ -7,51 +7,61 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" + %43 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ld_st_offset" + OpEntryPoint Kernel %1 "stateful_ld_st_simple" %void = OpTypeVoid + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %47 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar +%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %ulong = OpTypeInt 64 0 - %33 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %33 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %28 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 - OpStore %4 %10 - %11 = OpLoad %ulong %3 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %24 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %21 = OpIAdd %ulong %15 %ulong_4 - %25 = OpConvertUToPtr %_ptr_Generic_uint %21 - %14 = OpLoad %uint %25 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %uint %7 - %26 = OpConvertUToPtr %_ptr_Generic_uint %16 - OpStore %26 %17 - %18 = OpLoad %ulong %5 - %19 = OpLoad %uint %6 - %23 = OpIAdd %ulong %18 %ulong_4_0 - %27 = OpConvertUToPtr %_ptr_Generic_uint %23 - OpStore %27 %19 +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %1 = OpFunction %void None %47 + %23 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %24 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar + %41 = OpLabel + %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %9 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %23 + OpStore %3 %24 + %14 = OpBitcast %_ptr_Function_ulong %2 + %13 = OpLoad %ulong %14 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %13 + OpStore %9 %25 + %16 = OpBitcast %_ptr_Function_ulong %3 + %15 = OpLoad %ulong %16 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 + OpStore %11 %26 + %27 = OpLoad %_ptr_CrossWorkgroup_uchar %9 + %18 = OpConvertPtrToU %ulong %27 + %36 = OpCopyObject %ulong %18 + %35 = OpCopyObject %ulong %36 + %17 = OpCopyObject %ulong %35 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %17 + OpStore %12 %28 + %29 = OpLoad %_ptr_CrossWorkgroup_uchar %11 + %20 = OpConvertPtrToU %ulong %29 + %38 = OpCopyObject %ulong %20 + %37 = OpCopyObject %ulong %38 + %19 = OpCopyObject %ulong %37 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %19 + OpStore %10 %30 + %31 = OpLoad %_ptr_CrossWorkgroup_uchar %12 + %21 = OpConvertPtrToU %ulong %31 + %39 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 + %32 = OpLoad %ulong %39 + OpStore %8 %32 + %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 + %22 = OpConvertPtrToU %ulong %33 + %34 = OpLoad %ulong %8 + %40 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %22 + OpStore %40 %34 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 29fa93e..ef2c1c5 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -971,7 +971,7 @@ fn compute_denorm_information<'input>( Statement::Undef(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
- Statement::PtrAdd { .. } => {}
+ Statement::PtrAccess { .. } => {}
}
}
denorm_methods.insert(method_key, flush_counter);
@@ -1635,6 +1635,7 @@ fn convert_to_typed_statements( },
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
+ Statement::Conditional(c) => result.push(Statement::Conditional(c)),
_ => return Err(TranslateError::Unreachable),
}
}
@@ -1868,7 +1869,7 @@ fn normalize_labels( | Statement::Constant(_)
| Statement::Label(_)
| Statement::Undef(_, _)
- | Statement::PtrAdd { .. } => {}
+ | Statement::PtrAccess { .. } => {}
}
}
iter::once(Statement::Label(id_def.new_non_variable(None)))
@@ -2004,6 +2005,9 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Conversion(conv) => {
insert_mem_ssa_statement_default(id_def, &mut result, conv)?
}
+ Statement::PtrAccess(ptr_access) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
_ => return Err(TranslateError::Unreachable),
}
@@ -2159,55 +2163,11 @@ fn expand_arguments<'a, 'b>( name,
array_init,
})),
- Statement::PtrAdd {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- offset_src: constant_src,
- } => {
+ Statement::PtrAccess(ptr_access) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let sema = match state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
- };
- let ptr_type = ast::Type::Pointer(underlying_type.clone(), state_space);
- let new_dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema,
- },
- Some(&ptr_type),
- )?;
- let new_ptr_src = visitor.id(
- ArgumentDescriptor {
- op: ptr_src,
- is_dst: false,
- sema,
- },
- Some(&ptr_type),
- )?;
- let new_constant_src = visitor.id(
- ArgumentDescriptor {
- op: constant_src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Scalar(ast::ScalarType::S64)),
- )?;
- result.push(Statement::PtrAdd {
- underlying_type,
- state_space,
- dst: new_dst,
- ptr_src: new_ptr_src,
- offset_src: new_constant_src,
- })
+ let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::PtrAccess(new_inst));
+ result.extend(post_stmts);
}
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
@@ -2288,13 +2248,13 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { value: ast::ImmediateValue::S64(offset as i64),
}));
let dst = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::PtrAdd {
+ self.func.push(Statement::PtrAccess(PtrAccess {
underlying_type: underlying_type.clone(),
state_space: *state_space,
dst,
ptr_src: reg,
offset_src: id_constant_stmt,
- });
+ }));
return Ok(dst);
} else {
add_type = self.id_def.get_typed(reg)?;
@@ -2577,13 +2537,13 @@ fn insert_implicit_conversions( should_bitcast_wrapper,
None,
)?,
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src: constant_src,
- } => {
+ }) => {
let visit_desc = VisitArgumentDescriptor {
desc: ArgumentDescriptor {
op: ptr_src,
@@ -2591,12 +2551,14 @@ fn insert_implicit_conversions( sema: ArgumentSemantics::PhysicalPointer,
},
typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
- stmt_ctor: |new_ptr_src| Statement::PtrAdd {
- underlying_type,
- state_space,
- dst,
- ptr_src: new_ptr_src,
- offset_src: constant_src,
+ stmt_ctor: |new_ptr_src| {
+ Statement::PtrAccess(PtrAccess {
+ underlying_type,
+ state_space,
+ dst,
+ ptr_src: new_ptr_src,
+ offset_src: constant_src,
+ })
},
};
insert_implicit_conversions_impl(
@@ -3224,13 +3186,13 @@ fn emit_function_body_ops( let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src,
- } => {
+ }) => {
let u8_pointer = map.get_or_add(
builder,
SpirvType::from(ast::Type::Pointer(
@@ -4243,6 +4205,7 @@ fn expand_map_variables<'a, 'b>( // TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
+// TODO: propagate through calls?
fn convert_to_stateful_memory_access<'a>(
func_args: &mut SpirvMethodDecl,
func_body: Vec<TypedStatement>,
@@ -4397,183 +4360,70 @@ fn convert_to_stateful_memory_access<'a>( result.push(Statement::Variable(var));
}
}
- Statement::Instruction(ast::Instruction::Cvta(
- ast::CvtaDetails {
- to: ast::CvtaStateSpace::Global,
- size: ast::CvtaSize::U64,
- from: ast::CvtaStateSpace::Generic,
- },
+ Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
arg,
- )) if is_cvta_ptr_direct(&remapped_ids, &arg) => {
- let new_dst = *remapped_ids.get(&arg.dst).unwrap();
- let new_src = *remapped_ids.get(&arg.src.underlying().unwrap()).unwrap();
- result.push(Statement::Instruction(ast::Instruction::Mov(
- ast::MovDetails {
- typ: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- src_is_address: false,
- dst_width: 0,
- src_width: 0,
- relaxed_src2_conv: false,
- },
- ast::Arg2Mov::Normal(ast::Arg2MovNormal {
- dst: ast::IdOrVector::Reg(new_dst),
- src: ast::OperandOrVector::Reg(new_src),
- }),
- )));
- }
- Statement::Instruction(ast::Instruction::Ld(
- details
- @
- ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- ..
- },
+ ))
+ | Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
arg,
- )) if is_param_ld_ptr_direct(&remapped_ids, &func_args_ptr, &arg) => {
- let new_dst = if let ast::IdOrVector::Reg(dst) = arg.dst {
- *remapped_ids.get(&dst).unwrap()
- } else {
- return Err(TranslateError::Unreachable);
+ )) if is_add_ptr_direct(&remapped_ids, &arg) => {
+ let (ptr, offset) = match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => {
+ (remapped_ids.get(src1).unwrap(), arg.src2)
+ }
+ Some(src2) if remapped_ids.contains_key(src2) => {
+ (remapped_ids.get(src2).unwrap(), arg.src1)
+ }
+ _ => return Err(TranslateError::Unreachable),
};
- result.push(Statement::Instruction(ast::Instruction::Ld(
- ast::LdDetails {
- typ: ast::LdStType::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- ..details
- },
- ast::Arg2Ld {
- src: arg.src,
- dst: ast::IdOrVector::Reg(new_dst),
- },
- )));
- }
- Statement::Instruction(ast::Instruction::Ld(
- details
- @
- ast::LdDetails {
+ result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
state_space: ast::LdStateSpace::Global,
- ..
- },
- arg,
- )) if is_ldst_global_ptr_direct(&remapped_ids, &arg.src) => {
- let new_src = if let ast::Operand::Reg(src) = arg.src {
- *remapped_ids.get(&src).unwrap()
- } else {
- return Err(TranslateError::Unreachable);
- };
- result.push(Statement::Instruction(ast::Instruction::Ld(
- details,
- ast::Arg2Ld {
- src: ast::Operand::Reg(new_src),
- ..arg
- },
- )));
- }
- Statement::Instruction(ast::Instruction::St(
- details
- @
- ast::StData {
- state_space: ast::StStateSpace::Global,
- ..
- },
- arg,
- )) if is_ldst_global_ptr_direct(&remapped_ids, &arg.src1) => {
- let new_src1 = if let ast::Operand::Reg(src1) = arg.src1 {
- *remapped_ids.get(&src1).unwrap()
- } else {
- return Err(TranslateError::Unreachable);
- };
- result.push(Statement::Instruction(ast::Instruction::St(
- details,
- ast::Arg2St {
- src1: ast::Operand::Reg(new_src1),
- ..arg
- },
- )));
+ dst: *remapped_ids.get(&arg.dst).unwrap(),
+ ptr_src: *ptr,
+ offset_src: offset,
+ }))
}
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
let new_statement =
- inst.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, typ| {
- Ok(match remapped_ids.get(&arg_desc.op) {
- Some(new_id) => {
- let (old_type_full, _) = id_defs.get_typed(arg_desc.op)?;
- let old_type = old_type_full.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type_full));
- if arg_desc.is_dst {
- post_statements.push(Statement::Conversion(
- ImplicitConversion {
- src: converting_id,
- dst: *new_id,
- from: old_type,
- to: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- kind: ConversionKind::BitToPtr(
- ast::LdStateSpace::Global,
- ),
- src_sema: ArgumentSemantics::Default,
- dst_sema: arg_desc.sema,
- },
- ));
- converting_id
- } else {
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- to: old_type,
- kind: ConversionKind::PtrToBit(ast::UIntType::U64),
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- }
- None => match func_args_ptr.get(&arg_desc.op) {
- Some(new_id) => {
- if arg_desc.is_dst {
- return Err(TranslateError::Unreachable);
- }
- let (old_type, _) = id_defs.get_typed(arg_desc.op)?;
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type));
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Pointer(
- ast::ScalarType::U8,
- ast::LdStateSpace::Global,
- ),
- ast::LdStateSpace::Param,
- ),
- to: old_type_clone,
- kind: ConversionKind::PtrToPtr { spirv_ptr: false },
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- None => arg_desc.op,
- },
- })
+ inst.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, _| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ )
+ })?;
+ result.push(new_statement);
+ for s in post_statements {
+ result.push(s);
+ }
+ }
+ Statement::Call(call) => {
+ let mut post_statements = Vec::new();
+ let new_statement =
+ call.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, _| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ )
})?;
result.push(new_statement);
for s in post_statements {
result.push(s);
}
}
- Statement::Call(call) => todo!(),
_ => return Err(TranslateError::Unreachable),
}
}
@@ -4588,33 +4438,84 @@ fn convert_to_stateful_memory_access<'a>( Ok(result)
}
-fn is_param_ld_ptr_direct(
- remapped_ids: &HashMap<u32, u32>,
- func_args_ptr: &HashSet<u32>,
- arg: &ast::Arg2Ld<TypedArgParams>,
-) -> bool {
- match (arg.src.underlying(), &arg.dst) {
- (Some(src), ast::IdOrVector::Reg(dst)) => {
- func_args_ptr.contains(src) && remapped_ids.contains_key(dst)
+fn convert_to_stateful_memory_access_postprocess(
+ id_defs: &mut NumericIdResolver,
+ remapped_ids: &HashMap<spirv::Word, spirv::Word>,
+ func_args_ptr: &HashSet<spirv::Word>,
+ result: &mut Vec<TypedStatement>,
+ post_statements: &mut Vec<TypedStatement>,
+ arg_desc: ArgumentDescriptor<spirv::Word>,
+) -> Result<spirv::Word, TranslateError> {
+ Ok(match remapped_ids.get(&arg_desc.op) {
+ Some(new_id) => {
+ let (old_type_full, _) = id_defs.get_typed(arg_desc.op)?;
+ let old_type = old_type_full.clone();
+ let converting_id = id_defs.new_non_variable(Some(old_type_full));
+ if arg_desc.is_dst {
+ post_statements.push(Statement::Conversion(ImplicitConversion {
+ src: converting_id,
+ dst: *new_id,
+ from: old_type,
+ to: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: arg_desc.sema,
+ }));
+ converting_id
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ to: old_type,
+ kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ converting_id
+ }
}
- _ => false,
- }
+ None => match func_args_ptr.get(&arg_desc.op) {
+ Some(new_id) => {
+ if arg_desc.is_dst {
+ return Err(TranslateError::Unreachable);
+ }
+ let (old_type, _) = id_defs.get_typed(arg_desc.op)?;
+ let old_type_clone = old_type.clone();
+ let converting_id = id_defs.new_non_variable(Some(old_type));
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
+ ast::LdStateSpace::Param,
+ ),
+ to: old_type_clone,
+ kind: ConversionKind::PtrToPtr { spirv_ptr: false },
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ converting_id
+ }
+ None => arg_desc.op,
+ },
+ })
}
-fn is_cvta_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg2<TypedArgParams>) -> bool {
- match arg.src.underlying() {
- Some(src) => remapped_ids.contains_key(src) && remapped_ids.contains_key(&arg.dst),
- None => false,
+fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
+ if !remapped_ids.contains_key(&arg.dst) {
+ return false;
}
-}
-
-fn is_ldst_global_ptr_direct(
- remapped_ids: &HashMap<u32, u32>,
- src: &ast::Operand<spirv::Word>,
-) -> bool {
- match src.underlying() {
- Some(src) => remapped_ids.contains_key(src),
- None => false,
+ match arg.src1.underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => true,
+ Some(src2) if remapped_ids.contains_key(src2) => true,
+ _ => false,
}
}
@@ -4960,13 +4861,7 @@ enum Statement<I, P: ast::ArgParams> { Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
Undef(ast::Type, spirv::Word),
- PtrAdd {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
- dst: spirv::Word,
- ptr_src: spirv::Word,
- offset_src: spirv::Word,
- },
+ PtrAccess(PtrAccess<P>),
}
impl ExpandedStatement {
@@ -5035,23 +4930,23 @@ impl ExpandedStatement { let id = f(id, true);
Statement::Undef(typ, id)
}
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src: constant_src,
- } => {
+ }) => {
let dst = f(dst, true);
let ptr_src = f(ptr_src, false);
let constant_src = f(constant_src, false);
- Statement::PtrAdd {
+ Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
dst,
ptr_src,
offset_src: constant_src,
- }
+ })
}
}
}
@@ -5156,6 +5051,70 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> { }
}
+impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
+ fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<P, To>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<PtrAccess<To>, TranslateError> {
+ let sema = match self.state_space {
+ ast::LdStateSpace::Const
+ | ast::LdStateSpace::Global
+ | ast::LdStateSpace::Shared
+ | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
+ ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
+ ArgumentSemantics::RegisterPointer
+ }
+ };
+ let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
+ let new_dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema,
+ },
+ Some(&ptr_type),
+ )?;
+ let new_ptr_src = visitor.id(
+ ArgumentDescriptor {
+ op: self.ptr_src,
+ is_dst: false,
+ sema,
+ },
+ Some(&ptr_type),
+ )?;
+ let new_constant_src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.offset_src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &ast::Type::Scalar(ast::ScalarType::S64),
+ )?;
+ Ok(PtrAccess {
+ underlying_type: self.underlying_type,
+ state_space: self.state_space,
+ dst: new_dst,
+ ptr_src: new_ptr_src,
+ offset_src: new_constant_src,
+ })
+ }
+}
+
+impl VisitVariable for PtrAccess<TypedArgParams> {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ Ok(Statement::PtrAccess(self.map(f)?))
+ }
+}
+
pub trait ArgParamsEx: ast::ArgParams + Sized {
fn get_fn_decl<'x, 'b>(
id: &Self::Id,
@@ -5451,6 +5410,14 @@ pub struct ArgumentDescriptor<Op> { sema: ArgumentSemantics,
}
+pub struct PtrAccess<P: ast::ArgParams> {
+ underlying_type: ast::PointerType,
+ state_space: ast::LdStateSpace,
+ dst: spirv::Word,
+ ptr_src: spirv::Word,
+ offset_src: P::Operand,
+}
+
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArgumentSemantics {
// normal register access
|