aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs249
1 files changed, 241 insertions, 8 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 23a63be..365d1e8 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1289,6 +1289,9 @@ fn extract_globals<'input, 'b>(
..
},
) => global.push(var),
+ Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => {
+ local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg));
+ }
Statement::Instruction(ast::Instruction::Atom(
d
@
@@ -1591,6 +1594,24 @@ fn convert_to_typed_statements(
arg: arg.cast(),
}))
}
+ ast::Instruction::Xor { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Xor {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
+ ast::Instruction::Bfe { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Bfe {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
+ ast::Instruction::Rem { typ, arg } => {
+ result.push(Statement::Instruction(ast::Instruction::Rem {
+ typ,
+ arg: arg.cast(),
+ }))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1610,6 +1631,7 @@ fn convert_to_typed_statements(
Ok(result)
}
+//TODO: share common code between this and to_ptx_impl_bfe_call
fn to_ptx_impl_atomic_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
@@ -1705,6 +1727,100 @@ fn to_ptx_impl_atomic_call(
})
}
+fn to_ptx_impl_bfe_call(
+ id_defs: &mut NumericIdResolver,
+ ptx_impl_imports: &mut HashMap<String, Directive>,
+ typ: ast::IntType,
+ arg: ast::Arg4<ExpandedArgParams>,
+) -> ExpandedStatement {
+ let prefix = "__notcuda_ptx_impl__";
+ let suffix = match typ {
+ ast::IntType::U32 => "bfe_u32",
+ ast::IntType::U64 => "bfe_u64",
+ ast::IntType::S32 => "bfe_s32",
+ ast::IntType::S64 => "bfe_s64",
+ _ => unreachable!(),
+ };
+ let fn_name = format!("{}{}", prefix, suffix);
+ let fn_id = match ptx_impl_imports.entry(fn_name) {
+ hash_map::Entry::Vacant(entry) => {
+ let fn_id = id_defs.new_id(None);
+ let func_decl = ast::MethodDecl::Func::<spirv::Word>(
+ vec![ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ }],
+ fn_id,
+ vec![
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
+ ast::ScalarType::U32,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ast::FnArgument {
+ align: None,
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
+ ast::ScalarType::U32,
+ )),
+ name: id_defs.new_id(None),
+ array_init: Vec::new(),
+ },
+ ],
+ );
+ let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ let func = Function {
+ func_decl,
+ globals: Vec::new(),
+ body: None,
+ import_as: Some(entry.key().clone()),
+ spirv_decl,
+ };
+ entry.insert(Directive::Method(func));
+ fn_id
+ }
+ hash_map::Entry::Occupied(entry) => match entry.get() {
+ Directive::Method(Function {
+ func_decl: ast::MethodDecl::Func(_, name, _),
+ ..
+ }) => *name,
+ _ => unreachable!(),
+ },
+ };
+ Statement::Call(ResolvedCall {
+ uniform: false,
+ func: fn_id,
+ ret_params: vec![(
+ arg.dst,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ )],
+ param_list: vec![
+ (
+ arg.src1,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ),
+ (
+ arg.src2,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ),
+ (
+ arg.src3,
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ),
+ ],
+ })
+}
+
fn to_resolved_fn_args<T>(
params: Vec<T>,
params_decl: &[ast::FnArgumentType],
@@ -2803,7 +2919,7 @@ fn emit_function_body_ops(
let result_id = Some(a.dst);
let operand = a.src;
match t {
- ast::NotType::Pred => {
+ ast::BooleanType::Pred => {
// HACK ALERT
// Temporary workaround until IGC gets its shit together
// Currently IGC carries two copies of SPIRV-LLVM translator
@@ -2854,7 +2970,7 @@ fn emit_function_body_ops(
},
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::OrAndType::Pred {
+ if *t == ast::BooleanType::Pred {
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -2882,7 +2998,7 @@ fn emit_function_body_ops(
}
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::OrAndType::Pred {
+ if *t == ast::BooleanType::Pred {
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -3033,6 +3149,39 @@ fn emit_function_body_ops(
let result_type = map.get_or_add_scalar(builder, (*typ).into());
builder.bit_count(result_type, Some(arg.dst), arg.src)?;
}
+ ast::Instruction::Xor { typ, arg } => {
+ let builder_fn = match typ {
+ ast::BooleanType::Pred => emit_logical_xor_spirv,
+ _ => dr::Builder::bitwise_xor,
+ };
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::Instruction::Bfe { typ, arg } => {
+ let builder_fn = if typ.is_signed() {
+ dr::Builder::bit_field_s_extract
+ } else {
+ dr::Builder::bit_field_u_extract
+ };
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder_fn(
+ builder,
+ result_type,
+ Some(arg.dst),
+ arg.src1,
+ arg.src2,
+ arg.src3,
+ )?;
+ }
+ ast::Instruction::Rem { typ, arg } => {
+ let builder_fn = if typ.is_signed() {
+ dr::Builder::s_mod
+ } else {
+ dr::Builder::u_mod
+ };
+ let result_type = map.get_or_add_scalar(builder, (*typ).into());
+ builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -3079,6 +3228,20 @@ fn emit_function_body_ops(
Ok(())
}
+// TODO: check what kind of assembly do we emit
+fn emit_logical_xor_spirv(
+ builder: &mut dr::Builder,
+ result_type: spirv::Word,
+ result_id: Option<spirv::Word>,
+ op1: spirv::Word,
+ op2: spirv::Word,
+) -> Result<spirv::Word, dr::Error> {
+ let temp_or = builder.logical_or(result_type, None, op1, op2)?;
+ let temp_and = builder.logical_and(result_type, None, op1, op2)?;
+ let temp_neg = builder.logical_not(result_type, None, temp_and)?;
+ builder.logical_and(result_type, result_id, temp_or, temp_neg)
+}
+
fn emit_sqrt(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -5039,6 +5202,27 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
}
}
+ ast::Instruction::Xor { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Xor {
+ typ,
+ arg: arg.map_non_shift(visitor, &full_type, false)?,
+ }
+ }
+ ast::Instruction::Bfe { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Bfe {
+ typ,
+ arg: arg.map_bfe(visitor, &full_type)?,
+ }
+ }
+ ast::Instruction::Rem { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Rem {
+ typ,
+ arg: arg.map_non_shift(visitor, &full_type, false)?,
+ }
+ }
})
}
}
@@ -5351,6 +5535,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Clz { .. } => None,
ast::Instruction::Brev { .. } => None,
ast::Instruction::Popc { .. } => None,
+ ast::Instruction::Xor { .. } => None,
+ ast::Instruction::Bfe { .. } => None,
+ ast::Instruction::Rem { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -6192,6 +6379,52 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
src3,
})
}
+
+ fn map_bfe<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ typ: &ast::Type,
+ ) -> Result<ast::Arg4<U>, TranslateError> {
+ let dst = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(typ),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ typ,
+ )?;
+ let u32_type = ast::Type::Scalar(ast::ScalarType::U32);
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &u32_type,
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ &u32_type,
+ )?;
+ Ok(ast::Arg4 {
+ dst,
+ src1,
+ src2,
+ src3,
+ })
+ }
}
impl<T: ArgParamsEx> ast::Arg4Setp<T> {
@@ -6437,13 +6670,13 @@ impl ast::ScalarType {
}
}
-impl ast::NotType {
+impl ast::BooleanType {
fn to_type(self) -> ast::Type {
match self {
- ast::NotType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
- ast::NotType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::NotType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::NotType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
+ ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
+ ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
+ ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
}
}
}