aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-07-28 02:44:24 +0200
committerAndrzej Janik <[email protected]>2020-07-28 02:44:24 +0200
commit52faaab547afa1e6010d47445a3b50303291ef9c (patch)
tree1bbe3440b730ad41683e8c3e520d9a5feded76f0 /ptx
parentd514a5610a2202dbe64ce48961f93a9b0111a57b (diff)
downloadZLUDA-52faaab547afa1e6010d47445a3b50303291ef9c.tar.gz
ZLUDA-52faaab547afa1e6010d47445a3b50303291ef9c.zip
Remove the need for custom Arg types in middle-end
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/ast.rs108
-rw-r--r--ptx/src/ptx.lalrpop42
-rw-r--r--ptx/src/translate.rs323
3 files changed, 244 insertions, 229 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 6bb099a..9fab216 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,5 +1,5 @@
use std::convert::From;
-use std::num::ParseIntError;
+use std::{marker::PhantomData, num::ParseIntError};
quick_error! {
#[derive(Debug)]
@@ -52,7 +52,7 @@ pub struct Function<'a> {
pub kernel: bool,
pub name: &'a str,
pub args: Vec<Argument<'a>>,
- pub body: Vec<Statement<&'a str>>,
+ pub body: Vec<Statement<ParsedArgParams<'a>>>,
}
#[derive(Default)]
@@ -141,16 +141,16 @@ impl Default for ScalarType {
}
}
-pub enum Statement<ID> {
- Label(ID),
- Variable(Variable<ID>),
- Instruction(Option<PredAt<ID>>, Instruction<ID>),
+pub enum Statement<P: ArgParams> {
+ Label(P::ID),
+ Variable(Variable<P>),
+ Instruction(Option<PredAt<P::ID>>, Instruction<P>),
}
-pub struct Variable<ID> {
+pub struct Variable<P: ArgParams> {
pub space: StateSpace,
pub v_type: Type,
- pub name: ID,
+ pub name: P::ID,
pub count: Option<u32>,
}
@@ -169,59 +169,75 @@ pub struct PredAt<ID> {
pub label: ID,
}
-pub enum Instruction<ID> {
- Ld(LdData, Arg2<ID>),
- Mov(MovData, Arg2Mov<ID>),
- Mul(MulDetails, Arg3<ID>),
- Add(AddDetails, Arg3<ID>),
- Setp(SetpData, Arg4<ID>),
- SetpBool(SetpBoolData, Arg5<ID>),
- Not(NotData, Arg2<ID>),
- Bra(BraData, Arg1<ID>),
- Cvt(CvtData, Arg2<ID>),
- Shl(ShlData, Arg3<ID>),
- St(StData, Arg2St<ID>),
+pub enum Instruction<P: ArgParams> {
+ Ld(LdData, Arg2<P>),
+ Mov(MovData, Arg2Mov<P>),
+ Mul(MulDetails, Arg3<P>),
+ Add(AddDetails, Arg3<P>),
+ Setp(SetpData, Arg4<P>),
+ SetpBool(SetpBoolData, Arg5<P>),
+ Not(NotData, Arg2<P>),
+ Bra(BraData, Arg1<P>),
+ Cvt(CvtData, Arg2<P>),
+ Shl(ShlData, Arg3<P>),
+ St(StData, Arg2St<P>),
Ret(RetData),
}
-pub struct Arg1<ID> {
- pub src: ID, // it is a jump destination, but in terms of operands it is a source operand
+pub trait ArgParams {
+ type ID;
+ type Operand;
+ type MovOperand;
}
-pub struct Arg2<ID> {
- pub dst: ID,
- pub src: Operand<ID>,
+pub struct ParsedArgParams<'a> {
+ _marker: PhantomData<&'a ()>,
}
-pub struct Arg2St<ID> {
- pub src1: Operand<ID>,
- pub src2: Operand<ID>,
+impl<'a> ArgParams for ParsedArgParams<'a> {
+ type ID = &'a str;
+ type Operand = Operand<&'a str>;
+ type MovOperand = MovOperand<&'a str>;
}
-pub struct Arg2Mov<ID> {
- pub dst: ID,
- pub src: MovOperand<ID>,
+pub struct Arg1<P: ArgParams> {
+ pub src: P::ID, // it is a jump destination, but in terms of operands it is a source operand
}
-pub struct Arg3<ID> {
- pub dst: ID,
- pub src1: Operand<ID>,
- pub src2: Operand<ID>,
+pub struct Arg2<P: ArgParams> {
+ pub dst: P::ID,
+ pub src: P::Operand,
}
-pub struct Arg4<ID> {
- pub dst1: ID,
- pub dst2: Option<ID>,
- pub src1: Operand<ID>,
- pub src2: Operand<ID>,
+pub struct Arg2St<P: ArgParams> {
+ pub src1: P::Operand,
+ pub src2: P::Operand,
}
-pub struct Arg5<ID> {
- pub dst1: ID,
- pub dst2: Option<ID>,
- pub src1: Operand<ID>,
- pub src2: Operand<ID>,
- pub src3: Operand<ID>,
+pub struct Arg2Mov<P: ArgParams> {
+ pub dst: P::ID,
+ pub src: P::MovOperand,
+}
+
+pub struct Arg3<P: ArgParams> {
+ pub dst: P::ID,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+}
+
+pub struct Arg4<P: ArgParams> {
+ pub dst1: P::ID,
+ pub dst2: Option<P::ID>,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+}
+
+pub struct Arg5<P: ArgParams> {
+ pub dst1: P::ID,
+ pub dst2: Option<P::ID>,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+ pub src3: P::Operand,
}
pub enum Operand<ID> {
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index cc58cf2..af26765 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -223,7 +223,7 @@ FunctionInput: ast::Argument<'input> = {
}
};
-pub(crate) FunctionBody: Vec<ast::Statement<&'input str>> = {
+pub(crate) FunctionBody: Vec<ast::Statement<ast::ParsedArgParams<'input>>> = {
"{" <s:Statement*> "}" => { without_none(s) }
};
@@ -269,7 +269,7 @@ MemoryType: ast::ScalarType = {
".f64" => ast::ScalarType::F64,
};
-Statement: Option<ast::Statement<&'input str>> = {
+Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
<l:Label> => Some(ast::Statement::Label(l)),
DebugDirective => None,
<v:Variable> ";" => Some(ast::Statement::Variable(v)),
@@ -289,7 +289,7 @@ Label: &'input str = {
<id:ExtendedID> ":" => id
};
-Variable: ast::Variable<&'input str> = {
+Variable: ast::Variable<ast::ParsedArgParams<'input>> = {
<s:StateSpaceSpecifier> <t:Type> <v:VariableName> => {
let (name, count) = v;
ast::Variable { space: s, v_type: t, name: name, count: count }
@@ -310,7 +310,7 @@ VariableName: (&'input str, Option<u32>) = {
}
};
-Instruction: ast::Instruction<&'input str> = {
+Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd,
InstMov,
InstMul,
@@ -325,7 +325,7 @@ Instruction: ast::Instruction<&'input str> = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
-InstLd: ast::Instruction<&'input str> = {
+InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <v:VectorPrefix?> <t:MemoryType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
ast::Instruction::Ld(
ast::LdData {
@@ -370,7 +370,7 @@ LdCacheOperator: ast::LdCacheOperator = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
-InstMov: ast::Instruction<&'input str> = {
+InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mov" <t:MovType> <a:Arg2Mov> => {
ast::Instruction::Mov(ast::MovData{ typ:t }, a)
}
@@ -394,7 +394,7 @@ MovType: ast::Type = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
-InstMul: ast::Instruction<&'input str> = {
+InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a)
};
@@ -455,7 +455,7 @@ IntType : ast::IntType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
-InstAdd: ast::Instruction<&'input str> = {
+InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = {
"add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a)
};
@@ -492,7 +492,7 @@ InstAddMode: ast::AddDetails = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
// TODO: support f16 setp
-InstSetp: ast::Instruction<&'input str> = {
+InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = {
"setp" <d:SetpMode> <a:Arg4> => ast::Instruction::Setp(d, a),
"setp" <d:SetpBoolMode> <a:Arg5> => ast::Instruction::SetpBool(d, a),
};
@@ -556,7 +556,7 @@ SetpType: ast::ScalarType = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
-InstNot: ast::Instruction<&'input str> = {
+InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
"not" NotType <a:Arg2> => ast::Instruction::Not(ast::NotData{}, a)
};
@@ -571,12 +571,12 @@ PredAt: ast::PredAt<&'input str> = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
-InstBra: ast::Instruction<&'input str> = {
+InstBra: ast::Instruction<ast::ParsedArgParams<'input>> = {
"bra" <u:".uni"?> <a:Arg1> => ast::Instruction::Bra(ast::BraData{ uniform: u.is_some() }, a)
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
-InstCvt: ast::Instruction<&'input str> = {
+InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"cvt" CvtRnd? ".ftz"? ".sat"? CvtType CvtType <a:Arg2> => {
ast::Instruction::Cvt(ast::CvtData{}, a)
}
@@ -602,7 +602,7 @@ CvtType = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
-InstShl: ast::Instruction<&'input str> = {
+InstShl: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shl" ShlType <a:Arg3> => ast::Instruction::Shl(ast::ShlData{}, a)
};
@@ -612,7 +612,7 @@ ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
-InstSt: ast::Instruction<&'input str> = {
+InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <v:VectorPrefix?> <t:MemoryType> "[" <src1:Operand> "]" "," <src2:Operand> => {
ast::Instruction::St(
ast::StData {
@@ -642,7 +642,7 @@ StCacheOperator: ast::StCacheOperator = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
-InstRet: ast::Instruction<&'input str> = {
+InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = {
"ret" <u:".uni"?> => ast::Instruction::Ret(ast::RetData { uniform: u.is_some() })
};
@@ -675,28 +675,28 @@ VectorOperand: (&'input str, &'input str) = {
<pref:ExtendedID> <suf:DotID> => (pref, &suf[1..]),
};
-Arg1: ast::Arg1<&'input str> = {
+Arg1: ast::Arg1<ast::ParsedArgParams<'input>> = {
<src:ExtendedID> => ast::Arg1{<>}
};
-Arg2: ast::Arg2<&'input str> = {
+Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
};
-Arg2Mov: ast::Arg2Mov<&'input str> = {
+Arg2Mov: ast::Arg2Mov<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:MovOperand> => ast::Arg2Mov{<>}
};
-Arg3: ast::Arg3<&'input str> = {
+Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
};
-Arg4: ast::Arg4<&'input str> = {
+Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> => ast::Arg4{<>}
};
// TODO: pass src3 negation somewhere
-Arg5: ast::Arg5<&'input str> = {
+Arg5: ast::Arg5<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
};
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 3486edd..ebcb090 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -168,9 +168,9 @@ fn apply_id_offset(func_body: &mut Vec<ExpandedStatement>, id_offset: u32) {
}
}
-fn to_ssa<'a>(
- f_args: &[ast::Argument],
- f_body: Vec<ast::Statement<&'a str>>,
+fn to_ssa<'a, 'b>(
+ f_args: &'b [ast::Argument<'a>],
+ f_body: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
) -> (Vec<ExpandedStatement>, spirv::Word) {
let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body);
let normalized_statements = normalize_predicates(normalized_ids, &mut id_def);
@@ -214,7 +214,7 @@ fn normalize_labels(
}
fn normalize_predicates(
- func: Vec<ast::Statement<spirv::Word>>,
+ func: Vec<ast::Statement<NormalizedArgParams>>,
id_def: &mut NumericIdResolver,
) -> Vec<NormalizedStatement> {
let mut result = Vec::with_capacity(func.len());
@@ -343,51 +343,51 @@ fn expand_arguments(
fn normalize_insert_instruction(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
- instr: ast::Instruction<spirv::Word>,
-) -> Instruction {
+ instr: ast::Instruction<NormalizedArgParams>,
+) -> ast::Instruction<ExpandedArgParams> {
match instr {
ast::Instruction::Ld(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a);
- Instruction::Ld(d, arg)
+ ast::Instruction::Ld(d, arg)
}
ast::Instruction::Mov(d, a) => {
let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a);
- Instruction::Mov(d, arg)
+ ast::Instruction::Mov(d, arg)
}
ast::Instruction::Mul(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
- Instruction::Mul(d, arg)
+ ast::Instruction::Mul(d, arg)
}
ast::Instruction::Add(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| d.get_type().try_as_scalar(), a);
- Instruction::Add(d, arg)
+ ast::Instruction::Add(d, arg)
}
ast::Instruction::Setp(d, a) => {
let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a);
- Instruction::Setp(d, arg)
+ ast::Instruction::Setp(d, arg)
}
ast::Instruction::SetpBool(d, a) => {
let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a);
- Instruction::SetpBool(d, arg)
+ ast::Instruction::SetpBool(d, arg)
}
ast::Instruction::Not(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
- Instruction::Not(d, arg)
+ ast::Instruction::Not(d, arg)
}
- ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }),
+ ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, ast::Arg1 { src: a.src }),
ast::Instruction::Cvt(d, a) => {
let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
- Instruction::Cvt(d, arg)
+ ast::Instruction::Cvt(d, arg)
}
ast::Instruction::Shl(d, a) => {
let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a);
- Instruction::Shl(d, arg)
+ ast::Instruction::Shl(d, arg)
}
ast::Instruction::St(d, a) => {
let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a);
- Instruction::St(d, arg)
+ ast::Instruction::St(d, arg)
}
- ast::Instruction::Ret(d) => Instruction::Ret(d),
+ ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
}
}
@@ -395,9 +395,9 @@ fn normalize_expand_arg2(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
- a: ast::Arg2<spirv::Word>,
-) -> Arg2 {
- Arg2 {
+ a: ast::Arg2<NormalizedArgParams>,
+) -> ast::Arg2<ExpandedArgParams> {
+ ast::Arg2 {
dst: a.dst,
src: normalize_expand_operand(func, id_def, inst_type, a.src),
}
@@ -407,9 +407,9 @@ fn normalize_expand_arg2mov(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
- a: ast::Arg2Mov<spirv::Word>,
-) -> Arg2 {
- Arg2 {
+ a: ast::Arg2Mov<NormalizedArgParams>,
+) -> ast::Arg2Mov<ExpandedArgParams> {
+ ast::Arg2Mov {
dst: a.dst,
src: normalize_expand_mov_operand(func, id_def, inst_type, a.src),
}
@@ -419,9 +419,9 @@ fn normalize_expand_arg2st(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
- a: ast::Arg2St<spirv::Word>,
-) -> Arg2St {
- Arg2St {
+ a: ast::Arg2St<NormalizedArgParams>,
+) -> ast::Arg2St<ExpandedArgParams> {
+ ast::Arg2St {
src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
}
@@ -431,9 +431,9 @@ fn normalize_expand_arg3(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
- a: ast::Arg3<spirv::Word>,
-) -> Arg3 {
- Arg3 {
+ a: ast::Arg3<NormalizedArgParams>,
+) -> ast::Arg3<ExpandedArgParams> {
+ ast::Arg3 {
dst: a.dst,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
@@ -444,9 +444,9 @@ fn normalize_expand_arg4(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
- a: ast::Arg4<spirv::Word>,
-) -> Arg4 {
- Arg4 {
+ a: ast::Arg4<NormalizedArgParams>,
+) -> ast::Arg4<ExpandedArgParams> {
+ ast::Arg4 {
dst1: a.dst1,
dst2: a.dst2,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
@@ -458,9 +458,9 @@ fn normalize_expand_arg5(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
- a: ast::Arg5<spirv::Word>,
-) -> Arg5 {
- Arg5 {
+ a: ast::Arg5<NormalizedArgParams>,
+) -> ast::Arg5<ExpandedArgParams> {
+ ast::Arg5 {
dst1: a.dst1,
dst2: a.dst2,
src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
@@ -527,7 +527,7 @@ fn insert_implicit_conversions(
for s in func.into_iter() {
match s {
Statement::Instruction(inst) => match inst {
- Instruction::Ld(ld, mut arg) => {
+ ast::Instruction::Ld(ld, mut arg) => {
arg.src = insert_implicit_conversions_ld_src(
&mut result,
ast::Type::Scalar(ld.typ),
@@ -542,10 +542,10 @@ fn insert_implicit_conversions(
should_convert_relaxed_dst,
arg,
|arg| &mut arg.dst,
- |arg| Instruction::Ld(ld, arg),
+ |arg| ast::Instruction::Ld(ld, arg),
);
}
- Instruction::St(st, mut arg) => {
+ ast::Instruction::St(st, mut arg) => {
let arg_src2_type = id_def.get_type(arg.src2);
if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
arg.src2 = insert_conversion_src(
@@ -564,7 +564,7 @@ fn insert_implicit_conversions(
st.state_space.to_ld_ss(),
arg.src1,
);
- result.push(Statement::Instruction(Instruction::St(st, arg)));
+ result.push(Statement::Instruction(ast::Instruction::St(st, arg)));
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
},
@@ -668,10 +668,10 @@ fn emit_function_body_ops(
}
Statement::Instruction(inst) => match inst {
// SPIR-V does not support marking jumps as guaranteed-converged
- Instruction::Bra(_, arg) => {
+ ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?;
}
- Instruction::Ld(data, arg) => {
+ ast::Instruction::Ld(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
todo!()
}
@@ -686,7 +686,7 @@ fn emit_function_body_ops(
_ => todo!(),
}
}
- Instruction::St(data, arg) => {
+ ast::Instruction::St(data, arg) => {
if data.qualifier != ast::LdStQualifier::Weak
|| data.vector.is_some()
|| data.state_space != ast::StStateSpace::Generic
@@ -696,18 +696,18 @@ fn emit_function_body_ops(
builder.store(arg.src1, arg.src2, None, &[])?;
}
// SPIR-V does not support ret as guaranteed-converged
- Instruction::Ret(_) => builder.ret()?,
- Instruction::Mov(mov, arg) => {
+ ast::Instruction::Ret(_) => builder.ret()?,
+ ast::Instruction::Mov(mov, arg) => {
let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
- Instruction::Mul(mul, arg) => match mul {
+ ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Int(ref ctr) => {
emit_mul_int(builder, map, opencl, ctr, arg)?;
}
ast::MulDetails::Float(_) => todo!(),
},
- Instruction::Add(add, arg) => match add {
+ ast::Instruction::Add(add, arg) => match add {
ast::AddDetails::Int(ref desc) => {
emit_add_int(builder, map, desc, arg)?;
}
@@ -732,7 +732,7 @@ fn emit_mul_int(
map: &mut TypeWordMap,
opencl: spirv::Word,
desc: &ast::MulIntDesc,
- arg: &Arg3,
+ arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(desc.typ.into()));
match desc.control {
@@ -762,7 +762,7 @@ fn emit_add_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
ctr: &ast::AddIntDesc,
- arg: &Arg3,
+ arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::Base(ctr.typ.into()));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
@@ -837,10 +837,10 @@ fn emit_implicit_conversion(
}
// TODO: support scopes
-fn normalize_identifiers<'a>(
- args: &'a [ast::Argument<'a>],
- func: Vec<ast::Statement<&'a str>>,
-) -> (Vec<ast::Statement<spirv::Word>>, NumericIdResolver) {
+fn normalize_identifiers<'a, 'b>(
+ args: &'b [ast::Argument<'a>],
+ func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
+) -> (Vec<ast::Statement<NormalizedArgParams>>, NumericIdResolver) {
let mut id_defs = StringIdResolver::new();
for arg in args {
id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type)));
@@ -854,8 +854,8 @@ fn normalize_identifiers<'a>(
fn expand_map_ids<'a>(
id_defs: &mut StringIdResolver<'a>,
- result: &mut Vec<ast::Statement<spirv::Word>>,
- s: ast::Statement<&'a str>,
+ result: &mut Vec<ast::Statement<NormalizedArgParams>>,
+ s: ast::Statement<ast::ParsedArgParams<'a>>,
) {
match s {
ast::Statement::Label(name) => {
@@ -979,7 +979,7 @@ enum Statement<I> {
Constant(ConstantDefinition),
}
-impl Statement<Instruction> {
+impl Statement<ast::Instruction<ExpandedArgParams>> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
match self {
Statement::Variable(id, _, _) => f(id),
@@ -994,25 +994,25 @@ impl Statement<Instruction> {
}
}
-type NormalizedStatement = Statement<ast::Instruction<spirv::Word>>;
-type ExpandedStatement = Statement<Instruction>;
+enum NormalizedArgParams {}
+type NormalizedStatement = Statement<ast::Instruction<NormalizedArgParams>>;
-enum Instruction {
- Ld(ast::LdData, Arg2),
- Mov(ast::MovData, Arg2),
- Mul(ast::MulDetails, Arg3),
- Add(ast::AddDetails, Arg3),
- Setp(ast::SetpData, Arg4),
- SetpBool(ast::SetpBoolData, Arg5),
- Not(ast::NotData, Arg2),
- Bra(ast::BraData, Arg1),
- Cvt(ast::CvtData, Arg2),
- Shl(ast::ShlData, Arg3),
- St(ast::StData, Arg2St),
- Ret(ast::RetData),
+impl ast::ArgParams for NormalizedArgParams {
+ type ID = spirv::Word;
+ type Operand = ast::Operand<spirv::Word>;
+ type MovOperand = ast::MovOperand<spirv::Word>;
}
-impl ast::Instruction<spirv::Word> {
+enum ExpandedArgParams {}
+type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>>;
+
+impl ast::ArgParams for ExpandedArgParams {
+ type ID = spirv::Word;
+ type Operand = spirv::Word;
+ type MovOperand = spirv::Word;
+}
+
+impl ast::Instruction<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(&mut self, f: &mut F) {
match self {
ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
@@ -1031,22 +1031,22 @@ impl ast::Instruction<spirv::Word> {
}
}
-impl Instruction {
+impl ast::Instruction<ExpandedArgParams> {
fn visit_id<F: FnMut(&mut spirv::Word)>(&mut self, f: &mut F) {
let f_visitor = &mut Self::typed_visitor(f);
match self {
- Instruction::Ld(_, a) => a.visit_id(f_visitor, None),
- Instruction::Mov(_, a) => a.visit_id(f_visitor, None),
- Instruction::Mul(_, a) => a.visit_id(f_visitor, None),
- Instruction::Add(_, a) => a.visit_id(f_visitor, None),
- Instruction::Setp(_, a) => a.visit_id(f_visitor, None),
- Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None),
- Instruction::Not(_, a) => a.visit_id(f_visitor, None),
- Instruction::Cvt(_, a) => a.visit_id(f_visitor, None),
- Instruction::Shl(_, a) => a.visit_id(f_visitor, None),
- Instruction::St(_, a) => a.visit_id(f_visitor, None),
- Instruction::Bra(_, a) => a.visit_id(f_visitor, None),
- Instruction::Ret(_) => (),
+ ast::Instruction::Ld(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Mov(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Mul(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Add(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Setp(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::SetpBool(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Not(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Cvt(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Shl(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::St(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Bra(_, a) => a.visit_id(f_visitor, None),
+ ast::Instruction::Ret(_) => (),
}
}
@@ -1061,42 +1061,40 @@ impl Instruction {
f: &mut F,
) {
match self {
- Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
- Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
- Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
- Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
- Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
- Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
- Instruction::Not(_, _) => todo!(),
- Instruction::Cvt(_, _) => todo!(),
- Instruction::Shl(_, _) => todo!(),
- Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
- Instruction::Bra(_, a) => a.visit_id(f, None),
- Instruction::Ret(_) => (),
+ ast::Instruction::Ld(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
+ ast::Instruction::Mov(d, a) => a.visit_id(f, Some(d.typ)),
+ ast::Instruction::Mul(d, a) => a.visit_id(f, Some(d.get_type())),
+ ast::Instruction::Add(d, a) => a.visit_id(f, Some(d.get_type())),
+ ast::Instruction::Setp(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
+ ast::Instruction::SetpBool(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
+ ast::Instruction::Not(_, _) => todo!(),
+ ast::Instruction::Cvt(_, _) => todo!(),
+ ast::Instruction::Shl(_, _) => todo!(),
+ ast::Instruction::St(d, a) => a.visit_id(f, Some(ast::Type::Scalar(d.typ))),
+ ast::Instruction::Bra(_, a) => a.visit_id(f, None),
+ ast::Instruction::Ret(_) => (),
}
}
fn jump_target(&self) -> Option<spirv::Word> {
match self {
- Instruction::Bra(_, a) => Some(a.src),
- Instruction::Ld(_, _)
- | Instruction::Mov(_, _)
- | Instruction::Mul(_, _)
- | Instruction::Add(_, _)
- | Instruction::Setp(_, _)
- | Instruction::SetpBool(_, _)
- | Instruction::Not(_, _)
- | Instruction::Cvt(_, _)
- | Instruction::Shl(_, _)
- | Instruction::St(_, _)
- | Instruction::Ret(_) => None,
+ ast::Instruction::Bra(_, a) => Some(a.src),
+ ast::Instruction::Ld(_, _)
+ | ast::Instruction::Mov(_, _)
+ | ast::Instruction::Mul(_, _)
+ | ast::Instruction::Add(_, _)
+ | ast::Instruction::Setp(_, _)
+ | ast::Instruction::SetpBool(_, _)
+ | ast::Instruction::Not(_, _)
+ | ast::Instruction::Cvt(_, _)
+ | ast::Instruction::Shl(_, _)
+ | ast::Instruction::St(_, _)
+ | ast::Instruction::Ret(_) => None,
}
}
}
-struct Arg1 {
- pub src: spirv::Word,
-}
+type Arg1 = ast::Arg1<ExpandedArgParams>;
impl Arg1 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@@ -1108,10 +1106,7 @@ impl Arg1 {
}
}
-struct Arg2 {
- pub dst: spirv::Word,
- pub src: spirv::Word,
-}
+type Arg2 = ast::Arg2<ExpandedArgParams>;
impl Arg2 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@@ -1124,11 +1119,21 @@ impl Arg2 {
}
}
-pub struct Arg2St {
- pub src1: spirv::Word,
- pub src2: spirv::Word,
+type Arg2Mov = ast::Arg2Mov<ExpandedArgParams>;
+
+impl Arg2Mov {
+ fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
+ &mut self,
+ f: &mut F,
+ t: Option<ast::Type>,
+ ) {
+ f(true, &mut self.dst, t);
+ f(false, &mut self.src, t);
+ }
}
+type Arg2St = ast::Arg2St<ExpandedArgParams>;
+
impl Arg2St {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
@@ -1140,11 +1145,7 @@ impl Arg2St {
}
}
-struct Arg3 {
- pub dst: spirv::Word,
- pub src1: spirv::Word,
- pub src2: spirv::Word,
-}
+type Arg3 = ast::Arg3<ExpandedArgParams>;
impl Arg3 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@@ -1158,12 +1159,7 @@ impl Arg3 {
}
}
-struct Arg4 {
- pub dst1: spirv::Word,
- pub dst2: Option<spirv::Word>,
- pub src1: spirv::Word,
- pub src2: spirv::Word,
-}
+type Arg4 = ast::Arg4<ExpandedArgParams>;
impl Arg4 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@@ -1188,13 +1184,7 @@ impl Arg4 {
}
}
-struct Arg5 {
- pub dst1: spirv::Word,
- pub dst2: Option<spirv::Word>,
- pub src1: spirv::Word,
- pub src2: spirv::Word,
- pub src3: spirv::Word,
-}
+type Arg5 = ast::Arg5<ExpandedArgParams>;
impl Arg5 {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
@@ -1286,8 +1276,11 @@ impl<T> ast::PredAt<T> {
}
}
-impl<T> ast::Instruction<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
+impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
+ fn map_id<F: FnMut(&'a str) -> spirv::Word>(
+ self,
+ f: &mut F,
+ ) -> ast::Instruction<NormalizedArgParams> {
match self {
ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)),
ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)),
@@ -1305,13 +1298,13 @@ impl<T> ast::Instruction<T> {
}
}
-impl<T> ast::Arg1<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg1<U> {
+impl<'a> ast::Arg1<ast::ParsedArgParams<'a>> {
+ fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg1<NormalizedArgParams> {
ast::Arg1 { src: f(self.src) }
}
}
-impl ast::Arg1<spirv::Word> {
+impl ast::Arg1<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@@ -1321,8 +1314,8 @@ impl ast::Arg1<spirv::Word> {
}
}
-impl<T> ast::Arg2<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2<U> {
+impl<'a> ast::Arg2<ast::ParsedArgParams<'a>> {
+ fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg2<NormalizedArgParams> {
ast::Arg2 {
dst: f(self.dst),
src: self.src.map_id(f),
@@ -1330,7 +1323,7 @@ impl<T> ast::Arg2<T> {
}
}
-impl ast::Arg2<spirv::Word> {
+impl ast::Arg2<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@@ -1341,8 +1334,11 @@ impl ast::Arg2<spirv::Word> {
}
}
-impl<T> ast::Arg2St<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2St<U> {
+impl<'a> ast::Arg2St<ast::ParsedArgParams<'a>> {
+ fn map_id<F: FnMut(&'a str) -> spirv::Word>(
+ self,
+ f: &mut F,
+ ) -> ast::Arg2St<NormalizedArgParams> {
ast::Arg2St {
src1: self.src1.map_id(f),
src2: self.src2.map_id(f),
@@ -1350,7 +1346,7 @@ impl<T> ast::Arg2St<T> {
}
}
-impl ast::Arg2St<spirv::Word> {
+impl ast::Arg2St<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@@ -1361,8 +1357,11 @@ impl ast::Arg2St<spirv::Word> {
}
}
-impl<T> ast::Arg2Mov<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> {
+impl<'a> ast::Arg2Mov<ast::ParsedArgParams<'a>> {
+ fn map_id<F: FnMut(&'a str) -> spirv::Word>(
+ self,
+ f: &mut F,
+ ) -> ast::Arg2Mov<NormalizedArgParams> {
ast::Arg2Mov {
dst: f(self.dst),
src: self.src.map_id(f),
@@ -1370,7 +1369,7 @@ impl<T> ast::Arg2Mov<T> {
}
}
-impl ast::Arg2Mov<spirv::Word> {
+impl ast::Arg2Mov<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@@ -1381,8 +1380,8 @@ impl ast::Arg2Mov<spirv::Word> {
}
}
-impl<T> ast::Arg3<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg3<U> {
+impl<'a> ast::Arg3<ast::ParsedArgParams<'a>> {
+ fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg3<NormalizedArgParams> {
ast::Arg3 {
dst: f(self.dst),
src1: self.src1.map_id(f),
@@ -1391,7 +1390,7 @@ impl<T> ast::Arg3<T> {
}
}
-impl ast::Arg3<spirv::Word> {
+impl ast::Arg3<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@@ -1403,8 +1402,8 @@ impl ast::Arg3<spirv::Word> {
}
}
-impl<T> ast::Arg4<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
+impl<'a> ast::Arg4<ast::ParsedArgParams<'a>> {
+ fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg4<NormalizedArgParams> {
ast::Arg4 {
dst1: f(self.dst1),
dst2: self.dst2.map(|i| f(i)),
@@ -1414,7 +1413,7 @@ impl<T> ast::Arg4<T> {
}
}
-impl ast::Arg4<spirv::Word> {
+impl ast::Arg4<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@@ -1437,8 +1436,8 @@ impl ast::Arg4<spirv::Word> {
}
}
-impl<T> ast::Arg5<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg5<U> {
+impl<'a> ast::Arg5<ast::ParsedArgParams<'a>> {
+ fn map_id<F: FnMut(&'a str) -> spirv::Word>(self, f: &mut F) -> ast::Arg5<NormalizedArgParams> {
ast::Arg5 {
dst1: f(self.dst1),
dst2: self.dst2.map(|i| f(i)),
@@ -1449,7 +1448,7 @@ impl<T> ast::Arg5<T> {
}
}
-impl ast::Arg5<spirv::Word> {
+impl ast::Arg5<NormalizedArgParams> {
fn visit_id<F: FnMut(bool, &mut spirv::Word, Option<ast::Type>)>(
&mut self,
f: &mut F,
@@ -1779,7 +1778,7 @@ fn insert_with_implicit_conversion_dst<
T,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
- ToInstruction: FnOnce(T) -> Instruction,
+ ToInstruction: FnOnce(T) -> ast::Instruction<ExpandedArgParams>,
>(
func: &mut Vec<ExpandedStatement>,
instr_type: ast::ScalarType,
@@ -1907,7 +1906,7 @@ fn should_convert_relaxed_dst(
fn insert_implicit_bitcasts(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
- mut instr: Instruction,
+ mut instr: ast::Instruction<ExpandedArgParams>,
) {
let mut dst_coercion = None;
instr.visit_id_extended(&mut |is_dst, id, id_type| {