aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-09-16 01:25:09 +0200
committerAndrzej Janik <[email protected]>2021-09-16 01:25:09 +0200
commitca0d8ec666e499ec1a71132757acba407c3ba53b (patch)
tree161817205dae53c71ab6985b7f5c888cbf796c8a
parent467782b1d00da5f519840435aa417163fcb1a128 (diff)
downloadZLUDA-ca0d8ec666e499ec1a71132757acba407c3ba53b.tar.gz
ZLUDA-ca0d8ec666e499ec1a71132757acba407c3ba53b.zip
Add missing vray instructions
-rw-r--r--ptx/src/ast.rs3
-rw-r--r--ptx/src/ptx.lalrpop43
-rw-r--r--ptx/src/test/spirv_run/activemask.ptx18
-rw-r--r--ptx/src/test/spirv_run/activemask.spvtxt45
-rw-r--r--ptx/src/test/spirv_run/membar.ptx21
-rw-r--r--ptx/src/test/spirv_run/membar.spvtxt49
-rw-r--r--ptx/src/test/spirv_run/mod.rs3
-rw-r--r--ptx/src/test/spirv_run/prmt.ptx23
-rw-r--r--ptx/src/test/spirv_run/prmt.spvtxt67
-rw-r--r--ptx/src/translate.rs131
10 files changed, 399 insertions, 4 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 36e7191..a8309b0 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -287,6 +287,9 @@ pub enum Instruction<P: ArgParams> {
Bfe { typ: ScalarType, arg: Arg4<P> },
Bfi { typ: ScalarType, arg: Arg5<P> },
Rem { typ: ScalarType, arg: Arg3<P> },
+ Prmt { control: u16, arg: Arg3<P> },
+ Activemask { arg: Arg1<P> },
+ Membar { level: MemScope },
}
#[derive(Copy, Clone)]
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 0bc7655..fa3cfec 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -70,6 +70,7 @@ match {
".func",
".ge",
".geu",
+ ".gl",
".global",
".gpu",
".gt",
@@ -142,6 +143,7 @@ match {
} else {
// IF YOU ARE ADDING A NEW TOKEN HERE ALSO ADD IT BELOW TO ExtendedID
"abs",
+ "activemask",
"add",
"and",
"atom",
@@ -165,6 +167,7 @@ match {
"mad",
"map_f64_to_f32",
"max",
+ "membar",
"min",
"mov",
"mul",
@@ -172,6 +175,7 @@ match {
"not",
"or",
"popc",
+ "prmt",
"rcp",
"rem",
"ret",
@@ -196,6 +200,7 @@ match {
ExtendedID : &'input str = {
"abs",
+ "activemask",
"add",
"and",
"atom",
@@ -219,6 +224,7 @@ ExtendedID : &'input str = {
"mad",
"map_f64_to_f32",
"max",
+ "membar",
"min",
"mov",
"mul",
@@ -226,6 +232,7 @@ ExtendedID : &'input str = {
"not",
"or",
"popc",
+ "prmt",
"rcp",
"rem",
"ret",
@@ -292,6 +299,16 @@ U8Num: u8 = {
}
}
+U16Num: u16 = {
+ <x:NumToken> =>? {
+ let (text, radix, _) = x;
+ match u16::from_str_radix(text, radix) {
+ Ok(x) => Ok(x),
+ Err(err) => Err(ParseError::User { error: ast::PtxError::from(err) })
+ }
+ }
+}
+
U32Num: u32 = {
<x:NumToken> =>? {
let (text, radix, _) = x;
@@ -761,6 +778,9 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstRem,
InstBfe,
InstBfi,
+ InstPrmt,
+ InstActivemask,
+ InstMembar,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
@@ -821,6 +841,12 @@ MemScope: ast::MemScope = {
".sys" => ast::MemScope::Sys
};
+MembarLevel: ast::MemScope = {
+ ".cta" => ast::MemScope::Cta,
+ ".gl" => ast::MemScope::Gpu,
+ ".sys" => ast::MemScope::Sys
+};
+
LdNonGlobalStateSpace: ast::StateSpace = {
".const" => ast::StateSpace::Const,
".local" => ast::StateSpace::Local,
@@ -1445,8 +1471,9 @@ SelpType: ast::ScalarType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar
InstBar: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "bar" ".sync" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a),
+ "barrier" ".sync" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a),
"barrier" ".sync" ".aligned" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a),
- "bar" ".sync" <a:Arg1Bar> => ast::Instruction::Bar(ast::BarDetails::SyncAligned, a)
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
@@ -1731,11 +1758,25 @@ InstBfi: ast::Instruction<ast::ParsedArgParams<'input>> = {
"bfi" <typ:BitType> <arg:Arg5> => ast::Instruction::Bfi{ <> }
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
+InstPrmt: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "prmt" ".b32" <arg:Arg3> "," <control:U16Num> => ast::Instruction::Prmt{ <> }
+}
+
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem
InstRem: ast::Instruction<ast::ParsedArgParams<'input>> = {
"rem" <typ:IntType> <arg:Arg3> => ast::Instruction::Rem{ <> }
}
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask
+InstActivemask: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "activemask" ".b32" <arg:Arg1> => ast::Instruction::Activemask{ <> }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar
+InstMembar: ast::Instruction<ast::ParsedArgParams<'input>> = {
+ "membar" <level:MembarLevel> => ast::Instruction::Membar{ <> }
+}
NegTypeFtz: ast::ScalarType = {
".f16" => ast::ScalarType::F16,
diff --git a/ptx/src/test/spirv_run/activemask.ptx b/ptx/src/test/spirv_run/activemask.ptx
new file mode 100644
index 0000000..c352bb2
--- /dev/null
+++ b/ptx/src/test/spirv_run/activemask.ptx
@@ -0,0 +1,18 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry activemask(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 out_addr;
+ .reg .b32 temp;
+
+ ld.param.u64 out_addr, [output];
+
+ activemask.b32 temp;
+ st.u32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/activemask.spvtxt b/ptx/src/test/spirv_run/activemask.spvtxt
new file mode 100644
index 0000000..c4ad55d
--- /dev/null
+++ b/ptx/src/test/spirv_run/activemask.spvtxt
@@ -0,0 +1,45 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %16 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "activemask"
+ OpExecutionMode %1 ContractionOff
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %19 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %v4uint = OpTypeVector %uint 4
+ %bool = OpTypeBool
+ %true = OpConstantTrue %bool
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %1 = OpFunction %void None %19
+ %6 = OpFunctionParameter %ulong
+ %7 = OpFunctionParameter %ulong
+ %14 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %6
+ OpStore %3 %7
+ %8 = OpLoad %ulong %3 Aligned 8
+ OpStore %4 %8
+ %26 = OpSubgroupBallotKHR %v4uint %true
+ %9 = OpCompositeExtract %uint %26 0
+ OpStore %5 %9
+ %10 = OpLoad %ulong %4
+ %11 = OpLoad %uint %5
+ %12 = OpConvertUToPtr %_ptr_Generic_uint %10
+ %13 = OpCopyObject %uint %11
+ OpStore %12 %13 Aligned 4
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/membar.ptx b/ptx/src/test/spirv_run/membar.ptx
new file mode 100644
index 0000000..01aa9f2
--- /dev/null
+++ b/ptx/src/test/spirv_run/membar.ptx
@@ -0,0 +1,21 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry membar(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .s32 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u32 temp, [in_addr];
+ membar.sys;
+ st.s32 [out_addr], temp;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/membar.spvtxt b/ptx/src/test/spirv_run/membar.spvtxt
new file mode 100644
index 0000000..d808cf3
--- /dev/null
+++ b/ptx/src/test/spirv_run/membar.spvtxt
@@ -0,0 +1,49 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %20 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "membar"
+ OpExecutionMode %1 ContractionOff
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %23 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %uint_0 = OpConstant %uint 0
+ %uint_784 = OpConstant %uint 784
+ %1 = OpFunction %void None %23
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %18 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %9 = OpLoad %ulong %2 Aligned 8
+ OpStore %4 %9
+ %10 = OpLoad %ulong %3 Aligned 8
+ OpStore %5 %10
+ %12 = OpLoad %ulong %4
+ %16 = OpConvertUToPtr %_ptr_Generic_uint %12
+ %15 = OpLoad %uint %16 Aligned 4
+ %11 = OpCopyObject %uint %15
+ OpStore %6 %11
+ OpMemoryBarrier %uint_0 %uint_784
+ %13 = OpLoad %ulong %5
+ %14 = OpLoad %uint %6
+ %17 = OpConvertUToPtr %_ptr_Generic_uint %13
+ OpStore %17 %14 Aligned 4
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 97cfbb5..f6b556e 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -206,6 +206,9 @@ test_ptx!(stateful_neg_offset, [1237518u64], [1237518u64]);
test_ptx!(const, [0u16], [10u16, 20, 30, 40]);
test_ptx!(cvt_s16_s8, [0x139231C2u32], [0xFFFFFFC2u32]);
test_ptx!(cvt_f64_f32, [0.125f32], [0.125f64]);
+test_ptx!(prmt, [0x70c507d6u32, 0x6fbd4b5cu32], [0x6fbdd65cu32]);
+test_ptx!(activemask, [0u32], [1u32]);
+test_ptx!(membar, [152731u32], [152731u32]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/prmt.ptx b/ptx/src/test/spirv_run/prmt.ptx
new file mode 100644
index 0000000..ba339e8
--- /dev/null
+++ b/ptx/src/test/spirv_run/prmt.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry prmt(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp1;
+ .reg .u32 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u32 temp1, [in_addr];
+ ld.u32 temp2, [in_addr+4];
+ prmt.b32 temp2, temp1, temp2, 30212;
+ st.u32 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/prmt.spvtxt b/ptx/src/test/spirv_run/prmt.spvtxt
new file mode 100644
index 0000000..060f534
--- /dev/null
+++ b/ptx/src/test/spirv_run/prmt.spvtxt
@@ -0,0 +1,67 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %31 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "prmt"
+ OpExecutionMode %1 ContractionOff
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %34 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
+ %v4uchar = OpTypeVector %uchar 4
+ %1 = OpFunction %void None %34
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %29 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %10 = OpLoad %ulong %2 Aligned 8
+ OpStore %4 %10
+ %11 = OpLoad %ulong %3 Aligned 8
+ OpStore %5 %11
+ %13 = OpLoad %ulong %4
+ %23 = OpConvertUToPtr %_ptr_Generic_uint %13
+ %12 = OpLoad %uint %23 Aligned 4
+ OpStore %6 %12
+ %15 = OpLoad %ulong %4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %41 = OpBitcast %_ptr_Generic_uchar %24
+ %42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_uint %42
+ %14 = OpLoad %uint %22 Aligned 4
+ OpStore %7 %14
+ %17 = OpLoad %uint %6
+ %18 = OpLoad %uint %7
+ %26 = OpCopyObject %uint %17
+ %27 = OpCopyObject %uint %18
+ %44 = OpBitcast %v4uchar %26
+ %45 = OpBitcast %v4uchar %27
+ %46 = OpVectorShuffle %v4uchar %44 %45 4 0 6 7
+ %25 = OpBitcast %uint %46
+ %16 = OpCopyObject %uint %25
+ OpStore %7 %16
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %uint %7
+ %28 = OpConvertUToPtr %_ptr_Generic_uint %19
+ OpStore %28 %20 Aligned 4
+ OpReturn
+ OpFunctionEnd \ No newline at end of file
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index a41179d..e015062 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -2992,6 +2992,76 @@ fn emit_function_body_ops<'input>(
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
+ ast::Instruction::Prmt { control, arg } => {
+ let control = *control as u32;
+ let components = [
+ (control >> 0) & 0b1111,
+ (control >> 4) & 0b1111,
+ (control >> 8) & 0b1111,
+ (control >> 12) & 0b1111,
+ ];
+ if components.iter().any(|&c| c > 7) {
+ return Err(TranslateError::Todo);
+ }
+ let vec4_b8_type =
+ map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4));
+ let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
+ let src1_vector = builder.bitcast(vec4_b8_type, None, arg.src1)?;
+ let src2_vector = builder.bitcast(vec4_b8_type, None, arg.src2)?;
+ let dst_vector = builder.vector_shuffle(
+ vec4_b8_type,
+ None,
+ src1_vector,
+ src2_vector,
+ components,
+ )?;
+ builder.bitcast(b32_type, Some(arg.dst), dst_vector)?;
+ }
+ ast::Instruction::Activemask { arg } => {
+ let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
+ let vec4_b32_type =
+ map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B32, 4));
+ let pred_true = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ &[1],
+ )?;
+ let dst_vector = builder.subgroup_ballot_khr(vec4_b32_type, None, pred_true)?;
+ builder.composite_extract(b32_type, Some(arg.src), dst_vector, [0])?;
+ }
+ ast::Instruction::Membar { level } => {
+ let (scope, semantics) = match level {
+ ast::MemScope::Cta => (
+ spirv::Scope::Workgroup,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ ast::MemScope::Gpu => (
+ spirv::Scope::Device,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ ast::MemScope::Sys => (
+ spirv::Scope::CrossDevice,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ };
+ let spirv_scope = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(scope as u32),
+ )?;
+ let spirv_semantics = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(semantics),
+ )?;
+ builder.memory_barrier(spirv_scope, spirv_semantics)?;
+ }
},
Statement::LoadVar(details) => {
emit_load_var(builder, map, details)?;
@@ -4172,7 +4242,6 @@ fn normalize_identifiers<'input, 'b>(
match s {
ast::Statement::Label(id) => {
id_defs.add_def(*id, None, false);
- eprintln!("{}", id);
}
_ => (),
}
@@ -5800,7 +5869,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let new_args = a.map(visitor, &d)?;
ast::Instruction::St(d, new_args)
}
- ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
+ ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, false, None)?),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
@@ -5942,6 +6011,21 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
arg: arg.map_non_shift(visitor, &full_type, false)?,
}
}
+ ast::Instruction::Prmt { control, arg } => ast::Instruction::Prmt {
+ control,
+ arg: arg.map_prmt(visitor)?,
+ },
+ ast::Instruction::Activemask { arg } => ast::Instruction::Activemask {
+ arg: arg.map(
+ visitor,
+ true,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )),
+ )?,
+ },
+ ast::Instruction::Membar { level } => ast::Instruction::Membar { level },
})
}
}
@@ -6202,6 +6286,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Bfe { .. } => None,
ast::Instruction::Bfi { .. } => None,
ast::Instruction::Rem { .. } => None,
+ ast::Instruction::Prmt { .. } => None,
+ ast::Instruction::Activemask { .. } => None,
+ ast::Instruction::Membar { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -6339,12 +6426,13 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
+ is_dst: bool,
t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
- is_dst: false,
+ is_dst,
is_memory_access: false,
non_default_implicit_conversion: None,
},
@@ -6685,6 +6773,43 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
+
+ fn map_prmt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
+ }
}
impl<T: ArgParamsEx> ast::Arg4<T> {