diff options
Diffstat (limited to 'ptx/src/pass/insert_implicit_conversions2.rs')
-rw-r--r-- | ptx/src/pass/insert_implicit_conversions2.rs | 426 |
1 files changed, 426 insertions, 0 deletions
diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs new file mode 100644 index 0000000..4f738f5 --- /dev/null +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -0,0 +1,426 @@ +use std::mem;
+
+use super::*;
+use ptx_parser as ast;
+
+/*
+ There are several kinds of implicit conversions in PTX:
+ * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
+ * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
+ - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
+ semantics are to first zext/chop/bitcast `y` as needed and then do
+ documented special ld/st/cvt conversion rules for destination operands
+ - st.param [x] y (used as function return arguments) same rule as above applies
+ - generic/global ld: for instruction `ld x, [y]`, y must be of type
+ b64/u64/s64, which is bitcast to a pointer, dereferenced and then
+ documented special ld/st/cvt conversion rules are applied to dst
+ - generic/global st: for instruction `st [x], y`, x must be of type
+ b64/u64/s64, which is bitcast to a pointer
+*/
+pub(super) fn run<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(mut method) => {
+ method.body = method
+ .body
+ .map(|statements| run_statements(resolver, statements))
+ .transpose()?;
+ Directive2::Method(method)
+ }
+ })
+}
+
+fn run_statements<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ func: Vec<ExpandedStatement>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func.into_iter() {
+ insert_implicit_conversions_impl(resolver, &mut result, s)?;
+ }
+ Ok(result)
+}
+
+fn insert_implicit_conversions_impl<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ func: &mut Vec<ExpandedStatement>,
+ stmt: ExpandedStatement,
+) -> Result<(), TranslateError> {
+ let mut post_conv = Vec::new();
+ let statement = stmt.visit_map::<SpirvWord, TranslateError>(
+ &mut |operand,
+ type_state: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst,
+ relaxed_type_check| {
+ let (instr_type, instruction_space) = match type_state {
+ None => return Ok(operand),
+ Some(t) => t,
+ };
+ let (operand_type, operand_space) = resolver.get_typed(operand)?;
+ let conversion_fn = if relaxed_type_check {
+ if is_dst {
+ should_convert_relaxed_dst_wrapper
+ } else {
+ should_convert_relaxed_src_wrapper
+ }
+ } else {
+ default_implicit_conversion
+ };
+ match conversion_fn(
+ (*operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
+ Some(conv_kind) => {
+ let conv_output = if is_dst { &mut post_conv } else { &mut *func };
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type.clone();
+ let mut to_space = *operand_space;
+ let mut src =
+ resolver.register_unnamed(Some((instr_type.clone(), instruction_space)));
+ let mut dst = operand;
+ let result = Ok::<_, TranslateError>(src);
+ if !is_dst {
+ mem::swap(&mut src, &mut dst);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
+ }
+ conv_output.push(Statement::Conversion(ImplicitConversion {
+ src,
+ dst,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
+ kind: conv_kind,
+ }));
+ result
+ }
+ None => Ok(operand),
+ }
+ },
+ )?;
+ func.push(statement);
+ func.append(&mut post_conv);
+ Ok(())
+}
+
+pub(crate) fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if instruction_space == ast::StateSpace::Reg {
+ if operand_space == ast::StateSpace::Reg {
+ if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
+ {
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
+ }
+ } else if is_addressable(operand_space) {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ }
+ if instruction_space != operand_space {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
+}
+
+fn is_addressable(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Param | ast::StateSpace::Reg => false,
+ ast::StateSpace::SharedCluster
+ | ast::StateSpace::SharedCta
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc => todo!(),
+ }
+}
+
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space))
+ || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space))
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if operand_space == ast::StateSpace::Reg {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(error_mismatched_type()),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(error_mismatched_type()),
+ },
+ _ => Err(error_mismatched_type()),
+ }
+ } else if instruction_space == ast::StateSpace::Reg {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else {
+ Ok(None)
+ }
+ }
+ _ => Err(error_mismatched_type()),
+ }
+ } else {
+ Err(error_mismatched_type())
+ }
+}
+
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if space == ast::StateSpace::Reg {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
+ }
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
+ }
+}
+
+fn coerces_to_generic(this: ast::StateSpace) -> bool {
+ match this {
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ptx_parser::StateSpace::SharedCta
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::ParamEntry
+ | ast::StateSpace::ParamFunc
+ | ast::StateSpace::Generic => false,
+ }
+}
+
+fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
+ match (instr, operand) {
+ (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
+ if inst.size_of() != operand.size_of() {
+ return false;
+ }
+ match inst.kind() {
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
+ }
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
+ }
+ ast::ScalarKind::Pred => false,
+ }
+ }
+ (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand))
+ | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => {
+ should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
+ }
+ _ => false,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_dst_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if operand_space != instruction_space {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(TranslateError::MismatchedType),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
+fn should_convert_relaxed_dst(
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if dst_type == instr_type {
+ return None;
+ }
+ match (dst_type, instr_type) {
+ (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
+ if instr_type.size_of() == dst_type.size_of() {
+ Some(ConversionKind::Default)
+ } else if instr_type.size_of() < dst_type.size_of() {
+ Some(ConversionKind::SignExtend)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_dst(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
+
+pub(crate) fn should_convert_relaxed_src_wrapper(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
+) -> Result<Option<ConversionKind>, TranslateError> {
+ if operand_space != instruction_space {
+ return Err(error_mismatched_type());
+ }
+ if operand_type == instruction_type {
+ return Ok(None);
+ }
+ match should_convert_relaxed_src(operand_type, instruction_type) {
+ conv @ Some(_) => Ok(conv),
+ None => Err(error_mismatched_type()),
+ }
+}
+
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
+fn should_convert_relaxed_src(
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
+) -> Option<ConversionKind> {
+ if src_type == instr_type {
+ return None;
+ }
+ match (src_type, instr_type) {
+ (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
+ ast::ScalarKind::Bit => {
+ if instr_type.size_of() <= src_type.size_of() {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() != ast::ScalarKind::Float
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
+ {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ }
+ }
+ ast::ScalarKind::Pred => None,
+ },
+ (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type))
+ | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => {
+ should_convert_relaxed_src(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
+ }
+ _ => None,
+ }
+}
|