summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ptx/src/ast.rs88
-rw-r--r--ptx/src/ptx.lalrpop169
-rw-r--r--ptx/src/test/spirv_run/max.ptx23
-rw-r--r--ptx/src/test/spirv_run/max.spvtxt57
-rw-r--r--ptx/src/test/spirv_run/min.ptx23
-rw-r--r--ptx/src/test/spirv_run/min.spvtxt57
-rw-r--r--ptx/src/test/spirv_run/mod.rs3
-rw-r--r--ptx/src/test/spirv_run/or.ptx23
-rw-r--r--ptx/src/test/spirv_run/or.spvtxt58
-rw-r--r--ptx/src/test/spirv_run/sub.ptx22
-rw-r--r--ptx/src/test/spirv_run/sub.spvtxt49
-rw-r--r--ptx/src/translate.rs427
12 files changed, 819 insertions, 180 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 8c64ebf..048d43a 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -241,6 +241,10 @@ sub_scalar_type!(IntType {
S64
});
+sub_scalar_type!(UIntType { U8, U16, U32, U64 });
+
+sub_scalar_type!(SIntType { S8, S16, S32, S64 });
+
impl IntType {
pub fn is_signed(self) -> bool {
match self {
@@ -331,7 +335,7 @@ pub enum Instruction<P: ArgParams> {
Ld(LdDetails, Arg2Ld<P>),
Mov(MovDetails, Arg2Mov<P>),
Mul(MulDetails, Arg3<P>),
- Add(AddDetails, Arg3<P>),
+ Add(ArithDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>),
SetpBool(SetpBoolData, Arg5<P>),
Not(NotType, Arg2<P>),
@@ -346,6 +350,9 @@ pub enum Instruction<P: ArgParams> {
Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>),
Or(OrType, Arg3<P>),
+ Sub(ArithDetails, Arg3<P>),
+ Min(MinMaxDetails, Arg3<P>),
+ Max(MinMaxDetails, Arg3<P>),
}
#[derive(Copy, Clone)]
@@ -554,11 +561,6 @@ impl MovDetails {
}
}
-pub enum MulDetails {
- Int(MulIntDesc),
- Float(MulFloatDesc),
-}
-
#[derive(Copy, Clone)]
pub struct MulIntDesc {
pub typ: IntType,
@@ -572,14 +574,6 @@ pub enum MulIntControl {
Wide,
}
-#[derive(Copy, Clone)]
-pub struct MulFloatDesc {
- pub typ: FloatType,
- pub rounding: Option<RoundingMode>,
- pub flush_to_zero: bool,
- pub saturate: bool,
-}
-
#[derive(PartialEq, Eq, Copy, Clone)]
pub enum RoundingMode {
NearestEven,
@@ -588,23 +582,11 @@ pub enum RoundingMode {
PositiveInf,
}
-pub enum AddDetails {
- Int(AddIntDesc),
- Float(AddFloatDesc),
-}
-
pub struct AddIntDesc {
pub typ: IntType,
pub saturate: bool,
}
-pub struct AddFloatDesc {
- pub typ: FloatType,
- pub rounding: Option<RoundingMode>,
- pub flush_to_zero: bool,
- pub saturate: bool,
-}
-
pub struct SetpData {
pub typ: ScalarType,
pub flush_to_zero: bool,
@@ -810,3 +792,57 @@ sub_scalar_type!(OrType {
B32,
B64,
});
+
+#[derive(Copy, Clone)]
+pub enum MulDetails {
+ Unsigned(MulUInt),
+ Signed(MulSInt),
+ Float(ArithFloat),
+}
+
+#[derive(Copy, Clone)]
+pub struct MulUInt {
+ pub typ: UIntType,
+ pub control: MulIntControl,
+}
+
+#[derive(Copy, Clone)]
+pub struct MulSInt {
+ pub typ: SIntType,
+ pub control: MulIntControl,
+}
+
+#[derive(Copy, Clone)]
+pub enum ArithDetails {
+ Unsigned(UIntType),
+ Signed(ArithSInt),
+ Float(ArithFloat),
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithSInt {
+ pub typ: SIntType,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithFloat {
+ pub typ: FloatType,
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: bool,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone)]
+pub enum MinMaxDetails {
+ Signed(SIntType),
+ Unsigned(UIntType),
+ Float(MinMaxFloat),
+}
+
+#[derive(Copy, Clone)]
+pub struct MinMaxFloat {
+ pub ftz: bool,
+ pub nan: bool,
+ pub typ: FloatType,
+}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index d2d5be8..2c0e365 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -70,6 +70,7 @@ match {
".ltu",
".lu",
".nan",
+ ".NaN",
".ne",
".neu",
".num",
@@ -124,6 +125,8 @@ match {
"ld",
"mad",
"map_f64_to_f32",
+ "max",
+ "min",
"mov",
"mul",
"not",
@@ -134,6 +137,7 @@ match {
"shr",
r"sm_[0-9]+" => ShaderModel,
"st",
+ "sub",
"texmode_independent",
"texmode_unified",
} else {
@@ -153,6 +157,8 @@ ExtendedID : &'input str = {
"ld",
"mad",
"map_f64_to_f32",
+ "max",
+ "min",
"mov",
"mul",
"not",
@@ -163,6 +169,7 @@ ExtendedID : &'input str = {
"shr",
ShaderModel,
"st",
+ "sub",
"texmode_independent",
"texmode_unified",
ID
@@ -448,7 +455,10 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstCall,
InstAbs,
InstMad,
- InstOr
+ InstOr,
+ InstSub,
+ InstMin,
+ InstMax,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -570,38 +580,19 @@ MovVectorType: ast::ScalarType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
InstMul: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "mul" <d:InstMulMode> <a:Arg3> => ast::Instruction::Mul(d, a)
+ "mul" <d:MulDetails> <a:Arg3> => ast::Instruction::Mul(d, a)
};
-InstMulMode: ast::MulDetails = {
- <ctr:MulIntControl> <t:IntType> => ast::MulDetails::Int(ast::MulIntDesc {
+MulDetails: ast::MulDetails = {
+ <ctr:MulIntControl> <t:UIntType> => ast::MulDetails::Unsigned(ast::MulUInt{
typ: t,
control: ctr
}),
- <r:RoundingModeFloat?> <ftz:".ftz"?> <s:".sat"?> ".f32" => ast::MulDetails::Float(ast::MulFloatDesc {
- typ: ast::FloatType::F32,
- rounding: r,
- flush_to_zero: ftz.is_some(),
- saturate: s.is_some()
- }),
- <r:RoundingModeFloat?> ".f64" => ast::MulDetails::Float(ast::MulFloatDesc {
- typ: ast::FloatType::F64,
- rounding: r,
- flush_to_zero: false,
- saturate: false
- }),
- <r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16" => ast::MulDetails::Float(ast::MulFloatDesc {
- typ: ast::FloatType::F16,
- rounding: r.map(|_| ast::RoundingMode::NearestEven),
- flush_to_zero: ftz.is_some(),
- saturate: s.is_some()
+ <ctr:MulIntControl> <t:SIntType> => ast::MulDetails::Signed(ast::MulSInt{
+ typ: t,
+ control: ctr
}),
- <r:".rn"?> <ftz:".ftz"?> <s:".sat"?> ".f16x2" => ast::MulDetails::Float(ast::MulFloatDesc {
- typ: ast::FloatType::F16x2,
- rounding: r.map(|_| ast::RoundingMode::NearestEven),
- flush_to_zero: ftz.is_some(),
- saturate: s.is_some()
- })
+ <f:ArithFloat> => ast::MulDetails::Float(f)
};
MulIntControl: ast::MulIntControl = {
@@ -634,41 +625,23 @@ IntType : ast::IntType = {
".s64" => ast::IntType::S64,
};
+UIntType: ast::UIntType = {
+ ".u16" => ast::UIntType::U16,
+ ".u32" => ast::UIntType::U32,
+ ".u64" => ast::UIntType::U64,
+};
+
+SIntType: ast::SIntType = {
+ ".s16" => ast::SIntType::S16,
+ ".s32" => ast::SIntType::S32,
+ ".s64" => ast::SIntType::S64,
+};
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-add
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-add
InstAdd: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "add" <d:InstAddMode> <a:Arg3> => ast::Instruction::Add(d, a)
-};
-
-InstAddMode: ast::AddDetails = {
- <t:IntType> => ast::AddDetails::Int(ast::AddIntDesc {
- typ: t,
- saturate: false,
- }),
- ".sat" ".s32" => ast::AddDetails::Int(ast::AddIntDesc {
- typ: ast::IntType::S32,
- saturate: true,
- }),
- <rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::AddDetails::Float(ast::AddFloatDesc {
- typ: ast::FloatType::F32,
- rounding: rn,
- flush_to_zero: ftz.is_some(),
- saturate: sat.is_some(),
- }),
- <rn:RoundingModeFloat?> ".f64" => ast::AddDetails::Float(ast::AddFloatDesc {
- typ: ast::FloatType::F64,
- rounding: rn,
- flush_to_zero: false,
- saturate: false,
- }),
- <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?>".f16" => ast::AddDetails::Float(ast::AddFloatDesc {
- typ: ast::FloatType::F16,
- rounding: rn.map(|_| ast::RoundingMode::NearestEven),
- flush_to_zero: ftz.is_some(),
- saturate: sat.is_some(),
- }),
- ".rn"? ".ftz"? ".sat"? ".f16x2" => todo!()
+ "add" <d:ArithDetails> <a:Arg3> => ast::Instruction::Add(d, a)
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
@@ -1041,7 +1014,7 @@ InstAbs: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
InstMad: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "mad" <d:InstMulMode> <a:Arg4> => ast::Instruction::Mad(d, a),
+ "mad" <d:MulDetails> <a:Arg4> => ast::Instruction::Mad(d, a),
"mad" ".hi" ".sat" ".s32" => todo!()
};
@@ -1063,6 +1036,84 @@ OrType: ast::OrType = {
".b64" => ast::OrType::B64,
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
+InstSub: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "sub" <d:ArithDetails> <a:Arg3> => ast::Instruction::Sub(d, a),
+};
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min
+InstMin: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "min" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Min(d, a),
+};
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max
+InstMax: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "max" <d:MinMaxDetails> <a:Arg3> => ast::Instruction::Max(d, a),
+};
+
+MinMaxDetails: ast::MinMaxDetails = {
+ <t:UIntType> => ast::MinMaxDetails::Unsigned(t),
+ <t:SIntType> => ast::MinMaxDetails::Signed(t),
+ <ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
+ ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F32 }
+ ),
+ ".f64" => ast::MinMaxDetails::Float(
+ ast::MinMaxFloat{ ftz: false, nan: false, typ: ast::FloatType::F64 }
+ ),
+ <ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
+ ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16 }
+ ),
+ <ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
+ ast::MinMaxFloat{ ftz: ftz.is_some(), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
+ )
+}
+
+ArithDetails: ast::ArithDetails = {
+ <t:UIntType> => ast::ArithDetails::Unsigned(t),
+ <t:SIntType> => ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: t,
+ saturate: false,
+ }),
+ ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S32,
+ saturate: true,
+ }),
+ <f:ArithFloat> => ast::ArithDetails::Float(f)
+}
+
+ArithFloat: ast::ArithFloat = {
+ <rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
+ typ: ast::FloatType::F32,
+ rounding: rn,
+ flush_to_zero: ftz.is_some(),
+ saturate: sat.is_some(),
+ },
+ <rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
+ typ: ast::FloatType::F64,
+ rounding: rn,
+ flush_to_zero: false,
+ saturate: false,
+ },
+ <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
+ typ: ast::FloatType::F16,
+ rounding: rn.map(|_| ast::RoundingMode::NearestEven),
+ flush_to_zero: ftz.is_some(),
+ saturate: sat.is_some(),
+ },
+ <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
+ typ: ast::FloatType::F16x2,
+ rounding: rn.map(|_| ast::RoundingMode::NearestEven),
+ flush_to_zero: ftz.is_some(),
+ saturate: sat.is_some(),
+ },
+}
+
Operand: ast::Operand<&'input str> = {
<r:ExtendedID> => ast::Operand::Reg(r),
<r:ExtendedID> "+" <o:Num> => {
diff --git a/ptx/src/test/spirv_run/max.ptx b/ptx/src/test/spirv_run/max.ptx
new file mode 100644
index 0000000..8c72fe2
--- /dev/null
+++ b/ptx/src/test/spirv_run/max.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry max(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp1;
+ .reg .s32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp1, [in_addr];
+ ld.s32 temp2, [in_addr+4];
+ max.s32 temp1, temp1, temp2;
+ st.s32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/max.spvtxt b/ptx/src/test/spirv_run/max.spvtxt
new file mode 100644
index 0000000..cab9a9a
--- /dev/null
+++ b/ptx/src/test/spirv_run/max.spvtxt
@@ -0,0 +1,57 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %30 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "max"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %33 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %ulong_4 = OpConstant %ulong 4
+ %1 = OpFunction %void None %33
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %28 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %25 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %14 = OpLoad %uint %25
+ OpStore %6 %14
+ %17 = OpLoad %ulong %4
+ %24 = OpIAdd %ulong %17 %ulong_4
+ %26 = OpConvertUToPtr %_ptr_Generic_uint %24
+ %16 = OpLoad %uint %26
+ OpStore %7 %16
+ %19 = OpLoad %uint %6
+ %20 = OpLoad %uint %7
+ %18 = OpExtInst %uint %30 s_max %19 %20
+ OpStore %6 %18
+ %21 = OpLoad %ulong %5
+ %22 = OpLoad %uint %6
+ %27 = OpConvertUToPtr %_ptr_Generic_uint %21
+ OpStore %27 %22
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/min.ptx b/ptx/src/test/spirv_run/min.ptx
new file mode 100644
index 0000000..0311cdb
--- /dev/null
+++ b/ptx/src/test/spirv_run/min.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry min(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp1;
+ .reg .s32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.s32 temp1, [in_addr];
+ ld.s32 temp2, [in_addr+4];
+ min.s32 temp1, temp1, temp2;
+ st.s32 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/min.spvtxt b/ptx/src/test/spirv_run/min.spvtxt
new file mode 100644
index 0000000..119cd15
--- /dev/null
+++ b/ptx/src/test/spirv_run/min.spvtxt
@@ -0,0 +1,57 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %30 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "min"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %33 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %ulong_4 = OpConstant %ulong 4
+ %1 = OpFunction %void None %33
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %28 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %25 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %14 = OpLoad %uint %25
+ OpStore %6 %14
+ %17 = OpLoad %ulong %4
+ %24 = OpIAdd %ulong %17 %ulong_4
+ %26 = OpConvertUToPtr %_ptr_Generic_uint %24
+ %16 = OpLoad %uint %26
+ OpStore %7 %16
+ %19 = OpLoad %uint %6
+ %20 = OpLoad %uint %7
+ %18 = OpExtInst %uint %30 s_min %19 %20
+ OpStore %6 %18
+ %21 = OpLoad %ulong %5
+ %22 = OpLoad %uint %6
+ %27 = OpConvertUToPtr %_ptr_Generic_uint %21
+ OpStore %27 %22
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 99785a6..8caf540 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -70,6 +70,9 @@ test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64])
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(or, [1u64, 2u64], [3u64]);
+test_ptx!(sub, [2u64], [1u64]);
+test_ptx!(min, [555i32, 444i32], [444i32]);
+test_ptx!(max, [555i32, 444i32], [555i32]);
struct DisplayError<T: Debug> {
diff --git a/ptx/src/test/spirv_run/or.ptx b/ptx/src/test/spirv_run/or.ptx
new file mode 100644
index 0000000..1deb3c8
--- /dev/null
+++ b/ptx/src/test/spirv_run/or.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry or(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp1;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp1, [in_addr];
+ ld.u64 temp2, [in_addr+8];
+ or.b64 temp1, temp1, temp2;
+ st.u64 [out_addr], temp1;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/or.spvtxt b/ptx/src/test/spirv_run/or.spvtxt
new file mode 100644
index 0000000..fbf80c5
--- /dev/null
+++ b/ptx/src/test/spirv_run/or.spvtxt
@@ -0,0 +1,58 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %33 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "or"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %36 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_8 = OpConstant %ulong 8
+ %1 = OpFunction %void None %36
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %31 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %25 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %25
+ OpStore %6 %14
+ %17 = OpLoad %ulong %4
+ %24 = OpIAdd %ulong %17 %ulong_8
+ %26 = OpConvertUToPtr %_ptr_Generic_ulong %24
+ %16 = OpLoad %ulong %26
+ OpStore %7 %16
+ %19 = OpLoad %ulong %6
+ %20 = OpLoad %ulong %7
+ %28 = OpCopyObject %ulong %19
+ %29 = OpCopyObject %ulong %20
+ %27 = OpBitwiseOr %ulong %28 %29
+ %18 = OpCopyObject %ulong %27
+ OpStore %6 %18
+ %21 = OpLoad %ulong %5
+ %22 = OpLoad %ulong %6
+ %30 = OpConvertUToPtr %_ptr_Generic_ulong %21
+ OpStore %30 %22
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/sub.ptx b/ptx/src/test/spirv_run/sub.ptx
new file mode 100644
index 0000000..6cce9dc
--- /dev/null
+++ b/ptx/src/test/spirv_run/sub.ptx
@@ -0,0 +1,22 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry sub(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ sub.u64 temp2, temp, 1;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/sub.spvtxt b/ptx/src/test/spirv_run/sub.spvtxt
new file mode 100644
index 0000000..8520168
--- /dev/null
+++ b/ptx/src/test/spirv_run/sub.spvtxt
@@ -0,0 +1,49 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %25 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "sub"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %28 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_1 = OpConstant %ulong 1
+ %1 = OpFunction %void None %28
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %23 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %21
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpISub %ulong %17 %ulong_1
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index fb1b843..7c15744 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -595,6 +595,15 @@ fn convert_to_typed_statements(
ast::Instruction::Or(d, a) => {
result.push(Statement::Instruction(ast::Instruction::Or(d, a.cast())))
}
+ ast::Instruction::Sub(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Sub(d, a.cast())))
+ }
+ ast::Instruction::Min(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Min(d, a.cast())))
+ }
+ ast::Instruction::Max(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Max(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -968,62 +977,74 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn reg_offset(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
- typ: ast::Type,
+ mut typ: ast::Type,
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
match desc.sema {
- ArgumentSemantics::Default | ArgumentSemantics::DefaultRelaxed => {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
+ ArgumentSemantics::Default
+ | ArgumentSemantics::DefaultRelaxed
+ | ArgumentSemantics::PhysicalPointer => {
+ if desc.sema == ArgumentSemantics::PhysicalPointer {
+ typ = ast::Type::Scalar(ast::ScalarType::U64);
+ }
+ let (width, kind) = match typ {
+ ast::Type::Scalar(scalar_t) => {
+ let kind = match scalar_t.kind() {
+ kind @ ScalarKind::Bit
+ | kind @ ScalarKind::Unsigned
+ | kind @ ScalarKind::Signed => kind,
+ ScalarKind::Float => return Err(TranslateError::MismatchedType),
+ ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
+ ScalarKind::Pred => return Err(TranslateError::MismatchedType),
+ };
+ (scalar_t.width(), kind)
+ }
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let arith_detail = if kind == ScalarKind::Signed {
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::from_size(width),
+ saturate: false,
+ })
} else {
- todo!()
+ ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
- let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
+ let id_constant_stmt = self.id_def.new_id(typ);
let result_id = self.id_def.new_id(typ);
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- Ok(result_id)
- }
- ArgumentSemantics::PhysicalPointer => {
- let scalar_t = ast::ScalarType::U64;
- let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- let result_id = self.id_def.new_id(ast::Type::Scalar(scalar_t));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::U64;
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ // TODO: check for edge cases around min value/max value/wrapping
+ if offset < 0 && kind != ScalarKind::Signed {
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: ast::ScalarType::from_parts(width, kind),
+ value: -(offset as i64),
+ }));
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Sub(
+ arith_detail,
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ } else {
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: ast::ScalarType::from_parts(width, kind),
+ value: offset as i64,
+ }));
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ arith_detail,
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ }
Ok(result_id)
}
ArgumentSemantics::RegisterPointer => {
@@ -1522,14 +1543,22 @@ fn emit_function_body_ops(
}
},
ast::Instruction::Mul(mul, arg) => match mul {
- ast::MulDetails::Int(ref ctr) => {
- emit_mul_int(builder, map, opencl, ctr, arg)?;
+ ast::MulDetails::Signed(ref ctr) => {
+ emit_mul_sint(builder, map, opencl, ctr, arg)?
+ }
+ ast::MulDetails::Unsigned(ref ctr) => {
+ emit_mul_uint(builder, map, opencl, ctr, arg)?
}
ast::MulDetails::Float(_) => todo!(),
},
ast::Instruction::Add(add, arg) => match add {
- ast::AddDetails::Int(ref desc) => emit_add_int(builder, map, desc, arg)?,
- ast::AddDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
+ ast::ArithDetails::Signed(ref desc) => {
+ emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)?
+ }
+ ast::ArithDetails::Unsigned(ref desc) => {
+ emit_add_int(builder, map, (*desc).into(), false, arg)?
+ }
+ ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
},
ast::Instruction::Setp(setp, arg) => {
if arg.dst2.is_some() {
@@ -1581,8 +1610,11 @@ fn emit_function_body_ops(
}
ast::Instruction::SetpBool(_, _) => todo!(),
ast::Instruction::Mad(mad, arg) => match mad {
- ast::MulDetails::Int(ref desc) => {
- emit_mad_int(builder, map, opencl, desc, arg)?
+ ast::MulDetails::Signed(ref desc) => {
+ emit_mad_sint(builder, map, opencl, desc, arg)?
+ }
+ ast::MulDetails::Unsigned(ref desc) => {
+ emit_mad_uint(builder, map, opencl, desc, arg)?
}
ast::MulDetails::Float(desc) => emit_mad_float(builder, map, desc, arg)?,
},
@@ -1594,6 +1626,23 @@ fn emit_function_body_ops(
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
}
}
+ ast::Instruction::Sub(d, arg) => match d {
+ ast::ArithDetails::Signed(desc) => {
+ emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?;
+ }
+ ast::ArithDetails::Unsigned(desc) => {
+ emit_sub_int(builder, map, (*desc).into(), false, arg)?;
+ }
+ ast::ArithDetails::Float(desc) => {
+ emit_sub_float(builder, map, desc, arg)?;
+ }
+ },
+ ast::Instruction::Min(d, a) => {
+ emit_min(builder, map, opencl, d, a)?;
+ }
+ ast::Instruction::Max(d, a) => {
+ emit_max(builder, map, opencl, d, a)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(*typ));
@@ -1624,11 +1673,11 @@ fn emit_function_body_ops(
Ok(())
}
-fn emit_mad_int(
+fn emit_mad_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
- desc: &ast::MulIntDesc,
+ desc: &ast::MulUInt,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
@@ -1638,16 +1687,38 @@ fn emit_mad_int(
builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
}
ast::MulIntControl::High => {
- let cl_op = if desc.typ.is_signed() {
- spirv::CLOp::s_mad_hi
- } else {
- spirv::CLOp::u_mad_hi
- };
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
- cl_op as spirv::Word,
+ spirv::CLOp::u_mad_hi as spirv::Word,
+ [arg.src1, arg.src2, arg.src3],
+ )?;
+ }
+ ast::MulIntControl::Wide => todo!(),
+ };
+ Ok(())
+}
+
+fn emit_mad_sint(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MulSInt,
+ arg: &ast::Arg4<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ match desc.control {
+ ast::MulIntControl::Low => {
+ let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
+ builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
+ }
+ ast::MulIntControl::High => {
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::s_mad_hi as spirv::Word,
[arg.src1, arg.src2, arg.src3],
)?;
}
@@ -1659,7 +1730,7 @@ fn emit_mad_int(
fn emit_mad_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- desc: &ast::MulFloatDesc,
+ desc: &ast::ArithFloat,
arg: &ast::Arg4<ExpandedArgParams>,
) -> Result<(), dr::Error> {
todo!()
@@ -1668,7 +1739,7 @@ fn emit_mad_float(
fn emit_add_float(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- desc: &ast::AddFloatDesc,
+ desc: &ast::ArithFloat,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
if desc.flush_to_zero {
@@ -1680,6 +1751,67 @@ fn emit_add_float(
Ok(())
}
+fn emit_sub_float(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ desc: &ast::ArithFloat,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ if desc.flush_to_zero {
+ todo!()
+ }
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
+ emit_rounding_decoration(builder, arg.dst, desc.rounding);
+ Ok(())
+}
+
+fn emit_min(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MinMaxDetails,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let cl_op = match desc {
+ ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
+ ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
+ ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
+ };
+ let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ cl_op as spirv::Word,
+ [arg.src1, arg.src2],
+ )?;
+ Ok(())
+}
+
+fn emit_max(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MinMaxDetails,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let cl_op = match desc {
+ ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
+ ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
+ ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
+ };
+ let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ cl_op as spirv::Word,
+ [arg.src1, arg.src2],
+ )?;
+ Ok(())
+}
+
fn emit_cvt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1880,11 +2012,11 @@ fn emit_setp(
Ok(())
}
-fn emit_mul_int(
+fn emit_mul_sint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
opencl: spirv::Word,
- desc: &ast::MulIntDesc,
+ desc: &ast::MulSInt,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
let instruction_type = ast::ScalarType::from(desc.typ);
@@ -1894,16 +2026,11 @@ fn emit_mul_int(
builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
}
ast::MulIntControl::High => {
- let ocl_mul_hi = if desc.typ.is_signed() {
- spirv::CLOp::s_mul_hi
- } else {
- spirv::CLOp::u_mul_hi
- };
builder.ext_inst(
inst_type,
Some(arg.dst),
opencl,
- ocl_mul_hi as spirv::Word,
+ spirv::CLOp::s_mul_hi as spirv::Word,
[arg.src1, arg.src2],
)?;
}
@@ -1913,11 +2040,54 @@ fn emit_mul_int(
SpirvScalarKey::from(instruction_type),
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
- let mul = if desc.typ.is_signed() {
- builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
- } else {
- builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?
- };
+ let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
+ let instr_width = instruction_type.width();
+ let instr_kind = instruction_type.kind();
+ let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
+ let dst_type_id = map.get_or_add_scalar(builder, dst_type);
+ struct2_bitcast_to_wide(
+ builder,
+ map,
+ SpirvScalarKey::from(instruction_type),
+ inst_type,
+ arg.dst,
+ dst_type_id,
+ mul,
+ )?;
+ }
+ }
+ Ok(())
+}
+
+fn emit_mul_uint(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ desc: &ast::MulUInt,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ let instruction_type = ast::ScalarType::from(desc.typ);
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
+ match desc.control {
+ ast::MulIntControl::Low => {
+ builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::MulIntControl::High => {
+ builder.ext_inst(
+ inst_type,
+ Some(arg.dst),
+ opencl,
+ spirv::CLOp::u_mul_hi as spirv::Word,
+ [arg.src1, arg.src2],
+ )?;
+ }
+ ast::MulIntControl::Wide => {
+ let mul_ext_type = SpirvType::Struct(vec![
+ SpirvScalarKey::from(instruction_type),
+ SpirvScalarKey::from(instruction_type),
+ ]);
+ let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
+ let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
let instr_width = instruction_type.width();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
@@ -1981,14 +2151,33 @@ fn emit_abs(
fn emit_add_int(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- ctr: &ast::AddIntDesc,
+ typ: ast::ScalarType,
+ saturate: bool,
arg: &ast::Arg3<ExpandedArgParams>,
) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(ctr.typ)));
+ if saturate {
+ todo!()
+ }
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
Ok(())
}
+fn emit_sub_int(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ typ: ast::ScalarType,
+ saturate: bool,
+ arg: &ast::Arg3<ExpandedArgParams>,
+) -> Result<(), dr::Error> {
+ if saturate {
+ todo!()
+ }
+ let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
+ builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
+ Ok(())
+}
+
fn emit_implicit_conversion(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -2920,6 +3109,18 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
t,
a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?,
),
+ ast::Instruction::Sub(d, a) => {
+ let typ = d.get_type();
+ ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?)
+ }
+ ast::Instruction::Min(d, a) => {
+ let typ = d.get_type();
+ ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?)
+ }
+ ast::Instruction::Max(d, a) => {
+ let typ = d.get_type();
+ ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?)
+ }
})
}
}
@@ -3129,6 +3330,9 @@ impl ast::Instruction<ExpandedArgParams> {
| ast::Instruction::Abs(_, _)
| ast::Instruction::Call(_)
| ast::Instruction::Or(_, _)
+ | ast::Instruction::Sub(_, _)
+ | ast::Instruction::Min(_, _)
+ | ast::Instruction::Max(_, _)
| ast::Instruction::Mad(_, _) => None,
}
}
@@ -4049,25 +4253,33 @@ impl ast::ShrType {
}
}
-impl ast::AddDetails {
+impl ast::ArithDetails {
fn get_type(&self) -> ast::Type {
- match self {
- ast::AddDetails::Int(ast::AddIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
- ast::AddDetails::Float(ast::AddFloatDesc { typ, .. }) => {
- ast::Type::Scalar((*typ).into())
- }
- }
+ ast::Type::Scalar(match self {
+ ast::ArithDetails::Unsigned(t) => (*t).into(),
+ ast::ArithDetails::Signed(d) => d.typ.into(),
+ ast::ArithDetails::Float(d) => d.typ.into(),
+ })
}
}
impl ast::MulDetails {
fn get_type(&self) -> ast::Type {
- match self {
- ast::MulDetails::Int(ast::MulIntDesc { typ, .. }) => ast::Type::Scalar((*typ).into()),
- ast::MulDetails::Float(ast::MulFloatDesc { typ, .. }) => {
- ast::Type::Scalar((*typ).into())
- }
- }
+ ast::Type::Scalar(match self {
+ ast::MulDetails::Unsigned(d) => d.typ.into(),
+ ast::MulDetails::Signed(d) => d.typ.into(),
+ ast::MulDetails::Float(d) => d.typ.into(),
+ })
+ }
+}
+
+impl ast::MinMaxDetails {
+ fn get_type(&self) -> ast::Type {
+ ast::Type::Scalar(match self {
+ ast::MinMaxDetails::Signed(t) => (*t).into(),
+ ast::MinMaxDetails::Unsigned(t) => (*t).into(),
+ ast::MinMaxDetails::Float(d) => d.typ.into(),
+ })
}
}
@@ -4085,6 +4297,30 @@ impl ast::IntType {
}
}
+impl ast::SIntType {
+ fn from_size(width: u8) -> Self {
+ match width {
+ 1 => ast::SIntType::S8,
+ 2 => ast::SIntType::S16,
+ 4 => ast::SIntType::S32,
+ 8 => ast::SIntType::S64,
+ _ => unreachable!(),
+ }
+ }
+}
+
+impl ast::UIntType {
+ fn from_size(width: u8) -> Self {
+ match width {
+ 1 => ast::UIntType::U8,
+ 2 => ast::UIntType::U16,
+ 4 => ast::UIntType::U32,
+ 8 => ast::UIntType::U64,
+ _ => unreachable!(),
+ }
+ }
+}
+
impl ast::LdStateSpace {
fn to_spirv(self) -> spirv::StorageClass {
match self {
@@ -4128,7 +4364,8 @@ impl<T> ast::OperandOrVector<T> {
impl ast::MulDetails {
fn is_wide(&self) -> bool {
match self {
- ast::MulDetails::Int(desc) => desc.control == ast::MulIntControl::Wide,
+ ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide,
+ ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide,
ast::MulDetails::Float(_) => false,
}
}