diff options
-rw-r--r-- | level_zero/src/ze.rs | 32 | ||||
-rw-r--r-- | notcuda/src/impl/module.rs | 17 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 9 | ||||
-rw-r--r-- | ptx/src/translate.rs | 43 |
4 files changed, 88 insertions, 13 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 321e492..253ba4b 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -257,9 +257,15 @@ impl Module { ctx: &mut Context,
d: &Device,
binaries: &[&'a [u8]],
+ opts: Option<&CStr>,
) -> (Result<Self>, Option<BuildLog>) {
- let ocl_program = match Self::build_link_spirv_impl(binaries) {
- Err(_) => return (Err(sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN), None),
+ let ocl_program = match Self::build_link_spirv_impl(binaries, opts) {
+ Err(_) => {
+ return (
+ Err(sys::ze_result_t::ZE_RESULT_ERROR_MODULE_LINK_FAILURE),
+ None,
+ )
+ }
Ok(prog) => prog,
};
match ocl_core::get_program_info(&ocl_program, ocl_core::ProgramInfo::Binaries) {
@@ -271,7 +277,10 @@ impl Module { }
}
- fn build_link_spirv_impl<'a>(binaries: &[&'a [u8]]) -> ocl_core::Result<ocl_core::Program> {
+ fn build_link_spirv_impl<'a>(
+ binaries: &[&'a [u8]],
+ opts: Option<&CStr>,
+ ) -> ocl_core::Result<ocl_core::Program> {
let platforms = ocl_core::get_platform_ids()?;
let (platform, device) = platforms
.iter()
@@ -305,7 +314,22 @@ impl Module { for binary in binaries {
programs.push(ocl_core::create_program_with_il(&ocl_ctx, binary, None)?);
}
- let options = CString::default();
+ let options = match opts {
+ Some(o) => o.to_owned(),
+ None => CString::default(),
+ };
+ for program in programs.iter() {
+ ocl_core::compile_program(
+ program,
+ Some(&[device]),
+ &options,
+ &[],
+ &[],
+ None,
+ None,
+ None,
+ )?;
+ }
ocl_core::link_program::<ocl_core::DeviceId, _>(
&ocl_ctx,
Some(&[device]),
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index fa46bf4..cba030e 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -83,8 +83,21 @@ impl SpirvModule { self.binaries.len() * mem::size_of::<u32>(), ) }; - let l0_module = l0::Module::build_spirv(ctx, dev, byte_il, None).0?; - Ok(l0_module) + let l0_module = match self.should_link_ptx_impl { + None => { + l0::Module::build_spirv(ctx, dev, byte_il, Some(self.build_options.as_c_str())).0 + } + Some(ptx_impl) => { + l0::Module::build_link_spirv( + ctx, + &dev, + &[ptx_impl, byte_il], + Some(self.build_options.as_c_str()), + ) + .0 + } + }; + Ok(l0_module?) } } diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 027e891..c70ab5c 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -52,6 +52,7 @@ test_ptx!(setp, [10u64, 11u64], [1u64, 0u64]); test_ptx!(bra, [10u64], [11u64]);
test_ptx!(not, [0u64], [u64::max_value()]);
test_ptx!(shl, [11u64], [44u64]);
+test_ptx!(shl_link_hack, [11u64], [44u64]);
test_ptx!(cvt_sat_s_u, [-1i32], [0i32]);
test_ptx!(cvta, [3.0f32], [3.0f32]);
test_ptx!(block, [1u64], [2u64]);
@@ -202,7 +203,12 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>( let dev = devices.drain(0..1).next().unwrap();
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
let (module, maybe_log) = match module.should_link_ptx_impl {
- Some(ptx_impl) => ze::Module::build_link_spirv(&mut ctx, &dev, &[ptx_impl, byte_il]),
+ Some(ptx_impl) => ze::Module::build_link_spirv(
+ &mut ctx,
+ &dev,
+ &[ptx_impl, byte_il],
+ Some(module.build_options.as_c_str()),
+ ),
None => {
let (module, log) = ze::Module::build_spirv(
&mut ctx,
@@ -262,7 +268,6 @@ fn test_spvtxt_assert<'a>( let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?;
assert!(errors.len() == 0);
let spirv_module = translate::to_spirv_module(ast)?;
- eprintln!("{}", rspirv::binary::Disassemble::disassemble(&spirv_module.spirv));
let spv_context =
unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) };
assert!(spv_context != ptr::null_mut());
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 20c3edb..2b14bd7 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -2920,15 +2920,31 @@ fn emit_function_body_ops( }?;
}
ast::Instruction::Shl(t, a) => {
- let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
- builder.shift_left_logical(result_type, Some(a.dst), a.src1, a.src2)?;
+ let full_type = t.to_type();
+ let size_of = full_type.size_of();
+ let result_type = map.get_or_add(builder, SpirvType::from(full_type));
+ let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
+ builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?;
}
ast::Instruction::Shr(t, a) => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
+ let full_type = ast::ScalarType::from(*t);
+ let size_of = full_type.size_of();
+ let result_type = map.get_or_add_scalar(builder, full_type);
+ let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?;
if t.signed() {
- builder.shift_right_arithmetic(result_type, Some(a.dst), a.src1, a.src2)?;
+ builder.shift_right_arithmetic(
+ result_type,
+ Some(a.dst),
+ a.src1,
+ offset_src,
+ )?;
} else {
- builder.shift_right_logical(result_type, Some(a.dst), a.src1, a.src2)?;
+ builder.shift_right_logical(
+ result_type,
+ Some(a.dst),
+ a.src1,
+ offset_src,
+ )?;
}
}
ast::Instruction::Cvt(dets, arg) => {
@@ -3225,6 +3241,23 @@ fn emit_function_body_ops( Ok(())
}
+// HACK ALERT
+// For some reason IGC fails linking if the value and shift size are of different type
+fn insert_shift_hack(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ offset_var: spirv::Word,
+ size_of: usize,
+) -> Result<spirv::Word, TranslateError> {
+ let result_type = match size_of {
+ 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16),
+ 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64),
+ 4 => return Ok(offset_var),
+ _ => return Err(TranslateError::Unreachable),
+ };
+ Ok(builder.u_convert(result_type, None, offset_var)?)
+}
+
// TODO: check what kind of assembly do we emit
fn emit_logical_xor_spirv(
builder: &mut dr::Builder,
|