summaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/ast.rs7
-rw-r--r--ptx/src/ptx.lalrpop29
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/test/spirv_run/rcp.ptx21
-rw-r--r--ptx/src/test/spirv_run/rcp.spvtxt51
-rw-r--r--ptx/src/translate.rs64
6 files changed, 162 insertions, 11 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 1cbe721..f7cdcc3 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -510,6 +510,7 @@ pub enum Instruction<P: ArgParams> {
Sub(ArithDetails, Arg3<P>),
Min(MinMaxDetails, Arg3<P>),
Max(MinMaxDetails, Arg3<P>),
+ Rcp(RcpDetails, Arg2<P>),
}
#[derive(Copy, Clone)]
@@ -520,6 +521,12 @@ pub struct AbsDetails {
pub flush_to_zero: bool,
pub typ: ScalarType,
}
+#[derive(Copy, Clone)]
+pub struct RcpDetails {
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: bool,
+ pub is_f64: bool,
+}
pub struct CallInst<P: ArgParams> {
pub uniform: bool,
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index c29d16b..a132705 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -35,6 +35,7 @@ match {
".address_size",
".align",
".and",
+ ".approx",
".b16",
".b32",
".b64",
@@ -134,6 +135,7 @@ match {
"mul",
"not",
"or",
+ "rcp",
"ret",
"setp",
"shl",
@@ -166,6 +168,7 @@ ExtendedID : &'input str = {
"mul",
"not",
"or",
+ "rcp",
"ret",
"setp",
"shl",
@@ -542,6 +545,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstSub,
InstMin,
InstMax,
+ InstRcp
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -1119,6 +1123,31 @@ OrType: ast::OrType = {
".b64" => ast::OrType::B64,
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp
+InstRcp: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "rcp" <rounding:RcpRoundingMode> <ftz:".ftz"?> ".f32" <a:Arg2> => {
+ let details = ast::RcpDetails {
+ rounding,
+ flush_to_zero: ftz.is_some(),
+ is_f64: false,
+ };
+ ast::Instruction::Rcp(details, a)
+ },
+ "rcp" <rn:RoundingModeFloat> ".f64" <a:Arg2> => {
+ let details = ast::RcpDetails {
+ rounding: Some(rn),
+ flush_to_zero: false,
+ is_f64: true,
+ };
+ ast::Instruction::Rcp(details, a)
+ }
+};
+
+RcpRoundingMode: Option<ast::RoundingMode> = {
+ ".approx" => None,
+ <r:RoundingModeFloat> => Some(r)
+};
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 3a8acb1..b4ae149 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -80,6 +80,7 @@ test_ptx!(max, [555i32, 444i32], [555i32]);
test_ptx!(global_array, [0xDEADu32], [1u32]);
test_ptx!(extern_shared, [127u64], [127u64]);
test_ptx!(extern_shared_call, [121u64], [123u64]);
+test_ptx!(rcp, [2f32], [0.5f32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/rcp.ptx b/ptx/src/test/spirv_run/rcp.ptx
new file mode 100644
index 0000000..eb02d7e
--- /dev/null
+++ b/ptx/src/test/spirv_run/rcp.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry rcp(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .f32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.f32 temp, [in_addr];
+ rcp.approx.f32 temp, temp;
+ st.f32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/rcp.spvtxt b/ptx/src/test/spirv_run/rcp.spvtxt
new file mode 100644
index 0000000..08b3e6e
--- /dev/null
+++ b/ptx/src/test/spirv_run/rcp.spvtxt
@@ -0,0 +1,51 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %23 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "rcp"
+ OpDecorate %15 FPFastMathMode AllowRecip
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %26 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_Generic_float = OpTypePointer Generic %float
+ %float_1 = OpConstant %float 1
+ %1 = OpFunction %void None %26
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %21 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_float Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %10 = OpLoad %ulong %2
+ %9 = OpCopyObject %ulong %10
+ OpStore %4 %9
+ %12 = OpLoad %ulong %3
+ %11 = OpCopyObject %ulong %12
+ OpStore %5 %11
+ %14 = OpLoad %ulong %4
+ %19 = OpConvertUToPtr %_ptr_Generic_float %14
+ %13 = OpLoad %float %19
+ OpStore %6 %13
+ %16 = OpLoad %float %6
+ %15 = OpFDiv %float %float_1 %16
+ OpStore %6 %15
+ %17 = OpLoad %ulong %5
+ %18 = OpLoad %float %6
+ %20 = OpConvertUToPtr %_ptr_Generic_float %17
+ OpStore %20 %18
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index ab7187f..cccf6ad 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1144,6 +1144,9 @@ fn convert_to_typed_statements(
ast::Instruction::Max(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
}
+ ast::Instruction::Rcp(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Rcp(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -2179,6 +2182,9 @@ fn emit_function_body_ops(
ast::Instruction::Max(d, a) => {
emit_max(builder, map, opencl, d, a)?;
}
+ ast::Instruction::Rcp(d, a) => {
+ emit_rcp(builder, map, d, a)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -2209,6 +2215,40 @@ fn emit_function_body_ops(
Ok(())
}
+fn emit_rcp(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ desc: &ast::RcpDetails,
+ a: &ast::Arg2<ExpandedArgParams>,
+) -> Result<(), TranslateError> {
+ if desc.flush_to_zero {
+ todo!()
+ }
+ let (instr_type, constant) = if desc.is_f64 {
+ (ast::ScalarType::F64, vec_repr(1.0f64))
+ } else {
+ (ast::ScalarType::F32, vec_repr(1.0f32))
+ };
+ let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
+ let result_type = map.get_or_add_scalar(builder, instr_type);
+ builder.f_div(result_type, Some(a.dst), one, a.src)?;
+ emit_rounding_decoration(builder, a.dst, desc.rounding);
+ builder.decorate(
+ a.dst,
+ spirv::Decoration::FPFastMathMode,
+ &[dr::Operand::FPFastMathMode(
+ spirv::FPFastMathMode::ALLOW_RECIP,
+ )],
+ );
+ Ok(())
+}
+
+fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
+ let mut result = vec![0; mem::size_of::<T>()];
+ unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) };
+ result
+}
+
fn emit_variable(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -3735,7 +3775,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
) -> Result<ast::Instruction<U>, TranslateError> {
Ok(match self {
ast::Instruction::Abs(d, arg) => {
- ast::Instruction::Abs(d, arg.map(visitor, false, &ast::Type::Scalar(d.typ))?)
+ ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
@@ -3766,9 +3806,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
- ast::Instruction::Not(t, a) => {
- ast::Instruction::Not(t, a.map(visitor, false, &t.to_type())?)
- }
+ ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -3806,7 +3844,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
- ast::Instruction::Cvta(d, a.map(visitor, false, &inst_type)?)
+ ast::Instruction::Cvta(d, a.map(visitor, &inst_type)?)
}
ast::Instruction::Mad(d, a) => {
let inst_type = d.get_type();
@@ -3829,6 +3867,14 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let typ = d.get_type();
ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?)
}
+ ast::Instruction::Rcp(d, a) => {
+ let typ = ast::Type::Scalar(if d.is_f64 {
+ ast::ScalarType::F64
+ } else {
+ ast::ScalarType::F32
+ });
+ ast::Instruction::Rcp(d, a.map(visitor, &typ)?)
+ }
})
}
}
@@ -4072,6 +4118,7 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Sub(_, _)
| ast::Instruction::Min(_, _)
| ast::Instruction::Max(_, _)
+ | ast::Instruction::Rcp(_, _)
| ast::Instruction::Mad(_, _) => None,
}
}
@@ -4289,7 +4336,6 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- src_is_addr: bool,
t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
let new_dst = visitor.id(
@@ -4304,11 +4350,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if src_is_addr {
- ArgumentSemantics::Address
- } else {
- ArgumentSemantics::Default
- },
+ sema: ArgumentSemantics::Default,
},
t,
)?;