aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/ast.rs3
-rw-r--r--ptx/src/ptx.lalrpop34
-rw-r--r--ptx/src/test/spirv_run/brev.ptx21
-rw-r--r--ptx/src/test/spirv_run/brev.spvtxt47
-rw-r--r--ptx/src/test/spirv_run/clz.ptx21
-rw-r--r--ptx/src/test/spirv_run/clz.spvtxt47
-rw-r--r--ptx/src/test/spirv_run/mod.rs9
-rw-r--r--ptx/src/test/spirv_run/popc.ptx21
-rw-r--r--ptx/src/test/spirv_run/popc.spvtxt47
-rw-r--r--ptx/src/translate.rs66
10 files changed, 308 insertions, 8 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 653060b..b6ac3db 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -587,6 +587,9 @@ pub enum Instruction<P: ArgParams> {
Cos { flush_to_zero: bool, arg: Arg2<P> },
Lg2 { flush_to_zero: bool, arg: Arg2<P> },
Ex2 { flush_to_zero: bool, arg: Arg2<P> },
+ Clz { typ: BitType, arg: Arg2<P> },
+ Brev { typ: BitType, arg: Arg2<P> },
+ Popc { typ: BitType, arg: Arg2<P> },
}
#[derive(Copy, Clone)]
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 31c2356..cd1c642 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -143,7 +143,9 @@ match {
"bar",
"barrier",
"bra",
+ "brev",
"call",
+ "clz",
"cos",
"cvt",
"cvta",
@@ -162,6 +164,7 @@ match {
"neg",
"not",
"or",
+ "popc",
"rcp",
"ret",
"rsqrt",
@@ -190,7 +193,9 @@ ExtendedID : &'input str = {
"bar",
"barrier",
"bra",
+ "brev",
"call",
+ "clz",
"cos",
"cvt",
"cvta",
@@ -209,6 +214,7 @@ ExtendedID : &'input str = {
"neg",
"not",
"or",
+ "popc",
"rcp",
"ret",
"rsqrt",
@@ -699,6 +705,9 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstCos,
InstLg2,
InstEx2,
+ InstClz,
+ InstBrev,
+ InstPopc,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -1395,7 +1404,7 @@ InstBar: ast::Instruction<ast::ParsedArgParams<'input>> = {
// * Operation .dec requires .u32 type for instuction
// Otherwise as documented
InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op:AtomBitOp> <typ:AtomBitType> <a:Arg3Atom> => {
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op:AtomBitOp> <typ:BitType> <a:Arg3Atom> => {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
@@ -1459,7 +1468,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
InstAtomCas: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".cas" <typ:AtomBitType> <a:Arg4Atom> => {
+ "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".cas" <typ:BitType> <a:Arg4Atom> => {
let details = ast::AtomCasDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
@@ -1501,7 +1510,7 @@ AtomSIntOp: ast::AtomSIntOp = {
".max" => ast::AtomSIntOp::Max,
}
-AtomBitType: ast::BitType = {
+BitType: ast::BitType = {
".b32" => ast::BitType::B32,
".b64" => ast::BitType::B64,
}
@@ -1640,6 +1649,21 @@ InstEx2: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz
+InstClz: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "clz" <typ:BitType> <arg:Arg2> => ast::Instruction::Clz{ <> }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev
+InstBrev: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "brev" <typ:BitType> <arg:Arg2> => ast::Instruction::Brev{ <> }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc
+InstPopc: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "popc" <typ:BitType> <arg:Arg2> => ast::Instruction::Popc{ <> }
+}
+
NegTypeFtz: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
@@ -1858,7 +1882,7 @@ Section = {
};
SectionDwarfLines: () = {
- BitType Comma<U32Num>,
+ AnyBitType Comma<U32Num>,
".b32" SectionLabel,
".b64" SectionLabel,
".b32" SectionLabel "+" U32Num,
@@ -1870,7 +1894,7 @@ SectionLabel = {
DotID
};
-BitType = {
+AnyBitType = {
".b8", ".b16", ".b32", ".b64"
};
diff --git a/ptx/src/test/spirv_run/brev.ptx b/ptx/src/test/spirv_run/brev.ptx
new file mode 100644
index 0000000..1d9dd75
--- /dev/null
+++ b/ptx/src/test/spirv_run/brev.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry brev(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.b32 temp, [in_addr];
+ brev.b32 temp, temp;
+ st.b32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/brev.spvtxt b/ptx/src/test/spirv_run/brev.spvtxt
new file mode 100644
index 0000000..df5df53
--- /dev/null
+++ b/ptx/src/test/spirv_run/brev.spvtxt
@@ -0,0 +1,47 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %21 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "brev"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %24 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %1 = OpFunction %void None %24
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %19 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %9 = OpLoad %ulong %2
+ OpStore %4 %9
+ %10 = OpLoad %ulong %3
+ OpStore %5 %10
+ %12 = OpLoad %ulong %4
+ %17 = OpConvertUToPtr %_ptr_Generic_uint %12
+ %11 = OpLoad %uint %17
+ OpStore %6 %11
+ %14 = OpLoad %uint %6
+ %13 = OpBitReverse %uint %14
+ OpStore %6 %13
+ %15 = OpLoad %ulong %5
+ %16 = OpLoad %uint %6
+ %18 = OpConvertUToPtr %_ptr_Generic_uint %15
+ OpStore %18 %16
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/clz.ptx b/ptx/src/test/spirv_run/clz.ptx
new file mode 100644
index 0000000..b475b90
--- /dev/null
+++ b/ptx/src/test/spirv_run/clz.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry clz(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.b32 temp, [in_addr];
+ clz.b32 temp, temp;
+ st.b32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/clz.spvtxt b/ptx/src/test/spirv_run/clz.spvtxt
new file mode 100644
index 0000000..5d1ebc8
--- /dev/null
+++ b/ptx/src/test/spirv_run/clz.spvtxt
@@ -0,0 +1,47 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %21 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "clz"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %24 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %1 = OpFunction %void None %24
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %19 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %9 = OpLoad %ulong %2
+ OpStore %4 %9
+ %10 = OpLoad %ulong %3
+ OpStore %5 %10
+ %12 = OpLoad %ulong %4
+ %17 = OpConvertUToPtr %_ptr_Generic_uint %12
+ %11 = OpLoad %uint %17
+ OpStore %6 %11
+ %14 = OpLoad %uint %6
+ %13 = OpExtInst %uint %21 clz %14
+ OpStore %6 %13
+ %15 = OpLoad %ulong %5
+ %16 = OpLoad %uint %6
+ %18 = OpConvertUToPtr %_ptr_Generic_uint %15
+ OpStore %18 %16
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 163caac..a7ef75b 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -104,11 +104,18 @@ test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
test_ptx!(sqrt, [0.25f32], [0.5f32]);
test_ptx!(rsqrt, [0.25f64], [2f64]);
test_ptx!(neg, [181i32], [-181i32]);
-test_ptx!(sin, [std::f32::consts::PI/2f32], [1f32]);
+test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]);
test_ptx!(cos, [std::f32::consts::PI], [-1f32]);
test_ptx!(lg2, [512f32], [9f32]);
test_ptx!(ex2, [10f32], [1024f32]);
test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
+test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
+test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
+test_ptx!(
+ brev,
+ [0b11000111_01011100_10101110_11111011u32],
+ [0b11011111_01110101_00111010_11100011u32]
+);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/popc.ptx b/ptx/src/test/spirv_run/popc.ptx
new file mode 100644
index 0000000..7106422
--- /dev/null
+++ b/ptx/src/test/spirv_run/popc.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry popc(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.b32 temp, [in_addr];
+ popc.b32 temp, temp;
+ st.b32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/popc.spvtxt b/ptx/src/test/spirv_run/popc.spvtxt
new file mode 100644
index 0000000..bb4968f
--- /dev/null
+++ b/ptx/src/test/spirv_run/popc.spvtxt
@@ -0,0 +1,47 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %21 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "popc"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %24 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %1 = OpFunction %void None %24
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %19 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %9 = OpLoad %ulong %2
+ OpStore %4 %9
+ %10 = OpLoad %ulong %3
+ OpStore %5 %10
+ %12 = OpLoad %ulong %4
+ %17 = OpConvertUToPtr %_ptr_Generic_uint %12
+ %11 = OpLoad %uint %17
+ OpStore %6 %11
+ %14 = OpLoad %uint %6
+ %13 = OpBitCount %uint %14
+ OpStore %6 %13
+ %15 = OpLoad %ulong %5
+ %16 = OpLoad %uint %6
+ %18 = OpConvertUToPtr %_ptr_Generic_uint %15
+ OpStore %18 %16
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 9519951..23a63be 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1573,6 +1573,24 @@ fn convert_to_typed_statements(
arg: arg.cast(),
}))
}
+ ast::Instruction::Clz { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Clz {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
+ ast::Instruction::Brev { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Brev {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
+ ast::Instruction::Popc { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Popc {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -2997,6 +3015,24 @@ fn emit_function_body_ops(
[arg.src],
)?;
}
+ ast::Instruction::Clz { typ, arg } => {
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder.ext_inst(
+ result_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::clz as u32,
+ [arg.src],
+ )?;
+ }
+ ast::Instruction::Brev { typ, arg } => {
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder.bit_reverse(result_type, Some(arg.dst), arg.src)?;
+ }
+ ast::Instruction::Popc { typ, arg } => {
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder.bit_count(result_type, Some(arg.dst), arg.src)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -4881,7 +4917,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Type::Scalar(desc.src.into()),
),
};
- ast::Instruction::Cvt(d, a.map_cvt(visitor, &dst_t, &src_t)?)
+ ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
@@ -4980,6 +5016,29 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
arg: arg.map(visitor, &typ)?,
}
}
+ ast::Instruction::Clz { typ, arg } => {
+ let dst_type = ast::Type::Scalar(ast::ScalarType::B32);
+ let src_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Clz {
+ typ,
+ arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
+ }
+ }
+ ast::Instruction::Brev { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Brev {
+ typ,
+ arg: arg.map(visitor, &full_type)?,
+ }
+ }
+ ast::Instruction::Popc { typ, arg } => {
+ let dst_type = ast::Type::Scalar(ast::ScalarType::B32);
+ let src_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Popc {
+ typ,
+ arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
+ }
+ }
})
}
}
@@ -5289,6 +5348,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
+ ast::Instruction::Clz { .. } => None,
+ ast::Instruction::Brev { .. } => None,
+ ast::Instruction::Popc { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -5567,7 +5629,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
})
}
- fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ fn map_different_types<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
dst_t: &ast::Type,