summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-07-20 00:01:03 +0200
committerAndrzej Janik <[email protected]>2020-07-20 00:01:03 +0200
commit872d69c714e647bab9192d6ae5105fe2638b4f77 (patch)
treefd2522a931b0f92615c34714814c7081d0639c8b
parent3d6991e0ca808f05025ee84574642efcdd7ed696 (diff)
downloadZLUDA-872d69c714e647bab9192d6ae5105fe2638b4f77.tar.gz
ZLUDA-872d69c714e647bab9192d6ae5105fe2638b4f77.zip
Implement constants in translation middle-end
-rw-r--r--ptx/Cargo.toml1
-rw-r--r--ptx/src/ast.rs67
-rw-r--r--ptx/src/lib.rs1
-rw-r--r--ptx/src/ptx.lalrpop120
-rw-r--r--ptx/src/test/spirv_run/mod.rs4
-rw-r--r--ptx/src/test/spirv_run/mul_hi.ptx22
-rw-r--r--ptx/src/test/spirv_run/mul_hi.spvtxt26
-rw-r--r--ptx/src/test/spirv_run/mul_lo.ptx22
-rw-r--r--ptx/src/test/spirv_run/mul_lo.spvtxt26
-rw-r--r--ptx/src/translate.rs734
10 files changed, 845 insertions, 178 deletions
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml
index d3c4b73..e7be3b7 100644
--- a/ptx/Cargo.toml
+++ b/ptx/Cargo.toml
@@ -14,6 +14,7 @@ spirv_headers = "1.4"
quick-error = "1.2"
bit-vec = "0.6"
paste = "0.1"
+half ="1.6"
[build-dependencies.lalrpop]
version = "0.18.1"
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index c7cb7f7..0efc37c 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -243,15 +243,74 @@ pub struct MovData {
pub typ: Type,
}
-pub struct MulData {}
+pub struct MulData {
+ pub typ: Type,
+ pub desc: MulDescriptor,
+}
+
+pub enum MulDescriptor {
+ Int(MulIntControl),
+ Float(MulFloatDesc),
+}
+
+pub enum MulIntControl {
+ Low,
+ High,
+ Wide
+}
+
+pub struct MulFloatDesc {
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: bool,
+ pub saturate: bool,
+}
+
+pub enum RoundingMode {
+ NearestEven,
+ Zero,
+ NegativeInf,
+ PositiveInf
+}
pub struct AddData {
pub typ: ScalarType,
}
-pub struct SetpData {}
-
-pub struct SetpBoolData {}
+pub struct SetpData {
+ pub typ: ScalarType,
+ pub flush_to_zero: bool,
+ pub cmp_op: SetpCompareOp,
+}
+
+pub enum SetpCompareOp {
+ Eq,
+ NotEq,
+ Less,
+ LessOrEq,
+ Greater,
+ GreaterOrEq,
+ NanEq,
+ NanNotEq,
+ NanLess,
+ NanLessOrEq,
+ NanGreater,
+ NanGreaterOrEq,
+ IsNotNan,
+ IsNan,
+}
+
+pub enum SetpBoolPostOp {
+ And,
+ Or,
+ Xor,
+}
+
+pub struct SetpBoolData {
+ pub typ: ScalarType,
+ pub flush_to_zero: bool,
+ pub cmp_op: SetpCompareOp,
+ pub bool_op: SetpBoolPostOp
+}
pub struct NotData {}
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 5402326..15302ff 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -12,6 +12,7 @@ extern crate level_zero as ze;
extern crate level_zero_sys as l0;
extern crate rspirv;
extern crate spirv_headers as spirv;
+extern crate half;
#[cfg(test)]
extern crate spirv_tools_sys as spirv_tools;
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 64d7725..b44702d 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -399,20 +399,56 @@ InstMul: ast::Instruction<&'input str> = {
};
InstMulMode: ast::MulData = {
- MulIntControl? IntType => ast::MulData{},
- RoundingMode? ".ftz"? ".sat"? ".f32" => ast::MulData{},
- RoundingMode? ".f64" => ast::MulData{},
- ".rn"? ".ftz"? ".sat"? ".f16" => ast::MulData{},
- ".rn"? ".ftz"? ".sat"? ".f16x2" => ast::MulData{}
+ <ctr:MulIntControl> <t:IntType> => ast::MulData{
+ typ: ast::Type::Scalar(t),
+ desc: ast::MulDescriptor::Int(ctr)
+ },
+ <r:RoundingMode?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulData{
+ typ: ast::Type::Scalar(ast::ScalarType::F32),
+ desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
+ rounding: r,
+ flush_to_zero: ftz.is_some(),
+ saturate: s.is_some()
+ })
+ },
+ <r:RoundingMode?> ".f64" => ast::MulData{
+ typ: ast::Type::Scalar(ast::ScalarType::F64),
+ desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
+ rounding: r,
+ flush_to_zero: false,
+ saturate: false
+ })
+ },
+ <r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulData{
+ typ: ast::Type::Scalar(ast::ScalarType::F16),
+ desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
+ rounding: r.map(|_| ast::RoundingMode::NearestEven),
+ flush_to_zero: ftz.is_some(),
+ saturate: s.is_some()
+ })
+ },
+ <r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulData{
+ typ: ast::Type::ExtendedScalar(ast::ExtendedScalarType::F16x2),
+ desc: ast::MulDescriptor::Float(ast::MulFloatDesc {
+ rounding: r.map(|_| ast::RoundingMode::NearestEven),
+ flush_to_zero: ftz.is_some(),
+ saturate: s.is_some()
+ })
+ }
};
-MulIntControl = {
- ".hi", ".lo", ".wide"
+MulIntControl: ast::MulIntControl = {
+ ".hi" => ast::MulIntControl::High,
+ ".lo" => ast::MulIntControl::Low,
+ ".wide" => ast::MulIntControl::Wide
};
#[inline]
-RoundingMode = {
- ".rn", ".rz", ".rm", ".rp"
+RoundingMode : ast::RoundingMode = {
+ ".rn" => ast::RoundingMode::NearestEven,
+ ".rz" => ast::RoundingMode::Zero,
+ ".rm" => ast::RoundingMode::NegativeInf,
+ ".rp" => ast::RoundingMode::PositiveInf,
};
IntType : ast::ScalarType = {
@@ -449,27 +485,61 @@ InstSetp: ast::Instruction<&'input str> = {
};
SetpMode: ast::SetpData = {
- SetpCmpOp ".ftz"? SetpType => ast::SetpData{}
+ <cmp_op:SetpCompareOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpData{
+ typ: t,
+ flush_to_zero: ftz.is_some(),
+ cmp_op: cmp_op,
+ }
};
SetpBoolMode: ast::SetpBoolData = {
- SetpCmpOp SetpBoolOp ".ftz"? SetpType => ast::SetpBoolData{}
-};
-
-SetpCmpOp = {
- ".eq", ".ne", ".lt", ".le", ".gt", ".ge", ".lo", ".ls", ".hi", ".hs",
- ".equ", ".neu", ".ltu", ".leu", ".gtu", ".geu", ".num", ".nan"
-};
-
-SetpBoolOp = {
- ".and", ".or", ".xor"
+ <cmp_op:SetpCompareOp> <bool_op:SetpBoolPostOp> <ftz:".ftz"?> <t:SetpType> => ast::SetpBoolData{
+ typ: t,
+ flush_to_zero: ftz.is_some(),
+ cmp_op: cmp_op,
+ bool_op: bool_op,
+ }
};
-SetpType = {
- ".b16", ".b32", ".b64",
- ".u16", ".u32", ".u64",
- ".s16", ".s32", ".s64",
- ".f32", ".f64"
+SetpCompareOp: ast::SetpCompareOp = {
+ ".eq" => ast::SetpCompareOp::Eq,
+ ".ne" => ast::SetpCompareOp::NotEq,
+ ".lt" => ast::SetpCompareOp::Less,
+ ".le" => ast::SetpCompareOp::LessOrEq,
+ ".gt" => ast::SetpCompareOp::Greater,
+ ".ge" => ast::SetpCompareOp::GreaterOrEq,
+ ".lo" => ast::SetpCompareOp::Less,
+ ".ls" => ast::SetpCompareOp::LessOrEq,
+ ".hi" => ast::SetpCompareOp::Greater,
+ ".hs" => ast::SetpCompareOp::GreaterOrEq,
+ ".equ" => ast::SetpCompareOp::NanEq,
+ ".neu" => ast::SetpCompareOp::NanNotEq,
+ ".ltu" => ast::SetpCompareOp::NanLess,
+ ".leu" => ast::SetpCompareOp::NanLessOrEq,
+ ".gtu" => ast::SetpCompareOp::NanGreater,
+ ".geu" => ast::SetpCompareOp::NanGreaterOrEq,
+ ".num" => ast::SetpCompareOp::IsNotNan,
+ ".nan" => ast::SetpCompareOp::IsNan,
+};
+
+SetpBoolPostOp: ast::SetpBoolPostOp = {
+ ".and" => ast::SetpBoolPostOp::And,
+ ".or" => ast::SetpBoolPostOp::Or,
+ ".xor" => ast::SetpBoolPostOp::Xor,
+};
+
+SetpType: ast::ScalarType = {
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index b573f2c..b374324 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -37,7 +37,9 @@ macro_rules! test_ptx {
}
test_ptx!(ld_st, [1u64], [1u64]);
-test_ptx!(mov, [1u64], [1u64]);
+//test_ptx!(mov, [1u64], [1u64]);
+//test_ptx!(mul_lo, [1u64], [2u64]);
+//test_ptx!(mul_hi, [u64::max_value()], [1u64]);
struct DisplayError<T: Display + Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/mul_hi.ptx b/ptx/src/test/spirv_run/mul_hi.ptx
new file mode 100644
index 0000000..1dc1572
--- /dev/null
+++ b/ptx/src/test/spirv_run/mul_hi.ptx
@@ -0,0 +1,22 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry mul_hi(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ mul.hi.u64 temp2, temp, 2;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/mul_hi.spvtxt b/ptx/src/test/spirv_run/mul_hi.spvtxt
new file mode 100644
index 0000000..db8943f
--- /dev/null
+++ b/ptx/src/test/spirv_run/mul_hi.spvtxt
@@ -0,0 +1,26 @@
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int64
+OpCapability Int8
+%1 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %5 "mul_hi"
+%2 = OpTypeVoid
+%3 = OpTypeInt 64 0
+%4 = OpTypeFunction %2 %3 %3
+%19 = OpTypePointer Generic %3
+%5 = OpFunction %2 None %4
+%6 = OpFunctionParameter %3
+%7 = OpFunctionParameter %3
+%18 = OpLabel
+%13 = OpCopyObject %3 %6
+%14 = OpCopyObject %3 %7
+%15 = OpConvertUToPtr %19 %13
+%16 = OpLoad %3 %15
+%100 = OpCopyObject %3 %16
+%17 = OpConvertUToPtr %19 %14
+OpStore %17 %100
+OpReturn
+OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mul_lo.ptx b/ptx/src/test/spirv_run/mul_lo.ptx
new file mode 100644
index 0000000..cae3b57
--- /dev/null
+++ b/ptx/src/test/spirv_run/mul_lo.ptx
@@ -0,0 +1,22 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry mul_lo(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ mul.lo.u64 temp2, temp, 2;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/mul_lo.spvtxt b/ptx/src/test/spirv_run/mul_lo.spvtxt
new file mode 100644
index 0000000..66e7bc1
--- /dev/null
+++ b/ptx/src/test/spirv_run/mul_lo.spvtxt
@@ -0,0 +1,26 @@
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int64
+OpCapability Int8
+%1 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %5 "mul_lo"
+%2 = OpTypeVoid
+%3 = OpTypeInt 64 0
+%4 = OpTypeFunction %2 %3 %3
+%19 = OpTypePointer Generic %3
+%5 = OpFunction %2 None %4
+%6 = OpFunctionParameter %3
+%7 = OpFunctionParameter %3
+%18 = OpLabel
+%13 = OpCopyObject %3 %6
+%14 = OpCopyObject %3 %7
+%15 = OpConvertUToPtr %19 %13
+%16 = OpLoad %3 %15
+%100 = OpCopyObject %3 %16
+%17 = OpConvertUToPtr %19 %14
+OpStore %17 %100
+OpReturn
+OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index ee28bb7..6620666 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -18,7 +18,7 @@ impl From<ast::Type> for SpirvType {
fn from(t: ast::Type) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t),
- ast::Type::ExtendedScalar(t) => SpirvType::Extended(t)
+ ast::Type::ExtendedScalar(t) => SpirvType::Extended(t),
}
}
}
@@ -60,7 +60,11 @@ impl TypeWordMap {
})
}
- fn get_or_add_extended(&mut self, b: &mut dr::Builder, t: ast::ExtendedScalarType) -> spirv::Word {
+ fn get_or_add_extended(
+ &mut self,
+ b: &mut dr::Builder,
+ t: ast::ExtendedScalarType,
+ ) -> spirv::Word {
*self
.complex
.entry(SpirvType::Extended(t))
@@ -178,8 +182,9 @@ fn to_ssa<'a>(
let registers = collect_var_definitions(&f_args, &f_body);
let (normalized_ids, unique_ids) =
normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
+ let (normalized_stmts, unique_ids) = normalize_statements(normalized_ids, unique_ids);
let (mut func_body, unique_ids) =
- insert_implicit_conversions(normalized_ids, unique_ids, &|x| type_check[&x]);
+ insert_implicit_conversions(normalized_stmts, unique_ids, &|x| type_check[&x]);
let bbs = get_basic_blocks(&func_body);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
@@ -195,6 +200,221 @@ fn to_ssa<'a>(
(func_body, bbs, phis, unique_ids)
}
+fn normalize_statements(
+ func: Vec<ast::Statement<spirv::Word>>,
+ unique_ids: spirv::Word,
+) -> (Vec<Statement>, spirv::Word) {
+ let mut result = Vec::with_capacity(func.len());
+ let mut id = unique_ids;
+ let new_id = &mut || {
+ let to_insert = id;
+ id += 1;
+ to_insert
+ };
+ for s in func {
+ match s {
+ ast::Statement::Label(id) => result.push(Statement::Label(id)),
+ ast::Statement::Instruction(pred, inst) => {
+ if let Some(pred) = pred {
+ let mut if_true = new_id();
+ let mut if_false = new_id();
+ if pred.not {
+ std::mem::swap(&mut if_true, &mut if_false);
+ }
+ let folded_bra = match &inst {
+ ast::Instruction::Bra(_, arg) => Some(arg.src),
+ _ => None,
+ };
+ let branch = BrachCondition {
+ predicate: pred.label,
+ if_true: folded_bra.unwrap_or(if_true),
+ if_false,
+ };
+ result.push(Statement::Conditional(branch));
+ if folded_bra.is_none() {
+ result.push(Statement::Label(if_true));
+ let instr = normalize_insert_instruction(&mut result, new_id, inst);
+ result.push(Statement::Instruction(instr));
+ }
+ result.push(Statement::Label(if_false));
+ } else {
+ let instr = normalize_insert_instruction(&mut result, new_id, inst);
+ result.push(Statement::Instruction(instr));
+ }
+ }
+ ast::Statement::Variable(_) => unreachable!(),
+ }
+ }
+ (result, id)
+}
+
+#[must_use]
+fn normalize_insert_instruction(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ instr: ast::Instruction<spirv::Word>,
+) -> Instruction {
+ match instr {
+ ast::Instruction::Ld(d, a) => {
+ let arg = normalize_expand_arg2(func, new_id, &|| Some(d.typ), a);
+ Instruction::Ld(d, arg)
+ }
+ ast::Instruction::Mov(d, a) => {
+ let arg = normalize_expand_arg2mov(func, new_id, &|| d.typ.try_as_scalar(), a);
+ Instruction::Mov(d, arg)
+ }
+ ast::Instruction::Mul(d, a) => {
+ let arg = normalize_expand_arg3(func, new_id, &|| d.typ.try_as_scalar(), a);
+ Instruction::Mul(d, arg)
+ }
+ ast::Instruction::Add(d, a) => {
+ let arg = normalize_expand_arg3(func, new_id, &|| Some(d.typ), a);
+ Instruction::Add(d, arg)
+ }
+ ast::Instruction::Setp(d, a) => {
+ let arg = normalize_expand_arg4(func, new_id, &|| Some(d.typ), a);
+ Instruction::Setp(d, arg)
+ }
+ ast::Instruction::SetpBool(d, a) => {
+ let arg = normalize_expand_arg5(func, new_id, &|| Some(d.typ), a);
+ Instruction::SetpBool(d, arg)
+ }
+ ast::Instruction::Not(d, a) => {
+ let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a);
+ Instruction::Not(d, arg)
+ }
+ ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }),
+ ast::Instruction::Cvt(d, a) => {
+ let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a);
+ Instruction::Cvt(d, arg)
+ }
+ ast::Instruction::Shl(d, a) => {
+ let arg = normalize_expand_arg3(func, new_id, &|| todo!(), a);
+ Instruction::Shl(d, arg)
+ }
+ ast::Instruction::St(d, a) => {
+ let arg = normalize_expand_arg2st(func, new_id, &|| todo!(), a);
+ Instruction::St(d, arg)
+ }
+ ast::Instruction::Ret(d) => Instruction::Ret(d),
+ }
+}
+
+fn normalize_expand_arg2(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ inst_type: &impl Fn() -> Option<ast::ScalarType>,
+ a: ast::Arg2<spirv::Word>,
+) -> Arg2 {
+ Arg2 {
+ dst: a.dst,
+ src: normalize_expand_operand(func, new_id, inst_type, a.src),
+ }
+}
+
+fn normalize_expand_arg2mov(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ inst_type: &impl Fn() -> Option<ast::ScalarType>,
+ a: ast::Arg2Mov<spirv::Word>,
+) -> Arg2 {
+ Arg2 {
+ dst: a.dst,
+ src: normalize_expand_mov_operand(func, new_id, inst_type, a.src),
+ }
+}
+
+fn normalize_expand_arg2st(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ inst_type: &impl Fn() -> Option<ast::ScalarType>,
+ a: ast::Arg2St<spirv::Word>,
+) -> Arg2St {
+ Arg2St {
+ src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
+ src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
+ }
+}
+
+fn normalize_expand_arg3(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ inst_type: &impl Fn() -> Option<ast::ScalarType>,
+ a: ast::Arg3<spirv::Word>,
+) -> Arg3 {
+ Arg3 {
+ dst: a.dst,
+ src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
+ src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
+ }
+}
+
+fn normalize_expand_arg4(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ inst_type: &impl Fn() -> Option<ast::ScalarType>,
+ a: ast::Arg4<spirv::Word>,
+) -> Arg4 {
+ Arg4 {
+ dst1: a.dst1,
+ dst2: a.dst2,
+ src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
+ src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
+ }
+}
+
+fn normalize_expand_arg5(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ inst_type: &impl Fn() -> Option<ast::ScalarType>,
+ a: ast::Arg5<spirv::Word>,
+) -> Arg5 {
+ Arg5 {
+ dst1: a.dst1,
+ dst2: a.dst2,
+ src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
+ src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
+ src3: normalize_expand_operand(func, new_id, inst_type, a.src3),
+ }
+}
+
+fn normalize_expand_operand(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ inst_type: &impl Fn() -> Option<ast::ScalarType>,
+ opr: ast::Operand<spirv::Word>,
+) -> spirv::Word {
+ match opr {
+ ast::Operand::Reg(r) => r,
+ ast::Operand::Imm(x) => {
+ if let Some(typ) = inst_type() {
+ let id = new_id();
+ func.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: typ,
+ value: x,
+ }));
+ id
+ } else {
+ todo!()
+ }
+ }
+ _ => todo!(),
+ }
+}
+
+fn normalize_expand_mov_operand(
+ func: &mut Vec<Statement>,
+ new_id: &mut impl FnMut() -> spirv::Word,
+ inst_type: &impl Fn() -> Option<ast::ScalarType>,
+ opr: ast::MovOperand<spirv::Word>,
+) -> spirv::Word {
+ match opr {
+ ast::MovOperand::Op(opr) => normalize_expand_operand(func, new_id, inst_type, opr),
+ _ => todo!(),
+ }
+}
+
fn collect_var_definitions<'a>(
args: &[ast::Argument<'a>],
body: &[ast::Statement<&'a str>],
@@ -249,17 +469,15 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
for s in normalized_ids.into_iter() {
match s {
Statement::Instruction(inst) => match inst {
- ast::Instruction::Ld(ld, mut arg) => {
- arg.src = arg.src.map_id(&mut |arg_src| {
- insert_implicit_conversions_ld_src(
- &mut result,
- ast::Type::Scalar(ld.typ),
- type_check,
- new_id,
- ld.state_space,
- arg_src,
- )
- });
+ Instruction::Ld(ld, mut arg) => {
+ arg.src = insert_implicit_conversions_ld_src(
+ &mut result,
+ ast::Type::Scalar(ld.typ),
+ type_check,
+ new_id,
+ ld.state_space,
+ arg.src,
+ );
insert_with_implicit_conversion_dst(
&mut result,
ld.typ,
@@ -268,40 +486,35 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
should_convert_relaxed_dst,
arg,
|arg| &mut arg.dst,
- |arg| ast::Instruction::Ld(ld, arg),
+ |arg| Instruction::Ld(ld, arg),
);
}
- ast::Instruction::St(st, mut arg) => {
- arg.src2 = arg.src2.map_id(&mut |arg_src| {
- let arg_src_type = type_check(arg_src);
- if let Some(conv) = should_convert_relaxed_src(arg_src_type, st.typ) {
- insert_conversion_src(
- &mut result,
- new_id,
- arg_src,
- arg_src_type,
- ast::Type::Scalar(st.typ),
- conv,
- )
- } else {
- arg_src
- }
- });
- arg.src1 = arg.src1.map_id(&mut |arg_src| {
- insert_implicit_conversions_ld_src(
+ Instruction::St(st, mut arg) => {
+ let arg_src2_type = type_check(arg.src2);
+ if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
+ arg.src2 = insert_conversion_src(
&mut result,
- ast::Type::Scalar(st.typ),
- type_check,
new_id,
- st.state_space.to_ld_ss(),
- arg_src,
- )
- });
- result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
+ arg.src2,
+ arg_src2_type,
+ ast::Type::Scalar(st.typ),
+ conv,
+ );
+ }
+ arg.src1 = insert_implicit_conversions_ld_src(
+ &mut result,
+ ast::Type::Scalar(st.typ),
+ type_check,
+ new_id,
+ st.state_space.to_ld_ss(),
+ arg.src1,
+ );
+ result.push(Statement::Instruction(Instruction::St(st, arg)));
}
inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst),
},
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
+ Statement::Constant(_) => (),
Statement::Converison(_) => unreachable!(),
}
}
@@ -390,61 +603,52 @@ fn emit_function_body_ops(
// If block starts with a label it has already been emitted,
// all other labels in the block are unused
Statement::Label(_) => (),
+ Statement::Constant(_) => todo!(),
Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
Statement::Conditional(bra) => {
builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
}
Statement::Instruction(inst) => match inst {
// SPIR-V does not support marking jumps as guaranteed-converged
- ast::Instruction::Bra(_, arg) => {
+ Instruction::Bra(_, arg) => {
builder.branch(arg.src)?;
}
- ast::Instruction::Ld(data, arg) => {
+ Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
todo!()
}
- let src = match arg.src {
- ast::Operand::Reg(id) => id,
- _ => todo!(),
- };
let result_type = map.get_or_add_scalar(builder, data.typ);
match data.state_space {
ast::LdStateSpace::Generic => {
- builder.load(result_type, Some(arg.dst), src, None, [])?;
+ builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
ast::LdStateSpace::Param => {
- builder.copy_object(result_type, Some(arg.dst), src)?;
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
_ => todo!(),
}
}
- ast::Instruction::St(data, arg) => {
+ Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
|| data.vector.is_some()
|| data.state_space != ast::StStateSpace::Generic
{
todo!()
}
- let dst = match arg.src1 {
- ast::Operand::Reg(id) => id,
- _ => todo!(),
- };
- let src = match arg.src2 {
- ast::Operand::Reg(id) => id,
- _ => todo!(),
- };
- builder.store(dst, src, None, &[])?;
+ builder.store(arg.src1, arg.src2, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
- ast::Instruction::Ret(_) => builder.ret()?,
- ast::Instruction::Mov(mov, arg) => {
+ Instruction::Ret(_) => builder.ret()?,
+ Instruction::Mov(mov, arg) => {
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
- let src = match arg.src {
- ast::MovOperand::Op(ast::Operand::Reg(id)) => id,
- _ => todo!(),
- };
- builder.copy_object(result_type, Some(arg.dst), src)?;
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
+ Instruction::Mul(mul, arg) => match mul.desc {
+ ast::MulDescriptor::Int(ref ctr) => {
+ emit_mul_int(builder, map, mul.typ, ctr, arg)
+ }
+ ast::MulDescriptor::Float(_) => todo!(),
+ },
_ => todo!(),
},
}
@@ -453,6 +657,17 @@ fn emit_function_body_ops(
Ok(())
}
+fn emit_mul_int(
+ _builder: &mut dr::Builder,
+ _map: &mut TypeWordMap,
+ _typ: ast::Type,
+ _ctr: &ast::MulIntControl,
+ _arg: &Arg3,
+) {
+ //let inst_type = map.get_or_add(builder, SpirvType::from(typ));
+ //builder.i_mul(inst_type, Some(arg.dst), Some(arg.src1), Some(arg.src2));
+}
+
fn emit_implicit_conversion(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -523,12 +738,11 @@ fn normalize_identifiers<'a>(
constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
type_map: &mut HashMap<spirv::Word, ast::Type>,
types: HashMap<Cow<'a, str>, ast::Type>,
-) -> (Vec<Statement>, spirv::Word) {
- let mut result = Vec::with_capacity(func.len());
+) -> (Vec<ast::Statement<spirv::Word>>, spirv::Word) {
let mut id: u32 = constant_identifiers.len() as u32;
let mut remapped_ids = HashMap::new();
- let mut get_or_add = |key| match key {
- Some(key) => constant_identifiers.get(key).map_or_else(
+ let mut get_or_add = |key| {
+ constant_identifiers.get(key).map_or_else(
|| {
*remapped_ids.entry(key).or_insert_with(|| {
let to_insert = id;
@@ -537,16 +751,12 @@ fn normalize_identifiers<'a>(
})
},
|id| *id,
- ),
- None => {
- let to_insert = id;
- id += 1;
- to_insert
- }
+ )
};
- for s in func {
- Statement::from_ast(s, &mut result, &mut get_or_add);
- }
+ let result = func
+ .into_iter()
+ .filter_map(|s| Statement::from_ast(s, &mut get_or_add))
+ .collect::<Vec<_>>();
type_map.extend(
remapped_ids
.into_iter()
@@ -594,7 +804,7 @@ fn apply_ssa_renaming(
for s in get_bb_body(func, bbs, BBIndex(bb)) {
s.visit_id(&mut |is_dst, id| {
if is_dst {
- old_dst_id[bb].push(*id)
+ old_dst_id[bb].push(id)
}
});
}
@@ -787,8 +997,8 @@ fn gather_phi_sets(
let mut blocks = vec![(Vec::new(), HashSet::new()); (all_ids - constant_ids) as usize];
for bb in 0..cfg.len() {
let mut var_kill = HashSet::new();
- let mut visitor = |is_dst, id: &u32| {
- if *id >= constant_ids {
+ let mut visitor = |is_dst, id: spirv::Word| {
+ if id >= constant_ids {
let id = id - constant_ids;
if is_dst {
var_kill.insert(id);
@@ -807,8 +1017,9 @@ fn gather_phi_sets(
for s in get_bb_body(func, cfg, BBIndex(bb)) {
match s {
Statement::Instruction(inst) => inst.visit_id(&mut visitor),
- Statement::Conditional(brc) => visitor(false, &brc.predicate),
+ Statement::Conditional(brc) => visitor(false, brc.predicate),
Statement::Converison(conv) => conv.visit_id(&mut visitor),
+ Statement::Constant(cons) => cons.visit_id(&mut visitor),
// label redefinition is a compile-time error
Statement::Label(_) => (),
}
@@ -859,6 +1070,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
unresolved_bb_edge.push((StmtIndex(idx), bra.if_false));
unresolved_bb_edge.push((StmtIndex(idx), bra.if_true));
}
+ Statement::Constant(_) => (),
Statement::Converison(_) => (),
};
}
@@ -877,7 +1089,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
bb_edge.insert((StmtIndex(target.0 - 1), target));
}
}
- Statement::Converison(_) | Statement::Label(_) => {
+ Statement::Converison(_) | Statement::Constant(_) | Statement::Label(_) => {
bb_edge.insert((StmtIndex(target.0 - 1), target));
}
// This is already in `unresolved_bb_edge`
@@ -1043,10 +1255,241 @@ impl fmt::Display for BBIndex {
enum Statement {
Label(u32),
- Instruction(ast::Instruction<spirv::Word>),
+ Instruction(Instruction),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Converison(ImplicitConversion),
+ Constant(ConstantDefinition),
+}
+
+enum Instruction {
+ Ld(ast::LdData, Arg2),
+ Mov(ast::MovData, Arg2),
+ Mul(ast::MulData, Arg3),
+ Add(ast::AddData, Arg3),
+ Setp(ast::SetpData, Arg4),
+ SetpBool(ast::SetpBoolData, Arg5),
+ Not(ast::NotData, Arg2),
+ Bra(ast::BraData, Arg1),
+ Cvt(ast::CvtData, Arg2),
+ Shl(ast::ShlData, Arg3),
+ St(ast::StData, Arg2St),
+ Ret(ast::RetData),
+}
+
+impl Instruction {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ match self {
+ Instruction::Ld(_, a) => a.visit_id(f),
+ Instruction::Mov(_, a) => a.visit_id(f),
+ Instruction::Mul(_, a) => a.visit_id(f),
+ Instruction::Add(_, a) => a.visit_id(f),
+ Instruction::Setp(_, a) => a.visit_id(f),
+ Instruction::SetpBool(_, a) => a.visit_id(f),
+ Instruction::Not(_, a) => a.visit_id(f),
+ Instruction::Cvt(_, a) => a.visit_id(f),
+ Instruction::Shl(_, a) => a.visit_id(f),
+ Instruction::St(_, a) => a.visit_id(f),
+ Instruction::Bra(_, a) => a.visit_id(f),
+ Instruction::Ret(_) => (),
+ }
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ match self {
+ Instruction::Ld(_, a) => a.visit_id_mut(f),
+ Instruction::Mov(_, a) => a.visit_id_mut(f),
+ Instruction::Mul(_, a) => a.visit_id_mut(f),
+ Instruction::Add(_, a) => a.visit_id_mut(f),
+ Instruction::Setp(_, a) => a.visit_id_mut(f),
+ Instruction::SetpBool(_, a) => a.visit_id_mut(f),
+ Instruction::Not(_, a) => a.visit_id_mut(f),
+ Instruction::Cvt(_, a) => a.visit_id_mut(f),
+ Instruction::Shl(_, a) => a.visit_id_mut(f),
+ Instruction::St(_, a) => a.visit_id_mut(f),
+ Instruction::Bra(_, a) => a.visit_id_mut(f),
+ Instruction::Ret(_) => (),
+ }
+ }
+
+ fn get_type(&self) -> Option<ast::Type> {
+ match self {
+ Instruction::Add(add, _) => Some(ast::Type::Scalar(add.typ)),
+ Instruction::Ret(_) => None,
+ Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
+ Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
+ Instruction::Mov(mov, _) => Some(mov.typ),
+ Instruction::Mul(mul, _) => Some(mul.typ),
+ _ => todo!(),
+ }
+ }
+
+ fn jump_target(&self) -> Option<spirv::Word> {
+ match self {
+ Instruction::Bra(_, a) => Some(a.src),
+ Instruction::Ld(_, _)
+ | Instruction::Mov(_, _)
+ | Instruction::Mul(_, _)
+ | Instruction::Add(_, _)
+ | Instruction::Setp(_, _)
+ | Instruction::SetpBool(_, _)
+ | Instruction::Not(_, _)
+ | Instruction::Cvt(_, _)
+ | Instruction::Shl(_, _)
+ | Instruction::St(_, _)
+ | Instruction::Ret(_) => None,
+ }
+ }
+
+ fn is_terminal(&self) -> bool {
+ match self {
+ Instruction::Ret(_) => true,
+ Instruction::Ld(_, _)
+ | Instruction::Mov(_, _)
+ | Instruction::Mul(_, _)
+ | Instruction::Add(_, _)
+ | Instruction::Setp(_, _)
+ | Instruction::SetpBool(_, _)
+ | Instruction::Not(_, _)
+ | Instruction::Cvt(_, _)
+ | Instruction::Shl(_, _)
+ | Instruction::St(_, _)
+ | Instruction::Bra(_, _) => false,
+ }
+ }
+}
+
+struct Arg1 {
+ pub src: spirv::Word,
+}
+
+impl Arg1 {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(false, self.src);
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ f(false, &mut self.src);
+ }
+}
+
+struct Arg2 {
+ pub dst: spirv::Word,
+ pub src: spirv::Word,
+}
+
+impl Arg2 {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst);
+ f(false, self.src);
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ f(false, &mut self.src);
+ f(true, &mut self.dst);
+ }
+}
+
+pub struct Arg2St {
+ pub src1: spirv::Word,
+ pub src2: spirv::Word,
+}
+
+impl Arg2St {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(false, self.src1);
+ f(false, self.src2);
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ f(false, &mut self.src1);
+ f(false, &mut self.src2);
+ }
+}
+
+struct Arg3 {
+ pub dst: spirv::Word,
+ pub src1: spirv::Word,
+ pub src2: spirv::Word,
+}
+
+impl Arg3 {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst);
+ f(false, self.src1);
+ f(false, self.src2);
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ f(false, &mut self.src1);
+ f(false, &mut self.src2);
+ f(true, &mut self.dst);
+ }
+}
+
+struct Arg4 {
+ pub dst1: spirv::Word,
+ pub dst2: Option<spirv::Word>,
+ pub src1: spirv::Word,
+ pub src2: spirv::Word,
+}
+
+impl Arg4 {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst1);
+ self.dst2.map(|dst2| f(true, dst2));
+ f(false, self.src1);
+ f(false, self.src2);
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ f(false, &mut self.src1);
+ f(false, &mut self.src2);
+ f(true, &mut self.dst1);
+ self.dst2.as_mut().map(|dst2| f(true, dst2));
+ }
+}
+
+struct Arg5 {
+ pub dst1: spirv::Word,
+ pub dst2: Option<spirv::Word>,
+ pub src1: spirv::Word,
+ pub src2: spirv::Word,
+ pub src3: spirv::Word,
+}
+
+impl Arg5 {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst1);
+ self.dst2.map(|dst2| f(true, dst2));
+ f(false, self.src1);
+ f(false, self.src2);
+ f(false, self.src3);
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ f(false, &mut self.src1);
+ f(false, &mut self.src2);
+ f(false, &mut self.src3);
+ f(true, &mut self.dst1);
+ self.dst2.as_mut().map(|dst2| f(true, dst2));
+ }
+}
+
+struct ConstantDefinition {
+ pub dst: spirv::Word,
+ pub typ: ast::ScalarType,
+ pub value: i128,
+}
+
+impl ConstantDefinition {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst);
+ }
+
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ f(true, &mut self.dst);
+ }
}
struct BrachCondition {
@@ -1056,10 +1499,10 @@ struct BrachCondition {
}
impl BrachCondition {
- fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
- f(false, &self.predicate);
- f(false, &self.if_true);
- f(false, &self.if_false);
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(false, self.predicate);
+ f(false, self.if_true);
+ f(false, self.if_false);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
@@ -1086,9 +1529,9 @@ enum ConversionKind {
}
impl ImplicitConversion {
- fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
- f(false, &self.src);
- f(true, &self.dst);
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(false, self.src);
+ f(true, self.dst);
}
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
@@ -1098,54 +1541,27 @@ impl ImplicitConversion {
}
impl Statement {
- fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>(
+ fn from_ast<'a, F: FnMut(&'a str) -> u32>(
s: ast::Statement<&'a str>,
- out: &mut Vec<Statement>,
get_id: &mut F,
- ) {
+ ) -> Option<ast::Statement<spirv::Word>> {
match s {
- ast::Statement::Label(name) => out.push(Statement::Label(get_id(Some(name)))),
- ast::Statement::Instruction(p, i) => {
- if let Some(pred) = p {
- let predicate = get_id(Some(pred.label));
- let mut if_true = get_id(None);
- let mut if_false = get_id(None);
- if pred.not {
- std::mem::swap(&mut if_true, &mut if_false);
- }
- let folded_bra = match &i {
- ast::Instruction::Bra(_, arg) => Some(get_id(Some(arg.src))),
- _ => None,
- };
- let branch = BrachCondition {
- predicate,
- if_true: folded_bra.unwrap_or(if_true),
- if_false,
- };
- out.push(Statement::Conditional(branch));
- if folded_bra.is_none() {
- out.push(Statement::Label(if_true));
- out.push(Statement::Instruction(
- i.map_id(&mut |name| get_id(Some(name))),
- ));
- }
- out.push(Statement::Label(if_false));
- } else {
- out.push(Statement::Instruction(
- i.map_id(&mut |name| get_id(Some(name))),
- ));
- }
- }
- ast::Statement::Variable(_) => (),
+ ast::Statement::Label(name) => Some(ast::Statement::Label(get_id(name))),
+ ast::Statement::Instruction(p, i) => Some(ast::Statement::Instruction(
+ p.map(|p| p.map_id(get_id)),
+ i.map_id(get_id),
+ )),
+ ast::Statement::Variable(_) => None,
}
}
- fn visit_id<F: FnMut(bool, &spirv::Word)>(&self, f: &mut F) {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
match self {
- Statement::Label(id) => f(false, id),
+ Statement::Label(id) => f(false, *id),
Statement::Instruction(inst) => inst.visit_id(f),
Statement::Conditional(bra) => bra.visit_id(f),
Statement::Converison(conv) => conv.visit_id(f),
+ Statement::Constant(cons) => cons.visit_id(f),
}
}
@@ -1157,6 +1573,16 @@ impl Statement {
Statement::Instruction(inst) => inst.visit_id_mut(f),
Statement::Conditional(bra) => bra.visit_id_mut(f),
Statement::Converison(conv) => conv.visit_id_mut(f),
+ Statement::Constant(cons) => cons.visit_id_mut(f),
+ }
+ }
+}
+
+impl<T> ast::PredAt<T> {
+ fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> {
+ ast::PredAt {
+ not: self.not,
+ label: f(self.label),
}
}
}
@@ -1220,7 +1646,8 @@ impl<T> ast::Instruction<T> {
ast::Instruction::Ld(ld, _) => Some(ast::Type::Scalar(ld.typ)),
ast::Instruction::St(st, _) => Some(ast::Type::Scalar(st.typ)),
ast::Instruction::Mov(mov, _) => Some(mov.typ),
- _ => todo!()
+ ast::Instruction::Mul(mul, _) => Some(mul.typ),
+ _ => todo!(),
}
}
}
@@ -1476,6 +1903,15 @@ enum ScalarKind {
Float,
}
+impl ast::Type {
+ fn try_as_scalar(self) -> Option<ast::ScalarType> {
+ match self {
+ ast::Type::Scalar(s) => Some(s),
+ ast::Type::ExtendedScalar(_) => None,
+ }
+ }
+}
+
impl ast::ScalarType {
fn width(self) -> u8 {
match self {
@@ -1688,7 +2124,7 @@ fn insert_with_implicit_conversion_dst<
NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
- ToInstruction: FnOnce(T) -> ast::Instruction<spirv::Word>,
+ ToInstruction: FnOnce(T) -> Instruction,
>(
func: &mut Vec<Statement>,
instr_type: ast::ScalarType,
@@ -1821,7 +2257,7 @@ fn insert_implicit_bitcasts<
func: &mut Vec<Statement>,
type_check: &TypeCheck,
new_id: &mut NewId,
- mut instr: ast::Instruction<spirv::Word>,
+ mut instr: Instruction,
) {
let mut dst_coercion = None;
if let Some(instr_type) = instr.get_type() {
@@ -1984,9 +2420,9 @@ mod tests {
fn get_basic_blocks_miniloop() {
let func = vec![
Statement::Label(12),
- Statement::Instruction(ast::Instruction::Bra(
+ Statement::Instruction(Instruction::Bra(
ast::BraData { uniform: false },
- ast::Arg1 { src: 12 },
+ Arg1 { src: 12 },
)),
];
let bbs = get_basic_blocks(&func);
@@ -2226,9 +2662,10 @@ mod tests {
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &ast);
let registers = collect_var_definitions(&[], &ast);
- let (normalized_ids, _) =
+ let (normalized_ids, unique_ids) =
normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers);
- let mut bbs = get_basic_blocks(&normalized_ids);
+ let (normalized_stmts, _) = normalize_statements(normalized_ids, unique_ids);
+ let mut bbs = get_basic_blocks(&normalized_stmts);
bbs.iter_mut().for_each(sort_pred_succ);
assert_eq!(
bbs,
@@ -2239,32 +2676,32 @@ mod tests {
succ: vec![BBIndex(1)],
},
BasicBlock {
- start: StmtIndex(3),
+ start: StmtIndex(6),
pred: vec![BBIndex(0), BBIndex(5)],
succ: vec![BBIndex(2), BBIndex(6)],
},
BasicBlock {
- start: StmtIndex(6),
+ start: StmtIndex(10),
pred: vec![BBIndex(1)],
succ: vec![BBIndex(3), BBIndex(4)],
},
BasicBlock {
- start: StmtIndex(9),
+ start: StmtIndex(14),
pred: vec![BBIndex(2)],
succ: vec![BBIndex(5)],
},
BasicBlock {
- start: StmtIndex(13),
+ start: StmtIndex(19),
pred: vec![BBIndex(2)],
succ: vec![BBIndex(5)],
},
BasicBlock {
- start: StmtIndex(16),
+ start: StmtIndex(23),
pred: vec![BBIndex(3), BBIndex(4)],
succ: vec![BBIndex(1)],
},
BasicBlock {
- start: StmtIndex(18),
+ start: StmtIndex(25),
pred: vec![BBIndex(1)],
succ: vec![],
},
@@ -2375,14 +2812,15 @@ mod tests {
collect_label_ids(&mut constant_ids, &fn_ast);
assert_eq!(constant_ids.len(), 4);
let registers = collect_var_definitions(&[], &fn_ast);
- let (normalized_ids, max_id) =
+ let (normalized_ids, unique_ids) =
normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers);
- let bbs = get_basic_blocks(&normalized_ids);
+ let (normalized_stmts, max_id) = normalize_statements(normalized_ids, unique_ids);
+ let bbs = get_basic_blocks(&normalized_stmts);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
let dom_fronts = dominance_frontiers(&bbs, &doms);
let phi = gather_phi_sets(
- &normalized_ids,
+ &normalized_stmts,
constant_ids.len() as u32,
max_id,
&bbs,
@@ -2490,7 +2928,7 @@ mod tests {
for s in func {
s.visit_id(&mut |is_dst, id| {
if is_dst {
- assert!(seen.insert(*id));
+ assert!(seen.insert(id));
}
});
}
@@ -2504,7 +2942,7 @@ mod tests {
fn get_ids(s: &Statement) -> Vec<spirv::Word> {
let mut result = Vec::new();
s.visit_id(&mut |_, id| {
- result.push(*id);
+ result.push(id);
});
result
}
@@ -2533,7 +2971,7 @@ mod tests {
let mut result = None;
s.visit_id(&mut |is_dst, id| {
if is_dst {
- assert_eq!(result.replace(*id), None);
+ assert_eq!(result.replace(id), None);
}
});
result.unwrap()