diff options
author | Andrzej Janik <[email protected]> | 2024-12-04 18:57:09 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-12-04 18:57:09 +0100 |
commit | 320cf9396c7c14f69dbff1e561012487d9e17c5b (patch) | |
tree | 4d88887a2e9e8b80817f88874f5105d545ae11bd | |
parent | 7a6df9dcbf59edef371e7f63c16c64916ddb0c0b (diff) | |
download | ZLUDA-320cf9396c7c14f69dbff1e561012487d9e17c5b.tar.gz ZLUDA-320cf9396c7c14f69dbff1e561012487d9e17c5b.zip |
Implement abs
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index fa011a3..739e53d 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -96,10 +96,6 @@ impl Module { let memory_buffer = unsafe { LLVMWriteBitcodeToMemoryBuffer(self.get()) };
MemoryBuffer(memory_buffer)
}
-
- fn write_to_stderr(&self) {
- unsafe { LLVMDumpModule(self.get()) };
- }
}
impl Drop for Module {
@@ -183,7 +179,6 @@ pub(super) fn run<'input>( Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
- module.write_to_stderr();
if let Err(err) = module.verify() {
panic!("{:?}", err);
}
@@ -529,7 +524,7 @@ impl<'a> MethodEmitContext<'a> { ast::Instruction::Shl { data, arguments } => self.emit_shl(data, arguments),
ast::Instruction::Ret { data } => Ok(self.emit_ret(data)),
ast::Instruction::Cvta { data, arguments } => self.emit_cvta(data, arguments),
- ast::Instruction::Abs { .. } => todo!(),
+ ast::Instruction::Abs { data, arguments } => self.emit_abs(data, arguments),
ast::Instruction::Mad { data, arguments } => self.emit_mad(data, arguments),
ast::Instruction::Fma { data, arguments } => self.emit_fma(data, arguments),
ast::Instruction::Sub { data, arguments } => self.emit_sub(data, arguments),
@@ -2149,6 +2144,30 @@ impl<'a> MethodEmitContext<'a> { Ok(())
}
+ fn emit_abs(
+ &mut self,
+ data: ast::TypeFtz,
+ arguments: ptx_parser::AbsArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let llvm_type = get_scalar_type(self.context, data.type_);
+ let src = self.resolver.value(arguments.src)?;
+ let (prefix, intrinsic_arguments) = if data.type_.kind() == ast::ScalarKind::Float {
+ ("llvm.fabs", vec![(src, llvm_type)])
+ } else {
+ let pred = get_scalar_type(self.context, ast::ScalarType::Pred);
+ let zero = unsafe { LLVMConstInt(pred, 0, 0) };
+ ("llvm.abs", vec![(src, llvm_type), (zero, pred)])
+ };
+ let llvm_intrinsic = format!("{}.{}\0", prefix, LLVMTypeDisplay(data.type_));
+ self.emit_intrinsic(
+ unsafe { CStr::from_bytes_with_nul_unchecked(llvm_intrinsic.as_bytes()) },
+ Some(arguments.dst),
+ &data.type_.into(),
+ intrinsic_arguments,
+ )?;
+ Ok(())
+ }
+
/*
// Currently unused, LLVM 18 (ROCm 6.2) does not support `llvm.set.rounding`
// Should be available in LLVM 19
|