diff options
author | Andrzej Janik <[email protected]> | 2020-05-26 00:33:32 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-05-26 00:33:32 +0200 |
commit | 4a0edf0e14cb6efb7d7a203adb5d4a0303b45d90 (patch) | |
tree | c212b495e55af1793b4d5f6810c5eb2b99c8d091 | |
parent | 9f60990765301af9a359100b94192137f466d351 (diff) | |
download | ZLUDA-4a0edf0e14cb6efb7d7a203adb5d4a0303b45d90.tar.gz ZLUDA-4a0edf0e14cb6efb7d7a203adb5d4a0303b45d90.zip |
Start implementing implicit conversions
-rw-r--r-- | ptx/src/ast.rs | 8 | ||||
-rw-r--r-- | ptx/src/translate.rs | 601 |
2 files changed, 455 insertions, 154 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index ce9a596..5d43a26 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -200,7 +200,7 @@ pub struct LdData { pub typ: ScalarType, } -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStQualifier { Weak, Volatile, @@ -208,14 +208,14 @@ pub enum LdStQualifier { Acquire(LdScope), } -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum LdScope { Cta, Gpu, Sys, } -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStateSpace { Generic, Const, @@ -225,7 +225,7 @@ pub enum LdStateSpace { Shared, } -#[derive(PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq)] pub enum LdCacheOperator { Cached, L2Only, diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a186772..ad87af8 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -5,7 +5,7 @@ use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet};
use std::{borrow::Cow, fmt};
-use rspirv::binary::Assemble;
+use rspirv::binary::{Assemble, Disassemble};
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
@@ -86,6 +86,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, dr::Error> { emit_function(&mut builder, &mut map, f)?;
}
let module = builder.module();
+ dbg!(print!("{}", module.disassemble()));
Ok(module.assemble())
}
@@ -122,19 +123,44 @@ fn emit_function<'a>( if f.kernel {
builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]);
}
+ let (mut func_body, bbs, _, unique_ids) = to_ssa(&f.args, f.body);
+ let id_offset = builder.reserve_ids(unique_ids);
+ emit_function_args(builder, id_offset, map, &f.args);
+ apply_id_offset(&mut func_body, id_offset);
+ emit_function_body_ops(builder, map, &func_body, &bbs)?;
+ builder.end_function()?;
+ Ok(func_id)
+}
+
+fn apply_id_offset(func_body: &mut Vec<Statement>, id_offset: u32) {
+ for s in func_body {
+ s.visit_id_mut(&mut |_, id| *id += id_offset);
+ }
+}
+
+fn to_ssa<'a>(
+ f_args: &[ast::Argument],
+ f_body: Vec<ast::Statement<&'a str>>,
+) -> (
+ Vec<Statement>,
+ Vec<BasicBlock>,
+ Vec<Vec<PhiDef>>,
+ spirv::Word,
+) {
let mut contant_ids = HashMap::new();
- collect_arg_ids(&mut contant_ids, &f.args);
- collect_label_ids(&mut contant_ids, &f.body);
- let registers = collect_registers(&f.body);
- let (normalized_ids, unique_ids, type_check) =
- normalize_identifiers(f.body, &contant_ids, registers);
+ let mut type_check = HashMap::new();
+ collect_arg_ids(&mut contant_ids, &mut type_check, &f_args);
+ collect_label_ids(&mut contant_ids, &f_body);
+ let registers = collect_var_definitions(&f_args, &f_body);
+ let (normalized_ids, unique_ids) =
+ normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
let (mut func_body, unique_ids) =
- insert_implicit_conversion(normalized_ids, unique_ids, &|x| type_check[&x]);
+ insert_implicit_conversions(normalized_ids, unique_ids, &|x| type_check[&x]);
let bbs = get_basic_blocks(&func_body);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms);
- let (_, unique_ids) = ssa_legalize(
+ let (phis, unique_ids) = ssa_legalize(
&mut func_body,
contant_ids.len() as u32,
unique_ids,
@@ -142,15 +168,17 @@ fn emit_function<'a>( &doms,
&dom_fronts,
);
- let id_offset = builder.reserve_ids(unique_ids);
- emit_function_args(builder, id_offset, map, &f.args);
- emit_function_body_ops(builder, id_offset, map, &func_body, &bbs)?;
- builder.end_function()?;
- Ok(func_id)
+ (func_body, bbs, phis, unique_ids)
}
-fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap<Cow<'a, str>, ast::Type> {
+fn collect_var_definitions<'a>(
+ args: &[ast::Argument<'a>],
+ body: &[ast::Statement<&'a str>],
+) -> HashMap<Cow<'a, str>, ast::Type> {
let mut result = HashMap::new();
+ for param in args {
+ result.insert(Cow::Borrowed(param.name), ast::Type::Scalar(param.a_type));
+ }
for s in body {
match s {
ast::Statement::Variable(var) => match var.count {
@@ -170,12 +198,19 @@ fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap<Cow<'a, st }
/*
- There are three kinds of implicit conversions in PTX:
+ There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
* special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- * pointer dereference in st/ld: not documented, but for instruction `ld.<space>.<type> x, [y]` semantics are x = *(<type>*)y
+ - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
+ semantics are to first zext/chop/bitcast `y` as needed and then do
+ documented special ld/st/cvt conversion rules for destination operands
+ - generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64,
+ which is bitcast to a pointer, dereferenced and then documented special
+ ld/st/cvt conversion rules are applied
+ - generic ld: for instruction `ld [x], y`, x must be of type b64/u64/s64,
+ which is bitcast to a pointer
*/
-fn insert_implicit_conversion<TypeCheck: Fn(spirv::Word) -> ast::Type>(
+fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
normalized_ids: Vec<Statement>,
unique_ids: spirv::Word,
type_check: &TypeCheck,
@@ -190,16 +225,42 @@ fn insert_implicit_conversion<TypeCheck: Fn(spirv::Word) -> ast::Type>( for s in normalized_ids.into_iter() {
match s {
Statement::Instruction(inst) => match inst {
- ast::Instruction::Add(add, arg) => {
- arg.insert_implicit_conversions(
+ ast::Instruction::Ld(ld, mut arg) => {
+ let new_arg_src = arg.src.map_id(&mut |arg_src| {
+ insert_implicit_conversions_ld_src(
+ &mut result,
+ ast::Type::Scalar(ld.typ),
+ type_check,
+ new_id,
+ |instr, op| ld.state_space.should_convert(instr, op),
+ arg_src,
+ )
+ });
+ arg.src = new_arg_src;
+ insert_implicit_bitcasts(
+ false,
+ true,
&mut result,
- ast::Type::Scalar(add.typ),
type_check,
new_id,
- |arg| Statement::Instruction(ast::Instruction::Add(add, arg)),
+ ast::Instruction::Ld(ld, arg),
);
}
- _ => todo!(),
+ ast::Instruction::St(st, mut arg) => {
+ let arg_dst_type = type_check(arg.dst);
+ let new_dst = new_id();
+ result.push(Statement::Converison(ImplicitConversion{
+ src: arg.dst,
+ dst: new_dst,
+ from: arg_dst_type,
+ to: ast::Type::Scalar(st.typ),
+ kind: ConversionKind::Ptr
+ }));
+ arg.dst = new_dst;
+ }
+ inst @ _ => {
+ insert_implicit_bitcasts(true, true, &mut result, type_check, new_id, inst)
+ }
},
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
Statement::Converison(_) => unreachable!(),
@@ -236,10 +297,15 @@ fn emit_function_args( }
}
-fn collect_arg_ids<'a>(result: &mut HashMap<&'a str, spirv::Word>, args: &'a [ast::Argument<'a>]) {
+fn collect_arg_ids<'a>(
+ result: &mut HashMap<&'a str, spirv::Word>,
+ type_check: &mut HashMap<spirv::Word, ast::Type>,
+ args: &'a [ast::Argument<'a>],
+) {
let mut id = result.len() as u32;
for arg in args {
result.insert(arg.name, id);
+ type_check.insert(id, ast::Type::Scalar(arg.a_type));
id += 1;
}
}
@@ -263,7 +329,6 @@ fn collect_label_ids<'a>( fn emit_function_body_ops(
builder: &mut dr::Builder,
- id_offset: spirv::Word,
map: &mut TypeWordMap,
func: &[Statement],
cfg: &[BasicBlock],
@@ -276,56 +341,40 @@ fn emit_function_body_ops( continue;
}
let header_id = if let Statement::Label(id) = body[0] {
- Some(id_offset + id)
+ Some(id)
} else {
None
};
builder.begin_block(header_id)?;
for s in body {
match s {
- // If block startd with a label it has already been emitted,
+ // If block starts with a label it has already been emitted,
// all other labels in the block are unused
Statement::Label(_) => (),
- Statement::Converison(_) => todo!(),
+ Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conditional(bra) => {
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
Statement::Instruction(inst) => match inst {
// SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => {
- builder.branch(arg.src + id_offset)?;
+ builder.branch(arg.src)?;
}
ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
todo!()
}
let src = match arg.src {
- ast::Operand::Reg(id) => id + id_offset,
+ ast::Operand::Reg(id) => id,
_ => todo!(),
};
let result_type = map.get_or_add_scalar(builder, data.typ);
match data.state_space {
ast::LdStateSpace::Generic => {
- // TODO: make the cast optional
- let ptr_result_type = map.get_or_add(
- builder,
- SpirvType::Pointer(
- data.typ,
- spirv::StorageClass::CrossWorkgroup,
- ),
- );
- let bitcast =
- builder.convert_u_to_ptr(ptr_result_type, None, src)?;
- builder.load(
- result_type,
- Some(arg.dst + id_offset),
- bitcast,
- None,
- [],
- )?;
+ builder.load(result_type, Some(arg.dst), src, None, [])?;
}
ast::LdStateSpace::Param => {
- builder.copy_object(result_type, Some(arg.dst + id_offset), src)?;
+ builder.copy_object(result_type, Some(arg.dst), src)?;
}
_ => todo!(),
}
@@ -338,17 +387,10 @@ fn emit_function_body_ops( todo!()
}
let src = match arg.src {
- ast::Operand::Reg(id) => id + id_offset,
+ ast::Operand::Reg(id) => id,
_ => todo!(),
};
- // TODO make cast optional
- let ptr_result_type = map.get_or_add(
- builder,
- SpirvType::Pointer(data.typ, spirv::StorageClass::CrossWorkgroup),
- );
- let bitcast =
- builder.convert_u_to_ptr(ptr_result_type, None, arg.dst + id_offset)?;
- builder.store(bitcast, src, None, &[])?;
+ builder.store(arg.dst, src, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
@@ -360,12 +402,76 @@ fn emit_function_body_ops( Ok(())
}
+fn emit_implicit_conversion(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ cv: &ImplicitConversion,
+) -> Result<(), dr::Error> {
+ let (from_type, to_type) = match (cv.from, cv.to) {
+ (ast::Type::Scalar(from), ast::Type::Scalar(to)) => (from, to),
+ _ => todo!(),
+ };
+ match cv.kind {
+ ConversionKind::Ptr => {
+ let dst_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(to_type, spirv_headers::StorageClass::Generic),
+ );
+ builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
+ }
+ ConversionKind::Default => {
+ if from_type.width() == to_type.width() {
+ if from_type.kind() == ScalarKind::Unsigned && to_type.kind() == ScalarKind::Byte
+ || from_type.kind() == ScalarKind::Byte
+ && to_type.kind() == ScalarKind::Unsigned
+ {
+ return Ok(());
+ }
+ let dst_type = map.get_or_add_scalar(builder, to_type);
+ builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
+ } else {
+ let as_unsigned_type = map.get_or_add_scalar(
+ builder,
+ ast::ScalarType::from_parts(from_type.width(), ScalarKind::Unsigned),
+ );
+ let as_unsigned = builder.bitcast(as_unsigned_type, None, cv.src)?;
+ let as_unsigned_wide_type =
+ ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned);
+ let as_unsigned_wide_spirv = map.get_or_add_scalar(
+ builder,
+ ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned),
+ );
+ if to_type.kind() == ScalarKind::Unsigned || to_type.kind() == ScalarKind::Byte {
+ builder.u_convert(as_unsigned_wide_spirv, Some(cv.dst), as_unsigned)?;
+ } else {
+ let as_unsigned_wide =
+ builder.u_convert(as_unsigned_wide_spirv, None, as_unsigned)?;
+ emit_implicit_conversion(
+ builder,
+ map,
+ &ImplicitConversion {
+ src: as_unsigned_wide,
+ dst: cv.dst,
+ from: ast::Type::Scalar(as_unsigned_wide_type),
+ to: cv.to,
+ kind: ConversionKind::Default,
+ },
+ )?;
+ }
+ }
+ }
+ ConversionKind::SignExtend => todo!(),
+ }
+ Ok(())
+}
+
// TODO: support scopes
fn normalize_identifiers<'a>(
func: Vec<ast::Statement<&'a str>>,
constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
+ type_map: &mut HashMap<spirv::Word, ast::Type>,
types: HashMap<Cow<'a, str>, ast::Type>,
-) -> (Vec<Statement>, spirv::Word, HashMap<spirv::Word, ast::Type>) {
+) -> (Vec<Statement>, spirv::Word) {
let mut result = Vec::with_capacity(func.len());
let mut id: u32 = constant_identifiers.len() as u32;
let mut remapped_ids = HashMap::new();
@@ -389,11 +495,12 @@ fn normalize_identifiers<'a>( for s in func {
Statement::from_ast(s, &mut result, &mut get_or_add);
}
- let mut type_map = HashMap::with_capacity(types.len());
- for (old_id, new_id) in remapped_ids {
- type_map.insert(new_id, types[old_id]);
- }
- (result, id, type_map)
+ type_map.extend(
+ remapped_ids
+ .into_iter()
+ .map(|(old_id, new_id)| (new_id, types[old_id])),
+ );
+ (result, id)
}
fn ssa_legalize(
@@ -911,10 +1018,17 @@ impl BrachCondition { }
struct ImplicitConversion {
- dst: spirv::Word,
src: spirv::Word,
+ dst: spirv::Word,
from: ast::Type,
to: ast::Type,
+ kind: ConversionKind,
+}
+
+enum ConversionKind {
+ Default, // zero-extend/chop/bitcast depending on types
+ SignExtend,
+ Ptr,
}
impl ImplicitConversion {
@@ -1050,6 +1164,16 @@ impl<T> ast::Instruction<T> { ast::Instruction::Ret(_) => (),
}
}
+
+ fn get_type(&self) -> Option<ast::Type> {
+ match self {
+ ast::Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)),
+ ast::Instruction::Ret(_) => None,
+ ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
+ ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
+ _ => todo!(),
+ }
+ }
}
impl<T: Copy> ast::Instruction<T> {
@@ -1162,31 +1286,6 @@ impl<T> ast::Arg3<T> { }
}
-impl ast::Arg3<spirv::Word> {
- fn insert_implicit_conversions<
- TypeCheck: Fn(spirv::Word) -> ast::Type,
- NewId: FnMut() -> spirv::Word,
- NewStatement: FnOnce(Self) -> Statement,
- >(
- self,
- func: &mut Vec<Statement>,
- op_type: ast::Type,
- type_check: &TypeCheck,
- new_id: &mut NewId,
- new_statement: NewStatement,
- ) {
- let src1 = self
- .src1
- .insert_implicit_conversion(func, op_type, type_check, new_id);
- let src2 = self
- .src2
- .insert_implicit_conversion(func, op_type, type_check, new_id);
- insert_implicit_conversion_dst(func, op_type, type_check, new_id, self.dst, |dst| {
- new_statement(Self { dst, src1, src2 })
- });
- }
-}
-
impl<T> ast::Arg4<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
ast::Arg4 {
@@ -1266,37 +1365,6 @@ impl<T> ast::Operand<T> { }
}
-impl ast::Operand<spirv::Word> {
- fn insert_implicit_conversion<
- TypeCheck: Fn(spirv::Word) -> ast::Type,
- NewId: FnMut() -> spirv::Word,
- >(
- self,
- func: &mut Vec<Statement>,
- op_type: ast::Type,
- type_check: &TypeCheck,
- new_id: &mut NewId,
- ) -> Self {
- match self {
- ast::Operand::Reg(src) => {
- if type_check(src) == op_type {
- return self;
- }
- let new_src = new_id();
- func.push(Statement::Converison(ImplicitConversion {
- src: src,
- dst: new_src,
- from: type_check(src),
- to: op_type,
- }));
- ast::Operand::Reg(new_src)
- }
- o @ ast::Operand::Imm(_) => o,
- ast::Operand::RegOffset(_, _) => todo!(),
- }
- }
-}
-
impl<T> ast::MovOperand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
match self {
@@ -1320,29 +1388,266 @@ impl<T> ast::MovOperand<T> { }
}
-fn insert_implicit_conversion_dst<
+#[derive(Clone, Copy, PartialEq)]
+enum ScalarKind {
+ Byte,
+ Unsigned,
+ Signed,
+ Float,
+}
+
+impl ast::ScalarType {
+ fn width(self) -> u8 {
+ match self {
+ ast::ScalarType::U8 => 1,
+ ast::ScalarType::S8 => 1,
+ ast::ScalarType::B8 => 1,
+ ast::ScalarType::U16 => 2,
+ ast::ScalarType::S16 => 2,
+ ast::ScalarType::B16 => 2,
+ ast::ScalarType::F16 => 2,
+ ast::ScalarType::U32 => 4,
+ ast::ScalarType::S32 => 4,
+ ast::ScalarType::B32 => 4,
+ ast::ScalarType::F32 => 4,
+ ast::ScalarType::U64 => 8,
+ ast::ScalarType::S64 => 8,
+ ast::ScalarType::B64 => 8,
+ ast::ScalarType::F64 => 8,
+ }
+ }
+
+ fn kind(self) -> ScalarKind {
+ match self {
+ ast::ScalarType::U8 => ScalarKind::Unsigned,
+ ast::ScalarType::U16 => ScalarKind::Unsigned,
+ ast::ScalarType::U32 => ScalarKind::Unsigned,
+ ast::ScalarType::U64 => ScalarKind::Unsigned,
+ ast::ScalarType::S8 => ScalarKind::Signed,
+ ast::ScalarType::S16 => ScalarKind::Signed,
+ ast::ScalarType::S32 => ScalarKind::Signed,
+ ast::ScalarType::S64 => ScalarKind::Signed,
+ ast::ScalarType::B8 => ScalarKind::Byte,
+ ast::ScalarType::B16 => ScalarKind::Byte,
+ ast::ScalarType::B32 => ScalarKind::Byte,
+ ast::ScalarType::B64 => ScalarKind::Byte,
+ ast::ScalarType::F16 => ScalarKind::Float,
+ ast::ScalarType::F32 => ScalarKind::Float,
+ ast::ScalarType::F64 => ScalarKind::Float,
+ }
+ }
+
+ fn from_parts(width: u8, kind: ScalarKind) -> Self {
+ match kind {
+ ScalarKind::Float => match width {
+ 2 => ast::ScalarType::F16,
+ 4 => ast::ScalarType::F32,
+ 8 => ast::ScalarType::F64,
+ _ => unreachable!(),
+ },
+ ScalarKind::Byte => match width {
+ 1 => ast::ScalarType::B8,
+ 2 => ast::ScalarType::B16,
+ 4 => ast::ScalarType::B32,
+ 8 => ast::ScalarType::B64,
+ _ => unreachable!(),
+ },
+ ScalarKind::Signed => match width {
+ 1 => ast::ScalarType::S8,
+ 2 => ast::ScalarType::S16,
+ 4 => ast::ScalarType::S32,
+ 8 => ast::ScalarType::S64,
+ _ => unreachable!(),
+ },
+ ScalarKind::Unsigned => match width {
+ 1 => ast::ScalarType::U8,
+ 2 => ast::ScalarType::U16,
+ 4 => ast::ScalarType::U32,
+ 8 => ast::ScalarType::U64,
+ _ => unreachable!(),
+ },
+ }
+ }
+}
+
+fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
+ match (instr, operand) {
+ (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
+ if inst.width() != operand.width() {
+ return false;
+ }
+ match inst.kind() {
+ ScalarKind::Byte => operand.kind() != ScalarKind::Byte,
+ ScalarKind::Float => operand.kind() == ScalarKind::Byte,
+ ScalarKind::Signed => {
+ operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Unsigned
+ }
+ ScalarKind::Unsigned => {
+ operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
+ }
+ }
+ }
+ _ => false,
+ }
+}
+
+impl ast::LdStateSpace {
+ fn should_convert(self, instr_type: ast::Type, op_type: ast::Type) -> Option<ConversionKind> {
+ match self {
+ ast::LdStateSpace::Param => {
+ if instr_type != op_type {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::LdStateSpace::Generic => Some(ConversionKind::Ptr),
+ _ => todo!(),
+ }
+ }
+}
+
+fn insert_forced_bitcast_src<
TypeCheck: Fn(spirv::Word) -> ast::Type,
NewId: FnMut() -> spirv::Word,
- NewStatement: FnOnce(spirv::Word) -> Statement,
>(
func: &mut Vec<Statement>,
op_type: ast::Type,
type_check: &TypeCheck,
new_id: &mut NewId,
- dst: spirv::Word,
- new_statement: NewStatement,
-) {
- if type_check(dst) == op_type {
- func.push(new_statement(dst));
- } else {
- let new_dst = new_id();
- func.push(new_statement(new_dst));
+ src: spirv::Word,
+) -> spirv::Word {
+ let src_type = type_check(src);
+ if src_type == op_type {
+ return src;
+ }
+ let new_src = new_id();
+ func.push(Statement::Converison(ImplicitConversion {
+ src: src,
+ dst: new_src,
+ from: src_type,
+ to: op_type,
+ kind: ConversionKind::Default,
+ }));
+ new_src
+}
+
+fn insert_implicit_conversions_ld_src<
+ TypeCheck: Fn(spirv::Word) -> ast::Type,
+ NewId: FnMut() -> spirv::Word,
+ ShouldConvert: Fn(ast::Type, ast::Type) -> Option<ConversionKind>,
+>(
+ func: &mut Vec<Statement>,
+ instr_type: ast::Type,
+ type_check: &TypeCheck,
+ new_id: &mut NewId,
+ should_convert: ShouldConvert,
+ src: spirv::Word,
+) -> spirv::Word {
+ let src_type = type_check(src);
+ if let Some(conv_kind) = should_convert(src_type, instr_type) {
+ let new_src = new_id();
func.push(Statement::Converison(ImplicitConversion {
- src: new_dst,
- dst: dst,
- from: type_check(new_dst),
- to: op_type,
+ src: src,
+ dst: new_src,
+ from: src_type,
+ to: instr_type,
+ kind: conv_kind,
}));
+ new_src
+ } else {
+ src
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
+fn should_convert_relaxed_dst(
+ dst_type: ast::Type,
+ instr_type: ast::ScalarType,
+) -> Option<ConversionKind> {
+ if dst_type == ast::Type::Scalar(instr_type) {
+ return None;
+ }
+ match dst_type {
+ ast::Type::Scalar(dst_type) => match instr_type.kind() {
+ ScalarKind::Byte => {
+ if instr_type.width() <= dst_type.width() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ScalarKind::Signed => {
+ if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
+ Some(ConversionKind::SignExtend)
+ } else {
+ None
+ }
+ }
+ ScalarKind::Unsigned => {
+ if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ScalarKind::Float => {
+ if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Float {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ },
+ _ => None,
+ }
+}
+
+fn insert_implicit_bitcasts<
+ TypeCheck: Fn(spirv::Word) -> ast::Type,
+ NewId: FnMut() -> spirv::Word,
+>(
+ do_src_bitcast: bool,
+ do_dst_bitcast: bool,
+ func: &mut Vec<Statement>,
+ type_check: &TypeCheck,
+ new_id: &mut NewId,
+ mut instr: ast::Instruction<spirv::Word>,
+) {
+ let mut dst_coercion = None;
+ if let Some(instr_type) = instr.get_type() {
+ instr.visit_id_mut(&mut |is_dst, id| {
+ if (is_dst && !do_dst_bitcast) || (!is_dst && !do_src_bitcast) {
+ return;
+ }
+ let id_type = type_check(*id);
+ if should_bitcast(instr_type, type_check(*id)) {
+ let replacement_id = new_id();
+ if is_dst {
+ dst_coercion = Some(ImplicitConversion {
+ src: replacement_id,
+ dst: *id,
+ from: instr_type,
+ to: id_type,
+ kind: ConversionKind::Default,
+ });
+ *id = replacement_id;
+ } else {
+ func.push(Statement::Converison(ImplicitConversion {
+ src: *id,
+ dst: replacement_id,
+ from: id_type,
+ to: instr_type,
+ kind: ConversionKind::Default,
+ }));
+ *id = replacement_id;
+ }
+ }
+ });
+ }
+ func.push(Statement::Instruction(instr));
+ if let Some(cond) = dst_coercion {
+ func.push(Statement::Converison(cond));
}
}
@@ -1678,6 +1983,12 @@ mod tests { // page 403
const FIG_19_4: &'static str = "{
+ .reg.u32 i;
+ .reg.u32 j;
+ .reg.u32 k;
+ .reg.pred p;
+ .reg.pred q;
+
mov.u32 i, 1;
mov.u32 j, 1;
mov.u32 k, 0;
@@ -1710,7 +2021,9 @@ mod tests { assert_eq!(errors.len(), 0);
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &ast);
- let (normalized_ids, _) = normalize_identifiers(ast, &constant_ids);
+ let registers = collect_var_definitions(&[], &ast);
+ let (normalized_ids, _) =
+ normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers);
let mut bbs = get_basic_blocks(&normalized_ids);
bbs.iter_mut().for_each(sort_pred_succ);
assert_eq!(
@@ -1857,7 +2170,9 @@ mod tests { let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast);
assert_eq!(constant_ids.len(), 4);
- let (normalized_ids, max_id) = normalize_identifiers(fn_ast, &constant_ids);
+ let registers = collect_var_definitions(&[], &fn_ast);
+ let (normalized_ids, max_id) =
+ normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers);
let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
@@ -1895,21 +2210,7 @@ mod tests { .parse(&mut errors, func)
.unwrap();
assert_eq!(errors.len(), 0);
- let mut constant_ids = HashMap::new();
- collect_label_ids(&mut constant_ids, &fn_ast);
- let (mut func, unique_ids) = normalize_identifiers(fn_ast, &constant_ids);
- let bbs = get_basic_blocks(&func);
- let rpostorder = to_reverse_postorder(&bbs);
- let doms = immediate_dominators(&bbs, &rpostorder);
- let dom_fronts = dominance_frontiers(&bbs, &doms);
- let (mut ssa_phis, _) = ssa_legalize(
- &mut func,
- constant_ids.len() as u32,
- unique_ids,
- &bbs,
- &doms,
- &dom_fronts,
- );
+ let (func, _, mut ssa_phis, unique_ids) = to_ssa(&[], fn_ast);
assert_phi_dst_id(unique_ids, &ssa_phis);
assert_dst_unique(&func, &ssa_phis);
sort_phi(&mut ssa_phis);
|