aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs201
1 files changed, 159 insertions, 42 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 6b07c0f..c351ccd 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,8 +1,11 @@
use crate::ast;
use half::f16;
use rspirv::{binary::Disassemble, dr};
-use std::collections::{hash_map, HashMap, HashSet};
use std::{borrow::Cow, hash::Hash, iter, mem};
+use std::{
+ collections::{hash_map, HashMap, HashSet},
+ convert::TryInto,
+};
use rspirv::binary::Assemble;
@@ -1499,6 +1502,15 @@ fn convert_to_typed_statements(
ast::Instruction::AtomCas(d, a) => result.push(Statement::Instruction(
ast::Instruction::AtomCas(d, a.cast()),
)),
+ ast::Instruction::Div(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Div(d, a.cast())))
+ }
+ ast::Instruction::Sqrt(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Sqrt(d, a.cast())))
+ }
+ ast::Instruction::Rsqrt(d, a) => {
+ result.push(Statement::Instruction(ast::Instruction::Rsqrt(d, a.cast())))
+ }
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
@@ -1982,7 +1994,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
| ArgumentSemantics::DefaultRelaxed
| ArgumentSemantics::PhysicalPointer => {
if desc.sema == ArgumentSemantics::PhysicalPointer {
- typ = ast::Type::Scalar(ast::ScalarType::U64);
+ typ = self.id_def.get_typed(reg)?;
}
let (width, kind) = match typ {
ast::Type::Scalar(scalar_t) => {
@@ -2013,7 +2025,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::from_parts(width, kind),
- value: ast::ImmediateValue::S64(-(offset as i64)),
+ value: ast::ImmediateValue::U64(-(offset as i64) as u64),
}));
self.func.push(Statement::Instruction(
ast::Instruction::<ExpandedArgParams>::Sub(
@@ -2765,6 +2777,34 @@ fn emit_function_body_ops(
arg.src2,
)?;
}
+ ast::Instruction::Div(details, arg) => match details {
+ ast::DivDetails::Unsigned(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::DivDetails::Signed(t) => {
+ let result_type = map.get_or_add_scalar(builder, (*t).into());
+ builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ }
+ ast::DivDetails::Float(t) => {
+ let result_type = map.get_or_add_scalar(builder, t.typ.into());
+ builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
+ emit_float_div_decoration(builder, arg.dst, t.kind);
+ }
+ },
+ ast::Instruction::Sqrt(details, a) => {
+ emit_sqrt(builder, map, opencl, details, a)?;
+ }
+ ast::Instruction::Rsqrt(details, a) => {
+ let result_type = map.get_or_add_scalar(builder, details.typ.into());
+ builder.ext_inst(
+ result_type,
+ Some(a.dst),
+ opencl,
+ spirv::CLOp::native_rsqrt as spirv::Word,
+ &[a.src],
+ )?;
+ }
},
Statement::LoadVar(arg, typ) => {
let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
@@ -2795,6 +2835,47 @@ fn emit_function_body_ops(
Ok(())
}
+fn emit_sqrt(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ opencl: spirv::Word,
+ details: &ast::SqrtDetails,
+ a: &ast::Arg2<ExpandedArgParams>,
+) -> Result<(), TranslateError> {
+ let result_type = map.get_or_add_scalar(builder, details.typ.into());
+ let (ocl_op, rounding) = match details.kind {
+ ast::SqrtKind::Approx => (spirv::CLOp::native_sqrt, None),
+ ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
+ };
+ builder.ext_inst(
+ result_type,
+ Some(a.dst),
+ opencl,
+ ocl_op as spirv::Word,
+ &[a.src],
+ )?;
+ emit_rounding_decoration(builder, a.dst, rounding);
+ Ok(())
+}
+
+fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) {
+ match kind {
+ ast::DivFloatKind::Approx => {
+ builder.decorate(
+ dst,
+ spirv::Decoration::FPFastMathMode,
+ &[dr::Operand::FPFastMathMode(
+ spirv::FPFastMathMode::ALLOW_RECIP,
+ )],
+ );
+ }
+ ast::DivFloatKind::Rounding(rnd) => {
+ emit_rounding_decoration(builder, dst, Some(rnd));
+ }
+ ast::DivFloatKind::Full => {}
+ }
+}
+
fn emit_atom(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -3307,7 +3388,25 @@ fn emit_setp(
(ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- _ => todo!(),
+ (ast::SetpCompareOp::NanEq, _) => {
+ builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanNotEq, _) => {
+ builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanLess, _) => {
+ builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanLessOrEq, _) => {
+ builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanGreater, _) => {
+ builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
+ }
+ (ast::SetpCompareOp::NanGreaterOrEq, _) => {
+ builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
+ }
+ _ => todo!()
}?;
Ok(())
}
@@ -3486,8 +3585,8 @@ fn emit_implicit_conversion(
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
- (_, _, ConversionKind::PtrToBit) => {
- let dst_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
+ (_, _, ConversionKind::PtrToBit(typ)) => {
+ let dst_type = map.get_or_add_scalar(builder, typ.into());
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
}
(_, _, ConversionKind::BitToPtr(_)) => {
@@ -4570,6 +4669,15 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::AtomCas(d, a) => {
ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
}
+ ast::Instruction::Div(d, a) => {
+ ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?)
+ }
+ ast::Instruction::Sqrt(d, a) => {
+ ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
+ }
+ ast::Instruction::Rsqrt(d, a) => {
+ ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
+ }
})
}
}
@@ -4794,32 +4902,7 @@ impl ast::Instruction<ExpandedArgParams> {
fn jump_target(&self) -> Option<spirv::Word> {
match self {
ast::Instruction::Bra(_, a) => Some(a.src),
- ast::Instruction::Ld(_, _)
- | ast::Instruction::Mov(_, _)
- | ast::Instruction::Mul(_, _)
- | ast::Instruction::Add(_, _)
- | ast::Instruction::Setp(_, _)
- | ast::Instruction::SetpBool(_, _)
- | ast::Instruction::Not(_, _)
- | ast::Instruction::Cvt(_, _)
- | ast::Instruction::Cvta(_, _)
- | ast::Instruction::Shl(_, _)
- | ast::Instruction::Shr(_, _)
- | ast::Instruction::St(_, _)
- | ast::Instruction::Ret(_)
- | ast::Instruction::Abs(_, _)
- | ast::Instruction::Call(_)
- | ast::Instruction::Or(_, _)
- | ast::Instruction::Sub(_, _)
- | ast::Instruction::Min(_, _)
- | ast::Instruction::Max(_, _)
- | ast::Instruction::Rcp(_, _)
- | ast::Instruction::And(_, _)
- | ast::Instruction::Selp(_, _)
- | ast::Instruction::Bar(_, _)
- | ast::Instruction::Atom(_, _)
- | ast::Instruction::AtomCas(_, _)
- | ast::Instruction::Mad(_, _) => None,
+ _ => None,
}
}
@@ -4856,6 +4939,9 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None,
ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None,
ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None,
+ ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None,
+ ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
+ ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
@@ -4885,13 +4971,19 @@ impl ast::Instruction<ExpandedArgParams> {
_,
)
| ast::Instruction::Cvt(
- ast::CvtDetails::FloatFromInt(ast::CvtDesc { flush_to_zero, .. }),
- _,
- )
- | ast::Instruction::Cvt(
ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }),
_,
) => flush_to_zero.map(|ftz| (ftz, 4)),
+ ast::Instruction::Div(ast::DivDetails::Float(details), _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
+ ast::Instruction::Sqrt(details, _) => details
+ .flush_to_zero
+ .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
+ ast::Instruction::Rsqrt(details, _) => Some((
+ details.flush_to_zero,
+ ast::ScalarType::from(details.typ).size_of(),
+ )),
}
}
}
@@ -4978,13 +5070,13 @@ struct ImplicitConversion {
kind: ConversionKind,
}
-#[derive(Debug, PartialEq, Copy, Clone)]
+#[derive(PartialEq, Copy, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
BitToPtr(ast::LdStateSpace),
- PtrToBit,
+ PtrToBit(ast::UIntType),
PtrToPtr { spirv_ptr: bool },
}
@@ -6027,6 +6119,16 @@ impl ast::MinMaxDetails {
}
}
+impl ast::DivDetails {
+ fn get_type(&self) -> ast::Type {
+ ast::Type::Scalar(match self {
+ ast::DivDetails::Unsigned(t) => (*t).into(),
+ ast::DivDetails::Signed(t) => (*t).into(),
+ ast::DivDetails::Float(d) => d.typ.into(),
+ })
+ }
+}
+
impl ast::AtomInnerDetails {
fn get_type(&self) -> ast::ScalarType {
match self {
@@ -6193,6 +6295,15 @@ fn bitcast_physical_pointer(
Err(TranslateError::Unreachable)
}
}
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => {
+ if let Some(ast::LdStateSpace::Shared) = ss {
+ Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ }
ast::Type::Pointer(op_scalar_t, op_space) => {
if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
if op_space == instr_space {
@@ -6220,10 +6331,16 @@ fn bitcast_physical_pointer(
fn force_bitcast_ptr_to_bit(
_: &ast::Type,
- _: &ast::Type,
+ instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
- Ok(Some(ConversionKind::PtrToBit))
+ // TODO: verify this on f32, u16 and the like
+ if let ast::Type::Scalar(scalar_t) = instr_type {
+ if let Ok(int_type) = (*scalar_t).try_into() {
+ return Ok(Some(ConversionKind::PtrToBit(int_type)));
+ }
+ }
+ Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
@@ -6542,9 +6659,9 @@ mod tests {
&ast::Type::Scalar(*instr_type),
);
if instr_idx == op_idx {
- assert_eq!(conversion, None);
+ assert!(conversion == None);
} else {
- assert_eq!(conversion, conv_table[instr_idx][op_idx]);
+ assert!(conversion == conv_table[instr_idx][op_idx]);
}
}
}