aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx22
-rw-r--r--ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt60
-rw-r--r--ptx/src/translate.rs19
4 files changed, 90 insertions, 12 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index c9ed9b1..ff48ae9 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -180,6 +180,7 @@ test_ptx!(
],
[0u32, 0u32, 0u32, 2u32]
);
+test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx b/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx
new file mode 100644
index 0000000..14d3d2c
--- /dev/null
+++ b/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx
@@ -0,0 +1,22 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry non_scalar_ptr_offset(
+ .param .u64 input_p,
+ .param .u64 output_p
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 x;
+ .reg .u32 y;
+
+ ld.param.u64 in_addr, [input_p];
+ ld.param.u64 out_addr, [output_p];
+
+ ld.global.v2.u32 {x,y}, [in_addr+8];
+ add.u32 x, x, y;
+ st.global.u32 [out_addr], x;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt b/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt
new file mode 100644
index 0000000..92dc7cc
--- /dev/null
+++ b/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt
@@ -0,0 +1,60 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %27 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "non_scalar_ptr_offset"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %30 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %ulong_8 = OpConstant %ulong 8
+ %v2uint = OpTypeVector %uint 2
+%_ptr_CrossWorkgroup_v2uint = OpTypePointer CrossWorkgroup %v2uint
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
+ %1 = OpFunction %void None %30
+ %9 = OpFunctionParameter %ulong
+ %10 = OpFunctionParameter %ulong
+ %25 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %9
+ OpStore %3 %10
+ %11 = OpLoad %ulong %2 Aligned 8
+ OpStore %4 %11
+ %12 = OpLoad %ulong %3 Aligned 8
+ OpStore %5 %12
+ %13 = OpLoad %ulong %4
+ %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_v2uint %13
+ %38 = OpBitcast %_ptr_CrossWorkgroup_uchar %23
+ %39 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %38 %ulong_8
+ %22 = OpBitcast %_ptr_CrossWorkgroup_v2uint %39
+ %8 = OpLoad %v2uint %22 Aligned 8
+ %14 = OpCompositeExtract %uint %8 0
+ %15 = OpCompositeExtract %uint %8 1
+ OpStore %6 %14
+ OpStore %7 %15
+ %17 = OpLoad %uint %6
+ %18 = OpLoad %uint %7
+ %16 = OpIAdd %uint %17 %18
+ OpStore %6 %16
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %uint %6
+ %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %19
+ OpStore %24 %20 Aligned 4
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index c2562c3..e0b82e8 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -2389,10 +2389,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
)));
Ok(id_add_result)
} else {
- let scalar_type = match typ {
- ast::Type::Scalar(underlying_type) => *underlying_type,
- _ => return Err(error_unreachable()),
- };
let id_constant_stmt = self.id_def.register_intermediate(
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
@@ -2404,7 +2400,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
}));
let dst = self.id_def.register_intermediate(typ.clone(), state_space);
self.func.push(Statement::PtrAccess(PtrAccess {
- underlying_type: scalar_type,
+ underlying_type: typ.clone(),
state_space: state_space,
dst,
ptr_src: reg,
@@ -3118,7 +3114,7 @@ fn emit_function_body_ops(
);
let result_type = map.get_or_add(
builder,
- SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
+ SpirvType::pointer_to(underlying_type.clone(), state_space.to_spirv()),
);
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain(
@@ -4532,7 +4528,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
};
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::ScalarType::U8,
+ underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
@@ -4575,7 +4571,7 @@ fn convert_to_stateful_memory_access<'a, 'input>(
)));
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::ScalarType::U8,
+ underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
@@ -5497,7 +5493,6 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
self,
visitor: &mut V,
) -> Result<PtrAccess<To>, TranslateError> {
- let ptr_type = ast::Type::Scalar(self.underlying_type.clone());
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
@@ -5505,7 +5500,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
is_memory_access: false,
non_default_implicit_conversion: None,
},
- Some((&ptr_type, self.state_space)),
+ Some((&self.underlying_type, self.state_space)),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
@@ -5514,7 +5509,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
is_memory_access: false,
non_default_implicit_conversion: None,
},
- Some((&ptr_type, self.state_space)),
+ Some((&self.underlying_type, self.state_space)),
)?;
let new_constant_src = visitor.operand(
ArgumentDescriptor {
@@ -5707,7 +5702,7 @@ pub struct ArgumentDescriptor<Op> {
}
pub struct PtrAccess<P: ast::ArgParams> {
- underlying_type: ast::ScalarType,
+ underlying_type: ast::Type,
state_space: ast::StateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,