summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-06 00:56:45 +0100
committerAndrzej Janik <[email protected]>2020-11-06 00:56:45 +0100
commitac6265f257654180f6661c406a025313190448c4 (patch)
treedf84117141e484c0c9da03a94aae4b5018c24607
parentd7bf1acf84faa8f6cb1d5edb6c4d9eb0f05a5ae0 (diff)
downloadZLUDA-ac6265f257654180f6661c406a025313190448c4.tar.gz
ZLUDA-ac6265f257654180f6661c406a025313190448c4.zip
Implement instructions bfe, rem, xor
-rw-r--r--ptx/lib/notcuda_ptx_impl.cl22
-rw-r--r--ptx/lib/notcuda_ptx_impl.spvbin48348 -> 49396 bytes
-rw-r--r--ptx/src/ast.rs19
-rw-r--r--ptx/src/ptx.lalrpop63
-rw-r--r--ptx/src/test/spirv_run/bfe.ptx23
-rw-r--r--ptx/src/test/spirv_run/bfe.spvtxt70
-rw-r--r--ptx/src/test/spirv_run/mod.rs14
-rw-r--r--ptx/src/test/spirv_run/rem.ptx23
-rw-r--r--ptx/src/test/spirv_run/rem.spvtxt55
-rw-r--r--ptx/src/test/spirv_run/xor.ptx23
-rw-r--r--ptx/src/test/spirv_run/xor.spvtxt55
-rw-r--r--ptx/src/translate.rs249
12 files changed, 576 insertions, 40 deletions
diff --git a/ptx/lib/notcuda_ptx_impl.cl b/ptx/lib/notcuda_ptx_impl.cl
index a0d487b..4249f2b 100644
--- a/ptx/lib/notcuda_ptx_impl.cl
+++ b/ptx/lib/notcuda_ptx_impl.cl
@@ -1,5 +1,5 @@
// Every time this file changes it must te rebuilt:
-// ocloc -file notcuda_ptx_impl.cl -64 -options "-cl-std=CL2.0" -out_dir . -device kbl -output_no_suffix -spv_only
+// ocloc -file notcuda_ptx_impl.cl -64 -options "-cl-std=CL2.0 -Dcl_intel_bit_instructions" -out_dir . -device kbl -output_no_suffix -spv_only
// Additionally you should strip names:
// spirv-opt --strip-debug notcuda_ptx_impl.spv -o notcuda_ptx_impl.spv
@@ -119,3 +119,23 @@ atomic_dec(atom_relaxed_sys_shared_dec, memory_order_relaxed, memory_order_relax
atomic_dec(atom_acquire_sys_shared_dec, memory_order_acquire, memory_order_acquire, memory_scope_device, __local);
atomic_dec(atom_release_sys_shared_dec, memory_order_release, memory_order_acquire, memory_scope_device, __local);
atomic_dec(atom_acq_rel_sys_shared_dec, memory_order_acq_rel, memory_order_acquire, memory_scope_device, __local);
+
+uint FUNC(bfe_u32)(uint base, uint pos, uint len)
+{
+ return intel_ubfe(base, pos, len);
+}
+
+ulong FUNC(bfe_u64)(ulong base, uint pos, uint len)
+{
+ return intel_ubfe(base, pos, len);
+}
+
+int FUNC(bfe_s32)(int base, uint pos, uint len)
+{
+ return intel_sbfe(base, pos, len);
+}
+
+long FUNC(bfe_s64)(long base, uint pos, uint len)
+{
+ return intel_sbfe(base, pos, len);
+} \ No newline at end of file
diff --git a/ptx/lib/notcuda_ptx_impl.spv b/ptx/lib/notcuda_ptx_impl.spv
index 36f37bb..1ef470f 100644
--- a/ptx/lib/notcuda_ptx_impl.spv
+++ b/ptx/lib/notcuda_ptx_impl.spv
Binary files differ
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index b6ac3db..5a5f6be 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -558,7 +558,7 @@ pub enum Instruction<P: ArgParams> {
Add(ArithDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>),
SetpBool(SetpBoolData, Arg5<P>),
- Not(NotType, Arg2<P>),
+ Not(BooleanType, Arg2<P>),
Bra(BraData, Arg1<P>),
Cvt(CvtDetails, Arg2<P>),
Cvta(CvtaDetails, Arg2<P>),
@@ -569,12 +569,12 @@ pub enum Instruction<P: ArgParams> {
Call(CallInst<P>),
Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>),
- Or(OrAndType, Arg3<P>),
+ Or(BooleanType, Arg3<P>),
Sub(ArithDetails, Arg3<P>),
Min(MinMaxDetails, Arg3<P>),
Max(MinMaxDetails, Arg3<P>),
Rcp(RcpDetails, Arg2<P>),
- And(OrAndType, Arg3<P>),
+ And(BooleanType, Arg3<P>),
Selp(SelpType, Arg4<P>),
Bar(BarDetails, Arg1Bar<P>),
Atom(AtomDetails, Arg3<P>),
@@ -590,6 +590,9 @@ pub enum Instruction<P: ArgParams> {
Clz { typ: BitType, arg: Arg2<P> },
Brev { typ: BitType, arg: Arg2<P> },
Popc { typ: BitType, arg: Arg2<P> },
+ Xor { typ: BooleanType, arg: Arg3<P> },
+ Bfe { typ: IntType, arg: Arg4<P> },
+ Rem { typ: IntType, arg: Arg3<P> },
}
#[derive(Copy, Clone)]
@@ -896,14 +899,6 @@ pub struct SetpBoolData {
pub bool_op: SetpBoolPostOp,
}
-#[derive(PartialEq, Eq, Copy, Clone)]
-pub enum NotType {
- Pred,
- B16,
- B32,
- B64,
-}
-
pub struct BraData {
pub uniform: bool,
}
@@ -1058,7 +1053,7 @@ pub struct RetData {
pub uniform: bool,
}
-sub_enum!(OrAndType {
+sub_enum!(BooleanType {
Pred,
B16,
B32,
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index cd1c642..6c231b2 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -142,6 +142,7 @@ match {
"atom",
"bar",
"barrier",
+ "bfe",
"bra",
"brev",
"call",
@@ -166,6 +167,7 @@ match {
"or",
"popc",
"rcp",
+ "rem",
"ret",
"rsqrt",
"selp",
@@ -179,6 +181,7 @@ match {
"sub",
"texmode_independent",
"texmode_unified",
+ "xor",
} else {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#identifiers
r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+" => ID,
@@ -192,6 +195,7 @@ ExtendedID : &'input str = {
"atom",
"bar",
"barrier",
+ "bfe",
"bra",
"brev",
"call",
@@ -216,6 +220,7 @@ ExtendedID : &'input str = {
"or",
"popc",
"rcp",
+ "rem",
"ret",
"rsqrt",
"selp",
@@ -229,6 +234,7 @@ ExtendedID : &'input str = {
"sub",
"texmode_independent",
"texmode_unified",
+ "xor",
ID
}
@@ -708,6 +714,9 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstClz,
InstBrev,
InstPopc,
+ InstXor,
+ InstRem,
+ InstBfe,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -874,6 +883,13 @@ IntType : ast::IntType = {
".s64" => ast::IntType::S64,
};
+IntType3264: ast::IntType = {
+ ".u32" => ast::IntType::U32,
+ ".u64" => ast::IntType::U64,
+ ".s32" => ast::IntType::S32,
+ ".s64" => ast::IntType::S64,
+}
+
UIntType: ast::UIntType = {
".u16" => ast::UIntType::U16,
".u32" => ast::UIntType::U32,
@@ -979,14 +995,14 @@ SetpTypeNoF32: ast::ScalarType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "not" <t:NotType> <a:Arg2> => ast::Instruction::Not(t, a)
+ "not" <t:BooleanType> <a:Arg2> => ast::Instruction::Not(t, a)
};
-NotType: ast::NotType = {
- ".pred" => ast::NotType::Pred,
- ".b16" => ast::NotType::B16,
- ".b32" => ast::NotType::B32,
- ".b64" => ast::NotType::B64,
+BooleanType: ast::BooleanType = {
+ ".pred" => ast::BooleanType::Pred,
+ ".b16" => ast::BooleanType::B16,
+ ".b32" => ast::BooleanType::B32,
+ ".b64" => ast::BooleanType::B64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at
@@ -1294,19 +1310,12 @@ SignedIntType: ast::ScalarType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or
InstOr: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "or" <d:OrAndType> <a:Arg3> => ast::Instruction::Or(d, a),
+ "or" <d:BooleanType> <a:Arg3> => ast::Instruction::Or(d, a),
};
-OrAndType: ast::OrAndType = {
- ".pred" => ast::OrAndType::Pred,
- ".b16" => ast::OrAndType::B16,
- ".b32" => ast::OrAndType::B32,
- ".b64" => ast::OrAndType::B64,
-}
-
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and
InstAnd: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "and" <d:OrAndType> <a:Arg3> => ast::Instruction::And(d, a),
+ "and" <d:BooleanType> <a:Arg3> => ast::Instruction::And(d, a),
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp
@@ -1447,7 +1456,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
};
ast::Instruction::Atom(details,a)
},
- "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op: AtomUIntOp> <typ:AtomUIntType> <a:Arg3Atom> => {
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op: AtomUIntOp> <typ:UIntType3264> <a:Arg3Atom> => {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
@@ -1456,7 +1465,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
};
ast::Instruction::Atom(details,a)
},
- "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op: AtomSIntOp> <typ:AtomSIntType> <a:Arg3Atom> => {
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op: AtomSIntOp> <typ:SIntType3264> <a:Arg3Atom> => {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
@@ -1515,12 +1524,12 @@ BitType: ast::BitType = {
".b64" => ast::BitType::B64,
}
-AtomUIntType: ast::UIntType = {
+UIntType3264: ast::UIntType = {
".u32" => ast::UIntType::U32,
".u64" => ast::UIntType::U64,
}
-AtomSIntType: ast::SIntType = {
+SIntType3264: ast::SIntType = {
".s32" => ast::SIntType::S32,
".s64" => ast::SIntType::S64,
}
@@ -1664,6 +1673,22 @@ InstPopc: ast::Instruction<ast::ParsedArgParams<'input>> = {
"popc" <typ:BitType> <arg:Arg2> => ast::Instruction::Popc{ <> }
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-xor
+InstXor: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "xor" <typ:BooleanType> <arg:Arg3> => ast::Instruction::Xor{ <> }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfe
+InstBfe: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "bfe" <typ:IntType3264> <arg:Arg4> => ast::Instruction::Bfe{ <> }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem
+InstRem: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "rem" <typ:IntType> <arg:Arg3> => ast::Instruction::Rem{ <> }
+}
+
+
NegTypeFtz: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
diff --git a/ptx/src/test/spirv_run/bfe.ptx b/ptx/src/test/spirv_run/bfe.ptx
new file mode 100644
index 0000000..60ee8a6
--- /dev/null
+++ b/ptx/src/test/spirv_run/bfe.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry bfe(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp<3>;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u32 temp0, [in_addr];
+ ld.u32 temp1, [in_addr+4];
+ ld.u32 temp2, [in_addr+8];
+ bfe.u32 temp0, temp0, temp1, temp2;
+ st.u32 [out_addr], temp0;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/bfe.spvtxt b/ptx/src/test/spirv_run/bfe.spvtxt
new file mode 100644
index 0000000..edcf138
--- /dev/null
+++ b/ptx/src/test/spirv_run/bfe.spvtxt
@@ -0,0 +1,70 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %40 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "bfe"
+ OpDecorate %34 LinkageAttributes "__notcuda_ptx_impl__bfe_u32" Import
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %43 = OpTypeFunction %uint %uint %uint %uint
+ %ulong = OpTypeInt 64 0
+ %45 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %ulong_4 = OpConstant %ulong 4
+ %ulong_8 = OpConstant %ulong 8
+ %34 = OpFunction %uint None %43
+ %36 = OpFunctionParameter %uint
+ %37 = OpFunctionParameter %uint
+ %38 = OpFunctionParameter %uint
+ OpFunctionEnd
+ %1 = OpFunction %void None %45
+ %9 = OpFunctionParameter %ulong
+ %10 = OpFunctionParameter %ulong
+ %33 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ %8 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %9
+ OpStore %3 %10
+ %11 = OpLoad %ulong %2
+ OpStore %4 %11
+ %12 = OpLoad %ulong %3
+ OpStore %5 %12
+ %14 = OpLoad %ulong %4
+ %29 = OpConvertUToPtr %_ptr_Generic_uint %14
+ %13 = OpLoad %uint %29
+ OpStore %6 %13
+ %16 = OpLoad %ulong %4
+ %26 = OpIAdd %ulong %16 %ulong_4
+ %30 = OpConvertUToPtr %_ptr_Generic_uint %26
+ %15 = OpLoad %uint %30
+ OpStore %7 %15
+ %18 = OpLoad %ulong %4
+ %28 = OpIAdd %ulong %18 %ulong_8
+ %31 = OpConvertUToPtr %_ptr_Generic_uint %28
+ %17 = OpLoad %uint %31
+ OpStore %8 %17
+ %20 = OpLoad %uint %6
+ %21 = OpLoad %uint %7
+ %22 = OpLoad %uint %8
+ %19 = OpFunctionCall %uint %34 %20 %21 %22
+ OpStore %6 %19
+ %23 = OpLoad %ulong %5
+ %24 = OpLoad %uint %6
+ %32 = OpConvertUToPtr %_ptr_Generic_uint %23
+ OpStore %32 %24
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index a7ef75b..5bbe45a 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -116,6 +116,20 @@ test_ptx!(
[0b11000111_01011100_10101110_11111011u32],
[0b11011111_01110101_00111010_11100011u32]
);
+test_ptx!(
+ xor,
+ [
+ 0b01010010_00011010_01000000_00001101u32,
+ 0b11100110_10011011_00001100_00100011u32
+ ],
+ [0b10110100100000010100110000101110u32]
+);
+test_ptx!(rem, [21692i32, 13i32], [8i32]);
+test_ptx!(
+ bfe,
+ [0b11111000_11000001_00100010_10100000u32, 16u32, 8u32],
+ [0b11000001u32]
+);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/rem.ptx b/ptx/src/test/spirv_run/rem.ptx
new file mode 100644
index 0000000..2ac482d
--- /dev/null
+++ b/ptx/src/test/spirv_run/rem.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry rem(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp1;
+ .reg .s32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp1, [in_addr];
+ ld.s32 temp2, [in_addr+4];
+ rem.s32 temp1, temp1, temp2;
+ st.s32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/rem.spvtxt b/ptx/src/test/spirv_run/rem.spvtxt
new file mode 100644
index 0000000..72d0965
--- /dev/null
+++ b/ptx/src/test/spirv_run/rem.spvtxt
@@ -0,0 +1,55 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %28 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "rem"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %31 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %ulong_4 = OpConstant %ulong 4
+ %1 = OpFunction %void None %31
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %26 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %10 = OpLoad %ulong %2
+ OpStore %4 %10
+ %11 = OpLoad %ulong %3
+ OpStore %5 %11
+ %13 = OpLoad %ulong %4
+ %23 = OpConvertUToPtr %_ptr_Generic_uint %13
+ %12 = OpLoad %uint %23
+ OpStore %6 %12
+ %15 = OpLoad %ulong %4
+ %22 = OpIAdd %ulong %15 %ulong_4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %22
+ %14 = OpLoad %uint %24
+ OpStore %7 %14
+ %17 = OpLoad %uint %6
+ %18 = OpLoad %uint %7
+ %16 = OpSMod %uint %17 %18
+ OpStore %6 %16
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %uint %6
+ %25 = OpConvertUToPtr %_ptr_Generic_uint %19
+ OpStore %25 %20
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/xor.ptx b/ptx/src/test/spirv_run/xor.ptx
new file mode 100644
index 0000000..a28b321
--- /dev/null
+++ b/ptx/src/test/spirv_run/xor.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry xor(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b32 temp1;
+ .reg .b32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.b32 temp1, [in_addr];
+ ld.b32 temp2, [in_addr+4];
+ xor.b32 temp1, temp1, temp2;
+ st.b32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/xor.spvtxt b/ptx/src/test/spirv_run/xor.spvtxt
new file mode 100644
index 0000000..ee09898
--- /dev/null
+++ b/ptx/src/test/spirv_run/xor.spvtxt
@@ -0,0 +1,55 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %28 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "xor"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %31 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %ulong_4 = OpConstant %ulong 4
+ %1 = OpFunction %void None %31
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %26 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %10 = OpLoad %ulong %2
+ OpStore %4 %10
+ %11 = OpLoad %ulong %3
+ OpStore %5 %11
+ %13 = OpLoad %ulong %4
+ %23 = OpConvertUToPtr %_ptr_Generic_uint %13
+ %12 = OpLoad %uint %23
+ OpStore %6 %12
+ %15 = OpLoad %ulong %4
+ %22 = OpIAdd %ulong %15 %ulong_4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %22
+ %14 = OpLoad %uint %24
+ OpStore %7 %14
+ %17 = OpLoad %uint %6
+ %18 = OpLoad %uint %7
+ %16 = OpBitwiseXor %uint %17 %18
+ OpStore %6 %16
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %uint %6
+ %25 = OpConvertUToPtr %_ptr_Generic_uint %19
+ OpStore %25 %20
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 23a63be..365d1e8 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1289,6 +1289,9 @@ fn extract_globals<'input, 'b>(
..
},
) => global.push(var),
+ Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => {
+ local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg));
+ }
Statement::Instruction(ast::Instruction::Atom(
d
@
@@ -1591,6 +1594,24 @@ fn convert_to_typed_statements(
arg: arg.cast(),
}))
}
+ ast::Instruction::Xor { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Xor {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
+ ast::Instruction::Bfe { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Bfe {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
+ ast::Instruction::Rem { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Rem {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1610,6 +1631,7 @@ fn convert_to_typed_statements(
Ok(result)
}
+//TODO: share common code between this and to_ptx_impl_bfe_call
fn to_ptx_impl_atomic_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
@@ -1705,6 +1727,100 @@ fn to_ptx_impl_atomic_call(
})
}
+fn to_ptx_impl_bfe_call(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ typ: ast::IntType,
+ arg: ast::Arg4<ExpandedArgParams>,
+) -> ExpandedStatement {
+ let prefix = "__notcuda_ptx_impl__";
+ let suffix = match typ {
+ ast::IntType::U32 => "bfe_u32",
+ ast::IntType::U64 => "bfe_u64",
+ ast::IntType::S32 => "bfe_s32",
+ ast::IntType::S64 => "bfe_s64",
+ _ => unreachable!(),
+ };
+ let fn_name = format!("{}{}", prefix, suffix);
+ let fn_id = match ptx_impl_imports.entry(fn_name) {
+ hash_map::Entry::Vacant(entry) => {
+ let fn_id = id_defs.new_id(None);
+ let func_decl = ast::MethodDecl::Func::<spirv::Word>(
+ vec![ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ }],
+ fn_id,
+ vec![
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
+ ast::ScalarType::U32,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
+ ast::ScalarType::U32,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ],
+ );
+ let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ let func = Function {
+ func_decl,
+ globals: Vec::new(),
+ body: None,
+ import_as: Some(entry.key().clone()),
+ spirv_decl,
+ };
+ entry.insert(Directive::Method(func));
+ fn_id
+ }
+ hash_map::Entry::Occupied(entry) => match entry.get() {
+ Directive::Method(Function {
+ func_decl: ast::MethodDecl::Func(_, name, _),
+ ..
+ }) => *name,
+ _ => unreachable!(),
+ },
+ };
+ Statement::Call(ResolvedCall {
+ uniform: false,
+ func: fn_id,
+ ret_params: vec![(
+ arg.dst,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ )],
+ param_list: vec![
+ (
+ arg.src1,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ),
+ (
+ arg.src2,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ),
+ (
+ arg.src3,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ),
+ ],
+ })
+}
+
fn to_resolved_fn_args<T>(
params: Vec<T>,
params_decl: &[ast::FnArgumentType],
@@ -2803,7 +2919,7 @@ fn emit_function_body_ops(
let result_id = Some(a.dst);
let operand = a.src;
match t {
- ast::NotType::Pred => {
+ ast::BooleanType::Pred => {
// HACK ALERT
// Temporary workaround until IGC gets its shit together
// Currently IGC carries two copies of SPIRV-LLVM translator
@@ -2854,7 +2970,7 @@ fn emit_function_body_ops(
},
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::OrAndType::Pred {
+ if *t == ast::BooleanType::Pred {
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -2882,7 +2998,7 @@ fn emit_function_body_ops(
}
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::OrAndType::Pred {
+ if *t == ast::BooleanType::Pred {
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -3033,6 +3149,39 @@ fn emit_function_body_ops(
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder.bit_count(result_type, Some(arg.dst), arg.src)?;
}
+ ast::Instruction::Xor { typ, arg } => {
+ let builder_fn = match typ {
+ ast::BooleanType::Pred => emit_logical_xor_spirv,
+ _ => dr::Builder::bitwise_xor,
+ };
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::Instruction::Bfe { typ, arg } => {
+ let builder_fn = if typ.is_signed() {
+ dr::Builder::bit_field_s_extract
+ } else {
+ dr::Builder::bit_field_u_extract
+ };
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder_fn(
+ builder,
+ result_type,
+ Some(arg.dst),
+ arg.src1,
+ arg.src2,
+ arg.src3,
+ )?;
+ }
+ ast::Instruction::Rem { typ, arg } => {
+ let builder_fn = if typ.is_signed() {
+ dr::Builder::s_mod
+ } else {
+ dr::Builder::u_mod
+ };
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -3079,6 +3228,20 @@ fn emit_function_body_ops(
Ok(())
}
+// TODO: check what kind of assembly do we emit
+fn emit_logical_xor_spirv(
+ builder: &mut dr::Builder,
+ result_type: spirv::Word,
+ result_id: Option<spirv::Word>,
+ op1: spirv::Word,
+ op2: spirv::Word,
+) -> Result<spirv::Word, dr::Error> {
+ let temp_or = builder.logical_or(result_type, None, op1, op2)?;
+ let temp_and = builder.logical_and(result_type, None, op1, op2)?;
+ let temp_neg = builder.logical_not(result_type, None, temp_and)?;
+ builder.logical_and(result_type, result_id, temp_or, temp_neg)
+}
+
fn emit_sqrt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -5039,6 +5202,27 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
}
}
+ ast::Instruction::Xor { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Xor {
+ typ,
+ arg: arg.map_non_shift(visitor, &full_type, false)?,
+ }
+ }
+ ast::Instruction::Bfe { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Bfe {
+ typ,
+ arg: arg.map_bfe(visitor, &full_type)?,
+ }
+ }
+ ast::Instruction::Rem { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Rem {
+ typ,
+ arg: arg.map_non_shift(visitor, &full_type, false)?,
+ }
+ }
})
}
}
@@ -5351,6 +5535,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Clz { .. } => None,
ast::Instruction::Brev { .. } => None,
ast::Instruction::Popc { .. } => None,
+ ast::Instruction::Xor { .. } => None,
+ ast::Instruction::Bfe { .. } => None,
+ ast::Instruction::Rem { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -6192,6 +6379,52 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
src3,
})
}
+
+ fn map_bfe<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ typ: &ast::Type,
+ ) -> Result<ast::Arg4<U>, TranslateError> {
+ let dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(typ),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ typ,
+ )?;
+ let u32_type = ast::Type::Scalar(ast::ScalarType::U32);
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &u32_type,
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &u32_type,
+ )?;
+ Ok(ast::Arg4 {
+ dst,
+ src1,
+ src2,
+ src3,
+ })
+ }
}
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
@@ -6437,13 +6670,13 @@ impl ast::ScalarType {
}
}
-impl ast::NotType {
+impl ast::BooleanType {
fn to_type(self) -> ast::Type {
match self {
- ast::NotType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
- ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
+ ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
+ ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
+ ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
}
}
}