From 76afbeba63d29e1247d5beb00902a8bb0279f791 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 6 Sep 2020 21:38:01 +0200 Subject: Implement support for PTX call instruction --- doc/NOTES.md | 6 + ptx/src/ast.rs | 44 +- ptx/src/ptx.lalrpop | 64 +- ptx/src/test/spirv_run/add.spvtxt | 84 ++- ptx/src/test/spirv_run/block.spvtxt | 95 +-- ptx/src/test/spirv_run/bra.spvtxt | 105 +-- ptx/src/test/spirv_run/call.ptx | 38 + ptx/src/test/spirv_run/call.spvtxt | 73 ++ ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt | 94 +-- ptx/src/test/spirv_run/cvta.spvtxt | 91 +-- ptx/src/test/spirv_run/ld_st.spvtxt | 79 +- ptx/src/test/spirv_run/local_align.spvtxt | 84 ++- ptx/src/test/spirv_run/mod.rs | 23 +- ptx/src/test/spirv_run/mov.spvtxt | 88 +-- ptx/src/test/spirv_run/mul_hi.spvtxt | 89 +-- ptx/src/test/spirv_run/mul_lo.spvtxt | 84 ++- ptx/src/test/spirv_run/not.spvtxt | 82 ++- ptx/src/test/spirv_run/setp.spvtxt | 138 ++-- ptx/src/test/spirv_run/shl.spvtxt | 86 ++- ptx/src/translate.rs | 1133 +++++++++++++++++++++-------- 20 files changed, 1692 insertions(+), 888 deletions(-) create mode 100644 ptx/src/test/spirv_run/call.ptx create mode 100644 ptx/src/test/spirv_run/call.spvtxt diff --git a/doc/NOTES.md b/doc/NOTES.md index b0f58f7..5e08b7e 100644 --- a/doc/NOTES.md +++ b/doc/NOTES.md @@ -75,3 +75,9 @@ CUDA <-> L0 * context ~= context (1.0+) * graph ~= command list * module ~= module + +IGC +--- +* IGC is extremely brittle and segfaults on fairly innocent code: + * OpBitcast of pointer to uint + * OpCopyMemory of alloca'd variable diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 7550d55..cfbdad5 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -51,23 +51,34 @@ pub struct Module<'a> { pub functions: Vec>, } -pub enum FunctionHeader<'a, P: ArgParams> { - Func(Vec>, P::ID), - Kernel(&'a str), +pub enum MethodDecl<'a, P: ArgParams> { + Func(Vec>, P::ID, Vec>), + Kernel(&'a str, Vec>), } pub struct Function<'a, P: ArgParams, S> { - pub func_directive: FunctionHeader<'a, P>, - pub args: Vec>, + pub func_directive: MethodDecl<'a, P>, pub body: Option>, } pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement>>; -#[derive(Default)] -pub struct Argument { +pub struct FnArgument { + pub base: KernelArgument

, + pub state_space: FnArgStateSpace, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum FnArgStateSpace { + Reg, + Param, +} + +#[derive(Default, Copy, Clone)] +pub struct KernelArgument { pub name: P::ID, pub a_type: ScalarType, + // TODO: turn length into part of type definition pub length: u32, } @@ -222,28 +233,26 @@ pub enum Instruction { Shl(ShlType, Arg3

), St(StData, Arg2St

), Ret(RetData), - Call(CallData, ArgCall

), + Call(CallInst

), Abs(AbsDetails, Arg2

), } -pub struct CallData { - pub uniform: bool, -} - pub struct AbsDetails { pub flush_to_zero: bool, pub typ: ScalarType, } -pub struct ArgCall { +pub struct CallInst { + pub uniform: bool, pub ret_params: Vec, pub func: P::ID, - pub param_list: Vec, + pub param_list: Vec, } pub trait ArgParams { type ID; type Operand; + type CallOperand; type MovOperand; } @@ -254,6 +263,7 @@ pub struct ParsedArgParams<'a> { impl<'a> ArgParams for ParsedArgParams<'a> { type ID = &'a str; type Operand = Operand<&'a str>; + type CallOperand = CallOperand<&'a str>; type MovOperand = MovOperand<&'a str>; } @@ -304,6 +314,12 @@ pub enum Operand { Imm(i128), } +#[derive(Copy, Clone)] +pub enum CallOperand { + Reg(ID), + Imm(i128), +} + pub enum MovOperand { Op(Operand), Vec(String, String), diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 7e38b78..53bb296 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -202,8 +202,7 @@ AddressSize = { Function: ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement>> = { LinkingDirective* - - + => ast::Function{<>} }; @@ -213,24 +212,43 @@ LinkingDirective = { ".weak" }; -FunctionHeader: ast::FunctionHeader<'input, ast::ParsedArgParams<'input>> = { - ".entry" => ast::FunctionHeader::Kernel(name), - ".func" => ast::FunctionHeader::Func(args.unwrap_or_else(|| Vec::new()), name) +MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = { + ".entry" => ast::MethodDecl::Kernel(name, params), + ".func" => ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) }; -Arguments: Vec>> = { - "(" > ")" => args -} +KernelArguments: Vec>> = { + "(" > ")" => args +}; + +FnArguments: Vec>> = { + "(" > ")" => args +}; + +FnInput: ast::FnArgument> = { + ".reg" <_type:ScalarType> => { + ast::FnArgument { + base: ast::KernelArgument {a_type: _type, name: name, length: 1 }, + state_space: ast::FnArgStateSpace::Reg, + } + }, + => { + ast::FnArgument { + base: p, + state_space: ast::FnArgStateSpace::Param, + } + } +}; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -FunctionInput: ast::Argument> = { +KernelInput: ast::KernelArgument> = { ".param" <_type:ScalarType> => { - ast::Argument {a_type: _type, name: name, length: 1 } + ast::KernelArgument {a_type: _type, name: name, length: 1 } }, ".param" "[" "]" => { let length = length.parse::(); let length = length.unwrap_with(errors); - ast::Argument { a_type: a_type, name: name, length: length } + ast::KernelArgument { a_type: a_type, name: name, length: length } } }; @@ -856,7 +874,10 @@ CvtaSize: ast::CvtaSize = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call InstCall: ast::Instruction> = { - "call" => ast::Instruction::Call(ast::CallData { uniform: u.is_some() }, a) + "call" => { + let (ret_params, func, param_list) = args; + ast::Instruction::Call(ast::CallInst { uniform: u.is_some(), ret_params, func, param_list }) + } }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs @@ -900,6 +921,15 @@ Operand: ast::Operand<&'input str> = { } }; +CallOperand: ast::CallOperand<&'input str> = { + => ast::CallOperand::Reg(r), + => { + let offset = o.parse::(); + let offset = offset.unwrap_with(errors); + ast::CallOperand::Imm(offset) + } +}; + MovOperand: ast::MovOperand<&'input str> = { => ast::MovOperand::Op(o), => { @@ -938,10 +968,12 @@ Arg5: ast::Arg5> = { "," "," "," "!"? => ast::Arg5{<>} }; -ArgCall: ast::ArgCall> = { - "(" > ")" "," "," "(" > ")" => ast::ArgCall{<>}, - "," "(" > ")" => ast::ArgCall{ret_params: Vec::new(), func, param_list}, - => ast::ArgCall{ret_params: Vec::new(), func, param_list: Vec::new()}, +ArgCall: (Vec<&'input str>, &'input str, Vec>) = { + "(" > ")" "," "," "(" > ")" => { + (ret_params, func, param_list) + }, + "," "(" > ")" => (Vec::new(), func, param_list), + => (Vec::new(), func, Vec::>::new()), }; OptionalDst: &'input str = { diff --git a/ptx/src/test/spirv_run/add.spvtxt b/ptx/src/test/spirv_run/add.spvtxt index 465a74e..6810fec 100644 --- a/ptx/src/test/spirv_run/add.spvtxt +++ b/ptx/src/test/spirv_run/add.spvtxt @@ -1,38 +1,46 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "add" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %21 = OpLabel - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - OpStore %8 %6 - OpStore %9 %7 - %12 = OpLoad %ulong %8 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %12 - %13 = OpLoad %ulong %19 - OpStore %10 %13 - %14 = OpLoad %ulong %10 - %15 = OpIAdd %ulong %14 %ulong_1 - OpStore %11 %15 - %16 = OpLoad %ulong %9 - %17 = OpLoad %ulong %11 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %20 %17 - OpReturn - OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "add" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpIAdd %ulong %17 %ulong_1 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/block.spvtxt b/ptx/src/test/spirv_run/block.spvtxt index a780cc3..534167d 100644 --- a/ptx/src/test/spirv_run/block.spvtxt +++ b/ptx/src/test/spirv_run/block.spvtxt @@ -1,44 +1,51 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "block" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %ulong_1_0 = OpConstant %ulong 1 - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %25 = OpLabel - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - OpStore %8 %6 - OpStore %9 %7 - %14 = OpLoad %ulong %8 - %23 = OpConvertUToPtr %_ptr_Generic_ulong %14 - %13 = OpLoad %ulong %23 - OpStore %10 %13 - %16 = OpLoad %ulong %10 - %15 = OpIAdd %ulong %16 %ulong_1 - OpStore %11 %15 - %12 = OpVariable %_ptr_Function_ulong Function - %18 = OpLoad %ulong %12 - %17 = OpIAdd %ulong %18 %ulong_1_0 - OpStore %12 %17 - %19 = OpLoad %ulong %9 - %20 = OpLoad %ulong %11 - %24 = OpConvertUToPtr %_ptr_Generic_ulong %19 - OpStore %24 %20 - OpReturn - OpFunctionEnd - \ No newline at end of file + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %29 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "block" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %32 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %ulong_1_0 = OpConstant %ulong 1 + %1 = OpFunction %void None %32 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %27 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %9 + OpStore %3 %10 + %12 = OpLoad %ulong %2 + %11 = OpCopyObject %ulong %12 + OpStore %4 %11 + %14 = OpLoad %ulong %3 + %13 = OpCopyObject %ulong %14 + OpStore %5 %13 + %16 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_ulong %16 + %15 = OpLoad %ulong %25 + OpStore %6 %15 + %18 = OpLoad %ulong %6 + %17 = OpIAdd %ulong %18 %ulong_1 + OpStore %7 %17 + %20 = OpLoad %ulong %8 + %19 = OpIAdd %ulong %20 %ulong_1_0 + OpStore %8 %19 + %21 = OpLoad %ulong %5 + %22 = OpLoad %ulong %7 + %26 = OpConvertUToPtr %_ptr_Generic_ulong %21 + OpStore %26 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/bra.spvtxt b/ptx/src/test/spirv_run/bra.spvtxt index 81fedc5..f59fda5 100644 --- a/ptx/src/test/spirv_run/bra.spvtxt +++ b/ptx/src/test/spirv_run/bra.spvtxt @@ -1,49 +1,56 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "bra" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %ulong_2 = OpConstant %ulong 2 - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %27 = OpLabel - %11 = OpVariable %_ptr_Function_ulong Function - %12 = OpVariable %_ptr_Function_ulong Function - %13 = OpVariable %_ptr_Function_ulong Function - %14 = OpVariable %_ptr_Function_ulong Function - OpStore %11 %6 - OpStore %12 %7 - %16 = OpLoad %ulong %11 - %25 = OpConvertUToPtr %_ptr_Generic_ulong %16 - %15 = OpLoad %ulong %25 - OpStore %13 %15 - OpBranch %8 - %8 = OpLabel - %18 = OpLoad %ulong %13 - %17 = OpIAdd %ulong %18 %ulong_1 - OpStore %14 %17 - OpBranch %10 - %30 = OpLabel - %20 = OpLoad %ulong %13 - %19 = OpIAdd %ulong %20 %ulong_2 - OpStore %14 %19 - OpBranch %10 - %10 = OpLabel - %21 = OpLoad %ulong %12 - %22 = OpLoad %ulong %14 - %26 = OpConvertUToPtr %_ptr_Generic_ulong %21 - OpStore %26 %22 - OpReturn - OpFunctionEnd - \ No newline at end of file + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %31 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "bra" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %34 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %ulong_2 = OpConstant %ulong 2 + %1 = OpFunction %void None %34 + %11 = OpFunctionParameter %ulong + %12 = OpFunctionParameter %ulong + %29 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_ulong Function + %10 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %11 + OpStore %3 %12 + %14 = OpLoad %ulong %2 + %13 = OpCopyObject %ulong %14 + OpStore %7 %13 + %16 = OpLoad %ulong %3 + %15 = OpCopyObject %ulong %16 + OpStore %8 %15 + %18 = OpLoad %ulong %7 + %27 = OpConvertUToPtr %_ptr_Generic_ulong %18 + %17 = OpLoad %ulong %27 + OpStore %9 %17 + OpBranch %4 + %4 = OpLabel + %20 = OpLoad %ulong %9 + %19 = OpIAdd %ulong %20 %ulong_1 + OpStore %10 %19 + OpBranch %6 + %37 = OpLabel + %22 = OpLoad %ulong %9 + %21 = OpIAdd %ulong %22 %ulong_2 + OpStore %10 %21 + OpBranch %6 + %6 = OpLabel + %23 = OpLoad %ulong %8 + %24 = OpLoad %ulong %10 + %28 = OpConvertUToPtr %_ptr_Generic_ulong %23 + OpStore %28 %24 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/call.ptx b/ptx/src/test/spirv_run/call.ptx new file mode 100644 index 0000000..f2ac39c --- /dev/null +++ b/ptx/src/test/spirv_run/call.ptx @@ -0,0 +1,38 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.func (.param.u64 output) incr (.param.u64 input); + +.visible .entry call( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.u64 temp, [in_addr]; + .param.u64 incr_in; + .param.u64 incr_out; + st.param.b64 [incr_in], temp; + call (incr_out), incr, (incr_in); + ld.param.u64 temp, [incr_out]; + st.global.u64 [out_addr], temp; + ret; +} + +.func (.param .u64 output) incr( + .param .u64 input +) +{ + .reg .u64 temp; + ld.param.u64 temp, [input]; + add.u64 temp, temp, 1; + st.param.u64 [output], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt new file mode 100644 index 0000000..001cda3 --- /dev/null +++ b/ptx/src/test/spirv_run/call.spvtxt @@ -0,0 +1,73 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %45 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %4 "call" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %48 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong + %51 = OpTypeFunction %ulong %ulong + %ulong_1 = OpConstant %ulong 1 + %4 = OpFunction %void None %48 + %12 = OpFunctionParameter %ulong + %13 = OpFunctionParameter %ulong + %30 = OpLabel + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_ulong Function + %10 = OpVariable %_ptr_Function_ulong Function + %11 = OpVariable %_ptr_Function_ulong Function + OpStore %5 %12 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %14 = OpCopyObject %ulong %15 + OpStore %7 %14 + %17 = OpLoad %ulong %6 + %16 = OpCopyObject %ulong %17 + OpStore %8 %16 + %19 = OpLoad %ulong %7 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %19 + %18 = OpLoad %ulong %28 + OpStore %9 %18 + %21 = OpLoad %ulong %9 + %20 = OpCopyObject %ulong %21 + OpStore %10 %20 + %23 = OpLoad %ulong %10 + %22 = OpFunctionCall %ulong %1 %23 + OpStore %11 %22 + %25 = OpLoad %ulong %11 + %24 = OpCopyObject %ulong %25 + OpStore %9 %24 + %26 = OpLoad %ulong %8 + %27 = OpLoad %ulong %9 + %29 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 + OpStore %29 %27 + OpReturn + OpFunctionEnd + %1 = OpFunction %ulong None %51 + %34 = OpFunctionParameter %ulong + %43 = OpLabel + %32 = OpVariable %_ptr_Function_ulong Function + %31 = OpVariable %_ptr_Function_ulong Function + %33 = OpVariable %_ptr_Function_ulong Function + OpStore %32 %34 + %36 = OpLoad %ulong %32 + %35 = OpCopyObject %ulong %36 + OpStore %33 %35 + %38 = OpLoad %ulong %33 + %37 = OpIAdd %ulong %38 %ulong_1 + OpStore %33 %37 + %40 = OpLoad %ulong %33 + %39 = OpCopyObject %ulong %40 + OpStore %31 %39 + %41 = OpLoad %ulong %31 + OpReturnValue %41 + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt index afd2864..208c279 100644 --- a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt +++ b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt @@ -1,43 +1,51 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "cvt_sat_s_u" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %23 = OpLabel - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_uint Function - %11 = OpVariable %_ptr_Function_uint Function - %12 = OpVariable %_ptr_Function_uint Function - OpStore %8 %6 - OpStore %9 %7 - %14 = OpLoad %ulong %8 - %21 = OpConvertUToPtr %_ptr_Generic_uint %14 - %13 = OpLoad %uint %21 - OpStore %10 %13 - %16 = OpLoad %uint %10 - %15 = OpSatConvertSToU %uint %16 - OpStore %11 %15 - %18 = OpLoad %uint %11 - %17 = OpBitcast %uint %18 - OpStore %12 %17 - %19 = OpLoad %ulong %9 - %20 = OpLoad %uint %12 - %22 = OpConvertUToPtr %_ptr_Generic_uint %19 - OpStore %22 %20 - OpReturn - OpFunctionEnd \ No newline at end of file + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %27 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "cvt_sat_s_u" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %30 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %30 + %9 = OpFunctionParameter %ulong + %10 = OpFunctionParameter %ulong + %25 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + %8 = OpVariable %_ptr_Function_uint Function + OpStore %2 %9 + OpStore %3 %10 + %12 = OpLoad %ulong %2 + %11 = OpCopyObject %ulong %12 + OpStore %4 %11 + %14 = OpLoad %ulong %3 + %13 = OpCopyObject %ulong %14 + OpStore %5 %13 + %16 = OpLoad %ulong %4 + %23 = OpConvertUToPtr %_ptr_Generic_uint %16 + %15 = OpLoad %uint %23 + OpStore %6 %15 + %18 = OpLoad %uint %6 + %17 = OpSatConvertSToU %uint %18 + OpStore %7 %17 + %20 = OpLoad %uint %7 + %19 = OpBitcast %uint %20 + OpStore %8 %19 + %21 = OpLoad %ulong %5 + %22 = OpLoad %uint %8 + %24 = OpConvertUToPtr %_ptr_Generic_uint %21 + OpStore %24 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt index 1aa7425..e708613 100644 --- a/ptx/src/test/spirv_run/cvta.spvtxt +++ b/ptx/src/test/spirv_run/cvta.spvtxt @@ -1,42 +1,49 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "cvta" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %float = OpTypeFloat 32 -%_ptr_Function_float = OpTypePointer Function %float -%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %21 = OpLabel - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_float Function - OpStore %8 %6 - OpStore %9 %7 - %12 = OpLoad %ulong %8 - %11 = OpCopyObject %ulong %12 - OpStore %8 %11 - %14 = OpLoad %ulong %9 - %13 = OpCopyObject %ulong %14 - OpStore %9 %13 - %16 = OpLoad %ulong %8 - %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16 - %15 = OpLoad %float %19 - OpStore %10 %15 - %17 = OpLoad %ulong %9 - %18 = OpLoad %float %10 - %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17 - OpStore %20 %18 - OpReturn - OpFunctionEnd - \ No newline at end of file + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "cvta" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %float = OpTypeFloat 32 +%_ptr_Function_float = OpTypePointer Function %float +%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float + %1 = OpFunction %void None %28 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_float Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %13 = OpCopyObject %ulong %14 + OpStore %4 %13 + %16 = OpLoad %ulong %5 + %15 = OpCopyObject %ulong %16 + OpStore %5 %15 + %18 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18 + %17 = OpLoad %float %21 + OpStore %6 %17 + %19 = OpLoad %ulong %5 + %20 = OpLoad %float %6 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19 + OpStore %22 %20 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/ld_st.spvtxt b/ptx/src/test/spirv_run/ld_st.spvtxt index 1cb7094..d36db57 100644 --- a/ptx/src/test/spirv_run/ld_st.spvtxt +++ b/ptx/src/test/spirv_run/ld_st.spvtxt @@ -1,38 +1,41 @@ -; SPIR-V -; Version: 1.5 -; Generator: Khronos SPIR-V Tools Assembler; 0 -; Bound: 20 -; Schema: 0 - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "ld_st" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %5 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %2 = OpFunction %void None %5 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %10 = OpLabel - %11 = OpVariable %_ptr_Function_ulong Function - %12 = OpVariable %_ptr_Function_ulong Function - %13 = OpVariable %_ptr_Function_ulong Function - OpStore %11 %8 - OpStore %12 %9 - %14 = OpLoad %ulong %11 - %15 = OpConvertUToPtr %_ptr_Generic_ulong %14 - %16 = OpLoad %ulong %15 - OpStore %13 %16 - %17 = OpLoad %ulong %12 - %18 = OpLoad %ulong %13 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %17 - OpStore %19 %18 - OpReturn - OpFunctionEnd \ No newline at end of file + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %21 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "ld_st" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %24 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %1 = OpFunction %void None %24 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %19 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %7 + OpStore %3 %8 + %10 = OpLoad %ulong %2 + %9 = OpCopyObject %ulong %10 + OpStore %4 %9 + %12 = OpLoad %ulong %3 + %11 = OpCopyObject %ulong %12 + OpStore %5 %11 + %14 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_Generic_ulong %14 + %13 = OpLoad %ulong %17 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %ulong %6 + %18 = OpConvertUToPtr %_ptr_Generic_ulong %15 + OpStore %18 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/local_align.spvtxt b/ptx/src/test/spirv_run/local_align.spvtxt index beefb76..09a3f92 100644 --- a/ptx/src/test/spirv_run/local_align.spvtxt +++ b/ptx/src/test/spirv_run/local_align.spvtxt @@ -1,38 +1,46 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "local_align" - OpDecorate %8 Alignment 8 - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong - %uchar = OpTypeInt 8 0 -%_arr_uchar_8 = OpTypeArray %uchar %8 -%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8 -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %18 = OpLabel - %8 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - OpStore %9 %6 - OpStore %10 %7 - %13 = OpLoad %ulong %9 - %16 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %16 - OpStore %11 %12 - %14 = OpLoad %ulong %10 - %15 = OpLoad %ulong %11 - %17 = OpConvertUToPtr %_ptr_Generic_ulong %14 - OpStore %17 %15 - OpReturn - OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %22 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "local_align" + OpDecorate %4 Alignment 8 + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %25 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uchar = OpTypeInt 8 0 +%_arr_uchar_8 = OpTypeArray %uchar %8 +%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8 +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %1 = OpFunction %void None %25 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %20 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %5 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %6 %12 + %15 = OpLoad %ulong %5 + %18 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %18 + OpStore %7 %14 + %16 = OpLoad %ulong %6 + %17 = OpLoad %ulong %7 + %19 = OpConvertUToPtr %_ptr_Generic_ulong %16 + OpStore %19 %17 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 9f62292..a72c453 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -14,7 +14,7 @@ use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::mem; use std::slice; -use std::{collections::HashMap, ptr, str}; +use std::{borrow::Cow, collections::HashMap, env, fs, path::PathBuf, ptr, str}; macro_rules! test_ptx { ($fn_name:ident, $input:expr, $output:expr) => { @@ -32,8 +32,9 @@ macro_rules! test_ptx { #[test] fn [<$fn_name _spvtxt>]() -> Result<(), Box> { let ptx_txt = include_str!(concat!(stringify!($fn_name), ".ptx")); + let spirv_file_name = concat!(stringify!($fn_name), ".spvtxt"); let spirv_txt = include_bytes!(concat!(stringify!($fn_name), ".spvtxt")); - test_spvtxt_assert(ptx_txt, spirv_txt) + test_spvtxt_assert(ptx_txt, spirv_txt, spirv_file_name) } } }; @@ -140,6 +141,7 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( fn test_spvtxt_assert<'a>( ptx_txt: &'a str, spirv_txt: &'a [u8], + spirv_file_name: &'a str, ) -> Result<(), Box> { let mut errors = Vec::new(); let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; @@ -191,16 +193,27 @@ fn test_spvtxt_assert<'a>( ) }; unsafe { spirv_tools::spvContextDestroy(spv_context) }; - if result == spv_result_t::SPV_SUCCESS { + let spirv_text = if result == spv_result_t::SPV_SUCCESS { let raw_text = unsafe { std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length) }; let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) }; // TODO: stop leaking kernel text - panic!(spv_from_ptx_text); + Cow::Borrowed(spv_from_ptx_text) } else { - panic!(ptx_mod.disassemble()); + Cow::Owned(ptx_mod.disassemble()) + }; + if let Ok(dump_path) = env::var("NOTCUDA_TEST_SPIRV_DUMP_DIR") { + let mut path = PathBuf::from(dump_path); + if let Ok(()) = fs::create_dir_all(&path) { + path.push(spirv_file_name); + #[allow(unused_must_use)] + { + fs::write(path, spirv_text.as_bytes()); + } + } } + panic!(spirv_text); } unsafe { spirv_tools::spvContextDestroy(spv_context) }; Ok(()) diff --git a/ptx/src/test/spirv_run/mov.spvtxt b/ptx/src/test/spirv_run/mov.spvtxt index 367a92a..d8a5029 100644 --- a/ptx/src/test/spirv_run/mov.spvtxt +++ b/ptx/src/test/spirv_run/mov.spvtxt @@ -1,43 +1,45 @@ -; SPIR-V -; Version: 1.5 -; Generator: Khronos SPIR-V Tools Assembler; 0 -; Bound: 23 -; Schema: 0 - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "mov" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %5 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %2 = OpFunction %void None %5 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %10 = OpLabel - %11 = OpVariable %_ptr_Function_ulong Function - %12 = OpVariable %_ptr_Function_ulong Function - %13 = OpVariable %_ptr_Function_ulong Function - %14 = OpVariable %_ptr_Function_ulong Function - OpStore %11 %8 - OpStore %12 %9 - %15 = OpLoad %ulong %11 - %16 = OpConvertUToPtr %_ptr_Generic_ulong %15 - %17 = OpLoad %ulong %16 - OpStore %13 %17 - %18 = OpLoad %ulong %13 - %19 = OpCopyObject %ulong %18 - OpStore %14 %19 - %20 = OpLoad %ulong %12 - %21 = OpLoad %ulong %14 - %22 = OpConvertUToPtr %_ptr_Generic_ulong %20 - OpStore %22 %21 - OpReturn - OpFunctionEnd - \ No newline at end of file + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %24 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "mov" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %27 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %1 = OpFunction %void None %27 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %22 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %20 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %20 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpCopyObject %ulong %17 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %21 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_hi.spvtxt b/ptx/src/test/spirv_run/mul_hi.spvtxt index d25dd8a..bea23a9 100644 --- a/ptx/src/test/spirv_run/mul_hi.spvtxt +++ b/ptx/src/test/spirv_run/mul_hi.spvtxt @@ -1,43 +1,46 @@ -; SPIR-V -; Version: 1.5 -; Generator: Khronos SPIR-V Tools Assembler; 0 -; Bound: 24 -; Schema: 0 - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %2 "mul_hi" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %5 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_2 = OpConstant %ulong 2 - %2 = OpFunction %void None %5 - %9 = OpFunctionParameter %ulong - %10 = OpFunctionParameter %ulong - %11 = OpLabel - %12 = OpVariable %_ptr_Function_ulong Function - %13 = OpVariable %_ptr_Function_ulong Function - %14 = OpVariable %_ptr_Function_ulong Function - %15 = OpVariable %_ptr_Function_ulong Function - OpStore %12 %9 - OpStore %13 %10 - %16 = OpLoad %ulong %12 - %17 = OpConvertUToPtr %_ptr_Generic_ulong %16 - %18 = OpLoad %ulong %17 - OpStore %14 %18 - %19 = OpLoad %ulong %14 - %20 = OpExtInst %ulong %1 u_mul_hi %19 %ulong_2 - OpStore %15 %20 - %21 = OpLoad %ulong %13 - %22 = OpLoad %ulong %15 - %23 = OpConvertUToPtr %_ptr_Generic_ulong %21 - OpStore %23 %22 - OpReturn - OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "mul_hi" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_2 = OpConstant %ulong 2 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpExtInst %ulong %25 u_mul_hi %17 %ulong_2 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mul_lo.spvtxt b/ptx/src/test/spirv_run/mul_lo.spvtxt index 4d7c2d8..e114374 100644 --- a/ptx/src/test/spirv_run/mul_lo.spvtxt +++ b/ptx/src/test/spirv_run/mul_lo.spvtxt @@ -1,38 +1,46 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "mul_lo" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_2 = OpConstant %ulong 2 - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %21 = OpLabel - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - OpStore %8 %6 - OpStore %9 %7 - %12 = OpLoad %ulong %8 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %12 - %13 = OpLoad %ulong %19 - OpStore %10 %13 - %14 = OpLoad %ulong %10 - %15 = OpIMul %ulong %14 %ulong_2 - OpStore %11 %15 - %16 = OpLoad %ulong %9 - %17 = OpLoad %ulong %11 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %20 %17 - OpReturn - OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "mul_lo" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_2 = OpConstant %ulong 2 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpIMul %ulong %17 %ulong_2 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt index 84482d9..de340ed 100644 --- a/ptx/src/test/spirv_run/not.spvtxt +++ b/ptx/src/test/spirv_run/not.spvtxt @@ -1,37 +1,45 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "not" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %20 = OpLabel - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - OpStore %8 %6 - OpStore %9 %7 - %13 = OpLoad %ulong %8 - %18 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %18 - OpStore %10 %12 - %15 = OpLoad %ulong %10 - %14 = OpNot %ulong %15 - OpStore %11 %14 - %16 = OpLoad %ulong %9 - %17 = OpLoad %ulong %11 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %19 %17 - OpReturn - OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %24 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "not" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %27 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %1 = OpFunction %void None %27 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %22 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %20 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %20 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpNot %ulong %17 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %21 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp.spvtxt b/ptx/src/test/spirv_run/setp.spvtxt index 064cd97..cb87f65 100644 --- a/ptx/src/test/spirv_run/setp.spvtxt +++ b/ptx/src/test/spirv_run/setp.spvtxt @@ -1,65 +1,73 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "setp" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %bool = OpTypeBool -%_ptr_Function_bool = OpTypePointer Function %bool -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_8 = OpConstant %ulong 8 - %ulong_1 = OpConstant %ulong 1 - %ulong_2 = OpConstant %ulong 2 - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %39 = OpLabel - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - %12 = OpVariable %_ptr_Function_ulong Function - %13 = OpVariable %_ptr_Function_bool Function - OpStore %8 %6 - OpStore %9 %7 - %19 = OpLoad %ulong %8 - %35 = OpConvertUToPtr %_ptr_Generic_ulong %19 - %18 = OpLoad %ulong %35 - OpStore %10 %18 - %21 = OpLoad %ulong %8 - %36 = OpCopyObject %ulong %21 - %32 = OpIAdd %ulong %36 %ulong_8 - %37 = OpConvertUToPtr %_ptr_Generic_ulong %32 - %20 = OpLoad %ulong %37 - OpStore %11 %20 - %23 = OpLoad %ulong %10 - %24 = OpLoad %ulong %11 - %22 = OpULessThan %bool %23 %24 - OpStore %13 %22 - %25 = OpLoad %bool %13 - OpBranchConditional %25 %14 %15 - %14 = OpLabel - %26 = OpCopyObject %ulong %ulong_1 - OpStore %12 %26 - OpBranch %15 - %15 = OpLabel - %27 = OpLoad %bool %13 - OpBranchConditional %27 %17 %16 - %16 = OpLabel - %28 = OpCopyObject %ulong %ulong_2 - OpStore %12 %28 - OpBranch %17 - %17 = OpLabel - %29 = OpLoad %ulong %9 - %30 = OpLoad %ulong %12 - %38 = OpConvertUToPtr %_ptr_Generic_ulong %29 - OpStore %38 %30 - OpReturn - OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %43 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "setp" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %46 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %bool = OpTypeBool +%_ptr_Function_bool = OpTypePointer Function %bool +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_8 = OpConstant %ulong 8 + %ulong_1 = OpConstant %ulong 1 + %ulong_2 = OpConstant %ulong 2 + %1 = OpFunction %void None %46 + %14 = OpFunctionParameter %ulong + %15 = OpFunctionParameter %ulong + %41 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + %8 = OpVariable %_ptr_Function_ulong Function + %9 = OpVariable %_ptr_Function_bool Function + OpStore %2 %14 + OpStore %3 %15 + %17 = OpLoad %ulong %2 + %16 = OpCopyObject %ulong %17 + OpStore %4 %16 + %19 = OpLoad %ulong %3 + %18 = OpCopyObject %ulong %19 + OpStore %5 %18 + %21 = OpLoad %ulong %4 + %37 = OpConvertUToPtr %_ptr_Generic_ulong %21 + %20 = OpLoad %ulong %37 + OpStore %6 %20 + %23 = OpLoad %ulong %4 + %38 = OpCopyObject %ulong %23 + %34 = OpIAdd %ulong %38 %ulong_8 + %39 = OpConvertUToPtr %_ptr_Generic_ulong %34 + %22 = OpLoad %ulong %39 + OpStore %7 %22 + %25 = OpLoad %ulong %6 + %26 = OpLoad %ulong %7 + %24 = OpULessThan %bool %25 %26 + OpStore %9 %24 + %27 = OpLoad %bool %9 + OpBranchConditional %27 %10 %11 + %10 = OpLabel + %28 = OpCopyObject %ulong %ulong_1 + OpStore %8 %28 + OpBranch %11 + %11 = OpLabel + %29 = OpLoad %bool %9 + OpBranchConditional %29 %13 %12 + %12 = OpLabel + %30 = OpCopyObject %ulong %ulong_2 + OpStore %8 %30 + OpBranch %13 + %13 = OpLabel + %31 = OpLoad %ulong %5 + %32 = OpLoad %ulong %8 + %40 = OpConvertUToPtr %_ptr_Generic_ulong %31 + OpStore %40 %32 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt index 3e57fc3..dbd2664 100644 --- a/ptx/src/test/spirv_run/shl.spvtxt +++ b/ptx/src/test/spirv_run/shl.spvtxt @@ -1,39 +1,47 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %1 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %5 "shl" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %4 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %uint = OpTypeInt 32 0 - %uint_2 = OpConstant %uint 2 - %5 = OpFunction %void None %4 - %6 = OpFunctionParameter %ulong - %7 = OpFunctionParameter %ulong - %21 = OpLabel - %8 = OpVariable %_ptr_Function_ulong Function - %9 = OpVariable %_ptr_Function_ulong Function - %10 = OpVariable %_ptr_Function_ulong Function - %11 = OpVariable %_ptr_Function_ulong Function - OpStore %8 %6 - OpStore %9 %7 - %13 = OpLoad %ulong %8 - %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 - %12 = OpLoad %ulong %19 - OpStore %10 %12 - %15 = OpLoad %ulong %10 - %14 = OpShiftLeftLogical %ulong %15 %uint_2 - OpStore %11 %14 - %16 = OpLoad %ulong %9 - %17 = OpLoad %ulong %11 - %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 - OpStore %20 %17 - OpReturn - OpFunctionEnd + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "shl" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpShiftLeftLogical %ulong %17 %uint_2 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 34d8c12..bd37b14 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,6 +1,6 @@ use crate::ast; use rspirv::dr; -use std::collections::{HashMap, HashSet}; +use std::collections::{hash_map, HashMap, HashSet}; use std::{borrow::Cow, iter, mem}; use rspirv::binary::Assemble; @@ -10,6 +10,7 @@ enum SpirvType { Base(SpirvScalarKey), Array(SpirvScalarKey, u32), Pointer(Box, spirv::StorageClass), + Func(Option>, Vec), } impl SpirvType { @@ -141,16 +142,44 @@ impl TypeWordMap { .entry(t) .or_insert_with(|| b.type_array(base, len)) } + SpirvType::Func(ref out_params, ref in_params) => { + let out_t = match out_params { + Some(p) => self.get_or_add(b, *p.clone()), + None => self.void(), + }; + let in_t = in_params + .iter() + .map(|t| self.get_or_add(b, t.clone())) + .collect::>(); + *self + .complex + .entry(t) + .or_insert_with(|| b.type_function(out_t, in_t)) + } } } - fn get_or_add_fn>( + fn get_or_add_fn( &mut self, b: &mut dr::Builder, - args: Args, - ) -> spirv::Word { - let params = args.map(|a| self.get_or_add(b, a)).collect::>(); - b.type_function(self.void(), params) + mut out_params: impl ExactSizeIterator, + in_params: impl ExactSizeIterator, + ) -> (spirv::Word, spirv::Word) { + let (out_args, out_spirv_type) = if out_params.len() == 0 { + (None, self.void()) + } else if out_params.len() == 1 { + let arg_as_key = out_params.next().unwrap(); + ( + Some(Box::new(arg_as_key.clone())), + self.get_or_add(b, arg_as_key), + ) + } else { + todo!() + }; + ( + out_spirv_type, + self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), + ) } } @@ -171,29 +200,31 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result f, + None => continue, + }; + emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?; + emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; builder.end_function()?; } Ok(builder.module()) } -fn emit_function_header( +fn emit_function_header<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, - global: &GlobalStringIdResolver, - func_directive: ast::FunctionHeader, - params: &[ast::Argument], + global: &GlobalStringIdResolver<'a>, + func_directive: ast::MethodDecl, ) -> Result<(), dr::Error> { - let func_type = get_function_type(builder, map, params); - let (fn_id, ret_type) = match func_directive { - ast::FunctionHeader::Kernel(name) => { + let (ret_type, func_type) = get_function_type(builder, map, &func_directive); + let fn_id = match func_directive { + ast::MethodDecl::Kernel(name, _) => { let fn_id = global.get_id(name); builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]); - (fn_id, map.void()) + fn_id } - ast::FunctionHeader::Func(params, name) => todo!(), + ast::MethodDecl::Func(_, name, _) => name, }; builder.begin_function( ret_type, @@ -201,6 +232,16 @@ fn emit_function_header( spirv::FunctionControl::NONE, func_type, )?; + func_directive.visit_args(|arg| { + let result_type = map.get_or_add_scalar(builder, arg.a_type); + let inst = dr::Instruction::new( + spirv::Op::FunctionParameter, + Some(result_type), + Some(arg.name), + Vec::new(), + ); + builder.function.as_mut().unwrap().parameters.push(inst); + }); Ok(()) } @@ -235,50 +276,116 @@ fn to_ssa_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, f: ast::ParsedFunction<'a>, ) -> ExpandedFunction<'a> { - let mut fn_resolver = FnStringIdResolver::new(id_defs, f.func_directive.name()); - let f_header = match f.func_directive { - ast::FunctionHeader::Kernel(name) => ast::FunctionHeader::Kernel(name), - ast::FunctionHeader::Func(ret_params, name) => { - let name_id = fn_resolver.add_global_def(name); - let ret_ids = expand_fn_params(&mut fn_resolver, ret_params); - ast::FunctionHeader::Func(ret_ids, name_id) - } - }; - let f_args = expand_fn_params(&mut fn_resolver, f.args); - let f_body = Some(to_ssa(fn_resolver, f.body.unwrap_or_else(|| Vec::new()))); - ExpandedFunction { - func_directive: f_header, - args: f_args, - body: f_body, - } + let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive); + to_ssa(str_resolver, fn_resolver, fn_decl, f.body) } -fn expand_fn_params<'a, 'b>( +fn expand_kernel_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, - args: Vec>>, -) -> Vec> { - args.into_iter() - .map(|a| ast::Argument { - name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))), - a_type: a.a_type, - length: a.length, - }) - .collect() + args: impl Iterator>>, +) -> Vec> { + args.map(|a| ast::KernelArgument { + name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))), + a_type: a.a_type, + length: a.length, + }) + .collect() } -fn to_ssa<'a, 'b>( - mut id_defs: FnStringIdResolver<'a, 'b>, - f_body: Vec>>, -) -> Vec { - let normalized_ids = normalize_identifiers(&mut id_defs, f_body); +fn expand_fn_params<'a, 'b>( + fn_resolver: &mut FnStringIdResolver<'a, 'b>, + args: impl Iterator>>, +) -> Vec> { + args.map(|a| ast::FnArgument { + state_space: a.state_space, + base: ast::KernelArgument { + name: fn_resolver.add_def(a.base.name, Some(ast::Type::Scalar(a.base.a_type))), + a_type: a.base.a_type, + length: a.base.length, + }, + }) + .collect() +} + +fn to_ssa<'input, 'b>( + mut id_defs: FnStringIdResolver<'input, 'b>, + fn_defs: GlobalFnDeclResolver<'input, 'b>, + f_args: ast::MethodDecl<'input, ExpandedArgParams>, + f_body: Option>>>, +) -> ExpandedFunction<'input> { + let f_body = match f_body { + Some(vec) => vec, + None => { + return ExpandedFunction { + func_directive: f_args, + body: None, + } + } + }; + let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body); let mut numeric_id_defs = id_defs.finish(); - let normalized_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); - let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut numeric_id_defs); + let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs); + let unadorned_statements = resolve_fn_calls(&fn_defs, unadorned_statements); + let (f_args, ssa_statements) = + insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs); let expanded_statements = insert_implicit_conversions(expanded_statements, &mut numeric_id_defs); let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); - labeled_statements + let sorted_statements = normalize_variable_decls(labeled_statements); + ExpandedFunction { + func_directive: f_args, + body: Some(sorted_statements), + } +} + +fn normalize_variable_decls(mut func: Vec) -> Vec { + func[1..].sort_by_key(|s| match s { + Statement::Variable(_) => 0, + _ => 1, + }); + func +} + +fn resolve_fn_calls( + fn_defs: &GlobalFnDeclResolver, + func: Vec, +) -> Vec { + func.into_iter() + .map(|s| { + match s { + Statement::Instruction(ast::Instruction::Call(call)) => { + // TODO: error out if lengths don't match + let fn_def = fn_defs.get_fn_decl(call.func); + let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals); + let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params); + let resolved_call = ResolvedCall { + uniform: call.uniform, + ret_params, + func: call.func, + param_list, + }; + Statement::Call(resolved_call) + } + s => s, + } + }) + .collect() +} + +fn to_resolved_fn_args( + params: Vec, + params_decl: &[(ast::FnArgStateSpace, ast::ScalarType)], +) -> Vec> { + params + .into_iter() + .zip(params_decl.iter()) + .map(|(id, &(space, typ))| ArgCall { + id, + typ: ast::Type::Scalar(typ), + space: space, + }) + .collect::>() } fn normalize_labels( @@ -297,9 +404,11 @@ fn normalize_labels( labels_in_use.insert(cond.if_true); labels_in_use.insert(cond.if_false); } - Statement::Variable(_, _, _, _) + Statement::Call(_) + | Statement::Variable(_) | Statement::LoadVar(_, _) | Statement::StoreVar(_, _) + | Statement::RetValue(_, _) | Statement::Conversion(_) | Statement::Constant(_) | Statement::Label(_) => (), @@ -314,14 +423,14 @@ fn normalize_labels( } fn normalize_predicates( - func: Vec>, + func: Vec, id_def: &mut NumericIdResolver, -) -> Vec { +) -> Vec { let mut result = Vec::with_capacity(func.len()); for s in func { match s { - ast::Statement::Label(id) => result.push(Statement::Label(id)), - ast::Statement::Instruction(pred, inst) => { + Statement::Label(id) => result.push(Statement::Label(id)), + Statement::Instruction((pred, inst)) => { if let Some(pred) = pred { let if_true = id_def.new_id(None); let if_false = id_def.new_id(None); @@ -347,66 +456,100 @@ fn normalize_predicates( result.push(Statement::Instruction(inst)); } } - ast::Statement::Variable(var) => result.push(Statement::Variable( - var.name, var.v_type, var.space, var.align, - )), + Statement::Variable(var) => result.push(Statement::Variable(var)), // Blocks are flattened when resolving ids - ast::Statement::Block(_) => unreachable!(), + _ => unreachable!(), } } result } -fn insert_mem_ssa_statements( - func: Vec, +fn insert_mem_ssa_statements<'a, 'b>( + func: Vec, id_def: &mut NumericIdResolver, -) -> Vec { + mut f_args: ast::MethodDecl<'a, ExpandedArgParams>, +) -> ( + ast::MethodDecl<'a, ExpandedArgParams>, + Vec, +) { let mut result = Vec::with_capacity(func.len()); + let out_param = match &mut f_args { + ast::MethodDecl::Kernel(_, in_params) => { + for p in in_params.iter_mut() { + let typ = ast::Type::Scalar(p.a_type); + let new_id = id_def.new_id(Some(typ)); + result.push(Statement::Variable(VariableDecl { + space: ast::StateSpace::Reg, + align: None, + v_type: typ, + name: p.name, + })); + result.push(Statement::StoreVar( + ast::Arg2St { + src1: p.name, + src2: new_id, + }, + typ, + )); + p.name = new_id; + } + None + } + ast::MethodDecl::Func(out_params, _, in_params) => { + for p in in_params.iter_mut() { + let typ = ast::Type::Scalar(p.base.a_type); + let new_id = id_def.new_id(Some(typ)); + result.push(Statement::Variable(VariableDecl { + space: ast::StateSpace::Reg, + align: None, + v_type: typ, + name: p.base.name, + })); + result.push(Statement::StoreVar( + ast::Arg2St { + src1: p.base.name, + src2: new_id, + }, + typ, + )); + p.base.name = new_id; + } + match &mut **out_params { + [p] => { + result.push(Statement::Variable(VariableDecl { + space: ast::StateSpace::Reg, + align: None, + v_type: ast::Type::Scalar(p.base.a_type), + name: p.base.name, + })); + Some(p.base.name) + } + [] => None, + _ => todo!(), + } + } + }; for s in func { match s { + Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call), Statement::Instruction(inst) => match inst { - ast::Instruction::Ld( - ld - @ - ast::LdData { - state_space: ast::LdStateSpace::Param, - .. - }, - arg, - ) => { - result.push(Statement::Instruction(ast::Instruction::Ld(ld, arg))); - } - inst => { - let mut post_statements = Vec::new(); - let inst = inst.visit_variable(&mut |desc| { - let id_type = match (desc.typ, desc.is_pointer) { - (Some(t), false) => t, - (Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64), - (None, _) => return desc.op, - }; - let generated_id = id_def.new_id(Some(id_type)); - if !desc.is_dst { - result.push(Statement::LoadVar( - Arg2 { - dst: generated_id, - src: desc.op, - }, - id_type, - )); - } else { - post_statements.push(Statement::StoreVar( - Arg2St { - src1: desc.op, - src2: generated_id, - }, - id_type, - )); - } - generated_id - }); - result.push(Statement::Instruction(inst)); - result.append(&mut post_statements); + ast::Instruction::Ret(d) => { + if let Some(out_param) = out_param { + let typ = id_def.get_type(out_param); + let new_id = id_def.new_id(Some(typ)); + result.push(Statement::LoadVar( + ast::Arg2 { + dst: new_id, + src: out_param, + }, + typ, + )); + result.push(Statement::RetValue(d, new_id)); + } else { + result.push(Statement::Instruction(ast::Instruction::Ret(d))) + } } + inst => insert_mem_ssa_statement_default(id_def, &mut result, inst), }, Statement::Conditional(mut bra) => { let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar( @@ -422,65 +565,116 @@ fn insert_mem_ssa_statements( bra.predicate = generated_id; result.push(Statement::Conditional(bra)); } - s @ Statement::Variable(_, _, _, _) | s @ Statement::Label(_) => result.push(s), + s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), Statement::LoadVar(_, _) | Statement::StoreVar(_, _) | Statement::Conversion(_) + | Statement::RetValue(_, _) | Statement::Constant(_) => unreachable!(), } } - result + (f_args, result) } -fn expand_arguments<'a, 'b, 'c>( - func: Vec, - id_def: &'c mut NumericIdResolver<'a, 'b>, +trait VisitVariable: Sized { + fn visit_variable<'a, F: FnMut(ArgumentDescriptor) -> spirv::Word>( + self, + f: &mut F, + ) -> UnadornedStatement; +} +trait VisitVariableExpanded { + fn visit_variable_extended) -> spirv::Word>( + self, + f: &mut F, + ) -> ExpandedStatement; +} + +fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( + id_def: &mut NumericIdResolver, + result: &mut Vec, + stmt: F, +) { + let mut post_statements = Vec::new(); + let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor| { + let id_type = match (desc.typ, desc.is_pointer) { + (Some(t), false) => t, + (Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64), + (None, _) => return desc.op, + }; + let generated_id = id_def.new_id(Some(id_type)); + if !desc.is_dst { + result.push(Statement::LoadVar( + Arg2 { + dst: generated_id, + src: desc.op, + }, + id_type, + )); + } else { + post_statements.push(Statement::StoreVar( + Arg2St { + src1: desc.op, + src2: generated_id, + }, + id_type, + )); + } + generated_id + }); + result.push(new_statement); + result.append(&mut post_statements); +} + +fn expand_arguments<'a, 'b>( + func: Vec, + id_def: &'b mut NumericIdResolver<'a>, ) -> Vec { let mut result = Vec::with_capacity(func.len()); for s in func { match s { + Statement::Call(call) => { + let mut visitor = FlattenArguments::new(&mut result, id_def); + let new_call = call.map(&mut visitor); + result.push(Statement::Call(new_call)); + } Statement::Instruction(inst) => { let mut visitor = FlattenArguments::new(&mut result, id_def); let new_inst = inst.map(&mut visitor); result.push(Statement::Instruction(new_inst)); } - Statement::Variable(id, typ, ss, align) => { - result.push(Statement::Variable(id, typ, ss, align)) - } + Statement::Variable(v_decl) => result.push(Statement::Variable(v_decl)), Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)), Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)), + Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(_) | Statement::Constant(_) => unreachable!(), } } result } -struct FlattenArguments<'a, 'b, 'c> { - func: &'c mut Vec, - id_def: &'c mut NumericIdResolver<'a, 'b>, +struct FlattenArguments<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut NumericIdResolver<'a>, } -impl<'a, 'b, 'c> FlattenArguments<'a, 'b, 'c> { - fn new( - func: &'c mut Vec, - id_def: &'c mut NumericIdResolver<'a, 'b>, - ) -> Self { +impl<'a, 'b> FlattenArguments<'a, 'b> { + fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { FlattenArguments { func, id_def } } } -impl<'a, 'b, 'c> ArgumentMapVisitor - for FlattenArguments<'a, 'b, 'c> +impl<'a, 'b> ArgumentMapVisitor + for FlattenArguments<'a, 'b> { - fn dst_variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + fn variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { desc.op } - fn src_operand(&mut self, desc: ArgumentDescriptor>) -> spirv::Word { + fn operand(&mut self, desc: ArgumentDescriptor>) -> spirv::Word { match desc.op { - ast::Operand::Reg(r) => r, + ast::Operand::Reg(r) => self.variable(desc.new_op(r)), ast::Operand::Imm(x) => { if let Some(typ) = desc.typ { let scalar_t = if let ast::Type::Scalar(scalar) = typ { @@ -535,12 +729,22 @@ impl<'a, 'b, 'c> ArgumentMapVisitor } } + fn src_call_operand( + &mut self, + desc: ArgumentDescriptor>, + ) -> spirv::Word { + match desc.op { + ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg)), + ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x))), + } + } + fn src_mov_operand( &mut self, desc: ArgumentDescriptor>, ) -> spirv::Word { match &desc.op { - ast::MovOperand::Op(opr) => self.src_operand(desc.new_op(*opr)), + ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)), ast::MovOperand::Vec(_, _) => todo!(), } } @@ -553,11 +757,12 @@ impl<'a, 'b, 'c> ArgumentMapVisitor - ld.param: not documented, but for instruction `ld.param. x, [y]`, semantics are to first zext/chop/bitcast `y` as needed and then do documented special ld/st/cvt conversion rules for destination operands - - generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64, - which is bitcast to a pointer, dereferenced and then documented special - ld/st/cvt conversion rules are applied to dst - - generic st: for instruction `st [x], y`, x must be of type b64/u64/s64, - which is bitcast to a pointer + - st.param [x] y (used as function return arguments) same rule as above applies + - generic/global ld: for instruction `ld x, [y]`, y must be of type + b64/u64/s64, which is bitcast to a pointer, dereferenced and then + documented special ld/st/cvt conversion rules are applied to dst + - generic/global st: for instruction `st [x], y`, x must be of type + b64/u64/s64, which is bitcast to a pointer */ fn insert_implicit_conversions( func: Vec, @@ -566,6 +771,7 @@ fn insert_implicit_conversions( let mut result = Vec::with_capacity(func.len()); for s in func.into_iter() { match s { + Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call), Statement::Instruction(inst) => match inst { ast::Instruction::Ld(ld, mut arg) => { arg.src = insert_implicit_conversions_ld_src( @@ -611,9 +817,10 @@ fn insert_implicit_conversions( s @ Statement::Conditional(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) - | s @ Statement::Variable(_, _, _, _) + | s @ Statement::Variable(_) | s @ Statement::LoadVar(_, _) - | s @ Statement::StoreVar(_, _) => result.push(s), + | s @ Statement::StoreVar(_, _) + | s @ Statement::RetValue(_, _) => result.push(s), Statement::Conversion(_) => unreachable!(), } } @@ -623,25 +830,19 @@ fn insert_implicit_conversions( fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, - args: &[ast::Argument], -) -> spirv::Word { - map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type))) -} - -fn emit_function_args( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - args: &[ast::Argument], -) { - for arg in args { - let result_type = map.get_or_add_scalar(builder, arg.a_type); - let inst = dr::Instruction::new( - spirv::Op::FunctionParameter, - Some(result_type), - Some(arg.name), - Vec::new(), - ); - builder.function.as_mut().unwrap().parameters.push(inst); + method_decl: &ast::MethodDecl, +) -> (spirv::Word, spirv::Word) { + match method_decl { + ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn( + builder, + out_params.iter().map(|p| SpirvType::from(p.base.a_type)), + in_params.iter().map(|p| SpirvType::from(p.base.a_type)), + ), + ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn( + builder, + iter::empty(), + params.iter().map(|p| SpirvType::from(p.a_type)), + ), } } @@ -649,9 +850,9 @@ fn emit_function_body_ops( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, - func: &Option>, + func: &[ExpandedStatement], ) -> Result<(), dr::Error> { - for s in func.as_ref().unwrap() { + for s in func { match s { Statement::Label(id) => { if builder.block.is_some() { @@ -667,13 +868,26 @@ fn emit_function_body_ops( } match s { Statement::Label(_) => (), - Statement::Variable(id, typ, ss, align) => { + Statement::Call(call) => { + let (result_type, result_id) = match &*call.ret_params { + [p] => (map.get_or_add(builder, SpirvType::from(p.typ)), p.id), + _ => todo!(), + }; + let arg_list = call.param_list.iter().map(|p| p.id).collect::>(); + builder.function_call(result_type, Some(result_id), call.func, arg_list)?; + } + Statement::Variable(VariableDecl { + name: id, + v_type: typ, + space: ss, + align, + }) => { let type_id = map.get_or_add( builder, SpirvType::new_pointer(*typ, spirv::StorageClass::Function), ); let st_class = match ss { - ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Reg | ast::StateSpace::Param => spirv::StorageClass::Function, ast::StateSpace::Local => spirv::StorageClass::Workgroup, _ => todo!(), }; @@ -722,7 +936,7 @@ fn emit_function_body_ops( } Statement::Instruction(inst) => match inst { ast::Instruction::Abs(_, _) => todo!(), - ast::Instruction::Call(_, _) => todo!(), + ast::Instruction::Call(_) => unreachable!(), // SPIR-V does not support marking jumps as guaranteed-converged ast::Instruction::Bra(_, arg) => { builder.branch(arg.src)?; @@ -737,7 +951,8 @@ fn emit_function_body_ops( builder.load(result_type, Some(arg.dst), arg.src, None, [])?; } ast::LdStateSpace::Param => { - builder.store(arg.dst, arg.src, None, [])?; + let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); + builder.copy_object(result_type, Some(arg.dst), arg.src)?; } _ => todo!(), } @@ -746,11 +961,17 @@ fn emit_function_body_ops( if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() || (data.state_space != ast::StStateSpace::Generic + && data.state_space != ast::StStateSpace::Param && data.state_space != ast::StStateSpace::Global) { todo!() } - builder.store(arg.src1, arg.src2, None, &[])?; + if data.state_space == ast::StStateSpace::Param { + let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); + builder.copy_object(result_type, Some(arg.src1), arg.src2)?; + } else { + builder.store(arg.src1, arg.src2, None, &[])?; + } } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, @@ -808,6 +1029,9 @@ fn emit_function_body_ops( Statement::StoreVar(arg, _) => { builder.store(arg.src1, arg.src2, None, [])?; } + Statement::RetValue(_, id) => { + builder.ret_value(*id)?; + } } } Ok(()) @@ -1123,8 +1347,9 @@ fn emit_implicit_conversion( // TODO: support scopes fn normalize_identifiers<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, + fn_defs: &GlobalFnDeclResolver<'a, 'b>, func: Vec>>, -) -> Vec> { +) -> Vec { for s in func.iter() { match s { ast::Statement::Label(id) => { @@ -1135,58 +1360,63 @@ fn normalize_identifiers<'a, 'b>( } let mut result = Vec::new(); for s in func { - expand_map_variables(id_defs, &mut result, s); + expand_map_variables(id_defs, fn_defs, &mut result, s); } result } fn expand_map_variables<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, - result: &mut Vec>, + fn_defs: &GlobalFnDeclResolver, + result: &mut Vec, s: ast::Statement>, ) { match s { ast::Statement::Block(block) => { id_defs.start_block(); for s in block { - expand_map_variables(id_defs, result, s); + expand_map_variables(id_defs, fn_defs, result, s); } id_defs.end_block(); } - ast::Statement::Label(name) => result.push(ast::Statement::Label(id_defs.get_id(name))), - ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction( + ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name))), + ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))), i.map_variable(&mut |id| id_defs.get_id(id)), - )), + ))), ast::Statement::Variable(var) => match var.count { Some(count) => { for new_id in id_defs.add_defs(var.name, count, var.v_type) { - result.push(ast::Statement::Variable(ast::Variable { + result.push(Statement::Variable(VariableDecl { space: var.space, align: var.align, v_type: var.v_type, name: new_id, - count: None, })) } } None => { let new_id = id_defs.add_def(var.name, Some(var.v_type)); - result.push(ast::Statement::Variable(ast::Variable { + result.push(Statement::Variable(VariableDecl { space: var.space, align: var.align, v_type: var.v_type, name: new_id, - count: None, })); } }, } } -struct GlobalStringIdResolver<'a> { +struct GlobalStringIdResolver<'input> { current_id: spirv::Word, - variables: HashMap, spirv::Word>, + variables: HashMap, spirv::Word>, + fns: HashMap, +} + +pub struct FnDecl { + ret_vals: Vec<(ast::FnArgStateSpace, ast::ScalarType)>, + params: Vec<(ast::FnArgStateSpace, ast::ScalarType)>, } impl<'a> GlobalStringIdResolver<'a> { @@ -1194,14 +1424,20 @@ impl<'a> GlobalStringIdResolver<'a> { Self { current_id: start_id, variables: HashMap::new(), + fns: HashMap::new(), } } - fn add_def(&mut self, id: &'a str) -> spirv::Word { - let numeric_id = self.current_id; - self.variables.insert(Cow::Borrowed(id), numeric_id); - self.current_id += 1; - numeric_id + fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word { + match self.variables.entry(Cow::Borrowed(id)) { + hash_map::Entry::Occupied(e) => *(e.get()), + hash_map::Entry::Vacant(e) => { + let numeric_id = self.current_id; + e.insert(numeric_id); + self.current_id += 1; + numeric_id + } + } } fn get_id(&self, id: &str) -> spirv::Word { @@ -1211,27 +1447,84 @@ impl<'a> GlobalStringIdResolver<'a> { fn current_id(&self) -> spirv::Word { self.current_id } + + fn start_fn<'b>( + &'b mut self, + header: &'b ast::MethodDecl<'a, ast::ParsedArgParams<'a>>, + ) -> ( + FnStringIdResolver<'a, 'b>, + GlobalFnDeclResolver<'a, 'b>, + ast::MethodDecl<'a, ExpandedArgParams>, + ) { + // In case a function decl was inserted eearlier we want to use its id + let name_id = self.get_or_add_def(header.name()); + let mut fn_resolver = FnStringIdResolver { + current_id: &mut self.current_id, + global_variables: &self.variables, + variables: vec![HashMap::new(); 1], + type_check: HashMap::new(), + }; + let new_fn_decl = match header { + ast::MethodDecl::Kernel(name, params) => { + ast::MethodDecl::Kernel(name, expand_kernel_params(&mut fn_resolver, params.iter())) + } + ast::MethodDecl::Func(ret_params, _, params) => { + let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter()); + let params_ids = expand_fn_params(&mut fn_resolver, params.iter()); + self.fns.insert( + name_id, + FnDecl { + ret_vals: ret_params_ids + .iter() + .map(|p| (p.state_space, p.base.a_type)) + .collect(), + params: params_ids + .iter() + .map(|p| (p.state_space, p.base.a_type)) + .collect(), + }, + ); + ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) + } + }; + ( + fn_resolver, + GlobalFnDeclResolver { + variables: &self.variables, + fns: &self.fns, + }, + new_fn_decl, + ) + } } -struct FnStringIdResolver<'a, 'b> { - global: &'b mut GlobalStringIdResolver<'a>, - variables: Vec, spirv::Word>>, - type_check: HashMap, +pub struct GlobalFnDeclResolver<'input, 'a> { + variables: &'a HashMap, spirv::Word>, + fns: &'a HashMap, } -impl<'a, 'b> FnStringIdResolver<'a, 'b> { - fn new(global: &'b mut GlobalStringIdResolver<'a>, f_name: &'a str) -> Self { - global.add_def(f_name); - Self { - global: global, - variables: vec![HashMap::new(); 1], - type_check: HashMap::new(), - } +impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { + fn get_fn_decl(&self, id: spirv::Word) -> &FnDecl { + &self.fns[&id] + } + + fn get_fn_decl_str(&self, id: &str) -> &'a FnDecl { + &self.fns[&self.variables[id]] } +} + +struct FnStringIdResolver<'input, 'b> { + current_id: &'b mut spirv::Word, + global_variables: &'b HashMap, spirv::Word>, + //global: &'b mut GlobalStringIdResolver<'a>, + variables: Vec, spirv::Word>>, + type_check: HashMap, +} - fn finish(self) -> NumericIdResolver<'a, 'b> { +impl<'a, 'b> FnStringIdResolver<'a, 'b> { + fn finish(self) -> NumericIdResolver<'b> { NumericIdResolver { - global: self.global, + current_id: self.current_id, type_check: self.type_check, } } @@ -1251,15 +1544,11 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { None => continue, } } - self.global.variables[id] - } - - fn add_global_def(&mut self, id: &'a str) -> spirv::Word { - self.global.add_def(id) + self.global_variables[id] } fn add_def(&mut self, id: &'a str, typ: Option) -> spirv::Word { - let numeric_id = self.global.current_id; + let numeric_id = *self.current_id; self.variables .last_mut() .unwrap() @@ -1267,7 +1556,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { if let Some(typ) = typ { self.type_check.insert(numeric_id, typ); } - self.global.current_id += 1; + *self.current_id += 1; numeric_id } @@ -1278,7 +1567,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { count: u32, typ: ast::Type, ) -> impl Iterator { - let numeric_id = self.global.current_id; + let numeric_id = *self.current_id; for i in 0..count { self.variables .last_mut() @@ -1286,65 +1575,191 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); self.type_check.insert(numeric_id + i, typ); } - self.global.current_id += count; + *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) } } -struct NumericIdResolver<'a, 'b> { - global: &'b mut GlobalStringIdResolver<'a>, +struct NumericIdResolver<'b> { + current_id: &'b mut spirv::Word, type_check: HashMap, } -impl<'a, 'b> NumericIdResolver<'a, 'b> { +impl<'b> NumericIdResolver<'b> { fn get_type(&self, id: spirv::Word) -> ast::Type { self.type_check[&id] } fn new_id(&mut self, typ: Option) -> spirv::Word { - let new_id = self.global.current_id; + let new_id = *self.current_id; if let Some(typ) = typ { self.type_check.insert(new_id, typ); } - self.global.current_id += 1; + *self.current_id += 1; new_id } } -enum Statement { - Variable(spirv::Word, ast::Type, ast::StateSpace, Option), - LoadVar(ast::Arg2, ast::Type), - StoreVar(ast::Arg2St, ast::Type), +enum Statement { Label(u32), + Variable(VariableDecl), Instruction(I), + LoadVar(ast::Arg2, ast::Type), + StoreVar(ast::Arg2St, ast::Type), + Call(ResolvedCall

), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Conversion(ImplicitConversion), Constant(ConstantDefinition), + RetValue(ast::RetData, spirv::Word), +} + +struct VariableDecl { + pub space: ast::StateSpace, + pub align: Option, + pub v_type: ast::Type, + pub name: spirv::Word, +} + +struct ResolvedCall { + pub uniform: bool, + pub ret_params: Vec>, + pub func: spirv::Word, + pub param_list: Vec>, +} + +impl> ResolvedCall { + fn map, V: ArgumentMapVisitor>( + self, + visitor: &mut V, + ) -> ResolvedCall { + let ret_params = self + .ret_params + .into_iter() + .map(|p| { + let new_id = visitor.variable(ArgumentDescriptor { + op: p.id, + typ: Some(p.typ), + is_dst: true, + is_pointer: false, + }); + ArgCall { + id: new_id, + typ: p.typ, + space: p.space, + } + }) + .collect(); + let func = visitor.variable(ArgumentDescriptor { + op: self.func, + typ: None, + is_dst: false, + is_pointer: false, + }); + let param_list = self + .param_list + .into_iter() + .map(|p| { + let new_id = visitor.src_call_operand(ArgumentDescriptor { + op: p.id, + typ: Some(p.typ), + is_dst: false, + is_pointer: false, + }); + ArgCall { + id: new_id, + typ: p.typ, + space: p.space, + } + }) + .collect(); + ResolvedCall { + uniform: self.uniform, + ret_params, + func, + param_list, + } + } +} + +impl VisitVariable for ResolvedCall { + fn visit_variable<'a, F: FnMut(ArgumentDescriptor) -> spirv::Word>( + self, + f: &mut F, + ) -> UnadornedStatement { + Statement::Call(self.map(f)) + } +} + +impl VisitVariableExpanded for ResolvedCall { + fn visit_variable_extended) -> spirv::Word>( + self, + f: &mut F, + ) -> ExpandedStatement { + Statement::Call(self.map(f)) + } +} + +struct ArgCall { + id: ID, + typ: ast::Type, + space: ast::FnArgStateSpace, +} + +pub trait ArgParamsEx: ast::ArgParams { + fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl; +} + +impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { + fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl { + decl.get_fn_decl_str(id) + } } enum NormalizedArgParams {} -type NormalizedStatement = Statement>; +type NormalizedStatement = Statement< + ( + Option>, + ast::Instruction, + ), + NormalizedArgParams, +>; +type UnadornedStatement = Statement, NormalizedArgParams>; impl ast::ArgParams for NormalizedArgParams { type ID = spirv::Word; type Operand = ast::Operand; + type CallOperand = ast::CallOperand; type MovOperand = ast::MovOperand; } +impl ArgParamsEx for NormalizedArgParams { + fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl { + decl.get_fn_decl(*id) + } +} + enum ExpandedArgParams {} -type ExpandedStatement = Statement>; +type ExpandedStatement = Statement, ExpandedArgParams>; type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>; impl ast::ArgParams for ExpandedArgParams { type ID = spirv::Word; type Operand = spirv::Word; + type CallOperand = spirv::Word; type MovOperand = spirv::Word; } -trait ArgumentMapVisitor { - fn dst_variable(&mut self, desc: ArgumentDescriptor) -> U::ID; - fn src_operand(&mut self, desc: ArgumentDescriptor) -> U::Operand; +impl ArgParamsEx for ExpandedArgParams { + fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl { + decl.get_fn_decl(*id) + } +} + +trait ArgumentMapVisitor { + fn variable(&mut self, desc: ArgumentDescriptor) -> U::ID; + fn operand(&mut self, desc: ArgumentDescriptor) -> U::Operand; + fn src_call_operand(&mut self, desc: ArgumentDescriptor) -> U::CallOperand; fn src_mov_operand(&mut self, desc: ArgumentDescriptor) -> U::MovOperand; } @@ -1352,12 +1767,16 @@ impl ArgumentMapVisitor for T where T: FnMut(ArgumentDescriptor) -> spirv::Word, { - fn dst_variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + fn variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { self(desc) } - fn src_operand(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + fn operand(&mut self, desc: ArgumentDescriptor) -> spirv::Word { self(desc) } + fn src_call_operand(&mut self, mut desc: ArgumentDescriptor) -> spirv::Word { + desc.op = self(desc.new_op(desc.op)); + desc.op + } fn src_mov_operand(&mut self, desc: ArgumentDescriptor) -> spirv::Word { self(desc) } @@ -1367,11 +1786,11 @@ impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> fo where T: FnMut(&str) -> spirv::Word, { - fn dst_variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word { + fn variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word { self(desc.op) } - fn src_operand( + fn operand( &mut self, desc: ArgumentDescriptor>, ) -> ast::Operand { @@ -1382,12 +1801,22 @@ where } } + fn src_call_operand( + &mut self, + desc: ArgumentDescriptor>, + ) -> ast::CallOperand { + match desc.op { + ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(id)), + ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm), + } + } + fn src_mov_operand( &mut self, desc: ArgumentDescriptor>, ) -> ast::MovOperand { match desc.op { - ast::MovOperand::Op(op) => ast::MovOperand::Op(self.src_operand(desc.new_op(op))), + ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))), ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2), } } @@ -1411,17 +1840,21 @@ impl ArgumentDescriptor { } } -impl ast::Instruction { - fn map>( +impl ast::Instruction { + fn map>( self, visitor: &mut V, ) -> ast::Instruction { match self { ast::Instruction::Abs(_, _) => todo!(), - ast::Instruction::Call(_, _) => todo!(), + ast::Instruction::Call(_) => unreachable!(), ast::Instruction::Ld(d, a) => { let inst_type = d.typ; - ast::Instruction::Ld(d, a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)))) + let src_is_pointer = d.state_space != ast::LdStateSpace::Param; + ast::Instruction::Ld( + d, + a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer), + ) } ast::Instruction::Mov(d, a) => { let inst_type = d.typ; @@ -1467,7 +1900,11 @@ impl ast::Instruction { } ast::Instruction::St(d, a) => { let inst_type = d.typ; - ast::Instruction::St(d, a.map(visitor, Some(ast::Type::Scalar(inst_type)))) + let param_space = d.state_space == ast::StStateSpace::Param; + ast::Instruction::St( + d, + a.map(visitor, Some(ast::Type::Scalar(inst_type)), param_space), + ) } ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), @@ -1479,12 +1916,12 @@ impl ast::Instruction { } } -impl ast::Instruction { - fn visit_variable) -> spirv::Word>( +impl VisitVariable for ast::Instruction { + fn visit_variable<'a, F: FnMut(ArgumentDescriptor) -> spirv::Word>( self, f: &mut F, - ) -> ast::Instruction { - self.map(f) + ) -> UnadornedStatement { + Statement::Instruction(self.map(f)) } } @@ -1492,11 +1929,11 @@ impl ArgumentMapVisitor for T where T: FnMut(ArgumentDescriptor) -> spirv::Word, { - fn dst_variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { + fn variable(&mut self, desc: ArgumentDescriptor) -> spirv::Word { self(desc) } - fn src_operand( + fn operand( &mut self, desc: ArgumentDescriptor>, ) -> ast::Operand { @@ -1507,6 +1944,16 @@ where } } + fn src_call_operand( + &mut self, + desc: ArgumentDescriptor>, + ) -> ast::CallOperand { + match desc.op { + ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id))), + ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm), + } + } + fn src_mov_operand( &mut self, desc: ArgumentDescriptor>, @@ -1515,7 +1962,7 @@ where ast::MovOperand::Op(op) => ast::MovOperand::Op(ArgumentMapVisitor::< NormalizedArgParams, NormalizedArgParams, - >::src_operand( + >::operand( self, desc.new_op(op) )), ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2), @@ -1524,17 +1971,8 @@ where } impl ast::Instruction { - fn visit_variable_extended) -> spirv::Word>( - self, - f: &mut F, - ) -> Self { - self.map(f) - } - fn jump_target(&self) -> Option { match self { - ast::Instruction::Abs(_, _) => todo!(), - ast::Instruction::Call(_, _) => todo!(), ast::Instruction::Bra(_, a) => Some(a.src), ast::Instruction::Ld(_, _) | ast::Instruction::Mov(_, _) @@ -1547,11 +1985,22 @@ impl ast::Instruction { | ast::Instruction::Cvta(_, _) | ast::Instruction::Shl(_, _) | ast::Instruction::St(_, _) - | ast::Instruction::Ret(_) => None, + | ast::Instruction::Ret(_) + | ast::Instruction::Abs(_, _) + | ast::Instruction::Call(_) => None, } } } +impl VisitVariableExpanded for ast::Instruction { + fn visit_variable_extended) -> spirv::Word>( + self, + f: &mut F, + ) -> ExpandedStatement { + Statement::Instruction(self.map(f)) + } +} + type Arg2 = ast::Arg2; type Arg2St = ast::Arg2St; @@ -1592,24 +2041,38 @@ impl ast::PredAt { } } -// REMOVE impl<'a> ast::Instruction> { fn map_variable spirv::Word>( self, f: &mut F, ) -> ast::Instruction { - self.map(f) + match self { + ast::Instruction::Call(call) => { + let call_inst = ast::CallInst { + uniform: call.uniform, + ret_params: call.ret_params.into_iter().map(|p| f(p)).collect(), + func: f(call.func), + param_list: call + .param_list + .into_iter() + .map(|p| p.map_variable(f)) + .collect(), + }; + ast::Instruction::Call(call_inst) + } + i => i.map(f), + } } } -impl ast::Arg1 { - fn map>( +impl ast::Arg1 { + fn map>( self, visitor: &mut V, t: Option, ) -> ast::Arg1 { ast::Arg1 { - src: visitor.dst_variable(ArgumentDescriptor { + src: visitor.variable(ArgumentDescriptor { op: self.src, typ: t, is_dst: false, @@ -1619,20 +2082,20 @@ impl ast::Arg1 { } } -impl ast::Arg2 { - fn map>( +impl ast::Arg2 { + fn map>( self, visitor: &mut V, t: Option, ) -> ast::Arg2 { ast::Arg2 { - dst: visitor.dst_variable(ArgumentDescriptor { + dst: visitor.variable(ArgumentDescriptor { op: self.dst, typ: t, is_dst: true, is_pointer: false, }), - src: visitor.src_operand(ArgumentDescriptor { + src: visitor.operand(ArgumentDescriptor { op: self.src, typ: t, is_dst: false, @@ -1641,41 +2104,42 @@ impl ast::Arg2 { } } - fn map_ld>( + fn map_ld>( self, visitor: &mut V, t: Option, + is_src_pointer: bool, ) -> ast::Arg2 { ast::Arg2 { - dst: visitor.dst_variable(ArgumentDescriptor { + dst: visitor.variable(ArgumentDescriptor { op: self.dst, typ: t, is_dst: true, is_pointer: false, }), - src: visitor.src_operand(ArgumentDescriptor { + src: visitor.operand(ArgumentDescriptor { op: self.src, typ: t, is_dst: false, - is_pointer: true, + is_pointer: is_src_pointer, }), } } - fn map_cvt>( + fn map_cvt>( self, visitor: &mut V, dst_t: ast::Type, src_t: ast::Type, ) -> ast::Arg2 { ast::Arg2 { - dst: visitor.dst_variable(ArgumentDescriptor { + dst: visitor.variable(ArgumentDescriptor { op: self.dst, typ: Some(dst_t), is_dst: true, is_pointer: false, }), - src: visitor.src_operand(ArgumentDescriptor { + src: visitor.operand(ArgumentDescriptor { op: self.src, typ: Some(src_t), is_dst: false, @@ -1685,20 +2149,21 @@ impl ast::Arg2 { } } -impl ast::Arg2St { - fn map>( +impl ast::Arg2St { + fn map>( self, visitor: &mut V, t: Option, + param_space: bool, ) -> ast::Arg2St { ast::Arg2St { - src1: visitor.src_operand(ArgumentDescriptor { + src1: visitor.operand(ArgumentDescriptor { op: self.src1, typ: t, - is_dst: false, - is_pointer: true, + is_dst: param_space, + is_pointer: !param_space, }), - src2: visitor.src_operand(ArgumentDescriptor { + src2: visitor.operand(ArgumentDescriptor { op: self.src2, typ: t, is_dst: false, @@ -1708,14 +2173,14 @@ impl ast::Arg2St { } } -impl ast::Arg2Mov { - fn map>( +impl ast::Arg2Mov { + fn map>( self, visitor: &mut V, t: Option, ) -> ast::Arg2Mov { ast::Arg2Mov { - dst: visitor.dst_variable(ArgumentDescriptor { + dst: visitor.variable(ArgumentDescriptor { op: self.dst, typ: t, is_dst: true, @@ -1731,26 +2196,26 @@ impl ast::Arg2Mov { } } -impl ast::Arg3 { - fn map_non_shift>( +impl ast::Arg3 { + fn map_non_shift>( self, visitor: &mut V, t: Option, ) -> ast::Arg3 { ast::Arg3 { - dst: visitor.dst_variable(ArgumentDescriptor { + dst: visitor.variable(ArgumentDescriptor { op: self.dst, typ: t, is_dst: true, is_pointer: false, }), - src1: visitor.src_operand(ArgumentDescriptor { + src1: visitor.operand(ArgumentDescriptor { op: self.src1, typ: t, is_dst: false, is_pointer: false, }), - src2: visitor.src_operand(ArgumentDescriptor { + src2: visitor.operand(ArgumentDescriptor { op: self.src2, typ: t, is_dst: false, @@ -1759,25 +2224,25 @@ impl ast::Arg3 { } } - fn map_shift>( + fn map_shift>( self, visitor: &mut V, t: Option, ) -> ast::Arg3 { ast::Arg3 { - dst: visitor.dst_variable(ArgumentDescriptor { + dst: visitor.variable(ArgumentDescriptor { op: self.dst, typ: t, is_dst: true, is_pointer: false, }), - src1: visitor.src_operand(ArgumentDescriptor { + src1: visitor.operand(ArgumentDescriptor { op: self.src1, typ: t, is_dst: false, is_pointer: false, }), - src2: visitor.src_operand(ArgumentDescriptor { + src2: visitor.operand(ArgumentDescriptor { op: self.src2, typ: Some(ast::Type::Scalar(ast::ScalarType::U32)), is_dst: false, @@ -1787,34 +2252,34 @@ impl ast::Arg3 { } } -impl ast::Arg4 { - fn map>( +impl ast::Arg4 { + fn map>( self, visitor: &mut V, t: Option, ) -> ast::Arg4 { ast::Arg4 { - dst1: visitor.dst_variable(ArgumentDescriptor { + dst1: visitor.variable(ArgumentDescriptor { op: self.dst1, typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), is_dst: true, is_pointer: false, }), dst2: self.dst2.map(|dst2| { - visitor.dst_variable(ArgumentDescriptor { + visitor.variable(ArgumentDescriptor { op: dst2, typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), is_dst: true, is_pointer: false, }) }), - src1: visitor.src_operand(ArgumentDescriptor { + src1: visitor.operand(ArgumentDescriptor { op: self.src1, typ: t, is_dst: false, is_pointer: false, }), - src2: visitor.src_operand(ArgumentDescriptor { + src2: visitor.operand(ArgumentDescriptor { op: self.src2, typ: t, is_dst: false, @@ -1824,40 +2289,40 @@ impl ast::Arg4 { } } -impl ast::Arg5 { - fn map>( +impl ast::Arg5 { + fn map>( self, visitor: &mut V, t: Option, ) -> ast::Arg5 { ast::Arg5 { - dst1: visitor.dst_variable(ArgumentDescriptor { + dst1: visitor.variable(ArgumentDescriptor { op: self.dst1, typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), is_dst: true, is_pointer: false, }), dst2: self.dst2.map(|dst2| { - visitor.dst_variable(ArgumentDescriptor { + visitor.variable(ArgumentDescriptor { op: dst2, typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), is_dst: true, is_pointer: false, }) }), - src1: visitor.src_operand(ArgumentDescriptor { + src1: visitor.operand(ArgumentDescriptor { op: self.src1, typ: t, is_dst: false, is_pointer: false, }), - src2: visitor.src_operand(ArgumentDescriptor { + src2: visitor.operand(ArgumentDescriptor { op: self.src2, typ: t, is_dst: false, is_pointer: false, }), - src3: visitor.src_operand(ArgumentDescriptor { + src3: visitor.operand(ArgumentDescriptor { op: self.src3, typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)), is_dst: false, @@ -1867,6 +2332,74 @@ impl ast::Arg5 { } } +/* +impl ast::ArgCall { + fn map<'a, U: ArgParamsEx, V: ArgumentMapVisitor>( + self, + visitor: &mut V, + fn_resolve: &GlobalFnDeclResolver<'a>, + ) -> ast::ArgCall { + // TODO: error out if lengths don't match + let fn_decl = T::get_fn_decl(&self.func, fn_resolve); + let ret_params = self + .ret_params + .into_iter() + .zip(fn_decl.ret_vals.iter().copied()) + .map(|(a, (space, typ))| { + visitor.variable(ArgumentDescriptor { + op: a, + typ: Some(ast::Type::Scalar(typ)), + is_dst: true, + is_pointer: if space == ast::FnArgStateSpace::Reg { + false + } else { + true + }, + }) + }) + .collect(); + let func = visitor.variable(ArgumentDescriptor { + op: self.func, + typ: None, + is_dst: false, + is_pointer: false, + }); + let param_list = self + .param_list + .into_iter() + .zip(fn_decl.params.iter().copied()) + .map(|(a, (space, typ))| { + visitor.src_call_operand(ArgumentDescriptor { + op: a, + typ: Some(ast::Type::Scalar(typ)), + is_dst: false, + is_pointer: if space == ast::FnArgStateSpace::Reg { + false + } else { + true + }, + }) + }) + .collect(); + ast::ArgCall { + uniform: false, + ret_params, + func: func, + param_list: param_list, + } + } +} +*/ + +impl ast::CallOperand { + fn map_variable U>(self, f: &mut F) -> ast::CallOperand { + match self { + ast::CallOperand::Reg(id) => ast::CallOperand::Reg(f(id)), + ast::CallOperand::Imm(x) => ast::CallOperand::Imm(x), + } + } +} + impl ast::StStateSpace { fn to_ld_ss(self) -> ast::LdStateSpace { match self { @@ -2282,10 +2815,10 @@ fn should_convert_relaxed_dst( fn insert_implicit_bitcasts( func: &mut Vec, id_def: &mut NumericIdResolver, - instr: ast::Instruction, + stmt: impl VisitVariableExpanded, ) { let mut dst_coercion = None; - let instr = instr.visit_variable_extended(&mut |mut desc| { + let instr = stmt.visit_variable_extended(&mut |mut desc| { let id_type_from_instr = match desc.typ { Some(t) => t, None => return desc.op, @@ -2315,17 +2848,25 @@ fn insert_implicit_bitcasts( desc.op } }); - func.push(Statement::Instruction(instr)); + func.push(instr); if let Some(cond) = dst_coercion { func.push(cond); } } - -impl<'a> ast::FunctionHeader<'a, ast::ParsedArgParams<'a>> { +impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> { fn name(&self) -> &'a str { match self { - ast::FunctionHeader::Kernel(name) => name, - ast::FunctionHeader::Func(_, name) => name, + ast::MethodDecl::Kernel(name, _) => name, + ast::MethodDecl::Func(_, name, _) => name, + } + } +} + +impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> { + fn visit_args(&self, f: impl FnMut(&ast::KernelArgument

)) { + match self { + ast::MethodDecl::Kernel(_, params) => params.iter().for_each(f), + ast::MethodDecl::Func(_, _, params) => params.iter().map(|a| &a.base).for_each(f), } } } -- cgit v1.2.3