summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-06-17 02:53:46 +0200
committerAndrzej Janik <[email protected]>2020-06-17 02:53:46 +0200
commit279e6246ba0ac3fc7b499497514d324c5bce1a78 (patch)
tree531118161eef27711d718f29045fd7c491678a53
parent4a0edf0e14cb6efb7d7a203adb5d4a0303b45d90 (diff)
downloadZLUDA-279e6246ba0ac3fc7b499497514d324c5bce1a78.tar.gz
ZLUDA-279e6246ba0ac3fc7b499497514d324c5bce1a78.zip
Finish implementing implicit conversions
-rw-r--r--level_zero-sys/build.rs2
-rw-r--r--ptx/src/ast.rs9
-rw-r--r--ptx/src/lib.rs3
-rw-r--r--ptx/src/ptx.lalrpop4
-rw-r--r--ptx/src/test/spirv_run/mod.rs13
-rw-r--r--ptx/src/translate.rs497
6 files changed, 404 insertions, 124 deletions
diff --git a/level_zero-sys/build.rs b/level_zero-sys/build.rs
index 8575c8c..883ded0 100644
--- a/level_zero-sys/build.rs
+++ b/level_zero-sys/build.rs
@@ -1,5 +1,7 @@
fn main() {
println!("cargo:rustc-link-lib=dylib=ze_loader");
+ // TODO: make this windows-only
+ println!("cargo:rustc-link-search=native=C:\\Windows\\System32");
println!("cargo:rerun-if-changed=build.rs");
} \ No newline at end of file
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 5d43a26..bf8ea0d 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -137,7 +137,7 @@ pub enum Instruction<ID> {
Bra(BraData, Arg1<ID>),
Cvt(CvtData, Arg2<ID>),
Shl(ShlData, Arg3<ID>),
- St(StData, Arg2<ID>),
+ St(StData, Arg2St<ID>),
Ret(RetData),
}
@@ -150,6 +150,11 @@ pub struct Arg2<ID> {
pub src: Operand<ID>,
}
+pub struct Arg2St<ID> {
+ pub src1: Operand<ID>,
+ pub src2: Operand<ID>,
+}
+
pub struct Arg2Mov<ID> {
pub dst: ID,
pub src: MovOperand<ID>,
@@ -264,7 +269,7 @@ pub struct StData {
pub typ: ScalarType,
}
-#[derive(PartialEq, Eq)]
+#[derive(PartialEq, Eq, Copy, Clone)]
pub enum StStateSpace {
Generic,
Global,
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 4a61f4e..0b0fd71 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -12,8 +12,7 @@ extern crate rspirv;
extern crate spirv_headers as spirv;
lalrpop_mod!(
- #[allow(dead_code)]
- #[allow(unused_imports)]
+ #[allow(warnings)]
ptx
);
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 79290da..22b91af 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -386,7 +386,7 @@ ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<&'input str> = {
- "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <dst:ID> "]" "," <src:Operand> => {
+ "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@@ -395,7 +395,7 @@ InstSt: ast::Instruction<&'input str> = {
vector: v,
typ: t
},
- ast::Arg2{dst:dst, src:src}
+ ast::Arg2St { src1:src1, src2:src2 }
)
}
};
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 765d67a..bb27431 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -57,7 +57,7 @@ fn test_ptx_assert<'a, T: From<u8> + ze::SafeRepr + Debug + Copy + PartialEq>(
Ok(())
}
-fn run_spirv<T: From<u8> + ze::SafeRepr + Copy>(
+fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
name: &CStr,
spirv: &[u32],
input: &[T],
@@ -84,15 +84,16 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy>(
let event_pool = ze::EventPool::new(&drv, 3, Some(&[&dev]))?;
let ev0 = ze::Event::new(&event_pool, 0)?;
let ev1 = ze::Event::new(&event_pool, 1)?;
+ let ev2 = ze::Event::new(&event_pool, 2)?;
let mut cmd_list = ze::CommandList::new(&dev)?;
- let out_b_ptr: ze::BufferPtrMut<T> = (&mut out_b).into();
+ let out_b_ptr_mut: ze::BufferPtrMut<T> = (&mut out_b).into();
cmd_list.append_memory_copy(inp_b_ptr_mut, input, None, Some(&ev0))?;
- cmd_list.append_memory_fill(out_b_ptr, 0u8.into(), Some(&ev1))?;
+ cmd_list.append_memory_fill(out_b_ptr_mut, 0u8.into(), Some(&ev1))?;
kernel.set_group_size(1, 1, 1)?;
kernel.set_arg_buffer(0, inp_b_ptr_mut)?;
- kernel.set_arg_buffer(1, out_b_ptr)?;
- cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], None, &[&ev0, &ev1])?;
- cmd_list.append_memory_copy(result.as_mut_slice(), inp_b_ptr_mut, None, Some(&ev0))?;
+ kernel.set_arg_buffer(1, out_b_ptr_mut)?;
+ cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &[&ev0, &ev1])?;
+ cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, Some(&ev2))?;
queue.execute(cmd_list)?;
Ok(result)
}
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index ad87af8..4444ba7 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -3,7 +3,7 @@ use bit_vec::BitVec;
use rspirv::dr;
use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
-use std::{borrow::Cow, fmt};
+use std::{borrow::Cow, fmt, mem};
use rspirv::binary::{Assemble, Disassemble};
@@ -86,7 +86,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> {
emit_function(&mut builder, &mut map, f)?;
}
let module = builder.module();
- dbg!(print!("{}", module.disassemble()));
+ println!("{}", module.disassemble());
Ok(module.assemble())
}
@@ -206,8 +206,8 @@ fn collect_var_definitions<'a>(
documented special ld/st/cvt conversion rules for destination operands
- generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64,
which is bitcast to a pointer, dereferenced and then documented special
- ld/st/cvt conversion rules are applied
- - generic ld: for instruction `ld [x], y`, x must be of type b64/u64/s64,
+ ld/st/cvt conversion rules are applied to dst
+ - generic st: for instruction `st [x], y`, x must be of type b64/u64/s64,
which is bitcast to a pointer
*/
fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
@@ -226,41 +226,56 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Ld(ld, mut arg) => {
- let new_arg_src = arg.src.map_id(&mut |arg_src| {
+ arg.src = arg.src.map_id(&mut |arg_src| {
insert_implicit_conversions_ld_src(
&mut result,
ast::Type::Scalar(ld.typ),
type_check,
new_id,
- |instr, op| ld.state_space.should_convert(instr, op),
+ ld.state_space,
arg_src,
)
});
- arg.src = new_arg_src;
- insert_implicit_bitcasts(
- false,
- true,
+ insert_with_implicit_conversion_dst(
&mut result,
+ ld.typ,
type_check,
new_id,
- ast::Instruction::Ld(ld, arg),
+ should_convert_relaxed_dst,
+ arg,
+ |arg| &mut arg.dst,
+ |arg| ast::Instruction::Ld(ld, arg),
);
}
ast::Instruction::St(st, mut arg) => {
- let arg_dst_type = type_check(arg.dst);
- let new_dst = new_id();
- result.push(Statement::Converison(ImplicitConversion{
- src: arg.dst,
- dst: new_dst,
- from: arg_dst_type,
- to: ast::Type::Scalar(st.typ),
- kind: ConversionKind::Ptr
- }));
- arg.dst = new_dst;
- }
- inst @ _ => {
- insert_implicit_bitcasts(true, true, &mut result, type_check, new_id, inst)
+ arg.src2 = arg.src2.map_id(&mut |arg_src| {
+ let arg_src_type = type_check(arg_src);
+ if let Some(conv) = should_convert_relaxed_src(arg_src_type, st.typ) {
+ insert_conversion_src(
+ &mut result,
+ new_id,
+ arg_src,
+ arg_src_type,
+ ast::Type::Scalar(st.typ),
+ conv,
+ )
+ } else {
+ arg_src
+ }
+ });
+ arg.src1 = arg.src1.map_id(&mut |arg_src| {
+ insert_implicit_conversions_ld_src(
+ &mut result,
+ ast::Type::Scalar(st.typ),
+ type_check,
+ new_id,
+ st.state_space.to_ld_ss(),
+ arg_src,
+ )
+ });
+ result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
}
+ inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst),
},
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
Statement::Converison(_) => unreachable!(),
@@ -386,11 +401,15 @@ fn emit_function_body_ops(
{
todo!()
}
- let src = match arg.src {
+ let dst = match arg.src1 {
ast::Operand::Reg(id) => id,
_ => todo!(),
};
- builder.store(arg.dst, src, None, &[])?;
+ let src = match arg.src2 {
+ ast::Operand::Reg(id) => id,
+ _ => todo!(),
+ };
+ builder.store(dst, src, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
@@ -417,17 +436,18 @@ fn emit_implicit_conversion(
builder,
SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic),
);
- builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
+ builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
ConversionKind::Default => {
if from_type.width() == to_type.width() {
+ let dst_type = map.get_or_add_scalar(builder, to_type);
if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte
|| from_type.kind() == ScalarKind::Byte
&& to_type.kind() == ScalarKind::Unsigned
{
- return Ok(());
+ // It is noop, but another instruction expects result of this conversion
+ builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
}
- let dst_type = map.get_or_add_scalar(builder, to_type);
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
} else {
let as_unsigned_type = map.get_or_add_scalar(
@@ -1025,8 +1045,10 @@ struct ImplicitConversion {
kind: ConversionKind,
}
+#[derive(Debug, PartialEq)]
enum ConversionKind {
- Default, // zero-extend/chop/bitcast depending on types
+ Default,
+ // zero-extend/chop/bitcast depending on types
SignExtend,
Ptr,
}
@@ -1136,10 +1158,7 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Not(_, a) => a.visit_id(f),
ast::Instruction::Cvt(_, a) => a.visit_id(f),
ast::Instruction::Shl(_, a) => a.visit_id(f),
- ast::Instruction::St(_, a) => {
- f(false, &a.dst);
- a.src.visit_id(f);
- }
+ ast::Instruction::St(_, a) => a.visit_id(f),
ast::Instruction::Bra(_, a) => a.visit_id(f),
ast::Instruction::Ret(_) => (),
}
@@ -1156,10 +1175,7 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Not(_, a) => a.visit_id_mut(f),
ast::Instruction::Cvt(_, a) => a.visit_id_mut(f),
ast::Instruction::Shl(_, a) => a.visit_id_mut(f),
- ast::Instruction::St(_, a) => {
- f(false, &mut a.dst);
- a.src.visit_id_mut(f);
- }
+ ast::Instruction::St(_, a) => a.visit_id_mut(f),
ast::Instruction::Bra(_, a) => a.visit_id_mut(f),
ast::Instruction::Ret(_) => (),
}
@@ -1245,6 +1261,25 @@ impl<T> ast::Arg2<T> {
}
}
+impl<T> ast::Arg2St<T> {
+ fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2St<U> {
+ ast::Arg2St {
+ src1: self.src1.map_id(f),
+ src2: self.src2.map_id(f),
+ }
+ }
+
+ fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
+ self.src1.visit_id(f);
+ self.src2.visit_id(f);
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ self.src1.visit_id_mut(f);
+ self.src2.visit_id_mut(f);
+ }
+}
+
impl<T> ast::Arg2Mov<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> {
ast::Arg2Mov {
@@ -1388,6 +1423,18 @@ impl<T> ast::MovOperand<T> {
}
}
+impl ast::StStateSpace {
+ fn to_ld_ss(self) -> ast::LdStateSpace {
+ match self {
+ ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
+ ast::StStateSpace::Global => ast::LdStateSpace::Global,
+ ast::StStateSpace::Local => ast::LdStateSpace::Local,
+ ast::StStateSpace::Param => ast::LdStateSpace::Param,
+ ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
+ }
+ }
+}
+
#[derive(Clone, Copy, PartialEq)]
enum ScalarKind {
Byte,
@@ -1491,72 +1538,197 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
}
}
-impl ast::LdStateSpace {
- fn should_convert(self, instr_type: ast::Type, op_type: ast::Type) -> Option<ConversionKind> {
- match self {
- ast::LdStateSpace::Param => {
- if instr_type != op_type {
- Some(ConversionKind::Default)
- } else {
- None
- }
- }
- ast::LdStateSpace::Generic => Some(ConversionKind::Ptr),
- _ => todo!(),
+fn insert_implicit_conversions_ld_src<
+ TypeCheck: Fn(spirv::Word) -> ast::Type,
+ NewId: FnMut() -> spirv::Word,
+>(
+ func: &mut Vec<Statement>,
+ instr_type: ast::Type,
+ type_check: &TypeCheck,
+ new_id: &mut NewId,
+ state_space: ast::LdStateSpace,
+ src: spirv::Word,
+) -> spirv::Word {
+ match state_space {
+ ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
+ func,
+ type_check,
+ new_id,
+ instr_type,
+ src,
+ should_convert_ld_param_src,
+ ),
+ ast::LdStateSpace::Generic => {
+ let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
+ mem::size_of::<usize>() as u8,
+ ScalarKind::Byte,
+ ));
+ let new_src = insert_implicit_conversions_ld_src_impl(
+ func,
+ type_check,
+ new_id,
+ new_src_type,
+ src,
+ should_convert_ld_generic_src_to_bitcast,
+ );
+ insert_conversion_src(
+ func,
+ new_id,
+ new_src,
+ new_src_type,
+ instr_type,
+ ConversionKind::Ptr,
+ )
}
+ _ => todo!(),
}
}
-fn insert_forced_bitcast_src<
+fn insert_implicit_conversions_ld_src_impl<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
+ ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<Statement>,
- op_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
+ instr_type: ast::Type,
src: spirv::Word,
+ should_convert: ShouldConvert,
) -> spirv::Word {
let src_type = type_check(src);
- if src_type == op_type {
- return src;
+ if let Some(conv) = should_convert(src_type, instr_type) {
+ insert_conversion_src(func, new_id, src, src_type, instr_type, conv)
+ } else {
+ src
+ }
+}
+
+fn should_convert_ld_param_src(
+ src_type: ast::Type,
+ instr_type: ast::Type,
+) -> Option<ConversionKind> {
+ if src_type != instr_type {
+ return Some(ConversionKind::Default);
+ }
+ None
+}
+
+// HACK ALERT
+// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
+// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
+fn should_convert_ld_generic_src_to_bitcast(
+ src_type: ast::Type,
+ _instr_type: ast::Type,
+) -> Option<ConversionKind> {
+ if let ast::Type::Scalar(src_type) = src_type {
+ if src_type.kind() == ScalarKind::Signed {
+ return Some(ConversionKind::Default);
+ }
}
- let new_src = new_id();
+ None
+}
+
+#[must_use]
+fn insert_conversion_src<NewId: FnMut() -> spirv::Word>(
+ func: &mut Vec<Statement>,
+ new_id: &mut NewId,
+ src: spirv::Word,
+ src_type: ast::Type,
+ instr_type: ast::Type,
+ conv: ConversionKind,
+) -> spirv::Word {
+ let temp_src = new_id();
func.push(Statement::Converison(ImplicitConversion {
src: src,
- dst: new_src,
+ dst: temp_src,
from: src_type,
- to: op_type,
- kind: ConversionKind::Default,
+ to: instr_type,
+ kind: conv,
}));
- new_src
+ temp_src
}
-fn insert_implicit_conversions_ld_src<
+fn insert_with_implicit_conversion_dst<
+ T,
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
- ShouldConvert: Fn(ast::Type, ast::Type) -> Option<ConversionKind>,
+ ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
+ Setter: Fn(&mut T) -> &mut spirv::Word,
+ ToInstruction: FnOnce(T) -> ast::Instruction<spirv::Word>,
>(
func: &mut Vec<Statement>,
- instr_type: ast::Type,
+ instr_type: ast::ScalarType,
type_check: &TypeCheck,
new_id: &mut NewId,
should_convert: ShouldConvert,
- src: spirv::Word,
-) -> spirv::Word {
- let src_type = type_check(src);
- if let Some(conv_kind) = should_convert(src_type, instr_type) {
- let new_src = new_id();
- func.push(Statement::Converison(ImplicitConversion {
- src: src,
- dst: new_src,
- from: src_type,
- to: instr_type,
- kind: conv_kind,
- }));
- new_src
- } else {
- src
+ mut t: T,
+ setter: Setter,
+ to_inst: ToInstruction,
+) {
+ let dst = setter(&mut t);
+ let dst_type = type_check(*dst);
+ let dst_coercion = should_convert(dst_type, instr_type)
+ .map(|conv| get_conversion_dst(new_id, dst, ast::Type::Scalar(instr_type), dst_type, conv));
+ func.push(Statement::Instruction(to_inst(t)));
+ if let Some(conv) = dst_coercion {
+ func.push(conv);
+ }
+}
+
+#[must_use]
+fn get_conversion_dst<NewId: FnMut() -> spirv::Word>(
+ new_id: &mut NewId,
+ dst: &mut spirv::Word,
+ instr_type: ast::Type,
+ dst_type: ast::Type,
+ kind: ConversionKind,
+) -> Statement {
+ let original_dst = *dst;
+ let temp_dst = new_id();
+ *dst = temp_dst;
+ Statement::Converison(ImplicitConversion {
+ src: temp_dst,
+ dst: original_dst,
+ from: instr_type,
+ to: dst_type,
+ kind: kind,
+ })
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
+fn should_convert_relaxed_src(
+ src_type: ast::Type,
+ instr_type: ast::ScalarType,
+) -> Option<ConversionKind> {
+ if src_type == ast::Type::Scalar(instr_type) {
+ return None;
+ }
+ match src_type {
+ ast::Type::Scalar(src_type) => match instr_type.kind() {
+ ScalarKind::Byte => {
+ if instr_type.width() <= src_type.width() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ScalarKind::Signed | ScalarKind::Unsigned => {
+ if instr_type.width() <= src_type.width() && src_type.kind() != ScalarKind::Float {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ScalarKind::Float => {
+ if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Byte {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ },
+ _ => None,
}
}
@@ -1578,8 +1750,14 @@ fn should_convert_relaxed_dst(
}
}
ScalarKind::Signed => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
- Some(ConversionKind::SignExtend)
+ if dst_type.kind() != ScalarKind::Float {
+ if instr_type.width() == dst_type.width() {
+ Some(ConversionKind::Default)
+ } else if instr_type.width() < dst_type.width() {
+ Some(ConversionKind::SignExtend)
+ } else {
+ None
+ }
} else {
None
}
@@ -1592,7 +1770,7 @@ fn should_convert_relaxed_dst(
}
}
ScalarKind::Float => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Float {
+ if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Byte {
Some(ConversionKind::Default)
} else {
None
@@ -1607,8 +1785,6 @@ fn insert_implicit_bitcasts<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
>(
- do_src_bitcast: bool,
- do_dst_bitcast: bool,
func: &mut Vec<Statement>,
type_check: &TypeCheck,
new_id: &mut NewId,
@@ -1617,37 +1793,32 @@ fn insert_implicit_bitcasts<
let mut dst_coercion = None;
if let Some(instr_type) = instr.get_type() {
instr.visit_id_mut(&mut |is_dst, id| {
- if (is_dst && !do_dst_bitcast) || (!is_dst && !do_src_bitcast) {
- return;
- }
let id_type = type_check(*id);
if should_bitcast(instr_type, type_check(*id)) {
- let replacement_id = new_id();
if is_dst {
- dst_coercion = Some(ImplicitConversion {
- src: replacement_id,
- dst: *id,
- from: instr_type,
- to: id_type,
- kind: ConversionKind::Default,
- });
- *id = replacement_id;
+ dst_coercion = Some(get_conversion_dst(
+ new_id,
+ id,
+ instr_type,
+ id_type,
+ ConversionKind::Default,
+ ));
} else {
- func.push(Statement::Converison(ImplicitConversion {
- src: *id,
- dst: replacement_id,
- from: id_type,
- to: instr_type,
- kind: ConversionKind::Default,
- }));
- *id = replacement_id;
+ *id = insert_conversion_src(
+ func,
+ new_id,
+ *id,
+ id_type,
+ instr_type,
+ ConversionKind::Default,
+ );
}
}
});
}
func.push(Statement::Instruction(instr));
if let Some(cond) = dst_coercion {
- func.push(Statement::Converison(cond));
+ func.push(cond);
}
}
@@ -1771,7 +1942,7 @@ mod tests {
vec![BasicBlock {
start: StmtIndex(0),
pred: vec![],
- succ: vec![]
+ succ: vec![],
}]
);
}
@@ -1791,7 +1962,7 @@ mod tests {
vec![BasicBlock {
start: StmtIndex(0),
pred: vec![BBIndex(0)],
- succ: vec![BBIndex(0)]
+ succ: vec![BBIndex(0)],
}]
);
}
@@ -2032,37 +2203,37 @@ mod tests {
BasicBlock {
start: StmtIndex(0),
pred: vec![],
- succ: vec![BBIndex(1)]
+ succ: vec![BBIndex(1)],
},
BasicBlock {
start: StmtIndex(3),
pred: vec![BBIndex(0), BBIndex(5)],
- succ: vec![BBIndex(2), BBIndex(6)]
+ succ: vec![BBIndex(2), BBIndex(6)],
},
BasicBlock {
start: StmtIndex(6),
pred: vec![BBIndex(1)],
- succ: vec![BBIndex(3), BBIndex(4)]
+ succ: vec![BBIndex(3), BBIndex(4)],
},
BasicBlock {
start: StmtIndex(9),
pred: vec![BBIndex(2)],
- succ: vec![BBIndex(5)]
+ succ: vec![BBIndex(5)],
},
BasicBlock {
start: StmtIndex(13),
pred: vec![BBIndex(2)],
- succ: vec![BBIndex(5)]
+ succ: vec![BBIndex(5)],
},
BasicBlock {
start: StmtIndex(16),
pred: vec![BBIndex(3), BBIndex(4)],
- succ: vec![BBIndex(1)]
+ succ: vec![BBIndex(1)],
},
BasicBlock {
start: StmtIndex(18),
pred: vec![BBIndex(1)],
- succ: vec![]
+ succ: vec![],
},
]
);
@@ -2350,4 +2521,106 @@ mod tests {
}
panic!()
}
+
+ static SCALAR_TYPES: [ast::ScalarType; 15] = [
+ ast::ScalarType::B8,
+ ast::ScalarType::B16,
+ ast::ScalarType::B32,
+ ast::ScalarType::B64,
+ ast::ScalarType::S8,
+ ast::ScalarType::S16,
+ ast::ScalarType::S32,
+ ast::ScalarType::S64,
+ ast::ScalarType::U8,
+ ast::ScalarType::U16,
+ ast::ScalarType::U32,
+ ast::ScalarType::U64,
+ ast::ScalarType::F16,
+ ast::ScalarType::F32,
+ ast::ScalarType::F64,
+ ];
+
+ static RELAXED_SRC_CONVERSION_TABLE: &'static str =
+ "b8 - chop chop chop - chop chop chop - chop chop chop chop chop chop
+ b16 inv - chop chop inv - chop chop inv - chop chop - chop chop
+ b32 inv inv - chop inv inv - chop inv inv - chop inv - chop
+ b64 inv inv inv - inv inv inv - inv inv inv - inv inv -
+ s8 - chop chop chop - chop chop chop - chop chop chop inv inv inv
+ s16 inv - chop chop inv - chop chop inv - chop chop inv inv inv
+ s32 inv inv - chop inv inv - chop inv inv - chop inv inv inv
+ s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
+ u8 - chop chop chop - chop chop chop - chop chop chop inv inv inv
+ u16 inv - chop chop inv - chop chop inv - chop chop inv inv inv
+ u32 inv inv - chop inv inv - chop inv inv - chop inv inv inv
+ u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
+ f16 inv - chop chop inv inv inv inv inv inv inv inv - inv inv
+ f32 inv inv - chop inv inv inv inv inv inv inv inv inv - inv
+ f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -";
+
+ static RELAXED_DST_CONVERSION_TABLE: &'static str =
+ "b8 - zext zext zext - zext zext zext - zext zext zext zext zext zext
+ b16 inv - zext zext inv - zext zext inv - zext zext - zext zext
+ b32 inv inv - zext inv inv - zext inv inv - zext inv - zext
+ b64 inv inv inv - inv inv inv - inv inv inv - inv inv -
+ s8 - sext sext sext - sext sext sext - sext sext sext inv inv inv
+ s16 inv - sext sext inv - sext sext inv - sext sext inv inv inv
+ s32 inv inv - sext inv inv - sext inv inv - sext inv inv inv
+ s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
+ u8 - zext zext zext - zext zext zext - zext zext zext inv inv inv
+ u16 inv - zext zext inv - zext zext inv - zext zext inv inv inv
+ u32 inv inv - zext inv inv - zext inv inv - zext inv inv inv
+ u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv
+ f16 inv - zext zext inv inv inv inv inv inv inv inv - inv inv
+ f32 inv inv - zext inv inv inv inv inv inv inv inv inv - inv
+ f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -";
+
+ fn table_entry_to_conversion(entry: &'static str) -> Option<ConversionKind> {
+ match entry {
+ "-" => Some(ConversionKind::Default),
+ "inv" => None,
+ "zext" => Some(ConversionKind::Default),
+ "chop" => Some(ConversionKind::Default),
+ "sext" => Some(ConversionKind::SignExtend),
+ _ => unreachable!(),
+ }
+ }
+
+ fn parse_conversion_table(table: &'static str) -> Vec<Vec<Option<ConversionKind>>> {
+ table
+ .lines()
+ .map(|line| {
+ line.split_ascii_whitespace()
+ .skip(1)
+ .map(table_entry_to_conversion)
+ .collect::<Vec<_>>()
+ })
+ .collect::<Vec<_>>()
+ }
+
+ fn assert_conversion_table<F: Fn(ast::Type, ast::ScalarType) -> Option<ConversionKind>>(
+ table: &'static str,
+ f: F,
+ ) {
+ let conv_table = parse_conversion_table(table);
+ for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
+ for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
+ let conversion = f(ast::Type::Scalar(*op_type), *instr_type);
+ if instr_idx == op_idx {
+ assert_eq!(conversion, None);
+ } else {
+ assert_eq!(conversion, conv_table[instr_idx][op_idx]);
+ }
+ }
+ }
+ }
+
+ #[test]
+ fn should_convert_relaxed_src_all_combinations() {
+ assert_conversion_table(RELAXED_SRC_CONVERSION_TABLE, should_convert_relaxed_src);
+ }
+
+ #[test]
+ fn should_convert_relaxed_dst_all_combinations() {
+ assert_conversion_table(RELAXED_DST_CONVERSION_TABLE, should_convert_relaxed_dst);
+ }
}