From 279e6246ba0ac3fc7b499497514d324c5bce1a78 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 17 Jun 2020 02:53:46 +0200 Subject: Finish implementing implicit conversions --- level_zero-sys/build.rs | 2 + ptx/src/ast.rs | 9 +- ptx/src/lib.rs | 3 +- ptx/src/ptx.lalrpop | 4 +- ptx/src/test/spirv_run/mod.rs | 13 +- ptx/src/translate.rs | 497 ++++++++++++++++++++++++++++++++---------- 6 files changed, 404 insertions(+), 124 deletions(-) diff --git a/level_zero-sys/build.rs b/level_zero-sys/build.rs index 8575c8c..883ded0 100644 --- a/level_zero-sys/build.rs +++ b/level_zero-sys/build.rs @@ -1,5 +1,7 @@ fn main() { println!("cargo:rustc-link-lib=dylib=ze_loader"); + // TODO: make this windows-only + println!("cargo:rustc-link-search=native=C:\\Windows\\System32"); println!("cargo:rerun-if-changed=build.rs"); } \ No newline at end of file diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 5d43a26..bf8ea0d 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -137,7 +137,7 @@ pub enum Instruction { Bra(BraData, Arg1), Cvt(CvtData, Arg2), Shl(ShlData, Arg3), - St(StData, Arg2), + St(StData, Arg2St), Ret(RetData), } @@ -150,6 +150,11 @@ pub struct Arg2 { pub src: Operand, } +pub struct Arg2St { + pub src1: Operand, + pub src2: Operand, +} + pub struct Arg2Mov { pub dst: ID, pub src: MovOperand, @@ -264,7 +269,7 @@ pub struct StData { pub typ: ScalarType, } -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Copy, Clone)] pub enum StStateSpace { Generic, Global, diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 4a61f4e..0b0fd71 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -12,8 +12,7 @@ extern crate rspirv; extern crate spirv_headers as spirv; lalrpop_mod!( - #[allow(dead_code)] - #[allow(unused_imports)] + #[allow(warnings)] ptx ); diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 79290da..22b91af 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -386,7 +386,7 @@ ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction<&'input str> = { - "st" "[" "]" "," => { + "st" "[" "]" "," => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -395,7 +395,7 @@ InstSt: ast::Instruction<&'input str> = { vector: v, typ: t }, - ast::Arg2{dst:dst, src:src} + ast::Arg2St { src1:src1, src2:src2 } ) } }; diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 765d67a..bb27431 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -57,7 +57,7 @@ fn test_ptx_assert<'a, T: From + ze::SafeRepr + Debug + Copy + PartialEq>( Ok(()) } -fn run_spirv + ze::SafeRepr + Copy>( +fn run_spirv + ze::SafeRepr + Copy + Debug>( name: &CStr, spirv: &[u32], input: &[T], @@ -84,15 +84,16 @@ fn run_spirv + ze::SafeRepr + Copy>( let event_pool = ze::EventPool::new(&drv, 3, Some(&[&dev]))?; let ev0 = ze::Event::new(&event_pool, 0)?; let ev1 = ze::Event::new(&event_pool, 1)?; + let ev2 = ze::Event::new(&event_pool, 2)?; let mut cmd_list = ze::CommandList::new(&dev)?; - let out_b_ptr: ze::BufferPtrMut = (&mut out_b).into(); + let out_b_ptr_mut: ze::BufferPtrMut = (&mut out_b).into(); cmd_list.append_memory_copy(inp_b_ptr_mut, input, None, Some(&ev0))?; - cmd_list.append_memory_fill(out_b_ptr, 0u8.into(), Some(&ev1))?; + cmd_list.append_memory_fill(out_b_ptr_mut, 0u8.into(), Some(&ev1))?; kernel.set_group_size(1, 1, 1)?; kernel.set_arg_buffer(0, inp_b_ptr_mut)?; - kernel.set_arg_buffer(1, out_b_ptr)?; - cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], None, &[&ev0, &ev1])?; - cmd_list.append_memory_copy(result.as_mut_slice(), inp_b_ptr_mut, None, Some(&ev0))?; + kernel.set_arg_buffer(1, out_b_ptr_mut)?; + cmd_list.append_launch_kernel(&kernel, &[1, 1, 1], Some(&ev2), &[&ev0, &ev1])?; + cmd_list.append_memory_copy(result.as_mut_slice(), out_b_ptr_mut, None, Some(&ev2))?; queue.execute(cmd_list)?; Ok(result) } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index ad87af8..4444ba7 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -3,7 +3,7 @@ use bit_vec::BitVec; use rspirv::dr; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; -use std::{borrow::Cow, fmt}; +use std::{borrow::Cow, fmt, mem}; use rspirv::binary::{Assemble, Disassemble}; @@ -86,7 +86,7 @@ pub fn to_spirv(ast: ast::Module) -> Result, dr::Error> { emit_function(&mut builder, &mut map, f)?; } let module = builder.module(); - dbg!(print!("{}", module.disassemble())); + println!("{}", module.disassemble()); Ok(module.assemble()) } @@ -206,8 +206,8 @@ fn collect_var_definitions<'a>( documented special ld/st/cvt conversion rules for destination operands - generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64, which is bitcast to a pointer, dereferenced and then documented special - ld/st/cvt conversion rules are applied - - generic ld: for instruction `ld [x], y`, x must be of type b64/u64/s64, + ld/st/cvt conversion rules are applied to dst + - generic st: for instruction `st [x], y`, x must be of type b64/u64/s64, which is bitcast to a pointer */ fn insert_implicit_conversions ast::Type>( @@ -226,41 +226,56 @@ fn insert_implicit_conversions ast::Type>( match s { Statement::Instruction(inst) => match inst { ast::Instruction::Ld(ld, mut arg) => { - let new_arg_src = arg.src.map_id(&mut |arg_src| { + arg.src = arg.src.map_id(&mut |arg_src| { insert_implicit_conversions_ld_src( &mut result, ast::Type::Scalar(ld.typ), type_check, new_id, - |instr, op| ld.state_space.should_convert(instr, op), + ld.state_space, arg_src, ) }); - arg.src = new_arg_src; - insert_implicit_bitcasts( - false, - true, + insert_with_implicit_conversion_dst( &mut result, + ld.typ, type_check, new_id, - ast::Instruction::Ld(ld, arg), + should_convert_relaxed_dst, + arg, + |arg| &mut arg.dst, + |arg| ast::Instruction::Ld(ld, arg), ); } ast::Instruction::St(st, mut arg) => { - let arg_dst_type = type_check(arg.dst); - let new_dst = new_id(); - result.push(Statement::Converison(ImplicitConversion{ - src: arg.dst, - dst: new_dst, - from: arg_dst_type, - to: ast::Type::Scalar(st.typ), - kind: ConversionKind::Ptr - })); - arg.dst = new_dst; - } - inst @ _ => { - insert_implicit_bitcasts(true, true, &mut result, type_check, new_id, inst) + arg.src2 = arg.src2.map_id(&mut |arg_src| { + let arg_src_type = type_check(arg_src); + if let Some(conv) = should_convert_relaxed_src(arg_src_type, st.typ) { + insert_conversion_src( + &mut result, + new_id, + arg_src, + arg_src_type, + ast::Type::Scalar(st.typ), + conv, + ) + } else { + arg_src + } + }); + arg.src1 = arg.src1.map_id(&mut |arg_src| { + insert_implicit_conversions_ld_src( + &mut result, + ast::Type::Scalar(st.typ), + type_check, + new_id, + st.state_space.to_ld_ss(), + arg_src, + ) + }); + result.push(Statement::Instruction(ast::Instruction::St(st, arg))); } + inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst), }, s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s), Statement::Converison(_) => unreachable!(), @@ -386,11 +401,15 @@ fn emit_function_body_ops( { todo!() } - let src = match arg.src { + let dst = match arg.src1 { ast::Operand::Reg(id) => id, _ => todo!(), }; - builder.store(arg.dst, src, None, &[])?; + let src = match arg.src2 { + ast::Operand::Reg(id) => id, + _ => todo!(), + }; + builder.store(dst, src, None, &[])?; } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, @@ -417,17 +436,18 @@ fn emit_implicit_conversion( builder, SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic), ); - builder.bitcast(dst_type, Some(cv.dst), cv.src)?; + builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } ConversionKind::Default => { if from_type.width() == to_type.width() { + let dst_type = map.get_or_add_scalar(builder, to_type); if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte || from_type.kind() == ScalarKind::Byte && to_type.kind() == ScalarKind::Unsigned { - return Ok(()); + // It is noop, but another instruction expects result of this conversion + builder.copy_object(dst_type, Some(cv.dst), cv.src)?; } - let dst_type = map.get_or_add_scalar(builder, to_type); builder.bitcast(dst_type, Some(cv.dst), cv.src)?; } else { let as_unsigned_type = map.get_or_add_scalar( @@ -1025,8 +1045,10 @@ struct ImplicitConversion { kind: ConversionKind, } +#[derive(Debug, PartialEq)] enum ConversionKind { - Default, // zero-extend/chop/bitcast depending on types + Default, + // zero-extend/chop/bitcast depending on types SignExtend, Ptr, } @@ -1136,10 +1158,7 @@ impl ast::Instruction { ast::Instruction::Not(_, a) => a.visit_id(f), ast::Instruction::Cvt(_, a) => a.visit_id(f), ast::Instruction::Shl(_, a) => a.visit_id(f), - ast::Instruction::St(_, a) => { - f(false, &a.dst); - a.src.visit_id(f); - } + ast::Instruction::St(_, a) => a.visit_id(f), ast::Instruction::Bra(_, a) => a.visit_id(f), ast::Instruction::Ret(_) => (), } @@ -1156,10 +1175,7 @@ impl ast::Instruction { ast::Instruction::Not(_, a) => a.visit_id_mut(f), ast::Instruction::Cvt(_, a) => a.visit_id_mut(f), ast::Instruction::Shl(_, a) => a.visit_id_mut(f), - ast::Instruction::St(_, a) => { - f(false, &mut a.dst); - a.src.visit_id_mut(f); - } + ast::Instruction::St(_, a) => a.visit_id_mut(f), ast::Instruction::Bra(_, a) => a.visit_id_mut(f), ast::Instruction::Ret(_) => (), } @@ -1245,6 +1261,25 @@ impl ast::Arg2 { } } +impl ast::Arg2St { + fn map_id U>(self, f: &mut F) -> ast::Arg2St { + ast::Arg2St { + src1: self.src1.map_id(f), + src2: self.src2.map_id(f), + } + } + + fn visit_id(&self, f: &mut F) { + self.src1.visit_id(f); + self.src2.visit_id(f); + } + + fn visit_id_mut(&mut self, f: &mut F) { + self.src1.visit_id_mut(f); + self.src2.visit_id_mut(f); + } +} + impl ast::Arg2Mov { fn map_id U>(self, f: &mut F) -> ast::Arg2Mov { ast::Arg2Mov { @@ -1388,6 +1423,18 @@ impl ast::MovOperand { } } +impl ast::StStateSpace { + fn to_ld_ss(self) -> ast::LdStateSpace { + match self { + ast::StStateSpace::Generic => ast::LdStateSpace::Generic, + ast::StStateSpace::Global => ast::LdStateSpace::Global, + ast::StStateSpace::Local => ast::LdStateSpace::Local, + ast::StStateSpace::Param => ast::LdStateSpace::Param, + ast::StStateSpace::Shared => ast::LdStateSpace::Shared, + } + } +} + #[derive(Clone, Copy, PartialEq)] enum ScalarKind { Byte, @@ -1491,72 +1538,197 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { } } -impl ast::LdStateSpace { - fn should_convert(self, instr_type: ast::Type, op_type: ast::Type) -> Option { - match self { - ast::LdStateSpace::Param => { - if instr_type != op_type { - Some(ConversionKind::Default) - } else { - None - } - } - ast::LdStateSpace::Generic => Some(ConversionKind::Ptr), - _ => todo!(), +fn insert_implicit_conversions_ld_src< + TypeCheck: Fn(spirv::Word) -> ast::Type, + NewId: FnMut() -> spirv::Word, +>( + func: &mut Vec, + instr_type: ast::Type, + type_check: &TypeCheck, + new_id: &mut NewId, + state_space: ast::LdStateSpace, + src: spirv::Word, +) -> spirv::Word { + match state_space { + ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl( + func, + type_check, + new_id, + instr_type, + src, + should_convert_ld_param_src, + ), + ast::LdStateSpace::Generic => { + let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts( + mem::size_of::() as u8, + ScalarKind::Byte, + )); + let new_src = insert_implicit_conversions_ld_src_impl( + func, + type_check, + new_id, + new_src_type, + src, + should_convert_ld_generic_src_to_bitcast, + ); + insert_conversion_src( + func, + new_id, + new_src, + new_src_type, + instr_type, + ConversionKind::Ptr, + ) } + _ => todo!(), } } -fn insert_forced_bitcast_src< +fn insert_implicit_conversions_ld_src_impl< TypeCheck: Fn(spirv::Word) -> ast::Type, NewId: FnMut() -> spirv::Word, + ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option, >( func: &mut Vec, - op_type: ast::Type, type_check: &TypeCheck, new_id: &mut NewId, + instr_type: ast::Type, src: spirv::Word, + should_convert: ShouldConvert, ) -> spirv::Word { let src_type = type_check(src); - if src_type == op_type { - return src; + if let Some(conv) = should_convert(src_type, instr_type) { + insert_conversion_src(func, new_id, src, src_type, instr_type, conv) + } else { + src + } +} + +fn should_convert_ld_param_src( + src_type: ast::Type, + instr_type: ast::Type, +) -> Option { + if src_type != instr_type { + return Some(ConversionKind::Default); + } + None +} + +// HACK ALERT +// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an +// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier +fn should_convert_ld_generic_src_to_bitcast( + src_type: ast::Type, + _instr_type: ast::Type, +) -> Option { + if let ast::Type::Scalar(src_type) = src_type { + if src_type.kind() == ScalarKind::Signed { + return Some(ConversionKind::Default); + } } - let new_src = new_id(); + None +} + +#[must_use] +fn insert_conversion_src spirv::Word>( + func: &mut Vec, + new_id: &mut NewId, + src: spirv::Word, + src_type: ast::Type, + instr_type: ast::Type, + conv: ConversionKind, +) -> spirv::Word { + let temp_src = new_id(); func.push(Statement::Converison(ImplicitConversion { src: src, - dst: new_src, + dst: temp_src, from: src_type, - to: op_type, - kind: ConversionKind::Default, + to: instr_type, + kind: conv, })); - new_src + temp_src } -fn insert_implicit_conversions_ld_src< +fn insert_with_implicit_conversion_dst< + T, TypeCheck: Fn(spirv::Word) -> ast::Type, NewId: FnMut() -> spirv::Word, - ShouldConvert: Fn(ast::Type, ast::Type) -> Option, + ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option, + Setter: Fn(&mut T) -> &mut spirv::Word, + ToInstruction: FnOnce(T) -> ast::Instruction, >( func: &mut Vec, - instr_type: ast::Type, + instr_type: ast::ScalarType, type_check: &TypeCheck, new_id: &mut NewId, should_convert: ShouldConvert, - src: spirv::Word, -) -> spirv::Word { - let src_type = type_check(src); - if let Some(conv_kind) = should_convert(src_type, instr_type) { - let new_src = new_id(); - func.push(Statement::Converison(ImplicitConversion { - src: src, - dst: new_src, - from: src_type, - to: instr_type, - kind: conv_kind, - })); - new_src - } else { - src + mut t: T, + setter: Setter, + to_inst: ToInstruction, +) { + let dst = setter(&mut t); + let dst_type = type_check(*dst); + let dst_coercion = should_convert(dst_type, instr_type) + .map(|conv| get_conversion_dst(new_id, dst, ast::Type::Scalar(instr_type), dst_type, conv)); + func.push(Statement::Instruction(to_inst(t))); + if let Some(conv) = dst_coercion { + func.push(conv); + } +} + +#[must_use] +fn get_conversion_dst spirv::Word>( + new_id: &mut NewId, + dst: &mut spirv::Word, + instr_type: ast::Type, + dst_type: ast::Type, + kind: ConversionKind, +) -> Statement { + let original_dst = *dst; + let temp_dst = new_id(); + *dst = temp_dst; + Statement::Converison(ImplicitConversion { + src: temp_dst, + dst: original_dst, + from: instr_type, + to: dst_type, + kind: kind, + }) +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands +fn should_convert_relaxed_src( + src_type: ast::Type, + instr_type: ast::ScalarType, +) -> Option { + if src_type == ast::Type::Scalar(instr_type) { + return None; + } + match src_type { + ast::Type::Scalar(src_type) => match instr_type.kind() { + ScalarKind::Byte => { + if instr_type.width() <= src_type.width() { + Some(ConversionKind::Default) + } else { + None + } + } + ScalarKind::Signed | ScalarKind::Unsigned => { + if instr_type.width() <= src_type.width() && src_type.kind() != ScalarKind::Float { + Some(ConversionKind::Default) + } else { + None + } + } + ScalarKind::Float => { + if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Byte { + Some(ConversionKind::Default) + } else { + None + } + } + }, + _ => None, } } @@ -1578,8 +1750,14 @@ fn should_convert_relaxed_dst( } } ScalarKind::Signed => { - if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float { - Some(ConversionKind::SignExtend) + if dst_type.kind() != ScalarKind::Float { + if instr_type.width() == dst_type.width() { + Some(ConversionKind::Default) + } else if instr_type.width() < dst_type.width() { + Some(ConversionKind::SignExtend) + } else { + None + } } else { None } @@ -1592,7 +1770,7 @@ fn should_convert_relaxed_dst( } } ScalarKind::Float => { - if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Float { + if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Byte { Some(ConversionKind::Default) } else { None @@ -1607,8 +1785,6 @@ fn insert_implicit_bitcasts< TypeCheck: Fn(spirv::Word) -> ast::Type, NewId: FnMut() -> spirv::Word, >( - do_src_bitcast: bool, - do_dst_bitcast: bool, func: &mut Vec, type_check: &TypeCheck, new_id: &mut NewId, @@ -1617,37 +1793,32 @@ fn insert_implicit_bitcasts< let mut dst_coercion = None; if let Some(instr_type) = instr.get_type() { instr.visit_id_mut(&mut |is_dst, id| { - if (is_dst && !do_dst_bitcast) || (!is_dst && !do_src_bitcast) { - return; - } let id_type = type_check(*id); if should_bitcast(instr_type, type_check(*id)) { - let replacement_id = new_id(); if is_dst { - dst_coercion = Some(ImplicitConversion { - src: replacement_id, - dst: *id, - from: instr_type, - to: id_type, - kind: ConversionKind::Default, - }); - *id = replacement_id; + dst_coercion = Some(get_conversion_dst( + new_id, + id, + instr_type, + id_type, + ConversionKind::Default, + )); } else { - func.push(Statement::Converison(ImplicitConversion { - src: *id, - dst: replacement_id, - from: id_type, - to: instr_type, - kind: ConversionKind::Default, - })); - *id = replacement_id; + *id = insert_conversion_src( + func, + new_id, + *id, + id_type, + instr_type, + ConversionKind::Default, + ); } } }); } func.push(Statement::Instruction(instr)); if let Some(cond) = dst_coercion { - func.push(Statement::Converison(cond)); + func.push(cond); } } @@ -1771,7 +1942,7 @@ mod tests { vec![BasicBlock { start: StmtIndex(0), pred: vec![], - succ: vec![] + succ: vec![], }] ); } @@ -1791,7 +1962,7 @@ mod tests { vec![BasicBlock { start: StmtIndex(0), pred: vec![BBIndex(0)], - succ: vec![BBIndex(0)] + succ: vec![BBIndex(0)], }] ); } @@ -2032,37 +2203,37 @@ mod tests { BasicBlock { start: StmtIndex(0), pred: vec![], - succ: vec![BBIndex(1)] + succ: vec![BBIndex(1)], }, BasicBlock { start: StmtIndex(3), pred: vec![BBIndex(0), BBIndex(5)], - succ: vec![BBIndex(2), BBIndex(6)] + succ: vec![BBIndex(2), BBIndex(6)], }, BasicBlock { start: StmtIndex(6), pred: vec![BBIndex(1)], - succ: vec![BBIndex(3), BBIndex(4)] + succ: vec![BBIndex(3), BBIndex(4)], }, BasicBlock { start: StmtIndex(9), pred: vec![BBIndex(2)], - succ: vec![BBIndex(5)] + succ: vec![BBIndex(5)], }, BasicBlock { start: StmtIndex(13), pred: vec![BBIndex(2)], - succ: vec![BBIndex(5)] + succ: vec![BBIndex(5)], }, BasicBlock { start: StmtIndex(16), pred: vec![BBIndex(3), BBIndex(4)], - succ: vec![BBIndex(1)] + succ: vec![BBIndex(1)], }, BasicBlock { start: StmtIndex(18), pred: vec![BBIndex(1)], - succ: vec![] + succ: vec![], }, ] ); @@ -2350,4 +2521,106 @@ mod tests { } panic!() } + + static SCALAR_TYPES: [ast::ScalarType; 15] = [ + ast::ScalarType::B8, + ast::ScalarType::B16, + ast::ScalarType::B32, + ast::ScalarType::B64, + ast::ScalarType::S8, + ast::ScalarType::S16, + ast::ScalarType::S32, + ast::ScalarType::S64, + ast::ScalarType::U8, + ast::ScalarType::U16, + ast::ScalarType::U32, + ast::ScalarType::U64, + ast::ScalarType::F16, + ast::ScalarType::F32, + ast::ScalarType::F64, + ]; + + static RELAXED_SRC_CONVERSION_TABLE: &'static str = + "b8 - chop chop chop - chop chop chop - chop chop chop chop chop chop + b16 inv - chop chop inv - chop chop inv - chop chop - chop chop + b32 inv inv - chop inv inv - chop inv inv - chop inv - chop + b64 inv inv inv - inv inv inv - inv inv inv - inv inv - + s8 - chop chop chop - chop chop chop - chop chop chop inv inv inv + s16 inv - chop chop inv - chop chop inv - chop chop inv inv inv + s32 inv inv - chop inv inv - chop inv inv - chop inv inv inv + s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv + u8 - chop chop chop - chop chop chop - chop chop chop inv inv inv + u16 inv - chop chop inv - chop chop inv - chop chop inv inv inv + u32 inv inv - chop inv inv - chop inv inv - chop inv inv inv + u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv + f16 inv - chop chop inv inv inv inv inv inv inv inv - inv inv + f32 inv inv - chop inv inv inv inv inv inv inv inv inv - inv + f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -"; + + static RELAXED_DST_CONVERSION_TABLE: &'static str = + "b8 - zext zext zext - zext zext zext - zext zext zext zext zext zext + b16 inv - zext zext inv - zext zext inv - zext zext - zext zext + b32 inv inv - zext inv inv - zext inv inv - zext inv - zext + b64 inv inv inv - inv inv inv - inv inv inv - inv inv - + s8 - sext sext sext - sext sext sext - sext sext sext inv inv inv + s16 inv - sext sext inv - sext sext inv - sext sext inv inv inv + s32 inv inv - sext inv inv - sext inv inv - sext inv inv inv + s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv + u8 - zext zext zext - zext zext zext - zext zext zext inv inv inv + u16 inv - zext zext inv - zext zext inv - zext zext inv inv inv + u32 inv inv - zext inv inv - zext inv inv - zext inv inv inv + u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv + f16 inv - zext zext inv inv inv inv inv inv inv inv - inv inv + f32 inv inv - zext inv inv inv inv inv inv inv inv inv - inv + f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -"; + + fn table_entry_to_conversion(entry: &'static str) -> Option { + match entry { + "-" => Some(ConversionKind::Default), + "inv" => None, + "zext" => Some(ConversionKind::Default), + "chop" => Some(ConversionKind::Default), + "sext" => Some(ConversionKind::SignExtend), + _ => unreachable!(), + } + } + + fn parse_conversion_table(table: &'static str) -> Vec>> { + table + .lines() + .map(|line| { + line.split_ascii_whitespace() + .skip(1) + .map(table_entry_to_conversion) + .collect::>() + }) + .collect::>() + } + + fn assert_conversion_table Option>( + table: &'static str, + f: F, + ) { + let conv_table = parse_conversion_table(table); + for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() { + for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() { + let conversion = f(ast::Type::Scalar(*op_type), *instr_type); + if instr_idx == op_idx { + assert_eq!(conversion, None); + } else { + assert_eq!(conversion, conv_table[instr_idx][op_idx]); + } + } + } + } + + #[test] + fn should_convert_relaxed_src_all_combinations() { + assert_conversion_table(RELAXED_SRC_CONVERSION_TABLE, should_convert_relaxed_src); + } + + #[test] + fn should_convert_relaxed_dst_all_combinations() { + assert_conversion_table(RELAXED_DST_CONVERSION_TABLE, should_convert_relaxed_dst); + } } -- cgit v1.2.3