diff options
author | Andrzej Janik <[email protected]> | 2020-10-01 00:44:58 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-10-01 18:11:57 +0200 |
commit | 3e92921275473e3dc028ff5159a17179af6047ba (patch) | |
tree | 1ecfe9c7ebe27785c2b132675224e3cf1de03631 | |
parent | 1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8 (diff) | |
download | ZLUDA-3e92921275473e3dc028ff5159a17179af6047ba.tar.gz ZLUDA-3e92921275473e3dc028ff5159a17179af6047ba.zip |
Fix remaining bugs in vector destructuring and in the process improve implicit conversions
-rw-r--r-- | Cargo.toml | 4 | ||||
-rw-r--r-- | ptx/Cargo.toml | 2 | ||||
-rw-r--r-- | ptx/src/ast.rs | 11 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/ld_st_implicit.spvtxt | 16 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mul_wide.spvtxt | 10 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/vector_extract.ptx | 3 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/vector_extract.spvtxt | 199 | ||||
-rw-r--r-- | ptx/src/translate.rs | 673 |
8 files changed, 433 insertions, 485 deletions
@@ -11,5 +11,5 @@ members = [ ]
[patch.crates-io]
-rspirv = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' }
-spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' }
\ No newline at end of file +rspirv = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' }
+spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' }
\ No newline at end of file diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 42d60cb..96ab9d0 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -10,7 +10,7 @@ edition = "2018" lalrpop-util = "0.19" regex = "1" rspirv = "0.6" -spirv_headers = "1.4" +spirv_headers = "~1.4.2" quick-error = "1.2" bit-vec = "0.6" half ="1.6" diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 7edfa70..097e19c 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -463,14 +463,14 @@ pub enum CallOperand<ID> { pub enum IdOrVector<ID> { Reg(ID), - Vec(Vec<ID>) + Vec(Vec<ID>), } pub enum OperandOrVector<ID> { Reg(ID), RegOffset(ID, i32), Imm(u32), - Vec(Vec<ID>) + Vec(Vec<ID>), } impl<T> From<Operand<T>> for OperandOrVector<T> { @@ -536,6 +536,8 @@ pub struct MovDetails { // two fields below are in use by member moves pub dst_width: u8, pub src_width: u8, + // This is in use by auto-generated movs + pub relaxed_src2_conv: bool, } impl MovDetails { @@ -544,7 +546,8 @@ impl MovDetails { typ, src_is_address: false, dst_width: 0, - src_width: 0 + src_width: 0, + relaxed_src2_conv: false, } } } @@ -560,7 +563,7 @@ pub struct MulIntDesc { pub control: MulIntControl, } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum MulIntControl { Low, High, diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt index 249af90..d4d9499 100644 --- a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt +++ b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt @@ -2,8 +2,10 @@ OpCapability Linkage OpCapability Addresses OpCapability Kernel - OpCapability Int64 OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 OpCapability Float64 %23 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL @@ -33,17 +35,17 @@ %11 = OpCopyObject %ulong %12 OpStore %5 %11 %14 = OpLoad %ulong %4 - %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 - %18 = OpLoad %float %17 - %31 = OpBitcast %uint %18 + %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 + %17 = OpLoad %float %18 + %31 = OpBitcast %uint %17 %13 = OpUConvert %ulong %31 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %ulong %6 + %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15 %32 = OpBitcast %ulong %16 %33 = OpUConvert %uint %32 - %19 = OpBitcast %float %33 - %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15 - OpStore %20 %19 + %20 = OpBitcast %float %33 + OpStore %19 %20 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt index 274612c..8ac0459 100644 --- a/ptx/src/test/spirv_run/mul_wide.spvtxt +++ b/ptx/src/test/spirv_run/mul_wide.spvtxt @@ -2,8 +2,10 @@ OpCapability Linkage OpCapability Addresses OpCapability Kernel - OpCapability Int64 OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 OpCapability Float64 %32 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL @@ -57,8 +59,8 @@ OpStore %8 %19 %22 = OpLoad %ulong %5 %23 = OpLoad %ulong %8 - %28 = OpCopyObject %ulong %23 - %29 = OpConvertUToPtr %_ptr_Generic_ulong %22 - OpStore %29 %28 + %28 = OpConvertUToPtr %_ptr_Generic_ulong %22 + %29 = OpCopyObject %ulong %23 + OpStore %28 %29 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector_extract.ptx b/ptx/src/test/spirv_run/vector_extract.ptx index 8624f8a..111f7c0 100644 --- a/ptx/src/test/spirv_run/vector_extract.ptx +++ b/ptx/src/test/spirv_run/vector_extract.ptx @@ -15,6 +15,9 @@ .reg .u16 temp4; .reg .v4.u16 foo; + ld.param.u64 in_addr, [input_p]; + ld.param.u64 out_addr, [output_p]; + ld.global.v4.u8 {temp1, temp2, temp3, temp4}, [in_addr]; mov.v4.u16 foo, {temp2, temp3, temp4, temp1}; mov.v4.u16 {temp3, temp4, temp1, temp2}, foo; diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt index ff0ee97..45df3a8 100644 --- a/ptx/src/test/spirv_run/vector_extract.spvtxt +++ b/ptx/src/test/spirv_run/vector_extract.spvtxt @@ -2,96 +2,123 @@ OpCapability Linkage OpCapability Addresses OpCapability Kernel - OpCapability Int64 OpCapability Int8 - %60 = OpExtInstImport "OpenCL.std" + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %75 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %31 "vector" + OpEntryPoint Kernel %1 "vector_extract" %void = OpTypeVoid - %uint = OpTypeInt 32 0 - %v2uint = OpTypeVector %uint 2 - %64 = OpTypeFunction %v2uint %v2uint -%_ptr_Function_v2uint = OpTypePointer Function %v2uint -%_ptr_Function_uint = OpTypePointer Function %uint %ulong = OpTypeInt 64 0 - %68 = OpTypeFunction %void %ulong %ulong + %78 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint - %1 = OpFunction %v2uint None %64 - %7 = OpFunctionParameter %v2uint - %30 = OpLabel - %3 = OpVariable %_ptr_Function_v2uint Function - %2 = OpVariable %_ptr_Function_v2uint Function - %4 = OpVariable %_ptr_Function_v2uint Function - %5 = OpVariable %_ptr_Function_uint Function - %6 = OpVariable %_ptr_Function_uint Function - OpStore %3 %7 - %9 = OpLoad %v2uint %3 - %27 = OpCompositeExtract %uint %9 0 - %8 = OpCopyObject %uint %27 - OpStore %5 %8 - %11 = OpLoad %v2uint %3 - %28 = OpCompositeExtract %uint %11 1 - %10 = OpCopyObject %uint %28 - OpStore %6 %10 - %13 = OpLoad %uint %5 - %14 = OpLoad %uint %6 - %12 = OpIAdd %uint %13 %14 - OpStore %6 %12 - %16 = OpLoad %v2uint %4 - %17 = OpLoad %uint %6 - %15 = OpCompositeInsert %v2uint %17 %16 0 - OpStore %4 %15 - %19 = OpLoad %v2uint %4 - %20 = OpLoad %uint %6 - %18 = OpCompositeInsert %v2uint %20 %19 1 - OpStore %4 %18 - %22 = OpLoad %v2uint %4 - %23 = OpLoad %v2uint %4 - %29 = OpCompositeExtract %uint %23 1 - %21 = OpCompositeInsert %v2uint %29 %22 0 - OpStore %4 %21 - %25 = OpLoad %v2uint %4 - %24 = OpCopyObject %v2uint %25 - OpStore %2 %24 - %26 = OpLoad %v2uint %2 - OpReturnValue %26 - OpFunctionEnd - %31 = OpFunction %void None %68 - %40 = OpFunctionParameter %ulong - %41 = OpFunctionParameter %ulong - %58 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function - %33 = OpVariable %_ptr_Function_ulong Function - %34 = OpVariable %_ptr_Function_ulong Function - %35 = OpVariable %_ptr_Function_ulong Function - %36 = OpVariable %_ptr_Function_v2uint Function - %37 = OpVariable %_ptr_Function_uint Function - %38 = OpVariable %_ptr_Function_uint Function - %39 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %40 - OpStore %33 %41 - %43 = OpLoad %ulong %32 - %42 = OpCopyObject %ulong %43 - OpStore %34 %42 - %45 = OpLoad %ulong %33 - %44 = OpCopyObject %ulong %45 - OpStore %35 %44 - %47 = OpLoad %ulong %34 - %54 = OpConvertUToPtr %_ptr_Generic_v2uint %47 - %46 = OpLoad %v2uint %54 - OpStore %36 %46 - %49 = OpLoad %v2uint %36 - %48 = OpFunctionCall %v2uint %1 %49 - OpStore %36 %48 - %51 = OpLoad %v2uint %36 - %55 = OpBitcast %ulong %51 - %56 = OpCopyObject %ulong %55 - %50 = OpCopyObject %ulong %56 - OpStore %39 %50 - %52 = OpLoad %ulong %35 - %53 = OpLoad %v2uint %36 - %57 = OpConvertUToPtr %_ptr_Generic_v2uint %52 - OpStore %57 %53 + %ushort = OpTypeInt 16 0 +%_ptr_Function_ushort = OpTypePointer Function %ushort + %v4ushort = OpTypeVector %ushort 4 +%_ptr_Function_v4ushort = OpTypePointer Function %v4ushort + %uchar = OpTypeInt 8 0 + %v4uchar = OpTypeVector %uchar 4 +%_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar + %1 = OpFunction %void None %78 + %11 = OpFunctionParameter %ulong + %12 = OpFunctionParameter %ulong + %73 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ushort Function + %7 = OpVariable %_ptr_Function_ushort Function + %8 = OpVariable %_ptr_Function_ushort Function + %9 = OpVariable %_ptr_Function_ushort Function + %10 = OpVariable %_ptr_Function_v4ushort Function + OpStore %2 %11 + OpStore %3 %12 + %14 = OpLoad %ulong %2 + %13 = OpCopyObject %ulong %14 + OpStore %4 %13 + %16 = OpLoad %ulong %3 + %15 = OpCopyObject %ulong %16 + OpStore %5 %15 + %21 = OpLoad %ulong %4 + %63 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21 + %45 = OpLoad %v4uchar %63 + %64 = OpCompositeExtract %uchar %45 0 + %87 = OpBitcast %uchar %64 + %17 = OpUConvert %ushort %87 + %65 = OpCompositeExtract %uchar %45 1 + %88 = OpBitcast %uchar %65 + %18 = OpUConvert %ushort %88 + %66 = OpCompositeExtract %uchar %45 2 + %89 = OpBitcast %uchar %66 + %19 = OpUConvert %ushort %89 + %67 = OpCompositeExtract %uchar %45 3 + %90 = OpBitcast %uchar %67 + %20 = OpUConvert %ushort %90 + OpStore %6 %17 + OpStore %7 %18 + OpStore %8 %19 + OpStore %9 %20 + %23 = OpLoad %ushort %7 + %24 = OpLoad %ushort %8 + %25 = OpLoad %ushort %9 + %26 = OpLoad %ushort %6 + %46 = OpUndef %v4ushort + %47 = OpCompositeInsert %v4ushort %23 %46 0 + %48 = OpCompositeInsert %v4ushort %24 %47 1 + %49 = OpCompositeInsert %v4ushort %25 %48 2 + %50 = OpCompositeInsert %v4ushort %26 %49 3 + %22 = OpCopyObject %v4ushort %50 + OpStore %10 %22 + %31 = OpLoad %v4ushort %10 + %51 = OpCopyObject %v4ushort %31 + %27 = OpCompositeExtract %ushort %51 0 + %28 = OpCompositeExtract %ushort %51 1 + %29 = OpCompositeExtract %ushort %51 2 + %30 = OpCompositeExtract %ushort %51 3 + OpStore %8 %27 + OpStore %9 %28 + OpStore %6 %29 + OpStore %7 %30 + %36 = OpLoad %ushort %8 + %37 = OpLoad %ushort %9 + %38 = OpLoad %ushort %6 + %39 = OpLoad %ushort %7 + %53 = OpUndef %v4ushort + %54 = OpCompositeInsert %v4ushort %36 %53 0 + %55 = OpCompositeInsert %v4ushort %37 %54 1 + %56 = OpCompositeInsert %v4ushort %38 %55 2 + %57 = OpCompositeInsert %v4ushort %39 %56 3 + %52 = OpCopyObject %v4ushort %57 + %32 = OpCompositeExtract %ushort %52 0 + %33 = OpCompositeExtract %ushort %52 1 + %34 = OpCompositeExtract %ushort %52 2 + %35 = OpCompositeExtract %ushort %52 3 + OpStore %9 %32 + OpStore %6 %33 + OpStore %7 %34 + OpStore %8 %35 + %40 = OpLoad %ulong %5 + %41 = OpLoad %ushort %6 + %42 = OpLoad %ushort %7 + %43 = OpLoad %ushort %8 + %44 = OpLoad %ushort %9 + %58 = OpUndef %v4uchar + %91 = OpBitcast %ushort %41 + %68 = OpUConvert %uchar %91 + %59 = OpCompositeInsert %v4uchar %68 %58 0 + %92 = OpBitcast %ushort %42 + %69 = OpUConvert %uchar %92 + %60 = OpCompositeInsert %v4uchar %69 %59 1 + %93 = OpBitcast %ushort %43 + %70 = OpUConvert %uchar %93 + %61 = OpCompositeInsert %v4uchar %70 %60 2 + %94 = OpBitcast %ushort %44 + %71 = OpUConvert %uchar %94 + %62 = OpCompositeInsert %v4uchar %71 %61 3 + %72 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %40 + OpStore %72 %62 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 981da86..37cef00 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -843,6 +843,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( (_, ArgumentSemantics::Address) => return Ok(desc.op),
(t, ArgumentSemantics::RegisterPointer)
| (t, ArgumentSemantics::Default)
+ | (t, ArgumentSemantics::DefaultRelaxed)
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
let generated_id = id_def.new_id(id_type);
@@ -933,17 +934,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn insert_composite_read(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver<'a>,
- (scalar_type, vec_len): (ast::ScalarType, u8),
+ typ: (ast::ScalarType, u8),
scalar_dst: Option<spirv::Word>,
+ scalar_sema_override: Option<ArgumentSemantics>,
composite_src: (spirv::Word, u8),
) -> spirv::Word {
- let new_id =
- scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Vector(scalar_type, vec_len)));
+ let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0)));
func.push(Statement::Composite(CompositeRead {
- typ: scalar_type,
+ typ: typ.0,
dst: new_id,
+ dst_semantics_override: scalar_sema_override,
src_composite: composite_src.0,
src_index: composite_src.1 as u32,
+ src_len: typ.1 as u32,
}));
new_id
}
@@ -963,7 +966,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
match desc.sema {
- ArgumentSemantics::Default => {
+ ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
@@ -1049,18 +1052,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn member_src(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vec_len): (ast::ScalarType, u8),
+ typ: (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
if desc.is_dst {
return Err(TranslateError::Unreachable);
}
- let new_id = self.id_def.new_id(ast::Type::Vector(scalar_type, vec_len));
- self.func.push(Statement::Composite(CompositeRead {
- typ: scalar_type,
- dst: new_id,
- src_composite: desc.op.0,
- src_index: desc.op.1 as u32,
- }));
+ let new_id = Self::insert_composite_read(
+ self.func,
+ self.id_def,
+ typ,
+ None,
+ Some(desc.sema),
+ desc.op,
+ );
Ok(new_id)
}
@@ -1077,10 +1081,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { let newer_id = self.id_def.new_id(typ);
self.func.push(Statement::Instruction(ast::Instruction::Mov(
ast::MovDetails {
- typ: typ,
+ typ: ast::Type::Scalar(scalar_type),
src_is_address: false,
- dst_width: 0,
- src_width: vec_len,
+ dst_width: vec_len,
+ src_width: 0,
+ relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed,
},
ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
(newer_id, idx as u8),
@@ -1099,6 +1104,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.id_def,
(scalar_type, vec_len),
Some(*id),
+ Some(desc.sema),
(new_id, idx as u8),
);
}
@@ -1144,9 +1150,9 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vec_len): (ast::ScalarType, u8),
+ typ: (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
- self.member_src(desc, (scalar_type, vec_len))
+ self.member_src(desc, typ)
}
fn id_or_vector(
@@ -1195,123 +1201,41 @@ fn insert_implicit_conversions( let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
- Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call)?,
- Statement::Instruction(inst) => match inst {
- ast::Instruction::Ld(ld, arg) => {
- let pre_conv = get_implicit_conversions_ld_src(
- id_def,
- ld.typ,
- ld.state_space,
- arg.src,
- false,
- )?;
- let post_conv = get_implicit_conversions_ld_dst(
- id_def,
- ld.typ,
- arg.dst,
- should_convert_relaxed_dst,
- false,
- )?;
- insert_with_conversions(
- &mut result,
- id_def,
- arg,
- pre_conv.into_iter(),
- iter::empty(),
- post_conv.into_iter().collect(),
- |arg| &mut arg.src,
- |arg| &mut arg.dst,
- |arg| ast::Instruction::Ld(ld, arg),
- )
+ Statement::Call(call) => insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ call,
+ should_bitcast_wrapper,
+ None,
+ )?,
+ Statement::Instruction(inst) => {
+ let mut default_conversion_fn = should_bitcast_wrapper
+ as fn(_, _, _) -> Result<Option<ConversionKind>, TranslateError>;
+ let mut state_space = None;
+ if let ast::Instruction::Ld(d, _) = &inst {
+ state_space = Some(d.state_space);
}
- ast::Instruction::St(st, arg) => {
- let pre_conv = get_implicit_conversions_ld_dst(
- id_def,
- st.typ,
- arg.src2,
- should_convert_relaxed_src,
- true,
- )?;
- let post_conv = get_implicit_conversions_ld_src(
- id_def,
- st.typ,
- st.state_space.to_ld_ss(),
- arg.src1,
- true,
- )?;
- let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param
- || st.state_space == ast::StStateSpace::Local
- {
- (Vec::new(), post_conv)
- } else {
- (post_conv, Vec::new())
- };
- insert_with_conversions(
- &mut result,
- id_def,
- arg,
- pre_conv.into_iter(),
- pre_conv_dest.into_iter(),
- post_conv,
- |arg| &mut arg.src2,
- |arg| &mut arg.src1,
- |arg| ast::Instruction::St(st, arg),
- )
+ if let ast::Instruction::St(d, _) = &inst {
+ state_space = Some(d.state_space.to_ld_ss());
}
- ast::Instruction::Mov(d, ast::Arg2Mov::Normal(mut arg)) => {
- // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2
- // TODO: handle the case of mixed vector/scalar implicit conversions
- let inst_typ_is_bit = match d.typ {
- ast::Type::Scalar(t) => ast::ScalarType::from(t).kind() == ScalarKind::Bit,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => false,
- };
- let mut did_vector_implicit = false;
- let mut post_conv = None;
- if inst_typ_is_bit {
- let src_type = id_def.get_typed(arg.src)?;
- if let ast::Type::Vector(_, _) = src_type {
- arg.src = insert_conversion_src(
- &mut result,
- id_def,
- arg.src,
- src_type,
- d.typ.into(),
- ConversionKind::Default,
- );
- did_vector_implicit = true;
- }
- let dst_type = id_def.get_typed(arg.dst)?;
- if let ast::Type::Vector(_, _) = dst_type {
- post_conv = Some(get_conversion_dst(
- id_def,
- &mut arg.dst,
- d.typ.into(),
- dst_type,
- ConversionKind::Default,
- ));
- did_vector_implicit = true;
- }
- }
- if did_vector_implicit {
- result.push(Statement::Instruction(ast::Instruction::Mov(
- d,
- ast::Arg2Mov::Normal(arg),
- )));
- } else {
- insert_implicit_bitcasts(
- &mut result,
- id_def,
- ast::Instruction::Mov(d, ast::Arg2Mov::Normal(arg)),
- )?;
- }
- if let Some(post_conv) = post_conv {
- result.push(post_conv);
- }
+ if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
+ default_conversion_fn = should_bitcast_packed;
}
- inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst)?,
- },
- Statement::Composite(c) => insert_implicit_bitcasts(&mut result, id_def, c)?,
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ inst,
+ default_conversion_fn,
+ state_space,
+ )?;
+ }
+ Statement::Composite(composite) => insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ composite,
+ should_bitcast_wrapper,
+ None,
+ )?,
s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
@@ -1326,6 +1250,77 @@ fn insert_implicit_conversions( Ok(result)
}
+fn insert_implicit_conversions_impl(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut MutableNumericIdResolver,
+ stmt: impl VisitVariableExpanded,
+ default_conversion_fn: fn(
+ ast::Type,
+ ast::Type,
+ Option<ast::LdStateSpace>,
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ state_space: Option<ast::LdStateSpace>,
+) -> Result<(), TranslateError> {
+ let mut post_conv = Vec::new();
+ let statement = stmt.visit_variable_extended(&mut |desc, typ| {
+ let instr_type = match typ {
+ None => return Ok(desc.op),
+ Some(t) => t,
+ };
+ let operand_type = id_def.get_typed(desc.op)?;
+ let mut conversion_fn = default_conversion_fn;
+ match desc.sema {
+ ArgumentSemantics::Default => {}
+ ArgumentSemantics::DefaultRelaxed => {
+ if desc.is_dst {
+ conversion_fn = should_convert_relaxed_dst_wrapper;
+ } else {
+ conversion_fn = should_convert_relaxed_src_wrapper;
+ }
+ }
+ ArgumentSemantics::PhysicalPointer => {
+ conversion_fn = bitcast_physical_pointer;
+ }
+ ArgumentSemantics::RegisterPointer => {
+ conversion_fn = force_bitcast;
+ }
+ ArgumentSemantics::Address => {
+ conversion_fn = force_bitcast_ptr_to_bit;
+ }
+ };
+ match conversion_fn(operand_type, instr_type, state_space)? {
+ Some(conv_kind) => {
+ let conv_output = if desc.is_dst {
+ &mut post_conv
+ } else {
+ &mut *func
+ };
+ let mut from = instr_type;
+ let mut to = operand_type;
+ let mut src = id_def.new_id(instr_type);
+ let mut dst = desc.op;
+ let result = Ok(src);
+ if !desc.is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from, &mut to);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from,
+ to,
+ kind: conv_kind,
+ }));
+ result
+ }
+ None => Ok(desc.op),
+ }
+ })?;
+ func.push(statement);
+ func.append(&mut post_conv);
+ Ok(())
+}
+
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1505,7 +1500,11 @@ fn emit_function_body_ops( composite_src,
scalar_src,
)) => {
- let result_type = map.get_or_add(builder, SpirvType::from(d.typ));
+ let scalar_type = d.typ.get_scalar()?;
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)),
+ );
let result_id = Some(dst.0);
builder.composite_insert(
result_type,
@@ -1545,8 +1544,8 @@ fn emit_function_body_ops( // Obviously, old and buggy one is used for compiling L0 SPIRV
// https://github.com/intel/intel-graphics-compiler/issues/148
let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
- let const_true = builder.constant_true(type_pred);
- let const_false = builder.constant_false(type_pred);
+ let const_true = builder.constant_true(type_pred, None);
+ let const_false = builder.constant_false(type_pred, None);
builder.select(result_type, result_id, operand, const_false, const_true)
}
_ => builder.not(result_type, result_id, operand),
@@ -2700,12 +2699,9 @@ where fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- (scalar_type, vec_len): (ast::ScalarType, u8),
+ (scalar_type, _): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
- self(
- desc.new_op(desc.op),
- Some(ast::Type::Vector(scalar_type.into(), vec_len)),
- )
+ self(desc.new_op(desc.op), Some(ast::Type::Scalar(scalar_type)))
}
}
@@ -2793,6 +2789,8 @@ pub struct ArgumentDescriptor<Op> { pub enum ArgumentSemantics {
// normal register access
Default,
+ // normal register access with relaxed conversion rules (ld/st)
+ DefaultRelaxed,
// st/ld global
PhysicalPointer,
// st/ld .param, .local
@@ -2834,11 +2832,12 @@ impl<T: ArgParamsEx> ast::Instruction<T> { }
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)?)
+ let is_wide = d.is_wide();
+ ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type, is_wide)?)
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)?)
+ ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type, false)?)
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
@@ -2889,7 +2888,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> { }
ast::Instruction::Mad(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Mad(d, a.map(visitor, inst_type)?)
+ let is_wide = d.is_wide();
+ ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?)
}
})
}
@@ -3004,6 +3004,27 @@ where }
impl ast::Type {
+ fn widen(self) -> Result<Self, TranslateError> {
+ match self {
+ ast::Type::Scalar(scalar) => {
+ let kind = scalar.kind();
+ let width = scalar.width();
+ if (kind != ScalarKind::Signed
+ && kind != ScalarKind::Unsigned
+ && kind != ScalarKind::Bit)
+ || (width == 8)
+ {
+ return Err(TranslateError::MismatchedType);
+ }
+ Ok(ast::Type::Scalar(ast::ScalarType::from_parts(
+ width * 2,
+ kind,
+ )))
+ }
+ _ => Err(TranslateError::Unreachable),
+ }
+ }
+
fn to_parts(self) -> TypeParts {
match self {
ast::Type::Scalar(scalar) => TypeParts {
@@ -3102,8 +3123,10 @@ type Arg2St = ast::Arg2St<ExpandedArgParams>; struct CompositeRead {
pub typ: ast::ScalarType,
pub dst: spirv::Word,
+ pub dst_semantics_override: Option<ArgumentSemantics>,
pub src_composite: spirv::Word,
pub src_index: u32,
+ pub src_len: u32,
}
impl VisitVariableExpanded for CompositeRead {
@@ -3116,12 +3139,15 @@ impl VisitVariableExpanded for CompositeRead { self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError> {
+ let dst_sema = self
+ .dst_semantics_override
+ .unwrap_or(ArgumentSemantics::Default);
Ok(Statement::Composite(CompositeRead {
dst: f(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ sema: dst_sema,
},
Some(ast::Type::Scalar(self.typ)),
)?,
@@ -3131,7 +3157,7 @@ impl VisitVariableExpanded for CompositeRead { is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(self.typ)),
+ Some(ast::Type::Vector(self.typ, self.src_len as u8)),
)?,
..self
}))
@@ -3328,7 +3354,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ sema: ArgumentSemantics::DefaultRelaxed,
},
t.into(),
)?;
@@ -3380,7 +3406,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> { ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ sema: ArgumentSemantics::DefaultRelaxed,
},
t,
)?;
@@ -3429,9 +3455,9 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> { op: self.src,
is_dst: false,
sema: if details.src_is_address {
- ArgumentSemantics::RegisterPointer
+ ArgumentSemantics::Address
} else {
- ArgumentSemantics::PhysicalPointer
+ ArgumentSemantics::Default
},
},
details.typ.into(),
@@ -3476,13 +3502,14 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { ) -> Result<ast::Arg2MovMember<U>, TranslateError> {
match self {
ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
+ let scalar_type = details.typ.get_scalar()?;
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src1 = visitor.id(
ArgumentDescriptor {
@@ -3490,7 +3517,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src2 = visitor.id(
ArgumentDescriptor {
@@ -3498,6 +3525,8 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { is_dst: false,
sema: if details.src_is_address {
ArgumentSemantics::Address
+ } else if details.relaxed_src2_conv {
+ ArgumentSemantics::DefaultRelaxed
} else {
ArgumentSemantics::Default
},
@@ -3527,13 +3556,14 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { Ok(ast::Arg2MovMember::Src(dst, src))
}
ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
+ let scalar_type = details.typ.get_scalar()?;
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let composite_src = visitor.id(
ArgumentDescriptor {
@@ -3541,16 +3571,19 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
- let scalar_typ = details.typ.get_scalar()?;
let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ sema: if details.relaxed_src2_conv {
+ ArgumentSemantics::DefaultRelaxed
+ } else {
+ ArgumentSemantics::Default
+ },
},
- (scalar_typ.into(), details.src_width),
+ (scalar_type.into(), details.src_width),
)?;
Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
}
@@ -3570,7 +3603,8 @@ impl<T: ArgParamsEx> ast::Arg3<T> { fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ typ: ast::Type,
+ is_wide: bool,
) -> Result<ast::Arg3<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@@ -3578,7 +3612,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ Some(if is_wide { typ.widen()? } else { typ }),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -3586,7 +3620,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- t,
+ typ,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -3594,7 +3628,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- t,
+ typ,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -3646,6 +3680,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> { self,
visitor: &mut V,
t: ast::Type,
+ is_wide: bool,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@@ -3653,7 +3688,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ Some(if is_wide { t.widen()? } else { t }),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -4050,6 +4085,54 @@ impl<T> ast::OperandOrVector<T> { }
}
+impl ast::MulDetails {
+ fn is_wide(&self) -> bool {
+ match self {
+ ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide,
+ ast::MulDetails::Float(_) => false,
+ }
+ }
+}
+
+fn force_bitcast(
+ operand: ast::Type,
+ instr: ast::Type,
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instr != operand {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Ok(None)
+ }
+}
+
+fn bitcast_physical_pointer(
+ operand_type: ast::Type,
+ _: ast::Type,
+ ss: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ match operand_type {
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => {
+ if let Some(space) = ss {
+ Ok(Some(ConversionKind::BitToPtr(space)))
+ } else {
+ Err(TranslateError::Unreachable)
+ }
+ }
+ _ => Err(TranslateError::MismatchedType),
+ }
+}
+
+fn force_bitcast_ptr_to_bit(
+ _: ast::Type,
+ _: ast::Type,
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ Ok(Some(ConversionKind::PtrToBit))
+}
+
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@@ -4077,187 +4160,50 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { }
}
-fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>>(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- mut instr: T,
- pre_conv_src: impl ExactSizeIterator<Item = ImplicitConversion>,
- pre_conv_dst: impl ExactSizeIterator<Item = ImplicitConversion>,
- mut post_conv: Vec<ImplicitConversion>,
- mut src: impl FnMut(&mut T) -> &mut spirv::Word,
- mut dst: impl FnMut(&mut T) -> &mut spirv::Word,
- to_inst: ToInstruction,
-) {
- insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
- insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
- if post_conv.len() > 0 {
- let new_id = id_def.new_id(post_conv[0].from);
- post_conv[0].src = new_id;
- post_conv.last_mut().unwrap().dst = *dst(&mut instr);
- *dst(&mut instr) = new_id;
- }
- func.push(Statement::Instruction(to_inst(instr)));
- for conv in post_conv {
- func.push(Statement::Conversion(conv));
- }
-}
-
-fn insert_with_conversions_pre_conv<T>(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- mut instr: &mut T,
- pre_conv: impl ExactSizeIterator<Item = ImplicitConversion>,
- src: &mut impl FnMut(&mut T) -> &mut spirv::Word,
-) {
- let pre_conv_len = pre_conv.len();
- for (i, mut conv) in pre_conv.enumerate() {
- let original_src = src(&mut instr);
- if i == 0 {
- conv.src = *original_src;
- }
- if i == pre_conv_len - 1 {
- let new_id = id_def.new_id(conv.to);
- conv.dst = new_id;
- *original_src = new_id;
+fn should_bitcast_packed(
+ operand: ast::Type,
+ instr: ast::Type,
+ ss: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
+ (operand, instr)
+ {
+ if scalar.kind() == ScalarKind::Bit
+ && scalar.width() == (vec_underlying_type.width() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
}
- func.push(Statement::Conversion(conv));
}
+ should_bitcast_wrapper(operand, instr, ss)
}
-fn get_implicit_conversions_ld_dst<
- ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
->(
- id_def: &mut MutableNumericIdResolver,
- instr_type: ast::Type,
- dst: spirv::Word,
- should_convert: ShouldConvert,
- in_reverse: bool,
-) -> Result<Option<ImplicitConversion>, TranslateError> {
- let dst_type = id_def.get_typed(dst)?;
- if let Some(conv) = should_convert(dst_type, instr_type) {
- Ok(Some(ImplicitConversion {
- src: u32::max_value(),
- dst: u32::max_value(),
- from: if !in_reverse { instr_type } else { dst_type },
- to: if !in_reverse { dst_type } else { instr_type },
- kind: conv,
- }))
- } else {
- Ok(None)
+fn should_bitcast_wrapper(
+ operand: ast::Type,
+ instr: ast::Type,
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instr == operand {
+ return Ok(None);
}
-}
-
-fn get_implicit_conversions_ld_src(
- id_def: &mut MutableNumericIdResolver,
- instr_type: ast::Type,
- state_space: ast::LdStateSpace,
- src: spirv::Word,
- in_reverse_param_local: bool,
-) -> Result<Vec<ImplicitConversion>, TranslateError> {
- let src_type = id_def.get_typed(src)?;
- match state_space {
- ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
- if src_type != instr_type {
- Ok(vec![
- ImplicitConversion {
- src: u32::max_value(),
- dst: u32::max_value(),
- from: if !in_reverse_param_local {
- src_type
- } else {
- instr_type
- },
- to: if !in_reverse_param_local {
- instr_type
- } else {
- src_type
- },
- kind: ConversionKind::Default,
- };
- 1
- ])
- } else {
- Ok(Vec::new())
- }
- }
- ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
- let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
- mem::size_of::<usize>() as u8,
- ScalarKind::Bit,
- ));
- let mut result = Vec::new();
- // HACK ALERT
- // IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
- // additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
- // TODO: error out if the src is not B64/U64/S64
- if let ast::Type::Scalar(scalar_src_type) = src_type {
- if scalar_src_type.kind() == ScalarKind::Signed {
- result.push(ImplicitConversion {
- src: u32::max_value(),
- dst: u32::max_value(),
- from: src_type,
- to: new_src_type,
- kind: ConversionKind::Default,
- });
- }
- }
- result.push(ImplicitConversion {
- src: u32::max_value(),
- dst: u32::max_value(),
- from: src_type,
- to: instr_type,
- kind: ConversionKind::BitToPtr(state_space),
- });
- if result.len() == 2 {
- let new_id = id_def.new_id(new_src_type);
- result[0].dst = new_id;
- result[1].src = new_id;
- result[1].from = new_src_type;
- }
- Ok(result)
- }
- _ => Err(TranslateError::Todo),
+ if should_bitcast(instr, operand) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
}
}
-#[must_use]
-fn insert_conversion_src(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- src: spirv::Word,
+fn should_convert_relaxed_src_wrapper(
src_type: ast::Type,
instr_type: ast::Type,
- conv: ConversionKind,
-) -> spirv::Word {
- let temp_src = id_def.new_id(instr_type);
- func.push(Statement::Conversion(ImplicitConversion {
- src: src,
- dst: temp_src,
- from: src_type,
- to: instr_type,
- kind: conv,
- }));
- temp_src
-}
-
-#[must_use]
-fn get_conversion_dst(
- id_def: &mut MutableNumericIdResolver,
- dst: &mut spirv::Word,
- instr_type: ast::Type,
- dst_type: ast::Type,
- kind: ConversionKind,
-) -> ExpandedStatement {
- let original_dst = *dst;
- let temp_dst = id_def.new_id(instr_type);
- *dst = temp_dst;
- Statement::Conversion(ImplicitConversion {
- src: temp_dst,
- dst: original_dst,
- from: instr_type,
- to: dst_type,
- kind: kind,
- })
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if src_type == instr_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_src(src_type, instr_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
@@ -4302,6 +4248,20 @@ fn should_convert_relaxed_src( }
}
+fn should_convert_relaxed_dst_wrapper(
+ dst_type: ast::Type,
+ instr_type: ast::Type,
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if dst_type == instr_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_dst(dst_type, instr_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
+}
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: ast::Type,
@@ -4357,55 +4317,6 @@ fn should_convert_relaxed_dst( }
}
-fn insert_implicit_bitcasts(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- stmt: impl VisitVariableExpanded,
-) -> Result<(), TranslateError> {
- let mut dst_coercion = None;
- let instr = stmt.visit_variable_extended(&mut |mut desc, typ| {
- let id_type_from_instr = match typ {
- Some(t) => t,
- None => return Ok(desc.op),
- };
- let id_actual_type = id_def.get_typed(desc.op)?;
- let conv_kind = if desc.sema == ArgumentSemantics::Address {
- Some(ConversionKind::PtrToBit)
- } else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
- Some(ConversionKind::Default)
- } else {
- None
- };
- if let Some(conv_kind) = conv_kind {
- if desc.is_dst {
- dst_coercion = Some(get_conversion_dst(
- id_def,
- &mut desc.op,
- id_type_from_instr,
- id_actual_type,
- conv_kind,
- ));
- Ok(desc.op)
- } else {
- Ok(insert_conversion_src(
- func,
- id_def,
- desc.op,
- id_actual_type,
- id_type_from_instr,
- conv_kind,
- ))
- }
- } else {
- Ok(desc.op)
- }
- })?;
- func.push(instr);
- if let Some(cond) = dst_coercion {
- func.push(cond);
- }
- Ok(())
-}
impl<'a> ast::MethodDecl<'a, &'a str> {
fn name(&self) -> &'a str {
match self {
|