diff options
author | Andrzej Janik <[email protected]> | 2020-10-02 00:11:28 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-10-02 00:11:28 +0200 |
commit | 9a65dd32f5898eb9dd3edf7cdddb1513a7a754ed (patch) | |
tree | a360a387ff99ceb42a75883047d776bd6bafa67f | |
parent | bd3d440dba9a913e2214de89a151f9c2c34984fe (diff) | |
download | ZLUDA-9a65dd32f5898eb9dd3edf7cdddb1513a7a754ed.tar.gz ZLUDA-9a65dd32f5898eb9dd3edf7cdddb1513a7a754ed.zip |
Add sub, min, max
-rw-r--r-- | ptx/src/ast.rs | 88 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 169 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/max.ptx | 23 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/max.spvtxt | 57 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/min.ptx | 23 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/min.spvtxt | 57 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 3 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/or.ptx | 23 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/or.spvtxt | 58 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/sub.ptx | 22 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/sub.spvtxt | 49 | ||||
-rw-r--r-- | ptx/src/translate.rs | 427 |
12 files changed, 819 insertions, 180 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 8c64ebf..048d43a 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -241,6 +241,10 @@ sub_scalar_type!(IntType { S64 }); +sub_scalar_type!(UIntType { U8, U16, U32, U64 }); + +sub_scalar_type!(SIntType { S8, S16, S32, S64 }); + impl IntType { pub fn is_signed(self) -> bool { match self { @@ -331,7 +335,7 @@ pub enum Instruction<P: ArgParams> { Ld(LdDetails, Arg2Ld<P>), Mov(MovDetails, Arg2Mov<P>), Mul(MulDetails, Arg3<P>), - Add(AddDetails, Arg3<P>), + Add(ArithDetails, Arg3<P>), Setp(SetpData, Arg4Setp<P>), SetpBool(SetpBoolData, Arg5<P>), Not(NotType, Arg2<P>), @@ -346,6 +350,9 @@ pub enum Instruction<P: ArgParams> { Abs(AbsDetails, Arg2<P>), Mad(MulDetails, Arg4<P>), Or(OrType, Arg3<P>), + Sub(ArithDetails, Arg3<P>), + Min(MinMaxDetails, Arg3<P>), + Max(MinMaxDetails, Arg3<P>), } #[derive(Copy, Clone)] @@ -554,11 +561,6 @@ impl MovDetails { } } -pub enum MulDetails { - Int(MulIntDesc), - Float(MulFloatDesc), -} - #[derive(Copy, Clone)] pub struct MulIntDesc { pub typ: IntType, @@ -572,14 +574,6 @@ pub enum MulIntControl { Wide, } -#[derive(Copy, Clone)] -pub struct MulFloatDesc { - pub typ: FloatType, - pub rounding: Option<RoundingMode>, - pub flush_to_zero: bool, - pub saturate: bool, -} - #[derive(PartialEq, Eq, Copy, Clone)] pub enum RoundingMode { NearestEven, @@ -588,23 +582,11 @@ pub enum RoundingMode { PositiveInf, } -pub enum AddDetails { - Int(AddIntDesc), - Float(AddFloatDesc), -} - pub struct AddIntDesc { pub typ: IntType, pub saturate: bool, } -pub struct AddFloatDesc { - pub typ: FloatType, - pub rounding: Option<RoundingMode>, - pub flush_to_zero: bool, - pub saturate: bool, -} - pub struct SetpData { pub typ: ScalarType, pub flush_to_zero: bool, @@ -810,3 +792,57 @@ sub_scalar_type!(OrType { B32, B64, }); + +#[derive(Copy, Clone)] +pub enum MulDetails { + Unsigned(MulUInt), + Signed(MulSInt), + Float(ArithFloat), +} + +#[derive(Copy, Clone)] +pub struct MulUInt { + pub typ: UIntType, + pub control: MulIntControl, +} + +#[derive(Copy, Clone)] +pub struct MulSInt { + pub typ: SIntType, + pub control: MulIntControl, +} + +#[derive(Copy, Clone)] +pub enum ArithDetails { + Unsigned(UIntType), + Signed(ArithSInt), + Float(ArithFloat), +} + +#[derive(Copy, Clone)] +pub struct ArithSInt { + pub typ: SIntType, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub struct ArithFloat { + pub typ: FloatType, + pub rounding: Option<RoundingMode>, + pub flush_to_zero: bool, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub enum MinMaxDetails { + Signed(SIntType), + Unsigned(UIntType), + Float(MinMaxFloat), +} + +#[derive(Copy, Clone)] +pub struct MinMaxFloat { + pub ftz: bool, + pub nan: bool, + pub typ: FloatType, +} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index d2d5be8..2c0e365 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -70,6 +70,7 @@ match { ".ltu", ".lu", ".nan", + ".NaN", ".ne", ".neu", ".num", @@ -124,6 +125,8 @@ match { "ld", "mad", "map_f64_to_f32", + "max", + "min", "mov", "mul", "not", @@ -134,6 +137,7 @@ match { "shr", r"sm_[0-9]+" => ShaderModel, "st", + "sub", "texmode_independent", "texmode_unified", } else { @@ -153,6 +157,8 @@ ExtendedID : &'input str = { "ld", "mad", "map_f64_to_f32", + "max", + "min", "mov", "mul", "not", @@ -163,6 +169,7 @@ ExtendedID : &'input str = { "shr", ShaderModel, "st", + "sub", "texmode_independent", "texmode_unified", ID @@ -448,7 +455,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstCall, InstAbs, InstMad, - InstOr + InstOr, + InstSub, + InstMin, + InstMax, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld @@ -570,38 +580,19 @@ MovVectorType: ast::ScalarType = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = { - "mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a) + "mul" <d:MulDetails> <a:Arg3> => ast::Instruction::Mul(d, a) }; -InstMulMode: ast::MulDetails = { - <ctr:MulIntControl> <t:IntType> => ast::MulDetails::Int(ast::MulIntDesc { +MulDetails: ast::MulDetails = { + <ctr:MulIntControl> <t:UIntType> => ast::MulDetails::Unsigned(ast::MulUInt{ typ: t, control: ctr }), - <r:RoundingModeFloat?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc { - typ: ast::FloatType::F32, - rounding: r, - flush_to_zero: ftz.is_some(), - saturate: s.is_some() - }), - <r:RoundingModeFloat?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc { - typ: ast::FloatType::F64, - rounding: r, - flush_to_zero: false, - saturate: false - }), - <r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulDetails::Float(ast::MulFloatDesc { - typ: ast::FloatType::F16, - rounding: r.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), - saturate: s.is_some() + <ctr:MulIntControl> <t:SIntType> => ast::MulDetails::Signed(ast::MulSInt{ + typ: t, + control: ctr }), - <r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulDetails::Float(ast::MulFloatDesc { - typ: ast::FloatType::F16x2, - rounding: r.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), - saturate: s.is_some() - }) + <f:ArithFloat> => ast::MulDetails::Float(f) }; MulIntControl: ast::MulIntControl = { @@ -634,41 +625,23 @@ IntType : ast::IntType = { ".s64" => ast::IntType::S64, }; +UIntType: ast::UIntType = { + ".u16" => ast::UIntType::U16, + ".u32" => ast::UIntType::U32, + ".u64" => ast::UIntType::U64, +}; + +SIntType: ast::SIntType = { + ".s16" => ast::SIntType::S16, + ".s32" => ast::SIntType::S32, + ".s64" => ast::SIntType::S64, +}; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = { - "add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a) -}; - -InstAddMode: ast::AddDetails = { - <t:IntType> => ast::AddDetails::Int(ast::AddIntDesc { - typ: t, - saturate: false, - }), - ".sat" ".s32" => ast::AddDetails::Int(ast::AddIntDesc { - typ: ast::IntType::S32, - saturate: true, - }), - <rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc { - typ: ast::FloatType::F32, - rounding: rn, - flush_to_zero: ftz.is_some(), - saturate: sat.is_some(), - }), - <rn:RoundingModeFloat?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc { - typ: ast::FloatType::F64, - rounding: rn, - flush_to_zero: false, - saturate: false, - }), - <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?>".f16" => ast::AddDetails::Float(ast::AddFloatDesc { - typ: ast::FloatType::F16, - rounding: rn.map(|_| ast::RoundingMode::NearestEven), - flush_to_zero: ftz.is_some(), - saturate: sat.is_some(), - }), - ".rn"? ".ftz"? ".sat"? ".f16x2" => todo!() + "add" <d:ArithDetails> <a:Arg3> => ast::Instruction::Add(d, a) }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp @@ -1041,7 +1014,7 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = { // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = { - "mad" <d:InstMulMode> <a:Arg4> => ast::Instruction::Mad(d, a), + "mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a), "mad" ".hi" ".sat" ".s32" => todo!() }; @@ -1063,6 +1036,84 @@ OrType: ast::OrType = { ".b64" => ast::OrType::B64, } +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub +InstSub: ast::Instruction<ast::ParsedArgParams<'input>> = { + "sub" <d:ArithDetails> <a:Arg3> => ast::Instruction::Sub(d, a), +}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min +InstMin: ast::Instruction<ast::ParsedArgParams<'input>> = { + "min" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Min(d, a), +}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max +InstMax: ast::Instruction<ast::ParsedArgParams<'input>> = { + "max" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Max(d, a), +}; + +MinMaxDetails: ast::MinMaxDetails = { + <t:UIntType> => ast::MinMaxDetails::Unsigned(t), + <t:SIntType> => ast::MinMaxDetails::Signed(t), + <ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float( + ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 } + ), + ".f64" => ast::MinMaxDetails::Float( + ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 } + ), + <ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float( + ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 } + ), + <ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float( + ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 } + ) +} + +ArithDetails: ast::ArithDetails = { + <t:UIntType> => ast::ArithDetails::Unsigned(t), + <t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt { + typ: t, + saturate: false, + }), + ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt { + typ: ast::SIntType::S32, + saturate: true, + }), + <f:ArithFloat> => ast::ArithDetails::Float(f) +} + +ArithFloat: ast::ArithFloat = { + <rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat { + typ: ast::FloatType::F32, + rounding: rn, + flush_to_zero: ftz.is_some(), + saturate: sat.is_some(), + }, + <rn:RoundingModeFloat?> ".f64" => ast::ArithFloat { + typ: ast::FloatType::F64, + rounding: rn, + flush_to_zero: false, + saturate: false, + }, + <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat { + typ: ast::FloatType::F16, + rounding: rn.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: sat.is_some(), + }, + <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat { + typ: ast::FloatType::F16x2, + rounding: rn.map(|_| ast::RoundingMode::NearestEven), + flush_to_zero: ftz.is_some(), + saturate: sat.is_some(), + }, +} + Operand: ast::Operand<&'input str> = { <r:ExtendedID> => ast::Operand::Reg(r), <r:ExtendedID> "+" <o:Num> => { diff --git a/ptx/src/test/spirv_run/max.ptx b/ptx/src/test/spirv_run/max.ptx new file mode 100644 index 0000000..8c72fe2 --- /dev/null +++ b/ptx/src/test/spirv_run/max.ptx @@ -0,0 +1,23 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry max(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp1;
+ .reg .s32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp1, [in_addr];
+ ld.s32 temp2, [in_addr+4];
+ max.s32 temp1, temp1, temp2;
+ st.s32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/max.spvtxt b/ptx/src/test/spirv_run/max.spvtxt new file mode 100644 index 0000000..cab9a9a --- /dev/null +++ b/ptx/src/test/spirv_run/max.spvtxt @@ -0,0 +1,57 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %30 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "max" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %33 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %1 = OpFunction %void None %33 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %28 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %15 + %14 = OpLoad %uint %25 + OpStore %6 %14 + %17 = OpLoad %ulong %4 + %24 = OpIAdd %ulong %17 %ulong_4 + %26 = OpConvertUToPtr %_ptr_Generic_uint %24 + %16 = OpLoad %uint %26 + OpStore %7 %16 + %19 = OpLoad %uint %6 + %20 = OpLoad %uint %7 + %18 = OpExtInst %uint %30 s_max %19 %20 + OpStore %6 %18 + %21 = OpLoad %ulong %5 + %22 = OpLoad %uint %6 + %27 = OpConvertUToPtr %_ptr_Generic_uint %21 + OpStore %27 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/min.ptx b/ptx/src/test/spirv_run/min.ptx new file mode 100644 index 0000000..0311cdb --- /dev/null +++ b/ptx/src/test/spirv_run/min.ptx @@ -0,0 +1,23 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry min(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp1;
+ .reg .s32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp1, [in_addr];
+ ld.s32 temp2, [in_addr+4];
+ min.s32 temp1, temp1, temp2;
+ st.s32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/min.spvtxt b/ptx/src/test/spirv_run/min.spvtxt new file mode 100644 index 0000000..119cd15 --- /dev/null +++ b/ptx/src/test/spirv_run/min.spvtxt @@ -0,0 +1,57 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %30 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "min" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %33 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %1 = OpFunction %void None %33 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %28 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %15 + %14 = OpLoad %uint %25 + OpStore %6 %14 + %17 = OpLoad %ulong %4 + %24 = OpIAdd %ulong %17 %ulong_4 + %26 = OpConvertUToPtr %_ptr_Generic_uint %24 + %16 = OpLoad %uint %26 + OpStore %7 %16 + %19 = OpLoad %uint %6 + %20 = OpLoad %uint %7 + %18 = OpExtInst %uint %30 s_min %19 %20 + OpStore %6 %18 + %21 = OpLoad %ulong %5 + %22 = OpLoad %uint %6 + %27 = OpConvertUToPtr %_ptr_Generic_uint %21 + OpStore %27 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 99785a6..8caf540 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -70,6 +70,9 @@ test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]) test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(or, [1u64, 2u64], [3u64]);
+test_ptx!(sub, [2u64], [1u64]);
+test_ptx!(min, [555i32, 444i32], [444i32]);
+test_ptx!(max, [555i32, 444i32], [555i32]);
struct DisplayError<T: Debug> {
diff --git a/ptx/src/test/spirv_run/or.ptx b/ptx/src/test/spirv_run/or.ptx new file mode 100644 index 0000000..1deb3c8 --- /dev/null +++ b/ptx/src/test/spirv_run/or.ptx @@ -0,0 +1,23 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry or(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp1, [in_addr];
+ ld.u64 temp2, [in_addr+8];
+ or.b64 temp1, temp1, temp2;
+ st.u64 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/or.spvtxt b/ptx/src/test/spirv_run/or.spvtxt new file mode 100644 index 0000000..fbf80c5 --- /dev/null +++ b/ptx/src/test/spirv_run/or.spvtxt @@ -0,0 +1,58 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %33 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "or" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %36 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_8 = OpConstant %ulong 8 + %1 = OpFunction %void None %36 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %31 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %25 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %25 + OpStore %6 %14 + %17 = OpLoad %ulong %4 + %24 = OpIAdd %ulong %17 %ulong_8 + %26 = OpConvertUToPtr %_ptr_Generic_ulong %24 + %16 = OpLoad %ulong %26 + OpStore %7 %16 + %19 = OpLoad %ulong %6 + %20 = OpLoad %ulong %7 + %28 = OpCopyObject %ulong %19 + %29 = OpCopyObject %ulong %20 + %27 = OpBitwiseOr %ulong %28 %29 + %18 = OpCopyObject %ulong %27 + OpStore %6 %18 + %21 = OpLoad %ulong %5 + %22 = OpLoad %ulong %6 + %30 = OpConvertUToPtr %_ptr_Generic_ulong %21 + OpStore %30 %22 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/sub.ptx b/ptx/src/test/spirv_run/sub.ptx new file mode 100644 index 0000000..6cce9dc --- /dev/null +++ b/ptx/src/test/spirv_run/sub.ptx @@ -0,0 +1,22 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry sub(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ sub.u64 temp2, temp, 1;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/sub.spvtxt b/ptx/src/test/spirv_run/sub.spvtxt new file mode 100644 index 0000000..8520168 --- /dev/null +++ b/ptx/src/test/spirv_run/sub.spvtxt @@ -0,0 +1,49 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "sub" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpISub %ulong %17 %ulong_1 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index fb1b843..7c15744 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -595,6 +595,15 @@ fn convert_to_typed_statements( ast::Instruction::Or(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
}
+ ast::Instruction::Sub(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast())))
+ }
+ ast::Instruction::Min(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast())))
+ }
+ ast::Instruction::Max(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -968,62 +977,74 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg_offset(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
- typ: ast::Type,
+ mut typ: ast::Type,
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
match desc.sema {
- ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
+ ArgumentSemantics::Default
+ | ArgumentSemantics::DefaultRelaxed
+ | ArgumentSemantics::PhysicalPointer => {
+ if desc.sema == ArgumentSemantics::PhysicalPointer {
+ typ = ast::Type::Scalar(ast::ScalarType::U64);
+ }
+ let (width, kind) = match typ {
+ ast::Type::Scalar(scalar_t) => {
+ let kind = match scalar_t.kind() {
+ kind @ ScalarKind::Bit
+ | kind @ ScalarKind::Unsigned
+ | kind @ ScalarKind::Signed => kind,
+ ScalarKind::Float => return Err(TranslateError::MismatchedType),
+ ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
+ ScalarKind::Pred => return Err(TranslateError::MismatchedType),
+ };
+ (scalar_t.width(), kind)
+ }
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let arith_detail = if kind == ScalarKind::Signed {
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::from_size(width),
+ saturate: false,
+ })
} else {
- todo!()
+ ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
- let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let id_constant_stmt = self.id_def.new_id(typ);
let result_id = self.id_def.new_id(typ);
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- Ok(result_id)
- }
- ArgumentSemantics::PhysicalPointer => {
- let scalar_t = ast::ScalarType::U64;
- let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::U64;
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ // TODO: check for edge cases around min value/max value/wrapping
+ if offset < 0 && kind != ScalarKind::Signed {
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: ast::ScalarType::from_parts(width, kind),
+ value: -(offset as i64),
+ }));
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Sub(
+ arith_detail,
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ } else {
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: ast::ScalarType::from_parts(width, kind),
+ value: offset as i64,
+ }));
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ arith_detail,
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ }
Ok(result_id)
}
ArgumentSemantics::RegisterPointer => {
@@ -1522,14 +1543,22 @@ fn emit_function_body_ops( }
},
ast::Instruction::Mul(mul, arg) => match mul {
- ast::MulDetails::Int(ref ctr) => {
- emit_mul_int(builder, map, opencl, ctr, arg)?;
+ ast::MulDetails::Signed(ref ctr) => {
+ emit_mul_sint(builder, map, opencl, ctr, arg)?
+ }
+ ast::MulDetails::Unsigned(ref ctr) => {
+ emit_mul_uint(builder, map, opencl, ctr, arg)?
}
ast::MulDetails::Float(_) => todo!(),
},
ast::Instruction::Add(add, arg) => match add {
- ast::AddDetails::Int(ref desc) => emit_add_int(builder, map, desc, arg)?,
- ast::AddDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
+ ast::ArithDetails::Signed(ref desc) => {
+ emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)?
+ }
+ ast::ArithDetails::Unsigned(ref desc) => {
+ emit_add_int(builder, map, (*desc).into(), false, arg)?
+ }
+ ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
},
ast::Instruction::Setp(setp, arg) => {
if arg.dst2.is_some() {
@@ -1581,8 +1610,11 @@ fn emit_function_body_ops( }
ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::Mad(mad, arg) => match mad {
- ast::MulDetails::Int(ref desc) => {
- emit_mad_int(builder, map, opencl, desc, arg)?
+ ast::MulDetails::Signed(ref desc) => {
+ emit_mad_sint(builder, map, opencl, desc, arg)?
+ }
+ ast::MulDetails::Unsigned(ref desc) => {
+ emit_mad_uint(builder, map, opencl, desc, arg)?
}
ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
},
@@ -1594,6 +1626,23 @@ fn emit_function_body_ops( builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
}
}
+ ast::Instruction::Sub(d, arg) => match d {
+ ast::ArithDetails::Signed(desc) => {
+ emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?;
+ }
+ ast::ArithDetails::Unsigned(desc) => {
+ emit_sub_int(builder, map, (*desc).into(), false, arg)?;
+ }
+ ast::ArithDetails::Float(desc) => {
+ emit_sub_float(builder, map, desc, arg)?;
+ }
+ },
+ ast::Instruction::Min(d, a) => {
+ emit_min(builder, map, opencl, d, a)?;
+ }
+ ast::Instruction::Max(d, a) => {
+ emit_max(builder, map, opencl, d, a)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@@ -1624,11 +1673,11 @@ fn emit_function_body_ops( Ok(())
}
-fn emit_mad_int(
+fn emit_mad_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
- desc: &ast::MulIntDesc,
+ desc: &ast::MulUInt,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
@@ -1638,16 +1687,38 @@ fn emit_mad_int( builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
}
ast::MulIntControl::High => {
- let cl_op = if desc.typ.is_signed() {
- spirv::CLOp::s_mad_hi
- } else {
- spirv::CLOp::u_mad_hi
- };
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
- cl_op as spirv::Word,
+ spirv::CLOp::u_mad_hi as spirv::Word,
+ [arg.src1, arg.src2, arg.src3],
+ )?;
+ }
+ ast::MulIntControl::Wide => todo!(),
+ };
+ Ok(())
+}
+
+fn emit_mad_sint(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MulSInt,
+ arg: &ast::Arg4<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ match desc.control {
+ ast::MulIntControl::Low => {
+ let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
+ builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
+ }
+ ast::MulIntControl::High => {
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::s_mad_hi as spirv::Word,
[arg.src1, arg.src2, arg.src3],
)?;
}
@@ -1659,7 +1730,7 @@ fn emit_mad_int( fn emit_mad_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- desc: &ast::MulFloatDesc,
+ desc: &ast::ArithFloat,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
todo!()
@@ -1668,7 +1739,7 @@ fn emit_mad_float( fn emit_add_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- desc: &ast::AddFloatDesc,
+ desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if desc.flush_to_zero {
@@ -1680,6 +1751,67 @@ fn emit_add_float( Ok(())
}
+fn emit_sub_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ desc: &ast::ArithFloat,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ if desc.flush_to_zero {
+ todo!()
+ }
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
+ Ok(())
+}
+
+fn emit_min(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MinMaxDetails,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let cl_op = match desc {
+ ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
+ ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
+ ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
+ };
+ let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ cl_op as spirv::Word,
+ [arg.src1, arg.src2],
+ )?;
+ Ok(())
+}
+
+fn emit_max(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MinMaxDetails,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let cl_op = match desc {
+ ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
+ ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
+ ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
+ };
+ let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ cl_op as spirv::Word,
+ [arg.src1, arg.src2],
+ )?;
+ Ok(())
+}
+
fn emit_cvt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1880,11 +2012,11 @@ fn emit_setp( Ok(())
}
-fn emit_mul_int(
+fn emit_mul_sint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
- desc: &ast::MulIntDesc,
+ desc: &ast::MulSInt,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ);
@@ -1894,16 +2026,11 @@ fn emit_mul_int( builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::MulIntControl::High => {
- let ocl_mul_hi = if desc.typ.is_signed() {
- spirv::CLOp::s_mul_hi
- } else {
- spirv::CLOp::u_mul_hi
- };
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
- ocl_mul_hi as spirv::Word,
+ spirv::CLOp::s_mul_hi as spirv::Word,
[arg.src1, arg.src2],
)?;
}
@@ -1913,11 +2040,54 @@ fn emit_mul_int( SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
- let mul = if desc.typ.is_signed() {
- builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
- } else {
- builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
- };
+ let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
+ let instr_width = instruction_type.width();
+ let instr_kind = instruction_type.kind();
+ let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
+ let dst_type_id = map.get_or_add_scalar(builder, dst_type);
+ struct2_bitcast_to_wide(
+ builder,
+ map,
+ SpirvScalarKey::from(instruction_type),
+ inst_type,
+ arg.dst,
+ dst_type_id,
+ mul,
+ )?;
+ }
+ }
+ Ok(())
+}
+
+fn emit_mul_uint(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MulUInt,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let instruction_type = ast::ScalarType::from(desc.typ);
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ match desc.control {
+ ast::MulIntControl::Low => {
+ builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::MulIntControl::High => {
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::u_mul_hi as spirv::Word,
+ [arg.src1, arg.src2],
+ )?;
+ }
+ ast::MulIntControl::Wide => {
+ let mul_ext_type = SpirvType::Struct(vec![
+ SpirvScalarKey::from(instruction_type),
+ SpirvScalarKey::from(instruction_type),
+ ]);
+ let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
+ let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.width();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
@@ -1981,14 +2151,33 @@ fn emit_abs( fn emit_add_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- ctr: &ast::AddIntDesc,
+ typ: ast::ScalarType,
+ saturate: bool,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(ctr.typ)));
+ if saturate {
+ todo!()
+ }
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
+fn emit_sub_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ typ: ast::ScalarType,
+ saturate: bool,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ if saturate {
+ todo!()
+ }
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
+ builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
+ Ok(())
+}
+
fn emit_implicit_conversion(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -2920,6 +3109,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> { t,
a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?,
),
+ ast::Instruction::Sub(d, a) => {
+ let typ = d.get_type();
+ ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?)
+ }
+ ast::Instruction::Min(d, a) => {
+ let typ = d.get_type();
+ ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?)
+ }
+ ast::Instruction::Max(d, a) => {
+ let typ = d.get_type();
+ ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?)
+ }
})
}
}
@@ -3129,6 +3330,9 @@ impl ast::Instruction<ExpandedArgParams> { | ast::Instruction::Abs(_, _)
| ast::Instruction::Call(_)
| ast::Instruction::Or(_, _)
+ | ast::Instruction::Sub(_, _)
+ | ast::Instruction::Min(_, _)
+ | ast::Instruction::Max(_, _)
| ast::Instruction::Mad(_, _) => None,
}
}
@@ -4049,25 +4253,33 @@ impl ast::ShrType { }
}
-impl ast::AddDetails {
+impl ast::ArithDetails {
fn get_type(&self) -> ast::Type {
- match self {
- ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
- ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => {
- ast::Type::Scalar((*typ).into())
- }
- }
+ ast::Type::Scalar(match self {
+ ast::ArithDetails::Unsigned(t) => (*t).into(),
+ ast::ArithDetails::Signed(d) => d.typ.into(),
+ ast::ArithDetails::Float(d) => d.typ.into(),
+ })
}
}
impl ast::MulDetails {
fn get_type(&self) -> ast::Type {
- match self {
- ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
- ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => {
- ast::Type::Scalar((*typ).into())
- }
- }
+ ast::Type::Scalar(match self {
+ ast::MulDetails::Unsigned(d) => d.typ.into(),
+ ast::MulDetails::Signed(d) => d.typ.into(),
+ ast::MulDetails::Float(d) => d.typ.into(),
+ })
+ }
+}
+
+impl ast::MinMaxDetails {
+ fn get_type(&self) -> ast::Type {
+ ast::Type::Scalar(match self {
+ ast::MinMaxDetails::Signed(t) => (*t).into(),
+ ast::MinMaxDetails::Unsigned(t) => (*t).into(),
+ ast::MinMaxDetails::Float(d) => d.typ.into(),
+ })
}
}
@@ -4085,6 +4297,30 @@ impl ast::IntType { }
}
+impl ast::SIntType {
+ fn from_size(width: u8) -> Self {
+ match width {
+ 1 => ast::SIntType::S8,
+ 2 => ast::SIntType::S16,
+ 4 => ast::SIntType::S32,
+ 8 => ast::SIntType::S64,
+ _ => unreachable!(),
+ }
+ }
+}
+
+impl ast::UIntType {
+ fn from_size(width: u8) -> Self {
+ match width {
+ 1 => ast::UIntType::U8,
+ 2 => ast::UIntType::U16,
+ 4 => ast::UIntType::U32,
+ 8 => ast::UIntType::U64,
+ _ => unreachable!(),
+ }
+ }
+}
+
impl ast::LdStateSpace {
fn to_spirv(self) -> spirv::StorageClass {
match self {
@@ -4128,7 +4364,8 @@ impl<T> ast::OperandOrVector<T> { impl ast::MulDetails {
fn is_wide(&self) -> bool {
match self {
- ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide,
+ ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide,
+ ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide,
ast::MulDetails::Float(_) => false,
}
}
|