diff options
author | Andrzej Janik <[email protected]> | 2020-09-27 23:51:34 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-09-30 19:27:29 +0200 |
commit | 1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8 (patch) | |
tree | 26161415586497ec9876198d6a55e17342b740ae | |
parent | 7c26568cbf017c55b27b72a7fcfe7761ce31e33c (diff) | |
download | ZLUDA-1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8.tar.gz ZLUDA-1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8.zip |
Implement vector-destructuring mov/ld/st
-rw-r--r-- | ptx/src/ast.rs | 185 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 136 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/ntid.spvtxt | 13 | ||||
-rw-r--r-- | ptx/src/translate.rs | 1179 |
4 files changed, 1041 insertions, 472 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index acefdc1..7edfa70 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -35,6 +35,19 @@ macro_rules! sub_scalar_type { } } } + + impl std::convert::TryFrom<ScalarType> for $name { + type Error = (); + + fn try_from(t: ScalarType) -> Result<Self, Self::Error> { + match t { + $( + ScalarType::$variant => Ok($name::$variant), + )+ + _ => Err(()), + } + } + } }; } @@ -159,20 +172,20 @@ pub struct Module<'a> { pub functions: Vec<ParsedFunction<'a>>, } -pub enum MethodDecl<'a, P: ArgParams> { - Func(Vec<FnArgument<P>>, P::ID, Vec<FnArgument<P>>), - Kernel(&'a str, Vec<KernelArgument<P>>), +pub enum MethodDecl<'a, ID> { + Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>), + Kernel(&'a str, Vec<KernelArgument<ID>>), } -pub type FnArgument<P> = Variable<FnArgumentType, P>; -pub type KernelArgument<P> = Variable<VariableParamType, P>; +pub type FnArgument<ID> = Variable<FnArgumentType, ID>; +pub type KernelArgument<ID> = Variable<VariableParamType, ID>; -pub struct Function<'a, P: ArgParams, S> { - pub func_directive: MethodDecl<'a, P>, +pub struct Function<'a, ID, S> { + pub func_directive: MethodDecl<'a, ID>, pub body: Option<Vec<S>>, } -pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>; +pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>; #[derive(PartialEq, Eq, Clone, Copy)] pub enum FnArgumentType { @@ -264,21 +277,21 @@ impl Default for ScalarType { } pub enum Statement<P: ArgParams> { - Label(P::ID), - Variable(MultiVariable<P>), - Instruction(Option<PredAt<P::ID>>, Instruction<P>), + Label(P::Id), + Variable(MultiVariable<P::Id>), + Instruction(Option<PredAt<P::Id>>, Instruction<P>), Block(Vec<Statement<P>>), } -pub struct MultiVariable<P: ArgParams> { - pub var: Variable<VariableType, P>, +pub struct MultiVariable<ID> { + pub var: Variable<VariableType, ID>, pub count: Option<u32>, } -pub struct Variable<T, P: ArgParams> { +pub struct Variable<T, ID> { pub align: Option<u32>, pub v_type: T, - pub name: P::ID, + pub name: ID, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -315,9 +328,8 @@ pub struct PredAt<ID> { } pub enum Instruction<P: ArgParams> { - Ld(LdData, Arg2<P>), - Mov(MovDetails, Arg2<P>), - MovVector(MovVectorDetails, Arg2Vec<P>), + Ld(LdDetails, Arg2Ld<P>), + Mov(MovDetails, Arg2Mov<P>), Mul(MulDetails, Arg3<P>), Add(AddDetails, Arg3<P>), Setp(SetpData, Arg4Setp<P>), @@ -338,11 +350,6 @@ pub enum Instruction<P: ArgParams> { pub struct MadFloatDesc {} #[derive(Copy, Clone)] -pub struct MovVectorDetails { - pub typ: MovVectorType, - pub length: u8, -} -#[derive(Copy, Clone)] pub struct AbsDetails { pub flush_to_zero: bool, pub typ: ScalarType, @@ -350,16 +357,18 @@ pub struct AbsDetails { pub struct CallInst<P: ArgParams> { pub uniform: bool, - pub ret_params: Vec<P::ID>, - pub func: P::ID, + pub ret_params: Vec<P::Id>, + pub func: P::Id, pub param_list: Vec<P::CallOperand>, } pub trait ArgParams { - type ID; + type Id; type Operand; + type IdOrVector; + type OperandOrVector; type CallOperand; - type VecOperand; + type SrcMemberOperand; } pub struct ParsedArgParams<'a> { @@ -367,57 +376,73 @@ pub struct ParsedArgParams<'a> { } impl<'a> ArgParams for ParsedArgParams<'a> { - type ID = &'a str; + type Id = &'a str; type Operand = Operand<&'a str>; type CallOperand = CallOperand<&'a str>; - type VecOperand = (&'a str, u8); + type IdOrVector = IdOrVector<&'a str>; + type OperandOrVector = OperandOrVector<&'a str>; + type SrcMemberOperand = (&'a str, u8); } pub struct Arg1<P: ArgParams> { - pub src: P::ID, // it is a jump destination, but in terms of operands it is a source operand + pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand } pub struct Arg2<P: ArgParams> { - pub dst: P::ID, + pub dst: P::Id, + pub src: P::Operand, +} +pub struct Arg2Ld<P: ArgParams> { + pub dst: P::IdOrVector, pub src: P::Operand, } pub struct Arg2St<P: ArgParams> { pub src1: P::Operand, - pub src2: P::Operand, + pub src2: P::OperandOrVector, +} + +pub enum Arg2Mov<P: ArgParams> { + Normal(Arg2MovNormal<P>), + Member(Arg2MovMember<P>), +} + +pub struct Arg2MovNormal<P: ArgParams> { + pub dst: P::IdOrVector, + pub src: P::OperandOrVector, } // We duplicate dst here because during further compilation // composite dst and composite src will receive different ids -pub enum Arg2Vec<P: ArgParams> { - Dst((P::ID, u8), P::ID, P::ID), - Src(P::ID, P::VecOperand), - Both((P::ID, u8), P::ID, P::VecOperand), +pub enum Arg2MovMember<P: ArgParams> { + Dst((P::Id, u8), P::Id, P::Id), + Src(P::Id, P::SrcMemberOperand), + Both((P::Id, u8), P::Id, P::SrcMemberOperand), } pub struct Arg3<P: ArgParams> { - pub dst: P::ID, + pub dst: P::Id, pub src1: P::Operand, pub src2: P::Operand, } pub struct Arg4<P: ArgParams> { - pub dst: P::ID, + pub dst: P::Id, pub src1: P::Operand, pub src2: P::Operand, pub src3: P::Operand, } pub struct Arg4Setp<P: ArgParams> { - pub dst1: P::ID, - pub dst2: Option<P::ID>, + pub dst1: P::Id, + pub dst2: Option<P::Id>, pub src1: P::Operand, pub src2: P::Operand, } pub struct Arg5<P: ArgParams> { - pub dst1: P::ID, - pub dst2: Option<P::ID>, + pub dst1: P::Id, + pub dst2: Option<P::Id>, pub src1: P::Operand, pub src2: P::Operand, pub src3: P::Operand, @@ -436,12 +461,34 @@ pub enum CallOperand<ID> { Imm(u32), } +pub enum IdOrVector<ID> { + Reg(ID), + Vec(Vec<ID>) +} + +pub enum OperandOrVector<ID> { + Reg(ID), + RegOffset(ID, i32), + Imm(u32), + Vec(Vec<ID>) +} + +impl<T> From<Operand<T>> for OperandOrVector<T> { + fn from(this: Operand<T>) -> Self { + match this { + Operand::Reg(r) => OperandOrVector::Reg(r), + Operand::RegOffset(r, imm) => OperandOrVector::RegOffset(r, imm), + Operand::Imm(imm) => OperandOrVector::Imm(imm), + } + } +} + pub enum VectorPrefix { V2, V4, } -pub struct LdData { +pub struct LdDetails { pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, @@ -482,45 +529,23 @@ pub enum LdCacheOperator { Uncached, } -sub_scalar_type!(MovScalarType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, - Pred, -}); - -// pred vectors are illegal -sub_scalar_type!(MovVectorType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, -}); - +#[derive(Copy, Clone)] pub struct MovDetails { - pub typ: MovType, + pub typ: Type, pub src_is_address: bool, -} - -sub_type! { - MovType { - Scalar(MovScalarType), - Vector(MovVectorType, u8), + // two fields below are in use by member moves + pub dst_width: u8, + pub src_width: u8, +} + +impl MovDetails { + pub fn new(typ: Type) -> Self { + MovDetails { + typ, + src_is_address: false, + dst_width: 0, + src_width: 0 + } } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 50a6aeb..ba3fc2b 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -194,7 +194,7 @@ TargetSpecifier = { "map_f64_to_f32" }; -Directive: Option<ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement<ast::ParsedArgParams<'input>>>> = { +Directive: Option<ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>> = { AddressSize => None, <f:Function> => Some(f), File => None, @@ -205,7 +205,7 @@ AddressSize = { ".address_size" Num }; -Function: ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement<ast::ParsedArgParams<'input>>> = { +Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = { LinkingDirective* <func_directive:MethodDecl> <body:FunctionBody> => ast::Function{<>} @@ -217,29 +217,29 @@ LinkingDirective = { ".weak" }; -MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = { +MethodDecl: ast::MethodDecl<'input, &'input str> = { ".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params), ".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => { ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) } }; -KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = { +KernelArguments: Vec<ast::KernelArgument<&'input str>> = { "(" <args:Comma<KernelInput>> ")" => args }; -FnArguments: Vec<ast::FnArgument<ast::ParsedArgParams<'input>>> = { +FnArguments: Vec<ast::FnArgument<&'input str>> = { "(" <args:Comma<FnInput>> ")" => args }; -KernelInput: ast::Variable<ast::VariableParamType, ast::ParsedArgParams<'input>> = { +KernelInput: ast::Variable<ast::VariableParamType, &'input str> = { <v:ParamVariable> => { let (align, v_type, name) = v; ast::Variable{ align, v_type, name } } } -FnInput: ast::Variable<ast::FnArgumentType, ast::ParsedArgParams<'input>> = { +FnInput: ast::Variable<ast::FnArgumentType, &'input str> = { <v:RegVariable> => { let (align, v_type, name) = v; let v_type = ast::FnArgumentType::Reg(v_type); @@ -320,7 +320,7 @@ Align: u32 = { }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names -MultiVariable: ast::MultiVariable<ast::ParsedArgParams<'input>> = { +MultiVariable: ast::MultiVariable<&'input str> = { <var:Variable> <count:VariableParam?> => ast::MultiVariable{<>} } @@ -331,7 +331,7 @@ VariableParam: u32 = { } } -Variable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = { +Variable: ast::Variable<ast::VariableType, &'input str> = { <v:RegVariable> => { let (align, v_type, name) = v; let v_type = ast::VariableType::Reg(v_type); @@ -356,7 +356,7 @@ RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = { } } -LocalVariable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = { +LocalVariable: ast::Variable<ast::VariableType, &'input str> = { ".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => { let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); ast::Variable {align, v_type, name} @@ -449,19 +449,29 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { - "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," <src:MemoryOperand> => { + "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:IdOrVector> "," <src:MemoryOperand> => { ast::Instruction::Ld( - ast::LdData { + ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), state_space: ss.unwrap_or(ast::LdStateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t }, - ast::Arg2 { dst:dst, src:src } + ast::Arg2Ld { dst:dst, src:src } ) } }; +IdOrVector: ast::IdOrVector<&'input str> = { + <dst:ExtendedID> => ast::IdOrVector::Reg(dst), + <dst:VectorExtract> => ast::IdOrVector::Vec(dst) +} + +OperandOrVector: ast::OperandOrVector<&'input str> = { + <op:Operand> => ast::OperandOrVector::from(op), + <dst:VectorExtract> => ast::OperandOrVector::Vec(dst) +} + LdStType: ast::Type = { <v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v), <t:LdStScalarType> => ast::Type::Scalar(t), @@ -498,49 +508,58 @@ LdCacheOperator: ast::LdCacheOperator = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = { - "mov" <t:MovType> <a:Arg2> => { - ast::Instruction::Mov(ast::MovDetails{ src_is_address: false, typ: t }, a) - }, - "mov" <t:MovVectorType> <a:Arg2Vec> => { - ast::Instruction::MovVector(ast::MovVectorDetails{typ: t, length: 0}, a) - } + <m:MovNormal> => ast::Instruction::Mov(m.0, m.1), + <m:MovVector> => ast::Instruction::Mov(m.0, m.1), }; -#[inline] -MovType: ast::MovType = { - <t:MovScalarType> => ast::MovType::Scalar(t), - <pref:VectorPrefix> <t:MovVectorType> => ast::MovType::Vector(t, pref) + +MovNormal: (ast::MovDetails, ast::Arg2Mov<ast::ParsedArgParams<'input>>) = { + "mov" <t:MovScalarType> <dst:ExtendedID> "," <src:Operand> => {( + ast::MovDetails::new(ast::Type::Scalar(t)), + ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: ast::IdOrVector::Reg(dst), src: src.into() }) + )}, + "mov" <pref:VectorPrefix> <t:MovVectorType> <dst:IdOrVector> "," <src:OperandOrVector> => {( + ast::MovDetails::new(ast::Type::Vector(t, pref)), + ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: dst, src: src }) + )} +} + +MovVector: (ast::MovDetails, ast::Arg2Mov<ast::ParsedArgParams<'input>>) = { + "mov" <t:MovVectorType> <a:Arg2MovMember> => {( + ast::MovDetails::new(ast::Type::Scalar(t.into())), + ast::Arg2Mov::Member(a) + )}, } #[inline] -MovScalarType: ast::MovScalarType = { - ".b16" => ast::MovScalarType::B16, - ".b32" => ast::MovScalarType::B32, - ".b64" => ast::MovScalarType::B64, - ".u16" => ast::MovScalarType::U16, - ".u32" => ast::MovScalarType::U32, - ".u64" => ast::MovScalarType::U64, - ".s16" => ast::MovScalarType::S16, - ".s32" => ast::MovScalarType::S32, - ".s64" => ast::MovScalarType::S64, - ".f32" => ast::MovScalarType::F32, - ".f64" => ast::MovScalarType::F64, - ".pred" => ast::MovScalarType::Pred +MovScalarType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, + ".pred" => ast::ScalarType::Pred }; #[inline] -MovVectorType: ast::MovVectorType = { - ".b16" => ast::MovVectorType::B16, - ".b32" => ast::MovVectorType::B32, - ".b64" => ast::MovVectorType::B64, - ".u16" => ast::MovVectorType::U16, - ".u32" => ast::MovVectorType::U32, - ".u64" => ast::MovVectorType::U64, - ".s16" => ast::MovVectorType::S16, - ".s32" => ast::MovVectorType::S32, - ".s64" => ast::MovVectorType::S64, - ".f32" => ast::MovVectorType::F32, - ".f64" => ast::MovVectorType::F64, +MovVectorType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul @@ -902,7 +921,7 @@ ShlType: ast::ShlType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = { - "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:Operand> => { + "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:OperandOrVector> => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -1044,13 +1063,13 @@ Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = { <dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>} }; -Arg2Vec: ast::Arg2Vec<ast::ParsedArgParams<'input>> = { - <dst:VectorOperand> "," <src:ExtendedID> => ast::Arg2Vec::Dst(dst, dst.0, src), - <dst:ExtendedID> "," <src:VectorOperand> => ast::Arg2Vec::Src(dst, src), - <dst:VectorOperand> "," <src:VectorOperand> => ast::Arg2Vec::Both(dst, dst.0, src), +Arg2MovMember: ast::Arg2MovMember<ast::ParsedArgParams<'input>> = { + <dst:MemberOperand> "," <src:ExtendedID> => ast::Arg2MovMember::Dst(dst, dst.0, src), + <dst:ExtendedID> "," <src:MemberOperand> => ast::Arg2MovMember::Src(dst, src), + <dst:MemberOperand> "," <src:MemberOperand> => ast::Arg2MovMember::Both(dst, dst.0, src), }; -VectorOperand: (&'input str, u8) = { +MemberOperand: (&'input str, u8) = { <pref:ExtendedID> "." <suf:ExtendedID> =>? { let suf_idx = vector_index(suf)?; Ok((pref, suf_idx)) @@ -1061,6 +1080,15 @@ VectorOperand: (&'input str, u8) = { } }; +VectorExtract: Vec<&'input str> = { + "{" <r1:ExtendedID> "," <r2:ExtendedID> "}" => { + vec![r1, r2] + }, + "{" <r1:ExtendedID> "," <r2:ExtendedID> "," <r3:ExtendedID> "," <r4:ExtendedID> "}" => { + vec![r1, r2, r3, r4] + }, +}; + Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = { <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>} }; diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt index ef308f0..be16d2e 100644 --- a/ptx/src/test/spirv_run/ntid.spvtxt +++ b/ptx/src/test/spirv_run/ntid.spvtxt @@ -4,15 +4,16 @@ OpCapability Kernel OpCapability Int64 OpCapability Int8 + OpCapability Float64 %29 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "add" %GlobalSize - OpDecorate %GlobalSize BuiltIn GlobalSize + OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize + OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize %void = OpTypeVoid %uint = OpTypeInt 32 0 - %v3uint = OpTypeVector %uint 3 -%_ptr_UniformConstant_v3uint = OpTypePointer UniformConstant %v3uint - %GlobalSize = OpVariable %_ptr_UniformConstant_v3uint UniformConstant + %v4uint = OpTypeVector %uint 4 +%_ptr_UniformConstant_v4uint = OpTypePointer UniformConstant %v4uint +%gl_WorkGroupSize = OpVariable %_ptr_UniformConstant_v4uint UniformConstant %ulong = OpTypeInt 64 0 %35 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong @@ -40,7 +41,7 @@ %25 = OpConvertUToPtr %_ptr_Generic_uint %16 %15 = OpLoad %uint %25 OpStore %6 %15 - %18 = OpLoad %v3uint %GlobalSize + %18 = OpLoad %v4uint %gl_WorkGroupSize %24 = OpCompositeExtract %uint %18 0 %17 = OpCopyObject %uint %24 OpStore %7 %17 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index a1d4b6a..981da86 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,6 +1,7 @@ use crate::ast;
use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
+use std::convert::TryInto;
use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
@@ -282,7 +283,7 @@ fn emit_function_header<'a>( builder: &mut dr::Builder,
map: &mut TypeWordMap,
global: &GlobalStringIdResolver<'a>,
- func_directive: ast::MethodDecl<ExpandedArgParams>,
+ func_directive: ast::MethodDecl<spirv::Word>,
all_args_lens: &mut HashMap<String, Vec<usize>>,
) -> Result<(), TranslateError> {
if let ast::MethodDecl::Kernel(name, args) = &func_directive {
@@ -334,8 +335,10 @@ fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::Linkage);
builder.capability(spirv::Capability::Addresses);
builder.capability(spirv::Capability::Kernel);
- builder.capability(spirv::Capability::Int64);
builder.capability(spirv::Capability::Int8);
+ builder.capability(spirv::Capability::Int16);
+ builder.capability(spirv::Capability::Int64);
+ builder.capability(spirv::Capability::Float16);
builder.capability(spirv::Capability::Float64);
}
@@ -362,8 +365,8 @@ fn to_ssa_function<'a>( fn expand_kernel_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
-) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
+ args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
+) -> Vec<ast::KernelArgument<spirv::Word>> {
args.map(|a| ast::KernelArgument {
name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))),
v_type: a.v_type,
@@ -374,8 +377,8 @@ fn expand_kernel_params<'a, 'b>( fn expand_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
-) -> Vec<ast::FnArgument<ExpandedArgParams>> {
+ args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
+) -> Vec<ast::FnArgument<spirv::Word>> {
args.map(|a| {
let ss = match a.v_type {
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
@@ -393,7 +396,7 @@ fn expand_fn_params<'a, 'b>( fn to_ssa<'input, 'b>(
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
- f_args: ast::MethodDecl<'input, ExpandedArgParams>,
+ f_args: ast::MethodDecl<'input, spirv::Word>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
) -> Result<Function<'input>, TranslateError> {
let f_body = match f_body {
@@ -409,11 +412,11 @@ fn to_ssa<'input, 'b>( let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
- let unadorned_statements =
- add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
+ let typed_statements =
+ convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let (f_args, ssa_statements) =
- insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args)?;
+ insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, f_args)?;
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
@@ -443,15 +446,16 @@ fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedSta func
}
-fn add_types_to_statements(
- func: Vec<UnadornedStatement>,
+fn convert_to_typed_statements(
+ func: Vec<UnconditionalStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &NumericIdResolver,
-) -> Result<Vec<UnadornedStatement>, TranslateError> {
- func.into_iter()
- .map(|s| {
- match s {
- Statement::Instruction(ast::Instruction::Call(call)) => {
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let mut result = Vec::<TypedStatement>::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::Call(call) => {
// TODO: error out if lengths don't match
let fn_def = fn_defs.get_fn_decl(call.func)?;
let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
@@ -462,7 +466,7 @@ fn add_types_to_statements( func: call.func,
param_list,
};
- Ok(Statement::Call(resolved_call))
+ result.push(Statement::Call(resolved_call));
}
// Supported ld/st:
// global: only compatible with reg b64/u64/s64 source/dest
@@ -477,25 +481,24 @@ fn add_types_to_statements( // One complication: immediate address is only allowed in local,
// It is not supported in generic ld
// ld.local foo, [1];
- Statement::Instruction(ast::Instruction::Ld(mut d, arg)) => {
+ ast::Instruction::Ld(mut d, arg) => {
match arg.src.underlying() {
- None => return Ok(Statement::Instruction(ast::Instruction::Ld(d, arg))),
+ None => {}
Some(u) => {
let (ss, _) = id_defs.get_typed(*u)?;
match (d.state_space, ss) {
(ast::LdStateSpace::Generic, StateSpace::Local) => {
d.state_space = ast::LdStateSpace::Local;
}
- _ => (),
+ _ => {}
};
}
};
-
- Ok(Statement::Instruction(ast::Instruction::Ld(d, arg)))
+ result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast())));
}
- Statement::Instruction(ast::Instruction::St(mut d, arg)) => {
+ ast::Instruction::St(mut d, arg) => {
match arg.src1.underlying() {
- None => return Ok(Statement::Instruction(ast::Instruction::St(d, arg))),
+ None => {}
Some(u) => {
let (ss, _) = id_defs.get_typed(*u)?;
match (d.state_space, ss) {
@@ -506,39 +509,101 @@ fn add_types_to_statements( };
}
};
- Ok(Statement::Instruction(ast::Instruction::St(d, arg)))
+ result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast())));
}
- Statement::Instruction(ast::Instruction::Mov(mut d, arg)) => {
- if let Some(src_id) = arg.src.underlying() {
- let (scope, _) = id_defs.get_typed(*src_id)?;
- d.src_is_address = match scope {
- StateSpace::Reg => false,
- StateSpace::Const
- | StateSpace::Global
- | StateSpace::Local
- | StateSpace::Shared
- | StateSpace::Param
- | StateSpace::ParamReg => true,
+ ast::Instruction::Mov(mut d, args) => match args {
+ ast::Arg2Mov::Normal(arg) => {
+ if let Some(src_id) = arg.src.single_underlying() {
+ let (scope, _) = id_defs.get_typed(*src_id)?;
+ d.src_is_address = match scope {
+ StateSpace::Reg => false,
+ StateSpace::Const
+ | StateSpace::Global
+ | StateSpace::Local
+ | StateSpace::Shared
+ | StateSpace::Param
+ | StateSpace::ParamReg => true,
+ };
+ }
+ result.push(Statement::Instruction(ast::Instruction::Mov(
+ d,
+ ast::Arg2Mov::Normal(arg.cast()),
+ )));
+ }
+ ast::Arg2Mov::Member(args) => {
+ if let Some(dst_typ) = args.vector_dst() {
+ match id_defs.get_typed(*dst_typ)? {
+ (_, ast::Type::Vector(_, len)) => {
+ d.dst_width = len;
+ }
+ _ => return Err(TranslateError::MismatchedType),
+ }
+ };
+ if let Some((src_typ, _)) = args.vector_src() {
+ match id_defs.get_typed(*src_typ)? {
+ (_, ast::Type::Vector(_, len)) => {
+ d.src_width = len;
+ }
+ _ => return Err(TranslateError::MismatchedType),
+ }
};
+ result.push(Statement::Instruction(ast::Instruction::Mov(
+ d,
+ ast::Arg2Mov::Member(args.cast()),
+ )));
}
- Ok(Statement::Instruction(ast::Instruction::Mov(d, arg)))
+ },
+ ast::Instruction::Mul(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast())))
}
- Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
- let new_dets = match id_defs.get_typed(*args.dst())? {
- (_, ast::Type::Vector(_, len)) => ast::MovVectorDetails {
- length: len,
- ..dets
- },
- _ => dets,
- };
- Ok(Statement::Instruction(ast::Instruction::MovVector(
- new_dets, args,
- )))
+ ast::Instruction::Add(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast())))
}
- s => Ok(s),
- }
- })
- .collect::<Result<Vec<_>, _>>()
+ ast::Instruction::Setp(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast())))
+ }
+ ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction(
+ ast::Instruction::SetpBool(d, a.cast()),
+ )),
+ ast::Instruction::Not(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast())))
+ }
+ ast::Instruction::Bra(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast())))
+ }
+ ast::Instruction::Cvt(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast())))
+ }
+ ast::Instruction::Cvta(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast())))
+ }
+ ast::Instruction::Shl(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast())))
+ }
+ ast::Instruction::Ret(d) => {
+ result.push(Statement::Instruction(ast::Instruction::Ret(d)))
+ }
+ ast::Instruction::Abs(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast())))
+ }
+ ast::Instruction::Mad(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast())))
+ }
+ },
+ Statement::Label(i) => result.push(Statement::Label(i)),
+ Statement::Variable(v) => result.push(Statement::Variable(v)),
+ Statement::LoadVar(a, t) => result.push(Statement::LoadVar(a, t)),
+ Statement::StoreVar(a, t) => result.push(Statement::StoreVar(a, t)),
+ Statement::Call(c) => result.push(Statement::Call(c.cast())),
+ Statement::Composite(c) => result.push(Statement::Composite(c)),
+ Statement::Conditional(c) => result.push(Statement::Conditional(c)),
+ Statement::Conversion(c) => result.push(Statement::Conversion(c)),
+ Statement::Constant(c) => result.push(Statement::Constant(c)),
+ Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
+ Statement::Undef(_, _) => return Err(TranslateError::Unreachable),
+ }
+ }
+ Ok(result)
}
fn to_resolved_fn_args<T>(
@@ -576,7 +641,8 @@ fn normalize_labels( | Statement::RetValue(_, _)
| Statement::Conversion(_)
| Statement::Constant(_)
- | Statement::Label(_) => (),
+ | Statement::Label(_)
+ | Statement::Undef(_, _) => (),
}
}
iter::once(Statement::Label(id_def.new_id(None)))
@@ -590,7 +656,7 @@ fn normalize_labels( fn normalize_predicates(
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
-) -> Vec<UnadornedStatement> {
+) -> Vec<UnconditionalStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
@@ -630,16 +696,10 @@ fn normalize_predicates( }
fn insert_mem_ssa_statements<'a, 'b>(
- func: Vec<UnadornedStatement>,
+ func: Vec<TypedStatement>,
id_def: &mut MutableNumericIdResolver,
- mut f_args: ast::MethodDecl<'a, ExpandedArgParams>,
-) -> Result<
- (
- ast::MethodDecl<'a, ExpandedArgParams>,
- Vec<UnadornedStatement>,
- ),
- TranslateError,
-> {
+ mut f_args: ast::MethodDecl<'a, spirv::Word>,
+) -> Result<(ast::MethodDecl<'a, spirv::Word>, Vec<TypedStatement>), TranslateError> {
let mut result = Vec::with_capacity(func.len());
let out_param = match &mut f_args {
ast::MethodDecl::Kernel(_, in_params) => {
@@ -697,7 +757,9 @@ fn insert_mem_ssa_statements<'a, 'b>( };
for s in func {
match s {
- Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call)?,
+ Statement::Call(call) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, call.cast())?
+ }
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
if let Some(out_param) = out_param {
@@ -734,7 +796,8 @@ fn insert_mem_ssa_statements<'a, 'b>( | Statement::StoreVar(_, _)
| Statement::Conversion(_)
| Statement::RetValue(_, _)
- | Statement::Constant(_) => unreachable!(),
+ | Statement::Constant(_)
+ | Statement::Undef(_, _) => {}
Statement::Composite(_) => todo!(),
}
}
@@ -751,7 +814,7 @@ trait VisitVariable: Sized { >(
self,
f: &mut F,
- ) -> Result<UnadornedStatement, TranslateError>;
+ ) -> Result<TypedStatement, TranslateError>;
}
trait VisitVariableExpanded {
fn visit_variable_extended<
@@ -767,7 +830,7 @@ trait VisitVariableExpanded { fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
id_def: &mut MutableNumericIdResolver,
- result: &mut Vec<UnadornedStatement>,
+ result: &mut Vec<TypedStatement>,
stmt: F,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
@@ -808,7 +871,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( }
fn expand_arguments<'a, 'b>(
- func: Vec<UnadornedStatement>,
+ func: Vec<TypedStatement>,
id_def: &'b mut MutableNumericIdResolver<'a>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
@@ -840,9 +903,10 @@ fn expand_arguments<'a, 'b>( Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Composite(_) | Statement::Conversion(_) | Statement::Constant(_) => {
- unreachable!()
- }
+ Statement::Composite(_)
+ | Statement::Conversion(_)
+ | Statement::Constant(_)
+ | Statement::Undef(_, _) => unreachable!(),
}
}
Ok(result)
@@ -865,12 +929,26 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { post_stmts: Vec::new(),
}
}
-}
-impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
- for FlattenArguments<'a, 'b>
-{
- fn variable(
+ fn insert_composite_read(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut MutableNumericIdResolver<'a>,
+ (scalar_type, vec_len): (ast::ScalarType, u8),
+ scalar_dst: Option<spirv::Word>,
+ composite_src: (spirv::Word, u8),
+ ) -> spirv::Word {
+ let new_id =
+ scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Vector(scalar_type, vec_len)));
+ func.push(Statement::Composite(CompositeRead {
+ typ: scalar_type,
+ dst: new_id,
+ src_composite: composite_src.0,
+ src_index: composite_src.1 as u32,
+ }));
+ new_id
+ }
+
+ fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
_: Option<ast::Type>,
@@ -878,112 +956,105 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> Ok(desc.op)
}
- fn operand(
+ fn reg_offset(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: ast::Type,
) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)),
- ast::Operand::Imm(x) => {
+ let (reg, offset) = desc.op;
+ match desc.sema {
+ ArgumentSemantics::Default => {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
todo!()
};
- let id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let result_id = self.id_def.new_id(typ);
self.func.push(Statement::Constant(ConstantDefinition {
- dst: id,
+ dst: id_constant_stmt,
typ: scalar_t,
- value: x as i64,
+ value: offset as i64,
}));
- Ok(id)
+ let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ Ok(result_id)
}
- ast::Operand::RegOffset(reg, offset) => match desc.sema {
- ArgumentSemantics::Default => {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- let result_id = self.id_def.new_id(typ);
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- Ok(result_id)
- }
- ArgumentSemantics::PhysicalPointer => {
- let scalar_t = ast::ScalarType::U64;
- let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::U64;
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- Ok(result_id)
- }
- ArgumentSemantics::RegisterPointer => {
- if offset == 0 {
- return Ok(reg);
- }
- todo!()
+ ArgumentSemantics::PhysicalPointer => {
+ let scalar_t = ast::ScalarType::U64;
+ let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: scalar_t,
+ value: offset as i64,
+ }));
+ let int_type = ast::IntType::U64;
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ Ok(result_id)
+ }
+ ArgumentSemantics::RegisterPointer => {
+ if offset == 0 {
+ return Ok(reg);
}
- ArgumentSemantics::Address => todo!(),
- },
+ todo!()
+ }
+ ArgumentSemantics::Address => todo!(),
}
}
- fn src_call_operand(
+ fn immediate(
&mut self,
- desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ desc: ArgumentDescriptor<u32>,
typ: ast::Type,
) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)),
- ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ),
- }
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
+ } else {
+ todo!()
+ };
+ let id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value: desc.op as i64,
+ }));
+ Ok(id)
}
- fn src_vec_operand(
+ fn member_src(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vec_len): (ast::MovVectorType, u8),
+ (scalar_type, vec_len): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
- let new_id = self
- .id_def
- .new_id(ast::Type::Vector(scalar_type.into(), vec_len));
+ if desc.is_dst {
+ return Err(TranslateError::Unreachable);
+ }
+ let new_id = self.id_def.new_id(ast::Type::Vector(scalar_type, vec_len));
self.func.push(Statement::Composite(CompositeRead {
typ: scalar_type,
dst: new_id,
@@ -992,6 +1063,115 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams> }));
Ok(new_id)
}
+
+ fn vector(
+ &mut self,
+ desc: ArgumentDescriptor<&Vec<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ let (scalar_type, vec_len) = typ.get_vector()?;
+ if !desc.is_dst {
+ let mut new_id = self.id_def.new_id(typ);
+ self.func.push(Statement::Undef(typ, new_id));
+ for (idx, id) in desc.op.iter().enumerate() {
+ let newer_id = self.id_def.new_id(typ);
+ self.func.push(Statement::Instruction(ast::Instruction::Mov(
+ ast::MovDetails {
+ typ: typ,
+ src_is_address: false,
+ dst_width: 0,
+ src_width: vec_len,
+ },
+ ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
+ (newer_id, idx as u8),
+ new_id,
+ *id,
+ )),
+ )));
+ new_id = newer_id;
+ }
+ Ok(new_id)
+ } else {
+ let new_id = self.id_def.new_id(typ);
+ for (idx, id) in desc.op.iter().enumerate() {
+ Self::insert_composite_read(
+ &mut self.post_stmts,
+ self.id_def,
+ (scalar_type, vec_len),
+ Some(*id),
+ (new_id, idx as u8),
+ );
+ }
+ Ok(new_id)
+ }
+ }
+}
+
+impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenArguments<'a, 'b> {
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
+ self.reg(desc, t)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ match desc.op {
+ ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
+ ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ ast::Operand::RegOffset(reg, offset) => {
+ self.reg_offset(desc.new_op((reg, offset)), typ)
+ }
+ }
+ }
+
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ match desc.op {
+ ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)),
+ ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ }
+ }
+
+ fn src_member_operand(
+ &mut self,
+ desc: ArgumentDescriptor<(spirv::Word, u8)>,
+ (scalar_type, vec_len): (ast::ScalarType, u8),
+ ) -> Result<spirv::Word, TranslateError> {
+ self.member_src(desc, (scalar_type, vec_len))
+ }
+
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ match desc.op {
+ ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
+ ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
+ }
+ }
+
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ match desc.op {
+ ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
+ ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ),
+ ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ),
+ ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
+ }
+ }
}
/*
@@ -1078,14 +1258,13 @@ fn insert_implicit_conversions( |arg| ast::Instruction::St(st, arg),
)
}
- ast::Instruction::Mov(d, mut arg) => {
+ ast::Instruction::Mov(d, ast::Arg2Mov::Normal(mut arg)) => {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2
// TODO: handle the case of mixed vector/scalar implicit conversions
let inst_typ_is_bit = match d.typ {
- ast::MovType::Scalar(t) => {
- ast::ScalarType::from(t).kind() == ScalarKind::Bit
- }
- ast::MovType::Vector(_, _) => false,
+ ast::Type::Scalar(t) => ast::ScalarType::from(t).kind() == ScalarKind::Bit,
+ ast::Type::Vector(_, _) => false,
+ ast::Type::Array(_, _) => false,
};
let mut did_vector_implicit = false;
let mut post_conv = None;
@@ -1115,12 +1294,15 @@ fn insert_implicit_conversions( }
}
if did_vector_implicit {
- result.push(Statement::Instruction(ast::Instruction::Mov(d, arg)));
+ result.push(Statement::Instruction(ast::Instruction::Mov(
+ d,
+ ast::Arg2Mov::Normal(arg),
+ )));
} else {
insert_implicit_bitcasts(
&mut result,
id_def,
- ast::Instruction::Mov(d, arg),
+ ast::Instruction::Mov(d, ast::Arg2Mov::Normal(arg)),
)?;
}
if let Some(post_conv) = post_conv {
@@ -1129,13 +1311,14 @@ fn insert_implicit_conversions( }
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst)?,
},
- s @ Statement::Composite(_)
- | s @ Statement::Conditional(_)
+ Statement::Composite(c) => insert_implicit_bitcasts(&mut result, id_def, c)?,
+ s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _)
+ | s @ Statement::Undef(_, _)
| s @ Statement::RetValue(_, _) => result.push(s),
Statement::Conversion(_) => unreachable!(),
}
@@ -1146,7 +1329,7 @@ fn insert_implicit_conversions( fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- method_decl: &ast::MethodDecl<ExpandedArgParams>,
+ method_decl: &ast::MethodDecl<spirv::Word>,
) -> (spirv::Word, spirv::Word) {
match method_decl {
ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn(
@@ -1173,7 +1356,7 @@ fn emit_function_body_ops( map: &mut TypeWordMap,
opencl: spirv::Word,
func: &[ExpandedStatement],
-) -> Result<(), dr::Error> {
+) -> Result<(), TranslateError> {
for s in func {
match s {
Statement::Label(id) => {
@@ -1305,11 +1488,34 @@ fn emit_function_body_ops( }
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
- ast::Instruction::Mov(d, arg) => {
- let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ)));
- builder.copy_object(result_type, Some(arg.dst), arg.src)?;
- }
+ ast::Instruction::Mov(d, arg) => match arg {
+ ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src })
+ | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => {
+ let result_type =
+ map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ)));
+ builder.copy_object(result_type, Some(*dst), *src)?;
+ }
+ ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
+ dst,
+ composite_src,
+ scalar_src,
+ ))
+ | ast::Arg2Mov::Member(ast::Arg2MovMember::Both(
+ dst,
+ composite_src,
+ scalar_src,
+ )) => {
+ let result_type = map.get_or_add(builder, SpirvType::from(d.typ));
+ let result_id = Some(dst.0);
+ builder.composite_insert(
+ result_type,
+ result_id,
+ *scalar_src,
+ *composite_src,
+ [dst.1 as u32],
+ )?;
+ }
+ },
ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Int(ref ctr) => {
emit_mul_int(builder, map, opencl, ctr, arg)?;
@@ -1361,31 +1567,6 @@ fn emit_function_body_ops( builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
- ast::Instruction::MovVector(typ, arg) => match arg {
- ast::Arg2Vec::Dst((dst, dst_index), composite_src, src)
- | ast::Arg2Vec::Both((dst, dst_index), composite_src, src) => {
- let result_type = map.get_or_add(
- builder,
- SpirvType::Vector(
- SpirvScalarKey::from(ast::ScalarType::from(typ.typ)),
- typ.length,
- ),
- );
- let result_id = Some(*dst);
- builder.composite_insert(
- result_type,
- result_id,
- *src,
- *composite_src,
- [*dst_index as u32],
- )?;
- }
- ast::Arg2Vec::Src(dst, src) => {
- let result_type =
- map.get_or_add_scalar(builder, ast::ScalarType::from(typ.typ));
- builder.copy_object(result_type, Some(*dst), *src)?;
- }
- },
ast::Instruction::Mad(mad, arg) => match mad {
ast::MulDetails::Int(ref desc) => {
emit_mad_int(builder, map, opencl, desc, arg)?
@@ -1413,6 +1594,10 @@ fn emit_function_body_ops( [c.src_index],
)?;
}
+ Statement::Undef(t, id) => {
+ let result_type = map.get_or_add(builder, SpirvType::from(*t));
+ builder.undef(result_type, Some(*id));
+ }
}
}
Ok(())
@@ -2016,11 +2201,11 @@ impl<'a> GlobalStringIdResolver<'a> { fn start_fn<'b>(
&'b mut self,
- header: &'b ast::MethodDecl<'a, ast::ParsedArgParams<'a>>,
+ header: &'b ast::MethodDecl<'a, &'a str>,
) -> (
FnStringIdResolver<'a, 'b>,
GlobalFnDeclResolver<'a, 'b>,
- ast::MethodDecl<'a, ExpandedArgParams>,
+ ast::MethodDecl<'a, spirv::Word>,
) {
// In case a function decl was inserted earlier we want to use its id
let name_id = self.get_or_add_def(header.name());
@@ -2213,7 +2398,7 @@ impl<'b> MutableNumericIdResolver<'b> { enum Statement<I, P: ast::ArgParams> {
Label(u32),
- Variable(ast::Variable<ast::VariableType, P>),
+ Variable(ast::Variable<ast::VariableType, P::Id>),
Instruction(I),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
@@ -2224,6 +2409,7 @@ enum Statement<I, P: ast::ArgParams> { Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
+ Undef(ast::Type, spirv::Word),
}
struct ResolvedCall<P: ast::ArgParams> {
@@ -2233,8 +2419,19 @@ struct ResolvedCall<P: ast::ArgParams> { pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
}
-impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
- fn map<To: ArgParamsEx<ID = spirv::Word>, V: ArgumentMapVisitor<From, To>>(
+impl<T: ast::ArgParams> ResolvedCall<T> {
+ fn cast<U: ast::ArgParams<CallOperand = T::CallOperand>>(self) -> ResolvedCall<U> {
+ ResolvedCall {
+ uniform: self.uniform,
+ ret_params: self.ret_params,
+ func: self.func,
+ param_list: self.param_list,
+ }
+ }
+}
+
+impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
+ fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<From, To>>(
self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
@@ -2242,7 +2439,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> { .ret_params
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
- let new_id = visitor.variable(
+ let new_id = visitor.id(
ArgumentDescriptor {
op: id,
is_dst: true,
@@ -2253,7 +2450,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> { Ok((new_id, typ))
})
.collect::<Result<Vec<_>, _>>()?;
- let func = visitor.variable(
+ let func = visitor.id(
ArgumentDescriptor {
op: self.func,
is_dst: false,
@@ -2285,7 +2482,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> { }
}
-impl VisitVariable for ResolvedCall<NormalizedArgParams> {
+impl VisitVariable for ResolvedCall<TypedArgParams> {
fn visit_variable<
'a,
F: FnMut(
@@ -2295,7 +2492,7 @@ impl VisitVariable for ResolvedCall<NormalizedArgParams> { >(
self,
f: &mut F,
- ) -> Result<UnadornedStatement, TranslateError> {
+ ) -> Result<TypedStatement, TranslateError> {
Ok(Statement::Call(self.map(f)?))
}
}
@@ -2314,16 +2511,16 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> { }
}
-pub trait ArgParamsEx: ast::ArgParams {
+pub trait ArgParamsEx: ast::ArgParams + Sized {
fn get_fn_decl<'x, 'b>(
- id: &Self::ID,
+ id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'x, 'b>,
) -> Result<&'b FnDecl, TranslateError>;
}
impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
fn get_fn_decl<'x, 'b>(
- id: &Self::ID,
+ id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'x, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl_str(id)
@@ -2331,6 +2528,25 @@ impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { }
enum NormalizedArgParams {}
+
+impl ast::ArgParams for NormalizedArgParams {
+ type Id = spirv::Word;
+ type Operand = ast::Operand<spirv::Word>;
+ type CallOperand = ast::CallOperand<spirv::Word>;
+ type IdOrVector = ast::IdOrVector<spirv::Word>;
+ type OperandOrVector = ast::OperandOrVector<spirv::Word>;
+ type SrcMemberOperand = (spirv::Word, u8);
+}
+
+impl ArgParamsEx for NormalizedArgParams {
+ fn get_fn_decl<'a, 'b>(
+ id: &Self::Id,
+ decl: &'b GlobalFnDeclResolver<'a, 'b>,
+ ) -> Result<&'b FnDecl, TranslateError> {
+ decl.get_fn_decl(*id)
+ }
+}
+
type NormalizedStatement = Statement<
(
Option<ast::PredAt<spirv::Word>>,
@@ -2338,81 +2554,100 @@ type NormalizedStatement = Statement< ),
NormalizedArgParams,
>;
-type UnadornedStatement = Statement<ast::Instruction<NormalizedArgParams>, NormalizedArgParams>;
-impl ast::ArgParams for NormalizedArgParams {
- type ID = spirv::Word;
+type UnconditionalStatement = Statement<ast::Instruction<NormalizedArgParams>, NormalizedArgParams>;
+
+enum TypedArgParams {}
+
+impl ast::ArgParams for TypedArgParams {
+ type Id = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
- type VecOperand = (spirv::Word, u8);
+ type IdOrVector = ast::IdOrVector<spirv::Word>;
+ type OperandOrVector = ast::OperandOrVector<spirv::Word>;
+ type SrcMemberOperand = (spirv::Word, u8);
}
-impl ArgParamsEx for NormalizedArgParams {
+impl ArgParamsEx for TypedArgParams {
fn get_fn_decl<'a, 'b>(
- id: &Self::ID,
+ id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'a, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
-#[derive(Copy, Clone)]
-pub enum StateSpace {
- Reg,
- Const,
- Global,
- Local,
- Shared,
- Param,
- ParamReg,
-}
+type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
-struct Function<'input> {
- pub func_directive: ast::MethodDecl<'input, ExpandedArgParams>,
- pub globals: Vec<ExpandedStatement>,
- pub body: Option<Vec<ExpandedStatement>>,
-}
-
impl ast::ArgParams for ExpandedArgParams {
- type ID = spirv::Word;
+ type Id = spirv::Word;
type Operand = spirv::Word;
type CallOperand = spirv::Word;
- type VecOperand = spirv::Word;
+ type IdOrVector = spirv::Word;
+ type OperandOrVector = spirv::Word;
+ type SrcMemberOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
fn get_fn_decl<'a, 'b>(
- id: &Self::ID,
+ id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'a, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
-trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
- fn variable(
+#[derive(Copy, Clone)]
+pub enum StateSpace {
+ Reg,
+ Const,
+ Global,
+ Local,
+ Shared,
+ Param,
+ ParamReg,
+}
+
+struct Function<'input> {
+ pub func_directive: ast::MethodDecl<'input, spirv::Word>,
+ pub globals: Vec<ExpandedStatement>,
+ pub body: Option<Vec<ExpandedStatement>>,
+}
+
+pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
+ fn id(
&mut self,
- desc: ArgumentDescriptor<T::ID>,
+ desc: ArgumentDescriptor<T::Id>,
typ: Option<ast::Type>,
- ) -> Result<U::ID, TranslateError>;
+ ) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
typ: ast::Type,
) -> Result<U::Operand, TranslateError>;
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<T::IdOrVector>,
+ typ: ast::Type,
+ ) -> Result<U::IdOrVector, TranslateError>;
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<T::OperandOrVector>,
+ typ: ast::Type,
+ ) -> Result<U::OperandOrVector, TranslateError>;
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<T::CallOperand>,
typ: ast::Type,
) -> Result<U::CallOperand, TranslateError>;
- fn src_vec_operand(
+ fn src_member_operand(
&mut self,
- desc: ArgumentDescriptor<T::VecOperand>,
- typ: (ast::MovVectorType, u8),
- ) -> Result<U::VecOperand, TranslateError>;
+ desc: ArgumentDescriptor<T::SrcMemberOperand>,
+ typ: (ast::ScalarType, u8),
+ ) -> Result<U::SrcMemberOperand, TranslateError>;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
@@ -2422,7 +2657,7 @@ where Option<ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
- fn variable(
+ fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<ast::Type>,
@@ -2438,6 +2673,22 @@ where self(desc, Some(t))
}
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ self(desc, Some(typ))
+ }
+
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ self(desc, Some(typ))
+ }
+
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
@@ -2446,10 +2697,10 @@ where self(desc, Some(t))
}
- fn src_vec_operand(
+ fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- (scalar_type, vec_len): (ast::MovVectorType, u8),
+ (scalar_type, vec_len): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
self(
desc.new_op(desc.op),
@@ -2462,7 +2713,7 @@ impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> fo where
T: FnMut(&str) -> Result<spirv::Word, TranslateError>,
{
- fn variable(
+ fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
_: Option<ast::Type>,
@@ -2477,8 +2728,38 @@ where ) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)),
- ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)),
+ ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
+ }
+ }
+
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::IdOrVector<&'a str>>,
+ _: ast::Type,
+ ) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
+ match desc.op {
+ ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)),
+ ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec(
+ ids.into_iter().map(self).collect::<Result<_, _>>()?,
+ )),
+ }
+ }
+
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::OperandOrVector<&'a str>>,
+ _: ast::Type,
+ ) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
+ match desc.op {
+ ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)),
+ ast::OperandOrVector::RegOffset(id, imm) => {
+ Ok(ast::OperandOrVector::RegOffset(self(id)?, imm))
+ }
+ ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
+ ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec(
+ ids.into_iter().map(self).collect::<Result<_, _>>()?,
+ )),
}
}
@@ -2493,16 +2774,16 @@ where }
}
- fn src_vec_operand(
+ fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(&str, u8)>,
- _: (ast::MovVectorType, u8),
+ _: (ast::ScalarType, u8),
) -> Result<(spirv::Word, u8), TranslateError> {
Ok((self(desc.op.0)?, desc.op.1))
}
}
-struct ArgumentDescriptor<Op> {
+pub struct ArgumentDescriptor<Op> {
op: Op,
is_dst: bool,
sema: ArgumentSemantics,
@@ -2536,22 +2817,19 @@ impl<T: ArgParamsEx> ast::Instruction<T> { visitor: &mut V,
) -> Result<ast::Instruction<U>, TranslateError> {
Ok(match self {
- ast::Instruction::MovVector(t, a) => {
- ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length))?)
- }
ast::Instruction::Abs(d, arg) => {
ast::Instruction::Abs(d, arg.map(visitor, false, ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
- ast::Instruction::Call(_) => unreachable!(),
+ ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
let is_param = d.state_space == ast::LdStateSpace::Param
|| d.state_space == ast::LdStateSpace::Local;
- ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)?)
+ ast::Instruction::Ld(d, a.map(visitor, inst_type, is_param)?)
}
ast::Instruction::Mov(d, a) => {
- let mapped = a.map(visitor, d.src_is_address, d.typ.into())?;
+ let mapped = a.map(visitor, d)?;
ast::Instruction::Mov(d, mapped)
}
ast::Instruction::Mul(d, a) => {
@@ -2617,7 +2895,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> { }
}
-impl VisitVariable for ast::Instruction<NormalizedArgParams> {
+impl VisitVariable for ast::Instruction<TypedArgParams> {
fn visit_variable<
'a,
F: FnMut(
@@ -2627,19 +2905,19 @@ impl VisitVariable for ast::Instruction<NormalizedArgParams> { >(
self,
f: &mut F,
- ) -> Result<UnadornedStatement, TranslateError> {
+ ) -> Result<TypedStatement, TranslateError> {
Ok(Statement::Instruction(self.map(f)?))
}
}
-impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
+impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
- fn variable(
+ fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<ast::Type>,
@@ -2673,10 +2951,47 @@ where }
}
- fn src_vec_operand(
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
+ match desc.op {
+ ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)),
+ ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec(
+ ids.iter()
+ .map(|id| self(desc.new_op(*id), Some(typ)))
+ .collect::<Result<_, _>>()?,
+ )),
+ }
+ }
+
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
+ match desc.op {
+ ast::OperandOrVector::Reg(id) => {
+ Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?))
+ }
+ ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset(
+ self(desc.new_op(id), Some(typ))?,
+ imm,
+ )),
+ ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
+ ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec(
+ ids.iter()
+ .map(|id| self(desc.new_op(*id), Some(typ)))
+ .collect::<Result<_, _>>()?,
+ )),
+ }
+ }
+
+ fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vector_len): (ast::MovVectorType, u8),
+ (scalar_type, vector_len): (ast::ScalarType, u8),
) -> Result<(spirv::Word, u8), TranslateError> {
Ok((
self(
@@ -2750,7 +3065,6 @@ impl ast::Instruction<ExpandedArgParams> { ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
- | ast::Instruction::MovVector(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
@@ -2786,12 +3100,44 @@ type Arg2 = ast::Arg2<ExpandedArgParams>; type Arg2St = ast::Arg2St<ExpandedArgParams>;
struct CompositeRead {
- pub typ: ast::MovVectorType,
+ pub typ: ast::ScalarType,
pub dst: spirv::Word,
pub src_composite: spirv::Word,
pub src_index: u32,
}
+impl VisitVariableExpanded for CompositeRead {
+ fn visit_variable_extended<
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<ExpandedStatement, TranslateError> {
+ Ok(Statement::Composite(CompositeRead {
+ dst: f(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(ast::Type::Scalar(self.typ)),
+ )?,
+ src_composite: f(
+ ArgumentDescriptor {
+ op: self.src_composite,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(ast::Type::Scalar(self.typ)),
+ )?,
+ ..self
+ }))
+ }
+}
+
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
@@ -2875,12 +3221,16 @@ impl ast::VariableParamType { }
impl<T: ArgParamsEx> ast::Arg1<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id>>(self) -> ast::Arg1<U> {
+ ast::Arg1 { src: self.src }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> Result<ast::Arg1<U>, TranslateError> {
- let new_src = visitor.variable(
+ let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
@@ -2893,13 +3243,20 @@ impl<T: ArgParamsEx> ast::Arg1<T> { }
impl<T: ArgParamsEx> ast::Arg2<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg2<U> {
+ ast::Arg2 {
+ src: self.src,
+ dst: self.dst,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
src_is_addr: bool,
t: ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
- let new_dst = visitor.variable(
+ let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -2925,62 +3282,82 @@ impl<T: ArgParamsEx> ast::Arg2<T> { })
}
- fn map_ld<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
- is_param: bool,
+ dst_t: ast::Type,
+ src_t: ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
- let dst = visitor.variable(
+ let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ Some(dst_t),
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if is_param {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ sema: ArgumentSemantics::Default,
},
- t,
+ src_t,
)?;
Ok(ast::Arg2 { dst, src })
}
+}
- fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg2Ld<T> {
+ fn cast<U: ArgParamsEx<Operand = T::Operand, IdOrVector = T::IdOrVector>>(
+ self,
+ ) -> ast::Arg2Ld<U> {
+ ast::Arg2Ld {
+ dst: self.dst,
+ src: self.src,
+ }
+ }
+
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- dst_t: ast::Type,
- src_t: ast::Type,
- ) -> Result<ast::Arg2<U>, TranslateError> {
- let dst = visitor.variable(
+ t: ast::Type,
+ is_param: bool,
+ ) -> Result<ast::Arg2Ld<U>, TranslateError> {
+ let dst = visitor.id_or_vector(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(dst_t),
+ t.into(),
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ sema: if is_param {
+ ArgumentSemantics::RegisterPointer
+ } else {
+ ArgumentSemantics::PhysicalPointer
+ },
},
- src_t,
+ t,
)?;
- Ok(ast::Arg2 { dst, src })
+ Ok(ast::Arg2Ld { dst, src })
}
}
impl<T: ArgParamsEx> ast::Arg2St<T> {
+ fn cast<U: ArgParamsEx<Operand = T::Operand, OperandOrVector = T::OperandOrVector>>(
+ self,
+ ) -> ast::Arg2St<U> {
+ ast::Arg2St {
+ src1: self.src1,
+ src2: self.src2,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -2999,7 +3376,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> { },
t,
)?;
- let src2 = visitor.operand(
+ let src2 = visitor.operand_or_vector(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
@@ -3011,105 +3388,191 @@ impl<T: ArgParamsEx> ast::Arg2St<T> { }
}
-impl<T: ArgParamsEx> ast::Arg2Vec<T> {
- fn dst(&self) -> &T::ID {
+impl<T: ArgParamsEx> ast::Arg2Mov<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ details: ast::MovDetails,
+ ) -> Result<ast::Arg2Mov<U>, TranslateError> {
+ Ok(match self {
+ ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?),
+ ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?),
+ })
+ }
+}
+
+impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
+ fn cast<U: ArgParamsEx<IdOrVector = P::IdOrVector, OperandOrVector = P::OperandOrVector>>(
+ self,
+ ) -> ast::Arg2MovNormal<U> {
+ ast::Arg2MovNormal {
+ dst: self.dst,
+ src: self.src,
+ }
+ }
+
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
+ self,
+ visitor: &mut V,
+ details: ast::MovDetails,
+ ) -> Result<ast::Arg2MovNormal<U>, TranslateError> {
+ let dst = visitor.id_or_vector(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ details.typ.into(),
+ )?;
+ let src = visitor.operand_or_vector(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: if details.src_is_address {
+ ArgumentSemantics::RegisterPointer
+ } else {
+ ArgumentSemantics::PhysicalPointer
+ },
+ },
+ details.typ.into(),
+ )?;
+ Ok(ast::Arg2MovNormal { dst, src })
+ }
+}
+
+impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, SrcMemberOperand = T::SrcMemberOperand>>(
+ self,
+ ) -> ast::Arg2MovMember<U> {
+ match self {
+ ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2),
+ ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src),
+ ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2),
+ }
+ }
+
+ fn vector_dst(&self) -> Option<&T::Id> {
match self {
- ast::Arg2Vec::Dst((d, _), _, _)
- | ast::Arg2Vec::Src(d, _)
- | ast::Arg2Vec::Both((d, _), _, _) => d,
+ ast::Arg2MovMember::Src(_, _) => None,
+ ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => {
+ Some(d)
+ }
}
}
+ fn vector_src(&self) -> Option<&T::SrcMemberOperand> {
+ match self {
+ ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d),
+ ast::Arg2MovMember::Dst(_, _, _) => None,
+ }
+ }
+}
+
+impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- (scalar_type, vec_len): (ast::MovVectorType, u8),
- ) -> Result<ast::Arg2Vec<U>, TranslateError> {
+ details: ast::MovDetails,
+ ) -> Result<ast::Arg2MovMember<U>, TranslateError> {
match self {
- ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => {
- let dst = visitor.variable(
+ ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
+ let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let src1 = visitor.variable(
+ let src1 = visitor.id(
ArgumentDescriptor {
op: composite_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let src2 = visitor.variable(
+ let src2 = visitor.id(
ArgumentDescriptor {
op: scalar_src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ sema: if details.src_is_address {
+ ArgumentSemantics::Address
+ } else {
+ ArgumentSemantics::Default
+ },
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- Ok(ast::Arg2Vec::Dst((dst, len), src1, src2))
+ Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2))
}
- ast::Arg2Vec::Src(dst, src) => {
- let dst = visitor.variable(
+ ast::Arg2MovMember::Src(dst, src) => {
+ let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let src = visitor.src_vec_operand(
+ let scalar_typ = details.typ.get_scalar()?;
+ let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
- (scalar_type, vec_len),
+ (scalar_typ.into(), details.src_width),
)?;
- Ok(ast::Arg2Vec::Src(dst, src))
+ Ok(ast::Arg2MovMember::Src(dst, src))
}
- ast::Arg2Vec::Both((dst, len), composite_src, src) => {
- let dst = visitor.variable(
+ ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
+ let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let composite_src = visitor.variable(
+ let composite_src = visitor.id(
ArgumentDescriptor {
op: composite_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let src = visitor.src_vec_operand(
+ let scalar_typ = details.typ.get_scalar()?;
+ let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
- (scalar_type, vec_len),
+ (scalar_typ.into(), details.src_width),
)?;
- Ok(ast::Arg2Vec::Both((dst, len), composite_src, src))
+ Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
}
}
}
}
impl<T: ArgParamsEx> ast::Arg3<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg3<U> {
+ ast::Arg3 {
+ dst: self.dst,
+ src1: self.src1,
+ src2: self.src2,
+ }
+ }
+
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
- let dst = visitor.variable(
+ let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -3141,7 +3604,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> { visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
- let dst = visitor.variable(
+ let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -3170,12 +3633,21 @@ impl<T: ArgParamsEx> ast::Arg3<T> { }
impl<T: ArgParamsEx> ast::Arg4<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4<U> {
+ ast::Arg4 {
+ dst: self.dst,
+ src1: self.src1,
+ src2: self.src2,
+ src3: self.src3,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg4<U>, TranslateError> {
- let dst = visitor.variable(
+ let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -3217,12 +3689,21 @@ impl<T: ArgParamsEx> ast::Arg4<T> { }
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4Setp<U> {
+ ast::Arg4Setp {
+ dst1: self.dst1,
+ dst2: self.dst2,
+ src1: self.src1,
+ src2: self.src2,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg4Setp<U>, TranslateError> {
- let dst1 = visitor.variable(
+ let dst1 = visitor.id(
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
@@ -3233,7 +3714,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { let dst2 = self
.dst2
.map(|dst2| {
- visitor.variable(
+ visitor.id(
ArgumentDescriptor {
op: dst2,
is_dst: true,
@@ -3269,12 +3750,22 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { }
impl<T: ArgParamsEx> ast::Arg5<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg5<U> {
+ ast::Arg5 {
+ dst1: self.dst1,
+ dst2: self.dst2,
+ src1: self.src1,
+ src2: self.src2,
+ src3: self.src3,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg5<U>, TranslateError> {
- let dst1 = visitor.variable(
+ let dst1 = visitor.id(
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
@@ -3285,7 +3776,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> { let dst2 = self
.dst2
.map(|dst2| {
- visitor.variable(
+ visitor.id(
ArgumentDescriptor {
op: dst2,
is_dst: true,
@@ -3329,6 +3820,22 @@ impl<T: ArgParamsEx> ast::Arg5<T> { }
}
+impl ast::Type {
+ fn get_vector(self) -> Result<(ast::ScalarType, u8), TranslateError> {
+ match self {
+ ast::Type::Vector(t, len) => Ok((t, len)),
+ _ => Err(TranslateError::MismatchedType),
+ }
+ }
+
+ fn get_scalar(self) -> Result<ast::ScalarType, TranslateError> {
+ match self {
+ ast::Type::Scalar(t) => Ok(t),
+ _ => Err(TranslateError::MismatchedType),
+ }
+ }
+}
+
impl<T> ast::CallOperand<T> {
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
self,
@@ -3528,13 +4035,21 @@ impl From<ast::FnArgumentType> for ast::VariableType { impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> {
match self {
- ast::Operand::Reg(r) => Some(r),
- ast::Operand::RegOffset(r, _) => Some(r),
+ ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r),
ast::Operand::Imm(_) => None,
}
}
}
+impl<T> ast::OperandOrVector<T> {
+ fn single_underlying(&self) -> Option<&T> {
+ match self {
+ ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r),
+ ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None,
+ }
+ }
+}
+
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@@ -3891,7 +4406,7 @@ fn insert_implicit_bitcasts( }
Ok(())
}
-impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
+impl<'a> ast::MethodDecl<'a, &'a str> {
fn name(&self) -> &'a str {
match self {
ast::MethodDecl::Kernel(name, _) => name,
@@ -3900,8 +4415,8 @@ impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> { }
}
-impl<'a, P: ArgParamsEx<ID = spirv::Word>> ast::MethodDecl<'a, P> {
- fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<P>)) {
+impl<'a> ast::MethodDecl<'a, spirv::Word> {
+ fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<spirv::Word>)) {
match self {
ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| {
|