aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-07-25 15:19:43 +0200
committerAndrzej Janik <[email protected]>2021-07-25 15:19:43 +0200
commit8f68287b18afb1510ab055f0317a3f0dacce5d32 (patch)
tree991e5b0c7f008b31cc1a83e2d0573894fd0b16a5 /ptx/src/translate.rs
parent9d4f26bd07f97e59da5556611490242a6830312a (diff)
downloadZLUDA-8f68287b18afb1510ab055f0317a3f0dacce5d32.tar.gz
ZLUDA-8f68287b18afb1510ab055f0317a3f0dacce5d32.zip
Tune generated code, add a workaround for geekbench
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs101
1 files changed, 73 insertions, 28 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index c236438..91e4237 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -559,25 +559,29 @@ fn emit_directives<'input>(
&directives,
kernel_info,
)?;
- for t in f.tuning.iter() {
- match *t {
- ast::TuningDirective::MaxNtid(nx, ny, nz) => {
- builder.execution_mode(
- fn_id,
- spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
- [nx, ny, nz],
- );
- }
- ast::TuningDirective::ReqNtid(nx, ny, nz) => {
- builder.execution_mode(
- fn_id,
- spirv_headers::ExecutionMode::LocalSize,
- [nx, ny, nz],
- );
+ if func_decl.name.is_kernel() {
+ // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx)
+ builder.execution_mode(fn_id, spirv_headers::ExecutionMode::ContractionOff, []);
+ for t in f.tuning.iter() {
+ match *t {
+ ast::TuningDirective::MaxNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id,
+ spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
+ [nx, ny, nz],
+ );
+ }
+ ast::TuningDirective::ReqNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id,
+ spirv_headers::ExecutionMode::LocalSize,
+ [nx, ny, nz],
+ );
+ }
+ // Too architecture specific
+ ast::TuningDirective::MaxNReg(..)
+ | ast::TuningDirective::MinNCtaPerSm(..) => {}
}
- // Too architecture specific
- ast::TuningDirective::MaxNReg(..)
- | ast::TuningDirective::MinNCtaPerSm(..) => {}
}
}
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
@@ -2772,6 +2776,7 @@ fn emit_function_body_ops(
emit_mad_float(builder, map, opencl, desc, arg)?
}
},
+ ast::Instruction::Fma(fma, arg) => emit_fma_float(builder, map, opencl, fma, arg)?,
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if *t == ast::ScalarType::Pred {
@@ -2798,7 +2803,7 @@ fn emit_function_body_ops(
emit_max(builder, map, opencl, d, a)?;
}
ast::Instruction::Rcp(d, a) => {
- emit_rcp(builder, map, d, a)?;
+ emit_rcp(builder, map, opencl, d, a)?;
}
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
@@ -2901,7 +2906,7 @@ fn emit_function_body_ops(
result_type,
Some(arg.dst),
opencl,
- spirv::CLOp::sin as u32,
+ spirv::CLOp::native_sin as u32,
[dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
@@ -2911,7 +2916,7 @@ fn emit_function_body_ops(
result_type,
Some(arg.dst),
opencl,
- spirv::CLOp::cos as u32,
+ spirv::CLOp::native_cos as u32,
[dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
@@ -2921,7 +2926,7 @@ fn emit_function_body_ops(
result_type,
Some(arg.dst),
opencl,
- spirv::CLOp::log2 as u32,
+ spirv::CLOp::native_log2 as u32,
[dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
@@ -2931,7 +2936,7 @@ fn emit_function_body_ops(
result_type,
Some(arg.dst),
opencl,
- spirv::CLOp::exp2 as u32,
+ spirv::CLOp::native_exp2 as u32,
[dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
@@ -3237,20 +3242,31 @@ fn emit_mul_float(
fn emit_rcp(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
+ opencl: spirv::Word,
desc: &ast::RcpDetails,
- a: &ast::Arg2<ExpandedArgParams>,
+ arg: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), TranslateError> {
let (instr_type, constant) = if desc.is_f64 {
(ast::ScalarType::F64, vec_repr(1.0f64))
} else {
(ast::ScalarType::F32, vec_repr(1.0f32))
};
- let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
let result_type = map.get_or_add_scalar(builder, instr_type);
- builder.f_div(result_type, Some(a.dst), one, a.src)?;
- emit_rounding_decoration(builder, a.dst, desc.rounding);
+ if !desc.is_f64 && desc.rounding.is_none() {
+ builder.ext_inst(
+ result_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::native_recip as u32,
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
+ )?;
+ return Ok(());
+ }
+ let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
+ builder.f_div(result_type, Some(arg.dst), one, arg.src)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
builder.decorate(
- a.dst,
+ arg.dst,
spirv::Decoration::FPFastMathMode,
[dr::Operand::FPFastMathMode(
spirv::FPFastMathMode::ALLOW_RECIP,
@@ -3372,6 +3388,30 @@ fn emit_mad_sint(
Ok(())
}
+fn emit_fma_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::ArithFloat,
+ arg: &ast::Arg4<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::fma as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1),
+ dr::Operand::IdRef(arg.src2),
+ dr::Operand::IdRef(arg.src3),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
fn emit_mad_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -5713,6 +5753,10 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let is_wide = d.is_wide();
ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?)
}
+ ast::Instruction::Fma(d, a) => {
+ let inst_type = ast::Type::Scalar(d.typ);
+ ast::Instruction::Fma(d, a.map(visitor, &inst_type, false)?)
+ }
ast::Instruction::Or(t, a) => ast::Instruction::Or(
t,
a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
@@ -6106,6 +6150,7 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
+ ast::Instruction::Fma(d, _) => d.flush_to_zero.map(|ftz| (ftz, d.typ.size_of())),
ast::Instruction::Setp(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, details.typ.size_of())),