diff options
-rw-r--r-- | ptx/src/ast.rs | 45 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 94 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/div_approx.ptx | 23 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/div_approx.spvtxt | 65 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 4 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/rsqrt.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/rsqrt.spvtxt | 56 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/shared_ptr_32.ptx | 29 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/shared_ptr_32.spvtxt | 74 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/sqrt.ptx | 21 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/sqrt.spvtxt | 56 | ||||
-rw-r--r-- | ptx/src/translate.rs | 201 |
12 files changed, 645 insertions, 44 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index ad8e87d..f00ddce 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -539,6 +539,9 @@ pub enum Instruction<P: ArgParams> { Bar(BarDetails, Arg1Bar<P>), Atom(AtomDetails, Arg3<P>), AtomCas(AtomCasDetails, Arg4<P>), + Div(DivDetails, Arg3<P>), + Sqrt(SqrtDetails, Arg2<P>), + Rsqrt(RsqrtDetails, Arg2<P>), } #[derive(Copy, Clone)] @@ -1132,7 +1135,28 @@ pub struct AtomCasDetails { pub semantics: AtomSemantics, pub scope: MemScope, pub space: AtomSpace, - pub typ: BitType + pub typ: BitType, +} + +#[derive(Copy, Clone)] +pub enum DivDetails { + Unsigned(UIntType), + Signed(SIntType), + Float(DivFloatDetails), +} + +#[derive(Copy, Clone)] +pub struct DivFloatDetails { + pub typ: FloatType, + pub flush_to_zero: Option<bool>, + pub kind: DivFloatKind, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum DivFloatKind { + Approx, + Full, + Rounding(RoundingMode), } pub enum NumsOrArrays<'a> { @@ -1140,6 +1164,25 @@ pub enum NumsOrArrays<'a> { Arrays(Vec<NumsOrArrays<'a>>), } +#[derive(Copy, Clone)] +pub struct SqrtDetails { + pub typ: FloatType, + pub flush_to_zero: Option<bool>, + pub kind: SqrtKind, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum SqrtKind { + Approx, + Rounding(RoundingMode), +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct RsqrtDetails { + pub typ: FloatType, + pub flush_to_zero: bool, +} + impl<'a> NumsOrArrays<'a> { pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> { self.normalize_dimensions(dimensions)?; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 806a3fc..4cf4255 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -66,6 +66,7 @@ match { ".f64", ".file", ".ftz", + ".full", ".func", ".ge", ".geu", @@ -94,6 +95,7 @@ match { ".num", ".or", ".param", + ".pragma", ".pred", ".reg", ".relaxed", @@ -145,6 +147,7 @@ match { "cvt", "cvta", "debug", + "div", "fma", "ld", "mad", @@ -157,11 +160,13 @@ match { "or", "rcp", "ret", + "rsqrt", "selp", "setp", "shl", "shr", r"sm_[0-9]+" => ShaderModel, + "sqrt", "st", "sub", "texmode_independent", @@ -184,6 +189,7 @@ ExtendedID : &'input str = { "cvt", "cvta", "debug", + "div", "fma", "ld", "mad", @@ -196,11 +202,13 @@ ExtendedID : &'input str = { "or", "rcp", "ret", + "rsqrt", "selp", "setp", "shl", "shr", ShaderModel, + "sqrt", "st", "sub", "texmode_independent", @@ -415,9 +423,14 @@ Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = { DebugDirective => None, <v:MultiVariable> ";" => Some(ast::Statement::Variable(v)), <p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)), + PragmaStatement => None, "{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s))) }; +PragmaStatement: () = { + ".pragma" String ";" +} + DebugDirective: () = { DebugLocation }; @@ -667,7 +680,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstSelp, InstBar, InstAtom, - InstAtomCas + InstAtomCas, + InstDiv, + InstSqrt, + InstRsqrt, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -1485,6 +1501,82 @@ AtomSIntType: ast::SIntType = { ".s64" => ast::SIntType::S64, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div +InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = { + "div" <t:UIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Unsigned(t), a), + "div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a), + "div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => { + let inner = ast::DivFloatDetails { + typ: ast::FloatType::F32, + flush_to_zero: Some(ftz.is_some()), + kind + }; + ast::Instruction::Div(ast::DivDetails::Float(inner), a) + }, + "div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => { + let inner = ast::DivFloatDetails { + typ: ast::FloatType::F64, + flush_to_zero: None, + kind: ast::DivFloatKind::Rounding(rnd) + }; + ast::Instruction::Div(ast::DivDetails::Float(inner), a) + }, +} + +DivFloatKind: ast::DivFloatKind = { + ".approx" => ast::DivFloatKind::Approx, + ".full" => ast::DivFloatKind::Full, + <rnd:RoundingModeFloat> => ast::DivFloatKind::Rounding(rnd), +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt +InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { + "sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => { + let details = ast::SqrtDetails { + typ: ast::FloatType::F32, + flush_to_zero: Some(ftz.is_some()), + kind: ast::SqrtKind::Approx, + }; + ast::Instruction::Sqrt(details, a) + }, + "sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => { + let details = ast::SqrtDetails { + typ: ast::FloatType::F32, + flush_to_zero: Some(ftz.is_some()), + kind: ast::SqrtKind::Rounding(rnd), + }; + ast::Instruction::Sqrt(details, a) + }, + "sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => { + let details = ast::SqrtDetails { + typ: ast::FloatType::F64, + flush_to_zero: None, + kind: ast::SqrtKind::Rounding(rnd), + }; + ast::Instruction::Sqrt(details, a) + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64 +InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { + "rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => { + let details = ast::RsqrtDetails { + typ: ast::FloatType::F32, + flush_to_zero: ftz.is_some(), + }; + ast::Instruction::Rsqrt(details, a) + }, + "rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => { + let details = ast::RsqrtDetails { + typ: ast::FloatType::F64, + flush_to_zero: ftz.is_some(), + }; + ast::Instruction::Rsqrt(details, a) + }, +} + ArithDetails: ast::ArithDetails = { <t:UIntType> => ast::ArithDetails::Unsigned(t), <t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt { diff --git a/ptx/src/test/spirv_run/div_approx.ptx b/ptx/src/test/spirv_run/div_approx.ptx new file mode 100644 index 0000000..b25e320 --- /dev/null +++ b/ptx/src/test/spirv_run/div_approx.ptx @@ -0,0 +1,23 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry div_approx(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp1;
+ .reg .f32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp1, [in_addr];
+ ld.f32 temp2, [in_addr+4];
+ div.approx.f32 temp1, temp1, temp2;
+ st.f32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/div_approx.spvtxt b/ptx/src/test/spirv_run/div_approx.spvtxt new file mode 100644 index 0000000..40cc152 --- /dev/null +++ b/ptx/src/test/spirv_run/div_approx.spvtxt @@ -0,0 +1,65 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 38 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%30 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "div_approx" +OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +OpDecorate %18 FPFastMathMode AllowRecip +%31 = OpTypeVoid +%32 = OpTypeInt 64 0 +%33 = OpTypeFunction %31 %32 %32 +%34 = OpTypePointer Function %32 +%35 = OpTypeFloat 32 +%36 = OpTypePointer Function %35 +%37 = OpTypePointer Generic %35 +%23 = OpConstant %32 4 +%1 = OpFunction %31 None %33 +%8 = OpFunctionParameter %32 +%9 = OpFunctionParameter %32 +%28 = OpLabel +%2 = OpVariable %34 Function +%3 = OpVariable %34 Function +%4 = OpVariable %34 Function +%5 = OpVariable %34 Function +%6 = OpVariable %36 Function +%7 = OpVariable %36 Function +OpStore %2 %8 +OpStore %3 %9 +%11 = OpLoad %32 %2 +%10 = OpCopyObject %32 %11 +OpStore %4 %10 +%13 = OpLoad %32 %3 +%12 = OpCopyObject %32 %13 +OpStore %5 %12 +%15 = OpLoad %32 %4 +%25 = OpConvertUToPtr %37 %15 +%14 = OpLoad %35 %25 +OpStore %6 %14 +%17 = OpLoad %32 %4 +%24 = OpIAdd %32 %17 %23 +%26 = OpConvertUToPtr %37 %24 +%16 = OpLoad %35 %26 +OpStore %7 %16 +%19 = OpLoad %35 %6 +%20 = OpLoad %35 %7 +%18 = OpFDiv %35 %19 %20 +OpStore %6 %18 +%21 = OpLoad %32 %5 +%22 = OpLoad %35 %6 +%27 = OpConvertUToPtr %37 %21 +OpStore %27 %22 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 40a9d64..4e9d39f 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -97,9 +97,13 @@ test_ptx!(and, [6u32, 3u32], [2u32]); test_ptx!(selp, [100u16, 200u16], [200u16]);
test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
test_ptx!(shared_variable, [513u64], [513u64]);
+test_ptx!(shared_ptr_32, [513u64], [513u64]);
test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]);
test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]);
test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
+test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
+test_ptx!(sqrt, [0.25f32], [0.5f32]);
+test_ptx!(rsqrt, [0.25f64], [2f64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/rsqrt.ptx b/ptx/src/test/spirv_run/rsqrt.ptx new file mode 100644 index 0000000..5821501 --- /dev/null +++ b/ptx/src/test/spirv_run/rsqrt.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry rsqrt(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f64 temp1;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f64 temp1, [in_addr];
+ rsqrt.approx.f64 temp1, temp1;
+ st.f64 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/rsqrt.spvtxt b/ptx/src/test/spirv_run/rsqrt.spvtxt new file mode 100644 index 0000000..5c3ba97 --- /dev/null +++ b/ptx/src/test/spirv_run/rsqrt.spvtxt @@ -0,0 +1,56 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 31 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%23 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "rsqrt" +OpDecorate %1 FunctionDenormModeINTEL 64 Preserve +%24 = OpTypeVoid +%25 = OpTypeInt 64 0 +%26 = OpTypeFunction %24 %25 %25 +%27 = OpTypePointer Function %25 +%28 = OpTypeFloat 64 +%29 = OpTypePointer Function %28 +%30 = OpTypePointer Generic %28 +%1 = OpFunction %24 None %26 +%7 = OpFunctionParameter %25 +%8 = OpFunctionParameter %25 +%21 = OpLabel +%2 = OpVariable %27 Function +%3 = OpVariable %27 Function +%4 = OpVariable %27 Function +%5 = OpVariable %27 Function +%6 = OpVariable %29 Function +OpStore %2 %7 +OpStore %3 %8 +%10 = OpLoad %25 %2 +%9 = OpCopyObject %25 %10 +OpStore %4 %9 +%12 = OpLoad %25 %3 +%11 = OpCopyObject %25 %12 +OpStore %5 %11 +%14 = OpLoad %25 %4 +%19 = OpConvertUToPtr %30 %14 +%13 = OpLoad %28 %19 +OpStore %6 %13 +%16 = OpLoad %28 %6 +%15 = OpExtInst %28 %23 native_rsqrt %16 +OpStore %6 %15 +%17 = OpLoad %25 %5 +%18 = OpLoad %28 %6 +%20 = OpConvertUToPtr %30 %17 +OpStore %20 %18 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/shared_ptr_32.ptx b/ptx/src/test/spirv_run/shared_ptr_32.ptx new file mode 100644 index 0000000..0334aa0 --- /dev/null +++ b/ptx/src/test/spirv_run/shared_ptr_32.ptx @@ -0,0 +1,29 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+
+.visible .entry shared_ptr_32(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .shared .align 4 .b8 shared_mem1[128];
+
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 shared_addr;
+
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+ mov.u32 shared_addr, shared_mem1;
+
+ ld.global.u64 temp1, [in_addr];
+ st.shared.u64 [shared_addr], temp1;
+ ld.shared.u64 temp2, [shared_addr+0];
+ st.global.u64 [out_addr], temp2;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt new file mode 100644 index 0000000..609cc0e --- /dev/null +++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt @@ -0,0 +1,74 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 47 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%34 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "shared_ptr_32" %4 +OpDecorate %4 Alignment 4 +%35 = OpTypeVoid +%36 = OpTypeInt 32 0 +%37 = OpTypeInt 8 0 +%38 = OpConstant %36 128 +%39 = OpTypeArray %37 %38 +%40 = OpTypePointer Workgroup %39 +%4 = OpVariable %40 Workgroup +%41 = OpTypeInt 64 0 +%42 = OpTypeFunction %35 %41 %41 +%43 = OpTypePointer Function %41 +%44 = OpTypePointer Function %36 +%45 = OpTypePointer CrossWorkgroup %41 +%46 = OpTypePointer Workgroup %41 +%25 = OpConstant %36 0 +%1 = OpFunction %35 None %42 +%10 = OpFunctionParameter %41 +%11 = OpFunctionParameter %41 +%32 = OpLabel +%2 = OpVariable %43 Function +%3 = OpVariable %43 Function +%5 = OpVariable %43 Function +%6 = OpVariable %43 Function +%7 = OpVariable %44 Function +%8 = OpVariable %43 Function +%9 = OpVariable %43 Function +OpStore %2 %10 +OpStore %3 %11 +%13 = OpLoad %41 %2 +%12 = OpCopyObject %41 %13 +OpStore %5 %12 +%15 = OpLoad %41 %3 +%14 = OpCopyObject %41 %15 +OpStore %6 %14 +%27 = OpConvertPtrToU %36 %4 +%16 = OpCopyObject %36 %27 +OpStore %7 %16 +%18 = OpLoad %41 %5 +%28 = OpConvertUToPtr %45 %18 +%17 = OpLoad %41 %28 +OpStore %8 %17 +%19 = OpLoad %36 %7 +%20 = OpLoad %41 %8 +%29 = OpConvertUToPtr %46 %19 +OpStore %29 %20 +%22 = OpLoad %36 %7 +%26 = OpIAdd %36 %22 %25 +%30 = OpConvertUToPtr %46 %26 +%21 = OpLoad %41 %30 +OpStore %9 %21 +%23 = OpLoad %41 %6 +%24 = OpLoad %41 %9 +%31 = OpConvertUToPtr %45 %23 +OpStore %31 %24 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/sqrt.ptx b/ptx/src/test/spirv_run/sqrt.ptx new file mode 100644 index 0000000..8b42f34 --- /dev/null +++ b/ptx/src/test/spirv_run/sqrt.ptx @@ -0,0 +1,21 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry sqrt(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp1;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp1, [in_addr];
+ sqrt.approx.f32 temp1, temp1;
+ st.f32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/sqrt.spvtxt b/ptx/src/test/spirv_run/sqrt.spvtxt new file mode 100644 index 0000000..d2c5b20 --- /dev/null +++ b/ptx/src/test/spirv_run/sqrt.spvtxt @@ -0,0 +1,56 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 31 +OpCapability GenericPointer +OpCapability Linkage +OpCapability Addresses +OpCapability Kernel +OpCapability Int8 +OpCapability Int16 +OpCapability Int64 +OpCapability Float16 +OpCapability Float64 +; OpCapability FunctionFloatControlINTEL +; OpExtension "SPV_INTEL_float_controls2" +%23 = OpExtInstImport "OpenCL.std" +OpMemoryModel Physical64 OpenCL +OpEntryPoint Kernel %1 "sqrt" +OpDecorate %1 FunctionDenormModeINTEL 32 Preserve +%24 = OpTypeVoid +%25 = OpTypeInt 64 0 +%26 = OpTypeFunction %24 %25 %25 +%27 = OpTypePointer Function %25 +%28 = OpTypeFloat 32 +%29 = OpTypePointer Function %28 +%30 = OpTypePointer Generic %28 +%1 = OpFunction %24 None %26 +%7 = OpFunctionParameter %25 +%8 = OpFunctionParameter %25 +%21 = OpLabel +%2 = OpVariable %27 Function +%3 = OpVariable %27 Function +%4 = OpVariable %27 Function +%5 = OpVariable %27 Function +%6 = OpVariable %29 Function +OpStore %2 %7 +OpStore %3 %8 +%10 = OpLoad %25 %2 +%9 = OpCopyObject %25 %10 +OpStore %4 %9 +%12 = OpLoad %25 %3 +%11 = OpCopyObject %25 %12 +OpStore %5 %11 +%14 = OpLoad %25 %4 +%19 = OpConvertUToPtr %30 %14 +%13 = OpLoad %28 %19 +OpStore %6 %13 +%16 = OpLoad %28 %6 +%15 = OpExtInst %28 %23 native_sqrt %16 +OpStore %6 %15 +%17 = OpLoad %25 %5 +%18 = OpLoad %28 %6 +%20 = OpConvertUToPtr %30 %17 +OpStore %20 %18 +OpReturn +OpFunctionEnd
\ No newline at end of file diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 6b07c0f..c351ccd 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,8 +1,11 @@ use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
-use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, hash::Hash, iter, mem};
+use std::{
+ collections::{hash_map, HashMap, HashSet},
+ convert::TryInto,
+};
use rspirv::binary::Assemble;
@@ -1499,6 +1502,15 @@ fn convert_to_typed_statements( ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
ast::Instruction::AtomCas(d, a.cast()),
)),
+ ast::Instruction::Div(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast())))
+ }
+ ast::Instruction::Sqrt(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast())))
+ }
+ ast::Instruction::Rsqrt(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1982,7 +1994,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { | ArgumentSemantics::DefaultRelaxed
| ArgumentSemantics::PhysicalPointer => {
if desc.sema == ArgumentSemantics::PhysicalPointer {
- typ = ast::Type::Scalar(ast::ScalarType::U64);
+ typ = self.id_def.get_typed(reg)?;
}
let (width, kind) = match typ {
ast::Type::Scalar(scalar_t) => {
@@ -2013,7 +2025,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
- value: ast::ImmediateValue::S64(-(offset as i64)),
+ value: ast::ImmediateValue::U64(-(offset as i64) as u64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Sub(
@@ -2765,6 +2777,34 @@ fn emit_function_body_ops( arg.src2,
)?;
}
+ ast::Instruction::Div(details, arg) => match details {
+ ast::DivDetails::Unsigned(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::DivDetails::Signed(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::DivDetails::Float(t) => {
+ let result_type = map.get_or_add_scalar(builder, t.typ.into());
+ builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ emit_float_div_decoration(builder, arg.dst, t.kind);
+ }
+ },
+ ast::Instruction::Sqrt(details, a) => {
+ emit_sqrt(builder, map, opencl, details, a)?;
+ }
+ ast::Instruction::Rsqrt(details, a) => {
+ let result_type = map.get_or_add_scalar(builder, details.typ.into());
+ builder.ext_inst(
+ result_type,
+ Some(a.dst),
+ opencl,
+ spirv::CLOp::native_rsqrt as spirv::Word,
+ &[a.src],
+ )?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -2795,6 +2835,47 @@ fn emit_function_body_ops( Ok(())
}
+fn emit_sqrt(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ details: &ast::SqrtDetails,
+ a: &ast::Arg2<ExpandedArgParams>,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add_scalar(builder, details.typ.into());
+ let (ocl_op, rounding) = match details.kind {
+ ast::SqrtKind::Approx => (spirv::CLOp::native_sqrt, None),
+ ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
+ };
+ builder.ext_inst(
+ result_type,
+ Some(a.dst),
+ opencl,
+ ocl_op as spirv::Word,
+ &[a.src],
+ )?;
+ emit_rounding_decoration(builder, a.dst, rounding);
+ Ok(())
+}
+
+fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) {
+ match kind {
+ ast::DivFloatKind::Approx => {
+ builder.decorate(
+ dst,
+ spirv::Decoration::FPFastMathMode,
+ &[dr::Operand::FPFastMathMode(
+ spirv::FPFastMathMode::ALLOW_RECIP,
+ )],
+ );
+ }
+ ast::DivFloatKind::Rounding(rnd) => {
+ emit_rounding_decoration(builder, dst, Some(rnd));
+ }
+ ast::DivFloatKind::Full => {}
+ }
+}
+
fn emit_atom(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -3307,7 +3388,25 @@ fn emit_setp( (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- _ => todo!(),
+ (ast::SetpCompareOp::NanEq, _) => {
+ builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanNotEq, _) => {
+ builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanLess, _) => {
+ builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanLessOrEq, _) => {
+ builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanGreater, _) => {
+ builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanGreaterOrEq, _) => {
+ builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ _ => todo!()
}?;
Ok(())
}
@@ -3486,8 +3585,8 @@ fn emit_implicit_conversion( let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
- (_, _, ConversionKind::PtrToBit) => {
- let dst_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
+ (_, _, ConversionKind::PtrToBit(typ)) => {
+ let dst_type = map.get_or_add_scalar(builder, typ.into());
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
}
(_, _, ConversionKind::BitToPtr(_)) => {
@@ -4570,6 +4669,15 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Instruction::AtomCas(d, a) => {
ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
}
+ ast::Instruction::Div(d, a) => {
+ ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?)
+ }
+ ast::Instruction::Sqrt(d, a) => {
+ ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
+ }
+ ast::Instruction::Rsqrt(d, a) => {
+ ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
+ }
})
}
}
@@ -4794,32 +4902,7 @@ impl ast::Instruction<ExpandedArgParams> { fn jump_target(&self) -> Option<spirv::Word> {
match self {
ast::Instruction::Bra(_, a) => Some(a.src),
- ast::Instruction::Ld(_, _)
- | ast::Instruction::Mov(_, _)
- | ast::Instruction::Mul(_, _)
- | ast::Instruction::Add(_, _)
- | ast::Instruction::Setp(_, _)
- | ast::Instruction::SetpBool(_, _)
- | ast::Instruction::Not(_, _)
- | ast::Instruction::Cvt(_, _)
- | ast::Instruction::Cvta(_, _)
- | ast::Instruction::Shl(_, _)
- | ast::Instruction::Shr(_, _)
- | ast::Instruction::St(_, _)
- | ast::Instruction::Ret(_)
- | ast::Instruction::Abs(_, _)
- | ast::Instruction::Call(_)
- | ast::Instruction::Or(_, _)
- | ast::Instruction::Sub(_, _)
- | ast::Instruction::Min(_, _)
- | ast::Instruction::Max(_, _)
- | ast::Instruction::Rcp(_, _)
- | ast::Instruction::And(_, _)
- | ast::Instruction::Selp(_, _)
- | ast::Instruction::Bar(_, _)
- | ast::Instruction::Atom(_, _)
- | ast::Instruction::AtomCas(_, _)
- | ast::Instruction::Mad(_, _) => None,
+ _ => None,
}
}
@@ -4856,6 +4939,9 @@ impl ast::Instruction<ExpandedArgParams> { ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None,
ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None,
ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None,
+ ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None,
+ ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
+ ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -4885,13 +4971,19 @@ impl ast::Instruction<ExpandedArgParams> { _,
)
| ast::Instruction::Cvt(
- ast::CvtDetails::FloatFromInt(ast::CvtDesc { flush_to_zero, .. }),
- _,
- )
- | ast::Instruction::Cvt(
ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }),
_,
) => flush_to_zero.map(|ftz| (ftz, 4)),
+ ast::Instruction::Div(ast::DivDetails::Float(details), _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
+ ast::Instruction::Sqrt(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
+ ast::Instruction::Rsqrt(details, _) => Some((
+ details.flush_to_zero,
+ ast::ScalarType::from(details.typ).size_of(),
+ )),
}
}
}
@@ -4978,13 +5070,13 @@ struct ImplicitConversion { kind: ConversionKind,
}
-#[derive(Debug, PartialEq, Copy, Clone)]
+#[derive(PartialEq, Copy, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
BitToPtr(ast::LdStateSpace),
- PtrToBit,
+ PtrToBit(ast::UIntType),
PtrToPtr { spirv_ptr: bool },
}
@@ -6027,6 +6119,16 @@ impl ast::MinMaxDetails { }
}
+impl ast::DivDetails {
+ fn get_type(&self) -> ast::Type {
+ ast::Type::Scalar(match self {
+ ast::DivDetails::Unsigned(t) => (*t).into(),
+ ast::DivDetails::Signed(t) => (*t).into(),
+ ast::DivDetails::Float(d) => d.typ.into(),
+ })
+ }
+}
+
impl ast::AtomInnerDetails {
fn get_type(&self) -> ast::ScalarType {
match self {
@@ -6193,6 +6295,15 @@ fn bitcast_physical_pointer( Err(TranslateError::Unreachable)
}
}
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => {
+ if let Some(ast::LdStateSpace::Shared) = ss {
+ Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ }
ast::Type::Pointer(op_scalar_t, op_space) => {
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
if op_space == instr_space {
@@ -6220,10 +6331,16 @@ fn bitcast_physical_pointer( fn force_bitcast_ptr_to_bit(
_: &ast::Type,
- _: &ast::Type,
+ instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
- Ok(Some(ConversionKind::PtrToBit))
+ // TODO: verify this on f32, u16 and the like
+ if let ast::Type::Scalar(scalar_t) = instr_type {
+ if let Ok(int_type) = (*scalar_t).try_into() {
+ return Ok(Some(ConversionKind::PtrToBit(int_type)));
+ }
+ }
+ Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
@@ -6542,9 +6659,9 @@ mod tests { &ast::Type::Scalar(*instr_type),
);
if instr_idx == op_idx {
- assert_eq!(conversion, None);
+ assert!(conversion == None);
} else {
- assert_eq!(conversion, conv_table[instr_idx][op_idx]);
+ assert!(conversion == conv_table[instr_idx][op_idx]);
}
}
}
|