aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-27 23:51:34 +0200
committerAndrzej Janik <[email protected]>2020-09-30 19:27:29 +0200
commit1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8 (patch)
tree26161415586497ec9876198d6a55e17342b740ae
parent7c26568cbf017c55b27b72a7fcfe7761ce31e33c (diff)
downloadZLUDA-1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8.tar.gz
ZLUDA-1e0b35be4bd52ea9273ba36adf224ec0a47eb7f8.zip
Implement vector-destructuring mov/ld/st
-rw-r--r--ptx/src/ast.rs185
-rw-r--r--ptx/src/ptx.lalrpop136
-rw-r--r--ptx/src/test/spirv_run/ntid.spvtxt13
-rw-r--r--ptx/src/translate.rs1179
4 files changed, 1041 insertions, 472 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index acefdc1..7edfa70 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -35,6 +35,19 @@ macro_rules! sub_scalar_type {
}
}
}
+
+ impl std::convert::TryFrom<ScalarType> for $name {
+ type Error = ();
+
+ fn try_from(t: ScalarType) -> Result<Self, Self::Error> {
+ match t {
+ $(
+ ScalarType::$variant => Ok($name::$variant),
+ )+
+ _ => Err(()),
+ }
+ }
+ }
};
}
@@ -159,20 +172,20 @@ pub struct Module<'a> {
pub functions: Vec<ParsedFunction<'a>>,
}
-pub enum MethodDecl<'a, P: ArgParams> {
- Func(Vec<FnArgument<P>>, P::ID, Vec<FnArgument<P>>),
- Kernel(&'a str, Vec<KernelArgument<P>>),
+pub enum MethodDecl<'a, ID> {
+ Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>),
+ Kernel(&'a str, Vec<KernelArgument<ID>>),
}
-pub type FnArgument<P> = Variable<FnArgumentType, P>;
-pub type KernelArgument<P> = Variable<VariableParamType, P>;
+pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
+pub type KernelArgument<ID> = Variable<VariableParamType, ID>;
-pub struct Function<'a, P: ArgParams, S> {
- pub func_directive: MethodDecl<'a, P>,
+pub struct Function<'a, ID, S> {
+ pub func_directive: MethodDecl<'a, ID>,
pub body: Option<Vec<S>>,
}
-pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>;
+pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum FnArgumentType {
@@ -264,21 +277,21 @@ impl Default for ScalarType {
}
pub enum Statement<P: ArgParams> {
- Label(P::ID),
- Variable(MultiVariable<P>),
- Instruction(Option<PredAt<P::ID>>, Instruction<P>),
+ Label(P::Id),
+ Variable(MultiVariable<P::Id>),
+ Instruction(Option<PredAt<P::Id>>, Instruction<P>),
Block(Vec<Statement<P>>),
}
-pub struct MultiVariable<P: ArgParams> {
- pub var: Variable<VariableType, P>,
+pub struct MultiVariable<ID> {
+ pub var: Variable<VariableType, ID>,
pub count: Option<u32>,
}
-pub struct Variable<T, P: ArgParams> {
+pub struct Variable<T, ID> {
pub align: Option<u32>,
pub v_type: T,
- pub name: P::ID,
+ pub name: ID,
}
#[derive(Eq, PartialEq, Copy, Clone)]
@@ -315,9 +328,8 @@ pub struct PredAt<ID> {
}
pub enum Instruction<P: ArgParams> {
- Ld(LdData, Arg2<P>),
- Mov(MovDetails, Arg2<P>),
- MovVector(MovVectorDetails, Arg2Vec<P>),
+ Ld(LdDetails, Arg2Ld<P>),
+ Mov(MovDetails, Arg2Mov<P>),
Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>),
@@ -338,11 +350,6 @@ pub enum Instruction<P: ArgParams> {
pub struct MadFloatDesc {}
#[derive(Copy, Clone)]
-pub struct MovVectorDetails {
- pub typ: MovVectorType,
- pub length: u8,
-}
-#[derive(Copy, Clone)]
pub struct AbsDetails {
pub flush_to_zero: bool,
pub typ: ScalarType,
@@ -350,16 +357,18 @@ pub struct AbsDetails {
pub struct CallInst<P: ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<P::ID>,
- pub func: P::ID,
+ pub ret_params: Vec<P::Id>,
+ pub func: P::Id,
pub param_list: Vec<P::CallOperand>,
}
pub trait ArgParams {
- type ID;
+ type Id;
type Operand;
+ type IdOrVector;
+ type OperandOrVector;
type CallOperand;
- type VecOperand;
+ type SrcMemberOperand;
}
pub struct ParsedArgParams<'a> {
@@ -367,57 +376,73 @@ pub struct ParsedArgParams<'a> {
}
impl<'a> ArgParams for ParsedArgParams<'a> {
- type ID = &'a str;
+ type Id = &'a str;
type Operand = Operand<&'a str>;
type CallOperand = CallOperand<&'a str>;
- type VecOperand = (&'a str, u8);
+ type IdOrVector = IdOrVector<&'a str>;
+ type OperandOrVector = OperandOrVector<&'a str>;
+ type SrcMemberOperand = (&'a str, u8);
}
pub struct Arg1<P: ArgParams> {
- pub src: P::ID, // it is a jump destination, but in terms of operands it is a source operand
+ pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand
}
pub struct Arg2<P: ArgParams> {
- pub dst: P::ID,
+ pub dst: P::Id,
+ pub src: P::Operand,
+}
+pub struct Arg2Ld<P: ArgParams> {
+ pub dst: P::IdOrVector,
pub src: P::Operand,
}
pub struct Arg2St<P: ArgParams> {
pub src1: P::Operand,
- pub src2: P::Operand,
+ pub src2: P::OperandOrVector,
+}
+
+pub enum Arg2Mov<P: ArgParams> {
+ Normal(Arg2MovNormal<P>),
+ Member(Arg2MovMember<P>),
+}
+
+pub struct Arg2MovNormal<P: ArgParams> {
+ pub dst: P::IdOrVector,
+ pub src: P::OperandOrVector,
}
// We duplicate dst here because during further compilation
// composite dst and composite src will receive different ids
-pub enum Arg2Vec<P: ArgParams> {
- Dst((P::ID, u8), P::ID, P::ID),
- Src(P::ID, P::VecOperand),
- Both((P::ID, u8), P::ID, P::VecOperand),
+pub enum Arg2MovMember<P: ArgParams> {
+ Dst((P::Id, u8), P::Id, P::Id),
+ Src(P::Id, P::SrcMemberOperand),
+ Both((P::Id, u8), P::Id, P::SrcMemberOperand),
}
pub struct Arg3<P: ArgParams> {
- pub dst: P::ID,
+ pub dst: P::Id,
pub src1: P::Operand,
pub src2: P::Operand,
}
pub struct Arg4<P: ArgParams> {
- pub dst: P::ID,
+ pub dst: P::Id,
pub src1: P::Operand,
pub src2: P::Operand,
pub src3: P::Operand,
}
pub struct Arg4Setp<P: ArgParams> {
- pub dst1: P::ID,
- pub dst2: Option<P::ID>,
+ pub dst1: P::Id,
+ pub dst2: Option<P::Id>,
pub src1: P::Operand,
pub src2: P::Operand,
}
pub struct Arg5<P: ArgParams> {
- pub dst1: P::ID,
- pub dst2: Option<P::ID>,
+ pub dst1: P::Id,
+ pub dst2: Option<P::Id>,
pub src1: P::Operand,
pub src2: P::Operand,
pub src3: P::Operand,
@@ -436,12 +461,34 @@ pub enum CallOperand<ID> {
Imm(u32),
}
+pub enum IdOrVector<ID> {
+ Reg(ID),
+ Vec(Vec<ID>)
+}
+
+pub enum OperandOrVector<ID> {
+ Reg(ID),
+ RegOffset(ID, i32),
+ Imm(u32),
+ Vec(Vec<ID>)
+}
+
+impl<T> From<Operand<T>> for OperandOrVector<T> {
+ fn from(this: Operand<T>) -> Self {
+ match this {
+ Operand::Reg(r) => OperandOrVector::Reg(r),
+ Operand::RegOffset(r, imm) => OperandOrVector::RegOffset(r, imm),
+ Operand::Imm(imm) => OperandOrVector::Imm(imm),
+ }
+ }
+}
+
pub enum VectorPrefix {
V2,
V4,
}
-pub struct LdData {
+pub struct LdDetails {
pub qualifier: LdStQualifier,
pub state_space: LdStateSpace,
pub caching: LdCacheOperator,
@@ -482,45 +529,23 @@ pub enum LdCacheOperator {
Uncached,
}
-sub_scalar_type!(MovScalarType {
- B16,
- B32,
- B64,
- U16,
- U32,
- U64,
- S16,
- S32,
- S64,
- F32,
- F64,
- Pred,
-});
-
-// pred vectors are illegal
-sub_scalar_type!(MovVectorType {
- B16,
- B32,
- B64,
- U16,
- U32,
- U64,
- S16,
- S32,
- S64,
- F32,
- F64,
-});
-
+#[derive(Copy, Clone)]
pub struct MovDetails {
- pub typ: MovType,
+ pub typ: Type,
pub src_is_address: bool,
-}
-
-sub_type! {
- MovType {
- Scalar(MovScalarType),
- Vector(MovVectorType, u8),
+ // two fields below are in use by member moves
+ pub dst_width: u8,
+ pub src_width: u8,
+}
+
+impl MovDetails {
+ pub fn new(typ: Type) -> Self {
+ MovDetails {
+ typ,
+ src_is_address: false,
+ dst_width: 0,
+ src_width: 0
+ }
}
}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 50a6aeb..ba3fc2b 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -194,7 +194,7 @@ TargetSpecifier = {
"map_f64_to_f32"
};
-Directive: Option<ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement<ast::ParsedArgParams<'input>>>> = {
+Directive: Option<ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>> = {
AddressSize => None,
<f:Function> => Some(f),
File => None,
@@ -205,7 +205,7 @@ AddressSize = {
".address_size" Num
};
-Function: ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement<ast::ParsedArgParams<'input>>> = {
+Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
LinkingDirective*
<func_directive:MethodDecl>
<body:FunctionBody> => ast::Function{<>}
@@ -217,29 +217,29 @@ LinkingDirective = {
".weak"
};
-MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = {
+MethodDecl: ast::MethodDecl<'input, &'input str> = {
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
}
};
-KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = {
+KernelArguments: Vec<ast::KernelArgument<&'input str>> = {
"(" <args:Comma<KernelInput>> ")" => args
};
-FnArguments: Vec<ast::FnArgument<ast::ParsedArgParams<'input>>> = {
+FnArguments: Vec<ast::FnArgument<&'input str>> = {
"(" <args:Comma<FnInput>> ")" => args
};
-KernelInput: ast::Variable<ast::VariableParamType, ast::ParsedArgParams<'input>> = {
+KernelInput: ast::Variable<ast::VariableParamType, &'input str> = {
<v:ParamVariable> => {
let (align, v_type, name) = v;
ast::Variable{ align, v_type, name }
}
}
-FnInput: ast::Variable<ast::FnArgumentType, ast::ParsedArgParams<'input>> = {
+FnInput: ast::Variable<ast::FnArgumentType, &'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Reg(v_type);
@@ -320,7 +320,7 @@ Align: u32 = {
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
-MultiVariable: ast::MultiVariable<ast::ParsedArgParams<'input>> = {
+MultiVariable: ast::MultiVariable<&'input str> = {
<var:Variable> <count:VariableParam?> => ast::MultiVariable{<>}
}
@@ -331,7 +331,7 @@ VariableParam: u32 = {
}
}
-Variable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
+Variable: ast::Variable<ast::VariableType, &'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::VariableType::Reg(v_type);
@@ -356,7 +356,7 @@ RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
}
}
-LocalVariable: ast::Variable<ast::VariableType, ast::ParsedArgParams<'input>> = {
+LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
ast::Variable {align, v_type, name}
@@ -449,19 +449,29 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," <src:MemoryOperand> => {
+ "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:IdOrVector> "," <src:MemoryOperand> => {
ast::Instruction::Ld(
- ast::LdData {
+ ast::LdDetails {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
typ: t
},
- ast::Arg2 { dst:dst, src:src }
+ ast::Arg2Ld { dst:dst, src:src }
)
}
};
+IdOrVector: ast::IdOrVector<&'input str> = {
+ <dst:ExtendedID> => ast::IdOrVector::Reg(dst),
+ <dst:VectorExtract> => ast::IdOrVector::Vec(dst)
+}
+
+OperandOrVector: ast::OperandOrVector<&'input str> = {
+ <op:Operand> => ast::OperandOrVector::from(op),
+ <dst:VectorExtract> => ast::OperandOrVector::Vec(dst)
+}
+
LdStType: ast::Type = {
<v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v),
<t:LdStScalarType> => ast::Type::Scalar(t),
@@ -498,49 +508,58 @@ LdCacheOperator: ast::LdCacheOperator = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "mov" <t:MovType> <a:Arg2> => {
- ast::Instruction::Mov(ast::MovDetails{ src_is_address: false, typ: t }, a)
- },
- "mov" <t:MovVectorType> <a:Arg2Vec> => {
- ast::Instruction::MovVector(ast::MovVectorDetails{typ: t, length: 0}, a)
- }
+ <m:MovNormal> => ast::Instruction::Mov(m.0, m.1),
+ <m:MovVector> => ast::Instruction::Mov(m.0, m.1),
};
-#[inline]
-MovType: ast::MovType = {
- <t:MovScalarType> => ast::MovType::Scalar(t),
- <pref:VectorPrefix> <t:MovVectorType> => ast::MovType::Vector(t, pref)
+
+MovNormal: (ast::MovDetails, ast::Arg2Mov<ast::ParsedArgParams<'input>>) = {
+ "mov" <t:MovScalarType> <dst:ExtendedID> "," <src:Operand> => {(
+ ast::MovDetails::new(ast::Type::Scalar(t)),
+ ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: ast::IdOrVector::Reg(dst), src: src.into() })
+ )},
+ "mov" <pref:VectorPrefix> <t:MovVectorType> <dst:IdOrVector> "," <src:OperandOrVector> => {(
+ ast::MovDetails::new(ast::Type::Vector(t, pref)),
+ ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: dst, src: src })
+ )}
+}
+
+MovVector: (ast::MovDetails, ast::Arg2Mov<ast::ParsedArgParams<'input>>) = {
+ "mov" <t:MovVectorType> <a:Arg2MovMember> => {(
+ ast::MovDetails::new(ast::Type::Scalar(t.into())),
+ ast::Arg2Mov::Member(a)
+ )},
}
#[inline]
-MovScalarType: ast::MovScalarType = {
- ".b16" => ast::MovScalarType::B16,
- ".b32" => ast::MovScalarType::B32,
- ".b64" => ast::MovScalarType::B64,
- ".u16" => ast::MovScalarType::U16,
- ".u32" => ast::MovScalarType::U32,
- ".u64" => ast::MovScalarType::U64,
- ".s16" => ast::MovScalarType::S16,
- ".s32" => ast::MovScalarType::S32,
- ".s64" => ast::MovScalarType::S64,
- ".f32" => ast::MovScalarType::F32,
- ".f64" => ast::MovScalarType::F64,
- ".pred" => ast::MovScalarType::Pred
+MovScalarType: ast::ScalarType = {
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
+ ".pred" => ast::ScalarType::Pred
};
#[inline]
-MovVectorType: ast::MovVectorType = {
- ".b16" => ast::MovVectorType::B16,
- ".b32" => ast::MovVectorType::B32,
- ".b64" => ast::MovVectorType::B64,
- ".u16" => ast::MovVectorType::U16,
- ".u32" => ast::MovVectorType::U32,
- ".u64" => ast::MovVectorType::U64,
- ".s16" => ast::MovVectorType::S16,
- ".s32" => ast::MovVectorType::S32,
- ".s64" => ast::MovVectorType::S64,
- ".f32" => ast::MovVectorType::F32,
- ".f64" => ast::MovVectorType::F64,
+MovVectorType: ast::ScalarType = {
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
@@ -902,7 +921,7 @@ ShlType: ast::ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:Operand> => {
+ "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:OperandOrVector> => {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@@ -1044,13 +1063,13 @@ Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
};
-Arg2Vec: ast::Arg2Vec<ast::ParsedArgParams<'input>> = {
- <dst:VectorOperand> "," <src:ExtendedID> => ast::Arg2Vec::Dst(dst, dst.0, src),
- <dst:ExtendedID> "," <src:VectorOperand> => ast::Arg2Vec::Src(dst, src),
- <dst:VectorOperand> "," <src:VectorOperand> => ast::Arg2Vec::Both(dst, dst.0, src),
+Arg2MovMember: ast::Arg2MovMember<ast::ParsedArgParams<'input>> = {
+ <dst:MemberOperand> "," <src:ExtendedID> => ast::Arg2MovMember::Dst(dst, dst.0, src),
+ <dst:ExtendedID> "," <src:MemberOperand> => ast::Arg2MovMember::Src(dst, src),
+ <dst:MemberOperand> "," <src:MemberOperand> => ast::Arg2MovMember::Both(dst, dst.0, src),
};
-VectorOperand: (&'input str, u8) = {
+MemberOperand: (&'input str, u8) = {
<pref:ExtendedID> "." <suf:ExtendedID> =>? {
let suf_idx = vector_index(suf)?;
Ok((pref, suf_idx))
@@ -1061,6 +1080,15 @@ VectorOperand: (&'input str, u8) = {
}
};
+VectorExtract: Vec<&'input str> = {
+ "{" <r1:ExtendedID> "," <r2:ExtendedID> "}" => {
+ vec![r1, r2]
+ },
+ "{" <r1:ExtendedID> "," <r2:ExtendedID> "," <r3:ExtendedID> "," <r4:ExtendedID> "}" => {
+ vec![r1, r2, r3, r4]
+ },
+};
+
Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {
<dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
};
diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt
index ef308f0..be16d2e 100644
--- a/ptx/src/test/spirv_run/ntid.spvtxt
+++ b/ptx/src/test/spirv_run/ntid.spvtxt
@@ -4,15 +4,16 @@
OpCapability Kernel
OpCapability Int64
OpCapability Int8
+ OpCapability Float64
%29 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %1 "add" %GlobalSize
- OpDecorate %GlobalSize BuiltIn GlobalSize
+ OpEntryPoint Kernel %1 "ntid" %gl_WorkGroupSize
+ OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize
%void = OpTypeVoid
%uint = OpTypeInt 32 0
- %v3uint = OpTypeVector %uint 3
-%_ptr_UniformConstant_v3uint = OpTypePointer UniformConstant %v3uint
- %GlobalSize = OpVariable %_ptr_UniformConstant_v3uint UniformConstant
+ %v4uint = OpTypeVector %uint 4
+%_ptr_UniformConstant_v4uint = OpTypePointer UniformConstant %v4uint
+%gl_WorkGroupSize = OpVariable %_ptr_UniformConstant_v4uint UniformConstant
%ulong = OpTypeInt 64 0
%35 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
@@ -40,7 +41,7 @@
%25 = OpConvertUToPtr %_ptr_Generic_uint %16
%15 = OpLoad %uint %25
OpStore %6 %15
- %18 = OpLoad %v3uint %GlobalSize
+ %18 = OpLoad %v4uint %gl_WorkGroupSize
%24 = OpCompositeExtract %uint %18 0
%17 = OpCopyObject %uint %24
OpStore %7 %17
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index a1d4b6a..981da86 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,6 +1,7 @@
use crate::ast;
use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
+use std::convert::TryInto;
use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
@@ -282,7 +283,7 @@ fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
global: &GlobalStringIdResolver<'a>,
- func_directive: ast::MethodDecl<ExpandedArgParams>,
+ func_directive: ast::MethodDecl<spirv::Word>,
all_args_lens: &mut HashMap<String, Vec<usize>>,
) -> Result<(), TranslateError> {
if let ast::MethodDecl::Kernel(name, args) = &func_directive {
@@ -334,8 +335,10 @@ fn emit_capabilities(builder: &mut dr::Builder) {
builder.capability(spirv::Capability::Linkage);
builder.capability(spirv::Capability::Addresses);
builder.capability(spirv::Capability::Kernel);
- builder.capability(spirv::Capability::Int64);
builder.capability(spirv::Capability::Int8);
+ builder.capability(spirv::Capability::Int16);
+ builder.capability(spirv::Capability::Int64);
+ builder.capability(spirv::Capability::Float16);
builder.capability(spirv::Capability::Float64);
}
@@ -362,8 +365,8 @@ fn to_ssa_function<'a>(
fn expand_kernel_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
-) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
+ args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
+) -> Vec<ast::KernelArgument<spirv::Word>> {
args.map(|a| ast::KernelArgument {
name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))),
v_type: a.v_type,
@@ -374,8 +377,8 @@ fn expand_kernel_params<'a, 'b>(
fn expand_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
-) -> Vec<ast::FnArgument<ExpandedArgParams>> {
+ args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
+) -> Vec<ast::FnArgument<spirv::Word>> {
args.map(|a| {
let ss = match a.v_type {
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
@@ -393,7 +396,7 @@ fn expand_fn_params<'a, 'b>(
fn to_ssa<'input, 'b>(
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
- f_args: ast::MethodDecl<'input, ExpandedArgParams>,
+ f_args: ast::MethodDecl<'input, spirv::Word>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
) -> Result<Function<'input>, TranslateError> {
let f_body = match f_body {
@@ -409,11 +412,11 @@ fn to_ssa<'input, 'b>(
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
- let unadorned_statements =
- add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
+ let typed_statements =
+ convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let (f_args, ssa_statements) =
- insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args)?;
+ insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, f_args)?;
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
@@ -443,15 +446,16 @@ fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedSta
func
}
-fn add_types_to_statements(
- func: Vec<UnadornedStatement>,
+fn convert_to_typed_statements(
+ func: Vec<UnconditionalStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &NumericIdResolver,
-) -> Result<Vec<UnadornedStatement>, TranslateError> {
- func.into_iter()
- .map(|s| {
- match s {
- Statement::Instruction(ast::Instruction::Call(call)) => {
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let mut result = Vec::<TypedStatement>::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::Call(call) => {
// TODO: error out if lengths don't match
let fn_def = fn_defs.get_fn_decl(call.func)?;
let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
@@ -462,7 +466,7 @@ fn add_types_to_statements(
func: call.func,
param_list,
};
- Ok(Statement::Call(resolved_call))
+ result.push(Statement::Call(resolved_call));
}
// Supported ld/st:
// global: only compatible with reg b64/u64/s64 source/dest
@@ -477,25 +481,24 @@ fn add_types_to_statements(
// One complication: immediate address is only allowed in local,
// It is not supported in generic ld
// ld.local foo, [1];
- Statement::Instruction(ast::Instruction::Ld(mut d, arg)) => {
+ ast::Instruction::Ld(mut d, arg) => {
match arg.src.underlying() {
- None => return Ok(Statement::Instruction(ast::Instruction::Ld(d, arg))),
+ None => {}
Some(u) => {
let (ss, _) = id_defs.get_typed(*u)?;
match (d.state_space, ss) {
(ast::LdStateSpace::Generic, StateSpace::Local) => {
d.state_space = ast::LdStateSpace::Local;
}
- _ => (),
+ _ => {}
};
}
};
-
- Ok(Statement::Instruction(ast::Instruction::Ld(d, arg)))
+ result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast())));
}
- Statement::Instruction(ast::Instruction::St(mut d, arg)) => {
+ ast::Instruction::St(mut d, arg) => {
match arg.src1.underlying() {
- None => return Ok(Statement::Instruction(ast::Instruction::St(d, arg))),
+ None => {}
Some(u) => {
let (ss, _) = id_defs.get_typed(*u)?;
match (d.state_space, ss) {
@@ -506,39 +509,101 @@ fn add_types_to_statements(
};
}
};
- Ok(Statement::Instruction(ast::Instruction::St(d, arg)))
+ result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast())));
}
- Statement::Instruction(ast::Instruction::Mov(mut d, arg)) => {
- if let Some(src_id) = arg.src.underlying() {
- let (scope, _) = id_defs.get_typed(*src_id)?;
- d.src_is_address = match scope {
- StateSpace::Reg => false,
- StateSpace::Const
- | StateSpace::Global
- | StateSpace::Local
- | StateSpace::Shared
- | StateSpace::Param
- | StateSpace::ParamReg => true,
+ ast::Instruction::Mov(mut d, args) => match args {
+ ast::Arg2Mov::Normal(arg) => {
+ if let Some(src_id) = arg.src.single_underlying() {
+ let (scope, _) = id_defs.get_typed(*src_id)?;
+ d.src_is_address = match scope {
+ StateSpace::Reg => false,
+ StateSpace::Const
+ | StateSpace::Global
+ | StateSpace::Local
+ | StateSpace::Shared
+ | StateSpace::Param
+ | StateSpace::ParamReg => true,
+ };
+ }
+ result.push(Statement::Instruction(ast::Instruction::Mov(
+ d,
+ ast::Arg2Mov::Normal(arg.cast()),
+ )));
+ }
+ ast::Arg2Mov::Member(args) => {
+ if let Some(dst_typ) = args.vector_dst() {
+ match id_defs.get_typed(*dst_typ)? {
+ (_, ast::Type::Vector(_, len)) => {
+ d.dst_width = len;
+ }
+ _ => return Err(TranslateError::MismatchedType),
+ }
+ };
+ if let Some((src_typ, _)) = args.vector_src() {
+ match id_defs.get_typed(*src_typ)? {
+ (_, ast::Type::Vector(_, len)) => {
+ d.src_width = len;
+ }
+ _ => return Err(TranslateError::MismatchedType),
+ }
};
+ result.push(Statement::Instruction(ast::Instruction::Mov(
+ d,
+ ast::Arg2Mov::Member(args.cast()),
+ )));
}
- Ok(Statement::Instruction(ast::Instruction::Mov(d, arg)))
+ },
+ ast::Instruction::Mul(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast())))
}
- Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
- let new_dets = match id_defs.get_typed(*args.dst())? {
- (_, ast::Type::Vector(_, len)) => ast::MovVectorDetails {
- length: len,
- ..dets
- },
- _ => dets,
- };
- Ok(Statement::Instruction(ast::Instruction::MovVector(
- new_dets, args,
- )))
+ ast::Instruction::Add(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast())))
}
- s => Ok(s),
- }
- })
- .collect::<Result<Vec<_>, _>>()
+ ast::Instruction::Setp(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast())))
+ }
+ ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction(
+ ast::Instruction::SetpBool(d, a.cast()),
+ )),
+ ast::Instruction::Not(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast())))
+ }
+ ast::Instruction::Bra(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast())))
+ }
+ ast::Instruction::Cvt(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast())))
+ }
+ ast::Instruction::Cvta(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast())))
+ }
+ ast::Instruction::Shl(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast())))
+ }
+ ast::Instruction::Ret(d) => {
+ result.push(Statement::Instruction(ast::Instruction::Ret(d)))
+ }
+ ast::Instruction::Abs(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast())))
+ }
+ ast::Instruction::Mad(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast())))
+ }
+ },
+ Statement::Label(i) => result.push(Statement::Label(i)),
+ Statement::Variable(v) => result.push(Statement::Variable(v)),
+ Statement::LoadVar(a, t) => result.push(Statement::LoadVar(a, t)),
+ Statement::StoreVar(a, t) => result.push(Statement::StoreVar(a, t)),
+ Statement::Call(c) => result.push(Statement::Call(c.cast())),
+ Statement::Composite(c) => result.push(Statement::Composite(c)),
+ Statement::Conditional(c) => result.push(Statement::Conditional(c)),
+ Statement::Conversion(c) => result.push(Statement::Conversion(c)),
+ Statement::Constant(c) => result.push(Statement::Constant(c)),
+ Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
+ Statement::Undef(_, _) => return Err(TranslateError::Unreachable),
+ }
+ }
+ Ok(result)
}
fn to_resolved_fn_args<T>(
@@ -576,7 +641,8 @@ fn normalize_labels(
| Statement::RetValue(_, _)
| Statement::Conversion(_)
| Statement::Constant(_)
- | Statement::Label(_) => (),
+ | Statement::Label(_)
+ | Statement::Undef(_, _) => (),
}
}
iter::once(Statement::Label(id_def.new_id(None)))
@@ -590,7 +656,7 @@ fn normalize_labels(
fn normalize_predicates(
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
-) -> Vec<UnadornedStatement> {
+) -> Vec<UnconditionalStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
@@ -630,16 +696,10 @@ fn normalize_predicates(
}
fn insert_mem_ssa_statements<'a, 'b>(
- func: Vec<UnadornedStatement>,
+ func: Vec<TypedStatement>,
id_def: &mut MutableNumericIdResolver,
- mut f_args: ast::MethodDecl<'a, ExpandedArgParams>,
-) -> Result<
- (
- ast::MethodDecl<'a, ExpandedArgParams>,
- Vec<UnadornedStatement>,
- ),
- TranslateError,
-> {
+ mut f_args: ast::MethodDecl<'a, spirv::Word>,
+) -> Result<(ast::MethodDecl<'a, spirv::Word>, Vec<TypedStatement>), TranslateError> {
let mut result = Vec::with_capacity(func.len());
let out_param = match &mut f_args {
ast::MethodDecl::Kernel(_, in_params) => {
@@ -697,7 +757,9 @@ fn insert_mem_ssa_statements<'a, 'b>(
};
for s in func {
match s {
- Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call)?,
+ Statement::Call(call) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, call.cast())?
+ }
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
if let Some(out_param) = out_param {
@@ -734,7 +796,8 @@ fn insert_mem_ssa_statements<'a, 'b>(
| Statement::StoreVar(_, _)
| Statement::Conversion(_)
| Statement::RetValue(_, _)
- | Statement::Constant(_) => unreachable!(),
+ | Statement::Constant(_)
+ | Statement::Undef(_, _) => {}
Statement::Composite(_) => todo!(),
}
}
@@ -751,7 +814,7 @@ trait VisitVariable: Sized {
>(
self,
f: &mut F,
- ) -> Result<UnadornedStatement, TranslateError>;
+ ) -> Result<TypedStatement, TranslateError>;
}
trait VisitVariableExpanded {
fn visit_variable_extended<
@@ -767,7 +830,7 @@ trait VisitVariableExpanded {
fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
id_def: &mut MutableNumericIdResolver,
- result: &mut Vec<UnadornedStatement>,
+ result: &mut Vec<TypedStatement>,
stmt: F,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
@@ -808,7 +871,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
}
fn expand_arguments<'a, 'b>(
- func: Vec<UnadornedStatement>,
+ func: Vec<TypedStatement>,
id_def: &'b mut MutableNumericIdResolver<'a>,
) -> Result<Vec<ExpandedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
@@ -840,9 +903,10 @@ fn expand_arguments<'a, 'b>(
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Composite(_) | Statement::Conversion(_) | Statement::Constant(_) => {
- unreachable!()
- }
+ Statement::Composite(_)
+ | Statement::Conversion(_)
+ | Statement::Constant(_)
+ | Statement::Undef(_, _) => unreachable!(),
}
}
Ok(result)
@@ -865,12 +929,26 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
post_stmts: Vec::new(),
}
}
-}
-impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
- for FlattenArguments<'a, 'b>
-{
- fn variable(
+ fn insert_composite_read(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut MutableNumericIdResolver<'a>,
+ (scalar_type, vec_len): (ast::ScalarType, u8),
+ scalar_dst: Option<spirv::Word>,
+ composite_src: (spirv::Word, u8),
+ ) -> spirv::Word {
+ let new_id =
+ scalar_dst.unwrap_or_else(|| id_def.new_id(ast::Type::Vector(scalar_type, vec_len)));
+ func.push(Statement::Composite(CompositeRead {
+ typ: scalar_type,
+ dst: new_id,
+ src_composite: composite_src.0,
+ src_index: composite_src.1 as u32,
+ }));
+ new_id
+ }
+
+ fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
_: Option<ast::Type>,
@@ -878,112 +956,105 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
Ok(desc.op)
}
- fn operand(
+ fn reg_offset(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: ast::Type,
) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)),
- ast::Operand::Imm(x) => {
+ let (reg, offset) = desc.op;
+ match desc.sema {
+ ArgumentSemantics::Default => {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
scalar
} else {
todo!()
};
- let id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let result_id = self.id_def.new_id(typ);
self.func.push(Statement::Constant(ConstantDefinition {
- dst: id,
+ dst: id_constant_stmt,
typ: scalar_t,
- value: x as i64,
+ value: offset as i64,
}));
- Ok(id)
+ let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ Ok(result_id)
}
- ast::Operand::RegOffset(reg, offset) => match desc.sema {
- ArgumentSemantics::Default => {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- let result_id = self.id_def.new_id(typ);
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- Ok(result_id)
- }
- ArgumentSemantics::PhysicalPointer => {
- let scalar_t = ast::ScalarType::U64;
- let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::U64;
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- Ok(result_id)
- }
- ArgumentSemantics::RegisterPointer => {
- if offset == 0 {
- return Ok(reg);
- }
- todo!()
+ ArgumentSemantics::PhysicalPointer => {
+ let scalar_t = ast::ScalarType::U64;
+ let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: scalar_t,
+ value: offset as i64,
+ }));
+ let int_type = ast::IntType::U64;
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ Ok(result_id)
+ }
+ ArgumentSemantics::RegisterPointer => {
+ if offset == 0 {
+ return Ok(reg);
}
- ArgumentSemantics::Address => todo!(),
- },
+ todo!()
+ }
+ ArgumentSemantics::Address => todo!(),
}
}
- fn src_call_operand(
+ fn immediate(
&mut self,
- desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ desc: ArgumentDescriptor<u32>,
typ: ast::Type,
) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)),
- ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ),
- }
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
+ } else {
+ todo!()
+ };
+ let id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value: desc.op as i64,
+ }));
+ Ok(id)
}
- fn src_vec_operand(
+ fn member_src(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vec_len): (ast::MovVectorType, u8),
+ (scalar_type, vec_len): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
- let new_id = self
- .id_def
- .new_id(ast::Type::Vector(scalar_type.into(), vec_len));
+ if desc.is_dst {
+ return Err(TranslateError::Unreachable);
+ }
+ let new_id = self.id_def.new_id(ast::Type::Vector(scalar_type, vec_len));
self.func.push(Statement::Composite(CompositeRead {
typ: scalar_type,
dst: new_id,
@@ -992,6 +1063,115 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
}));
Ok(new_id)
}
+
+ fn vector(
+ &mut self,
+ desc: ArgumentDescriptor<&Vec<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ let (scalar_type, vec_len) = typ.get_vector()?;
+ if !desc.is_dst {
+ let mut new_id = self.id_def.new_id(typ);
+ self.func.push(Statement::Undef(typ, new_id));
+ for (idx, id) in desc.op.iter().enumerate() {
+ let newer_id = self.id_def.new_id(typ);
+ self.func.push(Statement::Instruction(ast::Instruction::Mov(
+ ast::MovDetails {
+ typ: typ,
+ src_is_address: false,
+ dst_width: 0,
+ src_width: vec_len,
+ },
+ ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
+ (newer_id, idx as u8),
+ new_id,
+ *id,
+ )),
+ )));
+ new_id = newer_id;
+ }
+ Ok(new_id)
+ } else {
+ let new_id = self.id_def.new_id(typ);
+ for (idx, id) in desc.op.iter().enumerate() {
+ Self::insert_composite_read(
+ &mut self.post_stmts,
+ self.id_def,
+ (scalar_type, vec_len),
+ Some(*id),
+ (new_id, idx as u8),
+ );
+ }
+ Ok(new_id)
+ }
+ }
+}
+
+impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenArguments<'a, 'b> {
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
+ self.reg(desc, t)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ match desc.op {
+ ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
+ ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ ast::Operand::RegOffset(reg, offset) => {
+ self.reg_offset(desc.new_op((reg, offset)), typ)
+ }
+ }
+ }
+
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ match desc.op {
+ ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)),
+ ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ }
+ }
+
+ fn src_member_operand(
+ &mut self,
+ desc: ArgumentDescriptor<(spirv::Word, u8)>,
+ (scalar_type, vec_len): (ast::ScalarType, u8),
+ ) -> Result<spirv::Word, TranslateError> {
+ self.member_src(desc, (scalar_type, vec_len))
+ }
+
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ match desc.op {
+ ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
+ ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
+ }
+ }
+
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ match desc.op {
+ ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
+ ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ),
+ ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ),
+ ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
+ }
+ }
}
/*
@@ -1078,14 +1258,13 @@ fn insert_implicit_conversions(
|arg| ast::Instruction::St(st, arg),
)
}
- ast::Instruction::Mov(d, mut arg) => {
+ ast::Instruction::Mov(d, ast::Arg2Mov::Normal(mut arg)) => {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov-2
// TODO: handle the case of mixed vector/scalar implicit conversions
let inst_typ_is_bit = match d.typ {
- ast::MovType::Scalar(t) => {
- ast::ScalarType::from(t).kind() == ScalarKind::Bit
- }
- ast::MovType::Vector(_, _) => false,
+ ast::Type::Scalar(t) => ast::ScalarType::from(t).kind() == ScalarKind::Bit,
+ ast::Type::Vector(_, _) => false,
+ ast::Type::Array(_, _) => false,
};
let mut did_vector_implicit = false;
let mut post_conv = None;
@@ -1115,12 +1294,15 @@ fn insert_implicit_conversions(
}
}
if did_vector_implicit {
- result.push(Statement::Instruction(ast::Instruction::Mov(d, arg)));
+ result.push(Statement::Instruction(ast::Instruction::Mov(
+ d,
+ ast::Arg2Mov::Normal(arg),
+ )));
} else {
insert_implicit_bitcasts(
&mut result,
id_def,
- ast::Instruction::Mov(d, arg),
+ ast::Instruction::Mov(d, ast::Arg2Mov::Normal(arg)),
)?;
}
if let Some(post_conv) = post_conv {
@@ -1129,13 +1311,14 @@ fn insert_implicit_conversions(
}
inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst)?,
},
- s @ Statement::Composite(_)
- | s @ Statement::Conditional(_)
+ Statement::Composite(c) => insert_implicit_bitcasts(&mut result, id_def, c)?,
+ s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _)
+ | s @ Statement::Undef(_, _)
| s @ Statement::RetValue(_, _) => result.push(s),
Statement::Conversion(_) => unreachable!(),
}
@@ -1146,7 +1329,7 @@ fn insert_implicit_conversions(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- method_decl: &ast::MethodDecl<ExpandedArgParams>,
+ method_decl: &ast::MethodDecl<spirv::Word>,
) -> (spirv::Word, spirv::Word) {
match method_decl {
ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn(
@@ -1173,7 +1356,7 @@ fn emit_function_body_ops(
map: &mut TypeWordMap,
opencl: spirv::Word,
func: &[ExpandedStatement],
-) -> Result<(), dr::Error> {
+) -> Result<(), TranslateError> {
for s in func {
match s {
Statement::Label(id) => {
@@ -1305,11 +1488,34 @@ fn emit_function_body_ops(
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
- ast::Instruction::Mov(d, arg) => {
- let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ)));
- builder.copy_object(result_type, Some(arg.dst), arg.src)?;
- }
+ ast::Instruction::Mov(d, arg) => match arg {
+ ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src })
+ | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => {
+ let result_type =
+ map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ)));
+ builder.copy_object(result_type, Some(*dst), *src)?;
+ }
+ ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
+ dst,
+ composite_src,
+ scalar_src,
+ ))
+ | ast::Arg2Mov::Member(ast::Arg2MovMember::Both(
+ dst,
+ composite_src,
+ scalar_src,
+ )) => {
+ let result_type = map.get_or_add(builder, SpirvType::from(d.typ));
+ let result_id = Some(dst.0);
+ builder.composite_insert(
+ result_type,
+ result_id,
+ *scalar_src,
+ *composite_src,
+ [dst.1 as u32],
+ )?;
+ }
+ },
ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Int(ref ctr) => {
emit_mul_int(builder, map, opencl, ctr, arg)?;
@@ -1361,31 +1567,6 @@ fn emit_function_body_ops(
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::SetpBool(_, _) => todo!(),
- ast::Instruction::MovVector(typ, arg) => match arg {
- ast::Arg2Vec::Dst((dst, dst_index), composite_src, src)
- | ast::Arg2Vec::Both((dst, dst_index), composite_src, src) => {
- let result_type = map.get_or_add(
- builder,
- SpirvType::Vector(
- SpirvScalarKey::from(ast::ScalarType::from(typ.typ)),
- typ.length,
- ),
- );
- let result_id = Some(*dst);
- builder.composite_insert(
- result_type,
- result_id,
- *src,
- *composite_src,
- [*dst_index as u32],
- )?;
- }
- ast::Arg2Vec::Src(dst, src) => {
- let result_type =
- map.get_or_add_scalar(builder, ast::ScalarType::from(typ.typ));
- builder.copy_object(result_type, Some(*dst), *src)?;
- }
- },
ast::Instruction::Mad(mad, arg) => match mad {
ast::MulDetails::Int(ref desc) => {
emit_mad_int(builder, map, opencl, desc, arg)?
@@ -1413,6 +1594,10 @@ fn emit_function_body_ops(
[c.src_index],
)?;
}
+ Statement::Undef(t, id) => {
+ let result_type = map.get_or_add(builder, SpirvType::from(*t));
+ builder.undef(result_type, Some(*id));
+ }
}
}
Ok(())
@@ -2016,11 +2201,11 @@ impl<'a> GlobalStringIdResolver<'a> {
fn start_fn<'b>(
&'b mut self,
- header: &'b ast::MethodDecl<'a, ast::ParsedArgParams<'a>>,
+ header: &'b ast::MethodDecl<'a, &'a str>,
) -> (
FnStringIdResolver<'a, 'b>,
GlobalFnDeclResolver<'a, 'b>,
- ast::MethodDecl<'a, ExpandedArgParams>,
+ ast::MethodDecl<'a, spirv::Word>,
) {
// In case a function decl was inserted earlier we want to use its id
let name_id = self.get_or_add_def(header.name());
@@ -2213,7 +2398,7 @@ impl<'b> MutableNumericIdResolver<'b> {
enum Statement<I, P: ast::ArgParams> {
Label(u32),
- Variable(ast::Variable<ast::VariableType, P>),
+ Variable(ast::Variable<ast::VariableType, P::Id>),
Instruction(I),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
@@ -2224,6 +2409,7 @@ enum Statement<I, P: ast::ArgParams> {
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
+ Undef(ast::Type, spirv::Word),
}
struct ResolvedCall<P: ast::ArgParams> {
@@ -2233,8 +2419,19 @@ struct ResolvedCall<P: ast::ArgParams> {
pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
}
-impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
- fn map<To: ArgParamsEx<ID = spirv::Word>, V: ArgumentMapVisitor<From, To>>(
+impl<T: ast::ArgParams> ResolvedCall<T> {
+ fn cast<U: ast::ArgParams<CallOperand = T::CallOperand>>(self) -> ResolvedCall<U> {
+ ResolvedCall {
+ uniform: self.uniform,
+ ret_params: self.ret_params,
+ func: self.func,
+ param_list: self.param_list,
+ }
+ }
+}
+
+impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
+ fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<From, To>>(
self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
@@ -2242,7 +2439,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
.ret_params
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
- let new_id = visitor.variable(
+ let new_id = visitor.id(
ArgumentDescriptor {
op: id,
is_dst: true,
@@ -2253,7 +2450,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
Ok((new_id, typ))
})
.collect::<Result<Vec<_>, _>>()?;
- let func = visitor.variable(
+ let func = visitor.id(
ArgumentDescriptor {
op: self.func,
is_dst: false,
@@ -2285,7 +2482,7 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
}
}
-impl VisitVariable for ResolvedCall<NormalizedArgParams> {
+impl VisitVariable for ResolvedCall<TypedArgParams> {
fn visit_variable<
'a,
F: FnMut(
@@ -2295,7 +2492,7 @@ impl VisitVariable for ResolvedCall<NormalizedArgParams> {
>(
self,
f: &mut F,
- ) -> Result<UnadornedStatement, TranslateError> {
+ ) -> Result<TypedStatement, TranslateError> {
Ok(Statement::Call(self.map(f)?))
}
}
@@ -2314,16 +2511,16 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
}
}
-pub trait ArgParamsEx: ast::ArgParams {
+pub trait ArgParamsEx: ast::ArgParams + Sized {
fn get_fn_decl<'x, 'b>(
- id: &Self::ID,
+ id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'x, 'b>,
) -> Result<&'b FnDecl, TranslateError>;
}
impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
fn get_fn_decl<'x, 'b>(
- id: &Self::ID,
+ id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'x, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl_str(id)
@@ -2331,6 +2528,25 @@ impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
}
enum NormalizedArgParams {}
+
+impl ast::ArgParams for NormalizedArgParams {
+ type Id = spirv::Word;
+ type Operand = ast::Operand<spirv::Word>;
+ type CallOperand = ast::CallOperand<spirv::Word>;
+ type IdOrVector = ast::IdOrVector<spirv::Word>;
+ type OperandOrVector = ast::OperandOrVector<spirv::Word>;
+ type SrcMemberOperand = (spirv::Word, u8);
+}
+
+impl ArgParamsEx for NormalizedArgParams {
+ fn get_fn_decl<'a, 'b>(
+ id: &Self::Id,
+ decl: &'b GlobalFnDeclResolver<'a, 'b>,
+ ) -> Result<&'b FnDecl, TranslateError> {
+ decl.get_fn_decl(*id)
+ }
+}
+
type NormalizedStatement = Statement<
(
Option<ast::PredAt<spirv::Word>>,
@@ -2338,81 +2554,100 @@ type NormalizedStatement = Statement<
),
NormalizedArgParams,
>;
-type UnadornedStatement = Statement<ast::Instruction<NormalizedArgParams>, NormalizedArgParams>;
-impl ast::ArgParams for NormalizedArgParams {
- type ID = spirv::Word;
+type UnconditionalStatement = Statement<ast::Instruction<NormalizedArgParams>, NormalizedArgParams>;
+
+enum TypedArgParams {}
+
+impl ast::ArgParams for TypedArgParams {
+ type Id = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
- type VecOperand = (spirv::Word, u8);
+ type IdOrVector = ast::IdOrVector<spirv::Word>;
+ type OperandOrVector = ast::OperandOrVector<spirv::Word>;
+ type SrcMemberOperand = (spirv::Word, u8);
}
-impl ArgParamsEx for NormalizedArgParams {
+impl ArgParamsEx for TypedArgParams {
fn get_fn_decl<'a, 'b>(
- id: &Self::ID,
+ id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'a, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
-#[derive(Copy, Clone)]
-pub enum StateSpace {
- Reg,
- Const,
- Global,
- Local,
- Shared,
- Param,
- ParamReg,
-}
+type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
-struct Function<'input> {
- pub func_directive: ast::MethodDecl<'input, ExpandedArgParams>,
- pub globals: Vec<ExpandedStatement>,
- pub body: Option<Vec<ExpandedStatement>>,
-}
-
impl ast::ArgParams for ExpandedArgParams {
- type ID = spirv::Word;
+ type Id = spirv::Word;
type Operand = spirv::Word;
type CallOperand = spirv::Word;
- type VecOperand = spirv::Word;
+ type IdOrVector = spirv::Word;
+ type OperandOrVector = spirv::Word;
+ type SrcMemberOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
fn get_fn_decl<'a, 'b>(
- id: &Self::ID,
+ id: &Self::Id,
decl: &'b GlobalFnDeclResolver<'a, 'b>,
) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
-trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
- fn variable(
+#[derive(Copy, Clone)]
+pub enum StateSpace {
+ Reg,
+ Const,
+ Global,
+ Local,
+ Shared,
+ Param,
+ ParamReg,
+}
+
+struct Function<'input> {
+ pub func_directive: ast::MethodDecl<'input, spirv::Word>,
+ pub globals: Vec<ExpandedStatement>,
+ pub body: Option<Vec<ExpandedStatement>>,
+}
+
+pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
+ fn id(
&mut self,
- desc: ArgumentDescriptor<T::ID>,
+ desc: ArgumentDescriptor<T::Id>,
typ: Option<ast::Type>,
- ) -> Result<U::ID, TranslateError>;
+ ) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
typ: ast::Type,
) -> Result<U::Operand, TranslateError>;
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<T::IdOrVector>,
+ typ: ast::Type,
+ ) -> Result<U::IdOrVector, TranslateError>;
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<T::OperandOrVector>,
+ typ: ast::Type,
+ ) -> Result<U::OperandOrVector, TranslateError>;
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<T::CallOperand>,
typ: ast::Type,
) -> Result<U::CallOperand, TranslateError>;
- fn src_vec_operand(
+ fn src_member_operand(
&mut self,
- desc: ArgumentDescriptor<T::VecOperand>,
- typ: (ast::MovVectorType, u8),
- ) -> Result<U::VecOperand, TranslateError>;
+ desc: ArgumentDescriptor<T::SrcMemberOperand>,
+ typ: (ast::ScalarType, u8),
+ ) -> Result<U::SrcMemberOperand, TranslateError>;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
@@ -2422,7 +2657,7 @@ where
Option<ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
- fn variable(
+ fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<ast::Type>,
@@ -2438,6 +2673,22 @@ where
self(desc, Some(t))
}
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ self(desc, Some(typ))
+ }
+
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
+ self(desc, Some(typ))
+ }
+
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
@@ -2446,10 +2697,10 @@ where
self(desc, Some(t))
}
- fn src_vec_operand(
+ fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- (scalar_type, vec_len): (ast::MovVectorType, u8),
+ (scalar_type, vec_len): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
self(
desc.new_op(desc.op),
@@ -2462,7 +2713,7 @@ impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> fo
where
T: FnMut(&str) -> Result<spirv::Word, TranslateError>,
{
- fn variable(
+ fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
_: Option<ast::Type>,
@@ -2477,8 +2728,38 @@ where
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)),
- ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)),
+ ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
+ }
+ }
+
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::IdOrVector<&'a str>>,
+ _: ast::Type,
+ ) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
+ match desc.op {
+ ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)),
+ ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec(
+ ids.into_iter().map(self).collect::<Result<_, _>>()?,
+ )),
+ }
+ }
+
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::OperandOrVector<&'a str>>,
+ _: ast::Type,
+ ) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
+ match desc.op {
+ ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)),
+ ast::OperandOrVector::RegOffset(id, imm) => {
+ Ok(ast::OperandOrVector::RegOffset(self(id)?, imm))
+ }
+ ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
+ ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec(
+ ids.into_iter().map(self).collect::<Result<_, _>>()?,
+ )),
}
}
@@ -2493,16 +2774,16 @@ where
}
}
- fn src_vec_operand(
+ fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(&str, u8)>,
- _: (ast::MovVectorType, u8),
+ _: (ast::ScalarType, u8),
) -> Result<(spirv::Word, u8), TranslateError> {
Ok((self(desc.op.0)?, desc.op.1))
}
}
-struct ArgumentDescriptor<Op> {
+pub struct ArgumentDescriptor<Op> {
op: Op,
is_dst: bool,
sema: ArgumentSemantics,
@@ -2536,22 +2817,19 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
visitor: &mut V,
) -> Result<ast::Instruction<U>, TranslateError> {
Ok(match self {
- ast::Instruction::MovVector(t, a) => {
- ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length))?)
- }
ast::Instruction::Abs(d, arg) => {
ast::Instruction::Abs(d, arg.map(visitor, false, ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
- ast::Instruction::Call(_) => unreachable!(),
+ ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
let is_param = d.state_space == ast::LdStateSpace::Param
|| d.state_space == ast::LdStateSpace::Local;
- ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)?)
+ ast::Instruction::Ld(d, a.map(visitor, inst_type, is_param)?)
}
ast::Instruction::Mov(d, a) => {
- let mapped = a.map(visitor, d.src_is_address, d.typ.into())?;
+ let mapped = a.map(visitor, d)?;
ast::Instruction::Mov(d, mapped)
}
ast::Instruction::Mul(d, a) => {
@@ -2617,7 +2895,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
}
-impl VisitVariable for ast::Instruction<NormalizedArgParams> {
+impl VisitVariable for ast::Instruction<TypedArgParams> {
fn visit_variable<
'a,
F: FnMut(
@@ -2627,19 +2905,19 @@ impl VisitVariable for ast::Instruction<NormalizedArgParams> {
>(
self,
f: &mut F,
- ) -> Result<UnadornedStatement, TranslateError> {
+ ) -> Result<TypedStatement, TranslateError> {
Ok(Statement::Instruction(self.map(f)?))
}
}
-impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
+impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
Option<ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
- fn variable(
+ fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<ast::Type>,
@@ -2673,10 +2951,47 @@ where
}
}
- fn src_vec_operand(
+ fn id_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
+ match desc.op {
+ ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)),
+ ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec(
+ ids.iter()
+ .map(|id| self(desc.new_op(*id), Some(typ)))
+ .collect::<Result<_, _>>()?,
+ )),
+ }
+ }
+
+ fn operand_or_vector(
+ &mut self,
+ desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
+ typ: ast::Type,
+ ) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
+ match desc.op {
+ ast::OperandOrVector::Reg(id) => {
+ Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?))
+ }
+ ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset(
+ self(desc.new_op(id), Some(typ))?,
+ imm,
+ )),
+ ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
+ ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec(
+ ids.iter()
+ .map(|id| self(desc.new_op(*id), Some(typ)))
+ .collect::<Result<_, _>>()?,
+ )),
+ }
+ }
+
+ fn src_member_operand(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vector_len): (ast::MovVectorType, u8),
+ (scalar_type, vector_len): (ast::ScalarType, u8),
) -> Result<(spirv::Word, u8), TranslateError> {
Ok((
self(
@@ -2750,7 +3065,6 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
- | ast::Instruction::MovVector(_, _)
| ast::Instruction::Mul(_, _)
| ast::Instruction::Add(_, _)
| ast::Instruction::Setp(_, _)
@@ -2786,12 +3100,44 @@ type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
struct CompositeRead {
- pub typ: ast::MovVectorType,
+ pub typ: ast::ScalarType,
pub dst: spirv::Word,
pub src_composite: spirv::Word,
pub src_index: u32,
}
+impl VisitVariableExpanded for CompositeRead {
+ fn visit_variable_extended<
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<ExpandedStatement, TranslateError> {
+ Ok(Statement::Composite(CompositeRead {
+ dst: f(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(ast::Type::Scalar(self.typ)),
+ )?,
+ src_composite: f(
+ ArgumentDescriptor {
+ op: self.src_composite,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(ast::Type::Scalar(self.typ)),
+ )?,
+ ..self
+ }))
+ }
+}
+
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
@@ -2875,12 +3221,16 @@ impl ast::VariableParamType {
}
impl<T: ArgParamsEx> ast::Arg1<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id>>(self) -> ast::Arg1<U> {
+ ast::Arg1 { src: self.src }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> Result<ast::Arg1<U>, TranslateError> {
- let new_src = visitor.variable(
+ let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
@@ -2893,13 +3243,20 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
}
impl<T: ArgParamsEx> ast::Arg2<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg2<U> {
+ ast::Arg2 {
+ src: self.src,
+ dst: self.dst,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
src_is_addr: bool,
t: ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
- let new_dst = visitor.variable(
+ let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -2925,62 +3282,82 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
})
}
- fn map_ld<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
- is_param: bool,
+ dst_t: ast::Type,
+ src_t: ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
- let dst = visitor.variable(
+ let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ Some(dst_t),
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if is_param {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ sema: ArgumentSemantics::Default,
},
- t,
+ src_t,
)?;
Ok(ast::Arg2 { dst, src })
}
+}
- fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg2Ld<T> {
+ fn cast<U: ArgParamsEx<Operand = T::Operand, IdOrVector = T::IdOrVector>>(
+ self,
+ ) -> ast::Arg2Ld<U> {
+ ast::Arg2Ld {
+ dst: self.dst,
+ src: self.src,
+ }
+ }
+
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- dst_t: ast::Type,
- src_t: ast::Type,
- ) -> Result<ast::Arg2<U>, TranslateError> {
- let dst = visitor.variable(
+ t: ast::Type,
+ is_param: bool,
+ ) -> Result<ast::Arg2Ld<U>, TranslateError> {
+ let dst = visitor.id_or_vector(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(dst_t),
+ t.into(),
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ sema: if is_param {
+ ArgumentSemantics::RegisterPointer
+ } else {
+ ArgumentSemantics::PhysicalPointer
+ },
},
- src_t,
+ t,
)?;
- Ok(ast::Arg2 { dst, src })
+ Ok(ast::Arg2Ld { dst, src })
}
}
impl<T: ArgParamsEx> ast::Arg2St<T> {
+ fn cast<U: ArgParamsEx<Operand = T::Operand, OperandOrVector = T::OperandOrVector>>(
+ self,
+ ) -> ast::Arg2St<U> {
+ ast::Arg2St {
+ src1: self.src1,
+ src2: self.src2,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -2999,7 +3376,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
},
t,
)?;
- let src2 = visitor.operand(
+ let src2 = visitor.operand_or_vector(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
@@ -3011,105 +3388,191 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
}
}
-impl<T: ArgParamsEx> ast::Arg2Vec<T> {
- fn dst(&self) -> &T::ID {
+impl<T: ArgParamsEx> ast::Arg2Mov<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ details: ast::MovDetails,
+ ) -> Result<ast::Arg2Mov<U>, TranslateError> {
+ Ok(match self {
+ ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?),
+ ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?),
+ })
+ }
+}
+
+impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
+ fn cast<U: ArgParamsEx<IdOrVector = P::IdOrVector, OperandOrVector = P::OperandOrVector>>(
+ self,
+ ) -> ast::Arg2MovNormal<U> {
+ ast::Arg2MovNormal {
+ dst: self.dst,
+ src: self.src,
+ }
+ }
+
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
+ self,
+ visitor: &mut V,
+ details: ast::MovDetails,
+ ) -> Result<ast::Arg2MovNormal<U>, TranslateError> {
+ let dst = visitor.id_or_vector(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ details.typ.into(),
+ )?;
+ let src = visitor.operand_or_vector(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: if details.src_is_address {
+ ArgumentSemantics::RegisterPointer
+ } else {
+ ArgumentSemantics::PhysicalPointer
+ },
+ },
+ details.typ.into(),
+ )?;
+ Ok(ast::Arg2MovNormal { dst, src })
+ }
+}
+
+impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, SrcMemberOperand = T::SrcMemberOperand>>(
+ self,
+ ) -> ast::Arg2MovMember<U> {
+ match self {
+ ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2),
+ ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src),
+ ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2),
+ }
+ }
+
+ fn vector_dst(&self) -> Option<&T::Id> {
match self {
- ast::Arg2Vec::Dst((d, _), _, _)
- | ast::Arg2Vec::Src(d, _)
- | ast::Arg2Vec::Both((d, _), _, _) => d,
+ ast::Arg2MovMember::Src(_, _) => None,
+ ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => {
+ Some(d)
+ }
}
}
+ fn vector_src(&self) -> Option<&T::SrcMemberOperand> {
+ match self {
+ ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d),
+ ast::Arg2MovMember::Dst(_, _, _) => None,
+ }
+ }
+}
+
+impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- (scalar_type, vec_len): (ast::MovVectorType, u8),
- ) -> Result<ast::Arg2Vec<U>, TranslateError> {
+ details: ast::MovDetails,
+ ) -> Result<ast::Arg2MovMember<U>, TranslateError> {
match self {
- ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => {
- let dst = visitor.variable(
+ ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
+ let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let src1 = visitor.variable(
+ let src1 = visitor.id(
ArgumentDescriptor {
op: composite_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let src2 = visitor.variable(
+ let src2 = visitor.id(
ArgumentDescriptor {
op: scalar_src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ sema: if details.src_is_address {
+ ArgumentSemantics::Address
+ } else {
+ ArgumentSemantics::Default
+ },
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- Ok(ast::Arg2Vec::Dst((dst, len), src1, src2))
+ Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2))
}
- ast::Arg2Vec::Src(dst, src) => {
- let dst = visitor.variable(
+ ast::Arg2MovMember::Src(dst, src) => {
+ let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let src = visitor.src_vec_operand(
+ let scalar_typ = details.typ.get_scalar()?;
+ let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
- (scalar_type, vec_len),
+ (scalar_typ.into(), details.src_width),
)?;
- Ok(ast::Arg2Vec::Src(dst, src))
+ Ok(ast::Arg2MovMember::Src(dst, src))
}
- ast::Arg2Vec::Both((dst, len), composite_src, src) => {
- let dst = visitor.variable(
+ ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
+ let dst = visitor.id(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let composite_src = visitor.variable(
+ let composite_src = visitor.id(
ArgumentDescriptor {
op: composite_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(scalar_type.into())),
+ Some(details.typ.into()),
)?;
- let src = visitor.src_vec_operand(
+ let scalar_typ = details.typ.get_scalar()?;
+ let src = visitor.src_member_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
- (scalar_type, vec_len),
+ (scalar_typ.into(), details.src_width),
)?;
- Ok(ast::Arg2Vec::Both((dst, len), composite_src, src))
+ Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
}
}
}
}
impl<T: ArgParamsEx> ast::Arg3<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg3<U> {
+ ast::Arg3 {
+ dst: self.dst,
+ src1: self.src1,
+ src2: self.src2,
+ }
+ }
+
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
- let dst = visitor.variable(
+ let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -3141,7 +3604,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
- let dst = visitor.variable(
+ let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -3170,12 +3633,21 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
}
impl<T: ArgParamsEx> ast::Arg4<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4<U> {
+ ast::Arg4 {
+ dst: self.dst,
+ src1: self.src1,
+ src2: self.src2,
+ src3: self.src3,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg4<U>, TranslateError> {
- let dst = visitor.variable(
+ let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -3217,12 +3689,21 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
}
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4Setp<U> {
+ ast::Arg4Setp {
+ dst1: self.dst1,
+ dst2: self.dst2,
+ src1: self.src1,
+ src2: self.src2,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg4Setp<U>, TranslateError> {
- let dst1 = visitor.variable(
+ let dst1 = visitor.id(
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
@@ -3233,7 +3714,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
let dst2 = self
.dst2
.map(|dst2| {
- visitor.variable(
+ visitor.id(
ArgumentDescriptor {
op: dst2,
is_dst: true,
@@ -3269,12 +3750,22 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
}
impl<T: ArgParamsEx> ast::Arg5<T> {
+ fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg5<U> {
+ ast::Arg5 {
+ dst1: self.dst1,
+ dst2: self.dst2,
+ src1: self.src1,
+ src2: self.src2,
+ src3: self.src3,
+ }
+ }
+
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
) -> Result<ast::Arg5<U>, TranslateError> {
- let dst1 = visitor.variable(
+ let dst1 = visitor.id(
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
@@ -3285,7 +3776,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
let dst2 = self
.dst2
.map(|dst2| {
- visitor.variable(
+ visitor.id(
ArgumentDescriptor {
op: dst2,
is_dst: true,
@@ -3329,6 +3820,22 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
}
}
+impl ast::Type {
+ fn get_vector(self) -> Result<(ast::ScalarType, u8), TranslateError> {
+ match self {
+ ast::Type::Vector(t, len) => Ok((t, len)),
+ _ => Err(TranslateError::MismatchedType),
+ }
+ }
+
+ fn get_scalar(self) -> Result<ast::ScalarType, TranslateError> {
+ match self {
+ ast::Type::Scalar(t) => Ok(t),
+ _ => Err(TranslateError::MismatchedType),
+ }
+ }
+}
+
impl<T> ast::CallOperand<T> {
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
self,
@@ -3528,13 +4035,21 @@ impl From<ast::FnArgumentType> for ast::VariableType {
impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> {
match self {
- ast::Operand::Reg(r) => Some(r),
- ast::Operand::RegOffset(r, _) => Some(r),
+ ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r),
ast::Operand::Imm(_) => None,
}
}
}
+impl<T> ast::OperandOrVector<T> {
+ fn single_underlying(&self) -> Option<&T> {
+ match self {
+ ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r),
+ ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None,
+ }
+ }
+}
+
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@@ -3891,7 +4406,7 @@ fn insert_implicit_bitcasts(
}
Ok(())
}
-impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
+impl<'a> ast::MethodDecl<'a, &'a str> {
fn name(&self) -> &'a str {
match self {
ast::MethodDecl::Kernel(name, _) => name,
@@ -3900,8 +4415,8 @@ impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
}
}
-impl<'a, P: ArgParamsEx<ID = spirv::Word>> ast::MethodDecl<'a, P> {
- fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<P>)) {
+impl<'a> ast::MethodDecl<'a, spirv::Word> {
+ fn visit_args(&self, f: &mut impl FnMut(&ast::FnArgument<spirv::Word>)) {
match self {
ast::MethodDecl::Func(_, _, params) => params.iter().for_each(f),
ast::MethodDecl::Kernel(_, params) => params.iter().for_each(|arg| {