aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs131
1 files changed, 128 insertions, 3 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index a41179d..e015062 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -2992,6 +2992,76 @@ fn emit_function_body_ops<'input>(
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
}
+ ast::Instruction::Prmt { control, arg } => {
+ let control = *control as u32;
+ let components = [
+ (control >> 0) & 0b1111,
+ (control >> 4) & 0b1111,
+ (control >> 8) & 0b1111,
+ (control >> 12) & 0b1111,
+ ];
+ if components.iter().any(|&c| c > 7) {
+ return Err(TranslateError::Todo);
+ }
+ let vec4_b8_type =
+ map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4));
+ let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
+ let src1_vector = builder.bitcast(vec4_b8_type, None, arg.src1)?;
+ let src2_vector = builder.bitcast(vec4_b8_type, None, arg.src2)?;
+ let dst_vector = builder.vector_shuffle(
+ vec4_b8_type,
+ None,
+ src1_vector,
+ src2_vector,
+ components,
+ )?;
+ builder.bitcast(b32_type, Some(arg.dst), dst_vector)?;
+ }
+ ast::Instruction::Activemask { arg } => {
+ let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32);
+ let vec4_b32_type =
+ map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B32, 4));
+ let pred_true = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ &[1],
+ )?;
+ let dst_vector = builder.subgroup_ballot_khr(vec4_b32_type, None, pred_true)?;
+ builder.composite_extract(b32_type, Some(arg.src), dst_vector, [0])?;
+ }
+ ast::Instruction::Membar { level } => {
+ let (scope, semantics) = match level {
+ ast::MemScope::Cta => (
+ spirv::Scope::Workgroup,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ ast::MemScope::Gpu => (
+ spirv::Scope::Device,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ ast::MemScope::Sys => (
+ spirv::Scope::CrossDevice,
+ spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
+ | spirv::MemorySemantics::WORKGROUP_MEMORY
+ | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
+ ),
+ };
+ let spirv_scope = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(scope as u32),
+ )?;
+ let spirv_semantics = map.get_or_add_constant(
+ builder,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ &vec_repr(semantics),
+ )?;
+ builder.memory_barrier(spirv_scope, spirv_semantics)?;
+ }
},
Statement::LoadVar(details) => {
emit_load_var(builder, map, details)?;
@@ -4172,7 +4242,6 @@ fn normalize_identifiers<'input, 'b>(
match s {
ast::Statement::Label(id) => {
id_defs.add_def(*id, None, false);
- eprintln!("{}", id);
}
_ => (),
}
@@ -5800,7 +5869,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let new_args = a.map(visitor, &d)?;
ast::Instruction::St(d, new_args)
}
- ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
+ ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, false, None)?),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
@@ -5942,6 +6011,21 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
arg: arg.map_non_shift(visitor, &full_type, false)?,
}
}
+ ast::Instruction::Prmt { control, arg } => ast::Instruction::Prmt {
+ control,
+ arg: arg.map_prmt(visitor)?,
+ },
+ ast::Instruction::Activemask { arg } => ast::Instruction::Activemask {
+ arg: arg.map(
+ visitor,
+ true,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )),
+ )?,
+ },
+ ast::Instruction::Membar { level } => ast::Instruction::Membar { level },
})
}
}
@@ -6202,6 +6286,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Bfe { .. } => None,
ast::Instruction::Bfi { .. } => None,
ast::Instruction::Rem { .. } => None,
+ ast::Instruction::Prmt { .. } => None,
+ ast::Instruction::Activemask { .. } => None,
+ ast::Instruction::Membar { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -6339,12 +6426,13 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
+ is_dst: bool,
t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
- is_dst: false,
+ is_dst,
is_memory_access: false,
non_default_implicit_conversion: None,
},
@@ -6685,6 +6773,43 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
+
+ fn map_prmt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
+ }
}
impl<T: ArgParamsEx> ast::Arg4<T> {