diff options
author | NyanCatTW1 <[email protected]> | 2024-04-06 01:12:59 +0800 |
---|---|---|
committer | GitHub <[email protected]> | 2024-04-05 19:12:59 +0200 |
commit | 76bae5f91bf81409b8f592e52a2658d787515fa8 (patch) | |
tree | ce360a0ece25aaa76c7bc7a1db1c1596ffa4e830 /ptx/src | |
parent | b695f44c188efc8df8e2e2c149904bb82d2dc58b (diff) | |
download | ZLUDA-76bae5f91bf81409b8f592e52a2658d787515fa8.tar.gz ZLUDA-76bae5f91bf81409b8f592e52a2658d787515fa8.zip |
Implement mad.hi.cc (#196)
Diffstat (limited to 'ptx/src')
-rw-r--r-- | ptx/src/ast.rs | 1 | ||||
-rw-r--r-- | ptx/src/emit.rs | 38 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 7 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mad_hi_cc.ll | 90 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mad_hi_cc.ptx | 41 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 1 | ||||
-rw-r--r-- | ptx/src/translate.rs | 8 |
7 files changed, 153 insertions, 33 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 0281961..93793e6 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -380,6 +380,7 @@ pub enum Instruction<P: ArgParams> { },
MadCC {
type_: ScalarType,
+ is_hi: bool,
arg: Arg4<P>,
},
Fma(ArithFloat, Arg4<P>),
diff --git a/ptx/src/emit.rs b/ptx/src/emit.rs index 94cc973..d4d6df6 100644 --- a/ptx/src/emit.rs +++ b/ptx/src/emit.rs @@ -621,8 +621,8 @@ fn emit_statement( crate::translate::Statement::MadC(MadCDetails { type_, is_hi, arg }) => { emit_inst_madc(ctx, type_, is_hi, &arg)? } - crate::translate::Statement::MadCC(MadCCDetails { type_, arg }) => { - emit_inst_madcc(ctx, type_, &arg)? + crate::translate::Statement::MadCC(MadCCDetails { type_, is_hi, arg }) => { + emit_inst_madcc(ctx, type_, is_hi, &arg)? } crate::translate::Statement::AddC(type_, arg) => emit_inst_add_c(ctx, type_, &arg)?, crate::translate::Statement::AddCC(type_, arg) => { @@ -2079,16 +2079,17 @@ fn emit_inst_mad_lo( ) } -// TODO: support mad.hi.cc fn emit_inst_madcc( ctx: &mut EmitContext, type_: ast::ScalarType, + is_hi: bool, arg: &Arg4CarryOut<ExpandedArgParams>, ) -> Result<(), TranslateError> { - let builder = ctx.builder.get(); - let src1 = ctx.names.value(arg.src1)?; - let src2 = ctx.names.value(arg.src2)?; - let mul_result = unsafe { LLVMBuildMul(builder, src1, src2, LLVM_UNNAMED) }; + let mul_result = if is_hi { + emit_inst_mul_hi_impl(ctx, type_, None, arg.src1, arg.src2)? + } else { + emit_inst_mul_low_impl(ctx, None, arg.src1, arg.src2, LLVMBuildMul)? + }; emit_inst_addsub_cc_impl( ctx, "add", @@ -2176,29 +2177,6 @@ fn emit_inst_madc( mul_result, args.src3, ) - /* - let src3 = ctx.names.value(args.src3)?; - let add_no_carry = unsafe { LLVMBuildAdd(builder, mul_result, src3, LLVM_UNNAMED) }; - let carry_flag = ctx.names.value(args.carry_in)?; - let llvm_type = get_llvm_type(ctx, &ast::Type::Scalar(type_))?; - let carry_flag = unsafe { LLVMBuildZExt(builder, carry_flag, llvm_type, LLVM_UNNAMED) }; - if let Some(carry_out) = args.carry_out { - emit_inst_addsub_cc_impl( - ctx, - "add", - type_, - args.dst, - carry_out, - add_no_carry, - carry_flag, - )?; - } else { - ctx.names.register_result(args.dst, |dst| unsafe { - LLVMBuildAdd(builder, add_no_carry, carry_flag, dst) - }); - } - Ok(()) - */ } fn emit_inst_add_c( diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index ae57575..d5c9b61 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -1516,7 +1516,12 @@ InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#extended-precision-arithmetic-instructions-mad-cc
InstMadCC: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "mad" ".lo" ".cc" <type_:IntType3264> <arg:Arg4> => ast::Instruction::MadCC{<>},
+ "mad" ".lo" ".cc" <type_:IntType3264> <arg:Arg4> => {
+ ast::Instruction::MadCC { type_, arg, is_hi: false }
+ },
+ "mad" ".hi" ".cc" <type_:IntType3264> <arg:Arg4> => {
+ ast::Instruction::MadCC { type_, arg, is_hi: true }
+ },
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#extended-precision-arithmetic-instructions-madc
diff --git a/ptx/src/test/spirv_run/mad_hi_cc.ll b/ptx/src/test/spirv_run/mad_hi_cc.ll new file mode 100644 index 0000000..a5b1595 --- /dev/null +++ b/ptx/src/test/spirv_run/mad_hi_cc.ll @@ -0,0 +1,90 @@ +target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7" +target triple = "amdgcn-amd-amdhsa" + +define protected amdgpu_kernel void @mad_hi_cc(ptr addrspace(4) byref(i64) %"61", ptr addrspace(4) byref(i64) %"62") #0 { +"78": + %"14" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"14", align 1 + %"15" = alloca i1, align 1, addrspace(5) + store i1 false, ptr addrspace(5) %"15", align 1 + %"4" = alloca i64, align 8, addrspace(5) + %"5" = alloca i64, align 8, addrspace(5) + %"6" = alloca i32, align 4, addrspace(5) + %"7" = alloca i32, align 4, addrspace(5) + %"8" = alloca i32, align 4, addrspace(5) + %"9" = alloca i32, align 4, addrspace(5) + %"10" = alloca i32, align 4, addrspace(5) + %"11" = alloca i32, align 4, addrspace(5) + %"12" = alloca i32, align 4, addrspace(5) + %"13" = alloca i32, align 4, addrspace(5) + %"16" = load i64, ptr addrspace(4) %"61", align 8 + store i64 %"16", ptr addrspace(5) %"4", align 8 + %"17" = load i64, ptr addrspace(4) %"62", align 8 + store i64 %"17", ptr addrspace(5) %"5", align 8 + %"19" = load i64, ptr addrspace(5) %"4", align 8 + %"64" = inttoptr i64 %"19" to ptr + %"63" = load i32, ptr %"64", align 4 + store i32 %"63", ptr addrspace(5) %"8", align 4 + %"21" = load i64, ptr addrspace(5) %"4", align 8 + %"65" = inttoptr i64 %"21" to ptr + %"80" = getelementptr inbounds i8, ptr %"65", i64 4 + %"66" = load i32, ptr %"80", align 4 + store i32 %"66", ptr addrspace(5) %"9", align 4 + %"23" = load i64, ptr addrspace(5) %"4", align 8 + %"67" = inttoptr i64 %"23" to ptr + %"82" = getelementptr inbounds i8, ptr %"67", i64 8 + %"22" = load i32, ptr %"82", align 4 + store i32 %"22", ptr addrspace(5) %"10", align 4 + %"26" = load i32, ptr addrspace(5) %"8", align 4 + %"27" = load i32, ptr addrspace(5) %"9", align 4 + %"28" = load i32, ptr addrspace(5) %"10", align 4 + %0 = sext i32 %"26" to i64 + %1 = sext i32 %"27" to i64 + %2 = mul nsw i64 %0, %1 + %3 = lshr i64 %2, 32 + %4 = trunc i64 %3 to i32 + %5 = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %4, i32 %"28") + %"24" = extractvalue { i32, i1 } %5, 0 + %"25" = extractvalue { i32, i1 } %5, 1 + store i32 %"24", ptr addrspace(5) %"7", align 4 + store i1 %"25", ptr addrspace(5) %"14", align 1 + %6 = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 1, i32 -2) + %"29" = extractvalue { i32, i1 } %6, 0 + %"30" = extractvalue { i32, i1 } %6, 1 + store i32 %"29", ptr addrspace(5) %"6", align 4 + store i1 %"30", ptr addrspace(5) %"14", align 1 + %"32" = load i1, ptr addrspace(5) %"14", align 1 + %7 = zext i1 %"32" to i32 + %"71" = add i32 0, %7 + store i32 %"71", ptr addrspace(5) %"12", align 4 + %8 = call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 1, i32 -1) + %"33" = extractvalue { i32, i1 } %8, 0 + %"34" = extractvalue { i32, i1 } %8, 1 + store i32 %"33", ptr addrspace(5) %"6", align 4 + store i1 %"34", ptr addrspace(5) %"14", align 1 + %"36" = load i1, ptr addrspace(5) %"14", align 1 + %9 = zext i1 %"36" to i32 + %"72" = add i32 0, %9 + store i32 %"72", ptr addrspace(5) %"13", align 4 + %"37" = load i64, ptr addrspace(5) %"5", align 8 + %"38" = load i32, ptr addrspace(5) %"7", align 4 + %"73" = inttoptr i64 %"37" to ptr + store i32 %"38", ptr %"73", align 4 + %"39" = load i64, ptr addrspace(5) %"5", align 8 + %"40" = load i32, ptr addrspace(5) %"12", align 4 + %"74" = inttoptr i64 %"39" to ptr + %"84" = getelementptr inbounds i8, ptr %"74", i64 4 + store i32 %"40", ptr %"84", align 4 + %"41" = load i64, ptr addrspace(5) %"5", align 8 + %"42" = load i32, ptr addrspace(5) %"13", align 4 + %"76" = inttoptr i64 %"41" to ptr + %"86" = getelementptr inbounds i8, ptr %"76", i64 8 + store i32 %"42", ptr %"86", align 4 + ret void +} + +; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn +declare { i32, i1 } @llvm.uadd.with.overflow.i32(i32, i32) #1 + +attributes #0 = { "amdgpu-unsafe-fp-atomics"="true" "denormal-fp-math"="ieee,ieee" "denormal-fp-math-f32"="ieee,ieee" "no-trapping-math"="true" "uniform-work-group-size"="true" } +attributes #1 = { nocallback nofree nosync nounwind readnone speculatable willreturn } diff --git a/ptx/src/test/spirv_run/mad_hi_cc.ptx b/ptx/src/test/spirv_run/mad_hi_cc.ptx new file mode 100644 index 0000000..4a8cac3 --- /dev/null +++ b/ptx/src/test/spirv_run/mad_hi_cc.ptx @@ -0,0 +1,41 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mad_hi_cc( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u32 unused; + + .reg .s32 dst1; + .reg .b32 src1; + .reg .b32 src2; + .reg .b32 src3; + + .reg .b32 result_1; + .reg .b32 carry_out_1; + .reg .b32 carry_out_2; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + // test valid computational results + ld.s32 src1, [in_addr]; + ld.s32 src2, [in_addr+4]; + ld.b32 src3, [in_addr+8]; + mad.hi.cc.s32 dst1, src1, src2, src3; + + mad.hi.cc.u32 unused, 65536, 65536, 4294967294; // non-overflowing + addc.u32 carry_out_1, 0, 0; // carry_out_1 should be 0 + mad.hi.cc.u32 unused, 65536, 65536, 4294967295; // overflowing + addc.u32 carry_out_2, 0, 0; // carry_out_2 should be 1 + + st.s32 [out_addr], dst1; + st.s32 [out_addr+4], carry_out_1; + st.s32 [out_addr+8], carry_out_2; + ret; +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index a65240c..8f229c9 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -290,6 +290,7 @@ test_ptx!( [2147487519u32, 4294934539]
);
test_ptx!(madc_cc2, [0xDEADu32], [0u32, 1, 1, 2]);
+test_ptx!(mad_hi_cc, [0x26223377u32, 0x70777766u32, 0x60666633u32], [0x71272866u32, 0u32, 1u32]); // Multi-tap :)
test_ptx!(mov_vector_cast, [0x200000001u64], [2u32, 1u32]);
test_ptx!(
cvt_clamp,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 041c690..1a203bd 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1999,9 +1999,10 @@ fn insert_hardware_registers_impl<'input>( is_hi,
arg: Arg4CarryIn::new(arg, carry_out, TypedOperand::Reg(overflow_flag)),
})),
- Statement::Instruction(ast::Instruction::MadCC { type_, arg }) => {
+ Statement::Instruction(ast::Instruction::MadCC { type_, is_hi, arg }) => {
result.push(Statement::MadCC(MadCCDetails {
type_,
+ is_hi,
arg: Arg4CarryOut::new(arg, TypedOperand::Reg(overflow_flag)),
}))
}
@@ -5568,6 +5569,7 @@ impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for MadCD pub(crate) struct MadCCDetails<P: ast::ArgParams> {
pub(crate) type_: ast::ScalarType,
+ pub(crate) is_hi: bool,
pub(crate) arg: Arg4CarryOut<P>,
}
@@ -5578,6 +5580,7 @@ impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for MadCC ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
Ok(Statement::MadCC(MadCCDetails {
type_: self.type_,
+ is_hi: self.is_hi,
arg: self.arg.map(visitor, self.type_)?,
}))
}
@@ -6486,8 +6489,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> { carry_out,
arg: arg.map(visitor, &ast::Type::Scalar(type_), false)?,
},
- ast::Instruction::MadCC { type_, arg } => ast::Instruction::MadCC {
+ ast::Instruction::MadCC { type_, arg, is_hi } => ast::Instruction::MadCC {
type_,
+ is_hi,
arg: arg.map(visitor, &ast::Type::Scalar(type_), false)?,
},
ast::Instruction::Tex(details, arg) => {
|