diff options
author | Andrzej Janik <[email protected]> | 2020-09-14 21:45:56 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-09-14 21:45:56 +0200 |
commit | bb5025c9b17e3fc46e454ca8faab1e85e0361ba8 (patch) | |
tree | 07df096e1ad16e8c9464aac17c99194e7257937e /ptx | |
parent | 48dac435400117935624aed244d1442982c874e2 (diff) | |
download | ZLUDA-bb5025c9b17e3fc46e454ca8faab1e85e0361ba8.tar.gz ZLUDA-bb5025c9b17e3fc46e454ca8faab1e85e0361ba8.zip |
Refactor implicit conversions and start implementing vector extract/insert
Diffstat (limited to 'ptx')
-rw-r--r-- | ptx/src/ast.rs | 14 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 15 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/call.spvtxt | 52 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/cvta.spvtxt | 24 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/ld_st_implicit.ptx | 20 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/ld_st_implicit.spvtxt | 48 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 64 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/not.spvtxt | 16 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/shl.spvtxt | 16 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/vector.spvtxt | 117 | ||||
-rw-r--r-- | ptx/src/translate.rs | 1270 |
11 files changed, 1117 insertions, 539 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 7921930..078cb31 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -349,6 +349,7 @@ pub trait ArgParams { type ID; type Operand; type CallOperand; + type VecOperand; } pub struct ParsedArgParams<'a> { @@ -359,6 +360,7 @@ impl<'a> ArgParams for ParsedArgParams<'a> { type ID = &'a str; type Operand = Operand<&'a str>; type CallOperand = CallOperand<&'a str>; + type VecOperand = (&'a str, u8); } pub struct Arg1<P: ArgParams> { @@ -376,9 +378,9 @@ pub struct Arg2St<P: ArgParams> { } pub enum Arg2Vec<P: ArgParams> { - Dst((P::ID, u8), P::ID), - Src(P::ID, (P::ID, u8)), - Both((P::ID, u8), (P::ID, u8)), + Dst(P::VecOperand, P::ID), + Src(P::ID, P::VecOperand), + Both(P::VecOperand, P::VecOperand), } pub struct Arg3<P: ArgParams> { @@ -424,8 +426,7 @@ pub struct LdData { pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, - pub vector: Option<u8>, - pub typ: ScalarType, + pub typ: Type, } #[derive(Copy, Clone, PartialEq, Eq)] @@ -710,8 +711,7 @@ pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, pub caching: StCacheOperator, - pub vector: Option<u8>, - pub typ: ScalarType, + pub typ: Type, } #[derive(PartialEq, Eq, Copy, Clone)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index fd419f5..6e5f5e3 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -269,10 +269,10 @@ ScalarType: ast::ScalarType = { ".f16" => ast::ScalarType::F16, ".f16x2" => ast::ScalarType::F16x2, ".pred" => ast::ScalarType::Pred, - MemoryType + LdStScalarType }; -MemoryType: ast::ScalarType = { +LdStScalarType: ast::ScalarType = { ".b8" => ast::ScalarType::B8, ".b16" => ast::ScalarType::B16, ".b32" => ast::ScalarType::B32, @@ -446,13 +446,12 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { - "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ExtendedID> "," "[" <src:Operand> "]" => { + "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," "[" <src:Operand> "]" => { ast::Instruction::Ld( ast::LdData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), state_space: ss.unwrap_or(ast::LdStateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), - vector: v, typ: t }, ast::Arg2 { dst:dst, src:src } @@ -460,6 +459,11 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { } }; +LdStType: ast::Type = { + <v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v), + <t:LdStScalarType> => ast::Type::Scalar(t), +} + LdStQualifier: ast::LdStQualifier = { ".weak" => ast::LdStQualifier::Weak, ".volatile" => ast::LdStQualifier::Volatile, @@ -895,13 +899,12 @@ ShlType: ast::ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = { - "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => { + "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> "[" <src1:Operand> "]" "," <src2:Operand> => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), state_space: ss.unwrap_or(ast::StStateSpace::Generic), caching: cop.unwrap_or(ast::StCacheOperator::Writeback), - vector: v, typ: t }, ast::Arg2St { src1:src1, src2:src2 } diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt index 001cda3..ca4685a 100644 --- a/ptx/src/test/spirv_run/call.spvtxt +++ b/ptx/src/test/spirv_run/call.spvtxt @@ -4,20 +4,20 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %45 = OpExtInstImport "OpenCL.std" + %47 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %4 "call" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %48 = OpTypeFunction %void %ulong %ulong + %50 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %51 = OpTypeFunction %ulong %ulong + %53 = OpTypeFunction %ulong %ulong %ulong_1 = OpConstant %ulong 1 - %4 = OpFunction %void None %48 + %4 = OpFunction %void None %50 %12 = OpFunctionParameter %ulong %13 = OpFunctionParameter %ulong - %30 = OpLabel + %32 = OpLabel %5 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_ulong Function %7 = OpVariable %_ptr_Function_ulong Function @@ -38,7 +38,9 @@ %18 = OpLoad %ulong %28 OpStore %9 %18 %21 = OpLoad %ulong %9 - %20 = OpCopyObject %ulong %21 + %29 = OpCopyObject %ulong %21 + %30 = OpCopyObject %ulong %29 + %20 = OpCopyObject %ulong %30 OpStore %10 %20 %23 = OpLoad %ulong %10 %22 = OpFunctionCall %ulong %1 %23 @@ -48,26 +50,26 @@ OpStore %9 %24 %26 = OpLoad %ulong %8 %27 = OpLoad %ulong %9 - %29 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 - OpStore %29 %27 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 + OpStore %31 %27 OpReturn OpFunctionEnd - %1 = OpFunction %ulong None %51 - %34 = OpFunctionParameter %ulong - %43 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function - %31 = OpVariable %_ptr_Function_ulong Function + %1 = OpFunction %ulong None %53 + %36 = OpFunctionParameter %ulong + %45 = OpLabel + %34 = OpVariable %_ptr_Function_ulong Function %33 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %34 - %36 = OpLoad %ulong %32 - %35 = OpCopyObject %ulong %36 - OpStore %33 %35 - %38 = OpLoad %ulong %33 - %37 = OpIAdd %ulong %38 %ulong_1 - OpStore %33 %37 - %40 = OpLoad %ulong %33 - %39 = OpCopyObject %ulong %40 - OpStore %31 %39 - %41 = OpLoad %ulong %31 - OpReturnValue %41 + %35 = OpVariable %_ptr_Function_ulong Function + OpStore %34 %36 + %38 = OpLoad %ulong %34 + %37 = OpCopyObject %ulong %38 + OpStore %35 %37 + %40 = OpLoad %ulong %35 + %39 = OpIAdd %ulong %40 %ulong_1 + OpStore %35 %39 + %42 = OpLoad %ulong %35 + %41 = OpCopyObject %ulong %42 + OpStore %33 %41 + %43 = OpLoad %ulong %33 + OpReturnValue %43 OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt index e708613..84e7eac 100644 --- a/ptx/src/test/spirv_run/cvta.spvtxt +++ b/ptx/src/test/spirv_run/cvta.spvtxt @@ -4,20 +4,20 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %25 = OpExtInstImport "OpenCL.std" + %29 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvta" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong + %32 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %float = OpTypeFloat 32 %_ptr_Function_float = OpTypePointer Function %float %_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %1 = OpFunction %void None %28 + %1 = OpFunction %void None %32 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %23 = OpLabel + %27 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -32,18 +32,22 @@ %11 = OpCopyObject %ulong %12 OpStore %5 %11 %14 = OpLoad %ulong %4 - %13 = OpCopyObject %ulong %14 + %22 = OpCopyObject %ulong %14 + %21 = OpCopyObject %ulong %22 + %13 = OpCopyObject %ulong %21 OpStore %4 %13 %16 = OpLoad %ulong %5 - %15 = OpCopyObject %ulong %16 + %24 = OpCopyObject %ulong %16 + %23 = OpCopyObject %ulong %24 + %15 = OpCopyObject %ulong %23 OpStore %5 %15 %18 = OpLoad %ulong %4 - %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 - %17 = OpLoad %float %21 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 + %17 = OpLoad %float %25 OpStore %6 %17 %19 = OpLoad %ulong %5 %20 = OpLoad %float %6 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 - OpStore %22 %20 + %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 + OpStore %26 %20 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ld_st_implicit.ptx b/ptx/src/test/spirv_run/ld_st_implicit.ptx new file mode 100644 index 0000000..8562286 --- /dev/null +++ b/ptx/src/test/spirv_run/ld_st_implicit.ptx @@ -0,0 +1,20 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry ld_st_implicit(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.f32 temp, [in_addr];
+ st.global.f32 [out_addr], temp;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/ld_st_implicit.spvtxt b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt new file mode 100644 index 0000000..e7dba5a --- /dev/null +++ b/ptx/src/test/spirv_run/ld_st_implicit.spvtxt @@ -0,0 +1,48 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %23 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "ld_st_implicit" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %26 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float + %uint = OpTypeInt 32 0 + %1 = OpFunction %void None %26 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %21 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %14 + %18 = OpLoad %float %17 + %30 = OpBitcast %ulong %18 + %32 = OpUConvert %uint %30 + %13 = OpBitcast %uint %32 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %ulong %6 + %33 = OpBitcast %uint %16 + %19 = OpUConvert %ulong %33 + %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %15 + OpStore %20 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index a04f0eb..fd50d3c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -8,10 +8,12 @@ use spirv_headers::Word; use spirv_tools_sys::{
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
};
+use std::collections::hash_map::Entry;
use std::error;
use std::ffi::{c_void, CStr, CString};
use std::fmt;
use std::fmt::{Debug, Display, Formatter};
+use std::hash::Hash;
use std::mem;
use std::slice;
use std::{borrow::Cow, collections::HashMap, env, fs, path::PathBuf, ptr, str};
@@ -41,6 +43,7 @@ macro_rules! test_ptx { }
test_ptx!(ld_st, [1u64], [1u64]);
+test_ptx!(ld_st_implicit, [0.5f32], [0.5f32]);
test_ptx!(mov, [1u64], [1u64]);
test_ptx!(mul_lo, [1u64], [2u64]);
test_ptx!(mul_hi, [u64::max_value()], [1u64]);
@@ -214,14 +217,45 @@ fn test_spvtxt_assert<'a>( }
}
}
- panic!(spirv_text);
+ panic!(spirv_text.to_string());
}
unsafe { spirv_tools::spvContextDestroy(spv_context) };
Ok(())
}
+struct EqMap<T>
+where
+ T: Eq + Copy + Hash,
+{
+ m1: HashMap<T, T>,
+ m2: HashMap<T, T>,
+}
+
+impl<T: Copy + Eq + Hash> EqMap<T> {
+ fn new() -> Self {
+ EqMap {
+ m1: HashMap::new(),
+ m2: HashMap::new(),
+ }
+ }
+
+ fn is_equal(&mut self, t1: T, t2: T) -> bool {
+ match (self.m1.entry(t1), self.m2.entry(t2)) {
+ (Entry::Occupied(entry1), Entry::Occupied(entry2)) => {
+ *entry1.get() == t2 && *entry2.get() == t1
+ }
+ (Entry::Vacant(entry1), Entry::Vacant(entry2)) => {
+ entry1.insert(t2);
+ entry2.insert(t1);
+ true
+ }
+ _ => false,
+ }
+ }
+}
+
fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool {
- let mut map = HashMap::new();
+ let mut map = EqMap::new();
if !is_option_equal(&fn1.def, &fn2.def, &mut map, is_instr_equal) {
return false;
}
@@ -247,7 +281,7 @@ fn is_spirv_fn_equal(fn1: &Function, fn2: &Function) -> bool { true
}
-fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap<Word, Word>) -> bool {
+fn is_block_equal(b1: &Block, b2: &Block, map: &mut EqMap<Word>) -> bool {
if !is_option_equal(&b1.label, &b2.label, map, is_instr_equal) {
return false;
}
@@ -262,11 +296,7 @@ fn is_block_equal(b1: &Block, b2: &Block, map: &mut HashMap<Word, Word>) -> bool true
}
-fn is_instr_equal(
- instr1: &Instruction,
- instr2: &Instruction,
- map: &mut HashMap<Word, Word>,
-) -> bool {
+fn is_instr_equal(instr1: &Instruction, instr2: &Instruction, map: &mut EqMap<Word>) -> bool {
if instr1.class.opcode != instr2.class.opcode {
return false;
}
@@ -306,24 +336,14 @@ fn is_instr_equal( true
}
-fn is_word_equal(w1: &Word, w2: &Word, map: &mut HashMap<Word, Word>) -> bool {
- match map.entry(*w1) {
- std::collections::hash_map::Entry::Occupied(entry) => {
- if entry.get() != w2 {
- return false;
- }
- }
- std::collections::hash_map::Entry::Vacant(entry) => {
- entry.insert(*w2);
- }
- }
- true
+fn is_word_equal(t1: &Word, t2: &Word, map: &mut EqMap<Word>) -> bool {
+ map.is_equal(*t1, *t2)
}
-fn is_option_equal<T, F: FnOnce(&T, &T, &mut HashMap<Word, Word>) -> bool>(
+fn is_option_equal<T, F: FnOnce(&T, &T, &mut EqMap<Word>) -> bool>(
o1: &Option<T>,
o2: &Option<T>,
- map: &mut HashMap<Word, Word>,
+ map: &mut EqMap<Word>,
f: F,
) -> bool {
match (o1, o2) {
diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt index de340ed..b358858 100644 --- a/ptx/src/test/spirv_run/not.spvtxt +++ b/ptx/src/test/spirv_run/not.spvtxt @@ -4,18 +4,18 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %24 = OpExtInstImport "OpenCL.std" + %26 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "not" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %27 = OpTypeFunction %void %ulong %ulong + %29 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong - %1 = OpFunction %void None %27 + %1 = OpFunction %void None %29 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong - %22 = OpLabel + %24 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -35,11 +35,13 @@ %14 = OpLoad %ulong %20 OpStore %6 %14 %17 = OpLoad %ulong %6 - %16 = OpNot %ulong %17 + %22 = OpCopyObject %ulong %17 + %21 = OpNot %ulong %22 + %16 = OpCopyObject %ulong %21 OpStore %7 %16 %18 = OpLoad %ulong %5 %19 = OpLoad %ulong %7 - %21 = OpConvertUToPtr %_ptr_Generic_ulong %18 - OpStore %21 %19 + %23 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %23 %19 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt index dbd2664..4843a65 100644 --- a/ptx/src/test/spirv_run/shl.spvtxt +++ b/ptx/src/test/spirv_run/shl.spvtxt @@ -4,20 +4,20 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %25 = OpExtInstImport "OpenCL.std" + %27 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "shl" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong + %30 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong %uint = OpTypeInt 32 0 %uint_2 = OpConstant %uint 2 - %1 = OpFunction %void None %28 + %1 = OpFunction %void None %30 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong - %23 = OpLabel + %25 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -37,11 +37,13 @@ %14 = OpLoad %ulong %21 OpStore %6 %14 %17 = OpLoad %ulong %6 - %16 = OpShiftLeftLogical %ulong %17 %uint_2 + %23 = OpCopyObject %ulong %17 + %22 = OpShiftLeftLogical %ulong %23 %uint_2 + %16 = OpCopyObject %ulong %22 OpStore %7 %16 %18 = OpLoad %ulong %5 %19 = OpLoad %ulong %7 - %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 - OpStore %22 %19 + %24 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %24 %19 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index 6810fec..25dd80e 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -4,43 +4,92 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 - %25 = OpExtInstImport "OpenCL.std" + %58 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "add" + OpEntryPoint Kernel %31 "vector" %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %v2uint = OpTypeVector %uint 2 + %62 = OpTypeFunction %v2uint %v2uint +%_ptr_Function_v2uint = OpTypePointer Function %v2uint +%_ptr_Function_uint = OpTypePointer Function %uint %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong + %66 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %1 = OpFunction %void None %28 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %23 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %11 = OpLoad %ulong %2 - %10 = OpCopyObject %ulong %11 - OpStore %4 %10 - %13 = OpLoad %ulong %3 - %12 = OpCopyObject %ulong %13 - OpStore %5 %12 - %15 = OpLoad %ulong %4 - %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 - %14 = OpLoad %ulong %21 - OpStore %6 %14 - %17 = OpLoad %ulong %6 - %16 = OpIAdd %ulong %17 %ulong_1 - OpStore %7 %16 - %18 = OpLoad %ulong %5 - %19 = OpLoad %ulong %7 - %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 - OpStore %22 %19 +%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint + %1 = OpFunction %v2uint None %62 + %7 = OpFunctionParameter %v2uint + %30 = OpLabel + %3 = OpVariable %_ptr_Function_v2uint Function + %2 = OpVariable %_ptr_Function_v2uint Function + %4 = OpVariable %_ptr_Function_v2uint Function + %5 = OpVariable %_ptr_Function_uint Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %3 %7 + %9 = OpLoad %v2uint %3 + %24 = OpCompositeExtract %uint %9 0 + %8 = OpCopyObject %uint %24 + OpStore %5 %8 + %11 = OpLoad %v2uint %3 + %25 = OpCompositeExtract %uint %11 1 + %10 = OpCopyObject %uint %25 + OpStore %6 %10 + %13 = OpLoad %uint %5 + %14 = OpLoad %uint %6 + %12 = OpIAdd %uint %13 %14 + OpStore %6 %12 + %16 = OpLoad %uint %6 + %26 = OpCopyObject %uint %16 + %15 = OpCompositeInsert %uint %26 %15 0 + OpStore %4 %15 + %18 = OpLoad %uint %6 + %27 = OpCopyObject %uint %18 + %17 = OpCompositeInsert %uint %27 %17 1 + OpStore %4 %17 + %20 = OpLoad %v2uint %4 + %29 = OpCompositeExtract %uint %20 1 + %28 = OpCopyObject %uint %29 + %19 = OpCompositeInsert %uint %28 %19 0 + OpStore %4 %19 + %22 = OpLoad %v2uint %4 + %21 = OpCopyObject %v2uint %22 + OpStore %2 %21 + %23 = OpLoad %v2uint %2 + OpReturnValue %23 + OpFunctionEnd + %31 = OpFunction %void None %66 + %40 = OpFunctionParameter %ulong + %41 = OpFunctionParameter %ulong + %56 = OpLabel + %32 = OpVariable %_ptr_Function_ulong Function + %33 = OpVariable %_ptr_Function_ulong Function + %34 = OpVariable %_ptr_Function_ulong Function + %35 = OpVariable %_ptr_Function_ulong Function + %36 = OpVariable %_ptr_Function_v2uint Function + %37 = OpVariable %_ptr_Function_uint Function + %38 = OpVariable %_ptr_Function_uint Function + %39 = OpVariable %_ptr_Function_ulong Function + OpStore %32 %40 + OpStore %33 %41 + %43 = OpLoad %ulong %32 + %42 = OpCopyObject %ulong %43 + OpStore %34 %42 + %45 = OpLoad %ulong %33 + %44 = OpCopyObject %ulong %45 + OpStore %35 %44 + %47 = OpLoad %ulong %34 + %54 = OpConvertUToPtr %_ptr_Generic_v2uint %47 + %46 = OpLoad %v2uint %54 + OpStore %36 %46 + %49 = OpLoad %v2uint %36 + %48 = OpFunctionCall %v2uint %1 %49 + OpStore %36 %48 + %51 = OpLoad %v2uint %36 + %50 = OpCopyObject %ulong %51 + OpStore %39 %50 + %52 = OpLoad %ulong %35 + %53 = OpLoad %v2uint %36 + %55 = OpConvertUToPtr %_ptr_Generic_v2uint %52 + OpStore %55 %53 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7591722..57d3485 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,5 +1,5 @@ use crate::ast;
-use rspirv::dr;
+use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, iter, mem};
@@ -398,7 +398,8 @@ fn normalize_labels( labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
- Statement::Call(_)
+ Statement::Composite(_)
+ | Statement::Call(_)
| Statement::Variable(_)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
@@ -528,13 +529,13 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::Instruction::Ret(d) => {
if let Some(out_param) = out_param {
let typ = id_def.get_type(out_param);
- let new_id = id_def.new_id(Some(typ));
+ let new_id = id_def.new_id(typ);
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
src: out_param,
},
- typ,
+ typ.unwrap(),
));
result.push(Statement::RetValue(d, new_id));
} else {
@@ -561,19 +562,25 @@ fn insert_mem_ssa_statements<'a, 'b>( | Statement::Conversion(_)
| Statement::RetValue(_, _)
| Statement::Constant(_) => unreachable!(),
+ Statement::Composite(_) => todo!(),
}
}
(f_args, result)
}
trait VisitVariable: Sized {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement;
}
trait VisitVariableExpanded {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement;
@@ -585,8 +592,8 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( stmt: F,
) {
let mut post_statements = Vec::new();
- let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>| {
- let id_type = match (desc.typ, desc.is_pointer) {
+ let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
+ let id_type = match (id_def.get_type(desc.op), desc.is_pointer) {
(Some(t), false) => t,
(Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64),
(None, _) => return desc.op,
@@ -624,13 +631,15 @@ fn expand_arguments<'a, 'b>( match s {
Statement::Call(call) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let new_call = call.map(&mut visitor);
+ let (new_call, post_stmts) = (call.map(&mut visitor), visitor.post_stmts);
result.push(Statement::Call(new_call));
+ result.extend(post_stmts);
}
Statement::Instruction(inst) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let new_inst = inst.map(&mut visitor);
+ let (new_inst, post_stmts) = (inst.map(&mut visitor), visitor.post_stmts);
result.push(Statement::Instruction(new_inst));
+ result.extend(post_stmts);
}
Statement::Variable(ast::Variable {
align,
@@ -646,7 +655,9 @@ fn expand_arguments<'a, 'b>( Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Conversion(_) | Statement::Constant(_) => unreachable!(),
+ Statement::Composite(_) | Statement::Conversion(_) | Statement::Constant(_) => {
+ unreachable!()
+ }
}
}
result
@@ -655,74 +666,79 @@ fn expand_arguments<'a, 'b>( struct FlattenArguments<'a, 'b> {
func: &'b mut Vec<ExpandedStatement>,
id_def: &'b mut NumericIdResolver<'a>,
+ post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'b> FlattenArguments<'a, 'b> {
fn new(func: &'b mut Vec<ExpandedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
- FlattenArguments { func, id_def }
+ FlattenArguments {
+ func,
+ id_def,
+ post_stmts: Vec::new(),
+ }
}
}
impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
for FlattenArguments<'a, 'b>
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<ast::Type>,
+ ) -> spirv::Word {
desc.op
}
- fn operand(&mut self, desc: ArgumentDescriptor<ast::Operand<spirv::Word>>) -> spirv::Word {
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> spirv::Word {
match desc.op {
- ast::Operand::Reg(r) => self.variable(desc.new_op(r)),
+ ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)),
ast::Operand::Imm(x) => {
- if let Some(typ) = desc.typ {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id,
- typ: scalar_t,
- value: x,
- }));
- id
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
} else {
todo!()
- }
+ };
+ let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value: x,
+ }));
+ id
}
ast::Operand::RegOffset(reg, offset) => {
- if let Some(typ) = desc.typ {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i128,
- }));
- let result_id = self.id_def.new_id(desc.typ);
- let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- result_id
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
} else {
todo!()
- }
+ };
+ let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: scalar_t,
+ value: offset as i128,
+ }));
+ let result_id = self.id_def.new_id(Some(typ));
+ let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ result_id
}
}
}
@@ -730,18 +746,45 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ typ: ast::Type,
) -> spirv::Word {
match desc.op {
- ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg)),
- ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x))),
+ ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)),
+ ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ),
}
}
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- ) -> (spirv::Word, u8) {
- (self.variable(desc.new_op(desc.op.0)), desc.op.1)
+ typ: ast::MovVectorType,
+ ) -> spirv::Word {
+ let (vector_id, index) = desc.op;
+ let new_id = self.id_def.new_id(Some(ast::Type::Scalar(typ.into())));
+ let composite = if desc.is_dst {
+ Statement::Composite(CompositeAccess {
+ typ: typ,
+ dst: new_id,
+ src: vector_id,
+ index: index as u32,
+ is_write: true
+ })
+ } else {
+ Statement::Composite(CompositeAccess {
+ typ: typ,
+ dst: new_id,
+ src: vector_id,
+ index: index as u32,
+ is_write: false
+ })
+ };
+ if desc.is_dst {
+ self.post_stmts.push(composite);
+ new_id
+ } else {
+ self.func.push(composite);
+ new_id
+ }
}
}
@@ -768,48 +811,63 @@ fn insert_implicit_conversions( match s {
Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call),
Statement::Instruction(inst) => match inst {
- ast::Instruction::Ld(ld, mut arg) => {
- arg.src = insert_implicit_conversions_ld_src(
- &mut result,
- ast::Type::Scalar(ld.typ),
+ ast::Instruction::Ld(ld, arg) => {
+ let pre_conv =
+ get_implicit_conversions_ld_src(id_def, ld.typ, ld.state_space, arg.src);
+ let post_conv = get_implicit_conversions_ld_dst(
id_def,
- ld.state_space,
- arg.src,
+ ld.typ,
+ arg.dst,
+ should_convert_relaxed_dst,
+ false,
);
- insert_with_implicit_conversion_dst(
+ insert_with_conversions(
&mut result,
- ld.typ,
id_def,
- should_convert_relaxed_dst,
arg,
+ pre_conv.into_iter(),
+ iter::empty(),
+ post_conv.into_iter().collect(),
+ |arg| &mut arg.src,
|arg| &mut arg.dst,
|arg| ast::Instruction::Ld(ld, arg),
- );
+ )
}
- ast::Instruction::St(st, mut arg) => {
- let arg_src2_type = id_def.get_type(arg.src2);
- if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
- arg.src2 = insert_conversion_src(
- &mut result,
- id_def,
- arg.src2,
- arg_src2_type,
- ast::Type::Scalar(st.typ),
- conv,
- );
- }
- arg.src1 = insert_implicit_conversions_ld_src(
- &mut result,
- ast::Type::Scalar(st.typ),
+ ast::Instruction::St(st, arg) => {
+ let pre_conv = get_implicit_conversions_ld_dst(
id_def,
+ st.typ,
+ arg.src2,
+ should_convert_relaxed_src,
+ true,
+ );
+ let post_conv = get_implicit_conversions_ld_src(
+ id_def,
+ st.typ,
st.state_space.to_ld_ss(),
arg.src1,
);
- result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
+ let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param {
+ (Vec::new(), post_conv)
+ } else {
+ (post_conv, Vec::new())
+ };
+ insert_with_conversions(
+ &mut result,
+ id_def,
+ arg,
+ pre_conv.into_iter(),
+ pre_conv_dest.into_iter(),
+ post_conv,
+ |arg| &mut arg.src2,
+ |arg| &mut arg.src1,
+ |arg| ast::Instruction::St(st, arg),
+ )
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
},
- s @ Statement::Conditional(_)
+ s @ Statement::Composite(_)
+ | s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
@@ -950,10 +1008,10 @@ fn emit_function_body_ops( builder.branch(arg.src)?;
}
ast::Instruction::Ld(data, arg) => {
- if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
+ if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
- let result_type = map.get_or_add_scalar(builder, data.typ);
+ let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
match data.state_space {
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
@@ -967,7 +1025,6 @@ fn emit_function_body_ops( }
ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
- || data.vector.is_some()
|| (data.state_space != ast::StStateSpace::Generic
&& data.state_space != ast::StStateSpace::Param
&& data.state_space != ast::StStateSpace::Global)
@@ -1030,7 +1087,10 @@ fn emit_function_body_ops( builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
- ast::Instruction::MovVector(_, _) => todo!(),
+ ast::Instruction::MovVector(t, arg) => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
+ builder.copy_object(result_type, Some(arg.dst()), arg.src())?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@@ -1042,6 +1102,19 @@ fn emit_function_body_ops( Statement::RetValue(_, id) => {
builder.ret_value(*id)?;
}
+ Statement::Composite(c) => {
+ let result_type = map.get_or_add_scalar(builder, c.typ.into());
+ let result_id = Some(c.dst);
+ let indexes = [c.index];
+ if c.is_write {
+ let object = c.src;
+ let composite = c.dst;
+ builder.composite_insert(result_type, result_id, object, composite, indexes)?;
+ } else {
+ let composite = c.src;
+ builder.composite_extract(result_type, result_id, composite, indexes)?;
+ }
+ }
}
}
Ok(())
@@ -1188,7 +1261,7 @@ fn emit_setp( match (setp.cmp_op, setp.typ.kind()) {
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Eq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
@@ -1196,14 +1269,14 @@ fn emit_setp( }
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Less, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
@@ -1213,7 +1286,7 @@ fn emit_setp( builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::LessOrEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
@@ -1223,7 +1296,7 @@ fn emit_setp( builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Greater, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
@@ -1233,7 +1306,7 @@ fn emit_setp( builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
@@ -1294,54 +1367,56 @@ fn emit_implicit_conversion( map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), dr::Error> {
- let (from_type, to_type) = match (cv.from, cv.to) {
- (ast::Type::Scalar(from), ast::Type::Scalar(to)) => (from, to),
- _ => todo!(),
- };
+ let from_parts = cv.from.to_parts();
+ let to_parts = cv.to.to_parts();
match cv.kind {
ConversionKind::Ptr(space) => {
let dst_type = map.get_or_add(
builder,
- SpirvType::Pointer(
- Box::new(SpirvType::Base(SpirvScalarKey::from(to_type))),
- space.to_spirv(),
- ),
+ SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
ConversionKind::Default => {
- if from_type.width() == to_type.width() {
- let dst_type = map.get_or_add_scalar(builder, to_type);
- if from_type.kind() != ScalarKind::Float && to_type.kind() != ScalarKind::Float {
+ if from_parts.width == to_parts.width {
+ let dst_type = map.get_or_add(builder, SpirvType::from(cv.from));
+ if from_parts.scalar_kind != ScalarKind::Float
+ && to_parts.scalar_kind != ScalarKind::Float
+ {
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
} else {
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
} else {
- let as_unsigned_type = map.get_or_add_scalar(
- builder,
- ast::ScalarType::from_parts(from_type.width(), ScalarKind::Unsigned),
- );
- let as_unsigned = builder.bitcast(as_unsigned_type, None, cv.src)?;
- let as_unsigned_wide_type =
- ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned);
- let as_unsigned_wide_spirv = map.get_or_add_scalar(
+ // This block is safe because it's illegal to implictly convert between floating point instructions
+ let same_width_bit_type = map.get_or_add(
builder,
- ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned),
+ SpirvType::from(ast::Type::from_parts(TypeParts {
+ scalar_kind: ScalarKind::Bit,
+ ..from_parts
+ })),
);
- if to_type.kind() == ScalarKind::Unsigned || to_type.kind() == ScalarKind::Byte {
- builder.u_convert(as_unsigned_wide_spirv, Some(cv.dst), as_unsigned)?;
+ let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
+ let wide_bit_type = ast::Type::from_parts(TypeParts {
+ scalar_kind: ScalarKind::Bit,
+ ..to_parts
+ });
+ let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type));
+ if to_parts.scalar_kind == ScalarKind::Unsigned
+ || to_parts.scalar_kind == ScalarKind::Bit
+ {
+ builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else {
- let as_unsigned_wide =
- builder.u_convert(as_unsigned_wide_spirv, None, as_unsigned)?;
+ let wide_bit_value =
+ builder.u_convert(wide_bit_type_spirv, None, same_width_bit_value)?;
emit_implicit_conversion(
builder,
map,
&ImplicitConversion {
- src: as_unsigned_wide,
+ src: wide_bit_value,
dst: cv.dst,
- from: ast::Type::Scalar(as_unsigned_wide_type),
+ from: wide_bit_type,
to: cv.to,
kind: ConversionKind::Default,
},
@@ -1627,8 +1702,8 @@ struct NumericIdResolver<'b> { }
impl<'b> NumericIdResolver<'b> {
- fn get_type(&self, id: spirv::Word) -> ast::Type {
- self.type_check[&id]
+ fn get_type(&self, id: spirv::Word) -> Option<ast::Type> {
+ self.type_check.get(&id).map(|x| *x)
}
fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
@@ -1648,6 +1723,7 @@ enum Statement<I, P: ast::ArgParams> { LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Call(ResolvedCall<P>),
+ Composite(CompositeAccess),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Conversion(ImplicitConversion),
@@ -1671,31 +1747,37 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> { .ret_params
.into_iter()
.map(|(id, typ)| {
- let new_id = visitor.variable(ArgumentDescriptor {
- op: id,
- typ: Some(typ.into()),
- is_dst: true,
- is_pointer: false,
- });
+ let new_id = visitor.variable(
+ ArgumentDescriptor {
+ op: id,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(typ.into()),
+ );
(new_id, typ)
})
.collect();
- let func = visitor.variable(ArgumentDescriptor {
- op: self.func,
- typ: None,
- is_dst: false,
- is_pointer: false,
- });
+ let func = visitor.variable(
+ ArgumentDescriptor {
+ op: self.func,
+ is_dst: false,
+ is_pointer: false,
+ },
+ None,
+ );
let param_list = self
.param_list
.into_iter()
.map(|(id, typ)| {
- let new_id = visitor.src_call_operand(ArgumentDescriptor {
- op: id,
- typ: Some(typ.into()),
- is_dst: false,
- is_pointer: false,
- });
+ let new_id = visitor.src_call_operand(
+ ArgumentDescriptor {
+ op: id,
+ is_dst: false,
+ is_pointer: false,
+ },
+ typ.into(),
+ );
(new_id, typ)
})
.collect();
@@ -1709,7 +1791,10 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> { }
impl VisitVariable for ResolvedCall<NormalizedArgParams> {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement {
@@ -1718,7 +1803,9 @@ impl VisitVariable for ResolvedCall<NormalizedArgParams> { }
impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement {
@@ -1750,6 +1837,7 @@ impl ast::ArgParams for NormalizedArgParams { type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
+ type VecOperand = (spirv::Word, u8);
}
impl ArgParamsEx for NormalizedArgParams {
@@ -1766,6 +1854,7 @@ impl ast::ArgParams for ExpandedArgParams { type ID = spirv::Word;
type Operand = spirv::Word;
type CallOperand = spirv::Word;
+ type VecOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
@@ -1775,30 +1864,47 @@ impl ArgParamsEx for ExpandedArgParams { }
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
- fn variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID;
- fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand;
- fn src_call_operand(&mut self, desc: ArgumentDescriptor<T::CallOperand>) -> U::CallOperand;
- fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(T::ID, u8)>) -> (U::ID, u8);
+ fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
+ fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<T::CallOperand>,
+ typ: ast::Type,
+ ) -> U::CallOperand;
+ fn src_vec_operand(
+ &mut self,
+ desc: ArgumentDescriptor<T::VecOperand>,
+ typ: ast::MovVectorType,
+ ) -> U::VecOperand;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
- T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
+ T: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: Option<ast::Type>,
+ ) -> spirv::Word {
+ self(desc, t)
}
- fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
+ self(desc, Some(t))
}
- fn src_call_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc.new_op(desc.op))
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: ast::Type,
+ ) -> spirv::Word {
+ self(desc, Some(t))
}
fn src_vec_operand(
&mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- ) -> (spirv::Word, u8) {
- (self(desc.new_op(desc.op.0)), desc.op.1)
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: ast::MovVectorType,
+ ) -> spirv::Word {
+ self(desc, Some(ast::Type::Scalar(t.into())))
}
}
@@ -1806,13 +1912,14 @@ impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> fo where
T: FnMut(&str) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word {
+ fn variable(&mut self, desc: ArgumentDescriptor<&str>, _: Option<ast::Type>) -> spirv::Word {
self(desc.op)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
+ _: ast::Type,
) -> ast::Operand<spirv::Word> {
match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)),
@@ -1824,6 +1931,7 @@ where fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<&str>>,
+ _: ast::Type,
) -> ast::CallOperand<spirv::Word> {
match desc.op {
ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(id)),
@@ -1831,15 +1939,18 @@ where }
}
- fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(&str, u8)>) -> (spirv::Word, u8) {
+ fn src_vec_operand(
+ &mut self,
+ desc: ArgumentDescriptor<(&str, u8)>,
+ _: ast::MovVectorType,
+ ) -> (spirv::Word, u8) {
(self(desc.op.0), desc.op.1)
}
}
-struct ArgumentDescriptor<T> {
- op: T,
+struct ArgumentDescriptor<Op> {
+ op: Op,
is_dst: bool,
- typ: Option<ast::Type>,
is_pointer: bool,
}
@@ -1848,7 +1959,6 @@ impl<T> ArgumentDescriptor<T> { ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
- typ: self.typ,
is_pointer: self.is_pointer,
}
}
@@ -1860,39 +1970,35 @@ impl<T: ArgParamsEx> ast::Instruction<T> { visitor: &mut V,
) -> ast::Instruction<U> {
match self {
- ast::Instruction::MovVector(_, _) => todo!(),
+ ast::Instruction::MovVector(t, a) => ast::Instruction::MovVector(t, a.map(visitor, t)),
ast::Instruction::Abs(_, _) => todo!(),
+ // Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
let src_is_pointer = d.state_space != ast::LdStateSpace::Param;
- ast::Instruction::Ld(
- d,
- a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer),
- )
+ ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, src_is_pointer))
}
ast::Instruction::Mov(mov_type, a) => {
- ast::Instruction::Mov(mov_type, a.map(visitor, Some(mov_type.into())))
+ ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into()))
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, Some(inst_type)))
+ ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type))
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, Some(inst_type)))
+ ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type))
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
- ast::Instruction::Setp(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
+ ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type)))
}
ast::Instruction::SetpBool(d, a) => {
let inst_type = d.typ;
- ast::Instruction::SetpBool(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
- }
- ast::Instruction::Not(t, a) => {
- ast::Instruction::Not(t, a.map(visitor, Some(t.to_type())))
+ ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type)))
}
+ ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, t.to_type())),
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -1915,28 +2021,28 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t))
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, Some(t.to_type())))
+ ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type()))
}
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
let param_space = d.state_space == ast::StStateSpace::Param;
- ast::Instruction::St(
- d,
- a.map(visitor, Some(ast::Type::Scalar(inst_type)), param_space),
- )
+ ast::Instruction::St(d, a.map(visitor, inst_type, param_space))
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
- ast::Instruction::Cvta(d, a.map(visitor, Some(inst_type)))
+ ast::Instruction::Cvta(d, a.map(visitor, inst_type))
}
}
}
}
impl VisitVariable for ast::Instruction<NormalizedArgParams> {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement {
@@ -1946,29 +2052,37 @@ impl VisitVariable for ast::Instruction<NormalizedArgParams> { impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
where
- T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
+ T: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: Option<ast::Type>,
+ ) -> spirv::Word {
+ self(desc, t)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ t: ast::Type,
) -> ast::Operand<spirv::Word> {
match desc.op {
- ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id))),
+ ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id), Some(t))),
ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
- ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(desc.new_op(id)), imm),
+ ast::Operand::RegOffset(id, imm) => {
+ ast::Operand::RegOffset(self(desc.new_op(id), Some(t)), imm)
+ }
}
}
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ t: ast::Type,
) -> ast::CallOperand<spirv::Word> {
match desc.op {
- ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id))),
+ ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id), Some(t))),
ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm),
}
}
@@ -1976,11 +2090,74 @@ where fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
+ t: ast::MovVectorType,
) -> (spirv::Word, u8) {
- (self(desc.new_op(desc.op.0)), desc.op.1)
+ (
+ self(
+ desc.new_op(desc.op.0),
+ Some(ast::Type::Vector(t.into(), desc.op.1)),
+ ),
+ desc.op.1,
+ )
+ }
+}
+
+impl ast::Type {
+ fn to_parts(self) -> TypeParts {
+ match self {
+ ast::Type::Scalar(scalar) => TypeParts {
+ kind: TypeKind::Scalar,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: 0,
+ },
+ ast::Type::Vector(scalar, components) => TypeParts {
+ kind: TypeKind::Vector,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: components as u32,
+ },
+ ast::Type::Array(scalar, components) => TypeParts {
+ kind: TypeKind::Array,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: components,
+ },
+ }
+ }
+
+ fn from_parts(t: TypeParts) -> Self {
+ match t.kind {
+ TypeKind::Scalar => {
+ ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind))
+ }
+ TypeKind::Vector => ast::Type::Vector(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
+ t.components as u8,
+ ),
+ TypeKind::Array => ast::Type::Array(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
+ t.components,
+ ),
+ }
}
}
+#[derive(Eq, PartialEq, Copy, Clone)]
+struct TypeParts {
+ kind: TypeKind,
+ scalar_kind: ScalarKind,
+ width: u8,
+ components: u32,
+}
+
+#[derive(Eq, PartialEq, Copy, Clone)]
+enum TypeKind {
+ Scalar,
+ Vector,
+ Array,
+}
+
impl ast::Instruction<ExpandedArgParams> {
fn jump_target(&self) -> Option<spirv::Word> {
match self {
@@ -2005,7 +2182,9 @@ impl ast::Instruction<ExpandedArgParams> { }
impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement {
@@ -2016,6 +2195,29 @@ impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> { type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
+struct CompositeAccess {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src: spirv::Word,
+ pub index: u32,
+ pub is_write: bool
+}
+
+struct CompositeWrite {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src_composite: spirv::Word,
+ pub src_scalar: spirv::Word,
+ pub index: u32,
+}
+
+struct CompositeRead {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src: spirv::Word,
+ pub index: u32,
+}
+
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
@@ -2028,6 +2230,7 @@ struct BrachCondition { if_false: spirv::Word,
}
+#[derive(Copy, Clone)]
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
@@ -2036,7 +2239,7 @@ struct ImplicitConversion { kind: ConversionKind,
}
-#[derive(Debug, PartialEq)]
+#[derive(Debug, PartialEq, Copy, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
@@ -2084,12 +2287,14 @@ impl<T: ArgParamsEx> ast::Arg1<T> { t: Option<ast::Type>,
) -> ast::Arg1<U> {
ast::Arg1 {
- src: visitor.variable(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ src: visitor.variable(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2098,43 +2303,51 @@ impl<T: ArgParamsEx> ast::Arg2<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
fn map_ld<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
is_src_pointer: bool,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: is_src_pointer,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: is_src_pointer,
+ },
+ t,
+ ),
}
}
@@ -2145,18 +2358,22 @@ impl<T: ArgParamsEx> ast::Arg2<T> { src_t: ast::Type,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: Some(dst_t),
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: Some(src_t),
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(dst_t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ src_t,
+ ),
}
}
}
@@ -2165,22 +2382,26 @@ impl<T: ArgParamsEx> ast::Arg2St<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
param_space: bool,
) -> ast::Arg2St<U> {
ast::Arg2St {
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: param_space,
- is_pointer: !param_space,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: param_space,
+ is_pointer: !param_space,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2189,107 +2410,149 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: ast::MovVectorType,
) -> ast::Arg2Vec<U> {
match self {
ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst(
- visitor.src_vec_operand(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.variable(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ t,
+ ),
+ visitor.variable(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(t.into())),
+ ),
),
- ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src (
- visitor.variable(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.src_vec_operand(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src(
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(t.into())),
+ ),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
),
- ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both (
- visitor.src_vec_operand(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.src_vec_operand(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both(
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ t,
+ ),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
),
}
}
}
+impl ast::Arg2Vec<ExpandedArgParams> {
+ fn dst(&self) -> spirv::Word {
+ match self {
+ ast::Arg2Vec::Dst(dst, _) | ast::Arg2Vec::Src(dst, _) | ast::Arg2Vec::Both(dst, _) => {
+ *dst
+ }
+ }
+ }
+
+ fn src(&self) -> spirv::Word {
+ match self {
+ ast::Arg2Vec::Dst(_, src) | ast::Arg2Vec::Src(_, src) | ast::Arg2Vec::Both(_, src) => {
+ *src
+ }
+ }
+ }
+}
+
impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg3<U> {
ast::Arg3 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
fn map_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg3<U> {
ast::Arg3 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::U32)),
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ),
}
}
}
@@ -2298,35 +2561,43 @@ impl<T: ArgParamsEx> ast::Arg4<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg4<U> {
ast::Arg4 {
- dst1: visitor.variable(ArgumentDescriptor {
- op: self.dst1,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: true,
- is_pointer: false,
- }),
- dst2: self.dst2.map(|dst2| {
- visitor.variable(ArgumentDescriptor {
- op: dst2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ dst1: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst1,
is_dst: true,
is_pointer: false,
- })
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ ),
+ dst2: self.dst2.map(|dst2| {
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst2,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ )
}),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2335,41 +2606,51 @@ impl<T: ArgParamsEx> ast::Arg5<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg5<U> {
ast::Arg5 {
- dst1: visitor.variable(ArgumentDescriptor {
- op: self.dst1,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: true,
- is_pointer: false,
- }),
- dst2: self.dst2.map(|dst2| {
- visitor.variable(ArgumentDescriptor {
- op: dst2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ dst1: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst1,
is_dst: true,
is_pointer: false,
- })
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src3: visitor.operand(ArgumentDescriptor {
- op: self.src3,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: false,
- is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ ),
+ dst2: self.dst2.map(|dst2| {
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst2,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ )
}),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src3: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ is_pointer: false,
+ },
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ),
}
}
}
@@ -2395,9 +2676,9 @@ impl ast::StStateSpace { }
}
-#[derive(Clone, Copy, PartialEq)]
+#[derive(Clone, Copy, PartialEq, Eq)]
enum ScalarKind {
- Byte,
+ Bit,
Unsigned,
Signed,
Float,
@@ -2438,10 +2719,10 @@ impl ast::ScalarType { ast::ScalarType::S16 => ScalarKind::Signed,
ast::ScalarType::S32 => ScalarKind::Signed,
ast::ScalarType::S64 => ScalarKind::Signed,
- ast::ScalarType::B8 => ScalarKind::Byte,
- ast::ScalarType::B16 => ScalarKind::Byte,
- ast::ScalarType::B32 => ScalarKind::Byte,
- ast::ScalarType::B64 => ScalarKind::Byte,
+ ast::ScalarType::B8 => ScalarKind::Bit,
+ ast::ScalarType::B16 => ScalarKind::Bit,
+ ast::ScalarType::B32 => ScalarKind::Bit,
+ ast::ScalarType::B64 => ScalarKind::Bit,
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
@@ -2458,7 +2739,7 @@ impl ast::ScalarType { 8 => ast::ScalarType::F64,
_ => unreachable!(),
},
- ScalarKind::Byte => match width {
+ ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
@@ -2574,22 +2855,159 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { return false;
}
match inst.kind() {
- ScalarKind::Byte => operand.kind() != ScalarKind::Byte,
- ScalarKind::Float => operand.kind() == ScalarKind::Byte,
+ ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
+ ScalarKind::Float => operand.kind() == ScalarKind::Bit,
ScalarKind::Signed => {
- operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Unsigned
+ operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
}
ScalarKind::Unsigned => {
- operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
+ operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
}
- ScalarKind::Float2 => todo!(),
+ ScalarKind::Float2 => false,
ScalarKind::Pred => false,
}
}
+ (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
+ | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
+ should_bitcast(ast::Type::Scalar(inst), ast::Type::Scalar(operand))
+ }
_ => false,
}
}
+fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>>(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+ mut instr: T,
+ pre_conv_src: impl ExactSizeIterator<Item = ImplicitConversion>,
+ pre_conv_dst: impl ExactSizeIterator<Item = ImplicitConversion>,
+ mut post_conv: Vec<ImplicitConversion>,
+ mut src: impl FnMut(&mut T) -> &mut spirv::Word,
+ mut dst: impl FnMut(&mut T) -> &mut spirv::Word,
+ to_inst: ToInstruction,
+) {
+ insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
+ insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
+ if post_conv.len() > 0 {
+ let new_id = id_def.new_id(Some(post_conv[0].from));
+ post_conv[0].src = new_id;
+ post_conv.last_mut().unwrap().dst = *dst(&mut instr);
+ *dst(&mut instr) = new_id;
+ }
+ func.push(Statement::Instruction(to_inst(instr)));
+ for conv in post_conv {
+ func.push(Statement::Conversion(conv));
+ }
+}
+
+fn insert_with_conversions_pre_conv<T>(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+ mut instr: &mut T,
+ pre_conv: impl ExactSizeIterator<Item = ImplicitConversion>,
+ src: &mut impl FnMut(&mut T) -> &mut spirv::Word,
+) {
+ let pre_conv_len = pre_conv.len();
+ for (i, mut conv) in pre_conv.enumerate() {
+ let original_src = src(&mut instr);
+ if i == 0 {
+ conv.src = *original_src;
+ }
+ if i == pre_conv_len - 1 {
+ let new_id = id_def.new_id(Some(conv.to));
+ conv.dst = new_id;
+ *original_src = new_id;
+ }
+ func.push(Statement::Conversion(conv));
+ }
+}
+
+fn get_implicit_conversions_ld_dst<
+ ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
+>(
+ id_def: &mut NumericIdResolver,
+ instr_type: ast::Type,
+ dst: spirv::Word,
+ should_convert: ShouldConvert,
+ in_reverse: bool,
+) -> Option<ImplicitConversion> {
+ let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!());
+ if let Some(conv) = should_convert(dst_type, instr_type) {
+ Some(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: if !in_reverse { dst_type } else { instr_type },
+ to: if !in_reverse { instr_type } else { dst_type },
+ kind: conv,
+ })
+ } else {
+ None
+ }
+}
+
+fn get_implicit_conversions_ld_src(
+ id_def: &mut NumericIdResolver,
+ instr_type: ast::Type,
+ state_space: ast::LdStateSpace,
+ src: spirv::Word,
+) -> Vec<ImplicitConversion> {
+ let src_type = id_def.get_type(src).unwrap_or_else(|| todo!());
+ match state_space {
+ ast::LdStateSpace::Param => {
+ if src_type != instr_type {
+ vec![
+ ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: instr_type,
+ kind: ConversionKind::Default,
+ };
+ 1
+ ]
+ } else {
+ Vec::new()
+ }
+ }
+ ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
+ let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
+ mem::size_of::<usize>() as u8,
+ ScalarKind::Bit,
+ ));
+ let mut result = Vec::new();
+ // HACK ALERT
+ // IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
+ // additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
+ // TODO: error out if the src is not B64/U64/S64
+ if let ast::Type::Scalar(scalar_src_type) = src_type {
+ if scalar_src_type.kind() == ScalarKind::Signed {
+ result.push(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: new_src_type,
+ kind: ConversionKind::Default,
+ });
+ }
+ }
+ result.push(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: instr_type,
+ kind: ConversionKind::Ptr(state_space),
+ });
+ if result.len() == 2 {
+ let new_id = id_def.new_id(Some(new_src_type));
+ result[0].dst = new_id;
+ result[1].src = new_id;
+ result[1].from = new_src_type;
+ }
+ result
+ }
+ _ => todo!(),
+ }
+}
fn insert_implicit_conversions_ld_src(
func: &mut Vec<ExpandedStatement>,
instr_type: ast::Type,
@@ -2608,7 +3026,7 @@ fn insert_implicit_conversions_ld_src( ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
mem::size_of::<usize>() as u8,
- ScalarKind::Byte,
+ ScalarKind::Bit,
));
let new_src = insert_implicit_conversions_ld_src_impl(
func,
@@ -2640,8 +3058,8 @@ fn insert_implicit_conversions_ld_src_impl< should_convert: ShouldConvert,
) -> spirv::Word {
let src_type = id_def.get_type(src);
- if let Some(conv) = should_convert(src_type, instr_type) {
- insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
+ if let Some(conv) = should_convert(src_type.unwrap(), instr_type) {
+ insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv)
} else {
src
}
@@ -2692,14 +3110,15 @@ fn insert_conversion_src( temp_src
}
+/*
fn insert_with_implicit_conversion_dst<
T,
- ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
+ ShouldConvert: FnOnce(ast::StateSpace, ast::Type, ast::Type) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>,
>(
func: &mut Vec<ExpandedStatement>,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
id_def: &mut NumericIdResolver,
should_convert: ShouldConvert,
mut t: T,
@@ -2708,13 +3127,14 @@ fn insert_with_implicit_conversion_dst< ) {
let dst = setter(&mut t);
let dst_type = id_def.get_type(*dst);
- let dst_coercion = should_convert(dst_type, instr_type)
- .map(|conv| get_conversion_dst(id_def, dst, ast::Type::Scalar(instr_type), dst_type, conv));
+ let dst_coercion = should_convert(dst_type.unwrap(), instr_type)
+ .map(|conv| get_conversion_dst(id_def, dst, instr_type, dst_type.unwrap(), conv));
func.push(Statement::Instruction(to_inst(t)));
if let Some(conv) = dst_coercion {
func.push(conv);
}
}
+*/
#[must_use]
fn get_conversion_dst(
@@ -2739,14 +3159,14 @@ fn get_conversion_dst( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: ast::Type,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
) -> Option<ConversionKind> {
- if src_type == ast::Type::Scalar(instr_type) {
+ if src_type == instr_type {
return None;
}
- match src_type {
- ast::Type::Scalar(src_type) => match instr_type.kind() {
- ScalarKind::Byte => {
+ match (src_type, instr_type) {
+ (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ScalarKind::Bit => {
if instr_type.width() <= src_type.width() {
Some(ConversionKind::Default)
} else {
@@ -2761,7 +3181,7 @@ fn should_convert_relaxed_src( }
}
ScalarKind::Float => {
- if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Byte {
+ if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Bit {
Some(ConversionKind::Default)
} else {
None
@@ -2770,6 +3190,10 @@ fn should_convert_relaxed_src( ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
+ (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
+ | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
+ should_convert_relaxed_src(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ }
_ => None,
}
}
@@ -2777,14 +3201,14 @@ fn should_convert_relaxed_src( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: ast::Type,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
) -> Option<ConversionKind> {
- if dst_type == ast::Type::Scalar(instr_type) {
+ if dst_type == instr_type {
return None;
}
- match dst_type {
- ast::Type::Scalar(dst_type) => match instr_type.kind() {
- ScalarKind::Byte => {
+ match (dst_type, instr_type) {
+ (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ScalarKind::Bit => {
if instr_type.width() <= dst_type.width() {
Some(ConversionKind::Default)
} else {
@@ -2812,7 +3236,7 @@ fn should_convert_relaxed_dst( }
}
ScalarKind::Float => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Byte {
+ if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Bit {
Some(ConversionKind::Default)
} else {
None
@@ -2821,6 +3245,10 @@ fn should_convert_relaxed_dst( ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
+ (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
+ | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
+ should_convert_relaxed_dst(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ }
_ => None,
}
}
@@ -2831,13 +3259,13 @@ fn insert_implicit_bitcasts( stmt: impl VisitVariableExpanded,
) {
let mut dst_coercion = None;
- let instr = stmt.visit_variable_extended(&mut |mut desc| {
- let id_type_from_instr = match desc.typ {
+ let instr = stmt.visit_variable_extended(&mut |mut desc, typ| {
+ let id_type_from_instr = match typ {
Some(t) => t,
None => return desc.op,
};
- let id_actual_type = id_def.get_type(desc.op);
- if should_bitcast(id_type_from_instr, id_def.get_type(desc.op)) {
+ let id_actual_type = id_def.get_type(desc.op).unwrap();
+ if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
@@ -2970,14 +3398,14 @@ mod tests { .collect::<Vec<_>>()
}
- fn assert_conversion_table<F: Fn(ast::Type, ast::ScalarType) -> Option<ConversionKind>>(
+ fn assert_conversion_table<F: Fn(ast::Type, ast::Type) -> Option<ConversionKind>>(
table: &'static str,
f: F,
) {
let conv_table = parse_conversion_table(table);
for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
- let conversion = f(ast::Type::Scalar(*op_type), *instr_type);
+ let conversion = f(ast::Type::Scalar(*op_type), ast::Type::Scalar(*instr_type));
if instr_idx == op_idx {
assert_eq!(conversion, None);
} else {
|