diff options
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r-- | ptx/src/translate.rs | 78 |
1 files changed, 57 insertions, 21 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7566be8..3291ad5 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1505,6 +1505,7 @@ fn extract_globals<'input, 'b>( d,
a,
"inc",
+ ast::SizedScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@@ -1526,6 +1527,44 @@ fn extract_globals<'input, 'b>( d,
a,
"dec",
+ ast::SizedScalarType::U32,
+ ));
+ }
+ Statement::Instruction(ast::Instruction::Atom(
+ ast::AtomDetails {
+ inner:
+ ast::AtomInnerDetails::Float {
+ op: ast::AtomFloatOp::Add,
+ typ,
+ },
+ semantics,
+ scope,
+ space,
+ },
+ a,
+ )) => {
+ let details = ast::AtomDetails {
+ inner: ast::AtomInnerDetails::Float {
+ op: ast::AtomFloatOp::Add,
+ typ,
+ },
+ semantics,
+ scope,
+ space,
+ };
+ let (op, typ) = match typ {
+ ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32),
+ ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64),
+ ast::FloatType::F16 => unreachable!(),
+ ast::FloatType::F16x2 => unreachable!(),
+ };
+ local.push(to_ptx_impl_atomic_call(
+ id_def,
+ ptx_impl_imports,
+ details,
+ a,
+ op,
+ typ,
));
}
s => local.push(s),
@@ -1696,6 +1735,7 @@ fn to_ptx_impl_atomic_call( details: ast::AtomDetails,
arg: ast::Arg3<ExpandedArgParams>,
op: &'static str,
+ typ: ast::SizedScalarType,
) -> ExpandedStatement {
let semantics = ptx_semantics_name(details.semantics);
let scope = ptx_scope_name(details.scope);
@@ -1710,15 +1750,14 @@ fn to_ptx_impl_atomic_call( ast::AtomSpace::Global => ast::PointerStateSpace::Global,
ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
};
+ let scalar_typ = ast::ScalarType::from(typ);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
let fn_id = id_defs.new_non_variable(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@@ -1727,17 +1766,14 @@ fn to_ptx_impl_atomic_call( ast::FnArgument {
align: None,
v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::U32,
- ptr_space,
+ typ, ptr_space,
)),
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@@ -1768,19 +1804,16 @@ fn to_ptx_impl_atomic_call( func: fn_id,
ret_params: vec![(
arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
)],
param_list: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::U32,
- ptr_space,
- )),
+ ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)),
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
),
],
})
@@ -1963,14 +1996,13 @@ fn to_ptx_impl_bfi_call( arg.dst,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
)],
- // Note, for some reason PTX and SPIR-V order base&insert arguments differently
param_list: vec![
(
- arg.src2,
+ arg.src1,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
),
(
- arg.src1,
+ arg.src2,
ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
),
(
@@ -3476,8 +3508,12 @@ fn emit_atom( };
(spirv_op, typ.into())
}
- // TODO: Hardware is capable of this, implement it through builtin
- ast::AtomInnerDetails::Float { .. } => todo!(),
+ ast::AtomInnerDetails::Float { op, typ } => {
+ let spirv_op: fn(&mut dr::Builder, _, _, _, _, _, _) -> _ = match op {
+ ast::AtomFloatOp::Add => dr::Builder::atomic_f_add_ext,
+ };
+ (spirv_op, typ.into())
+ }
};
let result_type = map.get_or_add_scalar(builder, typ);
let memory_const = map.get_or_add_constant(
@@ -4287,8 +4323,8 @@ fn emit_implicit_conversion( }
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => {
let result_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
- builder.s_convert(result_type , Some(cv.dst), cv.src)?;
- },
+ builder.s_convert(result_type, Some(cv.dst), cv.src)?;
+ }
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
|