summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-06 21:38:01 +0200
committerAndrzej Janik <[email protected]>2020-09-08 21:29:18 +0200
commit76afbeba63d29e1247d5beb00902a8bb0279f791 (patch)
treedb5a79853dcf5fcd4137a6542f9198a17e40e6fe
parentbbb3a6c5cbaff3430191ef4858aa16be8320ce77 (diff)
downloadZLUDA-76afbeba63d29e1247d5beb00902a8bb0279f791.tar.gz
ZLUDA-76afbeba63d29e1247d5beb00902a8bb0279f791.zip
Implement support for PTX call instruction
-rw-r--r--doc/NOTES.md6
-rw-r--r--ptx/src/ast.rs44
-rw-r--r--ptx/src/ptx.lalrpop64
-rw-r--r--ptx/src/test/spirv_run/add.spvtxt84
-rw-r--r--ptx/src/test/spirv_run/block.spvtxt95
-rw-r--r--ptx/src/test/spirv_run/bra.spvtxt105
-rw-r--r--ptx/src/test/spirv_run/call.ptx38
-rw-r--r--ptx/src/test/spirv_run/call.spvtxt73
-rw-r--r--ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt94
-rw-r--r--ptx/src/test/spirv_run/cvta.spvtxt91
-rw-r--r--ptx/src/test/spirv_run/ld_st.spvtxt79
-rw-r--r--ptx/src/test/spirv_run/local_align.spvtxt84
-rw-r--r--ptx/src/test/spirv_run/mod.rs23
-rw-r--r--ptx/src/test/spirv_run/mov.spvtxt88
-rw-r--r--ptx/src/test/spirv_run/mul_hi.spvtxt89
-rw-r--r--ptx/src/test/spirv_run/mul_lo.spvtxt84
-rw-r--r--ptx/src/test/spirv_run/not.spvtxt82
-rw-r--r--ptx/src/test/spirv_run/setp.spvtxt138
-rw-r--r--ptx/src/test/spirv_run/shl.spvtxt86
-rw-r--r--ptx/src/translate.rs1133
20 files changed, 1692 insertions, 888 deletions
diff --git a/doc/NOTES.md b/doc/NOTES.md
index b0f58f7..5e08b7e 100644
--- a/doc/NOTES.md
+++ b/doc/NOTES.md
@@ -75,3 +75,9 @@ CUDA <-> L0
* context ~= context (1.0+)
* graph ~= command list
* module ~= module
+
+IGC
+---
+* IGC is extremely brittle and segfaults on fairly innocent code:
+ * OpBitcast of pointer to uint
+ * OpCopyMemory of alloca'd variable
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 7550d55..cfbdad5 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -51,23 +51,34 @@ pub struct Module<'a> {
pub functions: Vec<ParsedFunction<'a>>,
}
-pub enum FunctionHeader<'a, P: ArgParams> {
- Func(Vec<Argument<P>>, P::ID),
- Kernel(&'a str),
+pub enum MethodDecl<'a, P: ArgParams> {
+ Func(Vec<FnArgument<P>>, P::ID, Vec<FnArgument<P>>),
+ Kernel(&'a str, Vec<KernelArgument<P>>),
}
pub struct Function<'a, P: ArgParams, S> {
- pub func_directive: FunctionHeader<'a, P>,
- pub args: Vec<Argument<P>>,
+ pub func_directive: MethodDecl<'a, P>,
pub body: Option<Vec<S>>,
}
pub type ParsedFunction<'a> = Function<'a, ParsedArgParams<'a>, Statement<ParsedArgParams<'a>>>;
-#[derive(Default)]
-pub struct Argument<P: ArgParams> {
+pub struct FnArgument<P: ArgParams> {
+ pub base: KernelArgument<P>,
+ pub state_space: FnArgStateSpace,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum FnArgStateSpace {
+ Reg,
+ Param,
+}
+
+#[derive(Default, Copy, Clone)]
+pub struct KernelArgument<P: ArgParams> {
pub name: P::ID,
pub a_type: ScalarType,
+ // TODO: turn length into part of type definition
pub length: u32,
}
@@ -222,28 +233,26 @@ pub enum Instruction<P: ArgParams> {
Shl(ShlType, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
- Call(CallData, ArgCall<P>),
+ Call(CallInst<P>),
Abs(AbsDetails, Arg2<P>),
}
-pub struct CallData {
- pub uniform: bool,
-}
-
pub struct AbsDetails {
pub flush_to_zero: bool,
pub typ: ScalarType,
}
-pub struct ArgCall<P: ArgParams> {
+pub struct CallInst<P: ArgParams> {
+ pub uniform: bool,
pub ret_params: Vec<P::ID>,
pub func: P::ID,
- pub param_list: Vec<P::ID>,
+ pub param_list: Vec<P::CallOperand>,
}
pub trait ArgParams {
type ID;
type Operand;
+ type CallOperand;
type MovOperand;
}
@@ -254,6 +263,7 @@ pub struct ParsedArgParams<'a> {
impl<'a> ArgParams for ParsedArgParams<'a> {
type ID = &'a str;
type Operand = Operand<&'a str>;
+ type CallOperand = CallOperand<&'a str>;
type MovOperand = MovOperand<&'a str>;
}
@@ -304,6 +314,12 @@ pub enum Operand<ID> {
Imm(i128),
}
+#[derive(Copy, Clone)]
+pub enum CallOperand<ID> {
+ Reg(ID),
+ Imm(i128),
+}
+
pub enum MovOperand<ID> {
Op(Operand<ID>),
Vec(String, String),
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 7e38b78..53bb296 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -202,8 +202,7 @@ AddressSize = {
Function: ast::Function<'input, ast::ParsedArgParams<'input>, ast::Statement<ast::ParsedArgParams<'input>>> = {
LinkingDirective*
- <func_directive:FunctionHeader>
- <args:Arguments>
+ <func_directive:MethodDecl>
<body:FunctionBody> => ast::Function{<>}
};
@@ -213,24 +212,43 @@ LinkingDirective = {
".weak"
};
-FunctionHeader: ast::FunctionHeader<'input, ast::ParsedArgParams<'input>> = {
- ".entry" <name:ExtendedID> => ast::FunctionHeader::Kernel(name),
- ".func" <args:Arguments?> <name:ExtendedID> => ast::FunctionHeader::Func(args.unwrap_or_else(|| Vec::new()), name)
+MethodDecl: ast::MethodDecl<'input, ast::ParsedArgParams<'input>> = {
+ ".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
+ ".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
};
-Arguments: Vec<ast::Argument<ast::ParsedArgParams<'input>>> = {
- "(" <args:Comma<FunctionInput>> ")" => args
-}
+KernelArguments: Vec<ast::KernelArgument<ast::ParsedArgParams<'input>>> = {
+ "(" <args:Comma<KernelInput>> ")" => args
+};
+
+FnArguments: Vec<ast::FnArgument<ast::ParsedArgParams<'input>>> = {
+ "(" <args:Comma<FnInput>> ")" => args
+};
+
+FnInput: ast::FnArgument<ast::ParsedArgParams<'input>> = {
+ ".reg" <_type:ScalarType> <name:ExtendedID> => {
+ ast::FnArgument {
+ base: ast::KernelArgument {a_type: _type, name: name, length: 1 },
+ state_space: ast::FnArgStateSpace::Reg,
+ }
+ },
+ <p:KernelInput> => {
+ ast::FnArgument {
+ base: p,
+ state_space: ast::FnArgStateSpace::Param,
+ }
+ }
+};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
-FunctionInput: ast::Argument<ast::ParsedArgParams<'input>> = {
+KernelInput: ast::KernelArgument<ast::ParsedArgParams<'input>> = {
".param" <_type:ScalarType> <name:ExtendedID> => {
- ast::Argument {a_type: _type, name: name, length: 1 }
+ ast::KernelArgument {a_type: _type, name: name, length: 1 }
},
".param" <a_type:ScalarType> <name:ExtendedID> "[" <length:Num> "]" => {
let length = length.parse::<u32>();
let length = length.unwrap_with(errors);
- ast::Argument { a_type: a_type, name: name, length: length }
+ ast::KernelArgument { a_type: a_type, name: name, length: length }
}
};
@@ -856,7 +874,10 @@ CvtaSize: ast::CvtaSize = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call
InstCall: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "call" <u:".uni"?> <a:ArgCall> => ast::Instruction::Call(ast::CallData { uniform: u.is_some() }, a)
+ "call" <u:".uni"?> <args:ArgCall> => {
+ let (ret_params, func, param_list) = args;
+ ast::Instruction::Call(ast::CallInst { uniform: u.is_some(), ret_params, func, param_list })
+ }
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs
@@ -900,6 +921,15 @@ Operand: ast::Operand<&'input str> = {
}
};
+CallOperand: ast::CallOperand<&'input str> = {
+ <r:ExtendedID> => ast::CallOperand::Reg(r),
+ <o:Num> => {
+ let offset = o.parse::<i128>();
+ let offset = offset.unwrap_with(errors);
+ ast::CallOperand::Imm(offset)
+ }
+};
+
MovOperand: ast::MovOperand<&'input str> = {
<o:Operand> => ast::MovOperand::Op(o),
<o:VectorOperand> => {
@@ -938,10 +968,12 @@ Arg5: ast::Arg5<ast::ParsedArgParams<'input>> = {
<dst1:ExtendedID> <dst2:OptionalDst?> "," <src1:Operand> "," <src2:Operand> "," "!"? <src3:Operand> => ast::Arg5{<>}
};
-ArgCall: ast::ArgCall<ast::ParsedArgParams<'input>> = {
- "(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<ExtendedID>> ")" => ast::ArgCall{<>},
- <func:ExtendedID> "," "(" <param_list:Comma<ExtendedID>> ")" => ast::ArgCall{ret_params: Vec::new(), func, param_list},
- <func:ExtendedID> => ast::ArgCall{ret_params: Vec::new(), func, param_list: Vec::new()},
+ArgCall: (Vec<&'input str>, &'input str, Vec<ast::CallOperand<&'input str>>) = {
+ "(" <ret_params:Comma<ExtendedID>> ")" "," <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => {
+ (ret_params, func, param_list)
+ },
+ <func:ExtendedID> "," "(" <param_list:Comma<CallOperand>> ")" => (Vec::new(), func, param_list),
+ <func:ExtendedID> => (Vec::new(), func, Vec::<ast::CallOperand<_>>::new()),
};
OptionalDst: &'input str = {
diff --git a/ptx/src/test/spirv_run/add.spvtxt b/ptx/src/test/spirv_run/add.spvtxt
index 465a74e..6810fec 100644
--- a/ptx/src/test/spirv_run/add.spvtxt
+++ b/ptx/src/test/spirv_run/add.spvtxt
@@ -1,38 +1,46 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "add"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %ulong_1 = OpConstant %ulong 1
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %21 = OpLabel
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_ulong Function
- %11 = OpVariable %_ptr_Function_ulong Function
- OpStore %8 %6
- OpStore %9 %7
- %12 = OpLoad %ulong %8
- %19 = OpConvertUToPtr %_ptr_Generic_ulong %12
- %13 = OpLoad %ulong %19
- OpStore %10 %13
- %14 = OpLoad %ulong %10
- %15 = OpIAdd %ulong %14 %ulong_1
- OpStore %11 %15
- %16 = OpLoad %ulong %9
- %17 = OpLoad %ulong %11
- %20 = OpConvertUToPtr %_ptr_Generic_ulong %16
- OpStore %20 %17
- OpReturn
- OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %25 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "add"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %28 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_1 = OpConstant %ulong 1
+ %1 = OpFunction %void None %28
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %23 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %21
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpIAdd %ulong %17 %ulong_1
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/block.spvtxt b/ptx/src/test/spirv_run/block.spvtxt
index a780cc3..534167d 100644
--- a/ptx/src/test/spirv_run/block.spvtxt
+++ b/ptx/src/test/spirv_run/block.spvtxt
@@ -1,44 +1,51 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "block"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %ulong_1 = OpConstant %ulong 1
- %ulong_1_0 = OpConstant %ulong 1
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %25 = OpLabel
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_ulong Function
- %11 = OpVariable %_ptr_Function_ulong Function
- OpStore %8 %6
- OpStore %9 %7
- %14 = OpLoad %ulong %8
- %23 = OpConvertUToPtr %_ptr_Generic_ulong %14
- %13 = OpLoad %ulong %23
- OpStore %10 %13
- %16 = OpLoad %ulong %10
- %15 = OpIAdd %ulong %16 %ulong_1
- OpStore %11 %15
- %12 = OpVariable %_ptr_Function_ulong Function
- %18 = OpLoad %ulong %12
- %17 = OpIAdd %ulong %18 %ulong_1_0
- OpStore %12 %17
- %19 = OpLoad %ulong %9
- %20 = OpLoad %ulong %11
- %24 = OpConvertUToPtr %_ptr_Generic_ulong %19
- OpStore %24 %20
- OpReturn
- OpFunctionEnd
- \ No newline at end of file
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %29 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "block"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %32 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_1 = OpConstant %ulong 1
+ %ulong_1_0 = OpConstant %ulong 1
+ %1 = OpFunction %void None %32
+ %9 = OpFunctionParameter %ulong
+ %10 = OpFunctionParameter %ulong
+ %27 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %9
+ OpStore %3 %10
+ %12 = OpLoad %ulong %2
+ %11 = OpCopyObject %ulong %12
+ OpStore %4 %11
+ %14 = OpLoad %ulong %3
+ %13 = OpCopyObject %ulong %14
+ OpStore %5 %13
+ %16 = OpLoad %ulong %4
+ %25 = OpConvertUToPtr %_ptr_Generic_ulong %16
+ %15 = OpLoad %ulong %25
+ OpStore %6 %15
+ %18 = OpLoad %ulong %6
+ %17 = OpIAdd %ulong %18 %ulong_1
+ OpStore %7 %17
+ %20 = OpLoad %ulong %8
+ %19 = OpIAdd %ulong %20 %ulong_1_0
+ OpStore %8 %19
+ %21 = OpLoad %ulong %5
+ %22 = OpLoad %ulong %7
+ %26 = OpConvertUToPtr %_ptr_Generic_ulong %21
+ OpStore %26 %22
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/bra.spvtxt b/ptx/src/test/spirv_run/bra.spvtxt
index 81fedc5..f59fda5 100644
--- a/ptx/src/test/spirv_run/bra.spvtxt
+++ b/ptx/src/test/spirv_run/bra.spvtxt
@@ -1,49 +1,56 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "bra"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %ulong_1 = OpConstant %ulong 1
- %ulong_2 = OpConstant %ulong 2
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %27 = OpLabel
- %11 = OpVariable %_ptr_Function_ulong Function
- %12 = OpVariable %_ptr_Function_ulong Function
- %13 = OpVariable %_ptr_Function_ulong Function
- %14 = OpVariable %_ptr_Function_ulong Function
- OpStore %11 %6
- OpStore %12 %7
- %16 = OpLoad %ulong %11
- %25 = OpConvertUToPtr %_ptr_Generic_ulong %16
- %15 = OpLoad %ulong %25
- OpStore %13 %15
- OpBranch %8
- %8 = OpLabel
- %18 = OpLoad %ulong %13
- %17 = OpIAdd %ulong %18 %ulong_1
- OpStore %14 %17
- OpBranch %10
- %30 = OpLabel
- %20 = OpLoad %ulong %13
- %19 = OpIAdd %ulong %20 %ulong_2
- OpStore %14 %19
- OpBranch %10
- %10 = OpLabel
- %21 = OpLoad %ulong %12
- %22 = OpLoad %ulong %14
- %26 = OpConvertUToPtr %_ptr_Generic_ulong %21
- OpStore %26 %22
- OpReturn
- OpFunctionEnd
- \ No newline at end of file
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %31 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "bra"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %34 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_1 = OpConstant %ulong 1
+ %ulong_2 = OpConstant %ulong 2
+ %1 = OpFunction %void None %34
+ %11 = OpFunctionParameter %ulong
+ %12 = OpFunctionParameter %ulong
+ %29 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ %9 = OpVariable %_ptr_Function_ulong Function
+ %10 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %11
+ OpStore %3 %12
+ %14 = OpLoad %ulong %2
+ %13 = OpCopyObject %ulong %14
+ OpStore %7 %13
+ %16 = OpLoad %ulong %3
+ %15 = OpCopyObject %ulong %16
+ OpStore %8 %15
+ %18 = OpLoad %ulong %7
+ %27 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ %17 = OpLoad %ulong %27
+ OpStore %9 %17
+ OpBranch %4
+ %4 = OpLabel
+ %20 = OpLoad %ulong %9
+ %19 = OpIAdd %ulong %20 %ulong_1
+ OpStore %10 %19
+ OpBranch %6
+ %37 = OpLabel
+ %22 = OpLoad %ulong %9
+ %21 = OpIAdd %ulong %22 %ulong_2
+ OpStore %10 %21
+ OpBranch %6
+ %6 = OpLabel
+ %23 = OpLoad %ulong %8
+ %24 = OpLoad %ulong %10
+ %28 = OpConvertUToPtr %_ptr_Generic_ulong %23
+ OpStore %28 %24
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/call.ptx b/ptx/src/test/spirv_run/call.ptx
new file mode 100644
index 0000000..f2ac39c
--- /dev/null
+++ b/ptx/src/test/spirv_run/call.ptx
@@ -0,0 +1,38 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.func (.param.u64 output) incr (.param.u64 input);
+
+.visible .entry call(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u64 temp, [in_addr];
+ .param.u64 incr_in;
+ .param.u64 incr_out;
+ st.param.b64 [incr_in], temp;
+ call (incr_out), incr, (incr_in);
+ ld.param.u64 temp, [incr_out];
+ st.global.u64 [out_addr], temp;
+ ret;
+}
+
+.func (.param .u64 output) incr(
+ .param .u64 input
+)
+{
+ .reg .u64 temp;
+ ld.param.u64 temp, [input];
+ add.u64 temp, temp, 1;
+ st.param.u64 [output], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt
new file mode 100644
index 0000000..001cda3
--- /dev/null
+++ b/ptx/src/test/spirv_run/call.spvtxt
@@ -0,0 +1,73 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %45 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %4 "call"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %48 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
+ %51 = OpTypeFunction %ulong %ulong
+ %ulong_1 = OpConstant %ulong 1
+ %4 = OpFunction %void None %48
+ %12 = OpFunctionParameter %ulong
+ %13 = OpFunctionParameter %ulong
+ %30 = OpLabel
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ %9 = OpVariable %_ptr_Function_ulong Function
+ %10 = OpVariable %_ptr_Function_ulong Function
+ %11 = OpVariable %_ptr_Function_ulong Function
+ OpStore %5 %12
+ OpStore %6 %13
+ %15 = OpLoad %ulong %5
+ %14 = OpCopyObject %ulong %15
+ OpStore %7 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpCopyObject %ulong %17
+ OpStore %8 %16
+ %19 = OpLoad %ulong %7
+ %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %19
+ %18 = OpLoad %ulong %28
+ OpStore %9 %18
+ %21 = OpLoad %ulong %9
+ %20 = OpCopyObject %ulong %21
+ OpStore %10 %20
+ %23 = OpLoad %ulong %10
+ %22 = OpFunctionCall %ulong %1 %23
+ OpStore %11 %22
+ %25 = OpLoad %ulong %11
+ %24 = OpCopyObject %ulong %25
+ OpStore %9 %24
+ %26 = OpLoad %ulong %8
+ %27 = OpLoad %ulong %9
+ %29 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26
+ OpStore %29 %27
+ OpReturn
+ OpFunctionEnd
+ %1 = OpFunction %ulong None %51
+ %34 = OpFunctionParameter %ulong
+ %43 = OpLabel
+ %32 = OpVariable %_ptr_Function_ulong Function
+ %31 = OpVariable %_ptr_Function_ulong Function
+ %33 = OpVariable %_ptr_Function_ulong Function
+ OpStore %32 %34
+ %36 = OpLoad %ulong %32
+ %35 = OpCopyObject %ulong %36
+ OpStore %33 %35
+ %38 = OpLoad %ulong %33
+ %37 = OpIAdd %ulong %38 %ulong_1
+ OpStore %33 %37
+ %40 = OpLoad %ulong %33
+ %39 = OpCopyObject %ulong %40
+ OpStore %31 %39
+ %41 = OpLoad %ulong %31
+ OpReturnValue %41
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt
index afd2864..208c279 100644
--- a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt
+++ b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt
@@ -1,43 +1,51 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "cvt_sat_s_u"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
- %uint = OpTypeInt 32 0
-%_ptr_Function_uint = OpTypePointer Function %uint
-%_ptr_Generic_uint = OpTypePointer Generic %uint
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %23 = OpLabel
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_uint Function
- %11 = OpVariable %_ptr_Function_uint Function
- %12 = OpVariable %_ptr_Function_uint Function
- OpStore %8 %6
- OpStore %9 %7
- %14 = OpLoad %ulong %8
- %21 = OpConvertUToPtr %_ptr_Generic_uint %14
- %13 = OpLoad %uint %21
- OpStore %10 %13
- %16 = OpLoad %uint %10
- %15 = OpSatConvertSToU %uint %16
- OpStore %11 %15
- %18 = OpLoad %uint %11
- %17 = OpBitcast %uint %18
- OpStore %12 %17
- %19 = OpLoad %ulong %9
- %20 = OpLoad %uint %12
- %22 = OpConvertUToPtr %_ptr_Generic_uint %19
- OpStore %22 %20
- OpReturn
- OpFunctionEnd \ No newline at end of file
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %27 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "cvt_sat_s_u"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %30 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %1 = OpFunction %void None %30
+ %9 = OpFunctionParameter %ulong
+ %10 = OpFunctionParameter %ulong
+ %25 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ %8 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %9
+ OpStore %3 %10
+ %12 = OpLoad %ulong %2
+ %11 = OpCopyObject %ulong %12
+ OpStore %4 %11
+ %14 = OpLoad %ulong %3
+ %13 = OpCopyObject %ulong %14
+ OpStore %5 %13
+ %16 = OpLoad %ulong %4
+ %23 = OpConvertUToPtr %_ptr_Generic_uint %16
+ %15 = OpLoad %uint %23
+ OpStore %6 %15
+ %18 = OpLoad %uint %6
+ %17 = OpSatConvertSToU %uint %18
+ OpStore %7 %17
+ %20 = OpLoad %uint %7
+ %19 = OpBitcast %uint %20
+ OpStore %8 %19
+ %21 = OpLoad %ulong %5
+ %22 = OpLoad %uint %8
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %21
+ OpStore %24 %22
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/cvta.spvtxt b/ptx/src/test/spirv_run/cvta.spvtxt
index 1aa7425..e708613 100644
--- a/ptx/src/test/spirv_run/cvta.spvtxt
+++ b/ptx/src/test/spirv_run/cvta.spvtxt
@@ -1,42 +1,49 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "cvta"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
- %float = OpTypeFloat 32
-%_ptr_Function_float = OpTypePointer Function %float
-%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %21 = OpLabel
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_float Function
- OpStore %8 %6
- OpStore %9 %7
- %12 = OpLoad %ulong %8
- %11 = OpCopyObject %ulong %12
- OpStore %8 %11
- %14 = OpLoad %ulong %9
- %13 = OpCopyObject %ulong %14
- OpStore %9 %13
- %16 = OpLoad %ulong %8
- %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %16
- %15 = OpLoad %float %19
- OpStore %10 %15
- %17 = OpLoad %ulong %9
- %18 = OpLoad %float %10
- %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %17
- OpStore %20 %18
- OpReturn
- OpFunctionEnd
- \ No newline at end of file
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %25 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "cvta"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %28 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
+ %1 = OpFunction %void None %28
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %23 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_float Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %10 = OpLoad %ulong %2
+ %9 = OpCopyObject %ulong %10
+ OpStore %4 %9
+ %12 = OpLoad %ulong %3
+ %11 = OpCopyObject %ulong %12
+ OpStore %5 %11
+ %14 = OpLoad %ulong %4
+ %13 = OpCopyObject %ulong %14
+ OpStore %4 %13
+ %16 = OpLoad %ulong %5
+ %15 = OpCopyObject %ulong %16
+ OpStore %5 %15
+ %18 = OpLoad %ulong %4
+ %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
+ %17 = OpLoad %float %21
+ OpStore %6 %17
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %float %6
+ %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19
+ OpStore %22 %20
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/ld_st.spvtxt b/ptx/src/test/spirv_run/ld_st.spvtxt
index 1cb7094..d36db57 100644
--- a/ptx/src/test/spirv_run/ld_st.spvtxt
+++ b/ptx/src/test/spirv_run/ld_st.spvtxt
@@ -1,38 +1,41 @@
-; SPIR-V
-; Version: 1.5
-; Generator: Khronos SPIR-V Tools Assembler; 0
-; Bound: 20
-; Schema: 0
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %2 "ld_st"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %5 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %2 = OpFunction %void None %5
- %8 = OpFunctionParameter %ulong
- %9 = OpFunctionParameter %ulong
- %10 = OpLabel
- %11 = OpVariable %_ptr_Function_ulong Function
- %12 = OpVariable %_ptr_Function_ulong Function
- %13 = OpVariable %_ptr_Function_ulong Function
- OpStore %11 %8
- OpStore %12 %9
- %14 = OpLoad %ulong %11
- %15 = OpConvertUToPtr %_ptr_Generic_ulong %14
- %16 = OpLoad %ulong %15
- OpStore %13 %16
- %17 = OpLoad %ulong %12
- %18 = OpLoad %ulong %13
- %19 = OpConvertUToPtr %_ptr_Generic_ulong %17
- OpStore %19 %18
- OpReturn
- OpFunctionEnd \ No newline at end of file
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %21 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "ld_st"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %24 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %1 = OpFunction %void None %24
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %19 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %10 = OpLoad %ulong %2
+ %9 = OpCopyObject %ulong %10
+ OpStore %4 %9
+ %12 = OpLoad %ulong %3
+ %11 = OpCopyObject %ulong %12
+ OpStore %5 %11
+ %14 = OpLoad %ulong %4
+ %17 = OpConvertUToPtr %_ptr_Generic_ulong %14
+ %13 = OpLoad %ulong %17
+ OpStore %6 %13
+ %15 = OpLoad %ulong %5
+ %16 = OpLoad %ulong %6
+ %18 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ OpStore %18 %16
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/local_align.spvtxt b/ptx/src/test/spirv_run/local_align.spvtxt
index beefb76..09a3f92 100644
--- a/ptx/src/test/spirv_run/local_align.spvtxt
+++ b/ptx/src/test/spirv_run/local_align.spvtxt
@@ -1,38 +1,46 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "local_align"
- OpDecorate %8 Alignment 8
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
- %uchar = OpTypeInt 8 0
-%_arr_uchar_8 = OpTypeArray %uchar %8
-%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %18 = OpLabel
- %8 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_ulong Function
- %11 = OpVariable %_ptr_Function_ulong Function
- OpStore %9 %6
- OpStore %10 %7
- %13 = OpLoad %ulong %9
- %16 = OpConvertUToPtr %_ptr_Generic_ulong %13
- %12 = OpLoad %ulong %16
- OpStore %11 %12
- %14 = OpLoad %ulong %10
- %15 = OpLoad %ulong %11
- %17 = OpConvertUToPtr %_ptr_Generic_ulong %14
- OpStore %17 %15
- OpReturn
- OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %22 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "local_align"
+ OpDecorate %4 Alignment 8
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %25 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uchar = OpTypeInt 8 0
+%_arr_uchar_8 = OpTypeArray %uchar %8
+%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %1 = OpFunction %void None %25
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %20 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %5 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %6 %12
+ %15 = OpLoad %ulong %5
+ %18 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %18
+ OpStore %7 %14
+ %16 = OpLoad %ulong %6
+ %17 = OpLoad %ulong %7
+ %19 = OpConvertUToPtr %_ptr_Generic_ulong %16
+ OpStore %19 %17
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 9f62292..a72c453 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -14,7 +14,7 @@ use std::fmt;
use std::fmt::{Debug, Display, Formatter};
use std::mem;
use std::slice;
-use std::{collections::HashMap, ptr, str};
+use std::{borrow::Cow, collections::HashMap, env, fs, path::PathBuf, ptr, str};
macro_rules! test_ptx {
($fn_name:ident, $input:expr, $output:expr) => {
@@ -32,8 +32,9 @@ macro_rules! test_ptx {
#[test]
fn [<$fn_name _spvtxt>]() -> Result<(), Box<dyn std::error::Error>> {
let ptx_txt = include_str!(concat!(stringify!($fn_name), ".ptx"));
+ let spirv_file_name = concat!(stringify!($fn_name), ".spvtxt");
let spirv_txt = include_bytes!(concat!(stringify!($fn_name), ".spvtxt"));
- test_spvtxt_assert(ptx_txt, spirv_txt)
+ test_spvtxt_assert(ptx_txt, spirv_txt, spirv_file_name)
}
}
};
@@ -140,6 +141,7 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
fn test_spvtxt_assert<'a>(
ptx_txt: &'a str,
spirv_txt: &'a [u8],
+ spirv_file_name: &'a str,
) -> Result<(), Box<dyn error::Error + 'a>> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
@@ -191,16 +193,27 @@ fn test_spvtxt_assert<'a>(
)
};
unsafe { spirv_tools::spvContextDestroy(spv_context) };
- if result == spv_result_t::SPV_SUCCESS {
+ let spirv_text = if result == spv_result_t::SPV_SUCCESS {
let raw_text = unsafe {
std::slice::from_raw_parts((*spv_text).str_ as *const u8, (*spv_text).length)
};
let spv_from_ptx_text = unsafe { str::from_utf8_unchecked(raw_text) };
// TODO: stop leaking kernel text
- panic!(spv_from_ptx_text);
+ Cow::Borrowed(spv_from_ptx_text)
} else {
- panic!(ptx_mod.disassemble());
+ Cow::Owned(ptx_mod.disassemble())
+ };
+ if let Ok(dump_path) = env::var("NOTCUDA_TEST_SPIRV_DUMP_DIR") {
+ let mut path = PathBuf::from(dump_path);
+ if let Ok(()) = fs::create_dir_all(&path) {
+ path.push(spirv_file_name);
+ #[allow(unused_must_use)]
+ {
+ fs::write(path, spirv_text.as_bytes());
+ }
+ }
}
+ panic!(spirv_text);
}
unsafe { spirv_tools::spvContextDestroy(spv_context) };
Ok(())
diff --git a/ptx/src/test/spirv_run/mov.spvtxt b/ptx/src/test/spirv_run/mov.spvtxt
index 367a92a..d8a5029 100644
--- a/ptx/src/test/spirv_run/mov.spvtxt
+++ b/ptx/src/test/spirv_run/mov.spvtxt
@@ -1,43 +1,45 @@
-; SPIR-V
-; Version: 1.5
-; Generator: Khronos SPIR-V Tools Assembler; 0
-; Bound: 23
-; Schema: 0
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %2 "mov"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %5 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %2 = OpFunction %void None %5
- %8 = OpFunctionParameter %ulong
- %9 = OpFunctionParameter %ulong
- %10 = OpLabel
- %11 = OpVariable %_ptr_Function_ulong Function
- %12 = OpVariable %_ptr_Function_ulong Function
- %13 = OpVariable %_ptr_Function_ulong Function
- %14 = OpVariable %_ptr_Function_ulong Function
- OpStore %11 %8
- OpStore %12 %9
- %15 = OpLoad %ulong %11
- %16 = OpConvertUToPtr %_ptr_Generic_ulong %15
- %17 = OpLoad %ulong %16
- OpStore %13 %17
- %18 = OpLoad %ulong %13
- %19 = OpCopyObject %ulong %18
- OpStore %14 %19
- %20 = OpLoad %ulong %12
- %21 = OpLoad %ulong %14
- %22 = OpConvertUToPtr %_ptr_Generic_ulong %20
- OpStore %22 %21
- OpReturn
- OpFunctionEnd
- \ No newline at end of file
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %24 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "mov"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %27 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %1 = OpFunction %void None %27
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %22 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %20 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %20
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpCopyObject %ulong %17
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %21 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mul_hi.spvtxt b/ptx/src/test/spirv_run/mul_hi.spvtxt
index d25dd8a..bea23a9 100644
--- a/ptx/src/test/spirv_run/mul_hi.spvtxt
+++ b/ptx/src/test/spirv_run/mul_hi.spvtxt
@@ -1,43 +1,46 @@
-; SPIR-V
-; Version: 1.5
-; Generator: Khronos SPIR-V Tools Assembler; 0
-; Bound: 24
-; Schema: 0
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %2 "mul_hi"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %5 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %ulong_2 = OpConstant %ulong 2
- %2 = OpFunction %void None %5
- %9 = OpFunctionParameter %ulong
- %10 = OpFunctionParameter %ulong
- %11 = OpLabel
- %12 = OpVariable %_ptr_Function_ulong Function
- %13 = OpVariable %_ptr_Function_ulong Function
- %14 = OpVariable %_ptr_Function_ulong Function
- %15 = OpVariable %_ptr_Function_ulong Function
- OpStore %12 %9
- OpStore %13 %10
- %16 = OpLoad %ulong %12
- %17 = OpConvertUToPtr %_ptr_Generic_ulong %16
- %18 = OpLoad %ulong %17
- OpStore %14 %18
- %19 = OpLoad %ulong %14
- %20 = OpExtInst %ulong %1 u_mul_hi %19 %ulong_2
- OpStore %15 %20
- %21 = OpLoad %ulong %13
- %22 = OpLoad %ulong %15
- %23 = OpConvertUToPtr %_ptr_Generic_ulong %21
- OpStore %23 %22
- OpReturn
- OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %25 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "mul_hi"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %28 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_2 = OpConstant %ulong 2
+ %1 = OpFunction %void None %28
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %23 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %21
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpExtInst %ulong %25 u_mul_hi %17 %ulong_2
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mul_lo.spvtxt b/ptx/src/test/spirv_run/mul_lo.spvtxt
index 4d7c2d8..e114374 100644
--- a/ptx/src/test/spirv_run/mul_lo.spvtxt
+++ b/ptx/src/test/spirv_run/mul_lo.spvtxt
@@ -1,38 +1,46 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "mul_lo"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %ulong_2 = OpConstant %ulong 2
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %21 = OpLabel
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_ulong Function
- %11 = OpVariable %_ptr_Function_ulong Function
- OpStore %8 %6
- OpStore %9 %7
- %12 = OpLoad %ulong %8
- %19 = OpConvertUToPtr %_ptr_Generic_ulong %12
- %13 = OpLoad %ulong %19
- OpStore %10 %13
- %14 = OpLoad %ulong %10
- %15 = OpIMul %ulong %14 %ulong_2
- OpStore %11 %15
- %16 = OpLoad %ulong %9
- %17 = OpLoad %ulong %11
- %20 = OpConvertUToPtr %_ptr_Generic_ulong %16
- OpStore %20 %17
- OpReturn
- OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %25 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "mul_lo"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %28 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_2 = OpConstant %ulong 2
+ %1 = OpFunction %void None %28
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %23 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %21
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpIMul %ulong %17 %ulong_2
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/not.spvtxt b/ptx/src/test/spirv_run/not.spvtxt
index 84482d9..de340ed 100644
--- a/ptx/src/test/spirv_run/not.spvtxt
+++ b/ptx/src/test/spirv_run/not.spvtxt
@@ -1,37 +1,45 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "not"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %20 = OpLabel
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_ulong Function
- %11 = OpVariable %_ptr_Function_ulong Function
- OpStore %8 %6
- OpStore %9 %7
- %13 = OpLoad %ulong %8
- %18 = OpConvertUToPtr %_ptr_Generic_ulong %13
- %12 = OpLoad %ulong %18
- OpStore %10 %12
- %15 = OpLoad %ulong %10
- %14 = OpNot %ulong %15
- OpStore %11 %14
- %16 = OpLoad %ulong %9
- %17 = OpLoad %ulong %11
- %19 = OpConvertUToPtr %_ptr_Generic_ulong %16
- OpStore %19 %17
- OpReturn
- OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %24 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "not"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %27 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %1 = OpFunction %void None %27
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %22 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %20 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %20
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpNot %ulong %17
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %21 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/setp.spvtxt b/ptx/src/test/spirv_run/setp.spvtxt
index 064cd97..cb87f65 100644
--- a/ptx/src/test/spirv_run/setp.spvtxt
+++ b/ptx/src/test/spirv_run/setp.spvtxt
@@ -1,65 +1,73 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "setp"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
- %bool = OpTypeBool
-%_ptr_Function_bool = OpTypePointer Function %bool
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %ulong_8 = OpConstant %ulong 8
- %ulong_1 = OpConstant %ulong 1
- %ulong_2 = OpConstant %ulong 2
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %39 = OpLabel
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_ulong Function
- %11 = OpVariable %_ptr_Function_ulong Function
- %12 = OpVariable %_ptr_Function_ulong Function
- %13 = OpVariable %_ptr_Function_bool Function
- OpStore %8 %6
- OpStore %9 %7
- %19 = OpLoad %ulong %8
- %35 = OpConvertUToPtr %_ptr_Generic_ulong %19
- %18 = OpLoad %ulong %35
- OpStore %10 %18
- %21 = OpLoad %ulong %8
- %36 = OpCopyObject %ulong %21
- %32 = OpIAdd %ulong %36 %ulong_8
- %37 = OpConvertUToPtr %_ptr_Generic_ulong %32
- %20 = OpLoad %ulong %37
- OpStore %11 %20
- %23 = OpLoad %ulong %10
- %24 = OpLoad %ulong %11
- %22 = OpULessThan %bool %23 %24
- OpStore %13 %22
- %25 = OpLoad %bool %13
- OpBranchConditional %25 %14 %15
- %14 = OpLabel
- %26 = OpCopyObject %ulong %ulong_1
- OpStore %12 %26
- OpBranch %15
- %15 = OpLabel
- %27 = OpLoad %bool %13
- OpBranchConditional %27 %17 %16
- %16 = OpLabel
- %28 = OpCopyObject %ulong %ulong_2
- OpStore %12 %28
- OpBranch %17
- %17 = OpLabel
- %29 = OpLoad %ulong %9
- %30 = OpLoad %ulong %12
- %38 = OpConvertUToPtr %_ptr_Generic_ulong %29
- OpStore %38 %30
- OpReturn
- OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %43 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "setp"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %46 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %bool = OpTypeBool
+%_ptr_Function_bool = OpTypePointer Function %bool
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_8 = OpConstant %ulong 8
+ %ulong_1 = OpConstant %ulong 1
+ %ulong_2 = OpConstant %ulong 2
+ %1 = OpFunction %void None %46
+ %14 = OpFunctionParameter %ulong
+ %15 = OpFunctionParameter %ulong
+ %41 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ %8 = OpVariable %_ptr_Function_ulong Function
+ %9 = OpVariable %_ptr_Function_bool Function
+ OpStore %2 %14
+ OpStore %3 %15
+ %17 = OpLoad %ulong %2
+ %16 = OpCopyObject %ulong %17
+ OpStore %4 %16
+ %19 = OpLoad %ulong %3
+ %18 = OpCopyObject %ulong %19
+ OpStore %5 %18
+ %21 = OpLoad %ulong %4
+ %37 = OpConvertUToPtr %_ptr_Generic_ulong %21
+ %20 = OpLoad %ulong %37
+ OpStore %6 %20
+ %23 = OpLoad %ulong %4
+ %38 = OpCopyObject %ulong %23
+ %34 = OpIAdd %ulong %38 %ulong_8
+ %39 = OpConvertUToPtr %_ptr_Generic_ulong %34
+ %22 = OpLoad %ulong %39
+ OpStore %7 %22
+ %25 = OpLoad %ulong %6
+ %26 = OpLoad %ulong %7
+ %24 = OpULessThan %bool %25 %26
+ OpStore %9 %24
+ %27 = OpLoad %bool %9
+ OpBranchConditional %27 %10 %11
+ %10 = OpLabel
+ %28 = OpCopyObject %ulong %ulong_1
+ OpStore %8 %28
+ OpBranch %11
+ %11 = OpLabel
+ %29 = OpLoad %bool %9
+ OpBranchConditional %29 %13 %12
+ %12 = OpLabel
+ %30 = OpCopyObject %ulong %ulong_2
+ OpStore %8 %30
+ OpBranch %13
+ %13 = OpLabel
+ %31 = OpLoad %ulong %5
+ %32 = OpLoad %ulong %8
+ %40 = OpConvertUToPtr %_ptr_Generic_ulong %31
+ OpStore %40 %32
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/shl.spvtxt b/ptx/src/test/spirv_run/shl.spvtxt
index 3e57fc3..dbd2664 100644
--- a/ptx/src/test/spirv_run/shl.spvtxt
+++ b/ptx/src/test/spirv_run/shl.spvtxt
@@ -1,39 +1,47 @@
- OpCapability GenericPointer
- OpCapability Linkage
- OpCapability Addresses
- OpCapability Kernel
- OpCapability Int64
- OpCapability Int8
- %1 = OpExtInstImport "OpenCL.std"
- OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %5 "shl"
- %void = OpTypeVoid
- %ulong = OpTypeInt 64 0
- %4 = OpTypeFunction %void %ulong %ulong
-%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Generic_ulong = OpTypePointer Generic %ulong
- %uint = OpTypeInt 32 0
- %uint_2 = OpConstant %uint 2
- %5 = OpFunction %void None %4
- %6 = OpFunctionParameter %ulong
- %7 = OpFunctionParameter %ulong
- %21 = OpLabel
- %8 = OpVariable %_ptr_Function_ulong Function
- %9 = OpVariable %_ptr_Function_ulong Function
- %10 = OpVariable %_ptr_Function_ulong Function
- %11 = OpVariable %_ptr_Function_ulong Function
- OpStore %8 %6
- OpStore %9 %7
- %13 = OpLoad %ulong %8
- %19 = OpConvertUToPtr %_ptr_Generic_ulong %13
- %12 = OpLoad %ulong %19
- OpStore %10 %12
- %15 = OpLoad %ulong %10
- %14 = OpShiftLeftLogical %ulong %15 %uint_2
- OpStore %11 %14
- %16 = OpLoad %ulong %9
- %17 = OpLoad %ulong %11
- %20 = OpConvertUToPtr %_ptr_Generic_ulong %16
- OpStore %20 %17
- OpReturn
- OpFunctionEnd
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %25 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "shl"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %28 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %uint = OpTypeInt 32 0
+ %uint_2 = OpConstant %uint 2
+ %1 = OpFunction %void None %28
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %23 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %21
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpShiftLeftLogical %ulong %17 %uint_2
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 34d8c12..bd37b14 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,6 +1,6 @@
use crate::ast;
use rspirv::dr;
-use std::collections::{HashMap, HashSet};
+use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
@@ -10,6 +10,7 @@ enum SpirvType {
Base(SpirvScalarKey),
Array(SpirvScalarKey, u32),
Pointer(Box<SpirvType>, spirv::StorageClass),
+ Func(Option<Box<SpirvType>>, Vec<SpirvType>),
}
impl SpirvType {
@@ -141,16 +142,44 @@ impl TypeWordMap {
.entry(t)
.or_insert_with(|| b.type_array(base, len))
}
+ SpirvType::Func(ref out_params, ref in_params) => {
+ let out_t = match out_params {
+ Some(p) => self.get_or_add(b, *p.clone()),
+ None => self.void(),
+ };
+ let in_t = in_params
+ .iter()
+ .map(|t| self.get_or_add(b, t.clone()))
+ .collect::<Vec<_>>();
+ *self
+ .complex
+ .entry(t)
+ .or_insert_with(|| b.type_function(out_t, in_t))
+ }
}
}
- fn get_or_add_fn<Args: Iterator<Item = SpirvType>>(
+ fn get_or_add_fn(
&mut self,
b: &mut dr::Builder,
- args: Args,
- ) -> spirv::Word {
- let params = args.map(|a| self.get_or_add(b, a)).collect::<Vec<_>>();
- b.type_function(self.void(), params)
+ mut out_params: impl ExactSizeIterator<Item = SpirvType>,
+ in_params: impl ExactSizeIterator<Item = SpirvType>,
+ ) -> (spirv::Word, spirv::Word) {
+ let (out_args, out_spirv_type) = if out_params.len() == 0 {
+ (None, self.void())
+ } else if out_params.len() == 1 {
+ let arg_as_key = out_params.next().unwrap();
+ (
+ Some(Box::new(arg_as_key.clone())),
+ self.get_or_add(b, arg_as_key),
+ )
+ } else {
+ todo!()
+ };
+ (
+ out_spirv_type,
+ self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::<Vec<_>>())),
+ )
}
}
@@ -171,29 +200,31 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
for f in ssa_functions {
- emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive, &*f.args)?;
- emit_function_args(&mut builder, &mut map, &*f.args);
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.body)?;
+ let f_body = match f.body {
+ Some(f) => f,
+ None => continue,
+ };
+ emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?;
+ emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
builder.end_function()?;
}
Ok(builder.module())
}
-fn emit_function_header(
+fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- global: &GlobalStringIdResolver,
- func_directive: ast::FunctionHeader<ExpandedArgParams>,
- params: &[ast::Argument<ExpandedArgParams>],
+ global: &GlobalStringIdResolver<'a>,
+ func_directive: ast::MethodDecl<ExpandedArgParams>,
) -> Result<(), dr::Error> {
- let func_type = get_function_type(builder, map, params);
- let (fn_id, ret_type) = match func_directive {
- ast::FunctionHeader::Kernel(name) => {
+ let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
+ let fn_id = match func_directive {
+ ast::MethodDecl::Kernel(name, _) => {
let fn_id = global.get_id(name);
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]);
- (fn_id, map.void())
+ fn_id
}
- ast::FunctionHeader::Func(params, name) => todo!(),
+ ast::MethodDecl::Func(_, name, _) => name,
};
builder.begin_function(
ret_type,
@@ -201,6 +232,16 @@ fn emit_function_header(
spirv::FunctionControl::NONE,
func_type,
)?;
+ func_directive.visit_args(|arg| {
+ let result_type = map.get_or_add_scalar(builder, arg.a_type);
+ let inst = dr::Instruction::new(
+ spirv::Op::FunctionParameter,
+ Some(result_type),
+ Some(arg.name),
+ Vec::new(),
+ );
+ builder.function.as_mut().unwrap().parameters.push(inst);
+ });
Ok(())
}
@@ -235,50 +276,116 @@ fn to_ssa_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
f: ast::ParsedFunction<'a>,
) -> ExpandedFunction<'a> {
- let mut fn_resolver = FnStringIdResolver::new(id_defs, f.func_directive.name());
- let f_header = match f.func_directive {
- ast::FunctionHeader::Kernel(name) => ast::FunctionHeader::Kernel(name),
- ast::FunctionHeader::Func(ret_params, name) => {
- let name_id = fn_resolver.add_global_def(name);
- let ret_ids = expand_fn_params(&mut fn_resolver, ret_params);
- ast::FunctionHeader::Func(ret_ids, name_id)
- }
- };
- let f_args = expand_fn_params(&mut fn_resolver, f.args);
- let f_body = Some(to_ssa(fn_resolver, f.body.unwrap_or_else(|| Vec::new())));
- ExpandedFunction {
- func_directive: f_header,
- args: f_args,
- body: f_body,
- }
+ let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive);
+ to_ssa(str_resolver, fn_resolver, fn_decl, f.body)
}
-fn expand_fn_params<'a, 'b>(
+fn expand_kernel_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: Vec<ast::Argument<ast::ParsedArgParams<'a>>>,
-) -> Vec<ast::Argument<ExpandedArgParams>> {
- args.into_iter()
- .map(|a| ast::Argument {
- name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))),
- a_type: a.a_type,
- length: a.length,
- })
- .collect()
+ args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
+) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
+ args.map(|a| ast::KernelArgument {
+ name: fn_resolver.add_def(a.name, Some(ast::Type::Scalar(a.a_type))),
+ a_type: a.a_type,
+ length: a.length,
+ })
+ .collect()
}
-fn to_ssa<'a, 'b>(
- mut id_defs: FnStringIdResolver<'a, 'b>,
- f_body: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
-) -> Vec<ExpandedStatement> {
- let normalized_ids = normalize_identifiers(&mut id_defs, f_body);
+fn expand_fn_params<'a, 'b>(
+ fn_resolver: &mut FnStringIdResolver<'a, 'b>,
+ args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
+) -> Vec<ast::FnArgument<ExpandedArgParams>> {
+ args.map(|a| ast::FnArgument {
+ state_space: a.state_space,
+ base: ast::KernelArgument {
+ name: fn_resolver.add_def(a.base.name, Some(ast::Type::Scalar(a.base.a_type))),
+ a_type: a.base.a_type,
+ length: a.base.length,
+ },
+ })
+ .collect()
+}
+
+fn to_ssa<'input, 'b>(
+ mut id_defs: FnStringIdResolver<'input, 'b>,
+ fn_defs: GlobalFnDeclResolver<'input, 'b>,
+ f_args: ast::MethodDecl<'input, ExpandedArgParams>,
+ f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
+) -> ExpandedFunction<'input> {
+ let f_body = match f_body {
+ Some(vec) => vec,
+ None => {
+ return ExpandedFunction {
+ func_directive: f_args,
+ body: None,
+ }
+ }
+ };
+ let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body);
let mut numeric_id_defs = id_defs.finish();
- let normalized_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
- let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut numeric_id_defs);
+ let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
+ let unadorned_statements = resolve_fn_calls(&fn_defs, unadorned_statements);
+ let (f_args, ssa_statements) =
+ insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs);
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
- labeled_statements
+ let sorted_statements = normalize_variable_decls(labeled_statements);
+ ExpandedFunction {
+ func_directive: f_args,
+ body: Some(sorted_statements),
+ }
+}
+
+fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
+ func[1..].sort_by_key(|s| match s {
+ Statement::Variable(_) => 0,
+ _ => 1,
+ });
+ func
+}
+
+fn resolve_fn_calls(
+ fn_defs: &GlobalFnDeclResolver,
+ func: Vec<UnadornedStatement>,
+) -> Vec<UnadornedStatement> {
+ func.into_iter()
+ .map(|s| {
+ match s {
+ Statement::Instruction(ast::Instruction::Call(call)) => {
+ // TODO: error out if lengths don't match
+ let fn_def = fn_defs.get_fn_decl(call.func);
+ let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
+ let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params);
+ let resolved_call = ResolvedCall {
+ uniform: call.uniform,
+ ret_params,
+ func: call.func,
+ param_list,
+ };
+ Statement::Call(resolved_call)
+ }
+ s => s,
+ }
+ })
+ .collect()
+}
+
+fn to_resolved_fn_args<T>(
+ params: Vec<T>,
+ params_decl: &[(ast::FnArgStateSpace, ast::ScalarType)],
+) -> Vec<ArgCall<T>> {
+ params
+ .into_iter()
+ .zip(params_decl.iter())
+ .map(|(id, &(space, typ))| ArgCall {
+ id,
+ typ: ast::Type::Scalar(typ),
+ space: space,
+ })
+ .collect::<Vec<_>>()
}
fn normalize_labels(
@@ -297,9 +404,11 @@ fn normalize_labels(
labels_in_use.insert(cond.if_true);
labels_in_use.insert(cond.if_false);
}
- Statement::Variable(_, _, _, _)
+ Statement::Call(_)
+ | Statement::Variable(_)
| Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
+ | Statement::RetValue(_, _)
| Statement::Conversion(_)
| Statement::Constant(_)
| Statement::Label(_) => (),
@@ -314,14 +423,14 @@ fn normalize_labels(
}
fn normalize_predicates(
- func: Vec<ast::Statement<NormalizedArgParams>>,
+ func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
-) -> Vec<NormalizedStatement> {
+) -> Vec<UnadornedStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
- ast::Statement::Label(id) => result.push(Statement::Label(id)),
- ast::Statement::Instruction(pred, inst) => {
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
let if_true = id_def.new_id(None);
let if_false = id_def.new_id(None);
@@ -347,66 +456,100 @@ fn normalize_predicates(
result.push(Statement::Instruction(inst));
}
}
- ast::Statement::Variable(var) => result.push(Statement::Variable(
- var.name, var.v_type, var.space, var.align,
- )),
+ Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
- ast::Statement::Block(_) => unreachable!(),
+ _ => unreachable!(),
}
}
result
}
-fn insert_mem_ssa_statements(
- func: Vec<NormalizedStatement>,
+fn insert_mem_ssa_statements<'a, 'b>(
+ func: Vec<UnadornedStatement>,
id_def: &mut NumericIdResolver,
-) -> Vec<NormalizedStatement> {
+ mut f_args: ast::MethodDecl<'a, ExpandedArgParams>,
+) -> (
+ ast::MethodDecl<'a, ExpandedArgParams>,
+ Vec<UnadornedStatement>,
+) {
let mut result = Vec::with_capacity(func.len());
+ let out_param = match &mut f_args {
+ ast::MethodDecl::Kernel(_, in_params) => {
+ for p in in_params.iter_mut() {
+ let typ = ast::Type::Scalar(p.a_type);
+ let new_id = id_def.new_id(Some(typ));
+ result.push(Statement::Variable(VariableDecl {
+ space: ast::StateSpace::Reg,
+ align: None,
+ v_type: typ,
+ name: p.name,
+ }));
+ result.push(Statement::StoreVar(
+ ast::Arg2St {
+ src1: p.name,
+ src2: new_id,
+ },
+ typ,
+ ));
+ p.name = new_id;
+ }
+ None
+ }
+ ast::MethodDecl::Func(out_params, _, in_params) => {
+ for p in in_params.iter_mut() {
+ let typ = ast::Type::Scalar(p.base.a_type);
+ let new_id = id_def.new_id(Some(typ));
+ result.push(Statement::Variable(VariableDecl {
+ space: ast::StateSpace::Reg,
+ align: None,
+ v_type: typ,
+ name: p.base.name,
+ }));
+ result.push(Statement::StoreVar(
+ ast::Arg2St {
+ src1: p.base.name,
+ src2: new_id,
+ },
+ typ,
+ ));
+ p.base.name = new_id;
+ }
+ match &mut **out_params {
+ [p] => {
+ result.push(Statement::Variable(VariableDecl {
+ space: ast::StateSpace::Reg,
+ align: None,
+ v_type: ast::Type::Scalar(p.base.a_type),
+ name: p.base.name,
+ }));
+ Some(p.base.name)
+ }
+ [] => None,
+ _ => todo!(),
+ }
+ }
+ };
for s in func {
match s {
+ Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call),
Statement::Instruction(inst) => match inst {
- ast::Instruction::Ld(
- ld
- @
- ast::LdData {
- state_space: ast::LdStateSpace::Param,
- ..
- },
- arg,
- ) => {
- result.push(Statement::Instruction(ast::Instruction::Ld(ld, arg)));
- }
- inst => {
- let mut post_statements = Vec::new();
- let inst = inst.visit_variable(&mut |desc| {
- let id_type = match (desc.typ, desc.is_pointer) {
- (Some(t), false) => t,
- (Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64),
- (None, _) => return desc.op,
- };
- let generated_id = id_def.new_id(Some(id_type));
- if !desc.is_dst {
- result.push(Statement::LoadVar(
- Arg2 {
- dst: generated_id,
- src: desc.op,
- },
- id_type,
- ));
- } else {
- post_statements.push(Statement::StoreVar(
- Arg2St {
- src1: desc.op,
- src2: generated_id,
- },
- id_type,
- ));
- }
- generated_id
- });
- result.push(Statement::Instruction(inst));
- result.append(&mut post_statements);
+ ast::Instruction::Ret(d) => {
+ if let Some(out_param) = out_param {
+ let typ = id_def.get_type(out_param);
+ let new_id = id_def.new_id(Some(typ));
+ result.push(Statement::LoadVar(
+ ast::Arg2 {
+ dst: new_id,
+ src: out_param,
+ },
+ typ,
+ ));
+ result.push(Statement::RetValue(d, new_id));
+ } else {
+ result.push(Statement::Instruction(ast::Instruction::Ret(d)))
+ }
}
+ inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
},
Statement::Conditional(mut bra) => {
let generated_id = id_def.new_id(Some(ast::Type::ExtendedScalar(
@@ -422,65 +565,116 @@ fn insert_mem_ssa_statements(
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
- s @ Statement::Variable(_, _, _, _) | s @ Statement::Label(_) => result.push(s),
+ s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
Statement::LoadVar(_, _)
| Statement::StoreVar(_, _)
| Statement::Conversion(_)
+ | Statement::RetValue(_, _)
| Statement::Constant(_) => unreachable!(),
}
}
- result
+ (f_args, result)
}
-fn expand_arguments<'a, 'b, 'c>(
- func: Vec<NormalizedStatement>,
- id_def: &'c mut NumericIdResolver<'a, 'b>,
+trait VisitVariable: Sized {
+ fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ self,
+ f: &mut F,
+ ) -> UnadornedStatement;
+}
+trait VisitVariableExpanded {
+ fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ self,
+ f: &mut F,
+ ) -> ExpandedStatement;
+}
+
+fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
+ id_def: &mut NumericIdResolver,
+ result: &mut Vec<UnadornedStatement>,
+ stmt: F,
+) {
+ let mut post_statements = Vec::new();
+ let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>| {
+ let id_type = match (desc.typ, desc.is_pointer) {
+ (Some(t), false) => t,
+ (Some(_), true) => ast::Type::Scalar(ast::ScalarType::B64),
+ (None, _) => return desc.op,
+ };
+ let generated_id = id_def.new_id(Some(id_type));
+ if !desc.is_dst {
+ result.push(Statement::LoadVar(
+ Arg2 {
+ dst: generated_id,
+ src: desc.op,
+ },
+ id_type,
+ ));
+ } else {
+ post_statements.push(Statement::StoreVar(
+ Arg2St {
+ src1: desc.op,
+ src2: generated_id,
+ },
+ id_type,
+ ));
+ }
+ generated_id
+ });
+ result.push(new_statement);
+ result.append(&mut post_statements);
+}
+
+fn expand_arguments<'a, 'b>(
+ func: Vec<UnadornedStatement>,
+ id_def: &'b mut NumericIdResolver<'a>,
) -> Vec<ExpandedStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
+ Statement::Call(call) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_def);
+ let new_call = call.map(&mut visitor);
+ result.push(Statement::Call(new_call));
+ }
Statement::Instruction(inst) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
let new_inst = inst.map(&mut visitor);
result.push(Statement::Instruction(new_inst));
}
- Statement::Variable(id, typ, ss, align) => {
- result.push(Statement::Variable(id, typ, ss, align))
- }
+ Statement::Variable(v_decl) => result.push(Statement::Variable(v_decl)),
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
+ Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
Statement::Conversion(_) | Statement::Constant(_) => unreachable!(),
}
}
result
}
-struct FlattenArguments<'a, 'b, 'c> {
- func: &'c mut Vec<ExpandedStatement>,
- id_def: &'c mut NumericIdResolver<'a, 'b>,
+struct FlattenArguments<'a, 'b> {
+ func: &'b mut Vec<ExpandedStatement>,
+ id_def: &'b mut NumericIdResolver<'a>,
}
-impl<'a, 'b, 'c> FlattenArguments<'a, 'b, 'c> {
- fn new(
- func: &'c mut Vec<ExpandedStatement>,
- id_def: &'c mut NumericIdResolver<'a, 'b>,
- ) -> Self {
+impl<'a, 'b> FlattenArguments<'a, 'b> {
+ fn new(func: &'b mut Vec<ExpandedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
FlattenArguments { func, id_def }
}
}
-impl<'a, 'b, 'c> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
- for FlattenArguments<'a, 'b, 'c>
+impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
+ for FlattenArguments<'a, 'b>
{
- fn dst_variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
+ fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
desc.op
}
- fn src_operand(&mut self, desc: ArgumentDescriptor<ast::Operand<spirv::Word>>) -> spirv::Word {
+ fn operand(&mut self, desc: ArgumentDescriptor<ast::Operand<spirv::Word>>) -> spirv::Word {
match desc.op {
- ast::Operand::Reg(r) => r,
+ ast::Operand::Reg(r) => self.variable(desc.new_op(r)),
ast::Operand::Imm(x) => {
if let Some(typ) = desc.typ {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
@@ -535,12 +729,22 @@ impl<'a, 'b, 'c> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
}
}
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ ) -> spirv::Word {
+ match desc.op {
+ ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg)),
+ ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x))),
+ }
+ }
+
fn src_mov_operand(
&mut self,
desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>,
) -> spirv::Word {
match &desc.op {
- ast::MovOperand::Op(opr) => self.src_operand(desc.new_op(*opr)),
+ ast::MovOperand::Op(opr) => self.operand(desc.new_op(*opr)),
ast::MovOperand::Vec(_, _) => todo!(),
}
}
@@ -553,11 +757,12 @@ impl<'a, 'b, 'c> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
- ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
semantics are to first zext/chop/bitcast `y` as needed and then do
documented special ld/st/cvt conversion rules for destination operands
- - generic ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64,
- which is bitcast to a pointer, dereferenced and then documented special
- ld/st/cvt conversion rules are applied to dst
- - generic st: for instruction `st [x], y`, x must be of type b64/u64/s64,
- which is bitcast to a pointer
+ - st.param [x] y (used as function return arguments) same rule as above applies
+ - generic/global ld: for instruction `ld x, [y]`, y must be of type
+ b64/u64/s64, which is bitcast to a pointer, dereferenced and then
+ documented special ld/st/cvt conversion rules are applied to dst
+ - generic/global st: for instruction `st [x], y`, x must be of type
+ b64/u64/s64, which is bitcast to a pointer
*/
fn insert_implicit_conversions(
func: Vec<ExpandedStatement>,
@@ -566,6 +771,7 @@ fn insert_implicit_conversions(
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
+ Statement::Call(call) => insert_implicit_bitcasts(&mut result, id_def, call),
Statement::Instruction(inst) => match inst {
ast::Instruction::Ld(ld, mut arg) => {
arg.src = insert_implicit_conversions_ld_src(
@@ -611,9 +817,10 @@ fn insert_implicit_conversions(
s @ Statement::Conditional(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
- | s @ Statement::Variable(_, _, _, _)
+ | s @ Statement::Variable(_)
| s @ Statement::LoadVar(_, _)
- | s @ Statement::StoreVar(_, _) => result.push(s),
+ | s @ Statement::StoreVar(_, _)
+ | s @ Statement::RetValue(_, _) => result.push(s),
Statement::Conversion(_) => unreachable!(),
}
}
@@ -623,25 +830,19 @@ fn insert_implicit_conversions(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- args: &[ast::Argument<ExpandedArgParams>],
-) -> spirv::Word {
- map.get_or_add_fn(builder, args.iter().map(|arg| SpirvType::from(arg.a_type)))
-}
-
-fn emit_function_args(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- args: &[ast::Argument<ExpandedArgParams>],
-) {
- for arg in args {
- let result_type = map.get_or_add_scalar(builder, arg.a_type);
- let inst = dr::Instruction::new(
- spirv::Op::FunctionParameter,
- Some(result_type),
- Some(arg.name),
- Vec::new(),
- );
- builder.function.as_mut().unwrap().parameters.push(inst);
+ method_decl: &ast::MethodDecl<ExpandedArgParams>,
+) -> (spirv::Word, spirv::Word) {
+ match method_decl {
+ ast::MethodDecl::Func(out_params, _, in_params) => map.get_or_add_fn(
+ builder,
+ out_params.iter().map(|p| SpirvType::from(p.base.a_type)),
+ in_params.iter().map(|p| SpirvType::from(p.base.a_type)),
+ ),
+ ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
+ builder,
+ iter::empty(),
+ params.iter().map(|p| SpirvType::from(p.a_type)),
+ ),
}
}
@@ -649,9 +850,9 @@ fn emit_function_body_ops(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
- func: &Option<Vec<ExpandedStatement>>,
+ func: &[ExpandedStatement],
) -> Result<(), dr::Error> {
- for s in func.as_ref().unwrap() {
+ for s in func {
match s {
Statement::Label(id) => {
if builder.block.is_some() {
@@ -667,13 +868,26 @@ fn emit_function_body_ops(
}
match s {
Statement::Label(_) => (),
- Statement::Variable(id, typ, ss, align) => {
+ Statement::Call(call) => {
+ let (result_type, result_id) = match &*call.ret_params {
+ [p] => (map.get_or_add(builder, SpirvType::from(p.typ)), p.id),
+ _ => todo!(),
+ };
+ let arg_list = call.param_list.iter().map(|p| p.id).collect::<Vec<_>>();
+ builder.function_call(result_type, Some(result_id), call.func, arg_list)?;
+ }
+ Statement::Variable(VariableDecl {
+ name: id,
+ v_type: typ,
+ space: ss,
+ align,
+ }) => {
let type_id = map.get_or_add(
builder,
SpirvType::new_pointer(*typ, spirv::StorageClass::Function),
);
let st_class = match ss {
- ast::StateSpace::Reg => spirv::StorageClass::Function,
+ ast::StateSpace::Reg | ast::StateSpace::Param => spirv::StorageClass::Function,
ast::StateSpace::Local => spirv::StorageClass::Workgroup,
_ => todo!(),
};
@@ -722,7 +936,7 @@ fn emit_function_body_ops(
}
Statement::Instruction(inst) => match inst {
ast::Instruction::Abs(_, _) => todo!(),
- ast::Instruction::Call(_, _) => todo!(),
+ ast::Instruction::Call(_) => unreachable!(),
// SPIR-V does not support marking jumps as guaranteed-converged
ast::Instruction::Bra(_, arg) => {
builder.branch(arg.src)?;
@@ -737,7 +951,8 @@ fn emit_function_body_ops(
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
ast::LdStateSpace::Param => {
- builder.store(arg.dst, arg.src, None, [])?;
+ let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
_ => todo!(),
}
@@ -746,11 +961,17 @@ fn emit_function_body_ops(
if data.qualifier != ast::LdStQualifier::Weak
|| data.vector.is_some()
|| (data.state_space != ast::StStateSpace::Generic
+ && data.state_space != ast::StStateSpace::Param
&& data.state_space != ast::StStateSpace::Global)
{
todo!()
}
- builder.store(arg.src1, arg.src2, None, &[])?;
+ if data.state_space == ast::StStateSpace::Param {
+ let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
+ builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
+ } else {
+ builder.store(arg.src1, arg.src2, None, &[])?;
+ }
}
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
@@ -808,6 +1029,9 @@ fn emit_function_body_ops(
Statement::StoreVar(arg, _) => {
builder.store(arg.src1, arg.src2, None, [])?;
}
+ Statement::RetValue(_, id) => {
+ builder.ret_value(*id)?;
+ }
}
}
Ok(())
@@ -1123,8 +1347,9 @@ fn emit_implicit_conversion(
// TODO: support scopes
fn normalize_identifiers<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'a, 'b>,
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
-) -> Vec<ast::Statement<NormalizedArgParams>> {
+) -> Vec<NormalizedStatement> {
for s in func.iter() {
match s {
ast::Statement::Label(id) => {
@@ -1135,58 +1360,63 @@ fn normalize_identifiers<'a, 'b>(
}
let mut result = Vec::new();
for s in func {
- expand_map_variables(id_defs, &mut result, s);
+ expand_map_variables(id_defs, fn_defs, &mut result, s);
}
result
}
fn expand_map_variables<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
- result: &mut Vec<ast::Statement<NormalizedArgParams>>,
+ fn_defs: &GlobalFnDeclResolver,
+ result: &mut Vec<NormalizedStatement>,
s: ast::Statement<ast::ParsedArgParams<'a>>,
) {
match s {
ast::Statement::Block(block) => {
id_defs.start_block();
for s in block {
- expand_map_variables(id_defs, result, s);
+ expand_map_variables(id_defs, fn_defs, result, s);
}
id_defs.end_block();
}
- ast::Statement::Label(name) => result.push(ast::Statement::Label(id_defs.get_id(name))),
- ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction(
+ ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name))),
+ ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
i.map_variable(&mut |id| id_defs.get_id(id)),
- )),
+ ))),
ast::Statement::Variable(var) => match var.count {
Some(count) => {
for new_id in id_defs.add_defs(var.name, count, var.v_type) {
- result.push(ast::Statement::Variable(ast::Variable {
+ result.push(Statement::Variable(VariableDecl {
space: var.space,
align: var.align,
v_type: var.v_type,
name: new_id,
- count: None,
}))
}
}
None => {
let new_id = id_defs.add_def(var.name, Some(var.v_type));
- result.push(ast::Statement::Variable(ast::Variable {
+ result.push(Statement::Variable(VariableDecl {
space: var.space,
align: var.align,
v_type: var.v_type,
name: new_id,
- count: None,
}));
}
},
}
}
-struct GlobalStringIdResolver<'a> {
+struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
- variables: HashMap<Cow<'a, str>, spirv::Word>,
+ variables: HashMap<Cow<'input, str>, spirv::Word>,
+ fns: HashMap<spirv::Word, FnDecl>,
+}
+
+pub struct FnDecl {
+ ret_vals: Vec<(ast::FnArgStateSpace, ast::ScalarType)>,
+ params: Vec<(ast::FnArgStateSpace, ast::ScalarType)>,
}
impl<'a> GlobalStringIdResolver<'a> {
@@ -1194,14 +1424,20 @@ impl<'a> GlobalStringIdResolver<'a> {
Self {
current_id: start_id,
variables: HashMap::new(),
+ fns: HashMap::new(),
}
}
- fn add_def(&mut self, id: &'a str) -> spirv::Word {
- let numeric_id = self.current_id;
- self.variables.insert(Cow::Borrowed(id), numeric_id);
- self.current_id += 1;
- numeric_id
+ fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
+ match self.variables.entry(Cow::Borrowed(id)) {
+ hash_map::Entry::Occupied(e) => *(e.get()),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = self.current_id;
+ e.insert(numeric_id);
+ self.current_id += 1;
+ numeric_id
+ }
+ }
}
fn get_id(&self, id: &str) -> spirv::Word {
@@ -1211,27 +1447,84 @@ impl<'a> GlobalStringIdResolver<'a> {
fn current_id(&self) -> spirv::Word {
self.current_id
}
+
+ fn start_fn<'b>(
+ &'b mut self,
+ header: &'b ast::MethodDecl<'a, ast::ParsedArgParams<'a>>,
+ ) -> (
+ FnStringIdResolver<'a, 'b>,
+ GlobalFnDeclResolver<'a, 'b>,
+ ast::MethodDecl<'a, ExpandedArgParams>,
+ ) {
+ // In case a function decl was inserted eearlier we want to use its id
+ let name_id = self.get_or_add_def(header.name());
+ let mut fn_resolver = FnStringIdResolver {
+ current_id: &mut self.current_id,
+ global_variables: &self.variables,
+ variables: vec![HashMap::new(); 1],
+ type_check: HashMap::new(),
+ };
+ let new_fn_decl = match header {
+ ast::MethodDecl::Kernel(name, params) => {
+ ast::MethodDecl::Kernel(name, expand_kernel_params(&mut fn_resolver, params.iter()))
+ }
+ ast::MethodDecl::Func(ret_params, _, params) => {
+ let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter());
+ let params_ids = expand_fn_params(&mut fn_resolver, params.iter());
+ self.fns.insert(
+ name_id,
+ FnDecl {
+ ret_vals: ret_params_ids
+ .iter()
+ .map(|p| (p.state_space, p.base.a_type))
+ .collect(),
+ params: params_ids
+ .iter()
+ .map(|p| (p.state_space, p.base.a_type))
+ .collect(),
+ },
+ );
+ ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
+ }
+ };
+ (
+ fn_resolver,
+ GlobalFnDeclResolver {
+ variables: &self.variables,
+ fns: &self.fns,
+ },
+ new_fn_decl,
+ )
+ }
}
-struct FnStringIdResolver<'a, 'b> {
- global: &'b mut GlobalStringIdResolver<'a>,
- variables: Vec<HashMap<Cow<'a, str>, spirv::Word>>,
- type_check: HashMap<u32, ast::Type>,
+pub struct GlobalFnDeclResolver<'input, 'a> {
+ variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
+ fns: &'a HashMap<spirv::Word, FnDecl>,
}
-impl<'a, 'b> FnStringIdResolver<'a, 'b> {
- fn new(global: &'b mut GlobalStringIdResolver<'a>, f_name: &'a str) -> Self {
- global.add_def(f_name);
- Self {
- global: global,
- variables: vec![HashMap::new(); 1],
- type_check: HashMap::new(),
- }
+impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
+ fn get_fn_decl(&self, id: spirv::Word) -> &FnDecl {
+ &self.fns[&id]
+ }
+
+ fn get_fn_decl_str(&self, id: &str) -> &'a FnDecl {
+ &self.fns[&self.variables[id]]
}
+}
+
+struct FnStringIdResolver<'input, 'b> {
+ current_id: &'b mut spirv::Word,
+ global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
+ //global: &'b mut GlobalStringIdResolver<'a>,
+ variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
+ type_check: HashMap<u32, ast::Type>,
+}
- fn finish(self) -> NumericIdResolver<'a, 'b> {
+impl<'a, 'b> FnStringIdResolver<'a, 'b> {
+ fn finish(self) -> NumericIdResolver<'b> {
NumericIdResolver {
- global: self.global,
+ current_id: self.current_id,
type_check: self.type_check,
}
}
@@ -1251,15 +1544,11 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
None => continue,
}
}
- self.global.variables[id]
- }
-
- fn add_global_def(&mut self, id: &'a str) -> spirv::Word {
- self.global.add_def(id)
+ self.global_variables[id]
}
fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
- let numeric_id = self.global.current_id;
+ let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
@@ -1267,7 +1556,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
if let Some(typ) = typ {
self.type_check.insert(numeric_id, typ);
}
- self.global.current_id += 1;
+ *self.current_id += 1;
numeric_id
}
@@ -1278,7 +1567,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
count: u32,
typ: ast::Type,
) -> impl Iterator<Item = spirv::Word> {
- let numeric_id = self.global.current_id;
+ let numeric_id = *self.current_id;
for i in 0..count {
self.variables
.last_mut()
@@ -1286,65 +1575,191 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
self.type_check.insert(numeric_id + i, typ);
}
- self.global.current_id += count;
+ *self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
}
}
-struct NumericIdResolver<'a, 'b> {
- global: &'b mut GlobalStringIdResolver<'a>,
+struct NumericIdResolver<'b> {
+ current_id: &'b mut spirv::Word,
type_check: HashMap<u32, ast::Type>,
}
-impl<'a, 'b> NumericIdResolver<'a, 'b> {
+impl<'b> NumericIdResolver<'b> {
fn get_type(&self, id: spirv::Word) -> ast::Type {
self.type_check[&id]
}
fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
- let new_id = self.global.current_id;
+ let new_id = *self.current_id;
if let Some(typ) = typ {
self.type_check.insert(new_id, typ);
}
- self.global.current_id += 1;
+ *self.current_id += 1;
new_id
}
}
-enum Statement<I> {
- Variable(spirv::Word, ast::Type, ast::StateSpace, Option<u32>),
- LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
- StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
+enum Statement<I, P: ast::ArgParams> {
Label(u32),
+ Variable(VariableDecl),
Instruction(I),
+ LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
+ StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
+ Call(ResolvedCall<P>),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
+ RetValue(ast::RetData, spirv::Word),
+}
+
+struct VariableDecl {
+ pub space: ast::StateSpace,
+ pub align: Option<u32>,
+ pub v_type: ast::Type,
+ pub name: spirv::Word,
+}
+
+struct ResolvedCall<P: ast::ArgParams> {
+ pub uniform: bool,
+ pub ret_params: Vec<ArgCall<spirv::Word>>,
+ pub func: spirv::Word,
+ pub param_list: Vec<ArgCall<P::CallOperand>>,
+}
+
+impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
+ fn map<To: ArgParamsEx<ID = spirv::Word>, V: ArgumentMapVisitor<From, To>>(
+ self,
+ visitor: &mut V,
+ ) -> ResolvedCall<To> {
+ let ret_params = self
+ .ret_params
+ .into_iter()
+ .map(|p| {
+ let new_id = visitor.variable(ArgumentDescriptor {
+ op: p.id,
+ typ: Some(p.typ),
+ is_dst: true,
+ is_pointer: false,
+ });
+ ArgCall {
+ id: new_id,
+ typ: p.typ,
+ space: p.space,
+ }
+ })
+ .collect();
+ let func = visitor.variable(ArgumentDescriptor {
+ op: self.func,
+ typ: None,
+ is_dst: false,
+ is_pointer: false,
+ });
+ let param_list = self
+ .param_list
+ .into_iter()
+ .map(|p| {
+ let new_id = visitor.src_call_operand(ArgumentDescriptor {
+ op: p.id,
+ typ: Some(p.typ),
+ is_dst: false,
+ is_pointer: false,
+ });
+ ArgCall {
+ id: new_id,
+ typ: p.typ,
+ space: p.space,
+ }
+ })
+ .collect();
+ ResolvedCall {
+ uniform: self.uniform,
+ ret_params,
+ func,
+ param_list,
+ }
+ }
+}
+
+impl VisitVariable for ResolvedCall<NormalizedArgParams> {
+ fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ self,
+ f: &mut F,
+ ) -> UnadornedStatement {
+ Statement::Call(self.map(f))
+ }
+}
+
+impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
+ fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ self,
+ f: &mut F,
+ ) -> ExpandedStatement {
+ Statement::Call(self.map(f))
+ }
+}
+
+struct ArgCall<ID> {
+ id: ID,
+ typ: ast::Type,
+ space: ast::FnArgStateSpace,
+}
+
+pub trait ArgParamsEx: ast::ArgParams {
+ fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl;
+}
+
+impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
+ fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl {
+ decl.get_fn_decl_str(id)
+ }
}
enum NormalizedArgParams {}
-type NormalizedStatement = Statement<ast::Instruction<NormalizedArgParams>>;
+type NormalizedStatement = Statement<
+ (
+ Option<ast::PredAt<spirv::Word>>,
+ ast::Instruction<NormalizedArgParams>,
+ ),
+ NormalizedArgParams,
+>;
+type UnadornedStatement = Statement<ast::Instruction<NormalizedArgParams>, NormalizedArgParams>;
impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
+ type CallOperand = ast::CallOperand<spirv::Word>;
type MovOperand = ast::MovOperand<spirv::Word>;
}
+impl ArgParamsEx for NormalizedArgParams {
+ fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
+ decl.get_fn_decl(*id)
+ }
+}
+
enum ExpandedArgParams {}
-type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>>;
+type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
+ type CallOperand = spirv::Word;
type MovOperand = spirv::Word;
}
-trait ArgumentMapVisitor<T: ast::ArgParams, U: ast::ArgParams> {
- fn dst_variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID;
- fn src_operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand;
+impl ArgParamsEx for ExpandedArgParams {
+ fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
+ decl.get_fn_decl(*id)
+ }
+}
+
+trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
+ fn variable(&mut self, desc: ArgumentDescriptor<T::ID>) -> U::ID;
+ fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>) -> U::Operand;
+ fn src_call_operand(&mut self, desc: ArgumentDescriptor<T::CallOperand>) -> U::CallOperand;
fn src_mov_operand(&mut self, desc: ArgumentDescriptor<T::MovOperand>) -> U::MovOperand;
}
@@ -1352,12 +1767,16 @@ impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
{
- fn dst_variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
+ fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
- fn src_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
+ fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
+ fn src_call_operand(&mut self, mut desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
+ desc.op = self(desc.new_op(desc.op));
+ desc.op
+ }
fn src_mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
@@ -1367,11 +1786,11 @@ impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> fo
where
T: FnMut(&str) -> spirv::Word,
{
- fn dst_variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word {
+ fn variable(&mut self, desc: ArgumentDescriptor<&str>) -> spirv::Word {
self(desc.op)
}
- fn src_operand(
+ fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
) -> ast::Operand<spirv::Word> {
@@ -1382,12 +1801,22 @@ where
}
}
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::CallOperand<&str>>,
+ ) -> ast::CallOperand<spirv::Word> {
+ match desc.op {
+ ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(id)),
+ ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm),
+ }
+ }
+
fn src_mov_operand(
&mut self,
desc: ArgumentDescriptor<ast::MovOperand<&str>>,
) -> ast::MovOperand<spirv::Word> {
match desc.op {
- ast::MovOperand::Op(op) => ast::MovOperand::Op(self.src_operand(desc.new_op(op))),
+ ast::MovOperand::Op(op) => ast::MovOperand::Op(self.operand(desc.new_op(op))),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
}
}
@@ -1411,17 +1840,21 @@ impl<T> ArgumentDescriptor<T> {
}
}
-impl<T: ast::ArgParams> ast::Instruction<T> {
- fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Instruction<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
) -> ast::Instruction<U> {
match self {
ast::Instruction::Abs(_, _) => todo!(),
- ast::Instruction::Call(_, _) => todo!(),
+ ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
- ast::Instruction::Ld(d, a.map_ld(visitor, Some(ast::Type::Scalar(inst_type))))
+ let src_is_pointer = d.state_space != ast::LdStateSpace::Param;
+ ast::Instruction::Ld(
+ d,
+ a.map_ld(visitor, Some(ast::Type::Scalar(inst_type)), src_is_pointer),
+ )
}
ast::Instruction::Mov(d, a) => {
let inst_type = d.typ;
@@ -1467,7 +1900,11 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
}
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
- ast::Instruction::St(d, a.map(visitor, Some(ast::Type::Scalar(inst_type))))
+ let param_space = d.state_space == ast::StStateSpace::Param;
+ ast::Instruction::St(
+ d,
+ a.map(visitor, Some(ast::Type::Scalar(inst_type)), param_space),
+ )
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
@@ -1479,12 +1916,12 @@ impl<T: ast::ArgParams> ast::Instruction<T> {
}
}
-impl ast::Instruction<NormalizedArgParams> {
- fn visit_variable<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+impl VisitVariable for ast::Instruction<NormalizedArgParams> {
+ fn visit_variable<'a, F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
self,
f: &mut F,
- ) -> ast::Instruction<NormalizedArgParams> {
- self.map(f)
+ ) -> UnadornedStatement {
+ Statement::Instruction(self.map(f))
}
}
@@ -1492,11 +1929,11 @@ impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
where
T: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word,
{
- fn dst_variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
+ fn variable(&mut self, desc: ArgumentDescriptor<spirv::Word>) -> spirv::Word {
self(desc)
}
- fn src_operand(
+ fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
) -> ast::Operand<spirv::Word> {
@@ -1507,6 +1944,16 @@ where
}
}
+ fn src_call_operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
+ ) -> ast::CallOperand<spirv::Word> {
+ match desc.op {
+ ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id))),
+ ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm),
+ }
+ }
+
fn src_mov_operand(
&mut self,
desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>,
@@ -1515,7 +1962,7 @@ where
ast::MovOperand::Op(op) => ast::MovOperand::Op(ArgumentMapVisitor::<
NormalizedArgParams,
NormalizedArgParams,
- >::src_operand(
+ >::operand(
self, desc.new_op(op)
)),
ast::MovOperand::Vec(x1, x2) => ast::MovOperand::Vec(x1, x2),
@@ -1524,17 +1971,8 @@ where
}
impl ast::Instruction<ExpandedArgParams> {
- fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
- self,
- f: &mut F,
- ) -> Self {
- self.map(f)
- }
-
fn jump_target(&self) -> Option<spirv::Word> {
match self {
- ast::Instruction::Abs(_, _) => todo!(),
- ast::Instruction::Call(_, _) => todo!(),
ast::Instruction::Bra(_, a) => Some(a.src),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
@@ -1547,11 +1985,22 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Cvta(_, _)
| ast::Instruction::Shl(_, _)
| ast::Instruction::St(_, _)
- | ast::Instruction::Ret(_) => None,
+ | ast::Instruction::Ret(_)
+ | ast::Instruction::Abs(_, _)
+ | ast::Instruction::Call(_) => None,
}
}
}
+impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
+ fn visit_variable_extended<F: FnMut(ArgumentDescriptor<spirv::Word>) -> spirv::Word>(
+ self,
+ f: &mut F,
+ ) -> ExpandedStatement {
+ Statement::Instruction(self.map(f))
+ }
+}
+
type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
@@ -1592,24 +2041,38 @@ impl<T> ast::PredAt<T> {
}
}
-// REMOVE
impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
fn map_variable<F: FnMut(&str) -> spirv::Word>(
self,
f: &mut F,
) -> ast::Instruction<NormalizedArgParams> {
- self.map(f)
+ match self {
+ ast::Instruction::Call(call) => {
+ let call_inst = ast::CallInst {
+ uniform: call.uniform,
+ ret_params: call.ret_params.into_iter().map(|p| f(p)).collect(),
+ func: f(call.func),
+ param_list: call
+ .param_list
+ .into_iter()
+ .map(|p| p.map_variable(f))
+ .collect(),
+ };
+ ast::Instruction::Call(call_inst)
+ }
+ i => i.map(f),
+ }
}
}
-impl<T: ast::ArgParams> ast::Arg1<T> {
- fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg1<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg1<U> {
ast::Arg1 {
- src: visitor.dst_variable(ArgumentDescriptor {
+ src: visitor.variable(ArgumentDescriptor {
op: self.src,
typ: t,
is_dst: false,
@@ -1619,20 +2082,20 @@ impl<T: ast::ArgParams> ast::Arg1<T> {
}
}
-impl<T: ast::ArgParams> ast::Arg2<T> {
- fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg2<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.dst_variable(ArgumentDescriptor {
+ dst: visitor.variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
- src: visitor.src_operand(ArgumentDescriptor {
+ src: visitor.operand(ArgumentDescriptor {
op: self.src,
typ: t,
is_dst: false,
@@ -1641,41 +2104,42 @@ impl<T: ast::ArgParams> ast::Arg2<T> {
}
}
- fn map_ld<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+ fn map_ld<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
+ is_src_pointer: bool,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.dst_variable(ArgumentDescriptor {
+ dst: visitor.variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
- src: visitor.src_operand(ArgumentDescriptor {
+ src: visitor.operand(ArgumentDescriptor {
op: self.src,
typ: t,
is_dst: false,
- is_pointer: true,
+ is_pointer: is_src_pointer,
}),
}
}
- fn map_cvt<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+ fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
dst_t: ast::Type,
src_t: ast::Type,
) -> ast::Arg2<U> {
ast::Arg2 {
- dst: visitor.dst_variable(ArgumentDescriptor {
+ dst: visitor.variable(ArgumentDescriptor {
op: self.dst,
typ: Some(dst_t),
is_dst: true,
is_pointer: false,
}),
- src: visitor.src_operand(ArgumentDescriptor {
+ src: visitor.operand(ArgumentDescriptor {
op: self.src,
typ: Some(src_t),
is_dst: false,
@@ -1685,20 +2149,21 @@ impl<T: ast::ArgParams> ast::Arg2<T> {
}
}
-impl<T: ast::ArgParams> ast::Arg2St<T> {
- fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg2St<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
+ param_space: bool,
) -> ast::Arg2St<U> {
ast::Arg2St {
- src1: visitor.src_operand(ArgumentDescriptor {
+ src1: visitor.operand(ArgumentDescriptor {
op: self.src1,
typ: t,
- is_dst: false,
- is_pointer: true,
+ is_dst: param_space,
+ is_pointer: !param_space,
}),
- src2: visitor.src_operand(ArgumentDescriptor {
+ src2: visitor.operand(ArgumentDescriptor {
op: self.src2,
typ: t,
is_dst: false,
@@ -1708,14 +2173,14 @@ impl<T: ast::ArgParams> ast::Arg2St<T> {
}
}
-impl<T: ast::ArgParams> ast::Arg2Mov<T> {
- fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg2Mov<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg2Mov<U> {
ast::Arg2Mov {
- dst: visitor.dst_variable(ArgumentDescriptor {
+ dst: visitor.variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
@@ -1731,26 +2196,26 @@ impl<T: ast::ArgParams> ast::Arg2Mov<T> {
}
}
-impl<T: ast::ArgParams> ast::Arg3<T> {
- fn map_non_shift<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg3<T> {
+ fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg3<U> {
ast::Arg3 {
- dst: visitor.dst_variable(ArgumentDescriptor {
+ dst: visitor.variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
- src1: visitor.src_operand(ArgumentDescriptor {
+ src1: visitor.operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: false,
}),
- src2: visitor.src_operand(ArgumentDescriptor {
+ src2: visitor.operand(ArgumentDescriptor {
op: self.src2,
typ: t,
is_dst: false,
@@ -1759,25 +2224,25 @@ impl<T: ast::ArgParams> ast::Arg3<T> {
}
}
- fn map_shift<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+ fn map_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg3<U> {
ast::Arg3 {
- dst: visitor.dst_variable(ArgumentDescriptor {
+ dst: visitor.variable(ArgumentDescriptor {
op: self.dst,
typ: t,
is_dst: true,
is_pointer: false,
}),
- src1: visitor.src_operand(ArgumentDescriptor {
+ src1: visitor.operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: false,
}),
- src2: visitor.src_operand(ArgumentDescriptor {
+ src2: visitor.operand(ArgumentDescriptor {
op: self.src2,
typ: Some(ast::Type::Scalar(ast::ScalarType::U32)),
is_dst: false,
@@ -1787,34 +2252,34 @@ impl<T: ast::ArgParams> ast::Arg3<T> {
}
}
-impl<T: ast::ArgParams> ast::Arg4<T> {
- fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg4<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg4<U> {
ast::Arg4 {
- dst1: visitor.dst_variable(ArgumentDescriptor {
+ dst1: visitor.variable(ArgumentDescriptor {
op: self.dst1,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: true,
is_pointer: false,
}),
dst2: self.dst2.map(|dst2| {
- visitor.dst_variable(ArgumentDescriptor {
+ visitor.variable(ArgumentDescriptor {
op: dst2,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: true,
is_pointer: false,
})
}),
- src1: visitor.src_operand(ArgumentDescriptor {
+ src1: visitor.operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: false,
}),
- src2: visitor.src_operand(ArgumentDescriptor {
+ src2: visitor.operand(ArgumentDescriptor {
op: self.src2,
typ: t,
is_dst: false,
@@ -1824,40 +2289,40 @@ impl<T: ast::ArgParams> ast::Arg4<T> {
}
}
-impl<T: ast::ArgParams> ast::Arg5<T> {
- fn map<U: ast::ArgParams, V: ArgumentMapVisitor<T, U>>(
+impl<T: ArgParamsEx> ast::Arg5<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: Option<ast::Type>,
) -> ast::Arg5<U> {
ast::Arg5 {
- dst1: visitor.dst_variable(ArgumentDescriptor {
+ dst1: visitor.variable(ArgumentDescriptor {
op: self.dst1,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: true,
is_pointer: false,
}),
dst2: self.dst2.map(|dst2| {
- visitor.dst_variable(ArgumentDescriptor {
+ visitor.variable(ArgumentDescriptor {
op: dst2,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: true,
is_pointer: false,
})
}),
- src1: visitor.src_operand(ArgumentDescriptor {
+ src1: visitor.operand(ArgumentDescriptor {
op: self.src1,
typ: t,
is_dst: false,
is_pointer: false,
}),
- src2: visitor.src_operand(ArgumentDescriptor {
+ src2: visitor.operand(ArgumentDescriptor {
op: self.src2,
typ: t,
is_dst: false,
is_pointer: false,
}),
- src3: visitor.src_operand(ArgumentDescriptor {
+ src3: visitor.operand(ArgumentDescriptor {
op: self.src3,
typ: Some(ast::Type::ExtendedScalar(ast::ExtendedScalarType::Pred)),
is_dst: false,
@@ -1867,6 +2332,74 @@ impl<T: ast::ArgParams> ast::Arg5<T> {
}
}
+/*
+impl<T: ArgParamsEx> ast::ArgCall<T> {
+ fn map<'a, U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ fn_resolve: &GlobalFnDeclResolver<'a>,
+ ) -> ast::ArgCall<U> {
+ // TODO: error out if lengths don't match
+ let fn_decl = T::get_fn_decl(&self.func, fn_resolve);
+ let ret_params = self
+ .ret_params
+ .into_iter()
+ .zip(fn_decl.ret_vals.iter().copied())
+ .map(|(a, (space, typ))| {
+ visitor.variable(ArgumentDescriptor {
+ op: a,
+ typ: Some(ast::Type::Scalar(typ)),
+ is_dst: true,
+ is_pointer: if space == ast::FnArgStateSpace::Reg {
+ false
+ } else {
+ true
+ },
+ })
+ })
+ .collect();
+ let func = visitor.variable(ArgumentDescriptor {
+ op: self.func,
+ typ: None,
+ is_dst: false,
+ is_pointer: false,
+ });
+ let param_list = self
+ .param_list
+ .into_iter()
+ .zip(fn_decl.params.iter().copied())
+ .map(|(a, (space, typ))| {
+ visitor.src_call_operand(ArgumentDescriptor {
+ op: a,
+ typ: Some(ast::Type::Scalar(typ)),
+ is_dst: false,
+ is_pointer: if space == ast::FnArgStateSpace::Reg {
+ false
+ } else {
+ true
+ },
+ })
+ })
+ .collect();
+ ast::ArgCall {
+ uniform: false,
+ ret_params,
+ func: func,
+ param_list: param_list,
+ }
+ }
+}
+*/
+
+impl<T> ast::CallOperand<T> {
+ fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::CallOperand<U> {
+ match self {
+ ast::CallOperand::Reg(id) => ast::CallOperand::Reg(f(id)),
+ ast::CallOperand::Imm(x) => ast::CallOperand::Imm(x),
+ }
+ }
+}
+
impl ast::StStateSpace {
fn to_ld_ss(self) -> ast::LdStateSpace {
match self {
@@ -2282,10 +2815,10 @@ fn should_convert_relaxed_dst(
fn insert_implicit_bitcasts(
func: &mut Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
- instr: ast::Instruction<ExpandedArgParams>,
+ stmt: impl VisitVariableExpanded,
) {
let mut dst_coercion = None;
- let instr = instr.visit_variable_extended(&mut |mut desc| {
+ let instr = stmt.visit_variable_extended(&mut |mut desc| {
let id_type_from_instr = match desc.typ {
Some(t) => t,
None => return desc.op,
@@ -2315,17 +2848,25 @@ fn insert_implicit_bitcasts(
desc.op
}
});
- func.push(Statement::Instruction(instr));
+ func.push(instr);
if let Some(cond) = dst_coercion {
func.push(cond);
}
}
-
-impl<'a> ast::FunctionHeader<'a, ast::ParsedArgParams<'a>> {
+impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
fn name(&self) -> &'a str {
match self {
- ast::FunctionHeader::Kernel(name) => name,
- ast::FunctionHeader::Func(_, name) => name,
+ ast::MethodDecl::Kernel(name, _) => name,
+ ast::MethodDecl::Func(_, name, _) => name,
+ }
+ }
+}
+
+impl<'a, P: ArgParamsEx> ast::MethodDecl<'a, P> {
+ fn visit_args(&self, f: impl FnMut(&ast::KernelArgument<P>)) {
+ match self {
+ ast::MethodDecl::Kernel(_, params) => params.iter().for_each(f),
+ ast::MethodDecl::Func(_, _, params) => params.iter().map(|a| &a.base).for_each(f),
}
}
}