aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-07-25 15:19:43 +0200
committerAndrzej Janik <[email protected]>2021-07-25 15:19:43 +0200
commit8f68287b18afb1510ab055f0317a3f0dacce5d32 (patch)
tree991e5b0c7f008b31cc1a83e2d0573894fd0b16a5
parent9d4f26bd07f97e59da5556611490242a6830312a (diff)
downloadZLUDA-8f68287b18afb1510ab055f0317a3f0dacce5d32.tar.gz
ZLUDA-8f68287b18afb1510ab055f0317a3f0dacce5d32.zip
Tune generated code, add a workaround for geekbench
-rw-r--r--ptx/src/ast.rs1
-rw-r--r--ptx/src/ptx.lalrpop7
-rw-r--r--ptx/src/test/spirv_run/cos.spvtxt2
-rw-r--r--ptx/src/test/spirv_run/ex2.spvtxt2
-rw-r--r--ptx/src/test/spirv_run/fma.spvtxt2
-rw-r--r--ptx/src/test/spirv_run/lg2.spvtxt2
-rw-r--r--ptx/src/test/spirv_run/rcp.spvtxt7
-rw-r--r--ptx/src/test/spirv_run/sin.spvtxt2
-rw-r--r--ptx/src/translate.rs101
-rw-r--r--zluda/src/impl/device.rs2
-rw-r--r--zluda/src/impl/function.rs81
-rw-r--r--zluda/src/impl/memory.rs26
-rw-r--r--zluda_dump/src/debug.ptx55
-rw-r--r--zluda_dump/src/replay.py2
14 files changed, 212 insertions, 80 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 5432207..36e7191 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -261,6 +261,7 @@ pub enum Instruction<P: ArgParams> {
Call(CallInst<P>),
Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>),
+ Fma(ArithFloat, Arg4<P>),
Or(ScalarType, Arg3<P>),
Sub(ArithDetails, Arg3<P>),
Min(MinMaxDetails, Arg3<P>),
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 18ec4fb..b20a30a 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -743,6 +743,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstCall,
InstAbs,
InstMad,
+ InstFma,
InstOr,
InstAnd,
InstSub,
@@ -1345,7 +1346,11 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a),
"mad" ".hi" ".sat" ".s32" => todo!(),
- "fma" <f:ArithFloatMustRound> <a:Arg4> => ast::Instruction::Mad(ast::MulDetails::Float(f), a),
+};
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma
+InstFma: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "fma" <f:ArithFloatMustRound> <a:Arg4> => ast::Instruction::Fma(f, a),
};
SignedIntType: ast::ScalarType = {
diff --git a/ptx/src/test/spirv_run/cos.spvtxt b/ptx/src/test/spirv_run/cos.spvtxt
index 6fafcb5..8d6a0ca 100644
--- a/ptx/src/test/spirv_run/cos.spvtxt
+++ b/ptx/src/test/spirv_run/cos.spvtxt
@@ -37,7 +37,7 @@
%11 = OpLoad %float %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %float %6
- %13 = OpExtInst %float %21 cos %14
+ %13 = OpExtInst %float %21 native_cos %14
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %float %6
diff --git a/ptx/src/test/spirv_run/ex2.spvtxt b/ptx/src/test/spirv_run/ex2.spvtxt
index 62c44b8..3d7b58d 100644
--- a/ptx/src/test/spirv_run/ex2.spvtxt
+++ b/ptx/src/test/spirv_run/ex2.spvtxt
@@ -37,7 +37,7 @@
%11 = OpLoad %float %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %float %6
- %13 = OpExtInst %float %21 exp2 %14
+ %13 = OpExtInst %float %21 native_exp2 %14
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %float %6
diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt
index 8cc0e16..91a2159 100644
--- a/ptx/src/test/spirv_run/fma.spvtxt
+++ b/ptx/src/test/spirv_run/fma.spvtxt
@@ -59,7 +59,7 @@
%20 = OpLoad %float %6
%21 = OpLoad %float %7
%22 = OpLoad %float %8
- %19 = OpExtInst %float %35 mad %20 %21 %22
+ %19 = OpExtInst %float %35 fma %20 %21 %22
OpStore %6 %19
%23 = OpLoad %ulong %5
%24 = OpLoad %float %6
diff --git a/ptx/src/test/spirv_run/lg2.spvtxt b/ptx/src/test/spirv_run/lg2.spvtxt
index 3c7ca77..c30eeff 100644
--- a/ptx/src/test/spirv_run/lg2.spvtxt
+++ b/ptx/src/test/spirv_run/lg2.spvtxt
@@ -37,7 +37,7 @@
%11 = OpLoad %float %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %float %6
- %13 = OpExtInst %float %21 log2 %14
+ %13 = OpExtInst %float %21 native_log2 %14
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %float %6
diff --git a/ptx/src/test/spirv_run/rcp.spvtxt b/ptx/src/test/spirv_run/rcp.spvtxt
index 2d56ee8..09fa0d9 100644
--- a/ptx/src/test/spirv_run/rcp.spvtxt
+++ b/ptx/src/test/spirv_run/rcp.spvtxt
@@ -10,7 +10,7 @@
%21 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "rcp"
- OpDecorate %13 FPFastMathMode AllowRecip
+ OpExecutionMode %1 ContractionOff
%void = OpTypeVoid
%ulong = OpTypeInt 64 0
%24 = OpTypeFunction %void %ulong %ulong
@@ -18,7 +18,6 @@
%float = OpTypeFloat 32
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
- %float_1 = OpConstant %float 1
%1 = OpFunction %void None %24
%7 = OpFunctionParameter %ulong
%8 = OpFunctionParameter %ulong
@@ -39,11 +38,11 @@
%11 = OpLoad %float %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %float %6
- %13 = OpFDiv %float %float_1 %14
+ %13 = OpExtInst %float %21 native_recip %14
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %float %6
%18 = OpConvertUToPtr %_ptr_Generic_float %15
OpStore %18 %16 Aligned 4
OpReturn
- OpFunctionEnd
+ OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/sin.spvtxt b/ptx/src/test/spirv_run/sin.spvtxt
index 618d5f2..02eba40 100644
--- a/ptx/src/test/spirv_run/sin.spvtxt
+++ b/ptx/src/test/spirv_run/sin.spvtxt
@@ -37,7 +37,7 @@
%11 = OpLoad %float %17 Aligned 4
OpStore %6 %11
%14 = OpLoad %float %6
- %13 = OpExtInst %float %21 sin %14
+ %13 = OpExtInst %float %21 native_sin %14
OpStore %6 %13
%15 = OpLoad %ulong %5
%16 = OpLoad %float %6
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index c236438..91e4237 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -559,25 +559,29 @@ fn emit_directives<'input>(
&directives,
kernel_info,
)?;
- for t in f.tuning.iter() {
- match *t {
- ast::TuningDirective::MaxNtid(nx, ny, nz) => {
- builder.execution_mode(
- fn_id,
- spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
- [nx, ny, nz],
- );
- }
- ast::TuningDirective::ReqNtid(nx, ny, nz) => {
- builder.execution_mode(
- fn_id,
- spirv_headers::ExecutionMode::LocalSize,
- [nx, ny, nz],
- );
+ if func_decl.name.is_kernel() {
+ // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx)
+ builder.execution_mode(fn_id, spirv_headers::ExecutionMode::ContractionOff, []);
+ for t in f.tuning.iter() {
+ match *t {
+ ast::TuningDirective::MaxNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id,
+ spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
+ [nx, ny, nz],
+ );
+ }
+ ast::TuningDirective::ReqNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id,
+ spirv_headers::ExecutionMode::LocalSize,
+ [nx, ny, nz],
+ );
+ }
+ // Too architecture specific
+ ast::TuningDirective::MaxNReg(..)
+ | ast::TuningDirective::MinNCtaPerSm(..) => {}
}
- // Too architecture specific
- ast::TuningDirective::MaxNReg(..)
- | ast::TuningDirective::MinNCtaPerSm(..) => {}
}
}
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
@@ -2772,6 +2776,7 @@ fn emit_function_body_ops(
emit_mad_float(builder, map, opencl, desc, arg)?
}
},
+ ast::Instruction::Fma(fma, arg) => emit_fma_float(builder, map, opencl, fma, arg)?,
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
if *t == ast::ScalarType::Pred {
@@ -2798,7 +2803,7 @@ fn emit_function_body_ops(
emit_max(builder, map, opencl, d, a)?;
}
ast::Instruction::Rcp(d, a) => {
- emit_rcp(builder, map, d, a)?;
+ emit_rcp(builder, map, opencl, d, a)?;
}
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
@@ -2901,7 +2906,7 @@ fn emit_function_body_ops(
result_type,
Some(arg.dst),
opencl,
- spirv::CLOp::sin as u32,
+ spirv::CLOp::native_sin as u32,
[dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
@@ -2911,7 +2916,7 @@ fn emit_function_body_ops(
result_type,
Some(arg.dst),
opencl,
- spirv::CLOp::cos as u32,
+ spirv::CLOp::native_cos as u32,
[dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
@@ -2921,7 +2926,7 @@ fn emit_function_body_ops(
result_type,
Some(arg.dst),
opencl,
- spirv::CLOp::log2 as u32,
+ spirv::CLOp::native_log2 as u32,
[dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
@@ -2931,7 +2936,7 @@ fn emit_function_body_ops(
result_type,
Some(arg.dst),
opencl,
- spirv::CLOp::exp2 as u32,
+ spirv::CLOp::native_exp2 as u32,
[dr::Operand::IdRef(arg.src)].iter().cloned(),
)?;
}
@@ -3237,20 +3242,31 @@ fn emit_mul_float(
fn emit_rcp(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
+ opencl: spirv::Word,
desc: &ast::RcpDetails,
- a: &ast::Arg2<ExpandedArgParams>,
+ arg: &ast::Arg2<ExpandedArgParams>,
) -> Result<(), TranslateError> {
let (instr_type, constant) = if desc.is_f64 {
(ast::ScalarType::F64, vec_repr(1.0f64))
} else {
(ast::ScalarType::F32, vec_repr(1.0f32))
};
- let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
let result_type = map.get_or_add_scalar(builder, instr_type);
- builder.f_div(result_type, Some(a.dst), one, a.src)?;
- emit_rounding_decoration(builder, a.dst, desc.rounding);
+ if !desc.is_f64 && desc.rounding.is_none() {
+ builder.ext_inst(
+ result_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::native_recip as u32,
+ [dr::Operand::IdRef(arg.src)].iter().cloned(),
+ )?;
+ return Ok(());
+ }
+ let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
+ builder.f_div(result_type, Some(arg.dst), one, arg.src)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
builder.decorate(
- a.dst,
+ arg.dst,
spirv::Decoration::FPFastMathMode,
[dr::Operand::FPFastMathMode(
spirv::FPFastMathMode::ALLOW_RECIP,
@@ -3372,6 +3388,30 @@ fn emit_mad_sint(
Ok(())
}
+fn emit_fma_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::ArithFloat,
+ arg: &ast::Arg4<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::fma as spirv::Word,
+ [
+ dr::Operand::IdRef(arg.src1),
+ dr::Operand::IdRef(arg.src2),
+ dr::Operand::IdRef(arg.src3),
+ ]
+ .iter()
+ .cloned(),
+ )?;
+ Ok(())
+}
+
fn emit_mad_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -5713,6 +5753,10 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let is_wide = d.is_wide();
ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?)
}
+ ast::Instruction::Fma(d, a) => {
+ let inst_type = ast::Type::Scalar(d.typ);
+ ast::Instruction::Fma(d, a.map(visitor, &inst_type, false)?)
+ }
ast::Instruction::Or(t, a) => ast::Instruction::Or(
t,
a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
@@ -6106,6 +6150,7 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control
.flush_to_zero
.map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
+ ast::Instruction::Fma(d, _) => d.flush_to_zero.map(|ftz| (ftz, d.typ.size_of())),
ast::Instruction::Setp(details, _) => details
.flush_to_zero
.map(|ftz| (ftz, details.typ.size_of())),
diff --git a/zluda/src/impl/device.rs b/zluda/src/impl/device.rs
index 3b43c49..e886eb9 100644
--- a/zluda/src/impl/device.rs
+++ b/zluda/src/impl/device.rs
@@ -494,7 +494,7 @@ pub fn get_attribute(
l0::sys::ze_result_t::ZE_RESULT_ERROR_UNSUPPORTED_FEATURE,
))
*/
- return Ok(());
+ 0
}
};
unsafe { *pi = value };
diff --git a/zluda/src/impl/function.rs b/zluda/src/impl/function.rs
index 05f864b..548936f 100644
--- a/zluda/src/impl/function.rs
+++ b/zluda/src/impl/function.rs
@@ -51,6 +51,37 @@ impl LegacyArguments {
}
}
+unsafe fn set_arg(
+ kernel: &ocl_core::Kernel,
+ arg_index: usize,
+ arg_size: usize,
+ arg_value: *const c_void,
+ is_mem: bool,
+) -> Result<(), CUresult> {
+ if is_mem {
+ let error = 0;
+ unsafe {
+ ocl_core::ffi::clSetKernelArgSVMPointer(
+ kernel.as_ptr(),
+ arg_index as u32,
+ *(arg_value as *const _),
+ )
+ };
+ if error != 0 {
+ panic!("clSetKernelArgSVMPointer");
+ }
+ } else {
+ unsafe {
+ ocl_core::set_kernel_arg(
+ kernel,
+ arg_index as u32,
+ ocl_core::ArgVal::from_raw(arg_size, arg_value, is_mem),
+ )?;
+ };
+ }
+ Ok(())
+}
+
pub fn launch_kernel(
f: *mut Function,
grid_dim_x: c_uint,
@@ -74,27 +105,7 @@ pub fn launch_kernel(
let func: &mut FunctionData = unsafe { &mut *f }.as_result_mut()?;
if kernel_params != ptr::null_mut() {
for (i, &(arg_size, is_mem)) in func.arg_size.iter().enumerate() {
- if is_mem {
- let error = 0;
- unsafe {
- ocl_core::ffi::clSetKernelArgSVMPointer(
- func.base.as_ptr(),
- i as u32,
- *(*kernel_params.add(i) as *const _),
- )
- };
- if error != 0 {
- panic!("clSetKernelArgSVMPointer");
- }
- } else {
- unsafe {
- ocl_core::set_kernel_arg(
- &func.base,
- i as u32,
- ocl_core::ArgVal::from_raw(arg_size, *kernel_params.add(i), is_mem),
- )?;
- };
- }
+ unsafe { set_arg(&func.base, i, arg_size, *kernel_params.add(i), is_mem)? };
}
} else {
let mut offset = 0;
@@ -126,15 +137,13 @@ pub fn launch_kernel(
for (i, &(arg_size, is_mem)) in func.arg_size.iter().enumerate() {
let buffer_offset = round_up_to_multiple(offset, arg_size);
unsafe {
- ocl_core::set_kernel_arg(
+ set_arg(
&func.base,
- i as u32,
- ocl_core::ArgVal::from_raw(
- arg_size,
- buffer_ptr.add(buffer_offset) as *const _,
- is_mem,
- ),
- )?;
+ i,
+ arg_size,
+ buffer_ptr.add(buffer_offset) as *const _,
+ is_mem,
+ )?
};
offset = buffer_offset + arg_size;
}
@@ -144,11 +153,13 @@ pub fn launch_kernel(
}
if func.use_shared_mem {
unsafe {
- ocl_core::set_kernel_arg(
+ set_arg(
&func.base,
- func.arg_size.len() as u32,
- ocl_core::ArgVal::from_raw(shared_mem_bytes as usize, ptr::null(), false),
- )?;
+ func.arg_size.len(),
+ shared_mem_bytes as usize,
+ ptr::null(),
+ false,
+ )?
};
}
let global_dims = [
@@ -192,9 +203,9 @@ pub(crate) fn get_attribute(
CUfunction_attribute::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK => {
let max_threads = GlobalState::lock_function(func, |func| {
if let ocl_core::KernelWorkGroupInfoResult::WorkGroupSize(size) =
- ocl_core::get_kernel_work_group_info::<ocl_core::DeviceId>(
+ ocl_core::get_kernel_work_group_info::<()>(
&func.base,
- unsafe { ocl_core::DeviceId::null() },
+ (),
ocl_core::KernelWorkGroupInfo::WorkGroupSize,
)?
{
diff --git a/zluda/src/impl/memory.rs b/zluda/src/impl/memory.rs
index 3e96a8c..7293ca6 100644
--- a/zluda/src/impl/memory.rs
+++ b/zluda/src/impl/memory.rs
@@ -1,16 +1,32 @@
-use super::{stream, CUresult, GlobalState};
+use super::{
+ stream::{self, CU_STREAM_LEGACY},
+ CUresult, GlobalState,
+};
use std::{
ffi::c_void,
mem::{self, size_of},
};
pub fn alloc_v2(dptr: *mut *mut c_void, bytesize: usize) -> Result<(), CUresult> {
- let ptr = GlobalState::lock_current_context(|ctx| {
- let dev = unsafe { &mut *ctx.device };
- Ok::<_, CUresult>(unsafe {
+ let ptr = GlobalState::lock_stream(CU_STREAM_LEGACY, |stream_data| {
+ let dev = unsafe { &*(*stream_data.context).device };
+ let queue = stream_data.cmd_list.as_ref().unwrap();
+ let ptr = unsafe {
dev.ocl_ext
.device_mem_alloc(&dev.ocl_context, &dev.ocl_base, bytesize, 0)?
- })
+ };
+ // CUDA does the same thing and e.g. GeekBench relies on this behavior
+ let event = unsafe {
+ dev.ocl_ext.enqueue_memfill(
+ queue,
+ ptr,
+ &0u8 as *const u8 as *const c_void,
+ 1,
+ bytesize,
+ )?
+ };
+ ocl_core::wait_for_event(&event)?;
+ Ok::<_, CUresult>(ptr)
})??;
unsafe { *dptr = ptr };
Ok(())
diff --git a/zluda_dump/src/debug.ptx b/zluda_dump/src/debug.ptx
new file mode 100644
index 0000000..29104f8
--- /dev/null
+++ b/zluda_dump/src/debug.ptx
@@ -0,0 +1,55 @@
+/*
+ This collection of functions is here to assist with debugging
+ You use it by manually pasting into a module.ptx that was generated by zluda_dump
+ and inspecting content of additional debug buffer in replay.py
+*/
+
+.func debug_dump_from_thread_16(.reg.b64 debug_addr, .reg.u32 global_id_0, .reg.b16 value)
+{
+ .reg.u32 local_id;
+ mov.u32 local_id, %tid.x;
+ .reg.u32 local_size;
+ mov.u32 local_size, %ntid.x;
+ .reg.u32 group_id;
+ mov.u32 group_id, %ctaid.x;
+ .reg.b32 global_id;
+ mad.lo.u32 global_id, group_id, local_size, local_id;
+ .reg.pred should_exit;
+ setp.ne.u32 should_exit, global_id, global_id_0;
+ @should_exit bra END;
+ .reg.b32 index;
+ ld.global.u32 index, [debug_addr];
+ st.global.u32 [debug_addr], index+1;
+ .reg.u64 st_offset;
+ cvt.u64.u32 st_offset, index;
+ mad.lo.u64 st_offset, st_offset, 2, 4; // sizeof(b16), sizeof(32)
+ add.u64 debug_addr, debug_addr, st_offset;
+ st.global.u16 [debug_addr], value;
+END:
+ ret;
+}
+
+.func debug_dump_from_thread_32(.reg.b64 debug_addr, .reg.u32 global_id_0, .reg.b32 value)
+{
+ .reg.u32 local_id;
+ mov.u32 local_id, %tid.x;
+ .reg.u32 local_size;
+ mov.u32 local_size, %ntid.x;
+ .reg.u32 group_id;
+ mov.u32 group_id, %ctaid.x;
+ .reg.b32 global_id;
+ mad.lo.u32 global_id, group_id, local_size, local_id;
+ .reg.pred should_exit;
+ setp.ne.u32 should_exit, global_id, global_id_0;
+ @should_exit bra END;
+ .reg.b32 index;
+ ld.global.u32 index, [debug_addr];
+ st.global.u32 [debug_addr], index+1;
+ .reg.u64 st_offset;
+ cvt.u64.u32 st_offset, index;
+ mad.lo.u64 st_offset, st_offset, 4, 4; // sizeof(b32), sizeof(32)
+ add.u64 debug_addr, debug_addr, st_offset;
+ st.global.u32 [debug_addr], value;
+END:
+ ret;
+}
diff --git a/zluda_dump/src/replay.py b/zluda_dump/src/replay.py
index 723d954..c331d53 100644
--- a/zluda_dump/src/replay.py
+++ b/zluda_dump/src/replay.py
@@ -53,7 +53,7 @@ def parse_arguments(dump_path, prefix):
def append_debug_buffer(args, grid, block):
args = list(args)
- items = block[0] * block[1] * block[2] * block[0] * block[1] * block[2]
+ items = grid[0] * grid[1] * grid[2] * block[0] * block[1] * block[2]
debug_buff = np.zeros(items, dtype=np.uint32)
args.append((drv.InOut(debug_buff), debug_buff))
return args