diff options
author | Andrzej Janik <[email protected]> | 2024-04-06 01:23:53 +0200 |
---|---|---|
committer | GitHub <[email protected]> | 2024-04-06 01:23:53 +0200 |
commit | 774f4bcb37c39f876caf80ae0d39420fa4bc1c8b (patch) | |
tree | 6c257a53205ca27669bf4fb24817f6cff886e25b | |
parent | 0d9ace247567a07554294dc4653624943334a410 (diff) | |
download | ZLUDA-774f4bcb37c39f876caf80ae0d39420fa4bc1c8b.tar.gz ZLUDA-774f4bcb37c39f876caf80ae0d39420fa4bc1c8b.zip |
Implement sad instruction (#198)
-rw-r--r-- | ptx/src/ast.rs | 1 | ||||
-rw-r--r-- | ptx/src/emit.rs | 31 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 10 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 1 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/sad.ll | 63 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/sad.ptx | 29 | ||||
-rw-r--r-- | ptx/src/translate.rs | 4 |
7 files changed, 139 insertions, 0 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 93793e6..e5b5f97 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -476,6 +476,7 @@ pub enum Instruction<P: ArgParams> { MatchAny(Arg3<P>),
Red(AtomDetails, Arg2St<P>),
Nanosleep(Arg1<P>),
+ Sad(ScalarType, Arg4<P>),
}
#[derive(Copy, Clone)]
diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index d4d6df6..9e62d5b 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -13,6 +13,7 @@ use zluda_llvm::prelude::*; use zluda_llvm::zluda::*; use zluda_llvm::*; +use crate::ast::SetpData; use crate::translate::{ self, Arg4CarryOut, ConstType, ConversionKind, DenormSummary, ExpandedArgParams, FPDenormMode, MadCCDetails, MadCDetails, TranslationModule, TypeKind, TypeParts, @@ -1137,6 +1138,7 @@ fn emit_instruction( ast::Instruction::Vshr(arg) => emit_inst_vshr(ctx, arg)?, ast::Instruction::Set(details, arg) => emit_inst_set(ctx, details, arg)?, ast::Instruction::Red(details, arg) => emit_inst_red(ctx, details, arg)?, + ast::Instruction::Sad(type_, arg) => emit_inst_sad(ctx, *type_, arg)?, // replaced by function calls or Statement variants ast::Instruction::Activemask { .. } | ast::Instruction::Bar(..) @@ -1161,6 +1163,35 @@ fn emit_instruction( }) } +fn emit_inst_sad( + ctx: &mut EmitContext, + type_: ast::ScalarType, + arg: &ast::Arg4<ExpandedArgParams>, +) -> Result<(), TranslateError> { + let builder = ctx.builder.get(); + let less_than = emit_inst_setp_int( + ctx, + &SetpData { + typ: type_, + flush_to_zero: None, + cmp_op: ast::SetpCompareOp::Greater, + }, + None, + arg.src1, + arg.src2, + )?; + let a = ctx.names.value(arg.src1)?; + let b = ctx.names.value(arg.src2)?; + let a_minus_b = unsafe { LLVMBuildSub(builder, a, b, LLVM_UNNAMED) }; + let b_minus_a = unsafe { LLVMBuildSub(builder, b, a, LLVM_UNNAMED) }; + let a_or_b = unsafe { LLVMBuildSelect(builder, less_than, a_minus_b, b_minus_a, LLVM_UNNAMED) }; + let src3 = ctx.names.value(arg.src3)?; + ctx.names.register_result(arg.dst, |dst_name| unsafe { + LLVMBuildAdd(builder, src3, a_or_b, dst_name) + }); + Ok(()) +} + fn emit_inst_red( ctx: &mut EmitContext, details: &ast::AtomDetails, diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d5c9b61..5ec97e1 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -224,6 +224,7 @@ match { "rem",
"ret",
"rsqrt",
+ "sad",
"selp",
"set",
"setp",
@@ -305,6 +306,7 @@ ExtendedID : &'input str = { "rem",
"ret",
"rsqrt",
+ "sad",
"selp",
"set",
"setp",
@@ -839,6 +841,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstMatch,
InstRed,
InstNanosleep,
+ InstSad
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -2377,6 +2380,13 @@ InstNanosleep: ast::Instruction<ast::ParsedArgParams<'input>> = { }
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sad
+InstSad: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "sad" <type_:IntType> <a:Arg4> => {
+ ast::Instruction::Sad(type_, a)
+ }
+}
+
NegTypeFtz: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
".f16x2" => ast::ScalarType::F16x2,
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 1ec030b..5fb5a8b 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -364,6 +364,7 @@ test_ptx!( [1923569713u64, 1923569712],
[1923569713u64, 1923569712]
);
+test_ptx!(sad, [2147483648u32, 2, 13], [2147483659u32, 2147483663]);
test_ptx_warp!(
shfl,
diff --git a/ptx/src/test/spirv_run/sad.ll b/ptx/src/test/spirv_run/sad.ll new file mode 100644 index 0000000..c7a5726 --- /dev/null +++ b/ptx/src/test/spirv_run/sad.ll @@ -0,0 +1,63 @@ +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7" +target triple = "amdgcn-amd-amdhsa" + +define protected amdgpu_kernel void @sad(ptr addrspace(4) byref(i64) %"38", ptr addrspace(4) byref(i64) %"39") #0 { +"56": + %"11" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"11", align 1 + %"4" = alloca i64, align 8, addrspace(5) + %"5" = alloca i64, align 8, addrspace(5) + %"6" = alloca i32, align 4, addrspace(5) + %"7" = alloca i32, align 4, addrspace(5) + %"8" = alloca i32, align 4, addrspace(5) + %"9" = alloca i32, align 4, addrspace(5) + %"10" = alloca i32, align 4, addrspace(5) + %"12" = load i64, ptr addrspace(4) %"38", align 8 + store i64 %"12", ptr addrspace(5) %"4", align 8 + %"13" = load i64, ptr addrspace(4) %"39", align 8 + store i64 %"13", ptr addrspace(5) %"5", align 8 + %"15" = load i64, ptr addrspace(5) %"4", align 8 + %"41" = inttoptr i64 %"15" to ptr + %"40" = load i32, ptr %"41", align 4 + store i32 %"40", ptr addrspace(5) %"6", align 4 + %"17" = load i64, ptr addrspace(5) %"4", align 8 + %"42" = inttoptr i64 %"17" to ptr + %"58" = getelementptr inbounds i8, ptr %"42", i64 4 + %"43" = load i32, ptr %"58", align 4 + store i32 %"43", ptr addrspace(5) %"7", align 4 + %"19" = load i64, ptr addrspace(5) %"4", align 8 + %"44" = inttoptr i64 %"19" to ptr + %"60" = getelementptr inbounds i8, ptr %"44", i64 8 + %"45" = load i32, ptr %"60", align 4 + store i32 %"45", ptr addrspace(5) %"8", align 4 + %"21" = load i32, ptr addrspace(5) %"6", align 4 + %"22" = load i32, ptr addrspace(5) %"7", align 4 + %"23" = load i32, ptr addrspace(5) %"8", align 4 + %0 = icmp ugt i32 %"21", %"22" + %1 = sub i32 %"21", %"22" + %2 = sub i32 %"22", %"21" + %3 = select i1 %0, i32 %1, i32 %2 + %"46" = add i32 %"23", %3 + store i32 %"46", ptr addrspace(5) %"9", align 4 + %"25" = load i32, ptr addrspace(5) %"6", align 4 + %"26" = load i32, ptr addrspace(5) %"7", align 4 + %"27" = load i32, ptr addrspace(5) %"8", align 4 + %4 = icmp sgt i32 %"25", %"26" + %5 = sub i32 %"25", %"26" + %6 = sub i32 %"26", %"25" + %7 = select i1 %4, i32 %5, i32 %6 + %"50" = add i32 %"27", %7 + store i32 %"50", ptr addrspace(5) %"10", align 4 + %"28" = load i64, ptr addrspace(5) %"5", align 8 + %"29" = load i32, ptr addrspace(5) %"9", align 4 + %"54" = inttoptr i64 %"28" to ptr + store i32 %"29", ptr %"54", align 4 + %"30" = load i64, ptr addrspace(5) %"5", align 8 + %"31" = load i32, ptr addrspace(5) %"10", align 4 + %"55" = inttoptr i64 %"30" to ptr + %"62" = getelementptr inbounds i8, ptr %"55", i64 4 + store i32 %"31", ptr %"62", align 4 + ret void +} + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } diff --git a/ptx/src/test/spirv_run/sad.ptx b/ptx/src/test/spirv_run/sad.ptx new file mode 100644 index 0000000..c7ed6c6 --- /dev/null +++ b/ptx/src/test/spirv_run/sad.ptx @@ -0,0 +1,29 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.entry sad( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b32 a; + .reg .b32 b; + .reg .b32 c; + .reg .b32 result_u32; + .reg .b32 result_s32; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.u32 a, [in_addr]; + ld.u32 b, [in_addr+4]; + ld.u32 c, [in_addr+8]; + sad.u32 result_u32, a, b, c; + sad.s32 result_s32, a, b, c; + st.b32 [out_addr], result_u32; + st.b32 [out_addr+4], result_s32; + ret; +} diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 3b75ec9..61a74c9 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -6644,6 +6644,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::StateSpace::Reg,
)),
)?),
+ ast::Instruction::Sad(type_, a) => {
+ ast::Instruction::Sad(type_, a.map(visitor, &ast::Type::Scalar(type_), false)?)
+ }
})
}
}
@@ -7000,6 +7003,7 @@ impl<T: ast::ArgParams> ast::Instruction<T> { ast::Instruction::Shf(..) => None,
ast::Instruction::Vote(..) => None,
ast::Instruction::Nanosleep(..) => None,
+ ast::Instruction::Sad(_, _) => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
|