diff options
Diffstat (limited to 'ptx')
-rw-r--r-- | ptx/src/ast.rs | 81 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 121 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/vector.spvtxt | 124 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/vector_extract.spvtxt | 183 | ||||
-rw-r--r-- | ptx/src/translate.rs | 1996 |
5 files changed, 993 insertions, 1512 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 367f060..aba6bda 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -557,7 +557,7 @@ pub enum Instruction<P: ArgParams> { Mul(MulDetails, Arg3<P>), Add(ArithDetails, Arg3<P>), Setp(SetpData, Arg4Setp<P>), - SetpBool(SetpBoolData, Arg5<P>), + SetpBool(SetpBoolData, Arg5Setp<P>), Not(BooleanType, Arg2<P>), Bra(BraData, Arg1<P>), Cvt(CvtDetails, Arg2<P>), @@ -614,16 +614,12 @@ pub struct CallInst<P: ArgParams> { pub uniform: bool, pub ret_params: Vec<P::Id>, pub func: P::Id, - pub param_list: Vec<P::CallOperand>, + pub param_list: Vec<P::Operand>, } pub trait ArgParams { type Id; type Operand; - type IdOrVector; - type OperandOrVector; - type CallOperand; - type SrcMemberOperand; } pub struct ParsedArgParams<'a> { @@ -633,10 +629,6 @@ pub struct ParsedArgParams<'a> { impl<'a> ArgParams for ParsedArgParams<'a> { type Id = &'a str; type Operand = Operand<&'a str>; - type CallOperand = CallOperand<&'a str>; - type IdOrVector = IdOrVector<&'a str>; - type OperandOrVector = OperandOrVector<&'a str>; - type SrcMemberOperand = (&'a str, u8); } pub struct Arg1<P: ArgParams> { @@ -648,45 +640,32 @@ pub struct Arg1Bar<P: ArgParams> { } pub struct Arg2<P: ArgParams> { - pub dst: P::Id, + pub dst: P::Operand, pub src: P::Operand, } pub struct Arg2Ld<P: ArgParams> { - pub dst: P::IdOrVector, + pub dst: P::Operand, pub src: P::Operand, } pub struct Arg2St<P: ArgParams> { pub src1: P::Operand, - pub src2: P::OperandOrVector, -} - -pub enum Arg2Mov<P: ArgParams> { - Normal(Arg2MovNormal<P>), - Member(Arg2MovMember<P>), -} - -pub struct Arg2MovNormal<P: ArgParams> { - pub dst: P::IdOrVector, - pub src: P::OperandOrVector, + pub src2: P::Operand, } -// We duplicate dst here because during further compilation -// composite dst and composite src will receive different ids -pub enum Arg2MovMember<P: ArgParams> { - Dst((P::Id, u8), P::Id, P::Id), - Src(P::Id, P::SrcMemberOperand), - Both((P::Id, u8), P::Id, P::SrcMemberOperand), +pub struct Arg2Mov<P: ArgParams> { + pub dst: P::Operand, + pub src: P::Operand, } pub struct Arg3<P: ArgParams> { - pub dst: P::Id, + pub dst: P::Operand, pub src1: P::Operand, pub src2: P::Operand, } pub struct Arg4<P: ArgParams> { - pub dst: P::Id, + pub dst: P::Operand, pub src1: P::Operand, pub src2: P::Operand, pub src3: P::Operand, @@ -699,7 +678,7 @@ pub struct Arg4Setp<P: ArgParams> { pub src2: P::Operand, } -pub struct Arg5<P: ArgParams> { +pub struct Arg5Setp<P: ArgParams> { pub dst1: P::Id, pub dst2: Option<P::Id>, pub src1: P::Operand, @@ -715,39 +694,13 @@ pub enum ImmediateValue { F64(f64), } -#[derive(Copy, Clone)] -pub enum Operand<ID> { - Reg(ID), - RegOffset(ID, i32), - Imm(ImmediateValue), -} - -#[derive(Copy, Clone)] -pub enum CallOperand<ID> { - Reg(ID), - Imm(ImmediateValue), -} - -pub enum IdOrVector<ID> { - Reg(ID), - Vec(Vec<ID>), -} - -pub enum OperandOrVector<ID> { - Reg(ID), - RegOffset(ID, i32), +#[derive(Clone)] +pub enum Operand<Id> { + Reg(Id), + RegOffset(Id, i32), Imm(ImmediateValue), - Vec(Vec<ID>), -} - -impl<T> From<Operand<T>> for OperandOrVector<T> { - fn from(this: Operand<T>) -> Self { - match this { - Operand::Reg(r) => OperandOrVector::Reg(r), - Operand::RegOffset(r, imm) => OperandOrVector::RegOffset(r, imm), - Operand::Imm(imm) => OperandOrVector::Imm(imm), - } - } + VecMember(Id, u8), + VecPack(Vec<Id>), } pub enum VectorPrefix { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d2c235a..fd2a3f1 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -721,7 +721,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { - "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:IdOrVector> "," <src:MemoryOperand> => { + "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:DstOperandVec> "," <src:MemoryOperand> => { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -734,16 +734,6 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { } }; -IdOrVector: ast::IdOrVector<&'input str> = { - <dst:ExtendedID> => ast::IdOrVector::Reg(dst), - <dst:VectorExtract> => ast::IdOrVector::Vec(dst) -} - -OperandOrVector: ast::OperandOrVector<&'input str> = { - <op:Operand> => ast::OperandOrVector::from(op), - <dst:VectorExtract> => ast::OperandOrVector::Vec(dst) -} - LdStType: ast::LdStType = { <v:VectorPrefix> <t:LdStScalarType> => ast::LdStType::Vector(t, v), <t:LdStScalarType> => ast::LdStType::Scalar(t), @@ -780,27 +770,17 @@ LdCacheOperator: ast::LdCacheOperator = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = { - <m:MovNormal> => ast::Instruction::Mov(m.0, m.1), - <m:MovVector> => ast::Instruction::Mov(m.0, m.1), -}; - - -MovNormal: (ast::MovDetails, ast::Arg2Mov<ast::ParsedArgParams<'input>>) = { - "mov" <t:MovScalarType> <dst:ExtendedID> "," <src:Operand> => {( - ast::MovDetails::new(ast::Type::Scalar(t)), - ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: ast::IdOrVector::Reg(dst), src: src.into() }) - )}, - "mov" <pref:VectorPrefix> <t:MovVectorType> <dst:IdOrVector> "," <src:OperandOrVector> => {( - ast::MovDetails::new(ast::Type::Vector(t, pref)), - ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: dst, src: src }) - )} -} - -MovVector: (ast::MovDetails, ast::Arg2Mov<ast::ParsedArgParams<'input>>) = { - "mov" <t:MovVectorType> <a:Arg2MovMember> => {( - ast::MovDetails::new(ast::Type::Scalar(t.into())), - ast::Arg2Mov::Member(a) - )}, + "mov" <pref:VectorPrefix?> <t:MovScalarType> <dst:DstOperandVec> "," <src:SrcOperandVec> => { + let mov_type = match pref { + Some(vec_width) => ast::Type::Vector(t, vec_width), + None => ast::Type::Scalar(t) + }; + let details = ast::MovDetails::new(mov_type); + ast::Instruction::Mov( + details, + ast::Arg2Mov { dst, src } + ) + } } #[inline] @@ -819,21 +799,6 @@ MovScalarType: ast::ScalarType = { ".pred" => ast::ScalarType::Pred }; -#[inline] -MovVectorType: ast::ScalarType = { - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -}; - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul @@ -921,7 +886,7 @@ InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = { // TODO: support f16 setp InstSetp: ast::Instruction<ast::ParsedArgParams<'input>> = { "setp" <d:SetpMode> <a:Arg4Setp> => ast::Instruction::Setp(d, a), - "setp" <d:SetpBoolMode> <a:Arg5> => ast::Instruction::SetpBool(d, a), + "setp" <d:SetpBoolMode> <a:Arg5Setp> => ast::Instruction::SetpBool(d, a), }; SetpMode: ast::SetpData = { @@ -1198,7 +1163,7 @@ ShrType: ast::ShrType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = { - "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:OperandOrVector> => { + "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:SrcOperandVec> => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -1775,9 +1740,9 @@ Operand: ast::Operand<&'input str> = { <x:ImmediateValue> => ast::Operand::Imm(x) }; -CallOperand: ast::CallOperand<&'input str> = { - <r:ExtendedID> => ast::CallOperand::Reg(r), - <x:ImmediateValue> => ast::CallOperand::Imm(x) +CallOperand: ast::Operand<&'input str> = { + <r:ExtendedID> => ast::Operand::Reg(r), + <x:ImmediateValue> => ast::Operand::Imm(x) }; // TODO: start parsing whole constants sub-language: @@ -1825,13 +1790,7 @@ Arg1Bar: ast::Arg1Bar<ast::ParsedArgParams<'input>> = { }; Arg2: ast::Arg2<ast::ParsedArgParams<'input>> = { - <dst:ExtendedID> "," <src:Operand> => ast::Arg2{<>} -}; - -Arg2MovMember: ast::Arg2MovMember<ast::ParsedArgParams<'input>> = { - <dst:MemberOperand> "," <src:ExtendedID> => ast::Arg2MovMember::Dst(dst, dst.0, src), - <dst:ExtendedID> "," <src:MemberOperand> => ast::Arg2MovMember::Src(dst, src), - <dst:MemberOperand> "," <src:MemberOperand> => ast::Arg2MovMember::Both(dst, dst.0, src), + <dst:DstOperand> "," <src:Operand> => ast::Arg2{<>} }; MemberOperand: (&'input str, u8) = { @@ -1855,19 +1814,19 @@ VectorExtract: Vec<&'input str> = { }; Arg3: ast::Arg3<ast::ParsedArgParams<'input>> = { - <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>} + <dst:DstOperand> "," <src1:Operand> "," <src2:Operand> => ast::Arg3{<>} }; Arg3Atom: ast::Arg3<ast::ParsedArgParams<'input>> = { - <dst:ExtendedID> "," "[" <src1:Operand> "]" "," <src2:Operand> => ast::Arg3{<>} + <dst:DstOperand> "," "[" <src1:Operand> "]" "," <src2:Operand> => ast::Arg3{<>} }; Arg4: ast::Arg4<ast::ParsedArgParams<'input>> = { - <dst:ExtendedID> "," <src1:Operand> "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>} + <dst:DstOperand> "," <src1:Operand> "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>} }; Arg4Atom: ast::Arg4<ast::ParsedArgParams<'input>> = { - <dst:ExtendedID> "," "[" <src1:Operand> "]" "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>} + <dst:DstOperand> "," "[" <src1:Operand> "]" "," <src2:Operand> "," <src3:Operand> => ast::Arg4{<>} }; Arg4Setp: ast::Arg4Setp<ast::ParsedArgParams<'input>> = { @@ -1875,22 +1834,50 @@ Arg4Setp: ast::Arg4Setp<ast::ParsedArgParams<'input>> = { }; // TODO: pass src3 negation somewhere -Arg5: ast::Arg5<ast::ParsedArgParams<'input>> = { - <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>} +Arg5Setp: ast::Arg5Setp<ast::ParsedArgParams<'input>> = { + <dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5Setp{<>} }; -ArgCall: (Vec<&'input str>, &'input str, Vec<ast::CallOperand<&'input str>>) = { +ArgCall: (Vec<&'input str>, &'input str, Vec<ast::Operand<&'input str>>) = { "(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => { (ret_params, func, param_list) }, <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list), - <func:ExtendedID> => (Vec::new(), func, Vec::<ast::CallOperand<_>>::new()), + <func:ExtendedID> => (Vec::new(), func, Vec::<ast::Operand<_>>::new()), }; OptionalDst: &'input str = { "|" <dst2:ExtendedID> => dst2 } +SrcOperand: ast::Operand<&'input str> = { + <r:ExtendedID> => ast::Operand::Reg(r), + <r:ExtendedID> "+" <offset:S32Num> => ast::Operand::RegOffset(r, offset), + <x:ImmediateValue> => ast::Operand::Imm(x), + <mem_op:MemberOperand> => { + let (reg, idx) = mem_op; + ast::Operand::VecMember(reg, idx) + } +} + +SrcOperandVec: ast::Operand<&'input str> = { + <normal:SrcOperand> => normal, + <vec:VectorExtract> => ast::Operand::VecPack(vec), +} + +DstOperand: ast::Operand<&'input str> = { + <r:ExtendedID> => ast::Operand::Reg(r), + <mem_op:MemberOperand> => { + let (reg, idx) = mem_op; + ast::Operand::VecMember(reg, idx) + } +} + +DstOperandVec: ast::Operand<&'input str> = { + <normal:DstOperand> => normal, + <vec:VectorExtract> => ast::Operand::VecPack(vec), +} + VectorPrefix: u8 = { ".v2" => 2, ".v4" => 4 diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index 535e480..a77ab7d 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -7,91 +7,93 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %57 = OpExtInstImport "OpenCL.std" + %51 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %31 "vector" + OpEntryPoint Kernel %25 "vector" %void = OpTypeVoid %uint = OpTypeInt 32 0 %v2uint = OpTypeVector %uint 2 - %61 = OpTypeFunction %v2uint %v2uint + %55 = OpTypeFunction %v2uint %v2uint %_ptr_Function_v2uint = OpTypePointer Function %v2uint %_ptr_Function_uint = OpTypePointer Function %uint + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 %ulong = OpTypeInt 64 0 - %65 = OpTypeFunction %void %ulong %ulong + %67 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_v2uint = OpTypePointer Generic %v2uint - %1 = OpFunction %v2uint None %61 + %1 = OpFunction %v2uint None %55 %7 = OpFunctionParameter %v2uint - %30 = OpLabel + %24 = OpLabel %2 = OpVariable %_ptr_Function_v2uint Function %3 = OpVariable %_ptr_Function_v2uint Function %4 = OpVariable %_ptr_Function_v2uint Function %5 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_uint Function OpStore %3 %7 - %9 = OpLoad %v2uint %3 - %27 = OpCompositeExtract %uint %9 0 - %8 = OpCopyObject %uint %27 + %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0 + %9 = OpLoad %uint %59 + %8 = OpCopyObject %uint %9 OpStore %5 %8 - %11 = OpLoad %v2uint %3 - %28 = OpCompositeExtract %uint %11 1 - %10 = OpCopyObject %uint %28 + %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1 + %11 = OpLoad %uint %61 + %10 = OpCopyObject %uint %11 OpStore %6 %10 %13 = OpLoad %uint %5 %14 = OpLoad %uint %6 %12 = OpIAdd %uint %13 %14 OpStore %6 %12 - %16 = OpLoad %v2uint %4 - %17 = OpLoad %uint %6 - %15 = OpCompositeInsert %v2uint %17 %16 0 - OpStore %4 %15 - %19 = OpLoad %v2uint %4 - %20 = OpLoad %uint %6 - %18 = OpCompositeInsert %v2uint %20 %19 1 - OpStore %4 %18 + %16 = OpLoad %uint %6 + %15 = OpCopyObject %uint %16 + %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %62 %15 + %18 = OpLoad %uint %6 + %17 = OpCopyObject %uint %18 + %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + OpStore %63 %17 + %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + %20 = OpLoad %uint %64 + %19 = OpCopyObject %uint %20 + %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %65 %19 %22 = OpLoad %v2uint %4 - %23 = OpLoad %v2uint %4 - %29 = OpCompositeExtract %uint %23 1 - %21 = OpCompositeInsert %v2uint %29 %22 0 - OpStore %4 %21 - %25 = OpLoad %v2uint %4 - %24 = OpCopyObject %v2uint %25 - OpStore %2 %24 - %26 = OpLoad %v2uint %2 - OpReturnValue %26 + %21 = OpCopyObject %v2uint %22 + OpStore %2 %21 + %23 = OpLoad %v2uint %2 + OpReturnValue %23 OpFunctionEnd - %31 = OpFunction %void None %65 - %40 = OpFunctionParameter %ulong - %41 = OpFunctionParameter %ulong - %55 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function + %25 = OpFunction %void None %67 + %34 = OpFunctionParameter %ulong + %35 = OpFunctionParameter %ulong + %49 = OpLabel + %26 = OpVariable %_ptr_Function_ulong Function + %27 = OpVariable %_ptr_Function_ulong Function + %28 = OpVariable %_ptr_Function_ulong Function + %29 = OpVariable %_ptr_Function_ulong Function + %30 = OpVariable %_ptr_Function_v2uint Function + %31 = OpVariable %_ptr_Function_uint Function + %32 = OpVariable %_ptr_Function_uint Function %33 = OpVariable %_ptr_Function_ulong Function - %34 = OpVariable %_ptr_Function_ulong Function - %35 = OpVariable %_ptr_Function_ulong Function - %36 = OpVariable %_ptr_Function_v2uint Function - %37 = OpVariable %_ptr_Function_uint Function - %38 = OpVariable %_ptr_Function_uint Function - %39 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %40 - OpStore %33 %41 - %42 = OpLoad %ulong %32 - OpStore %34 %42 - %43 = OpLoad %ulong %33 - OpStore %35 %43 - %45 = OpLoad %ulong %34 - %52 = OpConvertUToPtr %_ptr_Generic_v2uint %45 - %44 = OpLoad %v2uint %52 - OpStore %36 %44 - %47 = OpLoad %v2uint %36 - %46 = OpFunctionCall %v2uint %1 %47 - OpStore %36 %46 - %49 = OpLoad %v2uint %36 - %53 = OpBitcast %ulong %49 - %48 = OpCopyObject %ulong %53 - OpStore %39 %48 - %50 = OpLoad %ulong %35 - %51 = OpLoad %v2uint %36 - %54 = OpConvertUToPtr %_ptr_Generic_v2uint %50 - OpStore %54 %51 + OpStore %26 %34 + OpStore %27 %35 + %36 = OpLoad %ulong %26 + OpStore %28 %36 + %37 = OpLoad %ulong %27 + OpStore %29 %37 + %39 = OpLoad %ulong %28 + %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39 + %38 = OpLoad %v2uint %46 + OpStore %30 %38 + %41 = OpLoad %v2uint %30 + %40 = OpFunctionCall %v2uint %1 %41 + OpStore %30 %40 + %43 = OpLoad %v2uint %30 + %47 = OpBitcast %ulong %43 + %42 = OpCopyObject %ulong %47 + OpStore %33 %42 + %44 = OpLoad %ulong %29 + %45 = OpLoad %v2uint %30 + %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44 + OpStore %48 %45 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt index 4943189..2037dec 100644 --- a/ptx/src/test/spirv_run/vector_extract.spvtxt +++ b/ptx/src/test/spirv_run/vector_extract.spvtxt @@ -7,12 +7,12 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %73 = OpExtInstImport "OpenCL.std" + %61 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "vector_extract" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %76 = OpTypeFunction %void %ulong %ulong + %64 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %ushort = OpTypeInt 16 0 %_ptr_Function_ushort = OpTypePointer Function %ushort @@ -21,10 +21,10 @@ %uchar = OpTypeInt 8 0 %v4uchar = OpTypeVector %uchar 4 %_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar - %1 = OpFunction %void None %76 - %11 = OpFunctionParameter %ulong - %12 = OpFunctionParameter %ulong - %71 = OpLabel + %1 = OpFunction %void None %64 + %17 = OpFunctionParameter %ulong + %18 = OpFunctionParameter %ulong + %59 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -34,89 +34,92 @@ %8 = OpVariable %_ptr_Function_ushort Function %9 = OpVariable %_ptr_Function_ushort Function %10 = OpVariable %_ptr_Function_v4ushort Function - OpStore %2 %11 - OpStore %3 %12 - %13 = OpLoad %ulong %2 - OpStore %4 %13 - %14 = OpLoad %ulong %3 - OpStore %5 %14 - %19 = OpLoad %ulong %4 - %61 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %19 - %43 = OpLoad %v4uchar %61 - %62 = OpCompositeExtract %uchar %43 0 - %85 = OpBitcast %uchar %62 - %15 = OpUConvert %ushort %85 - %63 = OpCompositeExtract %uchar %43 1 - %86 = OpBitcast %uchar %63 - %16 = OpUConvert %ushort %86 - %64 = OpCompositeExtract %uchar %43 2 - %87 = OpBitcast %uchar %64 - %17 = OpUConvert %ushort %87 - %65 = OpCompositeExtract %uchar %43 3 - %88 = OpBitcast %uchar %65 - %18 = OpUConvert %ushort %88 - OpStore %6 %15 - OpStore %7 %16 - OpStore %8 %17 - OpStore %9 %18 - %21 = OpLoad %ushort %7 - %22 = OpLoad %ushort %8 - %23 = OpLoad %ushort %9 - %24 = OpLoad %ushort %6 - %44 = OpUndef %v4ushort - %45 = OpCompositeInsert %v4ushort %21 %44 0 - %46 = OpCompositeInsert %v4ushort %22 %45 1 - %47 = OpCompositeInsert %v4ushort %23 %46 2 - %48 = OpCompositeInsert %v4ushort %24 %47 3 - %20 = OpCopyObject %v4ushort %48 - OpStore %10 %20 - %29 = OpLoad %v4ushort %10 - %49 = OpCopyObject %v4ushort %29 - %25 = OpCompositeExtract %ushort %49 0 - %26 = OpCompositeExtract %ushort %49 1 - %27 = OpCompositeExtract %ushort %49 2 - %28 = OpCompositeExtract %ushort %49 3 - OpStore %8 %25 - OpStore %9 %26 - OpStore %6 %27 - OpStore %7 %28 - %34 = OpLoad %ushort %8 - %35 = OpLoad %ushort %9 - %36 = OpLoad %ushort %6 - %37 = OpLoad %ushort %7 - %51 = OpUndef %v4ushort - %52 = OpCompositeInsert %v4ushort %34 %51 0 - %53 = OpCompositeInsert %v4ushort %35 %52 1 - %54 = OpCompositeInsert %v4ushort %36 %53 2 - %55 = OpCompositeInsert %v4ushort %37 %54 3 - %50 = OpCopyObject %v4ushort %55 - %30 = OpCompositeExtract %ushort %50 0 - %31 = OpCompositeExtract %ushort %50 1 - %32 = OpCompositeExtract %ushort %50 2 - %33 = OpCompositeExtract %ushort %50 3 - OpStore %9 %30 - OpStore %6 %31 - OpStore %7 %32 - OpStore %8 %33 - %38 = OpLoad %ulong %5 - %39 = OpLoad %ushort %6 - %40 = OpLoad %ushort %7 - %41 = OpLoad %ushort %8 - %42 = OpLoad %ushort %9 - %56 = OpUndef %v4uchar - %89 = OpBitcast %ushort %39 - %66 = OpUConvert %uchar %89 - %57 = OpCompositeInsert %v4uchar %66 %56 0 - %90 = OpBitcast %ushort %40 - %67 = OpUConvert %uchar %90 - %58 = OpCompositeInsert %v4uchar %67 %57 1 - %91 = OpBitcast %ushort %41 - %68 = OpUConvert %uchar %91 - %59 = OpCompositeInsert %v4uchar %68 %58 2 - %92 = OpBitcast %ushort %42 - %69 = OpUConvert %uchar %92 - %60 = OpCompositeInsert %v4uchar %69 %59 3 - %70 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %38 - OpStore %70 %60 + OpStore %2 %17 + OpStore %3 %18 + %19 = OpLoad %ulong %2 + OpStore %4 %19 + %20 = OpLoad %ulong %3 + OpStore %5 %20 + %21 = OpLoad %ulong %4 + %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21 + %11 = OpLoad %v4uchar %49 + %50 = OpCompositeExtract %uchar %11 0 + %51 = OpCompositeExtract %uchar %11 1 + %52 = OpCompositeExtract %uchar %11 2 + %53 = OpCompositeExtract %uchar %11 3 + %73 = OpBitcast %uchar %50 + %22 = OpUConvert %ushort %73 + %74 = OpBitcast %uchar %51 + %23 = OpUConvert %ushort %74 + %75 = OpBitcast %uchar %52 + %24 = OpUConvert %ushort %75 + %76 = OpBitcast %uchar %53 + %25 = OpUConvert %ushort %76 + OpStore %6 %22 + OpStore %7 %23 + OpStore %8 %24 + OpStore %9 %25 + %26 = OpLoad %ushort %7 + %27 = OpLoad %ushort %8 + %28 = OpLoad %ushort %9 + %29 = OpLoad %ushort %6 + %77 = OpUndef %v4ushort + %78 = OpCompositeInsert %v4ushort %26 %77 0 + %79 = OpCompositeInsert %v4ushort %27 %78 1 + %80 = OpCompositeInsert %v4ushort %28 %79 2 + %81 = OpCompositeInsert %v4ushort %29 %80 3 + %12 = OpCopyObject %v4ushort %81 + %30 = OpCopyObject %v4ushort %12 + OpStore %10 %30 + %31 = OpLoad %v4ushort %10 + %13 = OpCopyObject %v4ushort %31 + %32 = OpCompositeExtract %ushort %13 0 + %33 = OpCompositeExtract %ushort %13 1 + %34 = OpCompositeExtract %ushort %13 2 + %35 = OpCompositeExtract %ushort %13 3 + OpStore %8 %32 + OpStore %9 %33 + OpStore %6 %34 + OpStore %7 %35 + %36 = OpLoad %ushort %8 + %37 = OpLoad %ushort %9 + %38 = OpLoad %ushort %6 + %39 = OpLoad %ushort %7 + %82 = OpUndef %v4ushort + %83 = OpCompositeInsert %v4ushort %36 %82 0 + %84 = OpCompositeInsert %v4ushort %37 %83 1 + %85 = OpCompositeInsert %v4ushort %38 %84 2 + %86 = OpCompositeInsert %v4ushort %39 %85 3 + %15 = OpCopyObject %v4ushort %86 + %14 = OpCopyObject %v4ushort %15 + %40 = OpCompositeExtract %ushort %14 0 + %41 = OpCompositeExtract %ushort %14 1 + %42 = OpCompositeExtract %ushort %14 2 + %43 = OpCompositeExtract %ushort %14 3 + OpStore %9 %40 + OpStore %6 %41 + OpStore %7 %42 + OpStore %8 %43 + %44 = OpLoad %ushort %6 + %45 = OpLoad %ushort %7 + %46 = OpLoad %ushort %8 + %47 = OpLoad %ushort %9 + %87 = OpBitcast %ushort %44 + %54 = OpUConvert %uchar %87 + %88 = OpBitcast %ushort %45 + %55 = OpUConvert %uchar %88 + %89 = OpBitcast %ushort %46 + %56 = OpUConvert %uchar %89 + %90 = OpBitcast %ushort %47 + %57 = OpUConvert %uchar %90 + %91 = OpUndef %v4uchar + %92 = OpCompositeInsert %v4uchar %54 %91 0 + %93 = OpCompositeInsert %v4uchar %55 %92 1 + %94 = OpCompositeInsert %v4uchar %56 %93 2 + %95 = OpCompositeInsert %v4uchar %57 %94 3 + %16 = OpCopyObject %v4uchar %95 + %48 = OpLoad %ulong %5 + %58 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %48 + OpStore %58 %16 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 15211ab..20578eb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -27,6 +27,16 @@ quick_error! { }
}
+#[cfg(debug_assertions)]
+fn error_unreachable() -> TranslateError {
+ unreachable!()
+}
+
+#[cfg(not(debug_assertions))]
+fn error_unreachable() -> TranslateError {
+ TranslateError::Unreachable
+}
+
#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
@@ -82,7 +92,7 @@ impl ast::Type { ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
}
- ast::Type::Pointer(_, _) => return Err(TranslateError::Unreachable),
+ ast::Type::Pointer(_, _) => return Err(error_unreachable()),
})
}
}
@@ -364,7 +374,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, &components)
}
ast::Type::Array(typ, dims) => match dims.as_slice() {
- [] => return Err(TranslateError::Unreachable),
+ [] => return Err(error_unreachable()),
[dim] => {
let result_type = self
.get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
@@ -791,13 +801,14 @@ fn convert_dynamic_shared_memory_usage<'input>( ast::PointerStateSpace::Shared,
)),
});
- let shared_var_st = ExpandedStatement::StoreVar(
- ast::Arg2St {
+ let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
src1: shared_var_id,
src2: shared_id_param,
},
- ast::Type::Scalar(ast::ScalarType::B8),
- );
+ typ: ast::Type::Scalar(ast::ScalarType::B8),
+ member_index: None,
+ });
let mut new_statements = vec![shared_var, shared_var_st];
replace_uses_of_shared_memory(
&mut new_statements,
@@ -963,18 +974,17 @@ fn compute_denorm_information<'input>( denorm_count_map_update(&mut flush_counter, width, flush);
}
}
- Statement::LoadVar(_, _) => {}
- Statement::StoreVar(_, _) => {}
+ Statement::LoadVar(..) => {}
+ Statement::StoreVar(..) => {}
Statement::Call(_) => {}
- Statement::Composite(_) => {}
Statement::Conditional(_) => {}
Statement::Conversion(_) => {}
Statement::Constant(_) => {}
Statement::RetValue(_, _) => {}
- Statement::Undef(_, _) => {}
Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
+ Statement::RepackVector(_) => {}
}
}
denorm_methods.insert(method_key, flush_counter);
@@ -1307,7 +1317,7 @@ fn to_ssa<'input, 'b>( let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
- convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
+ convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
let typed_statements =
convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
@@ -1431,7 +1441,7 @@ fn normalize_variable_decls(directives: &mut Vec<Directive>) { fn convert_to_typed_statements(
func: Vec<UnconditionalStatement>,
fn_defs: &GlobalFnDeclResolver,
- id_defs: &NumericIdResolver,
+ id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
for s in func {
@@ -1447,7 +1457,7 @@ fn convert_to_typed_statements( .partition(|(_, arg_type)| arg_type.is_param());
let normalized_input_args = out_params
.into_iter()
- .map(|(id, typ)| (ast::CallOperand::Reg(id), typ))
+ .map(|(id, typ)| (ast::Operand::Reg(id), typ))
.chain(in_args.into_iter())
.collect();
let resolved_call = ResolvedCall {
@@ -1456,205 +1466,117 @@ fn convert_to_typed_statements( func: call.func,
param_list: normalized_input_args,
};
- result.push(Statement::Call(resolved_call));
- }
- ast::Instruction::Ld(d, arg) => {
- result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast())));
- }
- ast::Instruction::St(d, arg) => {
- result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast())));
- }
- ast::Instruction::Mov(mut d, args) => match args {
- ast::Arg2Mov::Normal(arg) => {
- if let Some(src_id) = arg.src.single_underlying() {
- let (typ, _) = id_defs.get_typed(*src_id)?;
- let take_address = match typ {
- ast::Type::Scalar(_) => false,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => true,
- ast::Type::Pointer(_, _) => true,
- };
- d.src_is_address = take_address;
- }
- result.push(Statement::Instruction(ast::Instruction::Mov(
- d,
- ast::Arg2Mov::Normal(arg.cast()),
- )));
- }
- ast::Arg2Mov::Member(args) => {
- if let Some(dst_typ) = args.vector_dst() {
- match id_defs.get_typed(*dst_typ)? {
- (ast::Type::Vector(_, len), _) => {
- d.dst_width = len;
- }
- _ => return Err(TranslateError::MismatchedType),
- }
- };
- if let Some((src_typ, _)) = args.vector_src() {
- match id_defs.get_typed(*src_typ)? {
- (ast::Type::Vector(_, len), _) => {
- d.src_width = len;
- }
- _ => return Err(TranslateError::MismatchedType),
- }
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let reresolved_call = resolved_call.visit(&mut visitor)?;
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => {
+ if let Some(src_id) = src.underlying() {
+ let (typ, _) = id_defs.get_typed(*src_id)?;
+ let take_address = match typ {
+ ast::Type::Scalar(_) => false,
+ ast::Type::Vector(_, _) => false,
+ ast::Type::Array(_, _) => true,
+ ast::Type::Pointer(_, _) => true,
};
- result.push(Statement::Instruction(ast::Instruction::Mov(
- d,
- ast::Arg2Mov::Member(args.cast()),
- )));
+ d.src_is_address = take_address;
}
- },
- ast::Instruction::Mul(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast())))
- }
- ast::Instruction::Add(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast())))
- }
- ast::Instruction::Setp(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast())))
- }
- ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction(
- ast::Instruction::SetpBool(d, a.cast()),
- )),
- ast::Instruction::Not(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast())))
- }
- ast::Instruction::Bra(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast())))
- }
- ast::Instruction::Cvt(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast())))
- }
- ast::Instruction::Cvta(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast())))
- }
- ast::Instruction::Shl(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast())))
- }
- ast::Instruction::Ret(d) => {
- result.push(Statement::Instruction(ast::Instruction::Ret(d)))
- }
- ast::Instruction::Abs(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast())))
- }
- ast::Instruction::Mad(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast())))
- }
- ast::Instruction::Shr(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast())))
- }
- ast::Instruction::Or(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
- }
- ast::Instruction::Sub(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast())))
- }
- ast::Instruction::Min(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast())))
- }
- ast::Instruction::Max(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
- }
- ast::Instruction::Rcp(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast())))
- }
- ast::Instruction::And(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::And(d, a.cast())))
- }
- ast::Instruction::Selp(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast())))
- }
- ast::Instruction::Bar(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast())))
- }
- ast::Instruction::Atom(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast())))
- }
- ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
- ast::Instruction::AtomCas(d, a.cast()),
- )),
- ast::Instruction::Div(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast())))
- }
- ast::Instruction::Sqrt(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast())))
- }
- ast::Instruction::Rsqrt(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
- }
- ast::Instruction::Neg(d, a) => {
- result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast())))
- }
- ast::Instruction::Sin { flush_to_zero, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Sin {
- flush_to_zero,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Cos { flush_to_zero, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Cos {
- flush_to_zero,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Lg2 { flush_to_zero, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Lg2 {
- flush_to_zero,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Ex2 { flush_to_zero, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Ex2 {
- flush_to_zero,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Clz { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Clz {
- typ,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Brev { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Brev {
- typ,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Popc { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Popc {
- typ,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Xor { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Xor {
- typ,
- arg: arg.cast(),
- }))
- }
- ast::Instruction::Bfe { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Bfe {
- typ,
- arg: arg.cast(),
- }))
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let instruction = Statement::Instruction(
+ ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?,
+ );
+ visitor.func.push(instruction);
+ visitor.func.extend(visitor.post_stmts);
}
- ast::Instruction::Rem { typ, arg } => {
- result.push(Statement::Instruction(ast::Instruction::Rem {
- typ,
- arg: arg.cast(),
- }))
+ inst => {
+ let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
+ let instruction = Statement::Instruction(inst.map(&mut visitor)?);
+ visitor.func.push(instruction);
+ visitor.func.extend(visitor.post_stmts);
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
}
}
Ok(result)
}
+struct VectorRepackVisitor<'a, 'b> {
+ func: &'b mut Vec<TypedStatement>,
+ id_def: &'b mut NumericIdResolver<'a>,
+ post_stmts: Option<TypedStatement>,
+}
+
+impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
+ fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
+ VectorRepackVisitor {
+ func,
+ id_def,
+ post_stmts: None,
+ }
+ }
+
+ fn convert_vector(
+ &mut self,
+ is_dst: bool,
+ vector_sema: ArgumentSemantics,
+ typ: &ast::Type,
+ idx: Vec<spirv::Word>,
+ ) -> Result<spirv::Word, TranslateError> {
+ // mov.u32 foobar, {a,b};
+ let scalar_t = match typ {
+ ast::Type::Vector(scalar_t, _) => *scalar_t,
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
+ let statement = Statement::RepackVector(RepackVectorDetails {
+ is_extract: is_dst,
+ typ: scalar_t,
+ packed: temp_vec,
+ unpacked: idx,
+ vector_sema,
+ });
+ if is_dst {
+ self.post_stmts = Some(statement);
+ } else {
+ self.func.push(statement);
+ }
+ Ok(temp_vec)
+ }
+}
+
+impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
+ for VectorRepackVisitor<'a, 'b>
+{
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ _: Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
+ Ok(desc.op)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: &ast::Type,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match desc.op {
+ ast::Operand::Reg(reg) => TypedOperand::Reg(reg),
+ ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
+ ast::Operand::Imm(x) => TypedOperand::Imm(x),
+ ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
+ ast::Operand::VecPack(vec) => {
+ TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?)
+ }
+ })
+ }
+}
+
//TODO: share common code between this and to_ptx_impl_bfe_call
fn to_ptx_impl_atomic_call(
id_defs: &mut NumericIdResolver,
@@ -1872,17 +1794,16 @@ fn normalize_labels( labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
- Statement::Composite(_)
- | Statement::Call(_)
- | Statement::Variable(_)
- | Statement::LoadVar(_, _)
- | Statement::StoreVar(_, _)
- | Statement::RetValue(_, _)
- | Statement::Conversion(_)
- | Statement::Constant(_)
- | Statement::Label(_)
- | Statement::Undef(_, _)
- | Statement::PtrAccess { .. } => {}
+ Statement::Call(..)
+ | Statement::Variable(..)
+ | Statement::LoadVar(..)
+ | Statement::StoreVar(..)
+ | Statement::RetValue(..)
+ | Statement::Conversion(..)
+ | Statement::Constant(..)
+ | Statement::Label(..)
+ | Statement::PtrAccess { .. }
+ | Statement::RepackVector(..) => {}
}
}
iter::once(Statement::Label(id_def.new_non_variable(None)))
@@ -1929,7 +1850,7 @@ fn normalize_predicates( }
Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
}
}
Ok(result)
@@ -1956,7 +1877,7 @@ fn insert_mem_ssa_statements<'a, 'b>( array_init: arg.array_init.clone(),
}));
}
- None => return Err(TranslateError::Unreachable),
+ None => return Err(error_unreachable()),
}
}
for spirv_arg in fn_decl.input.iter_mut() {
@@ -1970,13 +1891,14 @@ fn insert_mem_ssa_statements<'a, 'b>( name: spirv_arg.name,
array_init: spirv_arg.array_init.clone(),
}));
- result.push(Statement::StoreVar(
- ast::Arg2St {
+ result.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
src1: spirv_arg.name,
src2: new_id,
},
typ,
- ));
+ member_index: None,
+ }));
spirv_arg.name = new_id;
}
None => {}
@@ -1993,13 +1915,14 @@ fn insert_mem_ssa_statements<'a, 'b>( if let &[out_param] = &fn_decl.output.as_slice() {
let (typ, _) = id_def.get_typed(out_param.name)?;
let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::LoadVar(
- ast::Arg2 {
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::Arg2 {
dst: new_id,
src: out_param.name,
},
- typ.clone(),
- ));
+ typ: typ.clone(),
+ member_index: None,
+ }));
result.push(Statement::RetValue(d, new_id));
} else {
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
@@ -2010,13 +1933,14 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Conditional(mut bra) => {
let generated_id =
id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
- result.push(Statement::LoadVar(
- Arg2 {
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: Arg2 {
dst: generated_id,
src: bra.predicate,
},
- ast::Type::Scalar(ast::ScalarType::Pred),
- ));
+ typ: ast::Type::Scalar(ast::ScalarType::Pred),
+ member_index: None,
+ }));
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
@@ -2026,8 +1950,11 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::PtrAccess(ptr_access) => {
insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
}
+ Statement::RepackVector(repack) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, repack)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
}
}
Ok(result)
@@ -2059,101 +1986,156 @@ fn type_to_variable_type( scalar_type
.clone()
.try_into()
- .map_err(|_| TranslateError::Unreachable)?,
- (*space)
- .try_into()
- .map_err(|_| TranslateError::Unreachable)?,
+ .map_err(|_| error_unreachable())?,
+ (*space).try_into().map_err(|_| error_unreachable())?,
)))
}
ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
})
}
-trait VisitVariable: Sized {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
+ fn visit(
self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError>;
-}
-trait VisitVariableExpanded {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError>;
+ visitor: &mut impl ArgumentMapVisitor<From, To>,
+ ) -> Result<Statement<ast::Instruction<To>, To>, TranslateError>;
}
-struct VisitArgumentDescriptor<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> {
+struct VisitArgumentDescriptor<
+ 'a,
+ Ctor: FnOnce(spirv::Word) -> Statement<ast::Instruction<U>, U>,
+ U: ArgParamsEx,
+> {
desc: ArgumentDescriptor<spirv::Word>,
typ: &'a ast::Type,
stmt_ctor: Ctor,
}
-impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded
- for VisitArgumentDescriptor<'a, Ctor>
+impl<
+ 'a,
+ Ctor: FnOnce(spirv::Word) -> Statement<ast::Instruction<U>, U>,
+ T: ArgParamsEx<Id = spirv::Word>,
+ U: ArgParamsEx<Id = spirv::Word>,
+ > Visitable<T, U> for VisitArgumentDescriptor<'a, Ctor, U>
{
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+ fn visit(
self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- f(self.desc, Some(self.typ)).map(self.stmt_ctor)
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?))
}
}
-fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
- id_def: &mut NumericIdResolver,
- result: &mut Vec<TypedStatement>,
- stmt: F,
-) -> Result<(), TranslateError> {
- let mut post_statements = Vec::new();
- let new_statement = stmt.visit_variable(
- &mut |desc: ArgumentDescriptor<spirv::Word>, expected_type| {
- if expected_type.is_none() {
- return Ok(desc.op);
- };
- let (var_type, is_variable) = id_def.get_typed(desc.op)?;
- if !is_variable {
- return Ok(desc.op);
- }
- let generated_id = id_def.new_non_variable(Some(var_type.clone()));
- if !desc.is_dst {
- result.push(Statement::LoadVar(
- Arg2 {
- dst: generated_id,
- src: desc.op,
+struct InsertMemSSAVisitor<'a, 'input> {
+ id_def: &'a mut NumericIdResolver<'input>,
+ func: &'a mut Vec<TypedStatement>,
+ post_statements: Vec<TypedStatement>,
+}
+
+impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
+ fn symbol(
+ &mut self,
+ desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
+ expected_type: Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
+ let symbol = desc.op.0;
+ if expected_type.is_none() {
+ return Ok(symbol);
+ };
+ let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
+ if !is_variable {
+ return Ok(symbol);
+ };
+ let member_index = match desc.op.1 {
+ Some(idx) => {
+ let vector_width = match var_type {
+ ast::Type::Vector(scalar_t, width) => {
+ var_type = ast::Type::Scalar(scalar_t);
+ width
+ }
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ Some((
+ idx,
+ if self.id_def.special_registers.contains_key(&symbol) {
+ Some(vector_width)
+ } else {
+ None
},
- var_type,
- ));
- } else {
- post_statements.push(Statement::StoreVar(
- Arg2St {
- src1: desc.op,
+ ))
+ }
+ None => None,
+ };
+ let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
+ if !desc.is_dst {
+ self.func.push(Statement::LoadVar(LoadVarDetails {
+ arg: Arg2 {
+ dst: generated_id,
+ src: symbol,
+ },
+ typ: var_type,
+ member_index,
+ }));
+ } else {
+ self.post_statements
+ .push(Statement::StoreVar(StoreVarDetails {
+ arg: Arg2St {
+ src1: symbol,
src2: generated_id,
},
- var_type,
- ));
+ typ: var_type,
+ member_index: member_index.map(|(idx, _)| idx),
+ }));
+ }
+ Ok(generated_id)
+ }
+}
+
+impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
+ for InsertMemSSAVisitor<'a, 'input>
+{
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<&ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
+ self.symbol(desc.new_op((desc.op, None)), typ)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<TypedOperand>,
+ typ: &ast::Type,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match desc.op {
+ TypedOperand::Reg(reg) => {
+ TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
}
- Ok(generated_id)
- },
- )?;
- result.push(new_statement);
- result.append(&mut post_statements);
+ TypedOperand::RegOffset(reg, offset) => {
+ TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset)
+ }
+ op @ TypedOperand::Imm(..) => op,
+ TypedOperand::VecMember(symbol, index) => {
+ TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
+ }
+ })
+ }
+}
+
+fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable<TypedArgParams, TypedArgParams>>(
+ id_def: &'a mut NumericIdResolver<'input>,
+ func: &'a mut Vec<TypedStatement>,
+ stmt: S,
+) -> Result<(), TranslateError> {
+ let mut visitor = InsertMemSSAVisitor {
+ id_def,
+ func,
+ post_statements: Vec::new(),
+ };
+ let new_stmt = stmt.visit(&mut visitor)?;
+ visitor.func.push(new_stmt);
+ visitor.func.extend(visitor.post_statements);
Ok(())
}
@@ -2193,15 +2175,19 @@ fn expand_arguments<'a, 'b>( result.push(Statement::PtrAccess(new_inst));
result.extend(post_stmts);
}
+ Statement::RepackVector(repack) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_def);
+ let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::RepackVector(new_inst));
+ result.extend(post_stmts);
+ }
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
- Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
- Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
+ Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
+ Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
- Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => {
- return Err(TranslateError::Unreachable)
- }
+ Statement::Constant(_) => return Err(error_unreachable()),
}
}
Ok(result)
@@ -2225,27 +2211,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { }
}
- fn insert_composite_read(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver<'a>,
- typ: (ast::ScalarType, u8),
- scalar_dst: Option<spirv::Word>,
- scalar_sema_override: Option<ArgumentSemantics>,
- composite_src: (spirv::Word, u8),
- ) -> spirv::Word {
- let new_id =
- scalar_dst.unwrap_or_else(|| id_def.new_non_variable(ast::Type::Scalar(typ.0)));
- func.push(Statement::Composite(CompositeRead {
- typ: typ.0,
- dst: new_id,
- dst_semantics_override: scalar_sema_override,
- src_composite: composite_src.0,
- src_index: composite_src.1 as u32,
- src_len: typ.1 as u32,
- }));
- new_id
- }
-
fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
@@ -2367,69 +2332,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { }));
Ok(id)
}
-
- fn member_src(
- &mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- typ: (ast::ScalarType, u8),
- ) -> Result<spirv::Word, TranslateError> {
- if desc.is_dst {
- return Err(TranslateError::Unreachable);
- }
- let new_id = Self::insert_composite_read(
- self.func,
- self.id_def,
- typ,
- None,
- Some(desc.sema),
- desc.op,
- );
- Ok(new_id)
- }
-
- fn vector(
- &mut self,
- desc: ArgumentDescriptor<&Vec<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- let (scalar_type, vec_len) = typ.get_vector()?;
- if !desc.is_dst {
- let mut new_id = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::Undef(typ.clone(), new_id));
- for (idx, id) in desc.op.iter().enumerate() {
- let newer_id = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::Instruction(ast::Instruction::Mov(
- ast::MovDetails {
- typ: ast::Type::Scalar(scalar_type),
- src_is_address: false,
- dst_width: vec_len,
- src_width: 0,
- relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed,
- },
- ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
- (newer_id, idx as u8),
- new_id,
- *id,
- )),
- )));
- new_id = newer_id;
- }
- Ok(new_id)
- } else {
- let new_id = self.id_def.new_non_variable(typ.clone());
- for (idx, id) in desc.op.iter().enumerate() {
- Self::insert_composite_read(
- &mut self.post_stmts,
- self.id_def,
- (scalar_type, vec_len),
- Some(*id),
- Some(desc.sema),
- (new_id, idx as u8),
- );
- }
- Ok(new_id)
- }
- }
}
impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenArguments<'a, 'b> {
@@ -2443,58 +2345,16 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn operand(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
- ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ),
- ast::Operand::RegOffset(reg, offset) => {
+ TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
+ TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ TypedOperand::RegOffset(reg, offset) => {
self.reg_offset(desc.new_op((reg, offset)), typ)
}
- }
- }
-
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)),
- ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
- }
- }
-
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- typ: (ast::ScalarType, u8),
- ) -> Result<spirv::Word, TranslateError> {
- self.member_src(desc, typ)
- }
-
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
- }
- }
-
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- match desc.op {
- ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ),
- ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ),
- ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ),
+ TypedOperand::VecMember(..) => Err(error_unreachable()),
}
}
}
@@ -2543,7 +2403,7 @@ fn insert_implicit_conversions( if let ast::Instruction::AtomCas(d, _) = &inst {
state_space = Some(d.space.to_ld_ss());
}
- if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst {
+ if let ast::Instruction::Mov(..) = &inst {
default_conversion_fn = should_bitcast_packed;
}
insert_implicit_conversions_impl(
@@ -2554,13 +2414,6 @@ fn insert_implicit_conversions( state_space,
)?;
}
- Statement::Composite(composite) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- composite,
- should_bitcast_wrapper,
- None,
- )?,
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
@@ -2593,14 +2446,20 @@ fn insert_implicit_conversions( Some(state_space),
)?;
}
+ Statement::RepackVector(repack) => insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ repack,
+ should_bitcast_wrapper,
+ None,
+ )?,
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
- | s @ Statement::LoadVar(_, _)
- | s @ Statement::StoreVar(_, _)
- | s @ Statement::Undef(_, _)
+ | s @ Statement::LoadVar(..)
+ | s @ Statement::StoreVar(..)
| s @ Statement::RetValue(_, _) => result.push(s),
}
}
@@ -2610,7 +2469,7 @@ fn insert_implicit_conversions( fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
- stmt: impl VisitVariableExpanded,
+ stmt: impl Visitable<ExpandedArgParams, ExpandedArgParams>,
default_conversion_fn: for<'a> fn(
&'a ast::Type,
&'a ast::Type,
@@ -2619,62 +2478,64 @@ fn insert_implicit_conversions_impl( state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
- let statement = stmt.visit_variable_extended(&mut |desc, typ| {
- let instr_type = match typ {
- None => return Ok(desc.op),
- Some(t) => t,
- };
- let operand_type = id_def.get_typed(desc.op)?;
- let mut conversion_fn = default_conversion_fn;
- match desc.sema {
- ArgumentSemantics::Default => {}
- ArgumentSemantics::DefaultRelaxed => {
- if desc.is_dst {
- conversion_fn = should_convert_relaxed_dst_wrapper;
- } else {
- conversion_fn = should_convert_relaxed_src_wrapper;
+ let statement = stmt.visit(
+ &mut |desc: ArgumentDescriptor<spirv::Word>, typ: Option<&ast::Type>| {
+ let instr_type = match typ {
+ None => return Ok(desc.op),
+ Some(t) => t,
+ };
+ let operand_type = id_def.get_typed(desc.op)?;
+ let mut conversion_fn = default_conversion_fn;
+ match desc.sema {
+ ArgumentSemantics::Default => {}
+ ArgumentSemantics::DefaultRelaxed => {
+ if desc.is_dst {
+ conversion_fn = should_convert_relaxed_dst_wrapper;
+ } else {
+ conversion_fn = should_convert_relaxed_src_wrapper;
+ }
}
- }
- ArgumentSemantics::PhysicalPointer => {
- conversion_fn = bitcast_physical_pointer;
- }
- ArgumentSemantics::RegisterPointer => {
- conversion_fn = bitcast_register_pointer;
- }
- ArgumentSemantics::Address => {
- conversion_fn = force_bitcast_ptr_to_bit;
- }
- };
- match conversion_fn(&operand_type, instr_type, state_space)? {
- Some(conv_kind) => {
- let conv_output = if desc.is_dst {
- &mut post_conv
- } else {
- &mut *func
- };
- let mut from = instr_type.clone();
- let mut to = operand_type;
- let mut src = id_def.new_non_variable(instr_type.clone());
- let mut dst = desc.op;
- let result = Ok(src);
- if !desc.is_dst {
- mem::swap(&mut src, &mut dst);
- mem::swap(&mut from, &mut to);
+ ArgumentSemantics::PhysicalPointer => {
+ conversion_fn = bitcast_physical_pointer;
}
- conv_output.push(Statement::Conversion(ImplicitConversion {
- src,
- dst,
- from,
- to,
- kind: conv_kind,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
- }));
- result
+ ArgumentSemantics::RegisterPointer => {
+ conversion_fn = bitcast_register_pointer;
+ }
+ ArgumentSemantics::Address => {
+ conversion_fn = force_bitcast_ptr_to_bit;
+ }
+ };
+ match conversion_fn(&operand_type, instr_type, state_space)? {
+ Some(conv_kind) => {
+ let conv_output = if desc.is_dst {
+ &mut post_conv
+ } else {
+ &mut *func
+ };
+ let mut from = instr_type.clone();
+ let mut to = operand_type;
+ let mut src = id_def.new_non_variable(instr_type.clone());
+ let mut dst = desc.op;
+ let result = Ok(src);
+ if !desc.is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from, &mut to);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from,
+ to,
+ kind: conv_kind,
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
+ }));
+ result
+ }
+ None => Ok(desc.op),
}
- None => Ok(desc.op),
- }
- })?;
+ },
+ )?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
@@ -2861,38 +2722,11 @@ fn emit_function_body_ops( }
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
- ast::Instruction::Mov(d, arg) => match arg {
- ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src })
- | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => {
- let result_type = map
- .get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
- builder.copy_object(result_type, Some(*dst), *src)?;
- }
- ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
- dst,
- composite_src,
- scalar_src,
- ))
- | ast::Arg2Mov::Member(ast::Arg2MovMember::Both(
- dst,
- composite_src,
- scalar_src,
- )) => {
- let scalar_type = d.typ.get_scalar()?;
- let result_type = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)),
- );
- let result_id = Some(dst.0);
- builder.composite_insert(
- result_type,
- result_id,
- *scalar_src,
- *composite_src,
- [dst.1 as u32],
- )?;
- }
- },
+ ast::Instruction::Mov(d, arg) => {
+ let result_type =
+ map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
+ }
ast::Instruction::Mul(mul, arg) => match mul {
ast::MulDetails::Signed(ref ctr) => {
emit_mul_sint(builder, map, opencl, ctr, arg)?
@@ -3202,30 +3036,38 @@ fn emit_function_body_ops( builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
},
- Statement::LoadVar(arg, typ) => {
- let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
- builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
+ Statement::LoadVar(details) => {
+ emit_load_var(builder, map, details)?;
}
- Statement::StoreVar(arg, _) => {
- builder.store(arg.src1, arg.src2, None, [])?;
+ Statement::StoreVar(details) => {
+ let dst_ptr = match details.member_index {
+ Some(index) => {
+ let result_ptr_type = map.get_or_add(
+ builder,
+ SpirvType::new_pointer(
+ details.typ.clone(),
+ spirv::StorageClass::Function,
+ ),
+ );
+ let index_spirv = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(index as u32),
+ )?;
+ builder.in_bounds_access_chain(
+ result_ptr_type,
+ None,
+ details.arg.src1,
+ &[index_spirv],
+ )?
+ }
+ None => details.arg.src1,
+ };
+ builder.store(dst_ptr, details.arg.src2, None, [])?;
}
Statement::RetValue(_, id) => {
builder.ret_value(*id)?;
}
- Statement::Composite(c) => {
- let result_type = map.get_or_add_scalar(builder, c.typ.into());
- let result_id = Some(c.dst);
- builder.composite_extract(
- result_type,
- result_id,
- c.src_composite,
- [c.src_index],
- )?;
- }
- Statement::Undef(t, id) => {
- let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
- builder.undef(result_type, Some(*id));
- }
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
@@ -3254,6 +3096,38 @@ fn emit_function_body_ops( )?;
builder.bitcast(result_type, Some(*dst), temp)?;
}
+ Statement::RepackVector(repack) => {
+ if repack.is_extract {
+ let scalar_type = map.get_or_add_scalar(builder, repack.typ);
+ for (index, dst_id) in repack.unpacked.iter().enumerate() {
+ builder.composite_extract(
+ scalar_type,
+ Some(*dst_id),
+ repack.packed,
+ &[index as u32],
+ )?;
+ }
+ } else {
+ let vector_type = map.get_or_add(
+ builder,
+ SpirvType::Vector(
+ SpirvScalarKey::from(repack.typ),
+ repack.unpacked.len() as u8,
+ ),
+ );
+ let mut temp_vec = builder.undef(vector_type, None);
+ for (index, src_id) in repack.unpacked.iter().enumerate() {
+ temp_vec = builder.composite_insert(
+ vector_type,
+ None,
+ *src_id,
+ temp_vec,
+ &[index as u32],
+ )?;
+ }
+ builder.copy_object(vector_type, Some(repack.packed), temp_vec)?;
+ }
+ }
}
}
Ok(())
@@ -3271,7 +3145,7 @@ fn insert_shift_hack( 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16),
8 => map.get_or_add_scalar(builder, ast::ScalarType::B64),
4 => return Ok(offset_var),
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
};
Ok(builder.u_convert(result_type, None, offset_var)?)
}
@@ -3351,7 +3225,7 @@ fn emit_atom( let spirv_op = match op {
ast::AtomUIntOp::Add => dr::Builder::atomic_i_add,
ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => {
- return Err(TranslateError::Unreachable);
+ return Err(error_unreachable());
}
ast::AtomUIntOp::Min => dr::Builder::atomic_u_min,
ast::AtomUIntOp::Max => dr::Builder::atomic_u_max,
@@ -4165,6 +4039,58 @@ fn emit_implicit_conversion( Ok(())
}
+fn emit_load_var(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ details: &LoadVarDetails,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
+ match details.member_index {
+ Some((index, Some(width))) => {
+ let vector_type = match details.typ {
+ ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
+ let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?;
+ builder.composite_extract(
+ result_type,
+ Some(details.arg.dst),
+ vector_temp,
+ &[index as u32],
+ )?;
+ }
+ Some((index, None)) => {
+ let result_ptr_type = map.get_or_add(
+ builder,
+ SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function),
+ );
+ let index_spirv = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(index as u32),
+ )?;
+ let src = builder.in_bounds_access_chain(
+ result_ptr_type,
+ None,
+ details.arg.src,
+ &[index_spirv],
+ )?;
+ builder.load(result_type, Some(details.arg.dst), src, None, [])?;
+ }
+ None => {
+ builder.load(
+ result_type,
+ Some(details.arg.dst),
+ details.arg.src,
+ None,
+ [],
+ )?;
+ }
+ };
+ Ok(())
+}
+
fn normalize_identifiers<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
@@ -4290,9 +4216,11 @@ fn convert_to_stateful_memory_access<'a>( },
arg,
)) => {
- if let Some(src) = arg.src.underlying() {
- if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) {
- stateful_markers.push((arg.dst, *src));
+ if let (TypedOperand::Reg(dst), Some(src)) =
+ (arg.dst, arg.src.upcast().underlying())
+ {
+ if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) {
+ stateful_markers.push((dst, *src));
}
}
}
@@ -4320,7 +4248,9 @@ fn convert_to_stateful_memory_access<'a>( },
arg,
)) => {
- if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) {
+ if let (TypedOperand::Reg(dst), Some(src)) =
+ (&arg.dst, arg.src.upcast().underlying())
+ {
if func_args_64bit.contains(src) {
multi_hash_map_append(&mut stateful_init_reg, *dst, *src);
}
@@ -4369,13 +4299,17 @@ fn convert_to_stateful_memory_access<'a>( }),
arg,
)) => {
- if let Some(src1) = arg.src1.underlying() {
+ if let (TypedOperand::Reg(dst), Some(src1)) =
+ (arg.dst, arg.src1.upcast().underlying())
+ {
if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) {
- regs_ptr_new.insert(arg.dst);
+ regs_ptr_new.insert(dst);
}
- } else if let Some(src2) = arg.src2.underlying() {
+ } else if let (TypedOperand::Reg(dst), Some(src2)) =
+ (arg.dst, arg.src2.upcast().underlying())
+ {
if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) {
- regs_ptr_new.insert(arg.dst);
+ regs_ptr_new.insert(dst);
}
}
}
@@ -4426,19 +4360,20 @@ fn convert_to_stateful_memory_access<'a>( }),
arg,
)) if is_add_ptr_direct(&remapped_ids, &arg) => {
- let (ptr, offset) = match arg.src1.underlying() {
+ let (ptr, offset) = match arg.src1.upcast().underlying() {
Some(src1) if remapped_ids.contains_key(src1) => {
(remapped_ids.get(src1).unwrap(), arg.src2)
}
Some(src2) if remapped_ids.contains_key(src2) => {
(remapped_ids.get(src2).unwrap(), arg.src1)
}
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
};
+ let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
state_space: ast::LdStateSpace::Global,
- dst: *remapped_ids.get(&arg.dst).unwrap(),
+ dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
}))
@@ -4454,14 +4389,14 @@ fn convert_to_stateful_memory_access<'a>( }),
arg,
)) if is_add_ptr_direct(&remapped_ids, &arg) => {
- let (ptr, offset) = match arg.src1.underlying() {
+ let (ptr, offset) = match arg.src1.upcast().underlying() {
Some(src1) if remapped_ids.contains_key(src1) => {
(remapped_ids.get(src1).unwrap(), arg.src2)
}
Some(src2) if remapped_ids.contains_key(src2) => {
(remapped_ids.get(src2).unwrap(), arg.src1)
}
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
};
let offset_neg =
id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
@@ -4472,21 +4407,23 @@ fn convert_to_stateful_memory_access<'a>( },
ast::Arg2 {
src: offset,
- dst: offset_neg,
+ dst: TypedOperand::Reg(offset_neg),
},
)));
+ let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
state_space: ast::LdStateSpace::Global,
- dst: *remapped_ids.get(&arg.dst).unwrap(),
+ dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
- offset_src: ast::Operand::Reg(offset_neg),
+ offset_src: TypedOperand::Reg(offset_neg),
}))
}
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
- let new_statement = inst.visit_variable(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ let new_statement = inst.visit(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
+ expected_type: Option<&ast::Type>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
@@ -4499,14 +4436,13 @@ fn convert_to_stateful_memory_access<'a>( },
)?;
result.push(new_statement);
- for s in post_statements {
- result.push(s);
- }
+ result.extend(post_statements);
}
Statement::Call(call) => {
let mut post_statements = Vec::new();
- let new_statement = call.visit_variable(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>, expected_type| {
+ let new_statement = call.visit(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
+ expected_type: Option<&ast::Type>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
@@ -4519,11 +4455,28 @@ fn convert_to_stateful_memory_access<'a>( },
)?;
result.push(new_statement);
- for s in post_statements {
- result.push(s);
- }
+ result.extend(post_statements);
+ }
+ Statement::RepackVector(pack) => {
+ let mut post_statements = Vec::new();
+ let new_statement = pack.visit(
+ &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
+ expected_type: Option<&ast::Type>| {
+ convert_to_stateful_memory_access_postprocess(
+ id_defs,
+ &remapped_ids,
+ &func_args_ptr,
+ &mut result,
+ &mut post_statements,
+ arg_desc,
+ expected_type,
+ )
+ },
+ )?;
+ result.push(new_statement);
+ result.extend(post_statements);
}
- _ => return Err(TranslateError::Unreachable),
+ _ => return Err(error_unreachable()),
}
}
for arg in func_args.input.iter_mut() {
@@ -4588,7 +4541,7 @@ fn convert_to_stateful_memory_access_postprocess( None => match func_args_ptr.get(&arg_desc.op) {
Some(new_id) => {
if arg_desc.is_dst {
- return Err(TranslateError::Unreachable);
+ return Err(error_unreachable());
}
// We skip conversion here to trigger PtrAcces in a later pass
let old_type = match expected_type {
@@ -4617,13 +4570,20 @@ fn convert_to_stateful_memory_access_postprocess( }
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
- if !remapped_ids.contains_key(&arg.dst) {
- return false;
- }
- match arg.src1.underlying() {
- Some(src1) if remapped_ids.contains_key(src1) => true,
- Some(src2) if remapped_ids.contains_key(src2) => true,
- _ => false,
+ match arg.dst {
+ TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
+ return false
+ }
+ TypedOperand::Reg(dst) => {
+ if !remapped_ids.contains_key(&dst) {
+ return false;
+ }
+ match arg.src1.upcast().underlying() {
+ Some(src1) if remapped_ids.contains_key(src1) => true,
+ Some(src2) if remapped_ids.contains_key(src2) => true,
+ _ => false,
+ }
+ }
}
}
@@ -4962,14 +4922,13 @@ enum Statement<I, P: ast::ArgParams> { // SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Call(ResolvedCall<P>),
- LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
- StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
- Composite(CompositeRead),
+ LoadVar(LoadVarDetails),
+ StoreVar(StoreVarDetails),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
- Undef(ast::Type, spirv::Word),
PtrAccess(PtrAccess<P>),
+ RepackVector(RepackVectorDetails),
}
impl ExpandedStatement {
@@ -4981,19 +4940,19 @@ impl ExpandedStatement { Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| {
+ .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| {
Ok(f(arg.op, arg.is_dst))
})
.unwrap(),
- Statement::LoadVar(mut arg, typ) => {
- arg.dst = f(arg.dst, true);
- arg.src = f(arg.src, false);
- Statement::LoadVar(arg, typ)
+ Statement::LoadVar(mut details) => {
+ details.arg.dst = f(details.arg.dst, true);
+ details.arg.src = f(details.arg.src, false);
+ Statement::LoadVar(details)
}
- Statement::StoreVar(mut arg, typ) => {
- arg.src1 = f(arg.src1, false);
- arg.src2 = f(arg.src2, false);
- Statement::StoreVar(arg, typ)
+ Statement::StoreVar(mut details) => {
+ details.arg.src1 = f(details.arg.src1, false);
+ details.arg.src2 = f(details.arg.src2, false);
+ Statement::StoreVar(details)
}
Statement::Call(mut call) => {
for (id, typ) in call.ret_params.iter_mut() {
@@ -5010,11 +4969,6 @@ impl ExpandedStatement { }
Statement::Call(call)
}
- Statement::Composite(mut composite) => {
- composite.dst = f(composite.dst, true);
- composite.src_composite = f(composite.src_composite, false);
- Statement::Composite(composite)
- }
Statement::Conditional(mut conditional) => {
conditional.predicate = f(conditional.predicate, false);
conditional.if_true = f(conditional.if_true, false);
@@ -5034,10 +4988,6 @@ impl ExpandedStatement { let id = f(id, false);
Statement::RetValue(data, id)
}
- Statement::Undef(typ, id) => {
- let id = f(id, true);
- Statement::Undef(typ, id)
- }
Statement::PtrAccess(PtrAccess {
underlying_type,
state_space,
@@ -5056,19 +5006,100 @@ impl ExpandedStatement { offset_src: constant_src,
})
}
+ Statement::RepackVector(_) => todo!(),
}
}
}
+struct LoadVarDetails {
+ arg: ast::Arg2<ExpandedArgParams>,
+ typ: ast::Type,
+ // (index, vector_width)
+ // HACK ALERT
+ // For some reason IGC explodes when you try to load from builtin vectors
+ // using OpInBoundsAccessChain, the one true way to do it is to
+ // OpLoad+OpCompositeExtract
+ member_index: Option<(u8, Option<u8>)>,
+}
+
+struct StoreVarDetails {
+ arg: ast::Arg2St<ExpandedArgParams>,
+ typ: ast::Type,
+ member_index: Option<u8>,
+}
+
+struct RepackVectorDetails {
+ is_extract: bool,
+ typ: ast::ScalarType,
+ packed: spirv::Word,
+ unpacked: Vec<spirv::Word>,
+ vector_sema: ArgumentSemantics,
+}
+
+impl RepackVectorDetails {
+ fn map<
+ From: ArgParamsEx<Id = spirv::Word>,
+ To: ArgParamsEx<Id = spirv::Word>,
+ V: ArgumentMapVisitor<From, To>,
+ >(
+ self,
+ visitor: &mut V,
+ ) -> Result<RepackVectorDetails, TranslateError> {
+ let scalar = visitor.id(
+ ArgumentDescriptor {
+ op: self.packed,
+ is_dst: !self.is_extract,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
+ )?;
+ let scalar_type = self.typ;
+ let is_extract = self.is_extract;
+ let vector_sema = self.vector_sema;
+ let vector = self
+ .unpacked
+ .into_iter()
+ .map(|id| {
+ visitor.id(
+ ArgumentDescriptor {
+ op: id,
+ is_dst: is_extract,
+ sema: vector_sema,
+ },
+ Some(&ast::Type::Scalar(scalar_type)),
+ )
+ })
+ .collect::<Result<_, _>>()?;
+ Ok(RepackVectorDetails {
+ is_extract,
+ typ: self.typ,
+ packed: scalar,
+ unpacked: vector,
+ vector_sema,
+ })
+ }
+}
+
+impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
+ for RepackVectorDetails
+{
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::RepackVector(self.map::<_, _, _>(visitor)?))
+ }
+}
+
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
- pub func: spirv::Word,
- pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>,
+ pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
+ pub func: P::Id,
+ pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
- fn cast<U: ast::ArgParams<CallOperand = T::CallOperand>>(self) -> ResolvedCall<U> {
+ fn cast<U: ast::ArgParams<Id = T::Id, Operand = T::Operand>>(self) -> ResolvedCall<U> {
ResolvedCall {
uniform: self.uniform,
ret_params: self.ret_params,
@@ -5110,7 +5141,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> { .param_list
.into_iter()
.map::<Result<_, TranslateError>, _>(|(id, typ)| {
- let new_id = visitor.src_call_operand(
+ let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
@@ -5130,32 +5161,14 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> { }
}
-impl VisitVariable for ResolvedCall<TypedArgParams> {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- Ok(Statement::Call(self.map(f)?))
- }
-}
-
-impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
+ for ResolvedCall<T>
+{
+ fn visit(
self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- Ok(Statement::Call(self.map(f)?))
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::Call(self.map(visitor)?))
}
}
@@ -5208,18 +5221,14 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> { }
}
-impl VisitVariable for PtrAccess<TypedArgParams> {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
+ for PtrAccess<T>
+{
+ fn visit(
self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- Ok(Statement::PtrAccess(self.map(f)?))
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::PtrAccess(self.map(visitor)?))
}
}
@@ -5244,10 +5253,6 @@ enum NormalizedArgParams {} impl ast::ArgParams for NormalizedArgParams {
type Id = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
- type CallOperand = ast::CallOperand<spirv::Word>;
- type IdOrVector = ast::IdOrVector<spirv::Word>;
- type OperandOrVector = ast::OperandOrVector<spirv::Word>;
- type SrcMemberOperand = (spirv::Word, u8);
}
impl ArgParamsEx for NormalizedArgParams {
@@ -5273,11 +5278,7 @@ enum TypedArgParams {} impl ast::ArgParams for TypedArgParams {
type Id = spirv::Word;
- type Operand = ast::Operand<spirv::Word>;
- type CallOperand = ast::CallOperand<spirv::Word>;
- type IdOrVector = ast::IdOrVector<spirv::Word>;
- type OperandOrVector = ast::OperandOrVector<spirv::Word>;
- type SrcMemberOperand = (spirv::Word, u8);
+ type Operand = TypedOperand;
}
impl ArgParamsEx for TypedArgParams {
@@ -5289,6 +5290,25 @@ impl ArgParamsEx for TypedArgParams { }
}
+#[derive(Copy, Clone)]
+enum TypedOperand {
+ Reg(spirv::Word),
+ RegOffset(spirv::Word, i32),
+ Imm(ast::ImmediateValue),
+ VecMember(spirv::Word, u8),
+}
+
+impl TypedOperand {
+ fn upcast(self) -> ast::Operand<spirv::Word> {
+ match self {
+ TypedOperand::Reg(reg) => ast::Operand::Reg(reg),
+ TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx),
+ TypedOperand::Imm(x) => ast::Operand::Imm(x),
+ TypedOperand::VecMember(vec, idx) => ast::Operand::VecMember(vec, idx),
+ }
+ }
+}
+
type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
enum ExpandedArgParams {}
@@ -5297,10 +5317,6 @@ type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, Expanded impl ast::ArgParams for ExpandedArgParams {
type Id = spirv::Word;
type Operand = spirv::Word;
- type CallOperand = spirv::Word;
- type IdOrVector = spirv::Word;
- type OperandOrVector = spirv::Word;
- type SrcMemberOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
@@ -5312,29 +5328,6 @@ impl ArgParamsEx for ExpandedArgParams { }
}
-#[derive(Copy, Clone)]
-pub enum StateSpace {
- Reg,
- Const,
- Global,
- Local,
- Shared,
- Param,
-}
-
-impl From<ast::StateSpace> for StateSpace {
- fn from(ss: ast::StateSpace) -> Self {
- match ss {
- ast::StateSpace::Reg => StateSpace::Reg,
- ast::StateSpace::Const => StateSpace::Const,
- ast::StateSpace::Global => StateSpace::Global,
- ast::StateSpace::Local => StateSpace::Local,
- ast::StateSpace::Shared => StateSpace::Shared,
- ast::StateSpace::Param => StateSpace::Param,
- }
- }
-}
-
enum Directive<'input> {
Variable(ast::Variable<ast::VariableType, spirv::Word>),
Method(Function<'input>),
@@ -5359,26 +5352,6 @@ pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> { desc: ArgumentDescriptor<T::Operand>,
typ: &ast::Type,
) -> Result<U::Operand, TranslateError>;
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<T::IdOrVector>,
- typ: &ast::Type,
- ) -> Result<U::IdOrVector, TranslateError>;
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<T::OperandOrVector>,
- typ: &ast::Type,
- ) -> Result<U::OperandOrVector, TranslateError>;
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<T::CallOperand>,
- typ: &ast::Type,
- ) -> Result<U::CallOperand, TranslateError>;
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<T::SrcMemberOperand>,
- typ: (ast::ScalarType, u8),
- ) -> Result<U::SrcMemberOperand, TranslateError>;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
@@ -5399,42 +5372,10 @@ where fn operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(t))
- }
-
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<spirv::Word>,
typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(typ))
}
-
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(typ))
- }
-
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- t: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(t))
- }
-
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- (scalar_type, _): (ast::ScalarType, u8),
- ) -> Result<spirv::Word, TranslateError> {
- self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type)))
- }
}
impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> for T
@@ -5452,62 +5393,19 @@ where fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
- _: &ast::Type,
+ typ: &ast::Type,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
- match desc.op {
- ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)),
- ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)),
- ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
- }
- }
-
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::IdOrVector<&'a str>>,
- _: &ast::Type,
- ) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
- match desc.op {
- ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)),
- ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec(
- ids.into_iter().map(self).collect::<Result<_, _>>()?,
- )),
- }
- }
-
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::OperandOrVector<&'a str>>,
- _: &ast::Type,
- ) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
- match desc.op {
- ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)),
- ast::OperandOrVector::RegOffset(id, imm) => {
- Ok(ast::OperandOrVector::RegOffset(self(id)?, imm))
- }
- ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
- ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec(
- ids.into_iter().map(self).collect::<Result<_, _>>()?,
- )),
- }
- }
-
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<ast::CallOperand<&str>>,
- _: &ast::Type,
- ) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
- match desc.op {
- ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)),
- ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)),
- }
- }
-
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<(&str, u8)>,
- _: (ast::ScalarType, u8),
- ) -> Result<(spirv::Word, u8), TranslateError> {
- Ok((self(desc.op.0)?, desc.op.1))
+ Ok(match desc.op {
+ ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?),
+ ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm),
+ ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
+ ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member),
+ ast::Operand::VecPack(ref ids) => ast::Operand::VecPack(
+ ids.into_iter()
+ .map(|id| self.id(desc.new_op(id), Some(typ)))
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ })
}
}
@@ -5559,7 +5457,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
- ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
+ ast::Instruction::Call(_) => return Err(error_unreachable()),
ast::Instruction::Ld(d, a) => {
let new_args = a.map(visitor, &d)?;
ast::Instruction::Ld(d, new_args)
@@ -5752,18 +5650,12 @@ impl<T: ArgParamsEx> ast::Instruction<T> { }
}
-impl VisitVariable for ast::Instruction<TypedArgParams> {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
- >(
+impl<T: ArgParamsEx, U: ArgParamsEx> Visitable<T, U> for ast::Instruction<T> {
+ fn visit(
self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- Ok(Statement::Instruction(self.map(f)?))
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::Instruction(self.map(visitor)?))
}
}
@@ -5802,32 +5694,14 @@ impl ImplicitConversion { }
}
-impl VisitVariable for ImplicitConversion {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- self.map(f)
- }
-}
-
-impl VisitVariableExpanded for ImplicitConversion {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
- >(
+impl<From: ArgParamsEx<Id = spirv::Word>, To: ArgParamsEx<Id = spirv::Word>> Visitable<From, To>
+ for ImplicitConversion
+{
+ fn visit(
self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- self.map(f)
+ visitor: &mut impl ArgumentMapVisitor<From, To>,
+ ) -> Result<Statement<ast::Instruction<To>, To>, TranslateError> {
+ Ok(self.map(visitor)?)
}
}
@@ -5848,79 +5722,24 @@ where fn operand(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
- t: &ast::Type,
- ) -> Result<ast::Operand<spirv::Word>, TranslateError> {
- match desc.op {
- ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)),
- ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
- ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(
- self(desc.new_op(id), Some(t))?,
- imm,
- )),
- }
- }
-
- fn src_call_operand(
- &mut self,
- desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
- t: &ast::Type,
- ) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
- match desc.op {
- ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)),
- ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)),
- }
- }
-
- fn id_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
+ desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
- ) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
- match desc.op {
- ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)),
- ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec(
- ids.iter()
- .map(|id| self(desc.new_op(*id), Some(typ)))
- .collect::<Result<_, _>>()?,
- )),
- }
- }
-
- fn operand_or_vector(
- &mut self,
- desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
- typ: &ast::Type,
- ) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
- match desc.op {
- ast::OperandOrVector::Reg(id) => {
- Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?))
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match desc.op {
+ TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?),
+ TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
+ TypedOperand::RegOffset(id, imm) => {
+ TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm)
+ }
+ TypedOperand::VecMember(reg, index) => {
+ let scalar_type = match typ {
+ ast::Type::Scalar(scalar_t) => *scalar_t,
+ _ => return Err(error_unreachable()),
+ };
+ let vec_type = ast::Type::Vector(scalar_type, index + 1);
+ TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index)
}
- ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset(
- self(desc.new_op(id), Some(typ))?,
- imm,
- )),
- ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)),
- ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec(
- ids.iter()
- .map(|id| self(desc.new_op(*id), Some(typ)))
- .collect::<Result<_, _>>()?,
- )),
- }
- }
-
- fn src_member_operand(
- &mut self,
- desc: ArgumentDescriptor<(spirv::Word, u8)>,
- (scalar_type, vector_len): (ast::ScalarType, u8),
- ) -> Result<(spirv::Word, u8), TranslateError> {
- Ok((
- self(
- desc.new_op(desc.op.0),
- Some(&ast::Type::Vector(scalar_type.into(), vector_len)),
- )?,
- desc.op.1,
- ))
+ })
}
}
@@ -5942,7 +5761,7 @@ impl ast::Type { kind,
)))
}
- _ => Err(TranslateError::Unreachable),
+ _ => Err(error_unreachable()),
}
}
@@ -6182,67 +6001,9 @@ impl ast::Instruction<ExpandedArgParams> { }
}
-impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- Ok(Statement::Instruction(self.map(f)?))
- }
-}
-
type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
-struct CompositeRead {
- pub typ: ast::ScalarType,
- pub dst: spirv::Word,
- pub dst_semantics_override: Option<ArgumentSemantics>,
- pub src_composite: spirv::Word,
- pub src_index: u32,
- pub src_len: u32,
-}
-
-impl VisitVariableExpanded for CompositeRead {
- fn visit_variable_extended<
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
- >(
- self,
- f: &mut F,
- ) -> Result<ExpandedStatement, TranslateError> {
- let dst_sema = self
- .dst_semantics_override
- .unwrap_or(ArgumentSemantics::Default);
- Ok(Statement::Composite(CompositeRead {
- dst: f(
- ArgumentDescriptor {
- op: self.dst,
- is_dst: true,
- sema: dst_sema,
- },
- Some(&ast::Type::Scalar(self.typ)),
- )?,
- src_composite: f(
- ArgumentDescriptor {
- op: self.src_composite,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(self.typ, self.src_len as u8)),
- )?,
- ..self
- }))
- }
-}
-
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
@@ -6330,10 +6091,6 @@ impl From<ast::KernelArgumentType> for ast::Type { }
impl<T: ArgParamsEx> ast::Arg1<T> {
- fn cast<U: ArgParamsEx<Id = T::Id>>(self) -> ast::Arg1<U> {
- ast::Arg1 { src: self.src }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6352,10 +6109,6 @@ impl<T: ArgParamsEx> ast::Arg1<T> { }
impl<T: ArgParamsEx> ast::Arg1Bar<T> {
- fn cast<U: ArgParamsEx<Operand = T::Operand>>(self) -> ast::Arg1Bar<U> {
- ast::Arg1Bar { src: self.src }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6373,25 +6126,18 @@ impl<T: ArgParamsEx> ast::Arg1Bar<T> { }
impl<T: ArgParamsEx> ast::Arg2<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg2<U> {
- ast::Arg2 {
- src: self.src,
- dst: self.dst,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
- let new_dst = visitor.id(
+ let new_dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ t,
)?;
let new_src = visitor.operand(
ArgumentDescriptor {
@@ -6413,13 +6159,13 @@ impl<T: ArgParamsEx> ast::Arg2<T> { dst_t: &ast::Type,
src_t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(dst_t),
+ dst_t,
)?;
let src = visitor.operand(
ArgumentDescriptor {
@@ -6434,21 +6180,12 @@ impl<T: ArgParamsEx> ast::Arg2<T> { }
impl<T: ArgParamsEx> ast::Arg2Ld<T> {
- fn cast<U: ArgParamsEx<Operand = T::Operand, IdOrVector = T::IdOrVector>>(
- self,
- ) -> ast::Arg2Ld<U> {
- ast::Arg2Ld {
- dst: self.dst,
- src: self.src,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
details: &ast::LdDetails,
) -> Result<ast::Arg2Ld<U>, TranslateError> {
- let dst = visitor.id_or_vector(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -6478,15 +6215,6 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> { }
impl<T: ArgParamsEx> ast::Arg2St<T> {
- fn cast<U: ArgParamsEx<Operand = T::Operand, OperandOrVector = T::OperandOrVector>>(
- self,
- ) -> ast::Arg2St<U> {
- ast::Arg2St {
- src1: self.src1,
- src2: self.src2,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6509,7 +6237,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> { details.state_space.to_ld_ss(),
),
)?;
- let src2 = visitor.operand_or_vector(
+ let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
@@ -6527,29 +6255,7 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> { visitor: &mut V,
details: &ast::MovDetails,
) -> Result<ast::Arg2Mov<U>, TranslateError> {
- Ok(match self {
- ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?),
- ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?),
- })
- }
-}
-
-impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
- fn cast<U: ArgParamsEx<IdOrVector = P::IdOrVector, OperandOrVector = P::OperandOrVector>>(
- self,
- ) -> ast::Arg2MovNormal<U> {
- ast::Arg2MovNormal {
- dst: self.dst,
- src: self.src,
- }
- }
-
- fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
- self,
- visitor: &mut V,
- details: &ast::MovDetails,
- ) -> Result<ast::Arg2MovNormal<U>, TranslateError> {
- let dst = visitor.id_or_vector(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
@@ -6557,7 +6263,7 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> { },
&details.typ.clone().into(),
)?;
- let src = visitor.operand_or_vector(
+ let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
@@ -6569,144 +6275,11 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> { },
&details.typ.clone().into(),
)?;
- Ok(ast::Arg2MovNormal { dst, src })
- }
-}
-
-impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, SrcMemberOperand = T::SrcMemberOperand>>(
- self,
- ) -> ast::Arg2MovMember<U> {
- match self {
- ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2),
- ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src),
- ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2),
- }
- }
-
- fn vector_dst(&self) -> Option<&T::Id> {
- match self {
- ast::Arg2MovMember::Src(_, _) => None,
- ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => {
- Some(d)
- }
- }
- }
-
- fn vector_src(&self) -> Option<&T::SrcMemberOperand> {
- match self {
- ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d),
- ast::Arg2MovMember::Dst(_, _, _) => None,
- }
- }
-}
-
-impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
- fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
- self,
- visitor: &mut V,
- details: &ast::MovDetails,
- ) -> Result<ast::Arg2MovMember<U>, TranslateError> {
- match self {
- ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
- let scalar_type = details.typ.get_scalar()?;
- let dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(scalar_type, details.dst_width)),
- )?;
- let src1 = visitor.id(
- ArgumentDescriptor {
- op: composite_src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(scalar_type, details.dst_width)),
- )?;
- let src2 = visitor.id(
- ArgumentDescriptor {
- op: scalar_src,
- is_dst: false,
- sema: if details.src_is_address {
- ArgumentSemantics::Address
- } else if details.relaxed_src2_conv {
- ArgumentSemantics::DefaultRelaxed
- } else {
- ArgumentSemantics::Default
- },
- },
- Some(&details.typ.clone().into()),
- )?;
- Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2))
- }
- ast::Arg2MovMember::Src(dst, src) => {
- let dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(&details.typ.clone().into()),
- )?;
- let scalar_typ = details.typ.get_scalar()?;
- let src = visitor.src_member_operand(
- ArgumentDescriptor {
- op: src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- (scalar_typ.into(), details.src_width),
- )?;
- Ok(ast::Arg2MovMember::Src(dst, src))
- }
- ast::Arg2MovMember::Both((dst, len), composite_src, src) => {
- let scalar_type = details.typ.get_scalar()?;
- let dst = visitor.id(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(scalar_type, details.dst_width)),
- )?;
- let composite_src = visitor.id(
- ArgumentDescriptor {
- op: composite_src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- Some(&ast::Type::Vector(scalar_type, details.dst_width)),
- )?;
- let src = visitor.src_member_operand(
- ArgumentDescriptor {
- op: src,
- is_dst: false,
- sema: if details.relaxed_src2_conv {
- ArgumentSemantics::DefaultRelaxed
- } else {
- ArgumentSemantics::Default
- },
- },
- (scalar_type.into(), details.src_width),
- )?;
- Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src))
- }
- }
+ Ok(ast::Arg2Mov { dst, src })
}
}
impl<T: ArgParamsEx> ast::Arg3<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg3<U> {
- ast::Arg3 {
- dst: self.dst,
- src1: self.src1,
- src2: self.src2,
- }
- }
-
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6718,13 +6291,13 @@ impl<T: ArgParamsEx> ast::Arg3<T> { } else {
None
};
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(wide_type.as_ref().unwrap_or(typ)),
+ wide_type.as_ref().unwrap_or(typ),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6750,13 +6323,13 @@ impl<T: ArgParamsEx> ast::Arg3<T> { visitor: &mut V,
t: &ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(t),
+ t,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6784,13 +6357,13 @@ impl<T: ArgParamsEx> ast::Arg3<T> { state_space: ast::AtomSpace,
) -> Result<ast::Arg3<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ &ast::Type::Scalar(scalar_type),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6816,15 +6389,6 @@ impl<T: ArgParamsEx> ast::Arg3<T> { }
impl<T: ArgParamsEx> ast::Arg4<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4<U> {
- ast::Arg4 {
- dst: self.dst,
- src1: self.src1,
- src2: self.src2,
- src3: self.src3,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6836,13 +6400,13 @@ impl<T: ArgParamsEx> ast::Arg4<T> { } else {
None
};
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(wide_type.as_ref().unwrap_or(t)),
+ wide_type.as_ref().unwrap_or(t),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6881,13 +6445,13 @@ impl<T: ArgParamsEx> ast::Arg4<T> { visitor: &mut V,
t: ast::SelpType,
) -> Result<ast::Arg4<U>, TranslateError> {
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(t.into())),
+ &ast::Type::Scalar(t.into()),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6928,13 +6492,13 @@ impl<T: ArgParamsEx> ast::Arg4<T> { state_space: ast::AtomSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ &ast::Type::Scalar(scalar_type),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6976,13 +6540,13 @@ impl<T: ArgParamsEx> ast::Arg4<T> { visitor: &mut V,
typ: &ast::Type,
) -> Result<ast::Arg4<U>, TranslateError> {
- let dst = visitor.id(
+ let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(typ),
+ typ,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -7019,15 +6583,6 @@ impl<T: ArgParamsEx> ast::Arg4<T> { }
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg4Setp<U> {
- ast::Arg4Setp {
- dst1: self.dst1,
- dst2: self.dst2,
- src1: self.src1,
- src2: self.src2,
- }
- }
-
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -7079,22 +6634,12 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { }
}
-impl<T: ArgParamsEx> ast::Arg5<T> {
- fn cast<U: ArgParamsEx<Id = T::Id, Operand = T::Operand>>(self) -> ast::Arg5<U> {
- ast::Arg5 {
- dst1: self.dst1,
- dst2: self.dst2,
- src1: self.src1,
- src2: self.src2,
- src3: self.src3,
- }
- }
-
+impl<T: ArgParamsEx> ast::Arg5Setp<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: &ast::Type,
- ) -> Result<ast::Arg5<U>, TranslateError> {
+ ) -> Result<ast::Arg5Setp<U>, TranslateError> {
let dst1 = visitor.id(
ArgumentDescriptor {
op: self.dst1,
@@ -7140,7 +6685,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> { },
&ast::Type::Scalar(ast::ScalarType::Pred),
)?;
- Ok(ast::Arg5 {
+ Ok(ast::Arg5Setp {
dst1,
dst2,
src1,
@@ -7150,30 +6695,28 @@ impl<T: ArgParamsEx> ast::Arg5<T> { }
}
-impl ast::Type {
- fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> {
- match self {
- ast::Type::Vector(t, len) => Ok((*t, *len)),
- _ => Err(TranslateError::MismatchedType),
- }
- }
-
- fn get_scalar(&self) -> Result<ast::ScalarType, TranslateError> {
- match self {
- ast::Type::Scalar(t) => Ok(*t),
- _ => Err(TranslateError::MismatchedType),
- }
- }
-}
-
-impl<T> ast::CallOperand<T> {
+impl<T> ast::Operand<T> {
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
self,
f: &mut F,
- ) -> Result<ast::CallOperand<U>, TranslateError> {
+ ) -> Result<ast::Operand<U>, TranslateError> {
+ Ok(match self {
+ ast::Operand::Reg(reg) => ast::Operand::Reg(f(reg)?),
+ ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset),
+ ast::Operand::Imm(x) => ast::Operand::Imm(x),
+ ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx),
+ ast::Operand::VecPack(vec) => {
+ ast::Operand::VecPack(vec.into_iter().map(f).collect::<Result<_, _>>()?)
+ }
+ })
+ }
+}
+
+impl ast::Operand<spirv::Word> {
+ fn unwrap_reg(&self) -> Result<spirv::Word, TranslateError> {
match self {
- ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(f(id)?)),
- ast::CallOperand::Imm(x) => Ok(ast::CallOperand::Imm(x)),
+ ast::Operand::Reg(reg) => Ok(*reg),
+ _ => Err(error_unreachable()),
}
}
}
@@ -7394,15 +6937,8 @@ impl<T> ast::Operand<T> { match self {
ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r),
ast::Operand::Imm(_) => None,
- }
- }
-}
-
-impl<T> ast::OperandOrVector<T> {
- fn single_underlying(&self) -> Option<&T> {
- match self {
- ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r),
- ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None,
+ ast::Operand::VecMember(reg, _) => Some(reg),
+ ast::Operand::VecPack(..) => None,
}
}
}
@@ -7500,7 +7036,7 @@ fn bitcast_physical_pointer( if let Some(space) = ss {
Ok(Some(ConversionKind::BitToPtr(space)))
} else {
- Err(TranslateError::Unreachable)
+ Err(error_unreachable())
}
}
ast::Type::Scalar(ast::ScalarType::B32)
|