From 770a37945259d020de0f003ada9d590ae2ac5232 Mon Sep 17 00:00:00 2001 From: vosen Date: Wed, 9 Dec 2020 00:20:06 +0100 Subject: Refactor how vectors are handled (#20) Current code has a problem with handling vector members: "b.x" in "mov.u32 a, b.x". This functionality has been kinda tacked-on and has annoying issues: * vector members support is only limited to being source of movs (so "add.u32 a.x, b.x, c.y" will not work) * the width of "b" in "b.x" is not known, which led to some "interesting" workarounds * passes can either convert all member accesses to other member accesses or to temporaries. No way to convert some member accesses to temporaries (which we need for an important fix) This commit solves all this --- ptx/src/ast.rs | 81 +- ptx/src/ptx.lalrpop | 121 +- ptx/src/test/spirv_run/vector.spvtxt | 124 +- ptx/src/test/spirv_run/vector_extract.spvtxt | 183 +-- ptx/src/translate.rs | 2000 ++++++++++---------------- 5 files changed, 995 insertions(+), 1514 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 367f060..aba6bda 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -557,7 +557,7 @@ pub enum Instruction { Mul(MulDetails, Arg3

), Add(ArithDetails, Arg3

), Setp(SetpData, Arg4Setp

), - SetpBool(SetpBoolData, Arg5

), + SetpBool(SetpBoolData, Arg5Setp

), Not(BooleanType, Arg2

), Bra(BraData, Arg1

), Cvt(CvtDetails, Arg2

), @@ -614,16 +614,12 @@ pub struct CallInst { pub uniform: bool, pub ret_params: Vec, pub func: P::Id, - pub param_list: Vec, + pub param_list: Vec, } pub trait ArgParams { type Id; type Operand; - type IdOrVector; - type OperandOrVector; - type CallOperand; - type SrcMemberOperand; } pub struct ParsedArgParams<'a> { @@ -633,10 +629,6 @@ pub struct ParsedArgParams<'a> { impl<'a> ArgParams for ParsedArgParams<'a> { type Id = &'a str; type Operand = Operand<&'a str>; - type CallOperand = CallOperand<&'a str>; - type IdOrVector = IdOrVector<&'a str>; - type OperandOrVector = OperandOrVector<&'a str>; - type SrcMemberOperand = (&'a str, u8); } pub struct Arg1 { @@ -648,45 +640,32 @@ pub struct Arg1Bar { } pub struct Arg2 { - pub dst: P::Id, + pub dst: P::Operand, pub src: P::Operand, } pub struct Arg2Ld { - pub dst: P::IdOrVector, + pub dst: P::Operand, pub src: P::Operand, } pub struct Arg2St { pub src1: P::Operand, - pub src2: P::OperandOrVector, -} - -pub enum Arg2Mov { - Normal(Arg2MovNormal

), - Member(Arg2MovMember

), -} - -pub struct Arg2MovNormal { - pub dst: P::IdOrVector, - pub src: P::OperandOrVector, + pub src2: P::Operand, } -// We duplicate dst here because during further compilation -// composite dst and composite src will receive different ids -pub enum Arg2MovMember { - Dst((P::Id, u8), P::Id, P::Id), - Src(P::Id, P::SrcMemberOperand), - Both((P::Id, u8), P::Id, P::SrcMemberOperand), +pub struct Arg2Mov { + pub dst: P::Operand, + pub src: P::Operand, } pub struct Arg3 { - pub dst: P::Id, + pub dst: P::Operand, pub src1: P::Operand, pub src2: P::Operand, } pub struct Arg4 { - pub dst: P::Id, + pub dst: P::Operand, pub src1: P::Operand, pub src2: P::Operand, pub src3: P::Operand, @@ -699,7 +678,7 @@ pub struct Arg4Setp { pub src2: P::Operand, } -pub struct Arg5 { +pub struct Arg5Setp { pub dst1: P::Id, pub dst2: Option, pub src1: P::Operand, @@ -715,39 +694,13 @@ pub enum ImmediateValue { F64(f64), } -#[derive(Copy, Clone)] -pub enum Operand { - Reg(ID), - RegOffset(ID, i32), - Imm(ImmediateValue), -} - -#[derive(Copy, Clone)] -pub enum CallOperand { - Reg(ID), - Imm(ImmediateValue), -} - -pub enum IdOrVector { - Reg(ID), - Vec(Vec), -} - -pub enum OperandOrVector { - Reg(ID), - RegOffset(ID, i32), +#[derive(Clone)] +pub enum Operand { + Reg(Id), + RegOffset(Id, i32), Imm(ImmediateValue), - Vec(Vec), -} - -impl From> for OperandOrVector { - fn from(this: Operand) -> Self { - match this { - Operand::Reg(r) => OperandOrVector::Reg(r), - Operand::RegOffset(r, imm) => OperandOrVector::RegOffset(r, imm), - Operand::Imm(imm) => OperandOrVector::Imm(imm), - } - } + VecMember(Id, u8), + VecPack(Vec), } pub enum VectorPrefix { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d2c235a..fd2a3f1 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -721,7 +721,7 @@ Instruction: ast::Instruction> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld InstLd: ast::Instruction> = { - "ld" "," => { + "ld" "," => { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -734,16 +734,6 @@ InstLd: ast::Instruction> = { } }; -IdOrVector: ast::IdOrVector<&'input str> = { - => ast::IdOrVector::Reg(dst), - => ast::IdOrVector::Vec(dst) -} - -OperandOrVector: ast::OperandOrVector<&'input str> = { - => ast::OperandOrVector::from(op), - => ast::OperandOrVector::Vec(dst) -} - LdStType: ast::LdStType = { => ast::LdStType::Vector(t, v), => ast::LdStType::Scalar(t), @@ -780,27 +770,17 @@ LdCacheOperator: ast::LdCacheOperator = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov InstMov: ast::Instruction> = { - => ast::Instruction::Mov(m.0, m.1), - => ast::Instruction::Mov(m.0, m.1), -}; - - -MovNormal: (ast::MovDetails, ast::Arg2Mov>) = { - "mov" "," => {( - ast::MovDetails::new(ast::Type::Scalar(t)), - ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: ast::IdOrVector::Reg(dst), src: src.into() }) - )}, - "mov" "," => {( - ast::MovDetails::new(ast::Type::Vector(t, pref)), - ast::Arg2Mov::Normal(ast::Arg2MovNormal{ dst: dst, src: src }) - )} -} - -MovVector: (ast::MovDetails, ast::Arg2Mov>) = { - "mov" => {( - ast::MovDetails::new(ast::Type::Scalar(t.into())), - ast::Arg2Mov::Member(a) - )}, + "mov" "," => { + let mov_type = match pref { + Some(vec_width) => ast::Type::Vector(t, vec_width), + None => ast::Type::Scalar(t) + }; + let details = ast::MovDetails::new(mov_type); + ast::Instruction::Mov( + details, + ast::Arg2Mov { dst, src } + ) + } } #[inline] @@ -819,21 +799,6 @@ MovScalarType: ast::ScalarType = { ".pred" => ast::ScalarType::Pred }; -#[inline] -MovVectorType: ast::ScalarType = { - ".b16" => ast::ScalarType::B16, - ".b32" => ast::ScalarType::B32, - ".b64" => ast::ScalarType::B64, - ".u16" => ast::ScalarType::U16, - ".u32" => ast::ScalarType::U32, - ".u64" => ast::ScalarType::U64, - ".s16" => ast::ScalarType::S16, - ".s32" => ast::ScalarType::S32, - ".s64" => ast::ScalarType::S64, - ".f32" => ast::ScalarType::F32, - ".f64" => ast::ScalarType::F64, -}; - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul @@ -921,7 +886,7 @@ InstAdd: ast::Instruction> = { // TODO: support f16 setp InstSetp: ast::Instruction> = { "setp" => ast::Instruction::Setp(d, a), - "setp" => ast::Instruction::SetpBool(d, a), + "setp" => ast::Instruction::SetpBool(d, a), }; SetpMode: ast::SetpData = { @@ -1198,7 +1163,7 @@ ShrType: ast::ShrType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st // Warning: NVIDIA documentation is incorrect, you can specify scope only once InstSt: ast::Instruction> = { - "st" "," => { + "st" "," => { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), @@ -1775,9 +1740,9 @@ Operand: ast::Operand<&'input str> = { => ast::Operand::Imm(x) }; -CallOperand: ast::CallOperand<&'input str> = { - => ast::CallOperand::Reg(r), - => ast::CallOperand::Imm(x) +CallOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + => ast::Operand::Imm(x) }; // TODO: start parsing whole constants sub-language: @@ -1825,13 +1790,7 @@ Arg1Bar: ast::Arg1Bar> = { }; Arg2: ast::Arg2> = { - "," => ast::Arg2{<>} -}; - -Arg2MovMember: ast::Arg2MovMember> = { - "," => ast::Arg2MovMember::Dst(dst, dst.0, src), - "," => ast::Arg2MovMember::Src(dst, src), - "," => ast::Arg2MovMember::Both(dst, dst.0, src), + "," => ast::Arg2{<>} }; MemberOperand: (&'input str, u8) = { @@ -1855,19 +1814,19 @@ VectorExtract: Vec<&'input str> = { }; Arg3: ast::Arg3> = { - "," "," => ast::Arg3{<>} + "," "," => ast::Arg3{<>} }; Arg3Atom: ast::Arg3> = { - "," "[" "]" "," => ast::Arg3{<>} + "," "[" "]" "," => ast::Arg3{<>} }; Arg4: ast::Arg4> = { - "," "," "," => ast::Arg4{<>} + "," "," "," => ast::Arg4{<>} }; Arg4Atom: ast::Arg4> = { - "," "[" "]" "," "," => ast::Arg4{<>} + "," "[" "]" "," "," => ast::Arg4{<>} }; Arg4Setp: ast::Arg4Setp> = { @@ -1875,22 +1834,50 @@ Arg4Setp: ast::Arg4Setp> = { }; // TODO: pass src3 negation somewhere -Arg5: ast::Arg5> = { - "," "," "," "!"? => ast::Arg5{<>} +Arg5Setp: ast::Arg5Setp> = { + "," "," "," "!"? => ast::Arg5Setp{<>} }; -ArgCall: (Vec<&'input str>, &'input str, Vec>) = { +ArgCall: (Vec<&'input str>, &'input str, Vec>) = { "(" > ")" "," "," "(" > ")" => { (ret_params, func, param_list) }, "," "(" > ")" => (Vec::new(), func, param_list), - => (Vec::new(), func, Vec::>::new()), + => (Vec::new(), func, Vec::>::new()), }; OptionalDst: &'input str = { "|" => dst2 } +SrcOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + "+" => ast::Operand::RegOffset(r, offset), + => ast::Operand::Imm(x), + => { + let (reg, idx) = mem_op; + ast::Operand::VecMember(reg, idx) + } +} + +SrcOperandVec: ast::Operand<&'input str> = { + => normal, + => ast::Operand::VecPack(vec), +} + +DstOperand: ast::Operand<&'input str> = { + => ast::Operand::Reg(r), + => { + let (reg, idx) = mem_op; + ast::Operand::VecMember(reg, idx) + } +} + +DstOperandVec: ast::Operand<&'input str> = { + => normal, + => ast::Operand::VecPack(vec), +} + VectorPrefix: u8 = { ".v2" => 2, ".v4" => 4 diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index 535e480..a77ab7d 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -7,91 +7,93 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %57 = OpExtInstImport "OpenCL.std" + %51 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %31 "vector" + OpEntryPoint Kernel %25 "vector" %void = OpTypeVoid %uint = OpTypeInt 32 0 %v2uint = OpTypeVector %uint 2 - %61 = OpTypeFunction %v2uint %v2uint + %55 = OpTypeFunction %v2uint %v2uint %_ptr_Function_v2uint = OpTypePointer Function %v2uint %_ptr_Function_uint = OpTypePointer Function %uint + %uint_0 = OpConstant %uint 0 + %uint_1 = OpConstant %uint 1 %ulong = OpTypeInt 64 0 - %65 = OpTypeFunction %void %ulong %ulong + %67 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_v2uint = OpTypePointer Generic %v2uint - %1 = OpFunction %v2uint None %61 + %1 = OpFunction %v2uint None %55 %7 = OpFunctionParameter %v2uint - %30 = OpLabel + %24 = OpLabel %2 = OpVariable %_ptr_Function_v2uint Function %3 = OpVariable %_ptr_Function_v2uint Function %4 = OpVariable %_ptr_Function_v2uint Function %5 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_uint Function OpStore %3 %7 - %9 = OpLoad %v2uint %3 - %27 = OpCompositeExtract %uint %9 0 - %8 = OpCopyObject %uint %27 + %59 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_0 + %9 = OpLoad %uint %59 + %8 = OpCopyObject %uint %9 OpStore %5 %8 - %11 = OpLoad %v2uint %3 - %28 = OpCompositeExtract %uint %11 1 - %10 = OpCopyObject %uint %28 + %61 = OpInBoundsAccessChain %_ptr_Function_uint %3 %uint_1 + %11 = OpLoad %uint %61 + %10 = OpCopyObject %uint %11 OpStore %6 %10 %13 = OpLoad %uint %5 %14 = OpLoad %uint %6 %12 = OpIAdd %uint %13 %14 OpStore %6 %12 - %16 = OpLoad %v2uint %4 - %17 = OpLoad %uint %6 - %15 = OpCompositeInsert %v2uint %17 %16 0 - OpStore %4 %15 - %19 = OpLoad %v2uint %4 - %20 = OpLoad %uint %6 - %18 = OpCompositeInsert %v2uint %20 %19 1 - OpStore %4 %18 + %16 = OpLoad %uint %6 + %15 = OpCopyObject %uint %16 + %62 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %62 %15 + %18 = OpLoad %uint %6 + %17 = OpCopyObject %uint %18 + %63 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + OpStore %63 %17 + %64 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_1 + %20 = OpLoad %uint %64 + %19 = OpCopyObject %uint %20 + %65 = OpInBoundsAccessChain %_ptr_Function_uint %4 %uint_0 + OpStore %65 %19 %22 = OpLoad %v2uint %4 - %23 = OpLoad %v2uint %4 - %29 = OpCompositeExtract %uint %23 1 - %21 = OpCompositeInsert %v2uint %29 %22 0 - OpStore %4 %21 - %25 = OpLoad %v2uint %4 - %24 = OpCopyObject %v2uint %25 - OpStore %2 %24 - %26 = OpLoad %v2uint %2 - OpReturnValue %26 + %21 = OpCopyObject %v2uint %22 + OpStore %2 %21 + %23 = OpLoad %v2uint %2 + OpReturnValue %23 OpFunctionEnd - %31 = OpFunction %void None %65 - %40 = OpFunctionParameter %ulong - %41 = OpFunctionParameter %ulong - %55 = OpLabel - %32 = OpVariable %_ptr_Function_ulong Function + %25 = OpFunction %void None %67 + %34 = OpFunctionParameter %ulong + %35 = OpFunctionParameter %ulong + %49 = OpLabel + %26 = OpVariable %_ptr_Function_ulong Function + %27 = OpVariable %_ptr_Function_ulong Function + %28 = OpVariable %_ptr_Function_ulong Function + %29 = OpVariable %_ptr_Function_ulong Function + %30 = OpVariable %_ptr_Function_v2uint Function + %31 = OpVariable %_ptr_Function_uint Function + %32 = OpVariable %_ptr_Function_uint Function %33 = OpVariable %_ptr_Function_ulong Function - %34 = OpVariable %_ptr_Function_ulong Function - %35 = OpVariable %_ptr_Function_ulong Function - %36 = OpVariable %_ptr_Function_v2uint Function - %37 = OpVariable %_ptr_Function_uint Function - %38 = OpVariable %_ptr_Function_uint Function - %39 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %40 - OpStore %33 %41 - %42 = OpLoad %ulong %32 - OpStore %34 %42 - %43 = OpLoad %ulong %33 - OpStore %35 %43 - %45 = OpLoad %ulong %34 - %52 = OpConvertUToPtr %_ptr_Generic_v2uint %45 - %44 = OpLoad %v2uint %52 - OpStore %36 %44 - %47 = OpLoad %v2uint %36 - %46 = OpFunctionCall %v2uint %1 %47 - OpStore %36 %46 - %49 = OpLoad %v2uint %36 - %53 = OpBitcast %ulong %49 - %48 = OpCopyObject %ulong %53 - OpStore %39 %48 - %50 = OpLoad %ulong %35 - %51 = OpLoad %v2uint %36 - %54 = OpConvertUToPtr %_ptr_Generic_v2uint %50 - OpStore %54 %51 + OpStore %26 %34 + OpStore %27 %35 + %36 = OpLoad %ulong %26 + OpStore %28 %36 + %37 = OpLoad %ulong %27 + OpStore %29 %37 + %39 = OpLoad %ulong %28 + %46 = OpConvertUToPtr %_ptr_Generic_v2uint %39 + %38 = OpLoad %v2uint %46 + OpStore %30 %38 + %41 = OpLoad %v2uint %30 + %40 = OpFunctionCall %v2uint %1 %41 + OpStore %30 %40 + %43 = OpLoad %v2uint %30 + %47 = OpBitcast %ulong %43 + %42 = OpCopyObject %ulong %47 + OpStore %33 %42 + %44 = OpLoad %ulong %29 + %45 = OpLoad %v2uint %30 + %48 = OpConvertUToPtr %_ptr_Generic_v2uint %44 + OpStore %48 %45 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector_extract.spvtxt b/ptx/src/test/spirv_run/vector_extract.spvtxt index 4943189..2037dec 100644 --- a/ptx/src/test/spirv_run/vector_extract.spvtxt +++ b/ptx/src/test/spirv_run/vector_extract.spvtxt @@ -7,12 +7,12 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %73 = OpExtInstImport "OpenCL.std" + %61 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "vector_extract" %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %76 = OpTypeFunction %void %ulong %ulong + %64 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %ushort = OpTypeInt 16 0 %_ptr_Function_ushort = OpTypePointer Function %ushort @@ -21,10 +21,10 @@ %uchar = OpTypeInt 8 0 %v4uchar = OpTypeVector %uchar 4 %_ptr_CrossWorkgroup_v4uchar = OpTypePointer CrossWorkgroup %v4uchar - %1 = OpFunction %void None %76 - %11 = OpFunctionParameter %ulong - %12 = OpFunctionParameter %ulong - %71 = OpLabel + %1 = OpFunction %void None %64 + %17 = OpFunctionParameter %ulong + %18 = OpFunctionParameter %ulong + %59 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -34,89 +34,92 @@ %8 = OpVariable %_ptr_Function_ushort Function %9 = OpVariable %_ptr_Function_ushort Function %10 = OpVariable %_ptr_Function_v4ushort Function - OpStore %2 %11 - OpStore %3 %12 - %13 = OpLoad %ulong %2 - OpStore %4 %13 - %14 = OpLoad %ulong %3 - OpStore %5 %14 - %19 = OpLoad %ulong %4 - %61 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %19 - %43 = OpLoad %v4uchar %61 - %62 = OpCompositeExtract %uchar %43 0 - %85 = OpBitcast %uchar %62 - %15 = OpUConvert %ushort %85 - %63 = OpCompositeExtract %uchar %43 1 - %86 = OpBitcast %uchar %63 - %16 = OpUConvert %ushort %86 - %64 = OpCompositeExtract %uchar %43 2 - %87 = OpBitcast %uchar %64 - %17 = OpUConvert %ushort %87 - %65 = OpCompositeExtract %uchar %43 3 - %88 = OpBitcast %uchar %65 - %18 = OpUConvert %ushort %88 - OpStore %6 %15 - OpStore %7 %16 - OpStore %8 %17 - OpStore %9 %18 - %21 = OpLoad %ushort %7 - %22 = OpLoad %ushort %8 - %23 = OpLoad %ushort %9 - %24 = OpLoad %ushort %6 - %44 = OpUndef %v4ushort - %45 = OpCompositeInsert %v4ushort %21 %44 0 - %46 = OpCompositeInsert %v4ushort %22 %45 1 - %47 = OpCompositeInsert %v4ushort %23 %46 2 - %48 = OpCompositeInsert %v4ushort %24 %47 3 - %20 = OpCopyObject %v4ushort %48 - OpStore %10 %20 - %29 = OpLoad %v4ushort %10 - %49 = OpCopyObject %v4ushort %29 - %25 = OpCompositeExtract %ushort %49 0 - %26 = OpCompositeExtract %ushort %49 1 - %27 = OpCompositeExtract %ushort %49 2 - %28 = OpCompositeExtract %ushort %49 3 - OpStore %8 %25 - OpStore %9 %26 - OpStore %6 %27 - OpStore %7 %28 - %34 = OpLoad %ushort %8 - %35 = OpLoad %ushort %9 - %36 = OpLoad %ushort %6 - %37 = OpLoad %ushort %7 - %51 = OpUndef %v4ushort - %52 = OpCompositeInsert %v4ushort %34 %51 0 - %53 = OpCompositeInsert %v4ushort %35 %52 1 - %54 = OpCompositeInsert %v4ushort %36 %53 2 - %55 = OpCompositeInsert %v4ushort %37 %54 3 - %50 = OpCopyObject %v4ushort %55 - %30 = OpCompositeExtract %ushort %50 0 - %31 = OpCompositeExtract %ushort %50 1 - %32 = OpCompositeExtract %ushort %50 2 - %33 = OpCompositeExtract %ushort %50 3 - OpStore %9 %30 - OpStore %6 %31 - OpStore %7 %32 - OpStore %8 %33 - %38 = OpLoad %ulong %5 - %39 = OpLoad %ushort %6 - %40 = OpLoad %ushort %7 - %41 = OpLoad %ushort %8 - %42 = OpLoad %ushort %9 - %56 = OpUndef %v4uchar - %89 = OpBitcast %ushort %39 - %66 = OpUConvert %uchar %89 - %57 = OpCompositeInsert %v4uchar %66 %56 0 - %90 = OpBitcast %ushort %40 - %67 = OpUConvert %uchar %90 - %58 = OpCompositeInsert %v4uchar %67 %57 1 - %91 = OpBitcast %ushort %41 - %68 = OpUConvert %uchar %91 - %59 = OpCompositeInsert %v4uchar %68 %58 2 - %92 = OpBitcast %ushort %42 - %69 = OpUConvert %uchar %92 - %60 = OpCompositeInsert %v4uchar %69 %59 3 - %70 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %38 - OpStore %70 %60 + OpStore %2 %17 + OpStore %3 %18 + %19 = OpLoad %ulong %2 + OpStore %4 %19 + %20 = OpLoad %ulong %3 + OpStore %5 %20 + %21 = OpLoad %ulong %4 + %49 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %21 + %11 = OpLoad %v4uchar %49 + %50 = OpCompositeExtract %uchar %11 0 + %51 = OpCompositeExtract %uchar %11 1 + %52 = OpCompositeExtract %uchar %11 2 + %53 = OpCompositeExtract %uchar %11 3 + %73 = OpBitcast %uchar %50 + %22 = OpUConvert %ushort %73 + %74 = OpBitcast %uchar %51 + %23 = OpUConvert %ushort %74 + %75 = OpBitcast %uchar %52 + %24 = OpUConvert %ushort %75 + %76 = OpBitcast %uchar %53 + %25 = OpUConvert %ushort %76 + OpStore %6 %22 + OpStore %7 %23 + OpStore %8 %24 + OpStore %9 %25 + %26 = OpLoad %ushort %7 + %27 = OpLoad %ushort %8 + %28 = OpLoad %ushort %9 + %29 = OpLoad %ushort %6 + %77 = OpUndef %v4ushort + %78 = OpCompositeInsert %v4ushort %26 %77 0 + %79 = OpCompositeInsert %v4ushort %27 %78 1 + %80 = OpCompositeInsert %v4ushort %28 %79 2 + %81 = OpCompositeInsert %v4ushort %29 %80 3 + %12 = OpCopyObject %v4ushort %81 + %30 = OpCopyObject %v4ushort %12 + OpStore %10 %30 + %31 = OpLoad %v4ushort %10 + %13 = OpCopyObject %v4ushort %31 + %32 = OpCompositeExtract %ushort %13 0 + %33 = OpCompositeExtract %ushort %13 1 + %34 = OpCompositeExtract %ushort %13 2 + %35 = OpCompositeExtract %ushort %13 3 + OpStore %8 %32 + OpStore %9 %33 + OpStore %6 %34 + OpStore %7 %35 + %36 = OpLoad %ushort %8 + %37 = OpLoad %ushort %9 + %38 = OpLoad %ushort %6 + %39 = OpLoad %ushort %7 + %82 = OpUndef %v4ushort + %83 = OpCompositeInsert %v4ushort %36 %82 0 + %84 = OpCompositeInsert %v4ushort %37 %83 1 + %85 = OpCompositeInsert %v4ushort %38 %84 2 + %86 = OpCompositeInsert %v4ushort %39 %85 3 + %15 = OpCopyObject %v4ushort %86 + %14 = OpCopyObject %v4ushort %15 + %40 = OpCompositeExtract %ushort %14 0 + %41 = OpCompositeExtract %ushort %14 1 + %42 = OpCompositeExtract %ushort %14 2 + %43 = OpCompositeExtract %ushort %14 3 + OpStore %9 %40 + OpStore %6 %41 + OpStore %7 %42 + OpStore %8 %43 + %44 = OpLoad %ushort %6 + %45 = OpLoad %ushort %7 + %46 = OpLoad %ushort %8 + %47 = OpLoad %ushort %9 + %87 = OpBitcast %ushort %44 + %54 = OpUConvert %uchar %87 + %88 = OpBitcast %ushort %45 + %55 = OpUConvert %uchar %88 + %89 = OpBitcast %ushort %46 + %56 = OpUConvert %uchar %89 + %90 = OpBitcast %ushort %47 + %57 = OpUConvert %uchar %90 + %91 = OpUndef %v4uchar + %92 = OpCompositeInsert %v4uchar %54 %91 0 + %93 = OpCompositeInsert %v4uchar %55 %92 1 + %94 = OpCompositeInsert %v4uchar %56 %93 2 + %95 = OpCompositeInsert %v4uchar %57 %94 3 + %16 = OpCopyObject %v4uchar %95 + %48 = OpLoad %ulong %5 + %58 = OpConvertUToPtr %_ptr_CrossWorkgroup_v4uchar %48 + OpStore %58 %16 OpReturn OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 15211ab..20578eb 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -27,6 +27,16 @@ quick_error! { } } +#[cfg(debug_assertions)] +fn error_unreachable() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_unreachable() -> TranslateError { + TranslateError::Unreachable +} + #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), @@ -82,7 +92,7 @@ impl ast::Type { ast::Type::Pointer(ast::PointerType::Scalar(t), space) => { ast::Type::Pointer(ast::PointerType::Pointer(t, space), space) } - ast::Type::Pointer(_, _) => return Err(TranslateError::Unreachable), + ast::Type::Pointer(_, _) => return Err(error_unreachable()), }) } } @@ -364,7 +374,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, &components) } ast::Type::Array(typ, dims) => match dims.as_slice() { - [] => return Err(TranslateError::Unreachable), + [] => return Err(error_unreachable()), [dim] => { let result_type = self .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); @@ -791,13 +801,14 @@ fn convert_dynamic_shared_memory_usage<'input>( ast::PointerStateSpace::Shared, )), }); - let shared_var_st = ExpandedStatement::StoreVar( - ast::Arg2St { + let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { src1: shared_var_id, src2: shared_id_param, }, - ast::Type::Scalar(ast::ScalarType::B8), - ); + typ: ast::Type::Scalar(ast::ScalarType::B8), + member_index: None, + }); let mut new_statements = vec![shared_var, shared_var_st]; replace_uses_of_shared_memory( &mut new_statements, @@ -963,18 +974,17 @@ fn compute_denorm_information<'input>( denorm_count_map_update(&mut flush_counter, width, flush); } } - Statement::LoadVar(_, _) => {} - Statement::StoreVar(_, _) => {} + Statement::LoadVar(..) => {} + Statement::StoreVar(..) => {} Statement::Call(_) => {} - Statement::Composite(_) => {} Statement::Conditional(_) => {} Statement::Conversion(_) => {} Statement::Constant(_) => {} Statement::RetValue(_, _) => {} - Statement::Undef(_, _) => {} Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} + Statement::RepackVector(_) => {} } } denorm_methods.insert(method_key, flush_counter); @@ -1307,7 +1317,7 @@ fn to_ssa<'input, 'b>( let mut numeric_id_defs = id_defs.finish(); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = - convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; + convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; let typed_statements = convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( @@ -1431,7 +1441,7 @@ fn normalize_variable_decls(directives: &mut Vec) { fn convert_to_typed_statements( func: Vec, fn_defs: &GlobalFnDeclResolver, - id_defs: &NumericIdResolver, + id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { let mut result = Vec::::with_capacity(func.len()); for s in func { @@ -1447,7 +1457,7 @@ fn convert_to_typed_statements( .partition(|(_, arg_type)| arg_type.is_param()); let normalized_input_args = out_params .into_iter() - .map(|(id, typ)| (ast::CallOperand::Reg(id), typ)) + .map(|(id, typ)| (ast::Operand::Reg(id), typ)) .chain(in_args.into_iter()) .collect(); let resolved_call = ResolvedCall { @@ -1456,205 +1466,117 @@ fn convert_to_typed_statements( func: call.func, param_list: normalized_input_args, }; - result.push(Statement::Call(resolved_call)); - } - ast::Instruction::Ld(d, arg) => { - result.push(Statement::Instruction(ast::Instruction::Ld(d, arg.cast()))); - } - ast::Instruction::St(d, arg) => { - result.push(Statement::Instruction(ast::Instruction::St(d, arg.cast()))); - } - ast::Instruction::Mov(mut d, args) => match args { - ast::Arg2Mov::Normal(arg) => { - if let Some(src_id) = arg.src.single_underlying() { - let (typ, _) = id_defs.get_typed(*src_id)?; - let take_address = match typ { - ast::Type::Scalar(_) => false, - ast::Type::Vector(_, _) => false, - ast::Type::Array(_, _) => true, - ast::Type::Pointer(_, _) => true, - }; - d.src_is_address = take_address; - } - result.push(Statement::Instruction(ast::Instruction::Mov( - d, - ast::Arg2Mov::Normal(arg.cast()), - ))); - } - ast::Arg2Mov::Member(args) => { - if let Some(dst_typ) = args.vector_dst() { - match id_defs.get_typed(*dst_typ)? { - (ast::Type::Vector(_, len), _) => { - d.dst_width = len; - } - _ => return Err(TranslateError::MismatchedType), - } - }; - if let Some((src_typ, _)) = args.vector_src() { - match id_defs.get_typed(*src_typ)? { - (ast::Type::Vector(_, len), _) => { - d.src_width = len; - } - _ => return Err(TranslateError::MismatchedType), - } + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let reresolved_call = resolved_call.visit(&mut visitor)?; + visitor.func.push(reresolved_call); + visitor.func.extend(visitor.post_stmts); + } + ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { + if let Some(src_id) = src.underlying() { + let (typ, _) = id_defs.get_typed(*src_id)?; + let take_address = match typ { + ast::Type::Scalar(_) => false, + ast::Type::Vector(_, _) => false, + ast::Type::Array(_, _) => true, + ast::Type::Pointer(_, _) => true, }; - result.push(Statement::Instruction(ast::Instruction::Mov( - d, - ast::Arg2Mov::Member(args.cast()), - ))); + d.src_is_address = take_address; } - }, - ast::Instruction::Mul(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Mul(d, a.cast()))) - } - ast::Instruction::Add(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Add(d, a.cast()))) - } - ast::Instruction::Setp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Setp(d, a.cast()))) - } - ast::Instruction::SetpBool(d, a) => result.push(Statement::Instruction( - ast::Instruction::SetpBool(d, a.cast()), - )), - ast::Instruction::Not(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Not(d, a.cast()))) - } - ast::Instruction::Bra(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Bra(d, a.cast()))) - } - ast::Instruction::Cvt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Cvt(d, a.cast()))) - } - ast::Instruction::Cvta(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Cvta(d, a.cast()))) - } - ast::Instruction::Shl(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Shl(d, a.cast()))) - } - ast::Instruction::Ret(d) => { - result.push(Statement::Instruction(ast::Instruction::Ret(d))) - } - ast::Instruction::Abs(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Abs(d, a.cast()))) - } - ast::Instruction::Mad(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Mad(d, a.cast()))) - } - ast::Instruction::Shr(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Shr(d, a.cast()))) - } - ast::Instruction::Or(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast()))) - } - ast::Instruction::Sub(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast()))) - } - ast::Instruction::Min(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast()))) - } - ast::Instruction::Max(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast()))) - } - ast::Instruction::Rcp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast()))) - } - ast::Instruction::And(d, a) => { - result.push(Statement::Instruction(ast::Instruction::And(d, a.cast()))) - } - ast::Instruction::Selp(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Selp(d, a.cast()))) - } - ast::Instruction::Bar(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Bar(d, a.cast()))) - } - ast::Instruction::Atom(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Atom(d, a.cast()))) - } - ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction( - ast::Instruction::AtomCas(d, a.cast()), - )), - ast::Instruction::Div(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast()))) - } - ast::Instruction::Sqrt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast()))) - } - ast::Instruction::Rsqrt(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast()))) - } - ast::Instruction::Neg(d, a) => { - result.push(Statement::Instruction(ast::Instruction::Neg(d, a.cast()))) - } - ast::Instruction::Sin { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Sin { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Cos { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Cos { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Lg2 { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Lg2 { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Ex2 { flush_to_zero, arg } => { - result.push(Statement::Instruction(ast::Instruction::Ex2 { - flush_to_zero, - arg: arg.cast(), - })) - } - ast::Instruction::Clz { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Clz { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Brev { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Brev { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Popc { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Popc { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Xor { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Xor { - typ, - arg: arg.cast(), - })) - } - ast::Instruction::Bfe { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Bfe { - typ, - arg: arg.cast(), - })) + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let instruction = Statement::Instruction( + ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?, + ); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); } - ast::Instruction::Rem { typ, arg } => { - result.push(Statement::Instruction(ast::Instruction::Rem { - typ, - arg: arg.cast(), - })) + inst => { + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let instruction = Statement::Instruction(inst.map(&mut visitor)?); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), Statement::Conditional(c) => result.push(Statement::Conditional(c)), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) } +struct VectorRepackVisitor<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut NumericIdResolver<'a>, + post_stmts: Option, +} + +impl<'a, 'b> VectorRepackVisitor<'a, 'b> { + fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { + VectorRepackVisitor { + func, + id_def, + post_stmts: None, + } + } + + fn convert_vector( + &mut self, + is_dst: bool, + vector_sema: ArgumentSemantics, + typ: &ast::Type, + idx: Vec, + ) -> Result { + // mov.u32 foobar, {a,b}; + let scalar_t = match typ { + ast::Type::Vector(scalar_t, _) => *scalar_t, + _ => return Err(TranslateError::MismatchedType), + }; + let temp_vec = self.id_def.new_non_variable(Some(typ.clone())); + let statement = Statement::RepackVector(RepackVectorDetails { + is_extract: is_dst, + typ: scalar_t, + packed: temp_vec, + unpacked: idx, + vector_sema, + }); + if is_dst { + self.post_stmts = Some(statement); + } else { + self.func.push(statement); + } + Ok(temp_vec) + } +} + +impl<'a, 'b> ArgumentMapVisitor + for VectorRepackVisitor<'a, 'b> +{ + fn id( + &mut self, + desc: ArgumentDescriptor, + _: Option<&ast::Type>, + ) -> Result { + Ok(desc.op) + } + + fn operand( + &mut self, + desc: ArgumentDescriptor>, + typ: &ast::Type, + ) -> Result { + Ok(match desc.op { + ast::Operand::Reg(reg) => TypedOperand::Reg(reg), + ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), + ast::Operand::Imm(x) => TypedOperand::Imm(x), + ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), + ast::Operand::VecPack(vec) => { + TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?) + } + }) + } +} + //TODO: share common code between this and to_ptx_impl_bfe_call fn to_ptx_impl_atomic_call( id_defs: &mut NumericIdResolver, @@ -1872,17 +1794,16 @@ fn normalize_labels( labels_in_use.insert(cond.if_true); labels_in_use.insert(cond.if_false); } - Statement::Composite(_) - | Statement::Call(_) - | Statement::Variable(_) - | Statement::LoadVar(_, _) - | Statement::StoreVar(_, _) - | Statement::RetValue(_, _) - | Statement::Conversion(_) - | Statement::Constant(_) - | Statement::Label(_) - | Statement::Undef(_, _) - | Statement::PtrAccess { .. } => {} + Statement::Call(..) + | Statement::Variable(..) + | Statement::LoadVar(..) + | Statement::StoreVar(..) + | Statement::RetValue(..) + | Statement::Conversion(..) + | Statement::Constant(..) + | Statement::Label(..) + | Statement::PtrAccess { .. } + | Statement::RepackVector(..) => {} } } iter::once(Statement::Label(id_def.new_non_variable(None))) @@ -1929,7 +1850,7 @@ fn normalize_predicates( } Statement::Variable(var) => result.push(Statement::Variable(var)), // Blocks are flattened when resolving ids - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) @@ -1956,7 +1877,7 @@ fn insert_mem_ssa_statements<'a, 'b>( array_init: arg.array_init.clone(), })); } - None => return Err(TranslateError::Unreachable), + None => return Err(error_unreachable()), } } for spirv_arg in fn_decl.input.iter_mut() { @@ -1970,13 +1891,14 @@ fn insert_mem_ssa_statements<'a, 'b>( name: spirv_arg.name, array_init: spirv_arg.array_init.clone(), })); - result.push(Statement::StoreVar( - ast::Arg2St { + result.push(Statement::StoreVar(StoreVarDetails { + arg: ast::Arg2St { src1: spirv_arg.name, src2: new_id, }, typ, - )); + member_index: None, + })); spirv_arg.name = new_id; } None => {} @@ -1993,13 +1915,14 @@ fn insert_mem_ssa_statements<'a, 'b>( if let &[out_param] = &fn_decl.output.as_slice() { let (typ, _) = id_def.get_typed(out_param.name)?; let new_id = id_def.new_non_variable(Some(typ.clone())); - result.push(Statement::LoadVar( - ast::Arg2 { + result.push(Statement::LoadVar(LoadVarDetails { + arg: ast::Arg2 { dst: new_id, src: out_param.name, }, - typ.clone(), - )); + typ: typ.clone(), + member_index: None, + })); result.push(Statement::RetValue(d, new_id)); } else { result.push(Statement::Instruction(ast::Instruction::Ret(d))) @@ -2010,13 +1933,14 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Conditional(mut bra) => { let generated_id = id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred))); - result.push(Statement::LoadVar( - Arg2 { + result.push(Statement::LoadVar(LoadVarDetails { + arg: Arg2 { dst: generated_id, src: bra.predicate, }, - ast::Type::Scalar(ast::ScalarType::Pred), - )); + typ: ast::Type::Scalar(ast::ScalarType::Pred), + member_index: None, + })); bra.predicate = generated_id; result.push(Statement::Conditional(bra)); } @@ -2026,8 +1950,11 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::PtrAccess(ptr_access) => { insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)? } + Statement::RepackVector(repack) => { + insert_mem_ssa_statement_default(id_def, &mut result, repack)? + } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), } } Ok(result) @@ -2059,101 +1986,156 @@ fn type_to_variable_type( scalar_type .clone() .try_into() - .map_err(|_| TranslateError::Unreachable)?, - (*space) - .try_into() - .map_err(|_| TranslateError::Unreachable)?, + .map_err(|_| error_unreachable())?, + (*space).try_into().map_err(|_| error_unreachable())?, ))) } ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None, - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }) } -trait VisitVariable: Sized { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result; -} -trait VisitVariableExpanded { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +trait Visitable: Sized { + fn visit( self, - f: &mut F, - ) -> Result; + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, To>, TranslateError>; } -struct VisitArgumentDescriptor<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> { +struct VisitArgumentDescriptor< + 'a, + Ctor: FnOnce(spirv::Word) -> Statement, U>, + U: ArgParamsEx, +> { desc: ArgumentDescriptor, typ: &'a ast::Type, stmt_ctor: Ctor, } -impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded - for VisitArgumentDescriptor<'a, Ctor> +impl< + 'a, + Ctor: FnOnce(spirv::Word) -> Statement, U>, + T: ArgParamsEx, + U: ArgParamsEx, + > Visitable for VisitArgumentDescriptor<'a, Ctor, U> { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( + fn visit( self, - f: &mut F, - ) -> Result { - f(self.desc, Some(self.typ)).map(self.stmt_ctor) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?)) } } -fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( - id_def: &mut NumericIdResolver, - result: &mut Vec, - stmt: F, -) -> Result<(), TranslateError> { - let mut post_statements = Vec::new(); - let new_statement = stmt.visit_variable( - &mut |desc: ArgumentDescriptor, expected_type| { - if expected_type.is_none() { - return Ok(desc.op); - }; - let (var_type, is_variable) = id_def.get_typed(desc.op)?; - if !is_variable { - return Ok(desc.op); - } - let generated_id = id_def.new_non_variable(Some(var_type.clone())); - if !desc.is_dst { - result.push(Statement::LoadVar( - Arg2 { - dst: generated_id, - src: desc.op, +struct InsertMemSSAVisitor<'a, 'input> { + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + post_statements: Vec, +} + +impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { + fn symbol( + &mut self, + desc: ArgumentDescriptor<(spirv::Word, Option)>, + expected_type: Option<&ast::Type>, + ) -> Result { + let symbol = desc.op.0; + if expected_type.is_none() { + return Ok(symbol); + }; + let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?; + if !is_variable { + return Ok(symbol); + }; + let member_index = match desc.op.1 { + Some(idx) => { + let vector_width = match var_type { + ast::Type::Vector(scalar_t, width) => { + var_type = ast::Type::Scalar(scalar_t); + width + } + _ => return Err(TranslateError::MismatchedType), + }; + Some(( + idx, + if self.id_def.special_registers.contains_key(&symbol) { + Some(vector_width) + } else { + None }, - var_type, - )); - } else { - post_statements.push(Statement::StoreVar( - Arg2St { - src1: desc.op, + )) + } + None => None, + }; + let generated_id = self.id_def.new_non_variable(Some(var_type.clone())); + if !desc.is_dst { + self.func.push(Statement::LoadVar(LoadVarDetails { + arg: Arg2 { + dst: generated_id, + src: symbol, + }, + typ: var_type, + member_index, + })); + } else { + self.post_statements + .push(Statement::StoreVar(StoreVarDetails { + arg: Arg2St { + src1: symbol, src2: generated_id, }, - var_type, - )); + typ: var_type, + member_index: member_index.map(|(idx, _)| idx), + })); + } + Ok(generated_id) + } +} + +impl<'a, 'input> ArgumentMapVisitor + for InsertMemSSAVisitor<'a, 'input> +{ + fn id( + &mut self, + desc: ArgumentDescriptor, + typ: Option<&ast::Type>, + ) -> Result { + self.symbol(desc.new_op((desc.op, None)), typ) + } + + fn operand( + &mut self, + desc: ArgumentDescriptor, + typ: &ast::Type, + ) -> Result { + Ok(match desc.op { + TypedOperand::Reg(reg) => { + TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) } - Ok(generated_id) - }, - )?; - result.push(new_statement); - result.append(&mut post_statements); + TypedOperand::RegOffset(reg, offset) => { + TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset) + } + op @ TypedOperand::Imm(..) => op, + TypedOperand::VecMember(symbol, index) => { + TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) + } + }) + } +} + +fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable>( + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + stmt: S, +) -> Result<(), TranslateError> { + let mut visitor = InsertMemSSAVisitor { + id_def, + func, + post_statements: Vec::new(), + }; + let new_stmt = stmt.visit(&mut visitor)?; + visitor.func.push(new_stmt); + visitor.func.extend(visitor.post_statements); Ok(()) } @@ -2193,15 +2175,19 @@ fn expand_arguments<'a, 'b>( result.push(Statement::PtrAccess(new_inst)); result.extend(post_stmts); } + Statement::RepackVector(repack) => { + let mut visitor = FlattenArguments::new(&mut result, id_def); + let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts); + result.push(Statement::RepackVector(new_inst)); + result.extend(post_stmts); + } Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), - Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), - Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), + Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), + Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), - Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => { - return Err(TranslateError::Unreachable) - } + Statement::Constant(_) => return Err(error_unreachable()), } } Ok(result) @@ -2225,27 +2211,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } } - fn insert_composite_read( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver<'a>, - typ: (ast::ScalarType, u8), - scalar_dst: Option, - scalar_sema_override: Option, - composite_src: (spirv::Word, u8), - ) -> spirv::Word { - let new_id = - scalar_dst.unwrap_or_else(|| id_def.new_non_variable(ast::Type::Scalar(typ.0))); - func.push(Statement::Composite(CompositeRead { - typ: typ.0, - dst: new_id, - dst_semantics_override: scalar_sema_override, - src_composite: composite_src.0, - src_index: composite_src.1 as u32, - src_len: typ.1 as u32, - })); - new_id - } - fn reg( &mut self, desc: ArgumentDescriptor, @@ -2367,69 +2332,6 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { })); Ok(id) } - - fn member_src( - &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - typ: (ast::ScalarType, u8), - ) -> Result { - if desc.is_dst { - return Err(TranslateError::Unreachable); - } - let new_id = Self::insert_composite_read( - self.func, - self.id_def, - typ, - None, - Some(desc.sema), - desc.op, - ); - Ok(new_id) - } - - fn vector( - &mut self, - desc: ArgumentDescriptor<&Vec>, - typ: &ast::Type, - ) -> Result { - let (scalar_type, vec_len) = typ.get_vector()?; - if !desc.is_dst { - let mut new_id = self.id_def.new_non_variable(typ.clone()); - self.func.push(Statement::Undef(typ.clone(), new_id)); - for (idx, id) in desc.op.iter().enumerate() { - let newer_id = self.id_def.new_non_variable(typ.clone()); - self.func.push(Statement::Instruction(ast::Instruction::Mov( - ast::MovDetails { - typ: ast::Type::Scalar(scalar_type), - src_is_address: false, - dst_width: vec_len, - src_width: 0, - relaxed_src2_conv: desc.sema == ArgumentSemantics::DefaultRelaxed, - }, - ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( - (newer_id, idx as u8), - new_id, - *id, - )), - ))); - new_id = newer_id; - } - Ok(new_id) - } else { - let new_id = self.id_def.new_non_variable(typ.clone()); - for (idx, id) in desc.op.iter().enumerate() { - Self::insert_composite_read( - &mut self.post_stmts, - self.id_def, - (scalar_type, vec_len), - Some(*id), - Some(desc.sema), - (new_id, idx as u8), - ); - } - Ok(new_id) - } - } } impl<'a, 'b> ArgumentMapVisitor for FlattenArguments<'a, 'b> { @@ -2443,58 +2345,16 @@ impl<'a, 'b> ArgumentMapVisitor for FlattenAr fn operand( &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { match desc.op { - ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::Operand::Imm(x) => self.immediate(desc.new_op(x), typ), - ast::Operand::RegOffset(reg, offset) => { + TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), + TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ), + TypedOperand::RegOffset(reg, offset) => { self.reg_offset(desc.new_op((reg, offset)), typ) } - } - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)), - ast::CallOperand::Imm(x) => self.immediate(desc.new_op(x), typ), - } - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - typ: (ast::ScalarType, u8), - ) -> Result { - self.member_src(desc, typ) - } - - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::IdOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ), - } - } - - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result { - match desc.op { - ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)), - ast::OperandOrVector::RegOffset(r, imm) => self.reg_offset(desc.new_op((r, imm)), typ), - ast::OperandOrVector::Imm(imm) => self.immediate(desc.new_op(imm), typ), - ast::OperandOrVector::Vec(ref v) => self.vector(desc.new_op(v), typ), + TypedOperand::VecMember(..) => Err(error_unreachable()), } } } @@ -2543,7 +2403,7 @@ fn insert_implicit_conversions( if let ast::Instruction::AtomCas(d, _) = &inst { state_space = Some(d.space.to_ld_ss()); } - if let ast::Instruction::Mov(_, ast::Arg2Mov::Normal(_)) = &inst { + if let ast::Instruction::Mov(..) = &inst { default_conversion_fn = should_bitcast_packed; } insert_implicit_conversions_impl( @@ -2554,13 +2414,6 @@ fn insert_implicit_conversions( state_space, )?; } - Statement::Composite(composite) => insert_implicit_conversions_impl( - &mut result, - id_def, - composite, - should_bitcast_wrapper, - None, - )?, Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -2593,14 +2446,20 @@ fn insert_implicit_conversions( Some(state_space), )?; } + Statement::RepackVector(repack) => insert_implicit_conversions_impl( + &mut result, + id_def, + repack, + should_bitcast_wrapper, + None, + )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) | s @ Statement::Variable(_) - | s @ Statement::LoadVar(_, _) - | s @ Statement::StoreVar(_, _) - | s @ Statement::Undef(_, _) + | s @ Statement::LoadVar(..) + | s @ Statement::StoreVar(..) | s @ Statement::RetValue(_, _) => result.push(s), } } @@ -2610,7 +2469,7 @@ fn insert_implicit_conversions( fn insert_implicit_conversions_impl( func: &mut Vec, id_def: &mut MutableNumericIdResolver, - stmt: impl VisitVariableExpanded, + stmt: impl Visitable, default_conversion_fn: for<'a> fn( &'a ast::Type, &'a ast::Type, @@ -2619,63 +2478,65 @@ fn insert_implicit_conversions_impl( state_space: Option, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); - let statement = stmt.visit_variable_extended(&mut |desc, typ| { - let instr_type = match typ { - None => return Ok(desc.op), - Some(t) => t, - }; - let operand_type = id_def.get_typed(desc.op)?; - let mut conversion_fn = default_conversion_fn; - match desc.sema { - ArgumentSemantics::Default => {} - ArgumentSemantics::DefaultRelaxed => { - if desc.is_dst { - conversion_fn = should_convert_relaxed_dst_wrapper; - } else { - conversion_fn = should_convert_relaxed_src_wrapper; + let statement = stmt.visit( + &mut |desc: ArgumentDescriptor, typ: Option<&ast::Type>| { + let instr_type = match typ { + None => return Ok(desc.op), + Some(t) => t, + }; + let operand_type = id_def.get_typed(desc.op)?; + let mut conversion_fn = default_conversion_fn; + match desc.sema { + ArgumentSemantics::Default => {} + ArgumentSemantics::DefaultRelaxed => { + if desc.is_dst { + conversion_fn = should_convert_relaxed_dst_wrapper; + } else { + conversion_fn = should_convert_relaxed_src_wrapper; + } } - } - ArgumentSemantics::PhysicalPointer => { - conversion_fn = bitcast_physical_pointer; - } - ArgumentSemantics::RegisterPointer => { - conversion_fn = bitcast_register_pointer; - } - ArgumentSemantics::Address => { - conversion_fn = force_bitcast_ptr_to_bit; - } - }; - match conversion_fn(&operand_type, instr_type, state_space)? { - Some(conv_kind) => { - let conv_output = if desc.is_dst { - &mut post_conv - } else { - &mut *func - }; - let mut from = instr_type.clone(); - let mut to = operand_type; - let mut src = id_def.new_non_variable(instr_type.clone()); - let mut dst = desc.op; - let result = Ok(src); - if !desc.is_dst { - mem::swap(&mut src, &mut dst); - mem::swap(&mut from, &mut to); + ArgumentSemantics::PhysicalPointer => { + conversion_fn = bitcast_physical_pointer; } - conv_output.push(Statement::Conversion(ImplicitConversion { - src, - dst, - from, - to, - kind: conv_kind, - src_sema: ArgumentSemantics::Default, - dst_sema: ArgumentSemantics::Default, - })); - result - } - None => Ok(desc.op), - } - })?; - func.push(statement); + ArgumentSemantics::RegisterPointer => { + conversion_fn = bitcast_register_pointer; + } + ArgumentSemantics::Address => { + conversion_fn = force_bitcast_ptr_to_bit; + } + }; + match conversion_fn(&operand_type, instr_type, state_space)? { + Some(conv_kind) => { + let conv_output = if desc.is_dst { + &mut post_conv + } else { + &mut *func + }; + let mut from = instr_type.clone(); + let mut to = operand_type; + let mut src = id_def.new_non_variable(instr_type.clone()); + let mut dst = desc.op; + let result = Ok(src); + if !desc.is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from, &mut to); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from, + to, + kind: conv_kind, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, + })); + result + } + None => Ok(desc.op), + } + }, + )?; + func.push(statement); func.append(&mut post_conv); Ok(()) } @@ -2861,38 +2722,11 @@ fn emit_function_body_ops( } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, - ast::Instruction::Mov(d, arg) => match arg { - ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src }) - | ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => { - let result_type = map - .get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); - builder.copy_object(result_type, Some(*dst), *src)?; - } - ast::Arg2Mov::Member(ast::Arg2MovMember::Dst( - dst, - composite_src, - scalar_src, - )) - | ast::Arg2Mov::Member(ast::Arg2MovMember::Both( - dst, - composite_src, - scalar_src, - )) => { - let scalar_type = d.typ.get_scalar()?; - let result_type = map.get_or_add( - builder, - SpirvType::from(ast::Type::Vector(scalar_type, d.dst_width)), - ); - let result_id = Some(dst.0); - builder.composite_insert( - result_type, - result_id, - *scalar_src, - *composite_src, - [dst.1 as u32], - )?; - } - }, + ast::Instruction::Mov(d, arg) => { + let result_type = + map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); + builder.copy_object(result_type, Some(arg.dst), arg.src)?; + } ast::Instruction::Mul(mul, arg) => match mul { ast::MulDetails::Signed(ref ctr) => { emit_mul_sint(builder, map, opencl, ctr, arg)? @@ -3202,30 +3036,38 @@ fn emit_function_body_ops( builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } }, - Statement::LoadVar(arg, typ) => { - let type_id = map.get_or_add(builder, SpirvType::from(typ.clone())); - builder.load(type_id, Some(arg.dst), arg.src, None, [])?; + Statement::LoadVar(details) => { + emit_load_var(builder, map, details)?; } - Statement::StoreVar(arg, _) => { - builder.store(arg.src1, arg.src2, None, [])?; + Statement::StoreVar(details) => { + let dst_ptr = match details.member_index { + Some(index) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::new_pointer( + details.typ.clone(), + spirv::StorageClass::Function, + ), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + builder.in_bounds_access_chain( + result_ptr_type, + None, + details.arg.src1, + &[index_spirv], + )? + } + None => details.arg.src1, + }; + builder.store(dst_ptr, details.arg.src2, None, [])?; } Statement::RetValue(_, id) => { builder.ret_value(*id)?; } - Statement::Composite(c) => { - let result_type = map.get_or_add_scalar(builder, c.typ.into()); - let result_id = Some(c.dst); - builder.composite_extract( - result_type, - result_id, - c.src_composite, - [c.src_index], - )?; - } - Statement::Undef(t, id) => { - let result_type = map.get_or_add(builder, SpirvType::from(t.clone())); - builder.undef(result_type, Some(*id)); - } Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -3254,6 +3096,38 @@ fn emit_function_body_ops( )?; builder.bitcast(result_type, Some(*dst), temp)?; } + Statement::RepackVector(repack) => { + if repack.is_extract { + let scalar_type = map.get_or_add_scalar(builder, repack.typ); + for (index, dst_id) in repack.unpacked.iter().enumerate() { + builder.composite_extract( + scalar_type, + Some(*dst_id), + repack.packed, + &[index as u32], + )?; + } + } else { + let vector_type = map.get_or_add( + builder, + SpirvType::Vector( + SpirvScalarKey::from(repack.typ), + repack.unpacked.len() as u8, + ), + ); + let mut temp_vec = builder.undef(vector_type, None); + for (index, src_id) in repack.unpacked.iter().enumerate() { + temp_vec = builder.composite_insert( + vector_type, + None, + *src_id, + temp_vec, + &[index as u32], + )?; + } + builder.copy_object(vector_type, Some(repack.packed), temp_vec)?; + } + } } } Ok(()) @@ -3271,7 +3145,7 @@ fn insert_shift_hack( 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), 4 => return Ok(offset_var), - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; Ok(builder.u_convert(result_type, None, offset_var)?) } @@ -3351,7 +3225,7 @@ fn emit_atom( let spirv_op = match op { ast::AtomUIntOp::Add => dr::Builder::atomic_i_add, ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => { - return Err(TranslateError::Unreachable); + return Err(error_unreachable()); } ast::AtomUIntOp::Min => dr::Builder::atomic_u_min, ast::AtomUIntOp::Max => dr::Builder::atomic_u_max, @@ -4165,6 +4039,58 @@ fn emit_implicit_conversion( Ok(()) } +fn emit_load_var( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &LoadVarDetails, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); + match details.member_index { + Some((index, Some(width))) => { + let vector_type = match details.typ { + ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), + _ => return Err(TranslateError::MismatchedType), + }; + let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type)); + let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?; + builder.composite_extract( + result_type, + Some(details.arg.dst), + vector_temp, + &[index as u32], + )?; + } + Some((index, None)) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + let src = builder.in_bounds_access_chain( + result_ptr_type, + None, + details.arg.src, + &[index_spirv], + )?; + builder.load(result_type, Some(details.arg.dst), src, None, [])?; + } + None => { + builder.load( + result_type, + Some(details.arg.dst), + details.arg.src, + None, + [], + )?; + } + }; + Ok(()) +} + fn normalize_identifiers<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, fn_defs: &GlobalFnDeclResolver<'a, 'b>, @@ -4290,9 +4216,11 @@ fn convert_to_stateful_memory_access<'a>( }, arg, )) => { - if let Some(src) = arg.src.underlying() { - if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) { - stateful_markers.push((arg.dst, *src)); + if let (TypedOperand::Reg(dst), Some(src)) = + (arg.dst, arg.src.upcast().underlying()) + { + if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) { + stateful_markers.push((dst, *src)); } } } @@ -4320,7 +4248,9 @@ fn convert_to_stateful_memory_access<'a>( }, arg, )) => { - if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) { + if let (TypedOperand::Reg(dst), Some(src)) = + (&arg.dst, arg.src.upcast().underlying()) + { if func_args_64bit.contains(src) { multi_hash_map_append(&mut stateful_init_reg, *dst, *src); } @@ -4369,13 +4299,17 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) => { - if let Some(src1) = arg.src1.underlying() { + if let (TypedOperand::Reg(dst), Some(src1)) = + (arg.dst, arg.src1.upcast().underlying()) + { if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) { - regs_ptr_new.insert(arg.dst); + regs_ptr_new.insert(dst); } - } else if let Some(src2) = arg.src2.underlying() { + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arg.dst, arg.src2.upcast().underlying()) + { if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) { - regs_ptr_new.insert(arg.dst); + regs_ptr_new.insert(dst); } } } @@ -4426,19 +4360,20 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) if is_add_ptr_direct(&remapped_ids, &arg) => { - let (ptr, offset) = match arg.src1.underlying() { + let (ptr, offset) = match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => { (remapped_ids.get(src1).unwrap(), arg.src2) } Some(src2) if remapped_ids.contains_key(src2) => { (remapped_ids.get(src2).unwrap(), arg.src1) } - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; + let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, - dst: *remapped_ids.get(&arg.dst).unwrap(), + dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: offset, })) @@ -4454,14 +4389,14 @@ fn convert_to_stateful_memory_access<'a>( }), arg, )) if is_add_ptr_direct(&remapped_ids, &arg) => { - let (ptr, offset) = match arg.src1.underlying() { + let (ptr, offset) = match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => { (remapped_ids.get(src1).unwrap(), arg.src2) } Some(src2) if remapped_ids.contains_key(src2) => { (remapped_ids.get(src2).unwrap(), arg.src1) } - _ => return Err(TranslateError::Unreachable), + _ => return Err(error_unreachable()), }; let offset_neg = id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64))); @@ -4472,21 +4407,23 @@ fn convert_to_stateful_memory_access<'a>( }, ast::Arg2 { src: offset, - dst: offset_neg, + dst: TypedOperand::Reg(offset_neg), }, ))); + let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, - dst: *remapped_ids.get(&arg.dst).unwrap(), + dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, - offset_src: ast::Operand::Reg(offset_neg), + offset_src: TypedOperand::Reg(offset_neg), })) } Statement::Instruction(inst) => { let mut post_statements = Vec::new(); - let new_statement = inst.visit_variable( - &mut |arg_desc: ArgumentDescriptor, expected_type| { + let new_statement = inst.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4499,14 +4436,13 @@ fn convert_to_stateful_memory_access<'a>( }, )?; result.push(new_statement); - for s in post_statements { - result.push(s); - } + result.extend(post_statements); } Statement::Call(call) => { let mut post_statements = Vec::new(); - let new_statement = call.visit_variable( - &mut |arg_desc: ArgumentDescriptor, expected_type| { + let new_statement = call.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, @@ -4519,11 +4455,28 @@ fn convert_to_stateful_memory_access<'a>( }, )?; result.push(new_statement); - for s in post_statements { - result.push(s); - } + result.extend(post_statements); } - _ => return Err(TranslateError::Unreachable), + Statement::RepackVector(pack) => { + let mut post_statements = Vec::new(); + let new_statement = pack.visit( + &mut |arg_desc: ArgumentDescriptor, + expected_type: Option<&ast::Type>| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &func_args_ptr, + &mut result, + &mut post_statements, + arg_desc, + expected_type, + ) + }, + )?; + result.push(new_statement); + result.extend(post_statements); + } + _ => return Err(error_unreachable()), } } for arg in func_args.input.iter_mut() { @@ -4588,7 +4541,7 @@ fn convert_to_stateful_memory_access_postprocess( None => match func_args_ptr.get(&arg_desc.op) { Some(new_id) => { if arg_desc.is_dst { - return Err(TranslateError::Unreachable); + return Err(error_unreachable()); } // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { @@ -4617,13 +4570,20 @@ fn convert_to_stateful_memory_access_postprocess( } fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { - if !remapped_ids.contains_key(&arg.dst) { - return false; - } - match arg.src1.underlying() { - Some(src1) if remapped_ids.contains_key(src1) => true, - Some(src2) if remapped_ids.contains_key(src2) => true, - _ => false, + match arg.dst { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + match arg.src1.upcast().underlying() { + Some(src1) if remapped_ids.contains_key(src1) => true, + Some(src2) if remapped_ids.contains_key(src2) => true, + _ => false, + } + } } } @@ -4962,14 +4922,13 @@ enum Statement { // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Call(ResolvedCall

), - LoadVar(ast::Arg2, ast::Type), - StoreVar(ast::Arg2St, ast::Type), - Composite(CompositeRead), + LoadVar(LoadVarDetails), + StoreVar(StoreVarDetails), Conversion(ImplicitConversion), Constant(ConstantDefinition), RetValue(ast::RetData, spirv::Word), - Undef(ast::Type, spirv::Word), PtrAccess(PtrAccess

), + RepackVector(RepackVectorDetails), } impl ExpandedStatement { @@ -4981,19 +4940,19 @@ impl ExpandedStatement { Statement::Variable(var) } Statement::Instruction(inst) => inst - .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| { + .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| { Ok(f(arg.op, arg.is_dst)) }) .unwrap(), - Statement::LoadVar(mut arg, typ) => { - arg.dst = f(arg.dst, true); - arg.src = f(arg.src, false); - Statement::LoadVar(arg, typ) + Statement::LoadVar(mut details) => { + details.arg.dst = f(details.arg.dst, true); + details.arg.src = f(details.arg.src, false); + Statement::LoadVar(details) } - Statement::StoreVar(mut arg, typ) => { - arg.src1 = f(arg.src1, false); - arg.src2 = f(arg.src2, false); - Statement::StoreVar(arg, typ) + Statement::StoreVar(mut details) => { + details.arg.src1 = f(details.arg.src1, false); + details.arg.src2 = f(details.arg.src2, false); + Statement::StoreVar(details) } Statement::Call(mut call) => { for (id, typ) in call.ret_params.iter_mut() { @@ -5010,11 +4969,6 @@ impl ExpandedStatement { } Statement::Call(call) } - Statement::Composite(mut composite) => { - composite.dst = f(composite.dst, true); - composite.src_composite = f(composite.src_composite, false); - Statement::Composite(composite) - } Statement::Conditional(mut conditional) => { conditional.predicate = f(conditional.predicate, false); conditional.if_true = f(conditional.if_true, false); @@ -5034,10 +4988,6 @@ impl ExpandedStatement { let id = f(id, false); Statement::RetValue(data, id) } - Statement::Undef(typ, id) => { - let id = f(id, true); - Statement::Undef(typ, id) - } Statement::PtrAccess(PtrAccess { underlying_type, state_space, @@ -5056,19 +5006,100 @@ impl ExpandedStatement { offset_src: constant_src, }) } + Statement::RepackVector(_) => todo!(), } } } +struct LoadVarDetails { + arg: ast::Arg2, + typ: ast::Type, + // (index, vector_width) + // HACK ALERT + // For some reason IGC explodes when you try to load from builtin vectors + // using OpInBoundsAccessChain, the one true way to do it is to + // OpLoad+OpCompositeExtract + member_index: Option<(u8, Option)>, +} + +struct StoreVarDetails { + arg: ast::Arg2St, + typ: ast::Type, + member_index: Option, +} + +struct RepackVectorDetails { + is_extract: bool, + typ: ast::ScalarType, + packed: spirv::Word, + unpacked: Vec, + vector_sema: ArgumentSemantics, +} + +impl RepackVectorDetails { + fn map< + From: ArgParamsEx, + To: ArgParamsEx, + V: ArgumentMapVisitor, + >( + self, + visitor: &mut V, + ) -> Result { + let scalar = visitor.id( + ArgumentDescriptor { + op: self.packed, + is_dst: !self.is_extract, + sema: ArgumentSemantics::Default, + }, + Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)), + )?; + let scalar_type = self.typ; + let is_extract = self.is_extract; + let vector_sema = self.vector_sema; + let vector = self + .unpacked + .into_iter() + .map(|id| { + visitor.id( + ArgumentDescriptor { + op: id, + is_dst: is_extract, + sema: vector_sema, + }, + Some(&ast::Type::Scalar(scalar_type)), + ) + }) + .collect::>()?; + Ok(RepackVectorDetails { + is_extract, + typ: self.typ, + packed: scalar, + unpacked: vector, + vector_sema, + }) + } +} + +impl, U: ArgParamsEx> Visitable + for RepackVectorDetails +{ + fn visit( + self, + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::RepackVector(self.map::<_, _, _>(visitor)?)) + } +} + struct ResolvedCall { pub uniform: bool, - pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>, - pub func: spirv::Word, - pub param_list: Vec<(P::CallOperand, ast::FnArgumentType)>, + pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, + pub func: P::Id, + pub param_list: Vec<(P::Operand, ast::FnArgumentType)>, } impl ResolvedCall { - fn cast>(self) -> ResolvedCall { + fn cast>(self) -> ResolvedCall { ResolvedCall { uniform: self.uniform, ret_params: self.ret_params, @@ -5110,7 +5141,7 @@ impl> ResolvedCall { .param_list .into_iter() .map::, _>(|(id, typ)| { - let new_id = visitor.src_call_operand( + let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, @@ -5130,32 +5161,14 @@ impl> ResolvedCall { } } -impl VisitVariable for ResolvedCall { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - Ok(Statement::Call(self.map(f)?)) - } -} - -impl VisitVariableExpanded for ResolvedCall { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, U: ArgParamsEx> Visitable + for ResolvedCall +{ + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::Call(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::Call(self.map(visitor)?)) } } @@ -5208,18 +5221,14 @@ impl> PtrAccess

{ } } -impl VisitVariable for PtrAccess { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, U: ArgParamsEx> Visitable + for PtrAccess +{ + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::PtrAccess(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::PtrAccess(self.map(visitor)?)) } } @@ -5244,10 +5253,6 @@ enum NormalizedArgParams {} impl ast::ArgParams for NormalizedArgParams { type Id = spirv::Word; type Operand = ast::Operand; - type CallOperand = ast::CallOperand; - type IdOrVector = ast::IdOrVector; - type OperandOrVector = ast::OperandOrVector; - type SrcMemberOperand = (spirv::Word, u8); } impl ArgParamsEx for NormalizedArgParams { @@ -5273,11 +5278,7 @@ enum TypedArgParams {} impl ast::ArgParams for TypedArgParams { type Id = spirv::Word; - type Operand = ast::Operand; - type CallOperand = ast::CallOperand; - type IdOrVector = ast::IdOrVector; - type OperandOrVector = ast::OperandOrVector; - type SrcMemberOperand = (spirv::Word, u8); + type Operand = TypedOperand; } impl ArgParamsEx for TypedArgParams { @@ -5289,6 +5290,25 @@ impl ArgParamsEx for TypedArgParams { } } +#[derive(Copy, Clone)] +enum TypedOperand { + Reg(spirv::Word), + RegOffset(spirv::Word, i32), + Imm(ast::ImmediateValue), + VecMember(spirv::Word, u8), +} + +impl TypedOperand { + fn upcast(self) -> ast::Operand { + match self { + TypedOperand::Reg(reg) => ast::Operand::Reg(reg), + TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx), + TypedOperand::Imm(x) => ast::Operand::Imm(x), + TypedOperand::VecMember(vec, idx) => ast::Operand::VecMember(vec, idx), + } + } +} + type TypedStatement = Statement, TypedArgParams>; enum ExpandedArgParams {} @@ -5297,10 +5317,6 @@ type ExpandedStatement = Statement, Expanded impl ast::ArgParams for ExpandedArgParams { type Id = spirv::Word; type Operand = spirv::Word; - type CallOperand = spirv::Word; - type IdOrVector = spirv::Word; - type OperandOrVector = spirv::Word; - type SrcMemberOperand = spirv::Word; } impl ArgParamsEx for ExpandedArgParams { @@ -5312,29 +5328,6 @@ impl ArgParamsEx for ExpandedArgParams { } } -#[derive(Copy, Clone)] -pub enum StateSpace { - Reg, - Const, - Global, - Local, - Shared, - Param, -} - -impl From for StateSpace { - fn from(ss: ast::StateSpace) -> Self { - match ss { - ast::StateSpace::Reg => StateSpace::Reg, - ast::StateSpace::Const => StateSpace::Const, - ast::StateSpace::Global => StateSpace::Global, - ast::StateSpace::Local => StateSpace::Local, - ast::StateSpace::Shared => StateSpace::Shared, - ast::StateSpace::Param => StateSpace::Param, - } - } -} - enum Directive<'input> { Variable(ast::Variable), Method(Function<'input>), @@ -5359,26 +5352,6 @@ pub trait ArgumentMapVisitor { desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result; - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result; - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor, - typ: (ast::ScalarType, u8), - ) -> Result; } impl ArgumentMapVisitor for T @@ -5397,44 +5370,12 @@ where } fn operand( - &mut self, - desc: ArgumentDescriptor, - t: &ast::Type, - ) -> Result { - self(desc, Some(t)) - } - - fn id_or_vector( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { self(desc, Some(typ)) } - - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor, - typ: &ast::Type, - ) -> Result { - self(desc, Some(typ)) - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor, - t: &ast::Type, - ) -> Result { - self(desc, Some(t)) - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor, - (scalar_type, _): (ast::ScalarType, u8), - ) -> Result { - self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type))) - } } impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> for T @@ -5452,62 +5393,19 @@ where fn operand( &mut self, desc: ArgumentDescriptor>, - _: &ast::Type, + typ: &ast::Type, ) -> Result, TranslateError> { - match desc.op { - ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)), - ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)), - ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), - } - } - - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)), - ast::IdOrVector::Vec(ids) => Ok(ast::IdOrVector::Vec( - ids.into_iter().map(self).collect::>()?, - )), - } - } - - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)), - ast::OperandOrVector::RegOffset(id, imm) => { - Ok(ast::OperandOrVector::RegOffset(self(id)?, imm)) - } - ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)), - ast::OperandOrVector::Vec(ids) => Ok(ast::OperandOrVector::Vec( - ids.into_iter().map(self).collect::>()?, - )), - } - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor>, - _: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)), - ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)), - } - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor<(&str, u8)>, - _: (ast::ScalarType, u8), - ) -> Result<(spirv::Word, u8), TranslateError> { - Ok((self(desc.op.0)?, desc.op.1)) + Ok(match desc.op { + ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), + ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm), + ast::Operand::Imm(imm) => ast::Operand::Imm(imm), + ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), + ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( + ids.into_iter() + .map(|id| self.id(desc.new_op(id), Some(typ))) + .collect::, _>>()?, + ), + }) } } @@ -5559,7 +5457,7 @@ impl ast::Instruction { ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?) } // Call instruction is converted to a call statement early on - ast::Instruction::Call(_) => return Err(TranslateError::Unreachable), + ast::Instruction::Call(_) => return Err(error_unreachable()), ast::Instruction::Ld(d, a) => { let new_args = a.map(visitor, &d)?; ast::Instruction::Ld(d, new_args) @@ -5752,18 +5650,12 @@ impl ast::Instruction { } } -impl VisitVariable for ast::Instruction { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl Visitable for ast::Instruction { + fn visit( self, - f: &mut F, - ) -> Result { - Ok(Statement::Instruction(self.map(f)?)) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, U>, TranslateError> { + Ok(Statement::Instruction(self.map(visitor)?)) } } @@ -5802,32 +5694,14 @@ impl ImplicitConversion { } } -impl VisitVariable for ImplicitConversion { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - self.map(f) - } -} - -impl VisitVariableExpanded for ImplicitConversion { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( +impl, To: ArgParamsEx> Visitable + for ImplicitConversion +{ + fn visit( self, - f: &mut F, - ) -> Result { - self.map(f) + visitor: &mut impl ArgumentMapVisitor, + ) -> Result, To>, TranslateError> { + Ok(self.map(visitor)?) } } @@ -5848,79 +5722,24 @@ where fn operand( &mut self, - desc: ArgumentDescriptor>, - t: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)), - ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)), - ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset( - self(desc.new_op(id), Some(t))?, - imm, - )), - } - } - - fn src_call_operand( - &mut self, - desc: ArgumentDescriptor>, - t: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)), - ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)), - } - } - - fn id_or_vector( - &mut self, - desc: ArgumentDescriptor>, - typ: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)), - ast::IdOrVector::Vec(ref ids) => Ok(ast::IdOrVector::Vec( - ids.iter() - .map(|id| self(desc.new_op(*id), Some(typ))) - .collect::>()?, - )), - } - } - - fn operand_or_vector( - &mut self, - desc: ArgumentDescriptor>, + desc: ArgumentDescriptor, typ: &ast::Type, - ) -> Result, TranslateError> { - match desc.op { - ast::OperandOrVector::Reg(id) => { - Ok(ast::OperandOrVector::Reg(self(desc.new_op(id), Some(typ))?)) + ) -> Result { + Ok(match desc.op { + TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?), + TypedOperand::Imm(imm) => TypedOperand::Imm(imm), + TypedOperand::RegOffset(id, imm) => { + TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) + } + TypedOperand::VecMember(reg, index) => { + let scalar_type = match typ { + ast::Type::Scalar(scalar_t) => *scalar_t, + _ => return Err(error_unreachable()), + }; + let vec_type = ast::Type::Vector(scalar_type, index + 1); + TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) } - ast::OperandOrVector::RegOffset(id, imm) => Ok(ast::OperandOrVector::RegOffset( - self(desc.new_op(id), Some(typ))?, - imm, - )), - ast::OperandOrVector::Imm(imm) => Ok(ast::OperandOrVector::Imm(imm)), - ast::OperandOrVector::Vec(ref ids) => Ok(ast::OperandOrVector::Vec( - ids.iter() - .map(|id| self(desc.new_op(*id), Some(typ))) - .collect::>()?, - )), - } - } - - fn src_member_operand( - &mut self, - desc: ArgumentDescriptor<(spirv::Word, u8)>, - (scalar_type, vector_len): (ast::ScalarType, u8), - ) -> Result<(spirv::Word, u8), TranslateError> { - Ok(( - self( - desc.new_op(desc.op.0), - Some(&ast::Type::Vector(scalar_type.into(), vector_len)), - )?, - desc.op.1, - )) + }) } } @@ -5942,7 +5761,7 @@ impl ast::Type { kind, ))) } - _ => Err(TranslateError::Unreachable), + _ => Err(error_unreachable()), } } @@ -6182,67 +6001,9 @@ impl ast::Instruction { } } -impl VisitVariableExpanded for ast::Instruction { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - Ok(Statement::Instruction(self.map(f)?)) - } -} - type Arg2 = ast::Arg2; type Arg2St = ast::Arg2St; -struct CompositeRead { - pub typ: ast::ScalarType, - pub dst: spirv::Word, - pub dst_semantics_override: Option, - pub src_composite: spirv::Word, - pub src_index: u32, - pub src_len: u32, -} - -impl VisitVariableExpanded for CompositeRead { - fn visit_variable_extended< - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, - >( - self, - f: &mut F, - ) -> Result { - let dst_sema = self - .dst_semantics_override - .unwrap_or(ArgumentSemantics::Default); - Ok(Statement::Composite(CompositeRead { - dst: f( - ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: dst_sema, - }, - Some(&ast::Type::Scalar(self.typ)), - )?, - src_composite: f( - ArgumentDescriptor { - op: self.src_composite, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(self.typ, self.src_len as u8)), - )?, - ..self - })) - } -} - struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, @@ -6330,10 +6091,6 @@ impl From for ast::Type { } impl ast::Arg1 { - fn cast>(self) -> ast::Arg1 { - ast::Arg1 { src: self.src } - } - fn map>( self, visitor: &mut V, @@ -6352,10 +6109,6 @@ impl ast::Arg1 { } impl ast::Arg1Bar { - fn cast>(self) -> ast::Arg1Bar { - ast::Arg1Bar { src: self.src } - } - fn map>( self, visitor: &mut V, @@ -6373,25 +6126,18 @@ impl ast::Arg1Bar { } impl ast::Arg2 { - fn cast>(self) -> ast::Arg2 { - ast::Arg2 { - src: self.src, - dst: self.dst, - } - } - fn map>( self, visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { - let new_dst = visitor.id( + let new_dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(t), + t, )?; let new_src = visitor.operand( ArgumentDescriptor { @@ -6413,13 +6159,13 @@ impl ast::Arg2 { dst_t: &ast::Type, src_t: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(dst_t), + dst_t, )?; let src = visitor.operand( ArgumentDescriptor { @@ -6434,21 +6180,12 @@ impl ast::Arg2 { } impl ast::Arg2Ld { - fn cast>( - self, - ) -> ast::Arg2Ld { - ast::Arg2Ld { - dst: self.dst, - src: self.src, - } - } - fn map>( self, visitor: &mut V, details: &ast::LdDetails, ) -> Result, TranslateError> { - let dst = visitor.id_or_vector( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6478,15 +6215,6 @@ impl ast::Arg2Ld { } impl ast::Arg2St { - fn cast>( - self, - ) -> ast::Arg2St { - ast::Arg2St { - src1: self.src1, - src2: self.src2, - } - } - fn map>( self, visitor: &mut V, @@ -6509,7 +6237,7 @@ impl ast::Arg2St { details.state_space.to_ld_ss(), ), )?; - let src2 = visitor.operand_or_vector( + let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, @@ -6527,29 +6255,7 @@ impl ast::Arg2Mov { visitor: &mut V, details: &ast::MovDetails, ) -> Result, TranslateError> { - Ok(match self { - ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?), - ast::Arg2Mov::Member(arg) => ast::Arg2Mov::Member(arg.map(visitor, details)?), - }) - } -} - -impl ast::Arg2MovNormal

{ - fn cast>( - self, - ) -> ast::Arg2MovNormal { - ast::Arg2MovNormal { - dst: self.dst, - src: self.src, - } - } - - fn map>( - self, - visitor: &mut V, - details: &ast::MovDetails, - ) -> Result, TranslateError> { - let dst = visitor.id_or_vector( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, @@ -6557,7 +6263,7 @@ impl ast::Arg2MovNormal

{ }, &details.typ.clone().into(), )?; - let src = visitor.operand_or_vector( + let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, @@ -6569,144 +6275,11 @@ impl ast::Arg2MovNormal

{ }, &details.typ.clone().into(), )?; - Ok(ast::Arg2MovNormal { dst, src }) - } -} - -impl ast::Arg2MovMember { - fn cast>( - self, - ) -> ast::Arg2MovMember { - match self { - ast::Arg2MovMember::Dst(dst, src1, src2) => ast::Arg2MovMember::Dst(dst, src1, src2), - ast::Arg2MovMember::Src(dst, src) => ast::Arg2MovMember::Src(dst, src), - ast::Arg2MovMember::Both(dst, src1, src2) => ast::Arg2MovMember::Both(dst, src1, src2), - } - } - - fn vector_dst(&self) -> Option<&T::Id> { - match self { - ast::Arg2MovMember::Src(_, _) => None, - ast::Arg2MovMember::Dst((d, _), _, _) | ast::Arg2MovMember::Both((d, _), _, _) => { - Some(d) - } - } - } - - fn vector_src(&self) -> Option<&T::SrcMemberOperand> { - match self { - ast::Arg2MovMember::Src(_, d) | ast::Arg2MovMember::Both(_, _, d) => Some(d), - ast::Arg2MovMember::Dst(_, _, _) => None, - } - } -} - -impl ast::Arg2MovMember { - fn map>( - self, - visitor: &mut V, - details: &ast::MovDetails, - ) -> Result, TranslateError> { - match self { - ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => { - let scalar_type = details.typ.get_scalar()?; - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src1 = visitor.id( - ArgumentDescriptor { - op: composite_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src2 = visitor.id( - ArgumentDescriptor { - op: scalar_src, - is_dst: false, - sema: if details.src_is_address { - ArgumentSemantics::Address - } else if details.relaxed_src2_conv { - ArgumentSemantics::DefaultRelaxed - } else { - ArgumentSemantics::Default - }, - }, - Some(&details.typ.clone().into()), - )?; - Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2)) - } - ast::Arg2MovMember::Src(dst, src) => { - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&details.typ.clone().into()), - )?; - let scalar_typ = details.typ.get_scalar()?; - let src = visitor.src_member_operand( - ArgumentDescriptor { - op: src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - (scalar_typ.into(), details.src_width), - )?; - Ok(ast::Arg2MovMember::Src(dst, src)) - } - ast::Arg2MovMember::Both((dst, len), composite_src, src) => { - let scalar_type = details.typ.get_scalar()?; - let dst = visitor.id( - ArgumentDescriptor { - op: dst, - is_dst: true, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let composite_src = visitor.id( - ArgumentDescriptor { - op: composite_src, - is_dst: false, - sema: ArgumentSemantics::Default, - }, - Some(&ast::Type::Vector(scalar_type, details.dst_width)), - )?; - let src = visitor.src_member_operand( - ArgumentDescriptor { - op: src, - is_dst: false, - sema: if details.relaxed_src2_conv { - ArgumentSemantics::DefaultRelaxed - } else { - ArgumentSemantics::Default - }, - }, - (scalar_type.into(), details.src_width), - )?; - Ok(ast::Arg2MovMember::Both((dst, len), composite_src, src)) - } - } + Ok(ast::Arg2Mov { dst, src }) } } impl ast::Arg3 { - fn cast>(self) -> ast::Arg3 { - ast::Arg3 { - dst: self.dst, - src1: self.src1, - src2: self.src2, - } - } - fn map_non_shift>( self, visitor: &mut V, @@ -6718,13 +6291,13 @@ impl ast::Arg3 { } else { None }; - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(wide_type.as_ref().unwrap_or(typ)), + wide_type.as_ref().unwrap_or(typ), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6750,13 +6323,13 @@ impl ast::Arg3 { visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(t), + t, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6784,13 +6357,13 @@ impl ast::Arg3 { state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(scalar_type)), + &ast::Type::Scalar(scalar_type), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6816,15 +6389,6 @@ impl ast::Arg3 { } impl ast::Arg4 { - fn cast>(self) -> ast::Arg4 { - ast::Arg4 { - dst: self.dst, - src1: self.src1, - src2: self.src2, - src3: self.src3, - } - } - fn map>( self, visitor: &mut V, @@ -6836,13 +6400,13 @@ impl ast::Arg4 { } else { None }; - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(wide_type.as_ref().unwrap_or(t)), + wide_type.as_ref().unwrap_or(t), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6881,13 +6445,13 @@ impl ast::Arg4 { visitor: &mut V, t: ast::SelpType, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(t.into())), + &ast::Type::Scalar(t.into()), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6928,13 +6492,13 @@ impl ast::Arg4 { state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(&ast::Type::Scalar(scalar_type)), + &ast::Type::Scalar(scalar_type), )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -6976,13 +6540,13 @@ impl ast::Arg4 { visitor: &mut V, typ: &ast::Type, ) -> Result, TranslateError> { - let dst = visitor.id( + let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, - Some(typ), + typ, )?; let src1 = visitor.operand( ArgumentDescriptor { @@ -7019,15 +6583,6 @@ impl ast::Arg4 { } impl ast::Arg4Setp { - fn cast>(self) -> ast::Arg4Setp { - ast::Arg4Setp { - dst1: self.dst1, - dst2: self.dst2, - src1: self.src1, - src2: self.src2, - } - } - fn map>( self, visitor: &mut V, @@ -7079,22 +6634,12 @@ impl ast::Arg4Setp { } } -impl ast::Arg5 { - fn cast>(self) -> ast::Arg5 { - ast::Arg5 { - dst1: self.dst1, - dst2: self.dst2, - src1: self.src1, - src2: self.src2, - src3: self.src3, - } - } - +impl ast::Arg5Setp { fn map>( self, visitor: &mut V, t: &ast::Type, - ) -> Result, TranslateError> { + ) -> Result, TranslateError> { let dst1 = visitor.id( ArgumentDescriptor { op: self.dst1, @@ -7140,7 +6685,7 @@ impl ast::Arg5 { }, &ast::Type::Scalar(ast::ScalarType::Pred), )?; - Ok(ast::Arg5 { + Ok(ast::Arg5Setp { dst1, dst2, src1, @@ -7150,30 +6695,28 @@ impl ast::Arg5 { } } -impl ast::Type { - fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> { - match self { - ast::Type::Vector(t, len) => Ok((*t, *len)), - _ => Err(TranslateError::MismatchedType), - } - } - - fn get_scalar(&self) -> Result { - match self { - ast::Type::Scalar(t) => Ok(*t), - _ => Err(TranslateError::MismatchedType), - } - } -} - -impl ast::CallOperand { +impl ast::Operand { fn map_variable Result>( self, f: &mut F, - ) -> Result, TranslateError> { + ) -> Result, TranslateError> { + Ok(match self { + ast::Operand::Reg(reg) => ast::Operand::Reg(f(reg)?), + ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset), + ast::Operand::Imm(x) => ast::Operand::Imm(x), + ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx), + ast::Operand::VecPack(vec) => { + ast::Operand::VecPack(vec.into_iter().map(f).collect::>()?) + } + }) + } +} + +impl ast::Operand { + fn unwrap_reg(&self) -> Result { match self { - ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(f(id)?)), - ast::CallOperand::Imm(x) => Ok(ast::CallOperand::Imm(x)), + ast::Operand::Reg(reg) => Ok(*reg), + _ => Err(error_unreachable()), } } } @@ -7394,15 +6937,8 @@ impl ast::Operand { match self { ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r), ast::Operand::Imm(_) => None, - } - } -} - -impl ast::OperandOrVector { - fn single_underlying(&self) -> Option<&T> { - match self { - ast::OperandOrVector::Reg(r) | ast::OperandOrVector::RegOffset(r, _) => Some(r), - ast::OperandOrVector::Imm(_) | ast::OperandOrVector::Vec(_) => None, + ast::Operand::VecMember(reg, _) => Some(reg), + ast::Operand::VecPack(..) => None, } } } @@ -7500,7 +7036,7 @@ fn bitcast_physical_pointer( if let Some(space) = ss { Ok(Some(ConversionKind::BitToPtr(space))) } else { - Err(TranslateError::Unreachable) + Err(error_unreachable()) } } ast::Type::Scalar(ast::ScalarType::B32) -- cgit v1.2.3