diff options
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 1 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx | 22 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt | 60 | ||||
-rw-r--r-- | ptx/src/translate.rs | 19 |
4 files changed, 90 insertions, 12 deletions
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index c9ed9b1..ff48ae9 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -180,6 +180,7 @@ test_ptx!( ],
[0u32, 0u32, 0u32, 2u32]
);
+test_ptx!(non_scalar_ptr_offset, [1u32, 2u32, 3u32, 4u32], [7u32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx b/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx new file mode 100644 index 0000000..14d3d2c --- /dev/null +++ b/ptx/src/test/spirv_run/non_scalar_ptr_offset.ptx @@ -0,0 +1,22 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry non_scalar_ptr_offset( + .param .u64 input_p, + .param .u64 output_p +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 x; + .reg .u32 y; + + ld.param.u64 in_addr, [input_p]; + ld.param.u64 out_addr, [output_p]; + + ld.global.v2.u32 {x,y}, [in_addr+8]; + add.u32 x, x, y; + st.global.u32 [out_addr], x; + ret; +} diff --git a/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt b/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt new file mode 100644 index 0000000..92dc7cc --- /dev/null +++ b/ptx/src/test/spirv_run/non_scalar_ptr_offset.spvtxt @@ -0,0 +1,60 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %27 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "non_scalar_ptr_offset" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %30 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint + %ulong_8 = OpConstant %ulong 8 + %v2uint = OpTypeVector %uint 2 +%_ptr_CrossWorkgroup_v2uint = OpTypePointer CrossWorkgroup %v2uint + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %1 = OpFunction %void None %30 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %25 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %9 + OpStore %3 %10 + %11 = OpLoad %ulong %2 Aligned 8 + OpStore %4 %11 + %12 = OpLoad %ulong %3 Aligned 8 + OpStore %5 %12 + %13 = OpLoad %ulong %4 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_v2uint %13 + %38 = OpBitcast %_ptr_CrossWorkgroup_uchar %23 + %39 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %38 %ulong_8 + %22 = OpBitcast %_ptr_CrossWorkgroup_v2uint %39 + %8 = OpLoad %v2uint %22 Aligned 8 + %14 = OpCompositeExtract %uint %8 0 + %15 = OpCompositeExtract %uint %8 1 + OpStore %6 %14 + OpStore %7 %15 + %17 = OpLoad %uint %6 + %18 = OpLoad %uint %7 + %16 = OpIAdd %uint %17 %18 + OpStore %6 %16 + %19 = OpLoad %ulong %5 + %20 = OpLoad %uint %6 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %19 + OpStore %24 %20 Aligned 4 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index c2562c3..e0b82e8 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2389,10 +2389,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { )));
Ok(id_add_result)
} else {
- let scalar_type = match typ {
- ast::Type::Scalar(underlying_type) => *underlying_type,
- _ => return Err(error_unreachable()),
- };
let id_constant_stmt = self.id_def.register_intermediate(
ast::Type::Scalar(ast::ScalarType::S64),
ast::StateSpace::Reg,
@@ -2404,7 +2400,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { }));
let dst = self.id_def.register_intermediate(typ.clone(), state_space);
self.func.push(Statement::PtrAccess(PtrAccess {
- underlying_type: scalar_type,
+ underlying_type: typ.clone(),
state_space: state_space,
dst,
ptr_src: reg,
@@ -3118,7 +3114,7 @@ fn emit_function_body_ops( );
let result_type = map.get_or_add(
builder,
- SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
+ SpirvType::pointer_to(underlying_type.clone(), state_space.to_spirv()),
);
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain(
@@ -4532,7 +4528,7 @@ fn convert_to_stateful_memory_access<'a, 'input>( };
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::ScalarType::U8,
+ underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
@@ -4575,7 +4571,7 @@ fn convert_to_stateful_memory_access<'a, 'input>( )));
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::ScalarType::U8,
+ underlying_type: ast::Type::Scalar(ast::ScalarType::U8),
state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
@@ -5497,7 +5493,6 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> { self,
visitor: &mut V,
) -> Result<PtrAccess<To>, TranslateError> {
- let ptr_type = ast::Type::Scalar(self.underlying_type.clone());
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
@@ -5505,7 +5500,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> { is_memory_access: false,
non_default_implicit_conversion: None,
},
- Some((&ptr_type, self.state_space)),
+ Some((&self.underlying_type, self.state_space)),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
@@ -5514,7 +5509,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> { is_memory_access: false,
non_default_implicit_conversion: None,
},
- Some((&ptr_type, self.state_space)),
+ Some((&self.underlying_type, self.state_space)),
)?;
let new_constant_src = visitor.operand(
ArgumentDescriptor {
@@ -5707,7 +5702,7 @@ pub struct ArgumentDescriptor<Op> { }
pub struct PtrAccess<P: ast::ArgParams> {
- underlying_type: ast::ScalarType,
+ underlying_type: ast::Type,
state_space: ast::StateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,
|