aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ptx/src/ast.rs81
-rw-r--r--ptx/src/ptx.lalrpop121
-rw-r--r--ptx/src/test/spirv_run/vector.spvtxt124
-rw-r--r--ptx/src/test/spirv_run/vector_extract.spvtxt183
-rw-r--r--ptx/src/translate.rs1996
5 files changed, 993 insertions, 1512 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 367f060..aba6bda 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -557,7 +557,7 @@ pub enum Instruction<P: ArgParams> {
Mul(MulDetails, Arg3<P>),
Add(ArithDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>),
- SetpBool(SetpBoolData, Arg5<P>),
+ SetpBool(SetpBoolData, Arg5Setp<P>),
Not(BooleanType, Arg2<P>),
Bra(BraData, Arg1<P>),
Cvt(CvtDetails, Arg2<P>),
@@ -614,16 +614,12 @@ pub struct CallInst<P: ArgParams> {
pub uniform: bool,
pub ret_params: Vec<P::Id>,
pub func: P::Id,
- pub param_list: Vec<P::CallOperand>,
+ pub param_list: Vec<P::Operand>,
}
pub trait ArgParams {
type Id;
type Operand;
- type IdOrVector;
- type OperandOrVector;
- type CallOperand;
- type SrcMemberOperand;
}
pub struct ParsedArgParams<'a> {
@@ -633,10 +629,6 @@ pub struct ParsedArgParams<'a> {
impl<'a> ArgParams for ParsedArgParams<'a> {
type Id = &'a str;
type Operand = Operand<&'a str>;
- type CallOperand = CallOperand<&'a str>;
- type IdOrVector = IdOrVector<&'a str>;
- type OperandOrVector = OperandOrVector<&'a str>;
- type SrcMemberOperand = (&'a str, u8);
}
pub struct Arg1<P: ArgParams> {
@@ -648,45 +640,32 @@ pub struct Arg1Bar<P: ArgParams> {
}
pub struct Arg2<P: ArgParams> {
- pub dst: P::Id,
+ pub dst: P::Operand,
pub src: P::Operand,
}
pub struct Arg2Ld<P: ArgParams> {
- pub dst: P::IdOrVector,
+ pub dst: P::Operand,
pub src: P::Operand,
}
pub struct Arg2St<P: ArgParams> {
pub src1: P::Operand,
- pub src2: P::OperandOrVector,
-}
-
-pub enum Arg2Mov<P: ArgParams> {
- Normal(Arg2MovNormal<P>),
- Member(Arg2MovMember<P>),
-}
-
-pub struct Arg2MovNormal<P: ArgParams> {
- pub dst: P::IdOrVector,
- pub src: P::OperandOrVector,
+ pub src2: P::Operand,
}
-// We duplicate dst here because during further compilation
-// composite dst and composite src will receive different ids
-pub enum Arg2MovMember<P: ArgParams> {
- Dst((P::Id, u8), P::Id, P::Id),
- Src(P::Id, P::SrcMemberOperand),
- Both((P::Id, u8), P::Id, P::SrcMemberOperand),
+pub struct Arg2Mov<P: ArgParams> {
+ pub dst: P::Operand,
+ pub src: P::Operand,
}
pub struct Arg3<P: ArgParams> {
- pub dst: P::Id,
+ pub dst: P::Operand,
pub src1: P::Operand,
pub src2: P::Operand,
}
pub struct Arg4<P: ArgParams> {
- pub dst: P::Id,
+ pub dst: P::Operand,
pub src1: P::Operand,
pub src2: P::Operand,
pub src3: P::Operand,
@@ -699,7 +678,7 @@ pub struct Arg4Setp<P: ArgParams> {
pub src2: P::Operand,
}
-pub struct Arg5<P: ArgParams> {
+pub struct Arg5Setp<P: ArgParams> {
pub dst1: P::Id,
pub dst2: Option<P::Id>,
pub src1: P::Operand,
@@ -715,39 +694,13 @@ pub enum ImmediateValue {
F64(f64),
}
-#[derive(Copy, Clone)]
-pub enum Operand<ID> {
- Reg(ID),
- RegOffset(ID, i32),
- Imm(ImmediateValue),
-}
-
-#[derive(Copy, Clone)]
-pub enum CallOperand<ID> {
- Reg(ID),
- Imm(ImmediateValue),
-}
-
-pub enum IdOrVector<ID> {
- Reg(ID),
- Vec(Vec<ID>),
-}
-
-pub enum OperandOrVector<ID> {
- Reg(ID),
- RegOffset(ID, i32),
+#[derive(Clone)]
+pub enum Operand<Id> {
+ Reg(Id),
+ RegOffset(Id, i32),
Imm(ImmediateValue),
- Vec(Vec<ID>),
-}
-
-impl<T> From<Operand<T>> for OperandOrVector<T> {
- fn from(this: Operand<T>) -> Self {
- match this {
- Operand::Reg(r) => OperandOrVector::Reg(r),
- Operand::RegOffset(r, imm) => OperandOrVector::RegOffset(r, imm),
- Operand::Imm(imm) => OperandOrVector::Imm(imm),
- }
- }
+ VecMember(Id, u8),
+ VecPack(Vec<Id>),
}
pub enum VectorPrefix {
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index d2c235a..fd2a3f1 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -721,7 +721,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:IdOrVector> "," <src:MemoryOperand> => {
+ "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:DstOperandVec> "," <src:MemoryOperand> => {
ast::Instruction::Ld(
ast::LdDetails {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@@ -734,16 +734,6 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
};
-IdOrVector: ast::IdOrVector<&'input str> = {
- <dst:ExtendedID> => ast::IdOrVector::Reg(dst),
- <dst:VectorExtract> => ast::IdOrVector::Vec(dst)
-}
-
-OperandOrVector: ast::OperandOrVector<&'input str> = {
- <op:Operand> => ast::OperandOrVector::from(op),
- <dst:VectorExtract> => ast::OperandOrVector::Vec(dst)
-}
-
LdStType: ast::LdStType = {
<v:VectorPrefix> <t:LdStScalarType> => ast::LdStType::Vector(t, v),
<t:LdStScalarType> => ast::LdStType::Scalar(t),
@@ -780,27 +770,17 @@ LdCacheOperator: ast::LdCacheOperator = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
- <m:MovNormal> => ast::Instruction::Mov(m.0, m.1),
- <m:MovVector> => ast::Instruction::Mov(m.0, m.1),
-};
-
-
-MovNormal: (ast::MovDetails, ast::Arg2Mov<ast::ParsedArgParams<'input>>) = {
- "mov" <t:MovScalarType> <dst:ExtendedID> "," <src:Operand> => {(
- ast::MovDetails::new(ast::Type::Scalar(t)),
- ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: ast::IdOrVector::Reg(dst), src: src.into() })
- )},
- "mov" <pref:VectorPrefix> <t:MovVectorType> <dst:IdOrVector> "," <src:OperandOrVector> => {(
- ast::MovDetails::new(ast::Type::Vector(t, pref)),
- ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: dst, src: src })
- )}
-}
-
-MovVector: (ast::MovDetails, ast::Arg2Mov<ast::ParsedArgParams<'input>>) = {
- "mov" <t:MovVectorType> <a:Arg2MovMember> => {(
- ast::MovDetails::new(ast::Type::Scalar(t.into())),
- ast::Arg2Mov::Member(a)
- )},
+ "mov" <pref:VectorPrefix?> <t:MovScalarType> <dst:DstOperandVec> "," <src:SrcOperandVec> => {
+ let mov_type = match pref {
+ Some(vec_width) => ast::Type::Vector(t, vec_width),
+ None => ast::Type::Scalar(t)
+ };
+ let details = ast::MovDetails::new(mov_type);
+ ast::Instruction::Mov(
+ details,
+ ast::Arg2Mov { dst, src }
+ )
+ }
}
#[inline]
@@ -819,21 +799,6 @@ MovScalarType: ast::ScalarType = {
".pred" => ast::ScalarType::Pred
};
-#[inline]
-MovVectorType: ast::ScalarType = {
- ".b16" => ast::ScalarType::B16,
- ".b32" => ast::ScalarType::B32,
- ".b64" => ast::ScalarType::B64,
- ".u16" => ast::ScalarType::U16,
- ".u32" => ast::ScalarType::U32,
- ".u64" => ast::ScalarType::U64,
- ".s16" => ast::ScalarType::S16,
- ".s32" => ast::ScalarType::S32,
- ".s64" => ast::ScalarType::S64,
- ".f32" => ast::ScalarType::F32,
- ".f64" => ast::ScalarType::F64,
-};
-
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
@@ -921,7 +886,7 @@ InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = {
// TODO: support f16 setp
InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = {
"setp" <d:SetpMode> <a:Arg4Setp> => ast::Instruction::Setp(d, a),
- "setp" <d:SetpBoolMode> <a:Arg5> => ast::Instruction::SetpBool(d, a),
+ "setp" <d:SetpBoolMode> <a:Arg5Setp> => ast::Instruction::SetpBool(d, a),
};
SetpMode: ast::SetpData = {
@@ -1198,7 +1163,7 @@ ShrType: ast::ShrType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:OperandOrVector> => {
+ "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:SrcOperandVec> => {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@@ -1775,9 +1740,9 @@ Operand: ast::Operand<&'input str> = {
<x:ImmediateValue> => ast::Operand::Imm(x)
};
-CallOperand: ast::CallOperand<&'input str> = {
- <r:ExtendedID> => ast::CallOperand::Reg(r),
- <x:ImmediateValue> => ast::CallOperand::Imm(x)
+CallOperand: ast::Operand<&'input str> = {
+ <r:ExtendedID> => ast::Operand::Reg(r),
+ <x:ImmediateValue> => ast::Operand::Imm(x)
};
// TODO: start parsing whole constants sub-language:
@@ -1825,13 +1790,7 @@ Arg1Bar: ast::Arg1Bar<ast::ParsedArgParams<'input>> = {
};
Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = {
- <dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>}
-};
-
-Arg2MovMember: ast::Arg2MovMember<ast::ParsedArgParams<'input>> = {
- <dst:MemberOperand> "," <src:ExtendedID> => ast::Arg2MovMember::Dst(dst, dst.0, src),
- <dst:ExtendedID> "," <src:MemberOperand> => ast::Arg2MovMember::Src(dst, src),
- <dst:MemberOperand> "," <src:MemberOperand> => ast::Arg2MovMember::Both(dst, dst.0, src),
+ <dst:DstOperand> "," <src:Operand> => ast::Arg2{<>}
};
MemberOperand: (&'input str, u8) = {
@@ -1855,19 +1814,19 @@ VectorExtract: Vec<&'input str> = {
};
Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = {
- <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
+ <dst:DstOperand> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>}
};
Arg3Atom: ast::Arg3<ast::ParsedArgParams<'input>> = {
- <dst:ExtendedID> "," "[" <src1:Operand> "]" "," <src2:Operand> => ast::Arg3{<>}
+ <dst:DstOperand> "," "[" <src1:Operand> "]" "," <src2:Operand> => ast::Arg3{<>}
};
Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = {
- <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>}
+ <dst:DstOperand> "," <src1:Operand> "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>}
};
Arg4Atom: ast::Arg4<ast::ParsedArgParams<'input>> = {
- <dst:ExtendedID> "," "[" <src1:Operand> "]" "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>}
+ <dst:DstOperand> "," "[" <src1:Operand> "]" "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>}
};
Arg4Setp: ast::Arg4Setp<ast::ParsedArgParams<'input>> = {
@@ -1875,22 +1834,50 @@ Arg4Setp: ast::Arg4Setp<ast::ParsedArgParams<'input>> = {
};
// TODO: pass src3 negation somewhere
-Arg5: ast::Arg5<ast::ParsedArgParams<'input>> = {
- <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
+Arg5Setp: ast::Arg5Setp<ast::ParsedArgParams<'input>> = {
+ <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5Setp{<>}
};
-ArgCall: (Vec<&'input str>, &'input str, Vec<ast::CallOperand<&'input str>>) = {
+ArgCall: (Vec<&'input str>, &'input str, Vec<ast::Operand<&'input str>>) = {
"(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => {
(ret_params, func, param_list)
},
<func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list),
- <func:ExtendedID> => (Vec::new(), func, Vec::<ast::CallOperand<_>>::new()),
+ <func:ExtendedID> => (Vec::new(), func, Vec::<ast::Operand<_>>::new()),
};
OptionalDst: &'input str = {
"|" <dst2:ExtendedID> => dst2
}
+SrcOperand: ast::Operand<&'input str> = {
+ <r:ExtendedID> => ast::Operand::Reg(r),
+ <r:ExtendedID> "+" <offset:S32Num> => ast::Operand::RegOffset(r, offset),
+ <x:ImmediateValue> => ast::Operand::Imm(x),
+ <mem_op:MemberOperand> => {
+ let (reg, idx) = mem_op;
+ ast::Operand::VecMember(reg, idx)
+ }
+}
+
+SrcOperandVec: ast::Operand<&'input str> = {
+ <normal:SrcOperand> => normal,
+ <vec:VectorExtract> => ast::Operand::VecPack(vec),
+}
+
+DstOperand: ast::Operand<&'input str> = {
+ <r:ExtendedID> => ast::Operand::Reg(r),
+ <mem_op:MemberOperand> => {
+ let (reg, idx) = mem_op;
+ ast::Operand::VecMember(reg, idx)
+ }
+}
+
+DstOperandVec: ast::Operand<&'input str> = {
+ <normal:DstOperand> => normal,
+ <vec:VectorExtract> => ast::Operand::VecPack(vec),
+}
+
VectorPrefix: u8 = {
".v2" => 2,
".v4" => 4
diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt
index 535e480..a77ab7d 100644
--- a/ptx/src/test/spirv_run/vector.spvtxt
+++ b/ptx/src/test/spirv_run/vector.spvtxt
@@ -7,91 +7,93 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %57 = OpExtInstImport "OpenCL.std"
+ %51 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %31 "vector"
+ OpEntryPoint Kernel %25 "vector"
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%v2uint = OpTypeVector %uint 2
- %61 = OpTypeFunction %v2uint %v2uint
+ %55 = OpTypeFunction %v2uint %v2uint
%_ptr_Function_v2uint = OpTypePointer Function %v2uint
%_ptr_Function_uint = OpTypePointer Function %uint
+ %uint_0 = OpConstant %uint 0
+ %uint_1 = OpConstant %uint 1
%ulong = OpTypeInt 64 0
- %65 = OpTypeFunction %void %ulong %ulong
+ %67 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_v2uint = OpTypePointer Generic %v2uint
- %1 = OpFunction %v2uint None %61
+ %1 = OpFunction %v2uint None %55
%7 = OpFunctionParameter %v2uint
- %30 = OpLabel
+ %24 = OpLabel
%2 = OpVariable %_ptr_Function_v2uint Function
%3 = OpVariable %_ptr_Function_v2uint Function
%4 = OpVariable %_ptr_Function_v2uint Function
%5 = OpVariable %_ptr_Function_uint Function
%6 = OpVariable %_ptr_Function_uint Function
OpStore %3 %7
- %9 = OpLoad %v2uint %3
- %27 = OpCompositeExtract %uint %9 0
- %8 = OpCopyObject %uint %27
+ %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0
+ %9 = OpLoad %uint %59
+ %8 = OpCopyObject %uint %9
OpStore %5 %8
- %11 = OpLoad %v2uint %3
- %28 = OpCompositeExtract %uint %11 1
- %10 = OpCopyObject %uint %28
+ %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1
+ %11 = OpLoad %uint %61
+ %10 = OpCopyObject %uint %11
OpStore %6 %10
%13 = OpLoad %uint %5
%14 = OpLoad %uint %6
%12 = OpIAdd %uint %13 %14
OpStore %6 %12
- %16 = OpLoad %v2uint %4
- %17 = OpLoad %uint %6
- %15 = OpCompositeInsert %v2uint %17 %16 0
- OpStore %4 %15
- %19 = OpLoad %v2uint %4
- %20 = OpLoad %uint %6
- %18 = OpCompositeInsert %v2uint %20 %19 1
- OpStore %4 %18
+ %16 = OpLoad %uint %6
+ %15 = OpCopyObject %uint %16
+ %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
+ OpStore %62 %15
+ %18 = OpLoad %uint %6
+ %17 = OpCopyObject %uint %18
+ %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
+ OpStore %63 %17
+ %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1
+ %20 = OpLoad %uint %64
+ %19 = OpCopyObject %uint %20
+ %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0
+ OpStore %65 %19
%22 = OpLoad %v2uint %4
- %23 = OpLoad %v2uint %4
- %29 = OpCompositeExtract %uint %23 1
- %21 = OpCompositeInsert %v2uint %29 %22 0
- OpStore %4 %21
- %25 = OpLoad %v2uint %4
- %24 = OpCopyObject %v2uint %25
- OpStore %2 %24
- %26 = OpLoad %v2uint %2
- OpReturnValue %26
+ %21 = OpCopyObject %v2uint %22
+ OpStore %2 %21
+ %23 = OpLoad %v2uint %2
+ OpReturnValue %23
OpFunctionEnd
- %31 = OpFunction %void None %65
- %40 = OpFunctionParameter %ulong
- %41 = OpFunctionParameter %ulong
- %55 = OpLabel
- %32 = OpVariable %_ptr_Function_ulong Function
+ %25 = OpFunction %void None %67
+ %34 = OpFunctionParameter %ulong
+ %35 = OpFunctionParameter %ulong
+ %49 = OpLabel
+ %26 = OpVariable %_ptr_Function_ulong Function
+ %27 = OpVariable %_ptr_Function_ulong Function
+ %28 = OpVariable %_ptr_Function_ulong Function
+ %29 = OpVariable %_ptr_Function_ulong Function
+ %30 = OpVariable %_ptr_Function_v2uint Function
+ %31 = OpVariable %_ptr_Function_uint Function
+ %32 = OpVariable %_ptr_Function_uint Function
%33 = OpVariable %_ptr_Function_ulong Function
- %34 = OpVariable %_ptr_Function_ulong Function
- %35 = OpVariable %_ptr_Function_ulong Function
- %36 = OpVariable %_ptr_Function_v2uint Function
- %37 = OpVariable %_ptr_Function_uint Function
- %38 = OpVariable %_ptr_Function_uint Function
- %39 = OpVariable %_ptr_Function_ulong Function
- OpStore %32 %40
- OpStore %33 %41
- %42 = OpLoad %ulong %32
- OpStore %34 %42
- %43 = OpLoad %ulong %33
- OpStore %35 %43
- %45 = OpLoad %ulong %34
- %52 = OpConvertUToPtr %_ptr_Generic_v2uint %45
- %44 = OpLoad %v2uint %52
- OpStore %36 %44
- %47 = OpLoad %v2uint %36
- %46 = OpFunctionCall %v2uint %1 %47
- OpStore %36 %46
- %49 = OpLoad %v2uint %36
- %53 = OpBitcast %ulong %49
- %48 = OpCopyObject %ulong %53
- OpStore %39 %48
- %50 = OpLoad %ulong %35
- %51 = OpLoad %v2uint %36
- %54 = OpConvertUToPtr %_ptr_Generic_v2uint %50
- OpStore %54 %51
+ OpStore %26 %34
+ OpStore %27 %35
+ %36 = OpLoad %ulong %26
+ OpStore %28 %36
+ %37 = OpLoad %ulong %27
+ OpStore %29 %37
+ %39 = OpLoad %ulong %28
+ %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39
+ %38 = OpLoad %v2uint %46
+ OpStore %30 %38
+ %41 = OpLoad %v2uint %30
+ %40 = OpFunctionCall %v2uint %1 %41
+ OpStore %30 %40
+ %43 = OpLoad %v2uint %30
+ %47 = OpBitcast %ulong %43
+ %42 = OpCopyObject %ulong %47
+ OpStore %33 %42
+ %44 = OpLoad %ulong %29
+ %45 = OpLoad %v2uint %30
+ %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44
+ OpStore %48 %45
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt
index 4943189..2037dec 100644
--- a/ptx/src/test/spirv_run/vector_extract.spvtxt
+++ b/ptx/src/test/spirv_run/vector_extract.spvtxt
@@ -7,12 +7,12 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %73 = OpExtInstImport "OpenCL.std"
+ %61 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "vector_extract"
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
- %76 = OpTypeFunction %void %ulong %ulong
+ %64 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%ushort = OpTypeInt 16 0
%_ptr_Function_ushort = OpTypePointer Function %ushort
@@ -21,10 +21,10 @@
%uchar = OpTypeInt 8 0
%v4uchar = OpTypeVector %uchar 4
%_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar
- %1 = OpFunction %void None %76
- %11 = OpFunctionParameter %ulong
- %12 = OpFunctionParameter %ulong
- %71 = OpLabel
+ %1 = OpFunction %void None %64
+ %17 = OpFunctionParameter %ulong
+ %18 = OpFunctionParameter %ulong
+ %59 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
@@ -34,89 +34,92 @@
%8 = OpVariable %_ptr_Function_ushort Function
%9 = OpVariable %_ptr_Function_ushort Function
%10 = OpVariable %_ptr_Function_v4ushort Function
- OpStore %2 %11
- OpStore %3 %12
- %13 = OpLoad %ulong %2
- OpStore %4 %13
- %14 = OpLoad %ulong %3
- OpStore %5 %14
- %19 = OpLoad %ulong %4
- %61 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %19
- %43 = OpLoad %v4uchar %61
- %62 = OpCompositeExtract %uchar %43 0
- %85 = OpBitcast %uchar %62
- %15 = OpUConvert %ushort %85
- %63 = OpCompositeExtract %uchar %43 1
- %86 = OpBitcast %uchar %63
- %16 = OpUConvert %ushort %86
- %64 = OpCompositeExtract %uchar %43 2
- %87 = OpBitcast %uchar %64
- %17 = OpUConvert %ushort %87
- %65 = OpCompositeExtract %uchar %43 3
- %88 = OpBitcast %uchar %65
- %18 = OpUConvert %ushort %88
- OpStore %6 %15
- OpStore %7 %16
- OpStore %8 %17
- OpStore %9 %18
- %21 = OpLoad %ushort %7
- %22 = OpLoad %ushort %8
- %23 = OpLoad %ushort %9
- %24 = OpLoad %ushort %6
- %44 = OpUndef %v4ushort
- %45 = OpCompositeInsert %v4ushort %21 %44 0
- %46 = OpCompositeInsert %v4ushort %22 %45 1
- %47 = OpCompositeInsert %v4ushort %23 %46 2
- %48 = OpCompositeInsert %v4ushort %24 %47 3
- %20 = OpCopyObject %v4ushort %48
- OpStore %10 %20
- %29 = OpLoad %v4ushort %10
- %49 = OpCopyObject %v4ushort %29
- %25 = OpCompositeExtract %ushort %49 0
- %26 = OpCompositeExtract %ushort %49 1
- %27 = OpCompositeExtract %ushort %49 2
- %28 = OpCompositeExtract %ushort %49 3
- OpStore %8 %25
- OpStore %9 %26
- OpStore %6 %27
- OpStore %7 %28
- %34 = OpLoad %ushort %8
- %35 = OpLoad %ushort %9
- %36 = OpLoad %ushort %6
- %37 = OpLoad %ushort %7
- %51 = OpUndef %v4ushort
- %52 = OpCompositeInsert %v4ushort %34 %51 0
- %53 = OpCompositeInsert %v4ushort %35 %52 1
- %54 = OpCompositeInsert %v4ushort %36 %53 2
- %55 = OpCompositeInsert %v4ushort %37 %54 3
- %50 = OpCopyObject %v4ushort %55
- %30 = OpCompositeExtract %ushort %50 0
- %31 = OpCompositeExtract %ushort %50 1
- %32 = OpCompositeExtract %ushort %50 2
- %33 = OpCompositeExtract %ushort %50 3
- OpStore %9 %30
- OpStore %6 %31
- OpStore %7 %32
- OpStore %8 %33
- %38 = OpLoad %ulong %5
- %39 = OpLoad %ushort %6
- %40 = OpLoad %ushort %7
- %41 = OpLoad %ushort %8
- %42 = OpLoad %ushort %9
- %56 = OpUndef %v4uchar
- %89 = OpBitcast %ushort %39
- %66 = OpUConvert %uchar %89
- %57 = OpCompositeInsert %v4uchar %66 %56 0
- %90 = OpBitcast %ushort %40
- %67 = OpUConvert %uchar %90
- %58 = OpCompositeInsert %v4uchar %67 %57 1
- %91 = OpBitcast %ushort %41
- %68 = OpUConvert %uchar %91
- %59 = OpCompositeInsert %v4uchar %68 %58 2
- %92 = OpBitcast %ushort %42
- %69 = OpUConvert %uchar %92
- %60 = OpCompositeInsert %v4uchar %69 %59 3
- %70 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %38
- OpStore %70 %60
+ OpStore %2 %17
+ OpStore %3 %18
+ %19 = OpLoad %ulong %2
+ OpStore %4 %19
+ %20 = OpLoad %ulong %3
+ OpStore %5 %20
+ %21 = OpLoad %ulong %4
+ %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21
+ %11 = OpLoad %v4uchar %49
+ %50 = OpCompositeExtract %uchar %11 0
+ %51 = OpCompositeExtract %uchar %11 1
+ %52 = OpCompositeExtract %uchar %11 2
+ %53 = OpCompositeExtract %uchar %11 3
+ %73 = OpBitcast %uchar %50
+ %22 = OpUConvert %ushort %73
+ %74 = OpBitcast %uchar %51
+ %23 = OpUConvert %ushort %74
+ %75 = OpBitcast %uchar %52
+ %24 = OpUConvert %ushort %75
+ %76 = OpBitcast %uchar %53
+ %25 = OpUConvert %ushort %76
+ OpStore %6 %22
+ OpStore %7 %23
+ OpStore %8 %24
+ OpStore %9 %25
+ %26 = OpLoad %ushort %7
+ %27 = OpLoad %ushort %8
+ %28 = OpLoad %ushort %9
+ %29 = OpLoad %ushort %6
+ %77 = OpUndef %v4ushort
+ %78 = OpCompositeInsert %v4ushort %26 %77 0
+ %79 = OpCompositeInsert %v4ushort %27 %78 1
+ %80 = OpCompositeInsert %v4ushort %28 %79 2
+ %81 = OpCompositeInsert %v4ushort %29 %80 3
+ %12 = OpCopyObject %v4ushort %81
+ %30 = OpCopyObject %v4ushort %12
+ OpStore %10 %30
+ %31 = OpLoad %v4ushort %10
+ %13 = OpCopyObject %v4ushort %31
+ %32 = OpCompositeExtract %ushort %13 0
+ %33 = OpCompositeExtract %ushort %13 1
+ %34 = OpCompositeExtract %ushort %13 2
+ %35 = OpCompositeExtract %ushort %13 3
+ OpStore %8 %32
+ OpStore %9 %33
+ OpStore %6 %34
+ OpStore %7 %35
+ %36 = OpLoad %ushort %8
+ %37 = OpLoad %ushort %9
+ %38 = OpLoad %ushort %6
+ %39 = OpLoad %ushort %7
+ %82 = OpUndef %v4ushort
+ %83 = OpCompositeInsert %v4ushort %36 %82 0
+ %84 = OpCompositeInsert %v4ushort %37 %83 1
+ %85 = OpCompositeInsert %v4ushort %38 %84 2
+ %86 = OpCompositeInsert %v4ushort %39 %85 3
+ %15 = OpCopyObject %v4ushort %86
+ %14 = OpCopyObject %v4ushort %15
+ %40 = OpCompositeExtract %ushort %14 0
+ %41 = OpCompositeExtract %ushort %14 1
+ %42 = OpCompositeExtract %ushort %14 2
+ %43 = OpCompositeExtract %ushort %14 3
+ OpStore %9 %40
+ OpStore %6 %41
+ OpStore %7 %42
+ OpStore %8 %43
+ %44 = OpLoad %ushort %6
+ %45 = OpLoad %ushort %7
+ %46 = OpLoad %ushort %8
+ %47 = OpLoad %ushort %9
+ %87 = OpBitcast %ushort %44
+ %54 = OpUConvert %uchar %87
+ %88 = OpBitcast %ushort %45
+ %55 = OpUConvert %uchar %88
+ %89 = OpBitcast %ushort %46
+ %56 = OpUConvert %uchar %89
+ %90 = OpBitcast %ushort %47
+ %57 = OpUConvert %uchar %90
+ %91 = OpUndef %v4uchar
+ %92 = OpCompositeInsert %v4uchar %54 %91 0
+ %93 = OpCompositeInsert %v4uchar %55 %92 1
+ %94 = OpCompositeInsert %v4uchar %56 %93 2
+ %95 = OpCompositeInsert %v4uchar %57 %94 3
+ %16 = OpCopyObject %v4uchar %95
+ %48 = OpLoad %ulong %5
+ %58 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %48
+ OpStore %58 %16
OpReturn
OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 15211ab..20578eb 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -27,6 +27,16 @@ quick_error! {
}
}
+#[cfg(debug_assertions)]
+fn error_unreachable() -> TranslateError {
+ unreachable!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_unreachable() -> TranslateError {
+ TranslateError::Unreachable
+}
+
#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
@@ -82,7 +92,7 @@ impl ast::Type {
ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
}
- ast::Type::Pointer(_, _) => return Err(TranslateError::Unreachable),
+ ast::Type::Pointer(_, _) => return Err(error_unreachable()),
})
}
}
@@ -364,7 +374,7 @@ impl TypeWordMap {
b.constant_composite(result_type, None, &components)
}
ast::Type::Array(typ, dims) => match dims.as_slice() {
- [] => return Err(TranslateError::Unreachable),
+ [] => return Err(error_unreachable()),
[dim] => {
let result_type = self
.get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
@@ -791,13 +801,14 @@ fn convert_dynamic_shared_memory_usage<'input>(
ast::PointerStateSpace::Shared,
)),
});
- let shared_var_st = ExpandedStatement::StoreVar(
- ast::Arg2St {
+ let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
src1: shared_var_id,
src2: shared_id_param,
},
- ast::Type::Scalar(ast::ScalarType::B8),
- );
+ typ: ast::Type::Scalar(ast::ScalarType::B8),
+ member_index: None,
+ });
let mut new_statements = vec![shared_var, shared_var_st];
replace_uses_of_shared_memory(
&mut new_statements,
@@ -963,18 +974,17 @@ fn compute_denorm_information<'input>(
denorm_count_map_update(&mut flush_counter, width, flush);
}
}
- Statement::LoadVar(_, _) => {}
- Statement::StoreVar(_, _) => {}
+ Statement::LoadVar(..) => {}
+ Statement::StoreVar(..) => {}
Statement::Call(_) => {}
- Statement::Composite(_) => {}
Statement::Conditional(_) => {}
Statement::Conversion(_) => {}
Statement::Constant(_) => {}
Statement::RetValue(_, _) => {}
- Statement::Undef(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
+ Statement::RepackVector(_) => {}
}
}
denorm_methods.insert(method_key, flush_counter);
@@ -1307,7 +1317,7 @@ fn to_ssa<'input, 'b>(
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
- convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
+ convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
let typed_statements =
convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
@@ -1431,7 +1441,7 @@ fn normalize_variable_decls(directives: &mut Vec<Directive>) {
fn convert_to_typed_statements(
func: Vec<UnconditionalStatement>,
fn_defs: &GlobalFnDeclResolver,
- id_defs: &NumericIdResolver,
+ id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
for s in func {
@@ -1447,7 +1457,7 @@ fn convert_to_typed_statements(
.partition(|(_, arg_type)| arg_type.is_param());
let normalized_input_args = out_params
.into_iter()
- .map(|(id, typ)| (ast::CallOperand::Reg(id), typ))
+ .map(|(id, typ)| (ast::Operand::Reg(id), typ))
.chain(in_args.into_iter())
.collect();
let resolved_call = ResolvedCall {
@@ -1456,205 +1466,117 @@ fn convert_to_typed_statements(
func: call.func,
param_list: normalized_input_args,
};
- result.push(Statement::Call(resolved_call));
- }
- ast::Instruction::Ld(d, arg) => {
- result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast())));
- }
- ast::Instruction::St(d, arg) => {
- result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast())));
- }
- ast::Instruction::Mov(mut d, args) => match args {
- ast::Arg2Mov::Normal(arg) => {
- if let Some(src_id) = arg.src.single_underlying() {
- let (typ, _) = id_defs.get_typed(*src_id)?;
- let take_address = match typ {
- ast::Type::Scalar(_) => false,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => true,
- ast::Type::Pointer(_, _) => true,
- };
- d.src_is_address = take_address;
- }
- result.push(Statement::Instruction(ast::Instruction::Mov(
- d,
- ast::Arg2Mov::Normal(arg.cast()),
- )));
- }
- ast::Arg2Mov::Member(args) => {
- if let Some(dst_typ) = args.vector_dst() {
- match id_defs.get_typed(*dst_typ)? {
- (ast::Type::Vector(_, len), _) => {
- d.dst_width = len;
- }
- _ => return Err(TranslateError::MismatchedType),
- }
- };
- if let Some((src_typ, _)) = args.vector_src() {
- match id_defs.get_typed(*src_typ)? {
- (ast::Type::Vector(_, len), _) => {
- d.src_width = len;
- }
- _ => return Err(TranslateError::MismatchedType),
- }
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let reresolved_call = resolved_call.visit(&mut visitor)?;
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => {
+ if let Some(src_id) = src.underlying() {
+ let (typ, _) = id_defs.get_typed(*src_id)?;
+ let take_address = match typ {
+ ast::Type::Scalar(_) => false,
+ ast::Type::Vector(_, _) => false,
+ ast::Type::Array(_, _) => true,
+ ast::Type::Pointer(_, _) => true,
};
- result.push(Statement::Instruction(ast::Instruction::Mov(
- d,
- ast::Arg2Mov::Member(args.cast()),
- )));
+ d.src_is_address = take_address;
}
- },
- ast::Instruction::Mul(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast())))
- }
- ast::Instruction::Add(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast())))
- }
- ast::Instruction::Setp(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast())))
- }
- ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction(
- ast::Instruction::SetpBool(d, a.cast()),
- )),
- ast::Instruction::Not(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast())))
- }
- ast::Instruction::Bra(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast())))
- }
- ast::Instruction::Cvt(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast())))
- }
- ast::Instruction::Cvta(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast())))
- }
- ast::Instruction::Shl(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast())))
- }
- ast::Instruction::Ret(d) => {
- result.push(Statement::Instruction(ast::Instruction::Ret(d)))
- }
- ast::Instruction::Abs(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast())))
- }
- ast::Instruction::Mad(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast())))
- }
- ast::Instruction::Shr(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast())))
- }
- ast::Instruction::Or(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
- }
- ast::Instruction::Sub(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast())))
- }
- ast::Instruction::Min(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast())))
- }
- ast::Instruction::Max(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
- }
- ast::Instruction::Rcp(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast())))
- }
- ast::Instruction::And(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::And(d, a.cast())))
- }
- ast::Instruction::Selp(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast())))
- }
- ast::Instruction::Bar(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast())))
- }
- ast::Instruction::Atom(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast())))
- }
- ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
- ast::Instruction::AtomCas(d, a.cast()),
- )),
- ast::Instruction::Div(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast())))
- }
- ast::Instruction::Sqrt(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast())))
- }
- ast::Instruction::Rsqrt(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
- }
- ast::Instruction::Neg(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast())))
- }
- ast::Instruction::Sin { flush_to_zero, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Sin {
- flush_to_zero,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Cos { flush_to_zero, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Cos {
- flush_to_zero,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Lg2 { flush_to_zero, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Lg2 {
- flush_to_zero,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Ex2 { flush_to_zero, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Ex2 {
- flush_to_zero,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Clz { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Clz {
- typ,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Brev { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Brev {
- typ,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Popc { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Popc {
- typ,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Xor { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Xor {
- typ,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Bfe { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Bfe {
- typ,
- arg: arg.cast(),
- }))
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let instruction = Statement::Instruction(
+ ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?,
+ );
+ visitor.func.push(instruction);
+ visitor.func.extend(visitor.post_stmts);
}
- ast::Instruction::Rem { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Rem {
- typ,
- arg: arg.cast(),
- }))
+ inst => {
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let instruction = Statement::Instruction(inst.map(&mut visitor)?);
+ visitor.func.push(instruction);
+ visitor.func.extend(visitor.post_stmts);
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
}
}
Ok(result)
}
+struct VectorRepackVisitor<'a, 'b> {
+ func: &'b mut Vec<TypedStatement>,
+ id_def: &'b mut NumericIdResolver<'a>,
+ post_stmts: Option<TypedStatement>,
+}
+
+impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
+ fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
+ VectorRepackVisitor {
+ func,
+ id_def,
+ post_stmts: None,
+ }
+ }
+
+ fn convert_vector(
+ &mut self,
+ is_dst: bool,
+ vector_sema: ArgumentSemantics,
+ typ: &ast::Type,
+ idx: Vec<spirv::Word>,
+ ) -> Result<spirv::Word, TranslateError> {
+ // mov.u32 foobar, {a,b};
+ let scalar_t = match typ {
+ ast::Type::Vector(scalar_t, _) => *scalar_t,
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
+ let statement = Statement::RepackVector(RepackVectorDetails {
+ is_extract: is_dst,
+ typ: scalar_t,
+ packed: temp_vec,
+ unpacked: idx,
+ vector_sema,
+ });
+ if is_dst {
+ self.post_stmts = Some(statement);
+ } else {
+ self.func.push(statement);
+ }
+ Ok(temp_vec)
+ }
+}
+
+impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
+ for VectorRepackVisitor<'a, 'b>
+{
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ _: Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
+ Ok(desc.op)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: &ast::Type,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match desc.op {
+ ast::Operand::Reg(reg) => TypedOperand::Reg(reg),
+ ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
+ ast::Operand::Imm(x) => TypedOperand::Imm(x),
+ ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
+ ast::Operand::VecPack(vec) => {
+ TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?)
+ }
+ })
+ }
+}
+
//TODO: share common code between this and to_ptx_impl_bfe_call
fn to_ptx_impl_atomic_call(
id_defs: &mut NumericIdResolver,
@@ -1872,17 +1794,16 @@ fn normalize_labels(
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
- Statement::Composite(_)
- | Statement::Call(_)
- | Statement::Variable(_)
- | Statement::LoadVar(_, _)
- | Statement::StoreVar(_, _)
- | Statement::RetValue(_, _)
- | Statement::Conversion(_)
- | Statement::Constant(_)
- | Statement::Label(_)
- | Statement::Undef(_, _)
- | Statement::PtrAccess { .. } => {}
+ Statement::Call(..)
+ | Statement::Variable(..)
+ | Statement::LoadVar(..)
+ | Statement::StoreVar(..)
+ | Statement::RetValue(..)
+ | Statement::Conversion(..)
+ | Statement::Constant(..)
+ | Statement::Label(..)
+ | Statement::PtrAccess { .. }
+ | Statement::RepackVector(..) => {}
}
}
iter::once(Statement::Label(id_def.new_non_variable(None)))
@@ -1929,7 +1850,7 @@ fn normalize_predicates(
}
Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
}
}
Ok(result)
@@ -1956,7 +1877,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
array_init: arg.array_init.clone(),
}));
}
- None => return Err(TranslateError::Unreachable),
+ None => return Err(error_unreachable()),
}
}
for spirv_arg in fn_decl.input.iter_mut() {
@@ -1970,13 +1891,14 @@ fn insert_mem_ssa_statements<'a, 'b>(
name: spirv_arg.name,
array_init: spirv_arg.array_init.clone(),
}));
- result.push(Statement::StoreVar(
- ast::Arg2St {
+ result.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
src1: spirv_arg.name,
src2: new_id,
},
typ,
- ));
+ member_index: None,
+ }));
spirv_arg.name = new_id;
}
None => {}
@@ -1993,13 +1915,14 @@ fn insert_mem_ssa_statements<'a, 'b>(
if let &[out_param] = &fn_decl.output.as_slice() {
let (typ, _) = id_def.get_typed(out_param.name)?;
let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::LoadVar(
- ast::Arg2 {
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::Arg2 {
dst: new_id,
src: out_param.name,
},
- typ.clone(),
- ));
+ typ: typ.clone(),
+ member_index: None,
+ }));
result.push(Statement::RetValue(d, new_id));
} else {
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
@@ -2010,13 +1933,14 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::Conditional(mut bra) => {
let generated_id =
id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
- result.push(Statement::LoadVar(
- Arg2 {
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: Arg2 {
dst: generated_id,
src: bra.predicate,
},
- ast::Type::Scalar(ast::ScalarType::Pred),
- ));
+ typ: ast::Type::Scalar(ast::ScalarType::Pred),
+ member_index: None,
+ }));
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
@@ -2026,8 +1950,11 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::PtrAccess(ptr_access) => {
insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
}
+ Statement::RepackVector(repack) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, repack)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
}
}
Ok(result)
@@ -2059,101 +1986,156 @@ fn type_to_variable_type(
scalar_type
.clone()
.try_into()
- .map_err(|_| TranslateError::Unreachable)?,
- (*space)
- .try_into()
- .map_err(|_| TranslateError::Unreachable)?,
+ .map_err(|_| error_unreachable())?,
+ (*space).try_into().map_err(|_| error_unreachable())?,
)))
}
ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
})
}
-trait VisitVariable: Sized {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
+ fn visit(
self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError>;
-}
-trait VisitVariableExpanded {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError>;
+ visitor: &mut impl ArgumentMapVisitor<From, To>,
+ ) -> Result<Statement<ast::Instruction<To>, To>, TranslateError>;
}
-struct VisitArgumentDescriptor<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> {
+struct VisitArgumentDescriptor<
+ 'a,
+ Ctor: FnOnce(spirv::Word) -> Statement<ast::Instruction<U>, U>,
+ U: ArgParamsEx,
+> {
desc: ArgumentDescriptor<spirv::Word>,
typ: &'a ast::Type,
stmt_ctor: Ctor,
}
-impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded
- for VisitArgumentDescriptor<'a, Ctor>
+impl<
+ 'a,
+ Ctor: FnOnce(spirv::Word) -> Statement<ast::Instruction<U>, U>,
+ T: ArgParamsEx<Id = spirv::Word>,
+ U: ArgParamsEx<Id = spirv::Word>,
+ > Visitable<T, U> for VisitArgumentDescriptor<'a, Ctor, U>
{
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+ fn visit(
self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- f(self.desc, Some(self.typ)).map(self.stmt_ctor)
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?))
}
}
-fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
- id_def: &mut NumericIdResolver,
- result: &mut Vec<TypedStatement>,
- stmt: F,
-) -> Result<(), TranslateError> {
- let mut post_statements = Vec::new();
- let new_statement = stmt.visit_variable(
- &mut |desc: ArgumentDescriptor<spirv::Word>, expected_type| {
- if expected_type.is_none() {
- return Ok(desc.op);
- };
- let (var_type, is_variable) = id_def.get_typed(desc.op)?;
- if !is_variable {
- return Ok(desc.op);
- }
- let generated_id = id_def.new_non_variable(Some(var_type.clone()));
- if !desc.is_dst {
- result.push(Statement::LoadVar(
- Arg2 {
- dst: generated_id,
- src: desc.op,
+struct InsertMemSSAVisitor<'a, 'input> {
+ id_def: &'a mut NumericIdResolver<'input>,
+ func: &'a mut Vec<TypedStatement>,
+ post_statements: Vec<TypedStatement>,
+}
+
+impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
+ fn symbol(
+ &mut self,
+ desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
+ expected_type: Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
+ let symbol = desc.op.0;
+ if expected_type.is_none() {
+ return Ok(symbol);
+ };
+ let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
+ if !is_variable {
+ return Ok(symbol);
+ };
+ let member_index = match desc.op.1 {
+ Some(idx) => {
+ let vector_width = match var_type {
+ ast::Type::Vector(scalar_t, width) => {
+ var_type = ast::Type::Scalar(scalar_t);
+ width
+ }
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ Some((
+ idx,
+ if self.id_def.special_registers.contains_key(&symbol) {
+ Some(vector_width)
+ } else {
+ None
},
- var_type,
- ));
- } else {
- post_statements.push(Statement::StoreVar(
- Arg2St {
- src1: desc.op,
+ ))
+ }
+ None => None,
+ };
+ let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
+ if !desc.is_dst {
+ self.func.push(Statement::LoadVar(LoadVarDetails {
+ arg: Arg2 {
+ dst: generated_id,
+ src: symbol,
+ },
+ typ: var_type,
+ member_index,
+ }));
+ } else {
+ self.post_statements
+ .push(Statement::StoreVar(StoreVarDetails {
+ arg: Arg2St {
+ src1: symbol,
src2: generated_id,
},
- var_type,
- ));
+ typ: var_type,
+ member_index: member_index.map(|(idx, _)| idx),
+ }));
+ }
+ Ok(generated_id)
+ }
+}
+
+impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
+ for InsertMemSSAVisitor<'a, 'input>
+{
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
+ self.symbol(desc.new_op((desc.op, None)), typ)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<TypedOperand>,
+ typ: &ast::Type,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match desc.op {
+ TypedOperand::Reg(reg) => {
+ TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
- Ok(generated_id)
- },
- )?;
- result.push(new_statement);
- result.append(&mut post_statements);
+ TypedOperand::RegOffset(reg, offset) => {
+ TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset)
+ }
+ op @ TypedOperand::Imm(..) => op,
+ TypedOperand::VecMember(symbol, index) => {
+ TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
+ }
+ })
+ }
+}
+
+fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable<TypedArgParams, TypedArgParams>>(
+ id_def: &'a mut NumericIdResolver<'input>,
+ func: &'a mut Vec<TypedStatement>,
+ stmt: S,
+) -> Result<(), TranslateError> {
+ let mut visitor = InsertMemSSAVisitor {
+ id_def,
+ func,
+ post_statements: Vec::new(),
+ };
+ let new_stmt = stmt.visit(&mut visitor)?;
+ visitor.func.push(new_stmt);
+ visitor.func.extend(visitor.post_statements);
Ok(())
}
@@ -2193,15 +2175,19 @@ fn expand_arguments<'a, 'b>(
result.push(Statement::PtrAccess(new_inst));
result.extend(post_stmts);
}
+ Statement::RepackVector(repack) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_def);
+ let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::RepackVector(new_inst));
+ result.extend(post_stmts);
+ }
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
- Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
- Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
+ Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
+ Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
- Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => {
- return Err(TranslateError::Unreachable)
- }
+ Statement::Constant(_) => return Err(error_unreachable()),
}
}
Ok(result)
@@ -2225,27 +2211,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
}
}
- fn insert_composite_read(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver<'a>,
- typ: (ast::ScalarType, u8),
- scalar_dst: Option<spirv::Word>,
- scalar_sema_override: Option<ArgumentSemantics>,
- composite_src: (spirv::Word, u8),
- ) -> spirv::Word {
- let new_id =
- scalar_dst.unwrap_or_else(|| id_def.new_non_variable(ast::Type::Scalar(typ.0)));
- func.push(Statement::Composite(CompositeRead {
- typ: typ.0,
- dst: new_id,
- dst_semantics_override: scalar_sema_override,
- src_composite: composite_src.0,
- src_index: composite_src.1 as u32,
- src_len: typ.1 as u32,
- }));
- new_id
- }
-
fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
@@ -2367,69 +2332,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
}));
Ok(id)
}
-
- fn member_src(
- &mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- typ: (ast::ScalarType, u8),
- ) -> Result<spirv::Word, TranslateError> {
- if desc.is_dst {
- return Err(TranslateError::Unreachable);
- }
- let new_id = Self::insert_composite_read(
- self.func,
- self.id_def,
- typ,
- None,
- Some(desc.sema),
- desc.op,
- );
- Ok(new_id)
- }
-
- fn vector(
- &mut self,
- desc: ArgumentDescriptor<&Vec<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- let (scalar_type, vec_len) = typ.get_vector()?;
- if !desc.is_dst {
- let mut new_id = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::Undef(typ.clone(), new_id));
- for (idx, id) in desc.op.iter().enumerate() {
- let newer_id = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::Instruction(ast::Instruction::Mov(
- ast::MovDetails {
- typ: ast::Type::Scalar(scalar_type),
- src_is_address: false,
- dst_width: vec_len,
- src_width: 0,
- relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed,
- },
- ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
- (newer_id, idx as u8),
- new_id,
- *id,
- )),
- )));
- new_id = newer_id;
- }
- Ok(new_id)
- } else {
- let new_id = self.id_def.new_non_variable(typ.clone());
- for (idx, id) in desc.op.iter().enumerate() {
- Self::insert_composite_read(
- &mut self.post_stmts,
- self.id_def,
- (scalar_type, vec_len),
- Some(*id),
- Some(desc.sema),
- (new_id, idx as u8),
- );
- }
- Ok(new_id)
- }
- }
}
impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenArguments<'a, 'b> {
@@ -2443,58 +2345,16 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn operand(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
- ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ),
- ast::Operand::RegOffset(reg, offset) => {
+ TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
+ TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ TypedOperand::RegOffset(reg, offset) => {
self.reg_offset(desc.new_op((reg, offset)), typ)
}
- }
- }
-
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)),
- ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
- }
- }
-
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- typ: (ast::ScalarType, u8),
- ) -> Result<spirv::Word, TranslateError> {
- self.member_src(desc, typ)
- }
-
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
- }
- }
-
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ),
- ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ),
- ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
+ TypedOperand::VecMember(..) => Err(error_unreachable()),
}
}
}
@@ -2543,7 +2403,7 @@ fn insert_implicit_conversions(
if let ast::Instruction::AtomCas(d, _) = &inst {
state_space = Some(d.space.to_ld_ss());
}
- if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
+ if let ast::Instruction::Mov(..) = &inst {
default_conversion_fn = should_bitcast_packed;
}
insert_implicit_conversions_impl(
@@ -2554,13 +2414,6 @@ fn insert_implicit_conversions(
state_space,
)?;
}
- Statement::Composite(composite) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- composite,
- should_bitcast_wrapper,
- None,
- )?,
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
@@ -2593,14 +2446,20 @@ fn insert_implicit_conversions(
Some(state_space),
)?;
}
+ Statement::RepackVector(repack) => insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ repack,
+ should_bitcast_wrapper,
+ None,
+ )?,
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
- | s @ Statement::LoadVar(_, _)
- | s @ Statement::StoreVar(_, _)
- | s @ Statement::Undef(_, _)
+ | s @ Statement::LoadVar(..)
+ | s @ Statement::StoreVar(..)
| s @ Statement::RetValue(_, _) => result.push(s),
}
}
@@ -2610,7 +2469,7 @@ fn insert_implicit_conversions(
fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
- stmt: impl VisitVariableExpanded,
+ stmt: impl Visitable<ExpandedArgParams, ExpandedArgParams>,
default_conversion_fn: for<'a> fn(
&'a ast::Type,
&'a ast::Type,
@@ -2619,62 +2478,64 @@ fn insert_implicit_conversions_impl(
state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
- let statement = stmt.visit_variable_extended(&mut |desc, typ| {
- let instr_type = match typ {
- None => return Ok(desc.op),
- Some(t) => t,
- };
- let operand_type = id_def.get_typed(desc.op)?;
- let mut conversion_fn = default_conversion_fn;
- match desc.sema {
- ArgumentSemantics::Default => {}
- ArgumentSemantics::DefaultRelaxed => {
- if desc.is_dst {
- conversion_fn = should_convert_relaxed_dst_wrapper;
- } else {
- conversion_fn = should_convert_relaxed_src_wrapper;
+ let statement = stmt.visit(
+ &mut |desc: ArgumentDescriptor<spirv::Word>, typ: Option<&ast::Type>| {
+ let instr_type = match typ {
+ None => return Ok(desc.op),
+ Some(t) => t,
+ };
+ let operand_type = id_def.get_typed(desc.op)?;
+ let mut conversion_fn = default_conversion_fn;
+ match desc.sema {
+ ArgumentSemantics::Default => {}
+ ArgumentSemantics::DefaultRelaxed => {
+ if desc.is_dst {
+ conversion_fn = should_convert_relaxed_dst_wrapper;
+ } else {
+ conversion_fn = should_convert_relaxed_src_wrapper;
+ }
}
- }
- ArgumentSemantics::PhysicalPointer => {
- conversion_fn = bitcast_physical_pointer;
- }
- ArgumentSemantics::RegisterPointer => {
- conversion_fn = bitcast_register_pointer;
- }
- ArgumentSemantics::Address => {
- conversion_fn = force_bitcast_ptr_to_bit;
- }
- };
- match conversion_fn(&operand_type, instr_type, state_space)? {
- Some(conv_kind) => {
- let conv_output = if desc.is_dst {
- &mut post_conv
- } else {
- &mut *func
- };
- let mut from = instr_type.clone();
- let mut to = operand_type;
- let mut src = id_def.new_non_variable(instr_type.clone());
- let mut dst = desc.op;
- let result = Ok(src);
- if !desc.is_dst {
- mem::swap(&mut src, &mut dst);
- mem::swap(&mut from, &mut to);
+ ArgumentSemantics::PhysicalPointer => {
+ conversion_fn = bitcast_physical_pointer;
}
- conv_output.push(Statement::Conversion(ImplicitConversion {
- src,
- dst,
- from,
- to,
- kind: conv_kind,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
- }));
- result
+ ArgumentSemantics::RegisterPointer => {
+ conversion_fn = bitcast_register_pointer;
+ }
+ ArgumentSemantics::Address => {
+ conversion_fn = force_bitcast_ptr_to_bit;
+ }
+ };
+ match conversion_fn(&operand_type, instr_type, state_space)? {
+ Some(conv_kind) => {
+ let conv_output = if desc.is_dst {
+ &mut post_conv
+ } else {
+ &mut *func
+ };
+ let mut from = instr_type.clone();
+ let mut to = operand_type;
+ let mut src = id_def.new_non_variable(instr_type.clone());
+ let mut dst = desc.op;
+ let result = Ok(src);
+ if !desc.is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from, &mut to);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from,
+ to,
+ kind: conv_kind,
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ result
+ }
+ None => Ok(desc.op),
}
- None => Ok(desc.op),
- }
- })?;
+ },
+ )?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
@@ -2861,38 +2722,11 @@ fn emit_function_body_ops(
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
- ast::Instruction::Mov(d, arg) => match arg {
- ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src })
- | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => {
- let result_type = map
- .get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
- builder.copy_object(result_type, Some(*dst), *src)?;
- }
- ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
- dst,
- composite_src,
- scalar_src,
- ))
- | ast::Arg2Mov::Member(ast::Arg2MovMember::Both(
- dst,
- composite_src,
- scalar_src,
- )) => {
- let scalar_type = d.typ.get_scalar()?;
- let result_type = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)),
- );
- let result_id = Some(dst.0);
- builder.composite_insert(
- result_type,
- result_id,
- *scalar_src,
- *composite_src,
- [dst.1 as u32],
- )?;
- }
- },
+ ast::Instruction::Mov(d, arg) => {
+ let result_type =
+ map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
+ }
ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Signed(ref ctr) => {
emit_mul_sint(builder, map, opencl, ctr, arg)?
@@ -3202,30 +3036,38 @@ fn emit_function_body_ops(
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
},
- Statement::LoadVar(arg, typ) => {
- let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
- builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
+ Statement::LoadVar(details) => {
+ emit_load_var(builder, map, details)?;
}
- Statement::StoreVar(arg, _) => {
- builder.store(arg.src1, arg.src2, None, [])?;
+ Statement::StoreVar(details) => {
+ let dst_ptr = match details.member_index {
+ Some(index) => {
+ let result_ptr_type = map.get_or_add(
+ builder,
+ SpirvType::new_pointer(
+ details.typ.clone(),
+ spirv::StorageClass::Function,
+ ),
+ );
+ let index_spirv = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(index as u32),
+ )?;
+ builder.in_bounds_access_chain(
+ result_ptr_type,
+ None,
+ details.arg.src1,
+ &[index_spirv],
+ )?
+ }
+ None => details.arg.src1,
+ };
+ builder.store(dst_ptr, details.arg.src2, None, [])?;
}
Statement::RetValue(_, id) => {
builder.ret_value(*id)?;
}
- Statement::Composite(c) => {
- let result_type = map.get_or_add_scalar(builder, c.typ.into());
- let result_id = Some(c.dst);
- builder.composite_extract(
- result_type,
- result_id,
- c.src_composite,
- [c.src_index],
- )?;
- }
- Statement::Undef(t, id) => {
- let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
- builder.undef(result_type, Some(*id));
- }
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
@@ -3254,6 +3096,38 @@ fn emit_function_body_ops(
)?;
builder.bitcast(result_type, Some(*dst), temp)?;
}
+ Statement::RepackVector(repack) => {
+ if repack.is_extract {
+ let scalar_type = map.get_or_add_scalar(builder, repack.typ);
+ for (index, dst_id) in repack.unpacked.iter().enumerate() {
+ builder.composite_extract(
+ scalar_type,
+ Some(*dst_id),
+ repack.packed,
+ &[index as u32],
+ )?;
+ }
+ } else {
+ let vector_type = map.get_or_add(
+ builder,
+ SpirvType::Vector(
+ SpirvScalarKey::from(repack.typ),
+ repack.unpacked.len() as u8,
+ ),
+ );
+ let mut temp_vec = builder.undef(vector_type, None);
+ for (index, src_id) in repack.unpacked.iter().enumerate() {
+ temp_vec = builder.composite_insert(
+ vector_type,
+ None,
+ *src_id,
+ temp_vec,
+ &[index as u32],
+ )?;
+ }
+ builder.copy_object(vector_type, Some(repack.packed), temp_vec)?;
+ }
+ }
}
}
Ok(())
@@ -3271,7 +3145,7 @@ fn insert_shift_hack(
2 => map.get_or_add_scalar(builder, ast::ScalarType::B16),
8 => map.get_or_add_scalar(builder, ast::ScalarType::B64),
4 => return Ok(offset_var),
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
};
Ok(builder.u_convert(result_type, None, offset_var)?)
}
@@ -3351,7 +3225,7 @@ fn emit_atom(
let spirv_op = match op {
ast::AtomUIntOp::Add => dr::Builder::atomic_i_add,
ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => {
- return Err(TranslateError::Unreachable);
+ return Err(error_unreachable());
}
ast::AtomUIntOp::Min => dr::Builder::atomic_u_min,
ast::AtomUIntOp::Max => dr::Builder::atomic_u_max,
@@ -4165,6 +4039,58 @@ fn emit_implicit_conversion(
Ok(())
}
+fn emit_load_var(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ details: &LoadVarDetails,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
+ match details.member_index {
+ Some((index, Some(width))) => {
+ let vector_type = match details.typ {
+ ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
+ let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?;
+ builder.composite_extract(
+ result_type,
+ Some(details.arg.dst),
+ vector_temp,
+ &[index as u32],
+ )?;
+ }
+ Some((index, None)) => {
+ let result_ptr_type = map.get_or_add(
+ builder,
+ SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function),
+ );
+ let index_spirv = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(index as u32),
+ )?;
+ let src = builder.in_bounds_access_chain(
+ result_ptr_type,
+ None,
+ details.arg.src,
+ &[index_spirv],
+ )?;
+ builder.load(result_type, Some(details.arg.dst), src, None, [])?;
+ }
+ None => {
+ builder.load(
+ result_type,
+ Some(details.arg.dst),
+ details.arg.src,
+ None,
+ [],
+ )?;
+ }
+ };
+ Ok(())
+}
+
fn normalize_identifiers<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
@@ -4290,9 +4216,11 @@ fn convert_to_stateful_memory_access<'a>(
},
arg,
)) => {
- if let Some(src) = arg.src.underlying() {
- if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) {
- stateful_markers.push((arg.dst, *src));
+ if let (TypedOperand::Reg(dst), Some(src)) =
+ (arg.dst, arg.src.upcast().underlying())
+ {
+ if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) {
+ stateful_markers.push((dst, *src));
}
}
}
@@ -4320,7 +4248,9 @@ fn convert_to_stateful_memory_access<'a>(
},
arg,
)) => {
- if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) {
+ if let (TypedOperand::Reg(dst), Some(src)) =
+ (&arg.dst, arg.src.upcast().underlying())
+ {
if func_args_64bit.contains(src) {
multi_hash_map_append(&mut stateful_init_reg, *dst, *src);
}
@@ -4369,13 +4299,17 @@ fn convert_to_stateful_memory_access<'a>(
}),
arg,
)) => {
- if let Some(src1) = arg.src1.underlying() {
+ if let (TypedOperand::Reg(dst), Some(src1)) =
+ (arg.dst, arg.src1.upcast().underlying())
+ {
if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) {
- regs_ptr_new.insert(arg.dst);
+ regs_ptr_new.insert(dst);
}
- } else if let Some(src2) = arg.src2.underlying() {
+ } else if let (TypedOperand::Reg(dst), Some(src2)) =
+ (arg.dst, arg.src2.upcast().underlying())
+ {
if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) {
- regs_ptr_new.insert(arg.dst);
+ regs_ptr_new.insert(dst);
}
}
}
@@ -4426,19 +4360,20 @@ fn convert_to_stateful_memory_access<'a>(
}),
arg,
)) if is_add_ptr_direct(&remapped_ids, &arg) => {
- let (ptr, offset) = match arg.src1.underlying() {
+ let (ptr, offset) = match arg.src1.upcast().underlying() {
Some(src1) if remapped_ids.contains_key(src1) => {
(remapped_ids.get(src1).unwrap(), arg.src2)
}
Some(src2) if remapped_ids.contains_key(src2) => {
(remapped_ids.get(src2).unwrap(), arg.src1)
}
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
};
+ let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
state_space: ast::LdStateSpace::Global,
- dst: *remapped_ids.get(&arg.dst).unwrap(),
+ dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
}))
@@ -4454,14 +4389,14 @@ fn convert_to_stateful_memory_access<'a>(
}),
arg,
)) if is_add_ptr_direct(&remapped_ids, &arg) => {
- let (ptr, offset) = match arg.src1.underlying() {
+ let (ptr, offset) = match arg.src1.upcast().underlying() {
Some(src1) if remapped_ids.contains_key(src1) => {
(remapped_ids.get(src1).unwrap(), arg.src2)
}
Some(src2) if remapped_ids.contains_key(src2) => {
(remapped_ids.get(src2).unwrap(), arg.src1)
}
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
};
let offset_neg =
id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
@@ -4472,21 +4407,23 @@ fn convert_to_stateful_memory_access<'a>(
},
ast::Arg2 {
src: offset,
- dst: offset_neg,
+ dst: TypedOperand::Reg(offset_neg),
},
)));
+ let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
state_space: ast::LdStateSpace::Global,
- dst: *remapped_ids.get(&arg.dst).unwrap(),
+ dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
- offset_src: ast::Operand::Reg(offset_neg),
+ offset_src: TypedOperand::Reg(offset_neg),
}))
}
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
- let new_statement = inst.visit_variable(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ let new_statement = inst.visit(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
+ expected_type: Option<&ast::Type>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
@@ -4499,14 +4436,13 @@ fn convert_to_stateful_memory_access<'a>(
},
)?;
result.push(new_statement);
- for s in post_statements {
- result.push(s);
- }
+ result.extend(post_statements);
}
Statement::Call(call) => {
let mut post_statements = Vec::new();
- let new_statement = call.visit_variable(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ let new_statement = call.visit(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
+ expected_type: Option<&ast::Type>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
@@ -4519,11 +4455,28 @@ fn convert_to_stateful_memory_access<'a>(
},
)?;
result.push(new_statement);
- for s in post_statements {
- result.push(s);
- }
+ result.extend(post_statements);
+ }
+ Statement::RepackVector(pack) => {
+ let mut post_statements = Vec::new();
+ let new_statement = pack.visit(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
+ expected_type: Option<&ast::Type>| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ expected_type,
+ )
+ },
+ )?;
+ result.push(new_statement);
+ result.extend(post_statements);
}
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
}
}
for arg in func_args.input.iter_mut() {
@@ -4588,7 +4541,7 @@ fn convert_to_stateful_memory_access_postprocess(
None => match func_args_ptr.get(&arg_desc.op) {
Some(new_id) => {
if arg_desc.is_dst {
- return Err(TranslateError::Unreachable);
+ return Err(error_unreachable());
}
// We skip conversion here to trigger PtrAcces in a later pass
let old_type = match expected_type {
@@ -4617,13 +4570,20 @@ fn convert_to_stateful_memory_access_postprocess(
}
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
- if !remapped_ids.contains_key(&arg.dst) {
- return false;
- }
- match arg.src1.underlying() {
- Some(src1) if remapped_ids.contains_key(src1) => true,
- Some(src2) if remapped_ids.contains_key(src2) => true,
- _ => false,
+ match arg.dst {
+ TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
+ return false
+ }
+ TypedOperand::Reg(dst) => {
+ if !remapped_ids.contains_key(&dst) {
+ return false;
+ }
+ match arg.src1.upcast().underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => true,
+ Some(src2) if remapped_ids.contains_key(src2) => true,
+ _ => false,
+ }
+ }
}
}
@@ -4962,14 +4922,13 @@ enum Statement<I, P: ast::ArgParams> {
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Call(ResolvedCall<P>),
- LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
- StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
- Composite(CompositeRead),
+ LoadVar(LoadVarDetails),
+ StoreVar(StoreVarDetails),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
- Undef(ast::Type, spirv::Word),
PtrAccess(PtrAccess<P>),
+ RepackVector(RepackVectorDetails),
}
impl ExpandedStatement {
@@ -4981,19 +4940,19 @@ impl ExpandedStatement {
Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| {
+ .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| {
Ok(f(arg.op, arg.is_dst))
})
.unwrap(),
- Statement::LoadVar(mut arg, typ) => {
- arg.dst = f(arg.dst, true);
- arg.src = f(arg.src, false);
- Statement::LoadVar(arg, typ)
+ Statement::LoadVar(mut details) => {
+ details.arg.dst = f(details.arg.dst, true);
+ details.arg.src = f(details.arg.src, false);
+ Statement::LoadVar(details)
}
- Statement::StoreVar(mut arg, typ) => {
- arg.src1 = f(arg.src1, false);
- arg.src2 = f(arg.src2, false);
- Statement::StoreVar(arg, typ)
+ Statement::StoreVar(mut details) => {
+ details.arg.src1 = f(details.arg.src1, false);
+ details.arg.src2 = f(details.arg.src2, false);
+ Statement::StoreVar(details)
}
Statement::Call(mut call) => {
for (id, typ) in call.ret_params.iter_mut() {
@@ -5010,11 +4969,6 @@ impl ExpandedStatement {
}
Statement::Call(call)
}
- Statement::Composite(mut composite) => {
- composite.dst = f(composite.dst, true);
- composite.src_composite = f(composite.src_composite, false);
- Statement::Composite(composite)
- }
Statement::Conditional(mut conditional) => {
conditional.predicate = f(conditional.predicate, false);
conditional.if_true = f(conditional.if_true, false);
@@ -5034,10 +4988,6 @@ impl ExpandedStatement {
let id = f(id, false);
Statement::RetValue(data, id)
}
- Statement::Undef(typ, id) => {
- let id = f(id, true);
- Statement::Undef(typ, id)
- }
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
@@ -5056,19 +5006,100 @@ impl ExpandedStatement {
offset_src: constant_src,
})
}
+ Statement::RepackVector(_) => todo!(),
}
}
}
+struct LoadVarDetails {
+ arg: ast::Arg2<ExpandedArgParams>,
+ typ: ast::Type,
+ // (index, vector_width)
+ // HACK ALERT
+ // For some reason IGC explodes when you try to load from builtin vectors
+ // using OpInBoundsAccessChain, the one true way to do it is to
+ // OpLoad+OpCompositeExtract
+ member_index: Option<(u8, Option<u8>)>,
+}
+
+struct StoreVarDetails {
+ arg: ast::Arg2St<ExpandedArgParams>,
+ typ: ast::Type,
+ member_index: Option<u8>,
+}
+
+struct RepackVectorDetails {
+ is_extract: bool,
+ typ: ast::ScalarType,
+ packed: spirv::Word,
+ unpacked: Vec<spirv::Word>,
+ vector_sema: ArgumentSemantics,
+}
+
+impl RepackVectorDetails {
+ fn map<
+ From: ArgParamsEx<Id = spirv::Word>,
+ To: ArgParamsEx<Id = spirv::Word>,
+ V: ArgumentMapVisitor<From, To>,
+ >(
+ self,
+ visitor: &mut V,
+ ) -> Result<RepackVectorDetails, TranslateError> {
+ let scalar = visitor.id(
+ ArgumentDescriptor {
+ op: self.packed,
+ is_dst: !self.is_extract,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
+ )?;
+ let scalar_type = self.typ;
+ let is_extract = self.is_extract;
+ let vector_sema = self.vector_sema;
+ let vector = self
+ .unpacked
+ .into_iter()
+ .map(|id| {
+ visitor.id(
+ ArgumentDescriptor {
+ op: id,
+ is_dst: is_extract,
+ sema: vector_sema,
+ },
+ Some(&ast::Type::Scalar(scalar_type)),
+ )
+ })
+ .collect::<Result<_, _>>()?;
+ Ok(RepackVectorDetails {
+ is_extract,
+ typ: self.typ,
+ packed: scalar,
+ unpacked: vector,
+ vector_sema,
+ })
+ }
+}
+
+impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
+ for RepackVectorDetails
+{
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::RepackVector(self.map::<_, _, _>(visitor)?))
+ }
+}
+
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
- pub func: spirv::Word,
- pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
+ pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
+ pub func: P::Id,
+ pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
- fn cast<U: ast::ArgParams<CallOperand = T::CallOperand>>(self) -> ResolvedCall<U> {
+ fn cast<U: ast::ArgParams<Id = T::Id, Operand = T::Operand>>(self) -> ResolvedCall<U> {
ResolvedCall {
uniform: self.uniform,
ret_params: self.ret_params,
@@ -5110,7 +5141,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
.param_list
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
- let new_id = visitor.src_call_operand(
+ let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
@@ -5130,32 +5161,14 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
}
}
-impl VisitVariable for ResolvedCall<TypedArgParams> {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- Ok(Statement::Call(self.map(f)?))
- }
-}
-
-impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
+ for ResolvedCall<T>
+{
+ fn visit(
self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- Ok(Statement::Call(self.map(f)?))
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::Call(self.map(visitor)?))
}
}
@@ -5208,18 +5221,14 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
}
}
-impl VisitVariable for PtrAccess<TypedArgParams> {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
+ for PtrAccess<T>
+{
+ fn visit(
self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- Ok(Statement::PtrAccess(self.map(f)?))
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::PtrAccess(self.map(visitor)?))
}
}
@@ -5244,10 +5253,6 @@ enum NormalizedArgParams {}
impl ast::ArgParams for NormalizedArgParams {
type Id = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
- type CallOperand = ast::CallOperand<spirv::Word>;
- type IdOrVector = ast::IdOrVector<spirv::Word>;
- type OperandOrVector = ast::OperandOrVector<spirv::Word>;
- type SrcMemberOperand = (spirv::Word, u8);
}
impl ArgParamsEx for NormalizedArgParams {
@@ -5273,11 +5278,7 @@ enum TypedArgParams {}
impl ast::ArgParams for TypedArgParams {
type Id = spirv::Word;
- type Operand = ast::Operand<spirv::Word>;
- type CallOperand = ast::CallOperand<spirv::Word>;
- type IdOrVector = ast::IdOrVector<spirv::Word>;
- type OperandOrVector = ast::OperandOrVector<spirv::Word>;
- type SrcMemberOperand = (spirv::Word, u8);
+ type Operand = TypedOperand;
}
impl ArgParamsEx for TypedArgParams {
@@ -5289,6 +5290,25 @@ impl ArgParamsEx for TypedArgParams {
}
}
+#[derive(Copy, Clone)]
+enum TypedOperand {
+ Reg(spirv::Word),
+ RegOffset(spirv::Word, i32),
+ Imm(ast::ImmediateValue),
+ VecMember(spirv::Word, u8),
+}
+
+impl TypedOperand {
+ fn upcast(self) -> ast::Operand<spirv::Word> {
+ match self {
+ TypedOperand::Reg(reg) => ast::Operand::Reg(reg),
+ TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx),
+ TypedOperand::Imm(x) => ast::Operand::Imm(x),
+ TypedOperand::VecMember(vec, idx) => ast::Operand::VecMember(vec, idx),
+ }
+ }
+}
+
type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
enum ExpandedArgParams {}
@@ -5297,10 +5317,6 @@ type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, Expanded
impl ast::ArgParams for ExpandedArgParams {
type Id = spirv::Word;
type Operand = spirv::Word;
- type CallOperand = spirv::Word;
- type IdOrVector = spirv::Word;
- type OperandOrVector = spirv::Word;
- type SrcMemberOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
@@ -5312,29 +5328,6 @@ impl ArgParamsEx for ExpandedArgParams {
}
}
-#[derive(Copy, Clone)]
-pub enum StateSpace {
- Reg,
- Const,
- Global,
- Local,
- Shared,
- Param,
-}
-
-impl From<ast::StateSpace> for StateSpace {
- fn from(ss: ast::StateSpace) -> Self {
- match ss {
- ast::StateSpace::Reg => StateSpace::Reg,
- ast::StateSpace::Const => StateSpace::Const,
- ast::StateSpace::Global => StateSpace::Global,
- ast::StateSpace::Local => StateSpace::Local,
- ast::StateSpace::Shared => StateSpace::Shared,
- ast::StateSpace::Param => StateSpace::Param,
- }
- }
-}
-
enum Directive<'input> {
Variable(ast::Variable<ast::VariableType, spirv::Word>),
Method(Function<'input>),
@@ -5359,26 +5352,6 @@ pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
desc: ArgumentDescriptor<T::Operand>,
typ: &ast::Type,
) -> Result<U::Operand, TranslateError>;
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<T::IdOrVector>,
- typ: &ast::Type,
- ) -> Result<U::IdOrVector, TranslateError>;
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<T::OperandOrVector>,
- typ: &ast::Type,
- ) -> Result<U::OperandOrVector, TranslateError>;
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<T::CallOperand>,
- typ: &ast::Type,
- ) -> Result<U::CallOperand, TranslateError>;
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<T::SrcMemberOperand>,
- typ: (ast::ScalarType, u8),
- ) -> Result<U::SrcMemberOperand, TranslateError>;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
@@ -5399,42 +5372,10 @@ where
fn operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(t))
- }
-
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<spirv::Word>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(typ))
}
-
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(typ))
- }
-
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- t: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(t))
- }
-
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- (scalar_type, _): (ast::ScalarType, u8),
- ) -> Result<spirv::Word, TranslateError> {
- self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type)))
- }
}
impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> for T
@@ -5452,62 +5393,19 @@ where
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
- _: &ast::Type,
+ typ: &ast::Type,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
- match desc.op {
- ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)),
- ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)),
- ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
- }
- }
-
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::IdOrVector<&'a str>>,
- _: &ast::Type,
- ) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
- match desc.op {
- ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)),
- ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec(
- ids.into_iter().map(self).collect::<Result<_, _>>()?,
- )),
- }
- }
-
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::OperandOrVector<&'a str>>,
- _: &ast::Type,
- ) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
- match desc.op {
- ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)),
- ast::OperandOrVector::RegOffset(id, imm) => {
- Ok(ast::OperandOrVector::RegOffset(self(id)?, imm))
- }
- ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
- ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec(
- ids.into_iter().map(self).collect::<Result<_, _>>()?,
- )),
- }
- }
-
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<ast::CallOperand<&str>>,
- _: &ast::Type,
- ) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
- match desc.op {
- ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)),
- ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)),
- }
- }
-
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<(&str, u8)>,
- _: (ast::ScalarType, u8),
- ) -> Result<(spirv::Word, u8), TranslateError> {
- Ok((self(desc.op.0)?, desc.op.1))
+ Ok(match desc.op {
+ ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?),
+ ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm),
+ ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
+ ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member),
+ ast::Operand::VecPack(ref ids) => ast::Operand::VecPack(
+ ids.into_iter()
+ .map(|id| self.id(desc.new_op(id), Some(typ)))
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ })
}
}
@@ -5559,7 +5457,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
- ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
+ ast::Instruction::Call(_) => return Err(error_unreachable()),
ast::Instruction::Ld(d, a) => {
let new_args = a.map(visitor, &d)?;
ast::Instruction::Ld(d, new_args)
@@ -5752,18 +5650,12 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
}
}
-impl VisitVariable for ast::Instruction<TypedArgParams> {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+impl<T: ArgParamsEx, U: ArgParamsEx> Visitable<T, U> for ast::Instruction<T> {
+ fn visit(
self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- Ok(Statement::Instruction(self.map(f)?))
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::Instruction(self.map(visitor)?))
}
}
@@ -5802,32 +5694,14 @@ impl ImplicitConversion {
}
}
-impl VisitVariable for ImplicitConversion {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- self.map(f)
- }
-}
-
-impl VisitVariableExpanded for ImplicitConversion {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
- >(
+impl<From: ArgParamsEx<Id = spirv::Word>, To: ArgParamsEx<Id = spirv::Word>> Visitable<From, To>
+ for ImplicitConversion
+{
+ fn visit(
self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- self.map(f)
+ visitor: &mut impl ArgumentMapVisitor<From, To>,
+ ) -> Result<Statement<ast::Instruction<To>, To>, TranslateError> {
+ Ok(self.map(visitor)?)
}
}
@@ -5848,79 +5722,24 @@ where
fn operand(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
- t: &ast::Type,
- ) -> Result<ast::Operand<spirv::Word>, TranslateError> {
- match desc.op {
- ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)),
- ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
- ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(
- self(desc.new_op(id), Some(t))?,
- imm,
- )),
- }
- }
-
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
- t: &ast::Type,
- ) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
- match desc.op {
- ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)),
- ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)),
- }
- }
-
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
+ desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
- ) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
- match desc.op {
- ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)),
- ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec(
- ids.iter()
- .map(|id| self(desc.new_op(*id), Some(typ)))
- .collect::<Result<_, _>>()?,
- )),
- }
- }
-
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
- match desc.op {
- ast::OperandOrVector::Reg(id) => {
- Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?))
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match desc.op {
+ TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?),
+ TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
+ TypedOperand::RegOffset(id, imm) => {
+ TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm)
+ }
+ TypedOperand::VecMember(reg, index) => {
+ let scalar_type = match typ {
+ ast::Type::Scalar(scalar_t) => *scalar_t,
+ _ => return Err(error_unreachable()),
+ };
+ let vec_type = ast::Type::Vector(scalar_type, index + 1);
+ TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index)
}
- ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset(
- self(desc.new_op(id), Some(typ))?,
- imm,
- )),
- ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
- ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec(
- ids.iter()
- .map(|id| self(desc.new_op(*id), Some(typ)))
- .collect::<Result<_, _>>()?,
- )),
- }
- }
-
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vector_len): (ast::ScalarType, u8),
- ) -> Result<(spirv::Word, u8), TranslateError> {
- Ok((
- self(
- desc.new_op(desc.op.0),
- Some(&ast::Type::Vector(scalar_type.into(), vector_len)),
- )?,
- desc.op.1,
- ))
+ })
}
}
@@ -5942,7 +5761,7 @@ impl ast::Type {
kind,
)))
}
- _ => Err(TranslateError::Unreachable),
+ _ => Err(error_unreachable()),
}
}
@@ -6182,67 +6001,9 @@ impl ast::Instruction<ExpandedArgParams> {
}
}
-impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- Ok(Statement::Instruction(self.map(f)?))
- }
-}
-
type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
-struct CompositeRead {
- pub typ: ast::ScalarType,
- pub dst: spirv::Word,
- pub dst_semantics_override: Option<ArgumentSemantics>,
- pub src_composite: spirv::Word,
- pub src_index: u32,
- pub src_len: u32,
-}
-
-impl VisitVariableExpanded for CompositeRead {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- let dst_sema = self
- .dst_semantics_override
- .unwrap_or(ArgumentSemantics::Default);
- Ok(Statement::Composite(CompositeRead {
- dst: f(
- ArgumentDescriptor {
- op: self.dst,
- is_dst: true,
- sema: dst_sema,
- },
- Some(&ast::Type::Scalar(self.typ)),
- )?,
- src_composite: f(
- ArgumentDescriptor {
- op: self.src_composite,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(self.typ, self.src_len as u8)),
- )?,
- ..self
- }))
- }
-}
-
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
@@ -6330,10 +6091,6 @@ impl From<ast::KernelArgumentType> for ast::Type {
}
impl<T: ArgParamsEx> ast::Arg1<T> {
- fn cast<U: ArgParamsEx<Id = T::Id>>(self) -> ast::Arg1<U> {
- ast::Arg1 { src: self.src }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6352,10 +6109,6 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
}
impl<T: ArgParamsEx> ast::Arg1Bar<T> {
- fn cast<U: ArgParamsEx<Operand = T::Operand>>(self) -> ast::Arg1Bar<U> {
- ast::Arg1Bar { src: self.src }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6373,25 +6126,18 @@ impl<T: ArgParamsEx> ast::Arg1Bar<T> {
}
impl<T: ArgParamsEx> ast::Arg2<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg2<U> {
- ast::Arg2 {
- src: self.src,
- dst: self.dst,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
- let new_dst = visitor.id(
+ let new_dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ t,
)?;
let new_src = visitor.operand(
ArgumentDescriptor {
@@ -6413,13 +6159,13 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
dst_t: &ast::Type,
src_t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(dst_t),
+ dst_t,
)?;
let src = visitor.operand(
ArgumentDescriptor {
@@ -6434,21 +6180,12 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
}
impl<T: ArgParamsEx> ast::Arg2Ld<T> {
- fn cast<U: ArgParamsEx<Operand = T::Operand, IdOrVector = T::IdOrVector>>(
- self,
- ) -> ast::Arg2Ld<U> {
- ast::Arg2Ld {
- dst: self.dst,
- src: self.src,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
details: &ast::LdDetails,
) -> Result<ast::Arg2Ld<U>, TranslateError> {
- let dst = visitor.id_or_vector(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -6478,15 +6215,6 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
}
impl<T: ArgParamsEx> ast::Arg2St<T> {
- fn cast<U: ArgParamsEx<Operand = T::Operand, OperandOrVector = T::OperandOrVector>>(
- self,
- ) -> ast::Arg2St<U> {
- ast::Arg2St {
- src1: self.src1,
- src2: self.src2,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6509,7 +6237,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
details.state_space.to_ld_ss(),
),
)?;
- let src2 = visitor.operand_or_vector(
+ let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
@@ -6527,29 +6255,7 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> {
visitor: &mut V,
details: &ast::MovDetails,
) -> Result<ast::Arg2Mov<U>, TranslateError> {
- Ok(match self {
- ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?),
- ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?),
- })
- }
-}
-
-impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
- fn cast<U: ArgParamsEx<IdOrVector = P::IdOrVector, OperandOrVector = P::OperandOrVector>>(
- self,
- ) -> ast::Arg2MovNormal<U> {
- ast::Arg2MovNormal {
- dst: self.dst,
- src: self.src,
- }
- }
-
- fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
- self,
- visitor: &mut V,
- details: &ast::MovDetails,
- ) -> Result<ast::Arg2MovNormal<U>, TranslateError> {
- let dst = visitor.id_or_vector(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -6557,7 +6263,7 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
},
&details.typ.clone().into(),
)?;
- let src = visitor.operand_or_vector(
+ let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
@@ -6569,144 +6275,11 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
},
&details.typ.clone().into(),
)?;
- Ok(ast::Arg2MovNormal { dst, src })
- }
-}
-
-impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, SrcMemberOperand = T::SrcMemberOperand>>(
- self,
- ) -> ast::Arg2MovMember<U> {
- match self {
- ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2),
- ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src),
- ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2),
- }
- }
-
- fn vector_dst(&self) -> Option<&T::Id> {
- match self {
- ast::Arg2MovMember::Src(_, _) => None,
- ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => {
- Some(d)
- }
- }
- }
-
- fn vector_src(&self) -> Option<&T::SrcMemberOperand> {
- match self {
- ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d),
- ast::Arg2MovMember::Dst(_, _, _) => None,
- }
- }
-}
-
-impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
- fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
- self,
- visitor: &mut V,
- details: &ast::MovDetails,
- ) -> Result<ast::Arg2MovMember<U>, TranslateError> {
- match self {
- ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
- let scalar_type = details.typ.get_scalar()?;
- let dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(scalar_type, details.dst_width)),
- )?;
- let src1 = visitor.id(
- ArgumentDescriptor {
- op: composite_src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(scalar_type, details.dst_width)),
- )?;
- let src2 = visitor.id(
- ArgumentDescriptor {
- op: scalar_src,
- is_dst: false,
- sema: if details.src_is_address {
- ArgumentSemantics::Address
- } else if details.relaxed_src2_conv {
- ArgumentSemantics::DefaultRelaxed
- } else {
- ArgumentSemantics::Default
- },
- },
- Some(&details.typ.clone().into()),
- )?;
- Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2))
- }
- ast::Arg2MovMember::Src(dst, src) => {
- let dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(&details.typ.clone().into()),
- )?;
- let scalar_typ = details.typ.get_scalar()?;
- let src = visitor.src_member_operand(
- ArgumentDescriptor {
- op: src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- (scalar_typ.into(), details.src_width),
- )?;
- Ok(ast::Arg2MovMember::Src(dst, src))
- }
- ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
- let scalar_type = details.typ.get_scalar()?;
- let dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(scalar_type, details.dst_width)),
- )?;
- let composite_src = visitor.id(
- ArgumentDescriptor {
- op: composite_src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(scalar_type, details.dst_width)),
- )?;
- let src = visitor.src_member_operand(
- ArgumentDescriptor {
- op: src,
- is_dst: false,
- sema: if details.relaxed_src2_conv {
- ArgumentSemantics::DefaultRelaxed
- } else {
- ArgumentSemantics::Default
- },
- },
- (scalar_type.into(), details.src_width),
- )?;
- Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
- }
- }
+ Ok(ast::Arg2Mov { dst, src })
}
}
impl<T: ArgParamsEx> ast::Arg3<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg3<U> {
- ast::Arg3 {
- dst: self.dst,
- src1: self.src1,
- src2: self.src2,
- }
- }
-
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6718,13 +6291,13 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
} else {
None
};
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(wide_type.as_ref().unwrap_or(typ)),
+ wide_type.as_ref().unwrap_or(typ),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6750,13 +6323,13 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
visitor: &mut V,
t: &ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ t,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6784,13 +6357,13 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
state_space: ast::AtomSpace,
) -> Result<ast::Arg3<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ &ast::Type::Scalar(scalar_type),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6816,15 +6389,6 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
}
impl<T: ArgParamsEx> ast::Arg4<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4<U> {
- ast::Arg4 {
- dst: self.dst,
- src1: self.src1,
- src2: self.src2,
- src3: self.src3,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6836,13 +6400,13 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
} else {
None
};
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(wide_type.as_ref().unwrap_or(t)),
+ wide_type.as_ref().unwrap_or(t),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6881,13 +6445,13 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
visitor: &mut V,
t: ast::SelpType,
) -> Result<ast::Arg4<U>, TranslateError> {
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(t.into())),
+ &ast::Type::Scalar(t.into()),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6928,13 +6492,13 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
state_space: ast::AtomSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ &ast::Type::Scalar(scalar_type),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6976,13 +6540,13 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
visitor: &mut V,
typ: &ast::Type,
) -> Result<ast::Arg4<U>, TranslateError> {
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(typ),
+ typ,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -7019,15 +6583,6 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
}
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4Setp<U> {
- ast::Arg4Setp {
- dst1: self.dst1,
- dst2: self.dst2,
- src1: self.src1,
- src2: self.src2,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -7079,22 +6634,12 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
}
}
-impl<T: ArgParamsEx> ast::Arg5<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg5<U> {
- ast::Arg5 {
- dst1: self.dst1,
- dst2: self.dst2,
- src1: self.src1,
- src2: self.src2,
- src3: self.src3,
- }
- }
-
+impl<T: ArgParamsEx> ast::Arg5Setp<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
- ) -> Result<ast::Arg5<U>, TranslateError> {
+ ) -> Result<ast::Arg5Setp<U>, TranslateError> {
let dst1 = visitor.id(
ArgumentDescriptor {
op: self.dst1,
@@ -7140,7 +6685,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
},
&ast::Type::Scalar(ast::ScalarType::Pred),
)?;
- Ok(ast::Arg5 {
+ Ok(ast::Arg5Setp {
dst1,
dst2,
src1,
@@ -7150,30 +6695,28 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
}
}
-impl ast::Type {
- fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> {
- match self {
- ast::Type::Vector(t, len) => Ok((*t, *len)),
- _ => Err(TranslateError::MismatchedType),
- }
- }
-
- fn get_scalar(&self) -> Result<ast::ScalarType, TranslateError> {
- match self {
- ast::Type::Scalar(t) => Ok(*t),
- _ => Err(TranslateError::MismatchedType),
- }
- }
-}
-
-impl<T> ast::CallOperand<T> {
+impl<T> ast::Operand<T> {
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
self,
f: &mut F,
- ) -> Result<ast::CallOperand<U>, TranslateError> {
+ ) -> Result<ast::Operand<U>, TranslateError> {
+ Ok(match self {
+ ast::Operand::Reg(reg) => ast::Operand::Reg(f(reg)?),
+ ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset),
+ ast::Operand::Imm(x) => ast::Operand::Imm(x),
+ ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx),
+ ast::Operand::VecPack(vec) => {
+ ast::Operand::VecPack(vec.into_iter().map(f).collect::<Result<_, _>>()?)
+ }
+ })
+ }
+}
+
+impl ast::Operand<spirv::Word> {
+ fn unwrap_reg(&self) -> Result<spirv::Word, TranslateError> {
match self {
- ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(f(id)?)),
- ast::CallOperand::Imm(x) => Ok(ast::CallOperand::Imm(x)),
+ ast::Operand::Reg(reg) => Ok(*reg),
+ _ => Err(error_unreachable()),
}
}
}
@@ -7394,15 +6937,8 @@ impl<T> ast::Operand<T> {
match self {
ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r),
ast::Operand::Imm(_) => None,
- }
- }
-}
-
-impl<T> ast::OperandOrVector<T> {
- fn single_underlying(&self) -> Option<&T> {
- match self {
- ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r),
- ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None,
+ ast::Operand::VecMember(reg, _) => Some(reg),
+ ast::Operand::VecPack(..) => None,
}
}
}
@@ -7500,7 +7036,7 @@ fn bitcast_physical_pointer(
if let Some(space) = ss {
Ok(Some(ConversionKind::BitToPtr(space)))
} else {
- Err(TranslateError::Unreachable)
+ Err(error_unreachable())
}
}
ast::Type::Scalar(ast::ScalarType::B32)