aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-14 21:45:56 +0200
committerAndrzej Janik <[email protected]>2020-09-14 21:45:56 +0200
commitbb5025c9b17e3fc46e454ca8faab1e85e0361ba8 (patch)
tree07df096e1ad16e8c9464aac17c99194e7257937e /ptx/src/translate.rs
parent48dac435400117935624aed244d1442982c874e2 (diff)
downloadZLUDA-bb5025c9b17e3fc46e454ca8faab1e85e0361ba8.tar.gz
ZLUDA-bb5025c9b17e3fc46e454ca8faab1e85e0361ba8.zip
Refactor implicit conversions and start implementing vector extract/insert
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs1270
1 files changed, 849 insertions, 421 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 7591722..57d3485 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,5 +1,5 @@
use crate::ast;
-use rspirv::dr;
+use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, iter, mem};
@@ -398,7 +398,8 @@ fn normalize_labels(
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
- Statement::Call(_)
+ Statement::Composite(_)
+ | Statement::Call(_)
| Statement::Variable(_)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
@@ -528,13 +529,13 @@ fn insert_mem_ssa_statements<'a, 'b>(
ast::Instruction::Ret(d) => {
if let Some(out_param) = out_param {
let typ = id_def.get_type(out_param);
- let new_id = id_def.new_id(Some(typ));
+ let new_id = id_def.new_id(typ);
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
src: out_param,
},
- typ,
+ typ.unwrap(),
));
result.push(Statement::RetValue(d, new_id));
} else {
@@ -561,19 +562,25 @@ fn insert_mem_ssa_statements<'a, 'b>(
| Statement::Conversion(_)
| Statement::RetValue(_, _)
| Statement::Constant(_) => unreachable!(),
+ Statement::Composite(_) => todo!(),
}
}
(f_args, result)
}
trait VisitVariable: Sized {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement;
}
trait VisitVariableExpanded {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement;
@@ -585,8 +592,8 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
stmt: F,
) {
let mut post_statements = Vec::new();
- let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>| {
- let id_type = match (desc.typ, desc.is_pointer) {
+ let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
+ let id_type = match (id_def.get_type(desc.op), desc.is_pointer) {
(Some(t), false) => t,
(Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64),
(None, _) => return desc.op,
@@ -624,13 +631,15 @@ fn expand_arguments<'a, 'b>(
match s {
Statement::Call(call) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let new_call = call.map(&mut visitor);
+ let (new_call, post_stmts) = (call.map(&mut visitor), visitor.post_stmts);
result.push(Statement::Call(new_call));
+ result.extend(post_stmts);
}
Statement::Instruction(inst) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let new_inst = inst.map(&mut visitor);
+ let (new_inst, post_stmts) = (inst.map(&mut visitor), visitor.post_stmts);
result.push(Statement::Instruction(new_inst));
+ result.extend(post_stmts);
}
Statement::Variable(ast::Variable {
align,
@@ -646,7 +655,9 @@ fn expand_arguments<'a, 'b>(
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Conversion(_) | Statement::Constant(_) => unreachable!(),
+ Statement::Composite(_) | Statement::Conversion(_) | Statement::Constant(_) => {
+ unreachable!()
+ }
}
}
result
@@ -655,74 +666,79 @@ fn expand_arguments<'a, 'b>(
struct FlattenArguments<'a, 'b> {
func: &'b mut Vec<ExpandedStatement>,
id_def: &'b mut NumericIdResolver<'a>,
+ post_stmts: Vec<ExpandedStatement>,
}
impl<'a, 'b> FlattenArguments<'a, 'b> {
fn new(func: &'b mut Vec<ExpandedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
- FlattenArguments { func, id_def }
+ FlattenArguments {
+ func,
+ id_def,
+ post_stmts: Vec::new(),
+ }
}
}
impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
for FlattenArguments<'a, 'b>
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<ast::Type>,
+ ) -> spirv::Word {
desc.op
}
- fn operand(&mut self, desc: ArgumentDescriptor<ast::Operand<spirv::Word>>) -> spirv::Word {
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> spirv::Word {
match desc.op {
- ast::Operand::Reg(r) => self.variable(desc.new_op(r)),
+ ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)),
ast::Operand::Imm(x) => {
- if let Some(typ) = desc.typ {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id,
- typ: scalar_t,
- value: x,
- }));
- id
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
} else {
todo!()
- }
+ };
+ let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value: x,
+ }));
+ id
}
ast::Operand::RegOffset(reg, offset) => {
- if let Some(typ) = desc.typ {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i128,
- }));
- let result_id = self.id_def.new_id(desc.typ);
- let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- result_id
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
} else {
todo!()
- }
+ };
+ let id_constant_stmt = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: scalar_t,
+ value: offset as i128,
+ }));
+ let result_id = self.id_def.new_id(Some(typ));
+ let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ result_id
}
}
}
@@ -730,18 +746,45 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ typ: ast::Type,
) -> spirv::Word {
match desc.op {
- ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg)),
- ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x))),
+ ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)),
+ ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ),
}
}
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- ) -> (spirv::Word, u8) {
- (self.variable(desc.new_op(desc.op.0)), desc.op.1)
+ typ: ast::MovVectorType,
+ ) -> spirv::Word {
+ let (vector_id, index) = desc.op;
+ let new_id = self.id_def.new_id(Some(ast::Type::Scalar(typ.into())));
+ let composite = if desc.is_dst {
+ Statement::Composite(CompositeAccess {
+ typ: typ,
+ dst: new_id,
+ src: vector_id,
+ index: index as u32,
+ is_write: true
+ })
+ } else {
+ Statement::Composite(CompositeAccess {
+ typ: typ,
+ dst: new_id,
+ src: vector_id,
+ index: index as u32,
+ is_write: false
+ })
+ };
+ if desc.is_dst {
+ self.post_stmts.push(composite);
+ new_id
+ } else {
+ self.func.push(composite);
+ new_id
+ }
}
}
@@ -768,48 +811,63 @@ fn insert_implicit_conversions(
match s {
Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call),
Statement::Instruction(inst) => match inst {
- ast::Instruction::Ld(ld, mut arg) => {
- arg.src = insert_implicit_conversions_ld_src(
- &mut result,
- ast::Type::Scalar(ld.typ),
+ ast::Instruction::Ld(ld, arg) => {
+ let pre_conv =
+ get_implicit_conversions_ld_src(id_def, ld.typ, ld.state_space, arg.src);
+ let post_conv = get_implicit_conversions_ld_dst(
id_def,
- ld.state_space,
- arg.src,
+ ld.typ,
+ arg.dst,
+ should_convert_relaxed_dst,
+ false,
);
- insert_with_implicit_conversion_dst(
+ insert_with_conversions(
&mut result,
- ld.typ,
id_def,
- should_convert_relaxed_dst,
arg,
+ pre_conv.into_iter(),
+ iter::empty(),
+ post_conv.into_iter().collect(),
+ |arg| &mut arg.src,
|arg| &mut arg.dst,
|arg| ast::Instruction::Ld(ld, arg),
- );
+ )
}
- ast::Instruction::St(st, mut arg) => {
- let arg_src2_type = id_def.get_type(arg.src2);
- if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
- arg.src2 = insert_conversion_src(
- &mut result,
- id_def,
- arg.src2,
- arg_src2_type,
- ast::Type::Scalar(st.typ),
- conv,
- );
- }
- arg.src1 = insert_implicit_conversions_ld_src(
- &mut result,
- ast::Type::Scalar(st.typ),
+ ast::Instruction::St(st, arg) => {
+ let pre_conv = get_implicit_conversions_ld_dst(
id_def,
+ st.typ,
+ arg.src2,
+ should_convert_relaxed_src,
+ true,
+ );
+ let post_conv = get_implicit_conversions_ld_src(
+ id_def,
+ st.typ,
st.state_space.to_ld_ss(),
arg.src1,
);
- result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
+ let (pre_conv_dest, post_conv) = if st.state_space == ast::StStateSpace::Param {
+ (Vec::new(), post_conv)
+ } else {
+ (post_conv, Vec::new())
+ };
+ insert_with_conversions(
+ &mut result,
+ id_def,
+ arg,
+ pre_conv.into_iter(),
+ pre_conv_dest.into_iter(),
+ post_conv,
+ |arg| &mut arg.src2,
+ |arg| &mut arg.src1,
+ |arg| ast::Instruction::St(st, arg),
+ )
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
},
- s @ Statement::Conditional(_)
+ s @ Statement::Composite(_)
+ | s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
@@ -950,10 +1008,10 @@ fn emit_function_body_ops(
builder.branch(arg.src)?;
}
ast::Instruction::Ld(data, arg) => {
- if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
+ if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
- let result_type = map.get_or_add_scalar(builder, data.typ);
+ let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
match data.state_space {
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
@@ -967,7 +1025,6 @@ fn emit_function_body_ops(
}
ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
- || data.vector.is_some()
|| (data.state_space != ast::StStateSpace::Generic
&& data.state_space != ast::StStateSpace::Param
&& data.state_space != ast::StStateSpace::Global)
@@ -1030,7 +1087,10 @@ fn emit_function_body_ops(
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
- ast::Instruction::MovVector(_, _) => todo!(),
+ ast::Instruction::MovVector(t, arg) => {
+ let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
+ builder.copy_object(result_type, Some(arg.dst()), arg.src())?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@@ -1042,6 +1102,19 @@ fn emit_function_body_ops(
Statement::RetValue(_, id) => {
builder.ret_value(*id)?;
}
+ Statement::Composite(c) => {
+ let result_type = map.get_or_add_scalar(builder, c.typ.into());
+ let result_id = Some(c.dst);
+ let indexes = [c.index];
+ if c.is_write {
+ let object = c.src;
+ let composite = c.dst;
+ builder.composite_insert(result_type, result_id, object, composite, indexes)?;
+ } else {
+ let composite = c.src;
+ builder.composite_extract(result_type, result_id, composite, indexes)?;
+ }
+ }
}
}
Ok(())
@@ -1188,7 +1261,7 @@ fn emit_setp(
match (setp.cmp_op, setp.typ.kind()) {
(ast::SetpCompareOp::Eq, ScalarKind::Signed)
| (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Eq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Eq, ScalarKind::Float) => {
@@ -1196,14 +1269,14 @@ fn emit_setp(
}
(ast::SetpCompareOp::NotEq, ScalarKind::Signed)
| (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Less, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Less, ScalarKind::Signed) => {
@@ -1213,7 +1286,7 @@ fn emit_setp(
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::LessOrEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
@@ -1223,7 +1296,7 @@ fn emit_setp(
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Greater, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
@@ -1233,7 +1306,7 @@ fn emit_setp(
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Byte) => {
+ | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
@@ -1294,54 +1367,56 @@ fn emit_implicit_conversion(
map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), dr::Error> {
- let (from_type, to_type) = match (cv.from, cv.to) {
- (ast::Type::Scalar(from), ast::Type::Scalar(to)) => (from, to),
- _ => todo!(),
- };
+ let from_parts = cv.from.to_parts();
+ let to_parts = cv.to.to_parts();
match cv.kind {
ConversionKind::Ptr(space) => {
let dst_type = map.get_or_add(
builder,
- SpirvType::Pointer(
- Box::new(SpirvType::Base(SpirvScalarKey::from(to_type))),
- space.to_spirv(),
- ),
+ SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
ConversionKind::Default => {
- if from_type.width() == to_type.width() {
- let dst_type = map.get_or_add_scalar(builder, to_type);
- if from_type.kind() != ScalarKind::Float && to_type.kind() != ScalarKind::Float {
+ if from_parts.width == to_parts.width {
+ let dst_type = map.get_or_add(builder, SpirvType::from(cv.from));
+ if from_parts.scalar_kind != ScalarKind::Float
+ && to_parts.scalar_kind != ScalarKind::Float
+ {
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
} else {
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
} else {
- let as_unsigned_type = map.get_or_add_scalar(
- builder,
- ast::ScalarType::from_parts(from_type.width(), ScalarKind::Unsigned),
- );
- let as_unsigned = builder.bitcast(as_unsigned_type, None, cv.src)?;
- let as_unsigned_wide_type =
- ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned);
- let as_unsigned_wide_spirv = map.get_or_add_scalar(
+ // This block is safe because it's illegal to implictly convert between floating point instructions
+ let same_width_bit_type = map.get_or_add(
builder,
- ast::ScalarType::from_parts(to_type.width(), ScalarKind::Unsigned),
+ SpirvType::from(ast::Type::from_parts(TypeParts {
+ scalar_kind: ScalarKind::Bit,
+ ..from_parts
+ })),
);
- if to_type.kind() == ScalarKind::Unsigned || to_type.kind() == ScalarKind::Byte {
- builder.u_convert(as_unsigned_wide_spirv, Some(cv.dst), as_unsigned)?;
+ let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
+ let wide_bit_type = ast::Type::from_parts(TypeParts {
+ scalar_kind: ScalarKind::Bit,
+ ..to_parts
+ });
+ let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type));
+ if to_parts.scalar_kind == ScalarKind::Unsigned
+ || to_parts.scalar_kind == ScalarKind::Bit
+ {
+ builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else {
- let as_unsigned_wide =
- builder.u_convert(as_unsigned_wide_spirv, None, as_unsigned)?;
+ let wide_bit_value =
+ builder.u_convert(wide_bit_type_spirv, None, same_width_bit_value)?;
emit_implicit_conversion(
builder,
map,
&ImplicitConversion {
- src: as_unsigned_wide,
+ src: wide_bit_value,
dst: cv.dst,
- from: ast::Type::Scalar(as_unsigned_wide_type),
+ from: wide_bit_type,
to: cv.to,
kind: ConversionKind::Default,
},
@@ -1627,8 +1702,8 @@ struct NumericIdResolver<'b> {
}
impl<'b> NumericIdResolver<'b> {
- fn get_type(&self, id: spirv::Word) -> ast::Type {
- self.type_check[&id]
+ fn get_type(&self, id: spirv::Word) -> Option<ast::Type> {
+ self.type_check.get(&id).map(|x| *x)
}
fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
@@ -1648,6 +1723,7 @@ enum Statement<I, P: ast::ArgParams> {
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
Call(ResolvedCall<P>),
+ Composite(CompositeAccess),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Conversion(ImplicitConversion),
@@ -1671,31 +1747,37 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
.ret_params
.into_iter()
.map(|(id, typ)| {
- let new_id = visitor.variable(ArgumentDescriptor {
- op: id,
- typ: Some(typ.into()),
- is_dst: true,
- is_pointer: false,
- });
+ let new_id = visitor.variable(
+ ArgumentDescriptor {
+ op: id,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(typ.into()),
+ );
(new_id, typ)
})
.collect();
- let func = visitor.variable(ArgumentDescriptor {
- op: self.func,
- typ: None,
- is_dst: false,
- is_pointer: false,
- });
+ let func = visitor.variable(
+ ArgumentDescriptor {
+ op: self.func,
+ is_dst: false,
+ is_pointer: false,
+ },
+ None,
+ );
let param_list = self
.param_list
.into_iter()
.map(|(id, typ)| {
- let new_id = visitor.src_call_operand(ArgumentDescriptor {
- op: id,
- typ: Some(typ.into()),
- is_dst: false,
- is_pointer: false,
- });
+ let new_id = visitor.src_call_operand(
+ ArgumentDescriptor {
+ op: id,
+ is_dst: false,
+ is_pointer: false,
+ },
+ typ.into(),
+ );
(new_id, typ)
})
.collect();
@@ -1709,7 +1791,10 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
}
impl VisitVariable for ResolvedCall<NormalizedArgParams> {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement {
@@ -1718,7 +1803,9 @@ impl VisitVariable for ResolvedCall<NormalizedArgParams> {
}
impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement {
@@ -1750,6 +1837,7 @@ impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
+ type VecOperand = (spirv::Word, u8);
}
impl ArgParamsEx for NormalizedArgParams {
@@ -1766,6 +1854,7 @@ impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
type CallOperand = spirv::Word;
+ type VecOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
@@ -1775,30 +1864,47 @@ impl ArgParamsEx for ExpandedArgParams {
}
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
- fn variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID;
- fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand;
- fn src_call_operand(&mut self, desc: ArgumentDescriptor<T::CallOperand>) -> U::CallOperand;
- fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(T::ID, u8)>) -> (U::ID, u8);
+ fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
+ fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<T::CallOperand>,
+ typ: ast::Type,
+ ) -> U::CallOperand;
+ fn src_vec_operand(
+ &mut self,
+ desc: ArgumentDescriptor<T::VecOperand>,
+ typ: ast::MovVectorType,
+ ) -> U::VecOperand;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
- T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
+ T: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: Option<ast::Type>,
+ ) -> spirv::Word {
+ self(desc, t)
}
- fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
+ self(desc, Some(t))
}
- fn src_call_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc.new_op(desc.op))
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: ast::Type,
+ ) -> spirv::Word {
+ self(desc, Some(t))
}
fn src_vec_operand(
&mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- ) -> (spirv::Word, u8) {
- (self(desc.new_op(desc.op.0)), desc.op.1)
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: ast::MovVectorType,
+ ) -> spirv::Word {
+ self(desc, Some(ast::Type::Scalar(t.into())))
}
}
@@ -1806,13 +1912,14 @@ impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> fo
where
T: FnMut(&str) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word {
+ fn variable(&mut self, desc: ArgumentDescriptor<&str>, _: Option<ast::Type>) -> spirv::Word {
self(desc.op)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
+ _: ast::Type,
) -> ast::Operand<spirv::Word> {
match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)),
@@ -1824,6 +1931,7 @@ where
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<&str>>,
+ _: ast::Type,
) -> ast::CallOperand<spirv::Word> {
match desc.op {
ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(id)),
@@ -1831,15 +1939,18 @@ where
}
}
- fn src_vec_operand(&mut self, desc: ArgumentDescriptor<(&str, u8)>) -> (spirv::Word, u8) {
+ fn src_vec_operand(
+ &mut self,
+ desc: ArgumentDescriptor<(&str, u8)>,
+ _: ast::MovVectorType,
+ ) -> (spirv::Word, u8) {
(self(desc.op.0), desc.op.1)
}
}
-struct ArgumentDescriptor<T> {
- op: T,
+struct ArgumentDescriptor<Op> {
+ op: Op,
is_dst: bool,
- typ: Option<ast::Type>,
is_pointer: bool,
}
@@ -1848,7 +1959,6 @@ impl<T> ArgumentDescriptor<T> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
- typ: self.typ,
is_pointer: self.is_pointer,
}
}
@@ -1860,39 +1970,35 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
visitor: &mut V,
) -> ast::Instruction<U> {
match self {
- ast::Instruction::MovVector(_, _) => todo!(),
+ ast::Instruction::MovVector(t, a) => ast::Instruction::MovVector(t, a.map(visitor, t)),
ast::Instruction::Abs(_, _) => todo!(),
+ // Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
let src_is_pointer = d.state_space != ast::LdStateSpace::Param;
- ast::Instruction::Ld(
- d,
- a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer),
- )
+ ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, src_is_pointer))
}
ast::Instruction::Mov(mov_type, a) => {
- ast::Instruction::Mov(mov_type, a.map(visitor, Some(mov_type.into())))
+ ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into()))
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, Some(inst_type)))
+ ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type))
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, Some(inst_type)))
+ ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type))
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
- ast::Instruction::Setp(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
+ ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type)))
}
ast::Instruction::SetpBool(d, a) => {
let inst_type = d.typ;
- ast::Instruction::SetpBool(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
- }
- ast::Instruction::Not(t, a) => {
- ast::Instruction::Not(t, a.map(visitor, Some(t.to_type())))
+ ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type)))
}
+ ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, t.to_type())),
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -1915,28 +2021,28 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t))
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, Some(t.to_type())))
+ ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type()))
}
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
let param_space = d.state_space == ast::StStateSpace::Param;
- ast::Instruction::St(
- d,
- a.map(visitor, Some(ast::Type::Scalar(inst_type)), param_space),
- )
+ ast::Instruction::St(d, a.map(visitor, inst_type, param_space))
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
- ast::Instruction::Cvta(d, a.map(visitor, Some(inst_type)))
+ ast::Instruction::Cvta(d, a.map(visitor, inst_type))
}
}
}
}
impl VisitVariable for ast::Instruction<NormalizedArgParams> {
- fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable<
+ 'a,
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> UnadornedStatement {
@@ -1946,29 +2052,37 @@ impl VisitVariable for ast::Instruction<NormalizedArgParams> {
impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
where
- T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
+ T: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
{
- fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
- self(desc)
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: Option<ast::Type>,
+ ) -> spirv::Word {
+ self(desc, t)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ t: ast::Type,
) -> ast::Operand<spirv::Word> {
match desc.op {
- ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id))),
+ ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id), Some(t))),
ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
- ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(desc.new_op(id)), imm),
+ ast::Operand::RegOffset(id, imm) => {
+ ast::Operand::RegOffset(self(desc.new_op(id), Some(t)), imm)
+ }
}
}
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ t: ast::Type,
) -> ast::CallOperand<spirv::Word> {
match desc.op {
- ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id))),
+ ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id), Some(t))),
ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm),
}
}
@@ -1976,11 +2090,74 @@ where
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
+ t: ast::MovVectorType,
) -> (spirv::Word, u8) {
- (self(desc.new_op(desc.op.0)), desc.op.1)
+ (
+ self(
+ desc.new_op(desc.op.0),
+ Some(ast::Type::Vector(t.into(), desc.op.1)),
+ ),
+ desc.op.1,
+ )
+ }
+}
+
+impl ast::Type {
+ fn to_parts(self) -> TypeParts {
+ match self {
+ ast::Type::Scalar(scalar) => TypeParts {
+ kind: TypeKind::Scalar,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: 0,
+ },
+ ast::Type::Vector(scalar, components) => TypeParts {
+ kind: TypeKind::Vector,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: components as u32,
+ },
+ ast::Type::Array(scalar, components) => TypeParts {
+ kind: TypeKind::Array,
+ scalar_kind: scalar.kind(),
+ width: scalar.width(),
+ components: components,
+ },
+ }
+ }
+
+ fn from_parts(t: TypeParts) -> Self {
+ match t.kind {
+ TypeKind::Scalar => {
+ ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind))
+ }
+ TypeKind::Vector => ast::Type::Vector(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
+ t.components as u8,
+ ),
+ TypeKind::Array => ast::Type::Array(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
+ t.components,
+ ),
+ }
}
}
+#[derive(Eq, PartialEq, Copy, Clone)]
+struct TypeParts {
+ kind: TypeKind,
+ scalar_kind: ScalarKind,
+ width: u8,
+ components: u32,
+}
+
+#[derive(Eq, PartialEq, Copy, Clone)]
+enum TypeKind {
+ Scalar,
+ Vector,
+ Array,
+}
+
impl ast::Instruction<ExpandedArgParams> {
fn jump_target(&self) -> Option<spirv::Word> {
match self {
@@ -2005,7 +2182,9 @@ impl ast::Instruction<ExpandedArgParams> {
}
impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ fn visit_variable_extended<
+ F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ >(
self,
f: &mut F,
) -> ExpandedStatement {
@@ -2016,6 +2195,29 @@ impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
+struct CompositeAccess {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src: spirv::Word,
+ pub index: u32,
+ pub is_write: bool
+}
+
+struct CompositeWrite {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src_composite: spirv::Word,
+ pub src_scalar: spirv::Word,
+ pub index: u32,
+}
+
+struct CompositeRead {
+ pub typ: ast::MovVectorType,
+ pub dst: spirv::Word,
+ pub src: spirv::Word,
+ pub index: u32,
+}
+
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
@@ -2028,6 +2230,7 @@ struct BrachCondition {
if_false: spirv::Word,
}
+#[derive(Copy, Clone)]
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
@@ -2036,7 +2239,7 @@ struct ImplicitConversion {
kind: ConversionKind,
}
-#[derive(Debug, PartialEq)]
+#[derive(Debug, PartialEq, Copy, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
@@ -2084,12 +2287,14 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
t: Option<ast::Type>,
) -> ast::Arg1<U> {
ast::Arg1 {
- src: visitor.variable(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ src: visitor.variable(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2098,43 +2303,51 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
fn map_ld<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
is_src_pointer: bool,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: t,
- is_dst: false,
- is_pointer: is_src_pointer,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: is_src_pointer,
+ },
+ t,
+ ),
}
}
@@ -2145,18 +2358,22 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
src_t: ast::Type,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: Some(dst_t),
- is_dst: true,
- is_pointer: false,
- }),
- src: visitor.operand(ArgumentDescriptor {
- op: self.src,
- typ: Some(src_t),
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(dst_t),
+ ),
+ src: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ src_t,
+ ),
}
}
}
@@ -2165,22 +2382,26 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
param_space: bool,
) -> ast::Arg2St<U> {
ast::Arg2St {
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: param_space,
- is_pointer: !param_space,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: param_space,
+ is_pointer: !param_space,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2189,107 +2410,149 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: ast::MovVectorType,
) -> ast::Arg2Vec<U> {
match self {
ast::Arg2Vec::Dst(dst, src) => ast::Arg2Vec::Dst(
- visitor.src_vec_operand(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.variable(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ t,
+ ),
+ visitor.variable(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(t.into())),
+ ),
),
- ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src (
- visitor.variable(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.src_vec_operand(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src(
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(t.into())),
+ ),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
),
- ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both (
- visitor.src_vec_operand(ArgumentDescriptor {
- op: dst,
- typ: Some(t),
- is_dst: true,
- is_pointer: false,
- }),
- visitor.src_vec_operand(ArgumentDescriptor {
- op: src,
- typ: Some(t),
- is_dst: false,
- is_pointer: false,
- }),
+ ast::Arg2Vec::Both(dst, src) => ast::Arg2Vec::Both(
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ t,
+ ),
+ visitor.src_vec_operand(
+ ArgumentDescriptor {
+ op: src,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
),
}
}
}
+impl ast::Arg2Vec<ExpandedArgParams> {
+ fn dst(&self) -> spirv::Word {
+ match self {
+ ast::Arg2Vec::Dst(dst, _) | ast::Arg2Vec::Src(dst, _) | ast::Arg2Vec::Both(dst, _) => {
+ *dst
+ }
+ }
+ }
+
+ fn src(&self) -> spirv::Word {
+ match self {
+ ast::Arg2Vec::Dst(_, src) | ast::Arg2Vec::Src(_, src) | ast::Arg2Vec::Both(_, src) => {
+ *src
+ }
+ }
+ }
+}
+
impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg3<U> {
ast::Arg3 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
fn map_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg3<U> {
ast::Arg3 {
- dst: visitor.variable(ArgumentDescriptor {
- op: self.dst,
- typ: t,
- is_dst: true,
- is_pointer: false,
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::U32)),
- is_dst: false,
- is_pointer: false,
- }),
+ dst: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(t),
+ ),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ),
}
}
}
@@ -2298,35 +2561,43 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg4<U> {
ast::Arg4 {
- dst1: visitor.variable(ArgumentDescriptor {
- op: self.dst1,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: true,
- is_pointer: false,
- }),
- dst2: self.dst2.map(|dst2| {
- visitor.variable(ArgumentDescriptor {
- op: dst2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ dst1: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst1,
is_dst: true,
is_pointer: false,
- })
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ ),
+ dst2: self.dst2.map(|dst2| {
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst2,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ )
}),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
}
}
}
@@ -2335,41 +2606,51 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: ast::Type,
) -> ast::Arg5<U> {
ast::Arg5 {
- dst1: visitor.variable(ArgumentDescriptor {
- op: self.dst1,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: true,
- is_pointer: false,
- }),
- dst2: self.dst2.map(|dst2| {
- visitor.variable(ArgumentDescriptor {
- op: dst2,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ dst1: visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst1,
is_dst: true,
is_pointer: false,
- })
- }),
- src1: visitor.operand(ArgumentDescriptor {
- op: self.src1,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src2: visitor.operand(ArgumentDescriptor {
- op: self.src2,
- typ: t,
- is_dst: false,
- is_pointer: false,
- }),
- src3: visitor.operand(ArgumentDescriptor {
- op: self.src3,
- typ: Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- is_dst: false,
- is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ ),
+ dst2: self.dst2.map(|dst2| {
+ visitor.variable(
+ ArgumentDescriptor {
+ op: dst2,
+ is_dst: true,
+ is_pointer: false,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ )
}),
+ src1: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src2: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_pointer: false,
+ },
+ t,
+ ),
+ src3: visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ is_pointer: false,
+ },
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ),
}
}
}
@@ -2395,9 +2676,9 @@ impl ast::StStateSpace {
}
}
-#[derive(Clone, Copy, PartialEq)]
+#[derive(Clone, Copy, PartialEq, Eq)]
enum ScalarKind {
- Byte,
+ Bit,
Unsigned,
Signed,
Float,
@@ -2438,10 +2719,10 @@ impl ast::ScalarType {
ast::ScalarType::S16 => ScalarKind::Signed,
ast::ScalarType::S32 => ScalarKind::Signed,
ast::ScalarType::S64 => ScalarKind::Signed,
- ast::ScalarType::B8 => ScalarKind::Byte,
- ast::ScalarType::B16 => ScalarKind::Byte,
- ast::ScalarType::B32 => ScalarKind::Byte,
- ast::ScalarType::B64 => ScalarKind::Byte,
+ ast::ScalarType::B8 => ScalarKind::Bit,
+ ast::ScalarType::B16 => ScalarKind::Bit,
+ ast::ScalarType::B32 => ScalarKind::Bit,
+ ast::ScalarType::B64 => ScalarKind::Bit,
ast::ScalarType::F16 => ScalarKind::Float,
ast::ScalarType::F32 => ScalarKind::Float,
ast::ScalarType::F64 => ScalarKind::Float,
@@ -2458,7 +2739,7 @@ impl ast::ScalarType {
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
- ScalarKind::Byte => match width {
+ ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
@@ -2574,22 +2855,159 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
return false;
}
match inst.kind() {
- ScalarKind::Byte => operand.kind() != ScalarKind::Byte,
- ScalarKind::Float => operand.kind() == ScalarKind::Byte,
+ ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
+ ScalarKind::Float => operand.kind() == ScalarKind::Bit,
ScalarKind::Signed => {
- operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Unsigned
+ operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
}
ScalarKind::Unsigned => {
- operand.kind() == ScalarKind::Byte || operand.kind() == ScalarKind::Signed
+ operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
}
- ScalarKind::Float2 => todo!(),
+ ScalarKind::Float2 => false,
ScalarKind::Pred => false,
}
}
+ (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
+ | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
+ should_bitcast(ast::Type::Scalar(inst), ast::Type::Scalar(operand))
+ }
_ => false,
}
}
+fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>>(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+ mut instr: T,
+ pre_conv_src: impl ExactSizeIterator<Item = ImplicitConversion>,
+ pre_conv_dst: impl ExactSizeIterator<Item = ImplicitConversion>,
+ mut post_conv: Vec<ImplicitConversion>,
+ mut src: impl FnMut(&mut T) -> &mut spirv::Word,
+ mut dst: impl FnMut(&mut T) -> &mut spirv::Word,
+ to_inst: ToInstruction,
+) {
+ insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
+ insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
+ if post_conv.len() > 0 {
+ let new_id = id_def.new_id(Some(post_conv[0].from));
+ post_conv[0].src = new_id;
+ post_conv.last_mut().unwrap().dst = *dst(&mut instr);
+ *dst(&mut instr) = new_id;
+ }
+ func.push(Statement::Instruction(to_inst(instr)));
+ for conv in post_conv {
+ func.push(Statement::Conversion(conv));
+ }
+}
+
+fn insert_with_conversions_pre_conv<T>(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+ mut instr: &mut T,
+ pre_conv: impl ExactSizeIterator<Item = ImplicitConversion>,
+ src: &mut impl FnMut(&mut T) -> &mut spirv::Word,
+) {
+ let pre_conv_len = pre_conv.len();
+ for (i, mut conv) in pre_conv.enumerate() {
+ let original_src = src(&mut instr);
+ if i == 0 {
+ conv.src = *original_src;
+ }
+ if i == pre_conv_len - 1 {
+ let new_id = id_def.new_id(Some(conv.to));
+ conv.dst = new_id;
+ *original_src = new_id;
+ }
+ func.push(Statement::Conversion(conv));
+ }
+}
+
+fn get_implicit_conversions_ld_dst<
+ ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
+>(
+ id_def: &mut NumericIdResolver,
+ instr_type: ast::Type,
+ dst: spirv::Word,
+ should_convert: ShouldConvert,
+ in_reverse: bool,
+) -> Option<ImplicitConversion> {
+ let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!());
+ if let Some(conv) = should_convert(dst_type, instr_type) {
+ Some(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: if !in_reverse { dst_type } else { instr_type },
+ to: if !in_reverse { instr_type } else { dst_type },
+ kind: conv,
+ })
+ } else {
+ None
+ }
+}
+
+fn get_implicit_conversions_ld_src(
+ id_def: &mut NumericIdResolver,
+ instr_type: ast::Type,
+ state_space: ast::LdStateSpace,
+ src: spirv::Word,
+) -> Vec<ImplicitConversion> {
+ let src_type = id_def.get_type(src).unwrap_or_else(|| todo!());
+ match state_space {
+ ast::LdStateSpace::Param => {
+ if src_type != instr_type {
+ vec![
+ ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: instr_type,
+ kind: ConversionKind::Default,
+ };
+ 1
+ ]
+ } else {
+ Vec::new()
+ }
+ }
+ ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
+ let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
+ mem::size_of::<usize>() as u8,
+ ScalarKind::Bit,
+ ));
+ let mut result = Vec::new();
+ // HACK ALERT
+ // IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
+ // additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
+ // TODO: error out if the src is not B64/U64/S64
+ if let ast::Type::Scalar(scalar_src_type) = src_type {
+ if scalar_src_type.kind() == ScalarKind::Signed {
+ result.push(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: new_src_type,
+ kind: ConversionKind::Default,
+ });
+ }
+ }
+ result.push(ImplicitConversion {
+ src: u32::max_value(),
+ dst: u32::max_value(),
+ from: src_type,
+ to: instr_type,
+ kind: ConversionKind::Ptr(state_space),
+ });
+ if result.len() == 2 {
+ let new_id = id_def.new_id(Some(new_src_type));
+ result[0].dst = new_id;
+ result[1].src = new_id;
+ result[1].from = new_src_type;
+ }
+ result
+ }
+ _ => todo!(),
+ }
+}
fn insert_implicit_conversions_ld_src(
func: &mut Vec<ExpandedStatement>,
instr_type: ast::Type,
@@ -2608,7 +3026,7 @@ fn insert_implicit_conversions_ld_src(
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
mem::size_of::<usize>() as u8,
- ScalarKind::Byte,
+ ScalarKind::Bit,
));
let new_src = insert_implicit_conversions_ld_src_impl(
func,
@@ -2640,8 +3058,8 @@ fn insert_implicit_conversions_ld_src_impl<
should_convert: ShouldConvert,
) -> spirv::Word {
let src_type = id_def.get_type(src);
- if let Some(conv) = should_convert(src_type, instr_type) {
- insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
+ if let Some(conv) = should_convert(src_type.unwrap(), instr_type) {
+ insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv)
} else {
src
}
@@ -2692,14 +3110,15 @@ fn insert_conversion_src(
temp_src
}
+/*
fn insert_with_implicit_conversion_dst<
T,
- ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
+ ShouldConvert: FnOnce(ast::StateSpace, ast::Type, ast::Type) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>,
>(
func: &mut Vec<ExpandedStatement>,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
id_def: &mut NumericIdResolver,
should_convert: ShouldConvert,
mut t: T,
@@ -2708,13 +3127,14 @@ fn insert_with_implicit_conversion_dst<
) {
let dst = setter(&mut t);
let dst_type = id_def.get_type(*dst);
- let dst_coercion = should_convert(dst_type, instr_type)
- .map(|conv| get_conversion_dst(id_def, dst, ast::Type::Scalar(instr_type), dst_type, conv));
+ let dst_coercion = should_convert(dst_type.unwrap(), instr_type)
+ .map(|conv| get_conversion_dst(id_def, dst, instr_type, dst_type.unwrap(), conv));
func.push(Statement::Instruction(to_inst(t)));
if let Some(conv) = dst_coercion {
func.push(conv);
}
}
+*/
#[must_use]
fn get_conversion_dst(
@@ -2739,14 +3159,14 @@ fn get_conversion_dst(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
src_type: ast::Type,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
) -> Option<ConversionKind> {
- if src_type == ast::Type::Scalar(instr_type) {
+ if src_type == instr_type {
return None;
}
- match src_type {
- ast::Type::Scalar(src_type) => match instr_type.kind() {
- ScalarKind::Byte => {
+ match (src_type, instr_type) {
+ (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ScalarKind::Bit => {
if instr_type.width() <= src_type.width() {
Some(ConversionKind::Default)
} else {
@@ -2761,7 +3181,7 @@ fn should_convert_relaxed_src(
}
}
ScalarKind::Float => {
- if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Byte {
+ if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Bit {
Some(ConversionKind::Default)
} else {
None
@@ -2770,6 +3190,10 @@ fn should_convert_relaxed_src(
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
+ (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
+ | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
+ should_convert_relaxed_src(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ }
_ => None,
}
}
@@ -2777,14 +3201,14 @@ fn should_convert_relaxed_src(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
dst_type: ast::Type,
- instr_type: ast::ScalarType,
+ instr_type: ast::Type,
) -> Option<ConversionKind> {
- if dst_type == ast::Type::Scalar(instr_type) {
+ if dst_type == instr_type {
return None;
}
- match dst_type {
- ast::Type::Scalar(dst_type) => match instr_type.kind() {
- ScalarKind::Byte => {
+ match (dst_type, instr_type) {
+ (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ScalarKind::Bit => {
if instr_type.width() <= dst_type.width() {
Some(ConversionKind::Default)
} else {
@@ -2812,7 +3236,7 @@ fn should_convert_relaxed_dst(
}
}
ScalarKind::Float => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Byte {
+ if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Bit {
Some(ConversionKind::Default)
} else {
None
@@ -2821,6 +3245,10 @@ fn should_convert_relaxed_dst(
ScalarKind::Float2 => todo!(),
ScalarKind::Pred => None,
},
+ (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
+ | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
+ should_convert_relaxed_dst(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ }
_ => None,
}
}
@@ -2831,13 +3259,13 @@ fn insert_implicit_bitcasts(
stmt: impl VisitVariableExpanded,
) {
let mut dst_coercion = None;
- let instr = stmt.visit_variable_extended(&mut |mut desc| {
- let id_type_from_instr = match desc.typ {
+ let instr = stmt.visit_variable_extended(&mut |mut desc, typ| {
+ let id_type_from_instr = match typ {
Some(t) => t,
None => return desc.op,
};
- let id_actual_type = id_def.get_type(desc.op);
- if should_bitcast(id_type_from_instr, id_def.get_type(desc.op)) {
+ let id_actual_type = id_def.get_type(desc.op).unwrap();
+ if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
@@ -2970,14 +3398,14 @@ mod tests {
.collect::<Vec<_>>()
}
- fn assert_conversion_table<F: Fn(ast::Type, ast::ScalarType) -> Option<ConversionKind>>(
+ fn assert_conversion_table<F: Fn(ast::Type, ast::Type) -> Option<ConversionKind>>(
table: &'static str,
f: F,
) {
let conv_table = parse_conversion_table(table);
for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
- let conversion = f(ast::Type::Scalar(*op_type), *instr_type);
+ let conversion = f(ast::Type::Scalar(*op_type), ast::Type::Scalar(*instr_type));
if instr_idx == op_idx {
assert_eq!(conversion, None);
} else {