aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/ast.rs45
-rw-r--r--ptx/src/ptx.lalrpop94
-rw-r--r--ptx/src/test/spirv_run/div_approx.ptx23
-rw-r--r--ptx/src/test/spirv_run/div_approx.spvtxt65
-rw-r--r--ptx/src/test/spirv_run/mod.rs4
-rw-r--r--ptx/src/test/spirv_run/rsqrt.ptx21
-rw-r--r--ptx/src/test/spirv_run/rsqrt.spvtxt56
-rw-r--r--ptx/src/test/spirv_run/shared_ptr_32.ptx29
-rw-r--r--ptx/src/test/spirv_run/shared_ptr_32.spvtxt74
-rw-r--r--ptx/src/test/spirv_run/sqrt.ptx21
-rw-r--r--ptx/src/test/spirv_run/sqrt.spvtxt56
-rw-r--r--ptx/src/translate.rs201
12 files changed, 645 insertions, 44 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index ad8e87d..f00ddce 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -539,6 +539,9 @@ pub enum Instruction<P: ArgParams> {
Bar(BarDetails, Arg1Bar<P>),
Atom(AtomDetails, Arg3<P>),
AtomCas(AtomCasDetails, Arg4<P>),
+ Div(DivDetails, Arg3<P>),
+ Sqrt(SqrtDetails, Arg2<P>),
+ Rsqrt(RsqrtDetails, Arg2<P>),
}
#[derive(Copy, Clone)]
@@ -1132,7 +1135,28 @@ pub struct AtomCasDetails {
pub semantics: AtomSemantics,
pub scope: MemScope,
pub space: AtomSpace,
- pub typ: BitType
+ pub typ: BitType,
+}
+
+#[derive(Copy, Clone)]
+pub enum DivDetails {
+ Unsigned(UIntType),
+ Signed(SIntType),
+ Float(DivFloatDetails),
+}
+
+#[derive(Copy, Clone)]
+pub struct DivFloatDetails {
+ pub typ: FloatType,
+ pub flush_to_zero: Option<bool>,
+ pub kind: DivFloatKind,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum DivFloatKind {
+ Approx,
+ Full,
+ Rounding(RoundingMode),
}
pub enum NumsOrArrays<'a> {
@@ -1140,6 +1164,25 @@ pub enum NumsOrArrays<'a> {
Arrays(Vec<NumsOrArrays<'a>>),
}
+#[derive(Copy, Clone)]
+pub struct SqrtDetails {
+ pub typ: FloatType,
+ pub flush_to_zero: Option<bool>,
+ pub kind: SqrtKind,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum SqrtKind {
+ Approx,
+ Rounding(RoundingMode),
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct RsqrtDetails {
+ pub typ: FloatType,
+ pub flush_to_zero: bool,
+}
+
impl<'a> NumsOrArrays<'a> {
pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
self.normalize_dimensions(dimensions)?;
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 806a3fc..4cf4255 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -66,6 +66,7 @@ match {
".f64",
".file",
".ftz",
+ ".full",
".func",
".ge",
".geu",
@@ -94,6 +95,7 @@ match {
".num",
".or",
".param",
+ ".pragma",
".pred",
".reg",
".relaxed",
@@ -145,6 +147,7 @@ match {
"cvt",
"cvta",
"debug",
+ "div",
"fma",
"ld",
"mad",
@@ -157,11 +160,13 @@ match {
"or",
"rcp",
"ret",
+ "rsqrt",
"selp",
"setp",
"shl",
"shr",
r"sm_[0-9]+" => ShaderModel,
+ "sqrt",
"st",
"sub",
"texmode_independent",
@@ -184,6 +189,7 @@ ExtendedID : &'input str = {
"cvt",
"cvta",
"debug",
+ "div",
"fma",
"ld",
"mad",
@@ -196,11 +202,13 @@ ExtendedID : &'input str = {
"or",
"rcp",
"ret",
+ "rsqrt",
"selp",
"setp",
"shl",
"shr",
ShaderModel,
+ "sqrt",
"st",
"sub",
"texmode_independent",
@@ -415,9 +423,14 @@ Statement: Option<ast::Statement<ast::ParsedArgParams<'input>>> = {
DebugDirective => None,
<v:MultiVariable> ";" => Some(ast::Statement::Variable(v)),
<p:PredAt?> <i:Instruction> ";" => Some(ast::Statement::Instruction(p, i)),
+ PragmaStatement => None,
"{" <s:Statement*> "}" => Some(ast::Statement::Block(without_none(s)))
};
+PragmaStatement: () = {
+ ".pragma" String ";"
+}
+
DebugDirective: () = {
DebugLocation
};
@@ -667,7 +680,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstSelp,
InstBar,
InstAtom,
- InstAtomCas
+ InstAtomCas,
+ InstDiv,
+ InstSqrt,
+ InstRsqrt,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -1485,6 +1501,82 @@ AtomSIntType: ast::SIntType = {
".s64" => ast::SIntType::S64,
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div
+InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "div" <t:UIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Unsigned(t), a),
+ "div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a),
+ "div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => {
+ let inner = ast::DivFloatDetails {
+ typ: ast::FloatType::F32,
+ flush_to_zero: Some(ftz.is_some()),
+ kind
+ };
+ ast::Instruction::Div(ast::DivDetails::Float(inner), a)
+ },
+ "div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => {
+ let inner = ast::DivFloatDetails {
+ typ: ast::FloatType::F64,
+ flush_to_zero: None,
+ kind: ast::DivFloatKind::Rounding(rnd)
+ };
+ ast::Instruction::Div(ast::DivDetails::Float(inner), a)
+ },
+}
+
+DivFloatKind: ast::DivFloatKind = {
+ ".approx" => ast::DivFloatKind::Approx,
+ ".full" => ast::DivFloatKind::Full,
+ <rnd:RoundingModeFloat> => ast::DivFloatKind::Rounding(rnd),
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt
+InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
+ let details = ast::SqrtDetails {
+ typ: ast::FloatType::F32,
+ flush_to_zero: Some(ftz.is_some()),
+ kind: ast::SqrtKind::Approx,
+ };
+ ast::Instruction::Sqrt(details, a)
+ },
+ "sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => {
+ let details = ast::SqrtDetails {
+ typ: ast::FloatType::F32,
+ flush_to_zero: Some(ftz.is_some()),
+ kind: ast::SqrtKind::Rounding(rnd),
+ };
+ ast::Instruction::Sqrt(details, a)
+ },
+ "sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => {
+ let details = ast::SqrtDetails {
+ typ: ast::FloatType::F64,
+ flush_to_zero: None,
+ kind: ast::SqrtKind::Rounding(rnd),
+ };
+ ast::Instruction::Sqrt(details, a)
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64
+InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
+ let details = ast::RsqrtDetails {
+ typ: ast::FloatType::F32,
+ flush_to_zero: ftz.is_some(),
+ };
+ ast::Instruction::Rsqrt(details, a)
+ },
+ "rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => {
+ let details = ast::RsqrtDetails {
+ typ: ast::FloatType::F64,
+ flush_to_zero: ftz.is_some(),
+ };
+ ast::Instruction::Rsqrt(details, a)
+ },
+}
+
ArithDetails: ast::ArithDetails = {
<t:UIntType> => ast::ArithDetails::Unsigned(t),
<t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
diff --git a/ptx/src/test/spirv_run/div_approx.ptx b/ptx/src/test/spirv_run/div_approx.ptx
new file mode 100644
index 0000000..b25e320
--- /dev/null
+++ b/ptx/src/test/spirv_run/div_approx.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry div_approx(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp1;
+ .reg .f32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp1, [in_addr];
+ ld.f32 temp2, [in_addr+4];
+ div.approx.f32 temp1, temp1, temp2;
+ st.f32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/div_approx.spvtxt b/ptx/src/test/spirv_run/div_approx.spvtxt
new file mode 100644
index 0000000..40cc152
--- /dev/null
+++ b/ptx/src/test/spirv_run/div_approx.spvtxt
@@ -0,0 +1,65 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 38
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
+%30 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "div_approx"
+OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
+OpDecorate %18 FPFastMathMode AllowRecip
+%31 = OpTypeVoid
+%32 = OpTypeInt 64 0
+%33 = OpTypeFunction %31 %32 %32
+%34 = OpTypePointer Function %32
+%35 = OpTypeFloat 32
+%36 = OpTypePointer Function %35
+%37 = OpTypePointer Generic %35
+%23 = OpConstant %32 4
+%1 = OpFunction %31 None %33
+%8 = OpFunctionParameter %32
+%9 = OpFunctionParameter %32
+%28 = OpLabel
+%2 = OpVariable %34 Function
+%3 = OpVariable %34 Function
+%4 = OpVariable %34 Function
+%5 = OpVariable %34 Function
+%6 = OpVariable %36 Function
+%7 = OpVariable %36 Function
+OpStore %2 %8
+OpStore %3 %9
+%11 = OpLoad %32 %2
+%10 = OpCopyObject %32 %11
+OpStore %4 %10
+%13 = OpLoad %32 %3
+%12 = OpCopyObject %32 %13
+OpStore %5 %12
+%15 = OpLoad %32 %4
+%25 = OpConvertUToPtr %37 %15
+%14 = OpLoad %35 %25
+OpStore %6 %14
+%17 = OpLoad %32 %4
+%24 = OpIAdd %32 %17 %23
+%26 = OpConvertUToPtr %37 %24
+%16 = OpLoad %35 %26
+OpStore %7 %16
+%19 = OpLoad %35 %6
+%20 = OpLoad %35 %7
+%18 = OpFDiv %35 %19 %20
+OpStore %6 %18
+%21 = OpLoad %32 %5
+%22 = OpLoad %35 %6
+%27 = OpConvertUToPtr %37 %21
+OpStore %27 %22
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 40a9d64..4e9d39f 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -97,9 +97,13 @@ test_ptx!(and, [6u32, 3u32], [2u32]);
test_ptx!(selp, [100u16, 200u16], [200u16]);
test_ptx!(fma, [2f32, 3f32, 5f32], [11f32]);
test_ptx!(shared_variable, [513u64], [513u64]);
+test_ptx!(shared_ptr_32, [513u64], [513u64]);
test_ptx!(atom_cas, [91u32, 91u32], [91u32, 100u32]);
test_ptx!(atom_inc, [100u32], [100u32, 101u32, 0u32]);
test_ptx!(atom_add, [2u32, 4u32], [2u32, 6u32]);
+test_ptx!(div_approx, [1f32, 2f32], [0.5f32]);
+test_ptx!(sqrt, [0.25f32], [0.5f32]);
+test_ptx!(rsqrt, [0.25f64], [2f64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/rsqrt.ptx b/ptx/src/test/spirv_run/rsqrt.ptx
new file mode 100644
index 0000000..5821501
--- /dev/null
+++ b/ptx/src/test/spirv_run/rsqrt.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry rsqrt(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f64 temp1;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f64 temp1, [in_addr];
+ rsqrt.approx.f64 temp1, temp1;
+ st.f64 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/rsqrt.spvtxt b/ptx/src/test/spirv_run/rsqrt.spvtxt
new file mode 100644
index 0000000..5c3ba97
--- /dev/null
+++ b/ptx/src/test/spirv_run/rsqrt.spvtxt
@@ -0,0 +1,56 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 31
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
+%23 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "rsqrt"
+OpDecorate %1 FunctionDenormModeINTEL 64 Preserve
+%24 = OpTypeVoid
+%25 = OpTypeInt 64 0
+%26 = OpTypeFunction %24 %25 %25
+%27 = OpTypePointer Function %25
+%28 = OpTypeFloat 64
+%29 = OpTypePointer Function %28
+%30 = OpTypePointer Generic %28
+%1 = OpFunction %24 None %26
+%7 = OpFunctionParameter %25
+%8 = OpFunctionParameter %25
+%21 = OpLabel
+%2 = OpVariable %27 Function
+%3 = OpVariable %27 Function
+%4 = OpVariable %27 Function
+%5 = OpVariable %27 Function
+%6 = OpVariable %29 Function
+OpStore %2 %7
+OpStore %3 %8
+%10 = OpLoad %25 %2
+%9 = OpCopyObject %25 %10
+OpStore %4 %9
+%12 = OpLoad %25 %3
+%11 = OpCopyObject %25 %12
+OpStore %5 %11
+%14 = OpLoad %25 %4
+%19 = OpConvertUToPtr %30 %14
+%13 = OpLoad %28 %19
+OpStore %6 %13
+%16 = OpLoad %28 %6
+%15 = OpExtInst %28 %23 native_rsqrt %16
+OpStore %6 %15
+%17 = OpLoad %25 %5
+%18 = OpLoad %28 %6
+%20 = OpConvertUToPtr %30 %17
+OpStore %20 %18
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/shared_ptr_32.ptx b/ptx/src/test/spirv_run/shared_ptr_32.ptx
new file mode 100644
index 0000000..0334aa0
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_ptr_32.ptx
@@ -0,0 +1,29 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+
+.visible .entry shared_ptr_32(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .shared .align 4 .b8 shared_mem1[128];
+
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 shared_addr;
+
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+ mov.u32 shared_addr, shared_mem1;
+
+ ld.global.u64 temp1, [in_addr];
+ st.shared.u64 [shared_addr], temp1;
+ ld.shared.u64 temp2, [shared_addr+0];
+ st.global.u64 [out_addr], temp2;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt
new file mode 100644
index 0000000..609cc0e
--- /dev/null
+++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt
@@ -0,0 +1,74 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 47
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
+%34 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "shared_ptr_32" %4
+OpDecorate %4 Alignment 4
+%35 = OpTypeVoid
+%36 = OpTypeInt 32 0
+%37 = OpTypeInt 8 0
+%38 = OpConstant %36 128
+%39 = OpTypeArray %37 %38
+%40 = OpTypePointer Workgroup %39
+%4 = OpVariable %40 Workgroup
+%41 = OpTypeInt 64 0
+%42 = OpTypeFunction %35 %41 %41
+%43 = OpTypePointer Function %41
+%44 = OpTypePointer Function %36
+%45 = OpTypePointer CrossWorkgroup %41
+%46 = OpTypePointer Workgroup %41
+%25 = OpConstant %36 0
+%1 = OpFunction %35 None %42
+%10 = OpFunctionParameter %41
+%11 = OpFunctionParameter %41
+%32 = OpLabel
+%2 = OpVariable %43 Function
+%3 = OpVariable %43 Function
+%5 = OpVariable %43 Function
+%6 = OpVariable %43 Function
+%7 = OpVariable %44 Function
+%8 = OpVariable %43 Function
+%9 = OpVariable %43 Function
+OpStore %2 %10
+OpStore %3 %11
+%13 = OpLoad %41 %2
+%12 = OpCopyObject %41 %13
+OpStore %5 %12
+%15 = OpLoad %41 %3
+%14 = OpCopyObject %41 %15
+OpStore %6 %14
+%27 = OpConvertPtrToU %36 %4
+%16 = OpCopyObject %36 %27
+OpStore %7 %16
+%18 = OpLoad %41 %5
+%28 = OpConvertUToPtr %45 %18
+%17 = OpLoad %41 %28
+OpStore %8 %17
+%19 = OpLoad %36 %7
+%20 = OpLoad %41 %8
+%29 = OpConvertUToPtr %46 %19
+OpStore %29 %20
+%22 = OpLoad %36 %7
+%26 = OpIAdd %36 %22 %25
+%30 = OpConvertUToPtr %46 %26
+%21 = OpLoad %41 %30
+OpStore %9 %21
+%23 = OpLoad %41 %6
+%24 = OpLoad %41 %9
+%31 = OpConvertUToPtr %45 %23
+OpStore %31 %24
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/sqrt.ptx b/ptx/src/test/spirv_run/sqrt.ptx
new file mode 100644
index 0000000..8b42f34
--- /dev/null
+++ b/ptx/src/test/spirv_run/sqrt.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry sqrt(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp1;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp1, [in_addr];
+ sqrt.approx.f32 temp1, temp1;
+ st.f32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/sqrt.spvtxt b/ptx/src/test/spirv_run/sqrt.spvtxt
new file mode 100644
index 0000000..d2c5b20
--- /dev/null
+++ b/ptx/src/test/spirv_run/sqrt.spvtxt
@@ -0,0 +1,56 @@
+; SPIR-V
+; Version: 1.3
+; Generator: rspirv
+; Bound: 31
+OpCapability GenericPointer
+OpCapability Linkage
+OpCapability Addresses
+OpCapability Kernel
+OpCapability Int8
+OpCapability Int16
+OpCapability Int64
+OpCapability Float16
+OpCapability Float64
+; OpCapability FunctionFloatControlINTEL
+; OpExtension "SPV_INTEL_float_controls2"
+%23 = OpExtInstImport "OpenCL.std"
+OpMemoryModel Physical64 OpenCL
+OpEntryPoint Kernel %1 "sqrt"
+OpDecorate %1 FunctionDenormModeINTEL 32 Preserve
+%24 = OpTypeVoid
+%25 = OpTypeInt 64 0
+%26 = OpTypeFunction %24 %25 %25
+%27 = OpTypePointer Function %25
+%28 = OpTypeFloat 32
+%29 = OpTypePointer Function %28
+%30 = OpTypePointer Generic %28
+%1 = OpFunction %24 None %26
+%7 = OpFunctionParameter %25
+%8 = OpFunctionParameter %25
+%21 = OpLabel
+%2 = OpVariable %27 Function
+%3 = OpVariable %27 Function
+%4 = OpVariable %27 Function
+%5 = OpVariable %27 Function
+%6 = OpVariable %29 Function
+OpStore %2 %7
+OpStore %3 %8
+%10 = OpLoad %25 %2
+%9 = OpCopyObject %25 %10
+OpStore %4 %9
+%12 = OpLoad %25 %3
+%11 = OpCopyObject %25 %12
+OpStore %5 %11
+%14 = OpLoad %25 %4
+%19 = OpConvertUToPtr %30 %14
+%13 = OpLoad %28 %19
+OpStore %6 %13
+%16 = OpLoad %28 %6
+%15 = OpExtInst %28 %23 native_sqrt %16
+OpStore %6 %15
+%17 = OpLoad %25 %5
+%18 = OpLoad %28 %6
+%20 = OpConvertUToPtr %30 %17
+OpStore %20 %18
+OpReturn
+OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 6b07c0f..c351ccd 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,8 +1,11 @@
use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
-use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, hash::Hash, iter, mem};
+use std::{
+ collections::{hash_map, HashMap, HashSet},
+ convert::TryInto,
+};
use rspirv::binary::Assemble;
@@ -1499,6 +1502,15 @@ fn convert_to_typed_statements(
ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
ast::Instruction::AtomCas(d, a.cast()),
)),
+ ast::Instruction::Div(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast())))
+ }
+ ast::Instruction::Sqrt(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast())))
+ }
+ ast::Instruction::Rsqrt(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1982,7 +1994,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
| ArgumentSemantics::DefaultRelaxed
| ArgumentSemantics::PhysicalPointer => {
if desc.sema == ArgumentSemantics::PhysicalPointer {
- typ = ast::Type::Scalar(ast::ScalarType::U64);
+ typ = self.id_def.get_typed(reg)?;
}
let (width, kind) = match typ {
ast::Type::Scalar(scalar_t) => {
@@ -2013,7 +2025,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
- value: ast::ImmediateValue::S64(-(offset as i64)),
+ value: ast::ImmediateValue::U64(-(offset as i64) as u64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Sub(
@@ -2765,6 +2777,34 @@ fn emit_function_body_ops(
arg.src2,
)?;
}
+ ast::Instruction::Div(details, arg) => match details {
+ ast::DivDetails::Unsigned(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::DivDetails::Signed(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::DivDetails::Float(t) => {
+ let result_type = map.get_or_add_scalar(builder, t.typ.into());
+ builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ emit_float_div_decoration(builder, arg.dst, t.kind);
+ }
+ },
+ ast::Instruction::Sqrt(details, a) => {
+ emit_sqrt(builder, map, opencl, details, a)?;
+ }
+ ast::Instruction::Rsqrt(details, a) => {
+ let result_type = map.get_or_add_scalar(builder, details.typ.into());
+ builder.ext_inst(
+ result_type,
+ Some(a.dst),
+ opencl,
+ spirv::CLOp::native_rsqrt as spirv::Word,
+ &[a.src],
+ )?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -2795,6 +2835,47 @@ fn emit_function_body_ops(
Ok(())
}
+fn emit_sqrt(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ details: &ast::SqrtDetails,
+ a: &ast::Arg2<ExpandedArgParams>,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add_scalar(builder, details.typ.into());
+ let (ocl_op, rounding) = match details.kind {
+ ast::SqrtKind::Approx => (spirv::CLOp::native_sqrt, None),
+ ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
+ };
+ builder.ext_inst(
+ result_type,
+ Some(a.dst),
+ opencl,
+ ocl_op as spirv::Word,
+ &[a.src],
+ )?;
+ emit_rounding_decoration(builder, a.dst, rounding);
+ Ok(())
+}
+
+fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) {
+ match kind {
+ ast::DivFloatKind::Approx => {
+ builder.decorate(
+ dst,
+ spirv::Decoration::FPFastMathMode,
+ &[dr::Operand::FPFastMathMode(
+ spirv::FPFastMathMode::ALLOW_RECIP,
+ )],
+ );
+ }
+ ast::DivFloatKind::Rounding(rnd) => {
+ emit_rounding_decoration(builder, dst, Some(rnd));
+ }
+ ast::DivFloatKind::Full => {}
+ }
+}
+
fn emit_atom(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -3307,7 +3388,25 @@ fn emit_setp(
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- _ => todo!(),
+ (ast::SetpCompareOp::NanEq, _) => {
+ builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanNotEq, _) => {
+ builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanLess, _) => {
+ builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanLessOrEq, _) => {
+ builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanGreater, _) => {
+ builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanGreaterOrEq, _) => {
+ builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ _ => todo!()
}?;
Ok(())
}
@@ -3486,8 +3585,8 @@ fn emit_implicit_conversion(
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
- (_, _, ConversionKind::PtrToBit) => {
- let dst_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
+ (_, _, ConversionKind::PtrToBit(typ)) => {
+ let dst_type = map.get_or_add_scalar(builder, typ.into());
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
}
(_, _, ConversionKind::BitToPtr(_)) => {
@@ -4570,6 +4669,15 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::AtomCas(d, a) => {
ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
}
+ ast::Instruction::Div(d, a) => {
+ ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?)
+ }
+ ast::Instruction::Sqrt(d, a) => {
+ ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
+ }
+ ast::Instruction::Rsqrt(d, a) => {
+ ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
+ }
})
}
}
@@ -4794,32 +4902,7 @@ impl ast::Instruction<ExpandedArgParams> {
fn jump_target(&self) -> Option<spirv::Word> {
match self {
ast::Instruction::Bra(_, a) => Some(a.src),
- ast::Instruction::Ld(_, _)
- | ast::Instruction::Mov(_, _)
- | ast::Instruction::Mul(_, _)
- | ast::Instruction::Add(_, _)
- | ast::Instruction::Setp(_, _)
- | ast::Instruction::SetpBool(_, _)
- | ast::Instruction::Not(_, _)
- | ast::Instruction::Cvt(_, _)
- | ast::Instruction::Cvta(_, _)
- | ast::Instruction::Shl(_, _)
- | ast::Instruction::Shr(_, _)
- | ast::Instruction::St(_, _)
- | ast::Instruction::Ret(_)
- | ast::Instruction::Abs(_, _)
- | ast::Instruction::Call(_)
- | ast::Instruction::Or(_, _)
- | ast::Instruction::Sub(_, _)
- | ast::Instruction::Min(_, _)
- | ast::Instruction::Max(_, _)
- | ast::Instruction::Rcp(_, _)
- | ast::Instruction::And(_, _)
- | ast::Instruction::Selp(_, _)
- | ast::Instruction::Bar(_, _)
- | ast::Instruction::Atom(_, _)
- | ast::Instruction::AtomCas(_, _)
- | ast::Instruction::Mad(_, _) => None,
+ _ => None,
}
}
@@ -4856,6 +4939,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None,
ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None,
ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None,
+ ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None,
+ ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
+ ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -4885,13 +4971,19 @@ impl ast::Instruction<ExpandedArgParams> {
_,
)
| ast::Instruction::Cvt(
- ast::CvtDetails::FloatFromInt(ast::CvtDesc { flush_to_zero, .. }),
- _,
- )
- | ast::Instruction::Cvt(
ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }),
_,
) => flush_to_zero.map(|ftz| (ftz, 4)),
+ ast::Instruction::Div(ast::DivDetails::Float(details), _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
+ ast::Instruction::Sqrt(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
+ ast::Instruction::Rsqrt(details, _) => Some((
+ details.flush_to_zero,
+ ast::ScalarType::from(details.typ).size_of(),
+ )),
}
}
}
@@ -4978,13 +5070,13 @@ struct ImplicitConversion {
kind: ConversionKind,
}
-#[derive(Debug, PartialEq, Copy, Clone)]
+#[derive(PartialEq, Copy, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
BitToPtr(ast::LdStateSpace),
- PtrToBit,
+ PtrToBit(ast::UIntType),
PtrToPtr { spirv_ptr: bool },
}
@@ -6027,6 +6119,16 @@ impl ast::MinMaxDetails {
}
}
+impl ast::DivDetails {
+ fn get_type(&self) -> ast::Type {
+ ast::Type::Scalar(match self {
+ ast::DivDetails::Unsigned(t) => (*t).into(),
+ ast::DivDetails::Signed(t) => (*t).into(),
+ ast::DivDetails::Float(d) => d.typ.into(),
+ })
+ }
+}
+
impl ast::AtomInnerDetails {
fn get_type(&self) -> ast::ScalarType {
match self {
@@ -6193,6 +6295,15 @@ fn bitcast_physical_pointer(
Err(TranslateError::Unreachable)
}
}
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => {
+ if let Some(ast::LdStateSpace::Shared) = ss {
+ Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ }
ast::Type::Pointer(op_scalar_t, op_space) => {
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
if op_space == instr_space {
@@ -6220,10 +6331,16 @@ fn bitcast_physical_pointer(
fn force_bitcast_ptr_to_bit(
_: &ast::Type,
- _: &ast::Type,
+ instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
- Ok(Some(ConversionKind::PtrToBit))
+ // TODO: verify this on f32, u16 and the like
+ if let ast::Type::Scalar(scalar_t) = instr_type {
+ if let Ok(int_type) = (*scalar_t).try_into() {
+ return Ok(Some(ConversionKind::PtrToBit(int_type)));
+ }
+ }
+ Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
@@ -6542,9 +6659,9 @@ mod tests {
&ast::Type::Scalar(*instr_type),
);
if instr_idx == op_idx {
- assert_eq!(conversion, None);
+ assert!(conversion == None);
} else {
- assert_eq!(conversion, conv_table[instr_idx][op_idx]);
+ assert!(conversion == conv_table[instr_idx][op_idx]);
}
}
}