diff options
-rw-r--r-- | ptx/src/ast.rs | 3 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 34 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/brev.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/brev.spvtxt | 47 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/clz.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/clz.spvtxt | 47 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 9 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/popc.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/popc.spvtxt | 47 | ||||
-rw-r--r-- | ptx/src/translate.rs | 66 |
10 files changed, 308 insertions, 8 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 653060b..b6ac3db 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -587,6 +587,9 @@ pub enum Instruction<P: ArgParams> { Cos { flush_to_zero: bool, arg: Arg2<P> }, Lg2 { flush_to_zero: bool, arg: Arg2<P> }, Ex2 { flush_to_zero: bool, arg: Arg2<P> }, + Clz { typ: BitType, arg: Arg2<P> }, + Brev { typ: BitType, arg: Arg2<P> }, + Popc { typ: BitType, arg: Arg2<P> }, } #[derive(Copy, Clone)] diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 31c2356..cd1c642 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -143,7 +143,9 @@ match { "bar", "barrier", "bra", + "brev", "call", + "clz", "cos", "cvt", "cvta", @@ -162,6 +164,7 @@ match { "neg", "not", "or", + "popc", "rcp", "ret", "rsqrt", @@ -190,7 +193,9 @@ ExtendedID : &'input str = { "bar", "barrier", "bra", + "brev", "call", + "clz", "cos", "cvt", "cvta", @@ -209,6 +214,7 @@ ExtendedID : &'input str = { "neg", "not", "or", + "popc", "rcp", "ret", "rsqrt", @@ -699,6 +705,9 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstCos, InstLg2, InstEx2, + InstClz, + InstBrev, + InstPopc, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -1395,7 +1404,7 @@ InstBar: ast::Instruction<ast::ParsedArgParams<'input>> = { // * Operation .dec requires .u32 type for instuction // Otherwise as documented InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { - "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op:AtomBitOp> <typ:AtomBitType> <a:Arg3Atom> => { + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> <op:AtomBitOp> <typ:BitType> <a:Arg3Atom> => { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), @@ -1459,7 +1468,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { } InstAtomCas: ast::Instruction<ast::ParsedArgParams<'input>> = { - "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".cas" <typ:AtomBitType> <a:Arg4Atom> => { + "atom" <sema:AtomSemantics?> <scope:MemScope?> <space:AtomSpace?> ".cas" <typ:BitType> <a:Arg4Atom> => { let details = ast::AtomCasDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), @@ -1501,7 +1510,7 @@ AtomSIntOp: ast::AtomSIntOp = { ".max" => ast::AtomSIntOp::Max, } -AtomBitType: ast::BitType = { +BitType: ast::BitType = { ".b32" => ast::BitType::B32, ".b64" => ast::BitType::B64, } @@ -1640,6 +1649,21 @@ InstEx2: ast::Instruction<ast::ParsedArgParams<'input>> = { }, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz +InstClz: ast::Instruction<ast::ParsedArgParams<'input>> = { + "clz" <typ:BitType> <arg:Arg2> => ast::Instruction::Clz{ <> } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev +InstBrev: ast::Instruction<ast::ParsedArgParams<'input>> = { + "brev" <typ:BitType> <arg:Arg2> => ast::Instruction::Brev{ <> } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc +InstPopc: ast::Instruction<ast::ParsedArgParams<'input>> = { + "popc" <typ:BitType> <arg:Arg2> => ast::Instruction::Popc{ <> } +} + NegTypeFtz: ast::ScalarType = { ".f16" => ast::ScalarType::F16, ".f16x2" => ast::ScalarType::F16x2, @@ -1858,7 +1882,7 @@ Section = { }; SectionDwarfLines: () = { - BitType Comma<U32Num>, + AnyBitType Comma<U32Num>, ".b32" SectionLabel, ".b64" SectionLabel, ".b32" SectionLabel "+" U32Num, @@ -1870,7 +1894,7 @@ SectionLabel = { DotID }; -BitType = { +AnyBitType = { ".b8", ".b16", ".b32", ".b64" }; diff --git a/ptx/src/test/spirv_run/brev.ptx b/ptx/src/test/spirv_run/brev.ptx new file mode 100644 index 0000000..1d9dd75 --- /dev/null +++ b/ptx/src/test/spirv_run/brev.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry brev(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.b32 temp, [in_addr];
+ brev.b32 temp, temp;
+ st.b32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/brev.spvtxt b/ptx/src/test/spirv_run/brev.spvtxt new file mode 100644 index 0000000..df5df53 --- /dev/null +++ b/ptx/src/test/spirv_run/brev.spvtxt @@ -0,0 +1,47 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %21 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "brev" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %24 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %24 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %19 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %9 = OpLoad %ulong %2 + OpStore %4 %9 + %10 = OpLoad %ulong %3 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_Generic_uint %12 + %11 = OpLoad %uint %17 + OpStore %6 %11 + %14 = OpLoad %uint %6 + %13 = OpBitReverse %uint %14 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %uint %6 + %18 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %18 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/clz.ptx b/ptx/src/test/spirv_run/clz.ptx new file mode 100644 index 0000000..b475b90 --- /dev/null +++ b/ptx/src/test/spirv_run/clz.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry clz(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.b32 temp, [in_addr];
+ clz.b32 temp, temp;
+ st.b32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/clz.spvtxt b/ptx/src/test/spirv_run/clz.spvtxt new file mode 100644 index 0000000..5d1ebc8 --- /dev/null +++ b/ptx/src/test/spirv_run/clz.spvtxt @@ -0,0 +1,47 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %21 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "clz" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %24 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %24 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %19 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %9 = OpLoad %ulong %2 + OpStore %4 %9 + %10 = OpLoad %ulong %3 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_Generic_uint %12 + %11 = OpLoad %uint %17 + OpStore %6 %11 + %14 = OpLoad %uint %6 + %13 = OpExtInst %uint %21 clz %14 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %uint %6 + %18 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %18 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 163caac..a7ef75b 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -104,11 +104,18 @@ test_ptx!(div_approx, [1f32, 2f32], [0.5f32]); test_ptx!(sqrt, [0.25f32], [0.5f32]);
test_ptx!(rsqrt, [0.25f64], [2f64]);
test_ptx!(neg, [181i32], [-181i32]);
-test_ptx!(sin, [std::f32::consts::PI/2f32], [1f32]);
+test_ptx!(sin, [std::f32::consts::PI / 2f32], [1f32]);
test_ptx!(cos, [std::f32::consts::PI], [-1f32]);
test_ptx!(lg2, [512f32], [9f32]);
test_ptx!(ex2, [10f32], [1024f32]);
test_ptx!(cvt_rni, [9.5f32, 10.5f32], [10f32, 10f32]);
+test_ptx!(clz, [0b00000101_00101101_00010011_10101011u32], [5u32]);
+test_ptx!(popc, [0b10111100_10010010_01001001_10001010u32], [14u32]);
+test_ptx!(
+ brev,
+ [0b11000111_01011100_10101110_11111011u32],
+ [0b11011111_01110101_00111010_11100011u32]
+);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/popc.ptx b/ptx/src/test/spirv_run/popc.ptx new file mode 100644 index 0000000..7106422 --- /dev/null +++ b/ptx/src/test/spirv_run/popc.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry popc(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.b32 temp, [in_addr];
+ popc.b32 temp, temp;
+ st.b32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/popc.spvtxt b/ptx/src/test/spirv_run/popc.spvtxt new file mode 100644 index 0000000..bb4968f --- /dev/null +++ b/ptx/src/test/spirv_run/popc.spvtxt @@ -0,0 +1,47 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %21 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "popc" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %24 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %1 = OpFunction %void None %24 + %7 = OpFunctionParameter %ulong + %8 = OpFunctionParameter %ulong + %19 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + OpStore %2 %7 + OpStore %3 %8 + %9 = OpLoad %ulong %2 + OpStore %4 %9 + %10 = OpLoad %ulong %3 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %17 = OpConvertUToPtr %_ptr_Generic_uint %12 + %11 = OpLoad %uint %17 + OpStore %6 %11 + %14 = OpLoad %uint %6 + %13 = OpBitCount %uint %14 + OpStore %6 %13 + %15 = OpLoad %ulong %5 + %16 = OpLoad %uint %6 + %18 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %18 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 9519951..23a63be 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1573,6 +1573,24 @@ fn convert_to_typed_statements( arg: arg.cast(),
}))
}
+ ast::Instruction::Clz { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Clz {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
+ ast::Instruction::Brev { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Brev {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
+ ast::Instruction::Popc { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Popc {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -2997,6 +3015,24 @@ fn emit_function_body_ops( [arg.src],
)?;
}
+ ast::Instruction::Clz { typ, arg } => {
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder.ext_inst(
+ result_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::clz as u32,
+ [arg.src],
+ )?;
+ }
+ ast::Instruction::Brev { typ, arg } => {
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder.bit_reverse(result_type, Some(arg.dst), arg.src)?;
+ }
+ ast::Instruction::Popc { typ, arg } => {
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder.bit_count(result_type, Some(arg.dst), arg.src)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -4881,7 +4917,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Type::Scalar(desc.src.into()),
),
};
- ast::Instruction::Cvt(d, a.map_cvt(visitor, &dst_t, &src_t)?)
+ ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
@@ -4980,6 +5016,29 @@ impl<T: ArgParamsEx> ast::Instruction<T> { arg: arg.map(visitor, &typ)?,
}
}
+ ast::Instruction::Clz { typ, arg } => {
+ let dst_type = ast::Type::Scalar(ast::ScalarType::B32);
+ let src_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Clz {
+ typ,
+ arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
+ }
+ }
+ ast::Instruction::Brev { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Brev {
+ typ,
+ arg: arg.map(visitor, &full_type)?,
+ }
+ }
+ ast::Instruction::Popc { typ, arg } => {
+ let dst_type = ast::Type::Scalar(ast::ScalarType::B32);
+ let src_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Popc {
+ typ,
+ arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
+ }
+ }
})
}
}
@@ -5289,6 +5348,9 @@ impl ast::Instruction<ExpandedArgParams> { ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
+ ast::Instruction::Clz { .. } => None,
+ ast::Instruction::Brev { .. } => None,
+ ast::Instruction::Popc { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -5567,7 +5629,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> { })
}
- fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ fn map_different_types<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
dst_t: &ast::Type,
|