summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-10-01 00:44:58 +0200
committerAndrzej Janik <[email protected]>2020-10-01 18:11:57 +0200
commit3e92921275473e3dc028ff5159a17179af6047ba (patch)
tree1ecfe9c7ebe27785c2b132675224e3cf1de03631
parent1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8 (diff)
downloadZLUDA-3e92921275473e3dc028ff5159a17179af6047ba.tar.gz
ZLUDA-3e92921275473e3dc028ff5159a17179af6047ba.zip
Fix remaining bugs in vector destructuring and in the process improve implicit conversions
-rw-r--r--Cargo.toml4
-rw-r--r--ptx/Cargo.toml2
-rw-r--r--ptx/src/ast.rs11
-rw-r--r--ptx/src/test/spirv_run/ld_st_implicit.spvtxt16
-rw-r--r--ptx/src/test/spirv_run/mul_wide.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/vector_extract.ptx3
-rw-r--r--ptx/src/test/spirv_run/vector_extract.spvtxt199
-rw-r--r--ptx/src/translate.rs673
8 files changed, 433 insertions, 485 deletions
diff --git a/Cargo.toml b/Cargo.toml
index ed5d1f1..42be95a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -11,5 +11,5 @@ members = [
]
[patch.crates-io]
-rspirv = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' }
-spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '4523d54d785faff59c1e928dd1f210c531a70258' } \ No newline at end of file
+rspirv = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' }
+spirv_headers = { git = 'https://github.com/vosen/rspirv', rev = '0f5761918624f4a95107c14abe64946c5c5f60ce' } \ No newline at end of file
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml
index 42d60cb..96ab9d0 100644
--- a/ptx/Cargo.toml
+++ b/ptx/Cargo.toml
@@ -10,7 +10,7 @@ edition = "2018"
lalrpop-util = "0.19"
regex = "1"
rspirv = "0.6"
-spirv_headers = "1.4"
+spirv_headers = "~1.4.2"
quick-error = "1.2"
bit-vec = "0.6"
half ="1.6"
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 7edfa70..097e19c 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -463,14 +463,14 @@ pub enum CallOperand<ID> {
pub enum IdOrVector<ID> {
Reg(ID),
- Vec(Vec<ID>)
+ Vec(Vec<ID>),
}
pub enum OperandOrVector<ID> {
Reg(ID),
RegOffset(ID, i32),
Imm(u32),
- Vec(Vec<ID>)
+ Vec(Vec<ID>),
}
impl<T> From<Operand<T>> for OperandOrVector<T> {
@@ -536,6 +536,8 @@ pub struct MovDetails {
// two fields below are in use by member moves
pub dst_width: u8,
pub src_width: u8,
+ // This is in use by auto-generated movs
+ pub relaxed_src2_conv: bool,
}
impl MovDetails {
@@ -544,7 +546,8 @@ impl MovDetails {
typ,
src_is_address: false,
dst_width: 0,
- src_width: 0
+ src_width: 0,
+ relaxed_src2_conv: false,
}
}
}
@@ -560,7 +563,7 @@ pub struct MulIntDesc {
pub control: MulIntControl,
}
-#[derive(Copy, Clone)]
+#[derive(Copy, Clone, PartialEq, Eq)]
pub enum MulIntControl {
Low,
High,
diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt
index 249af90..d4d9499 100644
--- a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt
+++ b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt
@@ -2,8 +2,10 @@
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
- OpCapability Int64
OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
OpCapability Float64
%23 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
@@ -33,17 +35,17 @@
%11 = OpCopyObject %ulong %12
OpStore %5 %11
%14 = OpLoad %ulong %4
- %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14
- %18 = OpLoad %float %17
- %31 = OpBitcast %uint %18
+ %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14
+ %17 = OpLoad %float %18
+ %31 = OpBitcast %uint %17
%13 = OpUConvert %ulong %31
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %ulong %6
+ %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15
%32 = OpBitcast %ulong %16
%33 = OpUConvert %uint %32
- %19 = OpBitcast %float %33
- %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15
- OpStore %20 %19
+ %20 = OpBitcast %float %33
+ OpStore %19 %20
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt
index 274612c..8ac0459 100644
--- a/ptx/src/test/spirv_run/mul_wide.spvtxt
+++ b/ptx/src/test/spirv_run/mul_wide.spvtxt
@@ -2,8 +2,10 @@
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
- OpCapability Int64
OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
OpCapability Float64
%32 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
@@ -57,8 +59,8 @@
OpStore %8 %19
%22 = OpLoad %ulong %5
%23 = OpLoad %ulong %8
- %28 = OpCopyObject %ulong %23
- %29 = OpConvertUToPtr %_ptr_Generic_ulong %22
- OpStore %29 %28
+ %28 = OpConvertUToPtr %_ptr_Generic_ulong %22
+ %29 = OpCopyObject %ulong %23
+ OpStore %28 %29
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/vector_extract.ptx b/ptx/src/test/spirv_run/vector_extract.ptx
index 8624f8a..111f7c0 100644
--- a/ptx/src/test/spirv_run/vector_extract.ptx
+++ b/ptx/src/test/spirv_run/vector_extract.ptx
@@ -15,6 +15,9 @@
.reg .u16 temp4;
.reg .v4.u16 foo;
+ ld.param.u64 in_addr, [input_p];
+ ld.param.u64 out_addr, [output_p];
+
ld.global.v4.u8 {temp1, temp2, temp3, temp4}, [in_addr];
mov.v4.u16 foo, {temp2, temp3, temp4, temp1};
mov.v4.u16 {temp3, temp4, temp1, temp2}, foo;
diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt
index ff0ee97..45df3a8 100644
--- a/ptx/src/test/spirv_run/vector_extract.spvtxt
+++ b/ptx/src/test/spirv_run/vector_extract.spvtxt
@@ -2,96 +2,123 @@
OpCapability Linkage
OpCapability Addresses
OpCapability Kernel
- OpCapability Int64
OpCapability Int8
- %60 = OpExtInstImport "OpenCL.std"
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %75 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %31 "vector"
+ OpEntryPoint Kernel %1 "vector_extract"
%void = OpTypeVoid
- %uint = OpTypeInt 32 0
- %v2uint = OpTypeVector %uint 2
- %64 = OpTypeFunction %v2uint %v2uint
-%_ptr_Function_v2uint = OpTypePointer Function %v2uint
-%_ptr_Function_uint = OpTypePointer Function %uint
%ulong = OpTypeInt 64 0
- %68 = OpTypeFunction %void %ulong %ulong
+ %78 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
- %1 = OpFunction %v2uint None %64
- %7 = OpFunctionParameter %v2uint
- %30 = OpLabel
- %3 = OpVariable %_ptr_Function_v2uint Function
- %2 = OpVariable %_ptr_Function_v2uint Function
- %4 = OpVariable %_ptr_Function_v2uint Function
- %5 = OpVariable %_ptr_Function_uint Function
- %6 = OpVariable %_ptr_Function_uint Function
- OpStore %3 %7
- %9 = OpLoad %v2uint %3
- %27 = OpCompositeExtract %uint %9 0
- %8 = OpCopyObject %uint %27
- OpStore %5 %8
- %11 = OpLoad %v2uint %3
- %28 = OpCompositeExtract %uint %11 1
- %10 = OpCopyObject %uint %28
- OpStore %6 %10
- %13 = OpLoad %uint %5
- %14 = OpLoad %uint %6
- %12 = OpIAdd %uint %13 %14
- OpStore %6 %12
- %16 = OpLoad %v2uint %4
- %17 = OpLoad %uint %6
- %15 = OpCompositeInsert %v2uint %17 %16 0
- OpStore %4 %15
- %19 = OpLoad %v2uint %4
- %20 = OpLoad %uint %6
- %18 = OpCompositeInsert %v2uint %20 %19 1
- OpStore %4 %18
- %22 = OpLoad %v2uint %4
- %23 = OpLoad %v2uint %4
- %29 = OpCompositeExtract %uint %23 1
- %21 = OpCompositeInsert %v2uint %29 %22 0
- OpStore %4 %21
- %25 = OpLoad %v2uint %4
- %24 = OpCopyObject %v2uint %25
- OpStore %2 %24
- %26 = OpLoad %v2uint %2
- OpReturnValue %26
- OpFunctionEnd
- %31 = OpFunction %void None %68
- %40 = OpFunctionParameter %ulong
- %41 = OpFunctionParameter %ulong
- %58 = OpLabel
- %32 = OpVariable %_ptr_Function_ulong Function
- %33 = OpVariable %_ptr_Function_ulong Function
- %34 = OpVariable %_ptr_Function_ulong Function
- %35 = OpVariable %_ptr_Function_ulong Function
- %36 = OpVariable %_ptr_Function_v2uint Function
- %37 = OpVariable %_ptr_Function_uint Function
- %38 = OpVariable %_ptr_Function_uint Function
- %39 = OpVariable %_ptr_Function_ulong Function
- OpStore %32 %40
- OpStore %33 %41
- %43 = OpLoad %ulong %32
- %42 = OpCopyObject %ulong %43
- OpStore %34 %42
- %45 = OpLoad %ulong %33
- %44 = OpCopyObject %ulong %45
- OpStore %35 %44
- %47 = OpLoad %ulong %34
- %54 = OpConvertUToPtr %_ptr_Generic_v2uint %47
- %46 = OpLoad %v2uint %54
- OpStore %36 %46
- %49 = OpLoad %v2uint %36
- %48 = OpFunctionCall %v2uint %1 %49
- OpStore %36 %48
- %51 = OpLoad %v2uint %36
- %55 = OpBitcast %ulong %51
- %56 = OpCopyObject %ulong %55
- %50 = OpCopyObject %ulong %56
- OpStore %39 %50
- %52 = OpLoad %ulong %35
- %53 = OpLoad %v2uint %36
- %57 = OpConvertUToPtr %_ptr_Generic_v2uint %52
- OpStore %57 %53
+ %ushort = OpTypeInt 16 0
+%_ptr_Function_ushort = OpTypePointer Function %ushort
+ %v4ushort = OpTypeVector %ushort 4
+%_ptr_Function_v4ushort = OpTypePointer Function %v4ushort
+ %uchar = OpTypeInt 8 0
+ %v4uchar = OpTypeVector %uchar 4
+%_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar
+ %1 = OpFunction %void None %78
+ %11 = OpFunctionParameter %ulong
+ %12 = OpFunctionParameter %ulong
+ %73 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ushort Function
+ %7 = OpVariable %_ptr_Function_ushort Function
+ %8 = OpVariable %_ptr_Function_ushort Function
+ %9 = OpVariable %_ptr_Function_ushort Function
+ %10 = OpVariable %_ptr_Function_v4ushort Function
+ OpStore %2 %11
+ OpStore %3 %12
+ %14 = OpLoad %ulong %2
+ %13 = OpCopyObject %ulong %14
+ OpStore %4 %13
+ %16 = OpLoad %ulong %3
+ %15 = OpCopyObject %ulong %16
+ OpStore %5 %15
+ %21 = OpLoad %ulong %4
+ %63 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21
+ %45 = OpLoad %v4uchar %63
+ %64 = OpCompositeExtract %uchar %45 0
+ %87 = OpBitcast %uchar %64
+ %17 = OpUConvert %ushort %87
+ %65 = OpCompositeExtract %uchar %45 1
+ %88 = OpBitcast %uchar %65
+ %18 = OpUConvert %ushort %88
+ %66 = OpCompositeExtract %uchar %45 2
+ %89 = OpBitcast %uchar %66
+ %19 = OpUConvert %ushort %89
+ %67 = OpCompositeExtract %uchar %45 3
+ %90 = OpBitcast %uchar %67
+ %20 = OpUConvert %ushort %90
+ OpStore %6 %17
+ OpStore %7 %18
+ OpStore %8 %19
+ OpStore %9 %20
+ %23 = OpLoad %ushort %7
+ %24 = OpLoad %ushort %8
+ %25 = OpLoad %ushort %9
+ %26 = OpLoad %ushort %6
+ %46 = OpUndef %v4ushort
+ %47 = OpCompositeInsert %v4ushort %23 %46 0
+ %48 = OpCompositeInsert %v4ushort %24 %47 1
+ %49 = OpCompositeInsert %v4ushort %25 %48 2
+ %50 = OpCompositeInsert %v4ushort %26 %49 3
+ %22 = OpCopyObject %v4ushort %50
+ OpStore %10 %22
+ %31 = OpLoad %v4ushort %10
+ %51 = OpCopyObject %v4ushort %31
+ %27 = OpCompositeExtract %ushort %51 0
+ %28 = OpCompositeExtract %ushort %51 1
+ %29 = OpCompositeExtract %ushort %51 2
+ %30 = OpCompositeExtract %ushort %51 3
+ OpStore %8 %27
+ OpStore %9 %28
+ OpStore %6 %29
+ OpStore %7 %30
+ %36 = OpLoad %ushort %8
+ %37 = OpLoad %ushort %9
+ %38 = OpLoad %ushort %6
+ %39 = OpLoad %ushort %7
+ %53 = OpUndef %v4ushort
+ %54 = OpCompositeInsert %v4ushort %36 %53 0
+ %55 = OpCompositeInsert %v4ushort %37 %54 1
+ %56 = OpCompositeInsert %v4ushort %38 %55 2
+ %57 = OpCompositeInsert %v4ushort %39 %56 3
+ %52 = OpCopyObject %v4ushort %57
+ %32 = OpCompositeExtract %ushort %52 0
+ %33 = OpCompositeExtract %ushort %52 1
+ %34 = OpCompositeExtract %ushort %52 2
+ %35 = OpCompositeExtract %ushort %52 3
+ OpStore %9 %32
+ OpStore %6 %33
+ OpStore %7 %34
+ OpStore %8 %35
+ %40 = OpLoad %ulong %5
+ %41 = OpLoad %ushort %6
+ %42 = OpLoad %ushort %7
+ %43 = OpLoad %ushort %8
+ %44 = OpLoad %ushort %9
+ %58 = OpUndef %v4uchar
+ %91 = OpBitcast %ushort %41
+ %68 = OpUConvert %uchar %91
+ %59 = OpCompositeInsert %v4uchar %68 %58 0
+ %92 = OpBitcast %ushort %42
+ %69 = OpUConvert %uchar %92
+ %60 = OpCompositeInsert %v4uchar %69 %59 1
+ %93 = OpBitcast %ushort %43
+ %70 = OpUConvert %uchar %93
+ %61 = OpCompositeInsert %v4uchar %70 %60 2
+ %94 = OpBitcast %ushort %44
+ %71 = OpUConvert %uchar %94
+ %62 = OpCompositeInsert %v4uchar %71 %61 3
+ %72 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %40
+ OpStore %72 %62
OpReturn
OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 981da86..37cef00 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -843,6 +843,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
(_, ArgumentSemantics::Address) => return Ok(desc.op),
(t, ArgumentSemantics::RegisterPointer)
| (t, ArgumentSemantics::Default)
+ | (t, ArgumentSemantics::DefaultRelaxed)
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
let generated_id = id_def.new_id(id_type);
@@ -933,17 +934,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn insert_composite_read(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver<'a>,
- (scalar_type, vec_len): (ast::ScalarType, u8),
+ typ: (ast::ScalarType, u8),
scalar_dst: Option<spirv::Word>,
+ scalar_sema_override: Option<ArgumentSemantics>,
composite_src: (spirv::Word, u8),
) -> spirv::Word {
- let new_id =
- scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Vector(scalar_type, vec_len)));
+ let new_id = scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Scalar(typ.0)));
func.push(Statement::Composite(CompositeRead {
- typ: scalar_type,
+ typ: typ.0,
dst: new_id,
+ dst_semantics_override: scalar_sema_override,
src_composite: composite_src.0,
src_index: composite_src.1 as u32,
+ src_len: typ.1 as u32,
}));
new_id
}
@@ -963,7 +966,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
match desc.sema {
- ArgumentSemantics::Default => {
+ ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
@@ -1049,18 +1052,19 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn member_src(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vec_len): (ast::ScalarType, u8),
+ typ: (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
if desc.is_dst {
return Err(TranslateError::Unreachable);
}
- let new_id = self.id_def.new_id(ast::Type::Vector(scalar_type, vec_len));
- self.func.push(Statement::Composite(CompositeRead {
- typ: scalar_type,
- dst: new_id,
- src_composite: desc.op.0,
- src_index: desc.op.1 as u32,
- }));
+ let new_id = Self::insert_composite_read(
+ self.func,
+ self.id_def,
+ typ,
+ None,
+ Some(desc.sema),
+ desc.op,
+ );
Ok(new_id)
}
@@ -1077,10 +1081,11 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
let newer_id = self.id_def.new_id(typ);
self.func.push(Statement::Instruction(ast::Instruction::Mov(
ast::MovDetails {
- typ: typ,
+ typ: ast::Type::Scalar(scalar_type),
src_is_address: false,
- dst_width: 0,
- src_width: vec_len,
+ dst_width: vec_len,
+ src_width: 0,
+ relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed,
},
ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
(newer_id, idx as u8),
@@ -1099,6 +1104,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
self.id_def,
(scalar_type, vec_len),
Some(*id),
+ Some(desc.sema),
(new_id, idx as u8),
);
}
@@ -1144,9 +1150,9 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vec_len): (ast::ScalarType, u8),
+ typ: (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
- self.member_src(desc, (scalar_type, vec_len))
+ self.member_src(desc, typ)
}
fn id_or_vector(
@@ -1195,123 +1201,41 @@ fn insert_implicit_conversions(
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
- Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call)?,
- Statement::Instruction(inst) => match inst {
- ast::Instruction::Ld(ld, arg) => {
- let pre_conv = get_implicit_conversions_ld_src(
- id_def,
- ld.typ,
- ld.state_space,
- arg.src,
- false,
- )?;
- let post_conv = get_implicit_conversions_ld_dst(
- id_def,
- ld.typ,
- arg.dst,
- should_convert_relaxed_dst,
- false,
- )?;
- insert_with_conversions(
- &mut result,
- id_def,
- arg,
- pre_conv.into_iter(),
- iter::empty(),
- post_conv.into_iter().collect(),
- |arg| &mut arg.src,
- |arg| &mut arg.dst,
- |arg| ast::Instruction::Ld(ld, arg),
- )
+ Statement::Call(call) => insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ call,
+ should_bitcast_wrapper,
+ None,
+ )?,
+ Statement::Instruction(inst) => {
+ let mut default_conversion_fn = should_bitcast_wrapper
+ as fn(_, _, _) -> Result<Option<ConversionKind>, TranslateError>;
+ let mut state_space = None;
+ if let ast::Instruction::Ld(d, _) = &inst {
+ state_space = Some(d.state_space);
}
- ast::Instruction::St(st, arg) => {
- let pre_conv = get_implicit_conversions_ld_dst(
- id_def,
- st.typ,
- arg.src2,
- should_convert_relaxed_src,
- true,
- )?;
- let post_conv = get_implicit_conversions_ld_src(
- id_def,
- st.typ,
- st.state_space.to_ld_ss(),
- arg.src1,
- true,
- )?;
- let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param
- || st.state_space == ast::StStateSpace::Local
- {
- (Vec::new(), post_conv)
- } else {
- (post_conv, Vec::new())
- };
- insert_with_conversions(
- &mut result,
- id_def,
- arg,
- pre_conv.into_iter(),
- pre_conv_dest.into_iter(),
- post_conv,
- |arg| &mut arg.src2,
- |arg| &mut arg.src1,
- |arg| ast::Instruction::St(st, arg),
- )
+ if let ast::Instruction::St(d, _) = &inst {
+ state_space = Some(d.state_space.to_ld_ss());
}
- ast::Instruction::Mov(d, ast::Arg2Mov::Normal(mut arg)) => {
- // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2
- // TODO: handle the case of mixed vector/scalar implicit conversions
- let inst_typ_is_bit = match d.typ {
- ast::Type::Scalar(t) => ast::ScalarType::from(t).kind() == ScalarKind::Bit,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => false,
- };
- let mut did_vector_implicit = false;
- let mut post_conv = None;
- if inst_typ_is_bit {
- let src_type = id_def.get_typed(arg.src)?;
- if let ast::Type::Vector(_, _) = src_type {
- arg.src = insert_conversion_src(
- &mut result,
- id_def,
- arg.src,
- src_type,
- d.typ.into(),
- ConversionKind::Default,
- );
- did_vector_implicit = true;
- }
- let dst_type = id_def.get_typed(arg.dst)?;
- if let ast::Type::Vector(_, _) = dst_type {
- post_conv = Some(get_conversion_dst(
- id_def,
- &mut arg.dst,
- d.typ.into(),
- dst_type,
- ConversionKind::Default,
- ));
- did_vector_implicit = true;
- }
- }
- if did_vector_implicit {
- result.push(Statement::Instruction(ast::Instruction::Mov(
- d,
- ast::Arg2Mov::Normal(arg),
- )));
- } else {
- insert_implicit_bitcasts(
- &mut result,
- id_def,
- ast::Instruction::Mov(d, ast::Arg2Mov::Normal(arg)),
- )?;
- }
- if let Some(post_conv) = post_conv {
- result.push(post_conv);
- }
+ if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
+ default_conversion_fn = should_bitcast_packed;
}
- inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst)?,
- },
- Statement::Composite(c) => insert_implicit_bitcasts(&mut result, id_def, c)?,
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ inst,
+ default_conversion_fn,
+ state_space,
+ )?;
+ }
+ Statement::Composite(composite) => insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ composite,
+ should_bitcast_wrapper,
+ None,
+ )?,
s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
@@ -1326,6 +1250,77 @@ fn insert_implicit_conversions(
Ok(result)
}
+fn insert_implicit_conversions_impl(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut MutableNumericIdResolver,
+ stmt: impl VisitVariableExpanded,
+ default_conversion_fn: fn(
+ ast::Type,
+ ast::Type,
+ Option<ast::LdStateSpace>,
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ state_space: Option<ast::LdStateSpace>,
+) -> Result<(), TranslateError> {
+ let mut post_conv = Vec::new();
+ let statement = stmt.visit_variable_extended(&mut |desc, typ| {
+ let instr_type = match typ {
+ None => return Ok(desc.op),
+ Some(t) => t,
+ };
+ let operand_type = id_def.get_typed(desc.op)?;
+ let mut conversion_fn = default_conversion_fn;
+ match desc.sema {
+ ArgumentSemantics::Default => {}
+ ArgumentSemantics::DefaultRelaxed => {
+ if desc.is_dst {
+ conversion_fn = should_convert_relaxed_dst_wrapper;
+ } else {
+ conversion_fn = should_convert_relaxed_src_wrapper;
+ }
+ }
+ ArgumentSemantics::PhysicalPointer => {
+ conversion_fn = bitcast_physical_pointer;
+ }
+ ArgumentSemantics::RegisterPointer => {
+ conversion_fn = force_bitcast;
+ }
+ ArgumentSemantics::Address => {
+ conversion_fn = force_bitcast_ptr_to_bit;
+ }
+ };
+ match conversion_fn(operand_type, instr_type, state_space)? {
+ Some(conv_kind) => {
+ let conv_output = if desc.is_dst {
+ &mut post_conv
+ } else {
+ &mut *func
+ };
+ let mut from = instr_type;
+ let mut to = operand_type;
+ let mut src = id_def.new_id(instr_type);
+ let mut dst = desc.op;
+ let result = Ok(src);
+ if !desc.is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from, &mut to);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from,
+ to,
+ kind: conv_kind,
+ }));
+ result
+ }
+ None => Ok(desc.op),
+ }
+ })?;
+ func.push(statement);
+ func.append(&mut post_conv);
+ Ok(())
+}
+
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1505,7 +1500,11 @@ fn emit_function_body_ops(
composite_src,
scalar_src,
)) => {
- let result_type = map.get_or_add(builder, SpirvType::from(d.typ));
+ let scalar_type = d.typ.get_scalar()?;
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)),
+ );
let result_id = Some(dst.0);
builder.composite_insert(
result_type,
@@ -1545,8 +1544,8 @@ fn emit_function_body_ops(
// Obviously, old and buggy one is used for compiling L0 SPIRV
// https://github.com/intel/intel-graphics-compiler/issues/148
let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
- let const_true = builder.constant_true(type_pred);
- let const_false = builder.constant_false(type_pred);
+ let const_true = builder.constant_true(type_pred, None);
+ let const_false = builder.constant_false(type_pred, None);
builder.select(result_type, result_id, operand, const_false, const_true)
}
_ => builder.not(result_type, result_id, operand),
@@ -2700,12 +2699,9 @@ where
fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- (scalar_type, vec_len): (ast::ScalarType, u8),
+ (scalar_type, _): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
- self(
- desc.new_op(desc.op),
- Some(ast::Type::Vector(scalar_type.into(), vec_len)),
- )
+ self(desc.new_op(desc.op), Some(ast::Type::Scalar(scalar_type)))
}
}
@@ -2793,6 +2789,8 @@ pub struct ArgumentDescriptor<Op> {
pub enum ArgumentSemantics {
// normal register access
Default,
+ // normal register access with relaxed conversion rules (ld/st)
+ DefaultRelaxed,
// st/ld global
PhysicalPointer,
// st/ld .param, .local
@@ -2834,11 +2832,12 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)?)
+ let is_wide = d.is_wide();
+ ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type, is_wide)?)
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)?)
+ ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type, false)?)
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
@@ -2889,7 +2888,8 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
ast::Instruction::Mad(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Mad(d, a.map(visitor, inst_type)?)
+ let is_wide = d.is_wide();
+ ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?)
}
})
}
@@ -3004,6 +3004,27 @@ where
}
impl ast::Type {
+ fn widen(self) -> Result<Self, TranslateError> {
+ match self {
+ ast::Type::Scalar(scalar) => {
+ let kind = scalar.kind();
+ let width = scalar.width();
+ if (kind != ScalarKind::Signed
+ && kind != ScalarKind::Unsigned
+ && kind != ScalarKind::Bit)
+ || (width == 8)
+ {
+ return Err(TranslateError::MismatchedType);
+ }
+ Ok(ast::Type::Scalar(ast::ScalarType::from_parts(
+ width * 2,
+ kind,
+ )))
+ }
+ _ => Err(TranslateError::Unreachable),
+ }
+ }
+
fn to_parts(self) -> TypeParts {
match self {
ast::Type::Scalar(scalar) => TypeParts {
@@ -3102,8 +3123,10 @@ type Arg2St = ast::Arg2St<ExpandedArgParams>;
struct CompositeRead {
pub typ: ast::ScalarType,
pub dst: spirv::Word,
+ pub dst_semantics_override: Option<ArgumentSemantics>,
pub src_composite: spirv::Word,
pub src_index: u32,
+ pub src_len: u32,
}
impl VisitVariableExpanded for CompositeRead {
@@ -3116,12 +3139,15 @@ impl VisitVariableExpanded for CompositeRead {
self,
f: &mut F,
) -> Result<ExpandedStatement, TranslateError> {
+ let dst_sema = self
+ .dst_semantics_override
+ .unwrap_or(ArgumentSemantics::Default);
Ok(Statement::Composite(CompositeRead {
dst: f(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ sema: dst_sema,
},
Some(ast::Type::Scalar(self.typ)),
)?,
@@ -3131,7 +3157,7 @@ impl VisitVariableExpanded for CompositeRead {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(self.typ)),
+ Some(ast::Type::Vector(self.typ, self.src_len as u8)),
)?,
..self
}))
@@ -3328,7 +3354,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ sema: ArgumentSemantics::DefaultRelaxed,
},
t.into(),
)?;
@@ -3380,7 +3406,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ sema: ArgumentSemantics::DefaultRelaxed,
},
t,
)?;
@@ -3429,9 +3455,9 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
op: self.src,
is_dst: false,
sema: if details.src_is_address {
- ArgumentSemantics::RegisterPointer
+ ArgumentSemantics::Address
} else {
- ArgumentSemantics::PhysicalPointer
+ ArgumentSemantics::Default
},
},
details.typ.into(),
@@ -3476,13 +3502,14 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
) -> Result<ast::Arg2MovMember<U>, TranslateError> {
match self {
ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
+ let scalar_type = details.typ.get_scalar()?;
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src1 = visitor.id(
ArgumentDescriptor {
@@ -3490,7 +3517,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src2 = visitor.id(
ArgumentDescriptor {
@@ -3498,6 +3525,8 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: false,
sema: if details.src_is_address {
ArgumentSemantics::Address
+ } else if details.relaxed_src2_conv {
+ ArgumentSemantics::DefaultRelaxed
} else {
ArgumentSemantics::Default
},
@@ -3527,13 +3556,14 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
Ok(ast::Arg2MovMember::Src(dst, src))
}
ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
+ let scalar_type = details.typ.get_scalar()?;
let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let composite_src = visitor.id(
ArgumentDescriptor {
@@ -3541,16 +3571,19 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(ast::Type::Vector(scalar_type, details.dst_width)),
)?;
- let scalar_typ = details.typ.get_scalar()?;
let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ sema: if details.relaxed_src2_conv {
+ ArgumentSemantics::DefaultRelaxed
+ } else {
+ ArgumentSemantics::Default
+ },
},
- (scalar_typ.into(), details.src_width),
+ (scalar_type.into(), details.src_width),
)?;
Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
}
@@ -3570,7 +3603,8 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ typ: ast::Type,
+ is_wide: bool,
) -> Result<ast::Arg3<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@@ -3578,7 +3612,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ Some(if is_wide { typ.widen()? } else { typ }),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -3586,7 +3620,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- t,
+ typ,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -3594,7 +3628,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- t,
+ typ,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -3646,6 +3680,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
self,
visitor: &mut V,
t: ast::Type,
+ is_wide: bool,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@@ -3653,7 +3688,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ Some(if is_wide { t.widen()? } else { t }),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -4050,6 +4085,54 @@ impl<T> ast::OperandOrVector<T> {
}
}
+impl ast::MulDetails {
+ fn is_wide(&self) -> bool {
+ match self {
+ ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide,
+ ast::MulDetails::Float(_) => false,
+ }
+ }
+}
+
+fn force_bitcast(
+ operand: ast::Type,
+ instr: ast::Type,
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instr != operand {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Ok(None)
+ }
+}
+
+fn bitcast_physical_pointer(
+ operand_type: ast::Type,
+ _: ast::Type,
+ ss: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ match operand_type {
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => {
+ if let Some(space) = ss {
+ Ok(Some(ConversionKind::BitToPtr(space)))
+ } else {
+ Err(TranslateError::Unreachable)
+ }
+ }
+ _ => Err(TranslateError::MismatchedType),
+ }
+}
+
+fn force_bitcast_ptr_to_bit(
+ _: ast::Type,
+ _: ast::Type,
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ Ok(Some(ConversionKind::PtrToBit))
+}
+
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@@ -4077,187 +4160,50 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
}
}
-fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>>(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- mut instr: T,
- pre_conv_src: impl ExactSizeIterator<Item = ImplicitConversion>,
- pre_conv_dst: impl ExactSizeIterator<Item = ImplicitConversion>,
- mut post_conv: Vec<ImplicitConversion>,
- mut src: impl FnMut(&mut T) -> &mut spirv::Word,
- mut dst: impl FnMut(&mut T) -> &mut spirv::Word,
- to_inst: ToInstruction,
-) {
- insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
- insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
- if post_conv.len() > 0 {
- let new_id = id_def.new_id(post_conv[0].from);
- post_conv[0].src = new_id;
- post_conv.last_mut().unwrap().dst = *dst(&mut instr);
- *dst(&mut instr) = new_id;
- }
- func.push(Statement::Instruction(to_inst(instr)));
- for conv in post_conv {
- func.push(Statement::Conversion(conv));
- }
-}
-
-fn insert_with_conversions_pre_conv<T>(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- mut instr: &mut T,
- pre_conv: impl ExactSizeIterator<Item = ImplicitConversion>,
- src: &mut impl FnMut(&mut T) -> &mut spirv::Word,
-) {
- let pre_conv_len = pre_conv.len();
- for (i, mut conv) in pre_conv.enumerate() {
- let original_src = src(&mut instr);
- if i == 0 {
- conv.src = *original_src;
- }
- if i == pre_conv_len - 1 {
- let new_id = id_def.new_id(conv.to);
- conv.dst = new_id;
- *original_src = new_id;
+fn should_bitcast_packed(
+ operand: ast::Type,
+ instr: ast::Type,
+ ss: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
+ (operand, instr)
+ {
+ if scalar.kind() == ScalarKind::Bit
+ && scalar.width() == (vec_underlying_type.width() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
}
- func.push(Statement::Conversion(conv));
}
+ should_bitcast_wrapper(operand, instr, ss)
}
-fn get_implicit_conversions_ld_dst<
- ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
->(
- id_def: &mut MutableNumericIdResolver,
- instr_type: ast::Type,
- dst: spirv::Word,
- should_convert: ShouldConvert,
- in_reverse: bool,
-) -> Result<Option<ImplicitConversion>, TranslateError> {
- let dst_type = id_def.get_typed(dst)?;
- if let Some(conv) = should_convert(dst_type, instr_type) {
- Ok(Some(ImplicitConversion {
- src: u32::max_value(),
- dst: u32::max_value(),
- from: if !in_reverse { instr_type } else { dst_type },
- to: if !in_reverse { dst_type } else { instr_type },
- kind: conv,
- }))
- } else {
- Ok(None)
+fn should_bitcast_wrapper(
+ operand: ast::Type,
+ instr: ast::Type,
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instr == operand {
+ return Ok(None);
}
-}
-
-fn get_implicit_conversions_ld_src(
- id_def: &mut MutableNumericIdResolver,
- instr_type: ast::Type,
- state_space: ast::LdStateSpace,
- src: spirv::Word,
- in_reverse_param_local: bool,
-) -> Result<Vec<ImplicitConversion>, TranslateError> {
- let src_type = id_def.get_typed(src)?;
- match state_space {
- ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
- if src_type != instr_type {
- Ok(vec![
- ImplicitConversion {
- src: u32::max_value(),
- dst: u32::max_value(),
- from: if !in_reverse_param_local {
- src_type
- } else {
- instr_type
- },
- to: if !in_reverse_param_local {
- instr_type
- } else {
- src_type
- },
- kind: ConversionKind::Default,
- };
- 1
- ])
- } else {
- Ok(Vec::new())
- }
- }
- ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
- let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
- mem::size_of::<usize>() as u8,
- ScalarKind::Bit,
- ));
- let mut result = Vec::new();
- // HACK ALERT
- // IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
- // additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
- // TODO: error out if the src is not B64/U64/S64
- if let ast::Type::Scalar(scalar_src_type) = src_type {
- if scalar_src_type.kind() == ScalarKind::Signed {
- result.push(ImplicitConversion {
- src: u32::max_value(),
- dst: u32::max_value(),
- from: src_type,
- to: new_src_type,
- kind: ConversionKind::Default,
- });
- }
- }
- result.push(ImplicitConversion {
- src: u32::max_value(),
- dst: u32::max_value(),
- from: src_type,
- to: instr_type,
- kind: ConversionKind::BitToPtr(state_space),
- });
- if result.len() == 2 {
- let new_id = id_def.new_id(new_src_type);
- result[0].dst = new_id;
- result[1].src = new_id;
- result[1].from = new_src_type;
- }
- Ok(result)
- }
- _ => Err(TranslateError::Todo),
+ if should_bitcast(instr, operand) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
}
}
-#[must_use]
-fn insert_conversion_src(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- src: spirv::Word,
+fn should_convert_relaxed_src_wrapper(
src_type: ast::Type,
instr_type: ast::Type,
- conv: ConversionKind,
-) -> spirv::Word {
- let temp_src = id_def.new_id(instr_type);
- func.push(Statement::Conversion(ImplicitConversion {
- src: src,
- dst: temp_src,
- from: src_type,
- to: instr_type,
- kind: conv,
- }));
- temp_src
-}
-
-#[must_use]
-fn get_conversion_dst(
- id_def: &mut MutableNumericIdResolver,
- dst: &mut spirv::Word,
- instr_type: ast::Type,
- dst_type: ast::Type,
- kind: ConversionKind,
-) -> ExpandedStatement {
- let original_dst = *dst;
- let temp_dst = id_def.new_id(instr_type);
- *dst = temp_dst;
- Statement::Conversion(ImplicitConversion {
- src: temp_dst,
- dst: original_dst,
- from: instr_type,
- to: dst_type,
- kind: kind,
- })
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if src_type == instr_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_src(src_type, instr_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
@@ -4302,6 +4248,20 @@ fn should_convert_relaxed_src(
}
}
+fn should_convert_relaxed_dst_wrapper(
+ dst_type: ast::Type,
+ instr_type: ast::Type,
+ _: Option<ast::LdStateSpace>,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if dst_type == instr_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_dst(dst_type, instr_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
+}
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: ast::Type,
@@ -4357,55 +4317,6 @@ fn should_convert_relaxed_dst(
}
}
-fn insert_implicit_bitcasts(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- stmt: impl VisitVariableExpanded,
-) -> Result<(), TranslateError> {
- let mut dst_coercion = None;
- let instr = stmt.visit_variable_extended(&mut |mut desc, typ| {
- let id_type_from_instr = match typ {
- Some(t) => t,
- None => return Ok(desc.op),
- };
- let id_actual_type = id_def.get_typed(desc.op)?;
- let conv_kind = if desc.sema == ArgumentSemantics::Address {
- Some(ConversionKind::PtrToBit)
- } else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
- Some(ConversionKind::Default)
- } else {
- None
- };
- if let Some(conv_kind) = conv_kind {
- if desc.is_dst {
- dst_coercion = Some(get_conversion_dst(
- id_def,
- &mut desc.op,
- id_type_from_instr,
- id_actual_type,
- conv_kind,
- ));
- Ok(desc.op)
- } else {
- Ok(insert_conversion_src(
- func,
- id_def,
- desc.op,
- id_actual_type,
- id_type_from_instr,
- conv_kind,
- ))
- }
- } else {
- Ok(desc.op)
- }
- })?;
- func.push(instr);
- if let Some(cond) = dst_coercion {
- func.push(cond);
- }
- Ok(())
-}
impl<'a> ast::MethodDecl<'a, &'a str> {
fn name(&self) -> &'a str {
match self {