summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-14 21:45:56 +0200
committerAndrzej Janik <[email protected]>2020-09-14 21:45:56 +0200
commitbb5025c9b17e3fc46e454ca8faab1e85e0361ba8 (patch)
tree07df096e1ad16e8c9464aac17c99194e7257937e
parent48dac435400117935624aed244d1442982c874e2 (diff)
downloadZLUDA-bb5025c9b17e3fc46e454ca8faab1e85e0361ba8.tar.gz
ZLUDA-bb5025c9b17e3fc46e454ca8faab1e85e0361ba8.zip
Refactor implicit conversions and start implementing vector extract/insert
-rw-r--r--ptx/src/ast.rs14
-rw-r--r--ptx/src/ptx.lalrpop15
-rw-r--r--ptx/src/test/spirv_run/call.spvtxt52
-rw-r--r--ptx/src/test/spirv_run/cvta.spvtxt24
-rw-r--r--ptx/src/test/spirv_run/ld_st_implicit.ptx20
-rw-r--r--ptx/src/test/spirv_run/ld_st_implicit.spvtxt48
-rw-r--r--ptx/src/test/spirv_run/mod.rs64
-rw-r--r--ptx/src/test/spirv_run/not.spvtxt16
-rw-r--r--ptx/src/test/spirv_run/shl.spvtxt16
-rw-r--r--ptx/src/test/spirv_run/vector.spvtxt117
-rw-r--r--ptx/src/translate.rs1270
11 files changed, 1117 insertions, 539 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 7921930..078cb31 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -349,6 +349,7 @@ pub trait ArgParams {
type ID;
type Operand;
type CallOperand;
+ type VecOperand;
}
pub struct ParsedArgParams<'a> {
@@ -359,6 +360,7 @@ impl<'a> ArgParams for ParsedArgParams<'a> {
type ID = &'a str;
type Operand = Operand<&'a str>;
type CallOperand = CallOperand<&'a str>;
+ type VecOperand = (&'a str, u8);
}
pub struct Arg1<P: ArgParams> {
@@ -376,9 +378,9 @@ pub struct Arg2St<P: ArgParams> {
}
pub enum Arg2Vec<P: ArgParams> {
- Dst((P::ID, u8), P::ID),
- Src(P::ID, (P::ID, u8)),
- Both((P::ID, u8), (P::ID, u8)),
+ Dst(P::VecOperand, P::ID),
+ Src(P::ID, P::VecOperand),
+ Both(P::VecOperand, P::VecOperand),
}
pub struct Arg3<P: ArgParams> {
@@ -424,8 +426,7 @@ pub struct LdData {
pub qualifier: LdStQualifier,
pub state_space: LdStateSpace,
pub caching: LdCacheOperator,
- pub vector: Option<u8>,
- pub typ: ScalarType,
+ pub typ: Type,
}
#[derive(Copy, Clone, PartialEq, Eq)]
@@ -710,8 +711,7 @@ pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StStateSpace,
pub caching: StCacheOperator,
- pub vector: Option<u8>,
- pub typ: ScalarType,
+ pub typ: Type,
}
#[derive(PartialEq, Eq, Copy, Clone)]
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index fd419f5..6e5f5e3 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -269,10 +269,10 @@ ScalarType: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
".pred" => ast::ScalarType::Pred,
- MemoryType
+ LdStScalarType
};
-MemoryType: ast::ScalarType = {
+LdStScalarType: ast::ScalarType = {
".b8" => ast::ScalarType::B8,
".b16" => ast::ScalarType::B16,
".b32" => ast::ScalarType::B32,
@@ -446,13 +446,12 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
+ "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
ast::Instruction::Ld(
ast::LdData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
- vector: v,
typ: t
},
ast::Arg2 { dst:dst, src:src }
@@ -460,6 +459,11 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
};
+LdStType: ast::Type = {
+ <v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v),
+ <t:LdStScalarType> => ast::Type::Scalar(t),
+}
+
LdStQualifier: ast::LdStQualifier = {
".weak" => ast::LdStQualifier::Weak,
".volatile" => ast::LdStQualifier::Volatile,
@@ -895,13 +899,12 @@ ShlType: ast::ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => {
+ "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> "[" <src1:Operand> "]" "," <src2:Operand> => {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
state_space: ss.unwrap_or(ast::StStateSpace::Generic),
caching: cop.unwrap_or(ast::StCacheOperator::Writeback),
- vector: v,
typ: t
},
ast::Arg2St { src1:src1, src2:src2 }
diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt
index 001cda3..ca4685a 100644
--- a/ptx/src/test/spirv_run/call.spvtxt
+++ b/ptx/src/test/spirv_run/call.spvtxt
@@ -4,20 +4,20 @@
OpCapability Kernel
OpCapability Int64
OpCapability Int8
- %45 = OpExtInstImport "OpenCL.std"
+ %47 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %4 "call"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
- %48 = OpTypeFunction %void %ulong %ulong
+ %50 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
- %51 = OpTypeFunction %ulong %ulong
+ %53 = OpTypeFunction %ulong %ulong
%ulong_1 = OpConstant %ulong 1
- %4 = OpFunction %void None %48
+ %4 = OpFunction %void None %50
%12 = OpFunctionParameter %ulong
%13 = OpFunctionParameter %ulong
- %30 = OpLabel
+ %32 = OpLabel
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
@@ -38,7 +38,9 @@
%18 = OpLoad %ulong %28
OpStore %9 %18
%21 = OpLoad %ulong %9
- %20 = OpCopyObject %ulong %21
+ %29 = OpCopyObject %ulong %21
+ %30 = OpCopyObject %ulong %29
+ %20 = OpCopyObject %ulong %30
OpStore %10 %20
%23 = OpLoad %ulong %10
%22 = OpFunctionCall %ulong %1 %23
@@ -48,26 +50,26 @@
OpStore %9 %24
%26 = OpLoad %ulong %8
%27 = OpLoad %ulong %9
- %29 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26
- OpStore %29 %27
+ %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26
+ OpStore %31 %27
OpReturn
OpFunctionEnd
- %1 = OpFunction %ulong None %51
- %34 = OpFunctionParameter %ulong
- %43 = OpLabel
- %32 = OpVariable %_ptr_Function_ulong Function
- %31 = OpVariable %_ptr_Function_ulong Function
+ %1 = OpFunction %ulong None %53
+ %36 = OpFunctionParameter %ulong
+ %45 = OpLabel
+ %34 = OpVariable %_ptr_Function_ulong Function
%33 = OpVariable %_ptr_Function_ulong Function
- OpStore %32 %34
- %36 = OpLoad %ulong %32
- %35 = OpCopyObject %ulong %36
- OpStore %33 %35
- %38 = OpLoad %ulong %33
- %37 = OpIAdd %ulong %38 %ulong_1
- OpStore %33 %37
- %40 = OpLoad %ulong %33
- %39 = OpCopyObject %ulong %40
- OpStore %31 %39
- %41 = OpLoad %ulong %31
- OpReturnValue %41
+ %35 = OpVariable %_ptr_Function_ulong Function
+ OpStore %34 %36
+ %38 = OpLoad %ulong %34
+ %37 = OpCopyObject %ulong %38
+ OpStore %35 %37
+ %40 = OpLoad %ulong %35
+ %39 = OpIAdd %ulong %40 %ulong_1
+ OpStore %35 %39
+ %42 = OpLoad %ulong %35
+ %41 = OpCopyObject %ulong %42
+ OpStore %33 %41
+ %43 = OpLoad %ulong %33
+ OpReturnValue %43
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt
index e708613..84e7eac 100644
--- a/ptx/src/test/spirv_run/cvta.spvtxt
+++ b/ptx/src/test/spirv_run/cvta.spvtxt
@@ -4,20 +4,20 @@
OpCapability Kernel
OpCapability Int64
OpCapability Int8
- %25 = OpExtInstImport "OpenCL.std"
+ %29 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "cvta"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
- %28 = OpTypeFunction %void %ulong %ulong
+ %32 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%float = OpTypeFloat 32
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
- %1 = OpFunction %void None %28
+ %1 = OpFunction %void None %32
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
- %23 = OpLabel
+ %27 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@@ -32,18 +32,22 @@
%11 = OpCopyObject %ulong %12
OpStore %5 %11
%14 = OpLoad %ulong %4
- %13 = OpCopyObject %ulong %14
+ %22 = OpCopyObject %ulong %14
+ %21 = OpCopyObject %ulong %22
+ %13 = OpCopyObject %ulong %21
OpStore %4 %13
%16 = OpLoad %ulong %5
- %15 = OpCopyObject %ulong %16
+ %24 = OpCopyObject %ulong %16
+ %23 = OpCopyObject %ulong %24
+ %15 = OpCopyObject %ulong %23
OpStore %5 %15
%18 = OpLoad %ulong %4
- %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
- %17 = OpLoad %float %21
+ %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
+ %17 = OpLoad %float %25
OpStore %6 %17
%19 = OpLoad %ulong %5
%20 = OpLoad %float %6
- %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19
- OpStore %22 %20
+ %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19
+ OpStore %26 %20
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/ld_st_implicit.ptx b/ptx/src/test/spirv_run/ld_st_implicit.ptx
new file mode 100644
index 0000000..8562286
--- /dev/null
+++ b/ptx/src/test/spirv_run/ld_st_implicit.ptx
@@ -0,0 +1,20 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry ld_st_implicit(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.f32 temp, [in_addr];
+ st.global.f32 [out_addr], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt
new file mode 100644
index 0000000..e7dba5a
--- /dev/null
+++ b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt
@@ -0,0 +1,48 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %23 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "ld_st_implicit"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %26 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %float = OpTypeFloat 32
+%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
+ %uint = OpTypeInt 32 0
+ %1 = OpFunction %void None %26
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %21 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %10 = OpLoad %ulong %2
+ %9 = OpCopyObject %ulong %10
+ OpStore %4 %9
+ %12 = OpLoad %ulong %3
+ %11 = OpCopyObject %ulong %12
+ OpStore %5 %11
+ %14 = OpLoad %ulong %4
+ %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14
+ %18 = OpLoad %float %17
+ %30 = OpBitcast %ulong %18
+ %32 = OpUConvert %uint %30
+ %13 = OpBitcast %uint %32
+ OpStore %6 %13
+ %15 = OpLoad %ulong %5
+ %16 = OpLoad %ulong %6
+ %33 = OpBitcast %uint %16
+ %19 = OpUConvert %ulong %33
+ %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15
+ OpStore %20 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index a04f0eb..fd50d3c 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -8,10 +8,12 @@ use spirv_headers::Word;
use spirv_tools_sys::{
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
};
+use std::collections::hash_map::Entry;
use std::error;
use std::ffi::{c_void, CStr, CString};
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
+use std::hash::Hash;
use std::mem;
use std::slice;
use std::{borrow::Cow, collections::HashMap, env, fs, path::PathBuf, ptr, str};
@@ -41,6 +43,7 @@ macro_rules! test_ptx {
}
test_ptx!(ld_st, [1u64], [1u64]);
+test_ptx!(ld_st_implicit, [0.5f32], [0.5f32]);
test_ptx!(mov, [1u64], [1u64]);
test_ptx!(mul_lo, [1u64], [2u64]);
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
@@ -214,14 +217,45 @@ fn test_spvtxt_assert<'a>(
}
}
}
- panic!(spirv_text);
+ panic!(spirv_text.to_string());
}
unsafe { spirv_tools::spvContextDestroy(spv_context) };
Ok(())
}
+struct EqMap<T>
+where
+ T: Eq + Copy + Hash,
+{
+ m1: HashMap<T, T>,
+ m2: HashMap<T, T>,
+}
+
+impl<T: Copy + Eq + Hash> EqMap<T> {
+ fn new() -> Self {
+ EqMap {
+ m1: HashMap::new(),
+ m2: HashMap::new(),
+ }
+ }
+
+ fn is_equal(&mut self, t1: T, t2: T) -> bool {
+ match (self.m1.entry(t1), self.m2.entry(t2)) {
+ (Entry::Occupied(entry1), Entry::Occupied(entry2)) => {
+ *entry1.get() == t2 && *entry2.get() == t1
+ }
+ (Entry::Vacant(entry1), Entry::Vacant(entry2)) => {
+ entry1.insert(t2);
+ entry2.insert(t1);
+ true
+ }
+ _ => false,
+ }
+ }
+}
+
fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool {
- let mut map = HashMap::new();
+ let mut map = EqMap::new();
if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) {
return false;
}
@@ -247,7 +281,7 @@ fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool {
true
}
-fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap<Word, Word>) -> bool {
+fn is_block_equal(b1: &Block, b2: &Block, map: &mut EqMap<Word>) -> bool {
if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) {
return false;
}
@@ -262,11 +296,7 @@ fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap<Word, Word>) -> bool
true
}
-fn is_instr_equal(
- instr1: &Instruction,
- instr2: &Instruction,
- map: &mut HashMap<Word, Word>,
-) -> bool {
+fn is_instr_equal(instr1: &Instruction, instr2: &Instruction, map: &mut EqMap<Word>) -> bool {
if instr1.class.opcode != instr2.class.opcode {
return false;
}
@@ -306,24 +336,14 @@ fn is_instr_equal(
true
}
-fn is_word_equal(w1: &Word, w2: &Word, map: &mut HashMap<Word, Word>) -> bool {
- match map.entry(*w1) {
- std::collections::hash_map::Entry::Occupied(entry) => {
- if entry.get() != w2 {
- return false;
- }
- }
- std::collections::hash_map::Entry::Vacant(entry) => {
- entry.insert(*w2);
- }
- }
- true
+fn is_word_equal(t1: &Word, t2: &Word, map: &mut EqMap<Word>) -> bool {
+ map.is_equal(*t1, *t2)
}
-fn is_option_equal<T, F: FnOnce(&T, &T, &mut HashMap<Word, Word>) -> bool>(
+fn is_option_equal<T, F: FnOnce(&T, &T, &mut EqMap<Word>) -> bool>(
o1: &Option<T>,
o2: &Option<T>,
- map: &mut HashMap<Word, Word>,
+ map: &mut EqMap<Word>,
f: F,
) -> bool {
match (o1, o2) {
diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt
index de340ed..b358858 100644
--- a/ptx/src/test/spirv_run/not.spvtxt
+++ b/ptx/src/test/spirv_run/not.spvtxt
@@ -4,18 +4,18 @@
OpCapability Kernel
OpCapability Int64
OpCapability Int8
- %24 = OpExtInstImport "OpenCL.std"
+ %26 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "not"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
- %27 = OpTypeFunction %void %ulong %ulong
+ %29 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %1 = OpFunction %void None %27
+ %1 = OpFunction %void None %29
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
- %22 = OpLabel
+ %24 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@@ -35,11 +35,13 @@
%14 = OpLoad %ulong %20
OpStore %6 %14
%17 = OpLoad %ulong %6
- %16 = OpNot %ulong %17
+ %22 = OpCopyObject %ulong %17
+ %21 = OpNot %ulong %22
+ %16 = OpCopyObject %ulong %21
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
- %21 = OpConvertUToPtr %_ptr_Generic_ulong %18
- OpStore %21 %19
+ %23 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %23 %19
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt
index dbd2664..4843a65 100644
--- a/ptx/src/test/spirv_run/shl.spvtxt
+++ b/ptx/src/test/spirv_run/shl.spvtxt
@@ -4,20 +4,20 @@
OpCapability Kernel
OpCapability Int64
OpCapability Int8
- %25 = OpExtInstImport "OpenCL.std"
+ %27 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "shl"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
- %28 = OpTypeFunction %void %ulong %ulong
+ %30 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%uint = OpTypeInt 32 0
%uint_2 = OpConstant %uint 2
- %1 = OpFunction %void None %28
+ %1 = OpFunction %void None %30
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
- %23 = OpLabel
+ %25 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@@ -37,11 +37,13 @@
%14 = OpLoad %ulong %21
OpStore %6 %14
%17 = OpLoad %ulong %6
- %16 = OpShiftLeftLogical %ulong %17 %uint_2
+ %23 = OpCopyObject %ulong %17
+ %22 = OpShiftLeftLogical %ulong %23 %uint_2
+ %16 = OpCopyObject %ulong %22
OpStore %7 %16
%18 = OpLoad %ulong %5
%19 = OpLoad %ulong %7
- %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
- OpStore %22 %19
+ %24 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %24 %19
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt
index 6810fec..25dd80e 100644
--- a/ptx/src/test/spirv_run/vector.spvtxt
+++ b/ptx/src/test/spirv_run/vector.spvtxt
@@ -4,43 +4,92 @@
OpCapability Kernel
OpCapability Int64
OpCapability Int8
- %25 = OpExtInstImport "OpenCL.std"
+ %58 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "add"
+ OpEntryPoint Kernel %31 "vector"
%void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %v2uint = OpTypeVector %uint 2
+ %62 = OpTypeFunction %v2uint %v2uint
+%_ptr_Function_v2uint = OpTypePointer Function %v2uint
+%_ptr_Function_uint = OpTypePointer Function %uint
%ulong = OpTypeInt 64 0
- %28 = OpTypeFunction %void %ulong %ulong
+ %66 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %ulong_1 = OpConstant %ulong 1
- %1 = OpFunction %void None %28
- %8 = OpFunctionParameter %ulong
- %9 = OpFunctionParameter %ulong
- %23 = OpLabel
- %2 = OpVariable %_ptr_Function_ulong Function
- %3 = OpVariable %_ptr_Function_ulong Function
- %4 = OpVariable %_ptr_Function_ulong Function
- %5 = OpVariable %_ptr_Function_ulong Function
- %6 = OpVariable %_ptr_Function_ulong Function
- %7 = OpVariable %_ptr_Function_ulong Function
- OpStore %2 %8
- OpStore %3 %9
- %11 = OpLoad %ulong %2
- %10 = OpCopyObject %ulong %11
- OpStore %4 %10
- %13 = OpLoad %ulong %3
- %12 = OpCopyObject %ulong %13
- OpStore %5 %12
- %15 = OpLoad %ulong %4
- %21 = OpConvertUToPtr %_ptr_Generic_ulong %15
- %14 = OpLoad %ulong %21
- OpStore %6 %14
- %17 = OpLoad %ulong %6
- %16 = OpIAdd %ulong %17 %ulong_1
- OpStore %7 %16
- %18 = OpLoad %ulong %5
- %19 = OpLoad %ulong %7
- %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
- OpStore %22 %19
+%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
+ %1 = OpFunction %v2uint None %62
+ %7 = OpFunctionParameter %v2uint
+ %30 = OpLabel
+ %3 = OpVariable %_ptr_Function_v2uint Function
+ %2 = OpVariable %_ptr_Function_v2uint Function
+ %4 = OpVariable %_ptr_Function_v2uint Function
+ %5 = OpVariable %_ptr_Function_uint Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ OpStore %3 %7
+ %9 = OpLoad %v2uint %3
+ %24 = OpCompositeExtract %uint %9 0
+ %8 = OpCopyObject %uint %24
+ OpStore %5 %8
+ %11 = OpLoad %v2uint %3
+ %25 = OpCompositeExtract %uint %11 1
+ %10 = OpCopyObject %uint %25
+ OpStore %6 %10
+ %13 = OpLoad %uint %5
+ %14 = OpLoad %uint %6
+ %12 = OpIAdd %uint %13 %14
+ OpStore %6 %12
+ %16 = OpLoad %uint %6
+ %26 = OpCopyObject %uint %16
+ %15 = OpCompositeInsert %uint %26 %15 0
+ OpStore %4 %15
+ %18 = OpLoad %uint %6
+ %27 = OpCopyObject %uint %18
+ %17 = OpCompositeInsert %uint %27 %17 1
+ OpStore %4 %17
+ %20 = OpLoad %v2uint %4
+ %29 = OpCompositeExtract %uint %20 1
+ %28 = OpCopyObject %uint %29
+ %19 = OpCompositeInsert %uint %28 %19 0
+ OpStore %4 %19
+ %22 = OpLoad %v2uint %4
+ %21 = OpCopyObject %v2uint %22
+ OpStore %2 %21
+ %23 = OpLoad %v2uint %2
+ OpReturnValue %23
+ OpFunctionEnd
+ %31 = OpFunction %void None %66
+ %40 = OpFunctionParameter %ulong
+ %41 = OpFunctionParameter %ulong
+ %56 = OpLabel
+ %32 = OpVariable %_ptr_Function_ulong Function
+ %33 = OpVariable %_ptr_Function_ulong Function
+ %34 = OpVariable %_ptr_Function_ulong Function
+ %35 = OpVariable %_ptr_Function_ulong Function
+ %36 = OpVariable %_ptr_Function_v2uint Function
+ %37 = OpVariable %_ptr_Function_uint Function
+ %38 = OpVariable %_ptr_Function_uint Function
+ %39 = OpVariable %_ptr_Function_ulong Function
+ OpStore %32 %40
+ OpStore %33 %41
+ %43 = OpLoad %ulong %32
+ %42 = OpCopyObject %ulong %43
+ OpStore %34 %42
+ %45 = OpLoad %ulong %33
+ %44 = OpCopyObject %ulong %45
+ OpStore %35 %44
+ %47 = OpLoad %ulong %34
+ %54 = OpConvertUToPtr %_ptr_Generic_v2uint %47
+ %46 = OpLoad %v2uint %54
+ OpStore %36 %46
+ %49 = OpLoad %v2uint %36
+ %48 = OpFunctionCall %v2uint %1 %49
+ OpStore %36 %48
+ %51 = OpLoad %v2uint %36
+ %50 = OpCopyObject %ulong %51
+ OpStore %39 %50
+ %52 = OpLoad %ulong %35
+ %53 = OpLoad %v2uint %36
+ %55 = OpConvertUToPtr %_ptr_Generic_v2uint %52
+ OpStore %55 %53
OpReturn
OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 7591722..57d3485 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,5 +1,5 @@
use crate::ast;
-use rspirv::dr;
+use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, iter, mem};
@@ -398,7 +398,8 @@ fn normalize_labels(
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
- Statement::Call(_)
+ Statement::Composite(_)
+ | Statement::Call(_)
| Statement::Variable(_)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
@@ -528,13 +529,13 @@ fn insert_mem_ssa_statements<'a, 'b>(
ast::Instruction::Ret(d) => {
if let Some(out_param) = out_param {
let typ = id_def.get_type(out_param);
- let new_id = id_def.new_id(Some(typ));
+ let new_id = id_def.new_id(typ);
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
src: out_param,
},
- typ,
+ typ.unwrap(),
));
result.push(Statement::RetValue(d, new_id));
} else {
@@ -561,19 +562,25 @@ fn insert_mem_ssa_statements<'a, 'b>(
| Statement::Conversion(_)
| Statement::RetValue(_, _)
| Statement::Constant(_) => unreachable!(),
+ Statement::Composite(_) => todo!(),
}
}
(f_args, result)
}
trait VisitVariable: Sized {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement;
}
trait VisitVariableExpanded {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement;
@@ -585,8 +592,8 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
stmt: F,
) {
let mut post_statements = Vec::new();
- let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>| {
- let id_type = match (desc.typ, desc.is_pointer) {
+ let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
+ let id_type = match (id_def.get_type(desc.op), desc.is_pointer) {
(Some(t), false) => t,
(Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64),
(None, _) => return desc.op,
@@ -624,13 +631,15 @@ fn expand_arguments<'a, 'b>(
match s {
Statement::Call(call) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let new_call = call.map(&mut visitor);
+ let (new_call, post_stmts) = (call.map(&mut visitor), visitor.post_stmts);
result.push(Statement::Call(new_call));
+ result.extend(post_stmts);
}
Statement::Instruction(inst) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let new_inst = inst.map(&mut visitor);
+ let (new_inst, post_stmts) = (inst.map(&mut visitor), visitor.post_stmts);
result.push(Statement::Instruction(new_inst));
+ result.extend(post_stmts);
}
Statement::Variable(ast::Variable {
align,
@@ -646,7 +655,9 @@ fn expand_arguments<'a, 'b>(
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Conversion(_) | Statement::Constant(_) => unreachable!(),
+ Statement::Composite(_) | Statement::Conversion(_) | Statement::Constant(_) => {
+ unreachable!()
+ }
}
}
result
@@ -655,74 +666,79 @@ fn expand_arguments<'a, 'b>(
struct FlattenArguments<'a, 'b> {
func: &'b mut Vec<ExpandedStatement>,
id_def: &'b mut NumericIdResolver<'a>,
+ post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'b> FlattenArguments<'a, 'b> {
fn new(func: &'b mut Vec<ExpandedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
- FlattenArguments { func, id_def }
+ FlattenArguments {
+ func,
+ id_def,
+ post_stmts: Vec::new(),
+ }
}
}
impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
for FlattenArguments<'a, 'b>
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<ast::Type>,
+ ) -> spirv::Word {
desc.op
}
- fn operand(&mut self, desc: ArgumentDescriptor<ast::Operand<spirv::Word>>) -> spirv::Word {
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> spirv::Word {
match desc.op {
- ast::Operand::Reg(r) => self.variable(desc.new_op(r)),
+ ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)),
ast::Operand::Imm(x) => {
- if let Some(typ) = desc.typ {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id,
- typ: scalar_t,
- value: x,
- }));
- id
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
} else {
todo!()
- }
+ };
+ let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value: x,
+ }));
+ id
}
ast::Operand::RegOffset(reg, offset) => {
- if let Some(typ) = desc.typ {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i128,
- }));
- let result_id = self.id_def.new_id(desc.typ);
- let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- result_id
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
} else {
todo!()
- }
+ };
+ let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: scalar_t,
+ value: offset as i128,
+ }));
+ let result_id = self.id_def.new_id(Some(typ));
+ let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ result_id
}
}
}
@@ -730,18 +746,45 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ typ: ast::Type,
) -> spirv::Word {
match desc.op {
- ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg)),
- ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x))),
+ ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)),
+ ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ),
}
}
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- ) -> (spirv::Word, u8) {
- (self.variable(desc.new_op(desc.op.0)), desc.op.1)
+ typ: ast::MovVectorType,
+ ) -> spirv::Word {
+ let (vector_id, index) = desc.op;
+ let new_id = self.id_def.new_id(Some(ast::Type::Scalar(typ.into())));
+ let composite = if desc.is_dst {
+ Statement::Composite(CompositeAccess {
+ typ: typ,
+ dst: new_id,
+ src: vector_id,
+ index: index as u32,
+ is_write: true
+ })
+ } else {
+ Statement::Composite(CompositeAccess {
+ typ: typ,
+ dst: new_id,
+ src: vector_id,
+ index: index as u32,
+ is_write: false
+ })
+ };
+ if desc.is_dst {
+ self.post_stmts.push(composite);
+ new_id
+ } else {
+ self.func.push(composite);
+ new_id
+ }
}
}
@@ -768,48 +811,63 @@ fn insert_implicit_conversions(
match s {
Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call),
Statement::Instruction(inst) => match inst {
- ast::Instruction::Ld(ld, mut arg) => {
- arg.src = insert_implicit_conversions_ld_src(
- &mut result,
- ast::Type::Scalar(ld.typ),
+ ast::Instruction::Ld(ld, arg) => {
+ let pre_conv =
+ get_implicit_conversions_ld_src(id_def, ld.typ, ld.state_space, arg.src);
+ let post_conv = get_implicit_conversions_ld_dst(
id_def,
- ld.state_space,
- arg.src,
+ ld.typ,
+ arg.dst,
+ should_convert_relaxed_dst,
+ false,
);
- insert_with_implicit_conversion_dst(
+ insert_with_conversions(
&mut result,
- ld.typ,
id_def,
- should_convert_relaxed_dst,
arg,
+ pre_conv.into_iter(),
+ iter::empty(),
+ post_conv.into_iter().collect(),
+ |arg| &mut arg.src,
|arg| &mut arg.dst,
|arg| ast::Instruction::Ld(ld, arg),
- );
+ )
}
- ast::Instruction::St(st, mut arg) => {
- let arg_src2_type = id_def.get_type(arg.src2);
- if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
- arg.src2 = insert_conversion_src(
- &mut result,
- id_def,
- arg.src2,
- arg_src2_type,
- ast::Type::Scalar(st.typ),
- conv,
- );
- }
- arg.src1 = insert_implicit_conversions_ld_src(
- &mut result,
- ast::Type::Scalar(st.typ),
+ ast::Instruction::St(st, arg) => {
+ let pre_conv = get_implicit_conversions_ld_dst(
id_def,
+ st.typ,
+ arg.src2,
+ should_convert_relaxed_src,
+ true,
+ );
+ let post_conv = get_implicit_conversions_ld_src(
+ id_def,
+ st.typ,
st.state_space.to_ld_ss(),
arg.src1,
);
- result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
+ let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param {
+ (Vec::new(), post_conv)
+ } else {
+ (post_conv, Vec::new())
+ };
+ insert_with_conversions(
+ &mut result,
+ id_def,
+ arg,
+ pre_conv.into_iter(),
+ pre_conv_dest.into_iter(),
+ post_conv,
+ |arg| &mut arg.src2,
+ |arg| &mut arg.src1,
+ |arg| ast::Instruction::St(st, arg),
+ )
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
},
- s @ Statement::Conditional(_)
+ s @ Statement::Composite(_)
+ | s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
@@ -950,10 +1008,10 @@ fn emit_function_body_ops(
builder.branch(arg.src)?;
}
ast::Instruction::Ld(data, arg) => {
- if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
+ if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
- let result_type = map.get_or_add_scalar(builder, data.typ);
+ let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
match data.state_space {
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
@@ -967,7 +1025,6 @@ fn emit_function_body_ops(
}
ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
- || data.vector.is_some()
|| (data.state_space != ast::StStateSpace::Generic
&& data.state_space != ast::StStateSpace::Param
&& data.state_space != ast::StStateSpace::Global)
@@ -1030,7 +1087,10 @@ fn emit_function_body_ops(
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
- ast::Instruction::MovVector(_, _) => todo!(),
+ ast::Instruction::MovVector(t, arg) => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
+ builder.copy_object(result_type, Some(arg.dst()), arg.src())?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@@ -1042,6 +1102,19 @@ fn emit_function_body_ops(
Statement::RetValue(_, id) => {
builder.ret_value(*id)?;
}
+ Statement::Composite(c) => {
+ let result_type = map.get_or_add_scalar(builder, c.typ.into());
+ let result_id = Some(c.dst);
+ let indexes = [c.index];
+ if c.is_write {
+ let object = c.src;
+ let composite = c.dst;
+ builder.composite_insert(result_type, result_id, object, composite, indexes)?;
+ } else {
+ let composite = c.src;
+ builder.composite_extract(result_type, result_id, composite, indexes)?;
+ }
+ }
}
}
Ok(())
@@ -1188,7 +1261,7 @@ fn emit_setp(
match (setp.cmp_op, setp.typ.kind()) {
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Eq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
@@ -1196,14 +1269,14 @@ fn emit_setp(
}
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Less, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
@@ -1213,7 +1286,7 @@ fn emit_setp(
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::LessOrEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
@@ -1223,7 +1296,7 @@ fn emit_setp(
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Greater, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
@@ -1233,7 +1306,7 @@ fn emit_setp(
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
@@ -1294,54 +1367,56 @@ fn emit_implicit_conversion(
map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), dr::Error> {
- let (from_type, to_type) = match (cv.from, cv.to) {
- (ast::Type::Scalar(from), ast::Type::Scalar(to)) => (from, to),
- _ => todo!(),
- };
+ let from_parts = cv.from.to_parts();
+ let to_parts = cv.to.to_parts();
match cv.kind {
ConversionKind::Ptr(space) => {
let dst_type = map.get_or_add(
builder,
- SpirvType::Pointer(
- Box::new(SpirvType::Base(SpirvScalarKey::from(to_type))),
- space.to_spirv(),
- ),
+ SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
ConversionKind::Default => {
- if from_type.width() == to_type.width() {
- let dst_type = map.get_or_add_scalar(builder, to_type);
- if from_type.kind() != ScalarKind::Float && to_type.kind() != ScalarKind::Float {
+ if from_parts.width == to_parts.width {
+ let dst_type = map.get_or_add(builder, SpirvType::from(cv.from));
+ if from_parts.scalar_kind != ScalarKind::Float
+ && to_parts.scalar_kind != ScalarKind::Float
+ {
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
} else {
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
} else {
- let as_unsigned_type = map.get_or_add_scalar(
- builder,
- ast::ScalarType::from_parts(from_type.width(), ScalarKind::Unsigned),
- );
- let as_unsigned = builder.bitcast(as_unsigned_type, None, cv.src)?;
- let as_unsigned_wide_type =
- ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned);
- let as_unsigned_wide_spirv = map.get_or_add_scalar(
+ // This block is safe because it's illegal to implictly convert between floating point instructions
+ let same_width_bit_type = map.get_or_add(
builder,
- ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned),
+ SpirvType::from(ast::Type::from_parts(TypeParts {
+ scalar_kind: ScalarKind::Bit,
+ ..from_parts
+ })),
);
- if to_type.kind() == ScalarKind::Unsigned || to_type.kind() == ScalarKind::Byte {
- builder.u_convert(as_unsigned_wide_spirv, Some(cv.dst), as_unsigned)?;
+ let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
+ let wide_bit_type = ast::Type::from_parts(TypeParts {
+ scalar_kind: ScalarKind::Bit,
+ ..to_parts
+ });
+ let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type));
+ if to_parts.scalar_kind == ScalarKind::Unsigned
+ || to_parts.scalar_kind == ScalarKind::Bit
+ {
+ builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else {
- let as_unsigned_wide =
- builder.u_convert(as_unsigned_wide_spirv, None, as_unsigned)?;
+ let wide_bit_value =
+ builder.u_convert(wide_bit_type_spirv, None, same_width_bit_value)?;
emit_implicit_conversion(
builder,
map,
&ImplicitConversion {
- src: as_unsigned_wide,
+ src: wide_bit_value,
dst: cv.dst,
- from: ast::Type::Scalar(as_unsigned_wide_type),
+ from: wide_bit_type,
to: cv.to,
kind: ConversionKind::Default,
},
@@ -1627,8 +1702,8 @@ struct NumericIdResolver<'b> {
}
impl<'b> NumericIdResolver<'b> {
- fn get_type(&self, id: spirv::Word) -> ast::Type {
- self.type_check[&id]
+ fn get_type(&self, id: spirv::Word) -> Option<ast::Type> {
+ self.type_check.get(&id).map(|x| *x)
}
fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
@@ -1648,6 +1723,7 @@ enum Statement<I, P: ast::ArgParams> {
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Call(ResolvedCall<P>),
+ Composite(CompositeAccess),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Conversion(ImplicitConversion),
@@ -1671,31 +1747,37 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
.ret_params
.into_iter()
.map(|(id, typ)| {
- let new_id = visitor.variable(ArgumentDescriptor {
- op: id,
- typ: Some(typ.into()),
- is_dst: true,
- is_pointer: false,
- });
+ let new_id = visitor.variable(
+ ArgumentDescriptor {
+ op: id,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(typ.into()),
+ );
(new_id, typ)
})
.collect();
- let func = visitor.variable(ArgumentDescriptor {
- op: self.func,
- typ: None,
- is_dst: false,
- is_pointer: false,
- });
+ let func = visitor.variable(
+ ArgumentDescriptor {
+ op: self.func,
+ is_dst: false,
+ is_pointer: false,
+ },
+ None,
+ );
let param_list = self
.param_list
.into_iter()
.map(|(id, typ)| {
- let new_id = visitor.src_call_operand(ArgumentDescriptor {
- op: id,
- typ: Some(typ.into()),
- is_dst: false,
- is_pointer: false,
- });
+ let new_id = visitor.src_call_operand(
+ ArgumentDescriptor {
+ op: id,
+ is_dst: false,
+ is_pointer: false,
+ },
+ typ.into(),
+ );
(new_id, typ)
})
.collect();
@@ -1709,7 +1791,10 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
}
impl VisitVariable for ResolvedCall<NormalizedArgParams> {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement {
@@ -1718,7 +1803,9 @@ impl VisitVariable for ResolvedCall<NormalizedArgParams> {
}
impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement {
@@ -1750,6 +1837,7 @@ impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
+ type VecOperand = (spirv::Word, u8);
}
impl ArgParamsEx for NormalizedArgParams {
@@ -1766,6 +1854,7 @@ impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
type CallOperand = spirv::Word;
+ type VecOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
@@ -1775,30 +1864,47 @@ impl ArgParamsEx for ExpandedArgParams {
}
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
- fn variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID;
- fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand;
- fn src_call_operand(&mut self, desc: ArgumentDescriptor<T::CallOperand>) -> U::CallOperand;
- fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(T::ID, u8)>) -> (U::ID, u8);
+ fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
+ fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<T::CallOperand>,
+ typ: ast::Type,
+ ) -> U::CallOperand;
+ fn src_vec_operand(
+ &mut self,
+ desc: ArgumentDescriptor<T::VecOperand>,
+ typ: ast::MovVectorType,
+ ) -> U::VecOperand;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
- T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
+ T: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: Option<ast::Type>,
+ ) -> spirv::Word {
+ self(desc, t)
}
- fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
+ self(desc, Some(t))
}
- fn src_call_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc.new_op(desc.op))
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: ast::Type,
+ ) -> spirv::Word {
+ self(desc, Some(t))
}
fn src_vec_operand(
&mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- ) -> (spirv::Word, u8) {
- (self(desc.new_op(desc.op.0)), desc.op.1)
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: ast::MovVectorType,
+ ) -> spirv::Word {
+ self(desc, Some(ast::Type::Scalar(t.into())))
}
}
@@ -1806,13 +1912,14 @@ impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> fo
where
T: FnMut(&str) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word {
+ fn variable(&mut self, desc: ArgumentDescriptor<&str>, _: Option<ast::Type>) -> spirv::Word {
self(desc.op)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
+ _: ast::Type,
) -> ast::Operand<spirv::Word> {
match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)),
@@ -1824,6 +1931,7 @@ where
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<&str>>,
+ _: ast::Type,
) -> ast::CallOperand<spirv::Word> {
match desc.op {
ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(id)),
@@ -1831,15 +1939,18 @@ where
}
}
- fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(&str, u8)>) -> (spirv::Word, u8) {
+ fn src_vec_operand(
+ &mut self,
+ desc: ArgumentDescriptor<(&str, u8)>,
+ _: ast::MovVectorType,
+ ) -> (spirv::Word, u8) {
(self(desc.op.0), desc.op.1)
}
}
-struct ArgumentDescriptor<T> {
- op: T,
+struct ArgumentDescriptor<Op> {
+ op: Op,
is_dst: bool,
- typ: Option<ast::Type>,
is_pointer: bool,
}
@@ -1848,7 +1959,6 @@ impl<T> ArgumentDescriptor<T> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
- typ: self.typ,
is_pointer: self.is_pointer,
}
}
@@ -1860,39 +1970,35 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
visitor: &mut V,
) -> ast::Instruction<U> {
match self {
- ast::Instruction::MovVector(_, _) => todo!(),
+ ast::Instruction::MovVector(t, a) => ast::Instruction::MovVector(t, a.map(visitor, t)),
ast::Instruction::Abs(_, _) => todo!(),
+ // Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
let src_is_pointer = d.state_space != ast::LdStateSpace::Param;
- ast::Instruction::Ld(
- d,
- a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer),
- )
+ ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, src_is_pointer))
}
ast::Instruction::Mov(mov_type, a) => {
- ast::Instruction::Mov(mov_type, a.map(visitor, Some(mov_type.into())))
+ ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into()))
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, Some(inst_type)))
+ ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type))
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, Some(inst_type)))
+ ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type))
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
- ast::Instruction::Setp(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
+ ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type)))
}
ast::Instruction::SetpBool(d, a) => {
let inst_type = d.typ;
- ast::Instruction::SetpBool(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
- }
- ast::Instruction::Not(t, a) => {
- ast::Instruction::Not(t, a.map(visitor, Some(t.to_type())))
+ ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type)))
}
+ ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, t.to_type())),
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -1915,28 +2021,28 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t))
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, Some(t.to_type())))
+ ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type()))
}
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
let param_space = d.state_space == ast::StStateSpace::Param;
- ast::Instruction::St(
- d,
- a.map(visitor, Some(ast::Type::Scalar(inst_type)), param_space),
- )
+ ast::Instruction::St(d, a.map(visitor, inst_type, param_space))
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
- ast::Instruction::Cvta(d, a.map(visitor, Some(inst_type)))
+ ast::Instruction::Cvta(d, a.map(visitor, inst_type))
}
}
}
}
impl VisitVariable for ast::Instruction<NormalizedArgParams> {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement {
@@ -1946,29 +2052,37 @@ impl VisitVariable for ast::Instruction<NormalizedArgParams> {
impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
where
- T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
+ T: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: Option<ast::Type>,
+ ) -> spirv::Word {
+ self(desc, t)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ t: ast::Type,
) -> ast::Operand<spirv::Word> {
match desc.op {
- ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id))),
+ ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id), Some(t))),
ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
- ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(desc.new_op(id)), imm),
+ ast::Operand::RegOffset(id, imm) => {
+ ast::Operand::RegOffset(self(desc.new_op(id), Some(t)), imm)
+ }
}
}
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ t: ast::Type,
) -> ast::CallOperand<spirv::Word> {
match desc.op {
- ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id))),
+ ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id), Some(t))),
ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm),
}
}
@@ -1976,11 +2090,74 @@ where
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
+ t: ast::MovVectorType,
) -> (spirv::Word, u8) {
- (self(desc.new_op(desc.op.0)), desc.op.1)
+ (
+ self(
+ desc.new_op(desc.op.0),
+ Some(ast::Type::Vector(t.into(), desc.op.1)),
+ ),
+ desc.op.1,
+ )
+ }
+}
+
+impl ast::Type {
+ fn to_parts(self) -> TypeParts {
+ match self {
+ ast::Type::Scalar(scalar) => TypeParts {
+ kind: TypeKind::Scalar,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: 0,
+ },
+ ast::Type::Vector(scalar, components) => TypeParts {
+ kind: TypeKind::Vector,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: components as u32,
+ },
+ ast::Type::Array(scalar, components) => TypeParts {
+ kind: TypeKind::Array,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: components,
+ },
+ }
+ }
+
+ fn from_parts(t: TypeParts) -> Self {
+ match t.kind {
+ TypeKind::Scalar => {
+ ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind))
+ }
+ TypeKind::Vector => ast::Type::Vector(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
+ t.components as u8,
+ ),
+ TypeKind::Array => ast::Type::Array(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
+ t.components,
+ ),
+ }
}
}
+#[derive(Eq, PartialEq, Copy, Clone)]
+struct TypeParts {
+ kind: TypeKind,
+ scalar_kind: ScalarKind,
+ width: u8,
+ components: u32,
+}
+
+#[derive(Eq, PartialEq, Copy, Clone)]
+enum TypeKind {
+ Scalar,
+ Vector,
+ Array,
+}
+
impl ast::Instruction<ExpandedArgParams> {
fn jump_target(&self) -> Option<spirv::Word> {
match self {
@@ -2005,7 +2182,9 @@ impl ast::Instruction<ExpandedArgParams> {
}
impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement {
@@ -2016,6 +2195,29 @@ impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
+struct CompositeAccess {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src: spirv::Word,
+ pub index: u32,
+ pub is_write: bool
+}
+
+struct CompositeWrite {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src_composite: spirv::Word,
+ pub src_scalar: spirv::Word,
+ pub index: u32,
+}
+
+struct CompositeRead {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src: spirv::Word,
+ pub index: u32,
+}
+
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
@@ -2028,6 +2230,7 @@ struct BrachCondition {
if_false: spirv::Word,
}
+#[derive(Copy, Clone)]
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
@@ -2036,7 +2239,7 @@ struct ImplicitConversion {
kind: ConversionKind,
}
-#[derive(Debug, PartialEq)]
+#[derive(Debug, PartialEq, Copy, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
@@ -2084,12 +2287,14 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
t: Option<ast::Type>,
) -> ast::Arg1<U> {
ast::Arg1 {
- src: visitor.variable(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ src: visitor.variable(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2098,43 +2303,51 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
fn map_ld<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
is_src_pointer: bool,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: is_src_pointer,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: is_src_pointer,
+ },
+ t,
+ ),
}
}
@@ -2145,18 +2358,22 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
src_t: ast::Type,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: Some(dst_t),
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: Some(src_t),
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(dst_t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ src_t,
+ ),
}
}
}
@@ -2165,22 +2382,26 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
param_space: bool,
) -> ast::Arg2St<U> {
ast::Arg2St {
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: param_space,
- is_pointer: !param_space,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: param_space,
+ is_pointer: !param_space,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2189,107 +2410,149 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: ast::MovVectorType,
) -> ast::Arg2Vec<U> {
match self {
ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst(
- visitor.src_vec_operand(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.variable(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ t,
+ ),
+ visitor.variable(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(t.into())),
+ ),
),
- ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src (
- visitor.variable(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.src_vec_operand(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src(
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(t.into())),
+ ),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
),
- ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both (
- visitor.src_vec_operand(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.src_vec_operand(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both(
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ t,
+ ),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
),
}
}
}
+impl ast::Arg2Vec<ExpandedArgParams> {
+ fn dst(&self) -> spirv::Word {
+ match self {
+ ast::Arg2Vec::Dst(dst, _) | ast::Arg2Vec::Src(dst, _) | ast::Arg2Vec::Both(dst, _) => {
+ *dst
+ }
+ }
+ }
+
+ fn src(&self) -> spirv::Word {
+ match self {
+ ast::Arg2Vec::Dst(_, src) | ast::Arg2Vec::Src(_, src) | ast::Arg2Vec::Both(_, src) => {
+ *src
+ }
+ }
+ }
+}
+
impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg3<U> {
ast::Arg3 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
fn map_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg3<U> {
ast::Arg3 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::U32)),
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ),
}
}
}
@@ -2298,35 +2561,43 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg4<U> {
ast::Arg4 {
- dst1: visitor.variable(ArgumentDescriptor {
- op: self.dst1,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: true,
- is_pointer: false,
- }),
- dst2: self.dst2.map(|dst2| {
- visitor.variable(ArgumentDescriptor {
- op: dst2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ dst1: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst1,
is_dst: true,
is_pointer: false,
- })
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ ),
+ dst2: self.dst2.map(|dst2| {
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst2,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ )
}),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2335,41 +2606,51 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg5<U> {
ast::Arg5 {
- dst1: visitor.variable(ArgumentDescriptor {
- op: self.dst1,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: true,
- is_pointer: false,
- }),
- dst2: self.dst2.map(|dst2| {
- visitor.variable(ArgumentDescriptor {
- op: dst2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ dst1: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst1,
is_dst: true,
is_pointer: false,
- })
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src3: visitor.operand(ArgumentDescriptor {
- op: self.src3,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: false,
- is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ ),
+ dst2: self.dst2.map(|dst2| {
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst2,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ )
}),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src3: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ is_pointer: false,
+ },
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ),
}
}
}
@@ -2395,9 +2676,9 @@ impl ast::StStateSpace {
}
}
-#[derive(Clone, Copy, PartialEq)]
+#[derive(Clone, Copy, PartialEq, Eq)]
enum ScalarKind {
- Byte,
+ Bit,
Unsigned,
Signed,
Float,
@@ -2438,10 +2719,10 @@ impl ast::ScalarType {
ast::ScalarType::S16 => ScalarKind::Signed,
ast::ScalarType::S32 => ScalarKind::Signed,
ast::ScalarType::S64 => ScalarKind::Signed,
- ast::ScalarType::B8 => ScalarKind::Byte,
- ast::ScalarType::B16 => ScalarKind::Byte,
- ast::ScalarType::B32 => ScalarKind::Byte,
- ast::ScalarType::B64 => ScalarKind::Byte,
+ ast::ScalarType::B8 => ScalarKind::Bit,
+ ast::ScalarType::B16 => ScalarKind::Bit,
+ ast::ScalarType::B32 => ScalarKind::Bit,
+ ast::ScalarType::B64 => ScalarKind::Bit,
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
@@ -2458,7 +2739,7 @@ impl ast::ScalarType {
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
- ScalarKind::Byte => match width {
+ ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
@@ -2574,22 +2855,159 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
return false;
}
match inst.kind() {
- ScalarKind::Byte => operand.kind() != ScalarKind::Byte,
- ScalarKind::Float => operand.kind() == ScalarKind::Byte,
+ ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
+ ScalarKind::Float => operand.kind() == ScalarKind::Bit,
ScalarKind::Signed => {
- operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Unsigned
+ operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
}
ScalarKind::Unsigned => {
- operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
+ operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
}
- ScalarKind::Float2 => todo!(),
+ ScalarKind::Float2 => false,
ScalarKind::Pred => false,
}
}
+ (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
+ | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
+ should_bitcast(ast::Type::Scalar(inst), ast::Type::Scalar(operand))
+ }
_ => false,
}
}
+fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>>(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+ mut instr: T,
+ pre_conv_src: impl ExactSizeIterator<Item = ImplicitConversion>,
+ pre_conv_dst: impl ExactSizeIterator<Item = ImplicitConversion>,
+ mut post_conv: Vec<ImplicitConversion>,
+ mut src: impl FnMut(&mut T) -> &mut spirv::Word,
+ mut dst: impl FnMut(&mut T) -> &mut spirv::Word,
+ to_inst: ToInstruction,
+) {
+ insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
+ insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
+ if post_conv.len() > 0 {
+ let new_id = id_def.new_id(Some(post_conv[0].from));
+ post_conv[0].src = new_id;
+ post_conv.last_mut().unwrap().dst = *dst(&mut instr);
+ *dst(&mut instr) = new_id;
+ }
+ func.push(Statement::Instruction(to_inst(instr)));
+ for conv in post_conv {
+ func.push(Statement::Conversion(conv));
+ }
+}
+
+fn insert_with_conversions_pre_conv<T>(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+ mut instr: &mut T,
+ pre_conv: impl ExactSizeIterator<Item = ImplicitConversion>,
+ src: &mut impl FnMut(&mut T) -> &mut spirv::Word,
+) {
+ let pre_conv_len = pre_conv.len();
+ for (i, mut conv) in pre_conv.enumerate() {
+ let original_src = src(&mut instr);
+ if i == 0 {
+ conv.src = *original_src;
+ }
+ if i == pre_conv_len - 1 {
+ let new_id = id_def.new_id(Some(conv.to));
+ conv.dst = new_id;
+ *original_src = new_id;
+ }
+ func.push(Statement::Conversion(conv));
+ }
+}
+
+fn get_implicit_conversions_ld_dst<
+ ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
+>(
+ id_def: &mut NumericIdResolver,
+ instr_type: ast::Type,
+ dst: spirv::Word,
+ should_convert: ShouldConvert,
+ in_reverse: bool,
+) -> Option<ImplicitConversion> {
+ let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!());
+ if let Some(conv) = should_convert(dst_type, instr_type) {
+ Some(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: if !in_reverse { dst_type } else { instr_type },
+ to: if !in_reverse { instr_type } else { dst_type },
+ kind: conv,
+ })
+ } else {
+ None
+ }
+}
+
+fn get_implicit_conversions_ld_src(
+ id_def: &mut NumericIdResolver,
+ instr_type: ast::Type,
+ state_space: ast::LdStateSpace,
+ src: spirv::Word,
+) -> Vec<ImplicitConversion> {
+ let src_type = id_def.get_type(src).unwrap_or_else(|| todo!());
+ match state_space {
+ ast::LdStateSpace::Param => {
+ if src_type != instr_type {
+ vec![
+ ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: instr_type,
+ kind: ConversionKind::Default,
+ };
+ 1
+ ]
+ } else {
+ Vec::new()
+ }
+ }
+ ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
+ let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
+ mem::size_of::<usize>() as u8,
+ ScalarKind::Bit,
+ ));
+ let mut result = Vec::new();
+ // HACK ALERT
+ // IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
+ // additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
+ // TODO: error out if the src is not B64/U64/S64
+ if let ast::Type::Scalar(scalar_src_type) = src_type {
+ if scalar_src_type.kind() == ScalarKind::Signed {
+ result.push(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: new_src_type,
+ kind: ConversionKind::Default,
+ });
+ }
+ }
+ result.push(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: instr_type,
+ kind: ConversionKind::Ptr(state_space),
+ });
+ if result.len() == 2 {
+ let new_id = id_def.new_id(Some(new_src_type));
+ result[0].dst = new_id;
+ result[1].src = new_id;
+ result[1].from = new_src_type;
+ }
+ result
+ }
+ _ => todo!(),
+ }
+}
fn insert_implicit_conversions_ld_src(
func: &mut Vec<ExpandedStatement>,
instr_type: ast::Type,
@@ -2608,7 +3026,7 @@ fn insert_implicit_conversions_ld_src(
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
mem::size_of::<usize>() as u8,
- ScalarKind::Byte,
+ ScalarKind::Bit,
));
let new_src = insert_implicit_conversions_ld_src_impl(
func,
@@ -2640,8 +3058,8 @@ fn insert_implicit_conversions_ld_src_impl<
should_convert: ShouldConvert,
) -> spirv::Word {
let src_type = id_def.get_type(src);
- if let Some(conv) = should_convert(src_type, instr_type) {
- insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
+ if let Some(conv) = should_convert(src_type.unwrap(), instr_type) {
+ insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv)
} else {
src
}
@@ -2692,14 +3110,15 @@ fn insert_conversion_src(
temp_src
}
+/*
fn insert_with_implicit_conversion_dst<
T,
- ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
+ ShouldConvert: FnOnce(ast::StateSpace, ast::Type, ast::Type) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>,
>(
func: &mut Vec<ExpandedStatement>,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
id_def: &mut NumericIdResolver,
should_convert: ShouldConvert,
mut t: T,
@@ -2708,13 +3127,14 @@ fn insert_with_implicit_conversion_dst<
) {
let dst = setter(&mut t);
let dst_type = id_def.get_type(*dst);
- let dst_coercion = should_convert(dst_type, instr_type)
- .map(|conv| get_conversion_dst(id_def, dst, ast::Type::Scalar(instr_type), dst_type, conv));
+ let dst_coercion = should_convert(dst_type.unwrap(), instr_type)
+ .map(|conv| get_conversion_dst(id_def, dst, instr_type, dst_type.unwrap(), conv));
func.push(Statement::Instruction(to_inst(t)));
if let Some(conv) = dst_coercion {
func.push(conv);
}
}
+*/
#[must_use]
fn get_conversion_dst(
@@ -2739,14 +3159,14 @@ fn get_conversion_dst(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: ast::Type,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
) -> Option<ConversionKind> {
- if src_type == ast::Type::Scalar(instr_type) {
+ if src_type == instr_type {
return None;
}
- match src_type {
- ast::Type::Scalar(src_type) => match instr_type.kind() {
- ScalarKind::Byte => {
+ match (src_type, instr_type) {
+ (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ScalarKind::Bit => {
if instr_type.width() <= src_type.width() {
Some(ConversionKind::Default)
} else {
@@ -2761,7 +3181,7 @@ fn should_convert_relaxed_src(
}
}
ScalarKind::Float => {
- if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Byte {
+ if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Bit {
Some(ConversionKind::Default)
} else {
None
@@ -2770,6 +3190,10 @@ fn should_convert_relaxed_src(
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
+ (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
+ | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
+ should_convert_relaxed_src(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ }
_ => None,
}
}
@@ -2777,14 +3201,14 @@ fn should_convert_relaxed_src(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: ast::Type,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
) -> Option<ConversionKind> {
- if dst_type == ast::Type::Scalar(instr_type) {
+ if dst_type == instr_type {
return None;
}
- match dst_type {
- ast::Type::Scalar(dst_type) => match instr_type.kind() {
- ScalarKind::Byte => {
+ match (dst_type, instr_type) {
+ (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ScalarKind::Bit => {
if instr_type.width() <= dst_type.width() {
Some(ConversionKind::Default)
} else {
@@ -2812,7 +3236,7 @@ fn should_convert_relaxed_dst(
}
}
ScalarKind::Float => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Byte {
+ if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Bit {
Some(ConversionKind::Default)
} else {
None
@@ -2821,6 +3245,10 @@ fn should_convert_relaxed_dst(
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
+ (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
+ | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
+ should_convert_relaxed_dst(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ }
_ => None,
}
}
@@ -2831,13 +3259,13 @@ fn insert_implicit_bitcasts(
stmt: impl VisitVariableExpanded,
) {
let mut dst_coercion = None;
- let instr = stmt.visit_variable_extended(&mut |mut desc| {
- let id_type_from_instr = match desc.typ {
+ let instr = stmt.visit_variable_extended(&mut |mut desc, typ| {
+ let id_type_from_instr = match typ {
Some(t) => t,
None => return desc.op,
};
- let id_actual_type = id_def.get_type(desc.op);
- if should_bitcast(id_type_from_instr, id_def.get_type(desc.op)) {
+ let id_actual_type = id_def.get_type(desc.op).unwrap();
+ if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
@@ -2970,14 +3398,14 @@ mod tests {
.collect::<Vec<_>>()
}
- fn assert_conversion_table<F: Fn(ast::Type, ast::ScalarType) -> Option<ConversionKind>>(
+ fn assert_conversion_table<F: Fn(ast::Type, ast::Type) -> Option<ConversionKind>>(
table: &'static str,
f: F,
) {
let conv_table = parse_conversion_table(table);
for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
- let conversion = f(ast::Type::Scalar(*op_type), *instr_type);
+ let conversion = f(ast::Type::Scalar(*op_type), ast::Type::Scalar(*instr_type));
if instr_idx == op_idx {
assert_eq!(conversion, None);
} else {