From 78a9f22cf7e6c819f04991c1624578c969c1a146 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 23 Sep 2024 06:02:28 +0200 Subject: Refactor implicit conversions, explicit ld/st and global hoisting --- ptx/src/pass/emit_llvm.rs | 80 +++-- ptx/src/pass/hoist_globals.rs | 45 +++ ptx/src/pass/insert_explicit_load_store.rs | 101 +++++-- ptx/src/pass/insert_implicit_conversions2.rs | 426 +++++++++++++++++++++++++++ ptx/src/pass/mod.rs | 19 +- ptx/src/test/spirv_run/mod.rs | 2 +- 6 files changed, 627 insertions(+), 46 deletions(-) create mode 100644 ptx/src/pass/hoist_globals.rs create mode 100644 ptx/src/pass/insert_implicit_conversions2.rs diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 3060335..235ad7d 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -164,17 +164,16 @@ impl Deref for MemoryBuffer { } pub(super) fn run<'input>( - id_defs: &GlobalStringIdResolver<'input>, - call_map: MethodsCallMap<'input>, - directives: Vec>, + id_defs: GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, ) -> Result { let context = Context::new(); let module = Module::new(&context, LLVM_UNNAMED); - let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs); + let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs); for directive in directives { match directive { - Directive::Variable(..) => todo!(), - Directive::Method(method) => emit_ctx.emit_method(method)?, + Directive2::Variable(..) => todo!(), + Directive2::Method(method) => emit_ctx.emit_method(method)?, } } module.write_to_stderr(); @@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> { context: LLVMContextRef, module: LLVMModuleRef, builder: Builder, - id_defs: &'a GlobalStringIdResolver<'input>, + id_defs: &'a GlobalStringIdentResolver2<'input>, resolver: ResolveIdent, } @@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn new( context: &Context, module: &Module, - id_defs: &'a GlobalStringIdResolver<'input>, + id_defs: &'a GlobalStringIdentResolver2<'input>, ) -> Self { ModuleEmitContext { context: context.get(), @@ -215,20 +214,27 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { LLVMCallConv::LLVMCCallConv as u32 } - fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> { - let func_decl = method.func_decl.borrow(); + fn emit_method( + &mut self, + method: Function2<'input, ast::Instruction, SpirvWord>, + ) -> Result<(), TranslateError> { + let func_decl = method.func_decl; let name = method .import_as .as_deref() - .unwrap_or_else(|| match func_decl.name { - ast::MethodName::Kernel(name) => name, - ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id], - }); + .or_else(|| match func_decl.name { + ast::MethodName::Kernel(name) => Some(name), + ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(), + }) + .ok_or_else(|| error_unreachable())?; let name = CString::new(name).map_err(|_| error_unreachable())?; let fn_type = get_function_type( self.context, func_decl.return_arguments.iter().map(|v| &v.v_type), - func_decl.input_arguments.iter().map(|v| &v.v_type), + func_decl + .input_arguments + .iter() + .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), )?; let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; if let ast::MethodName::Func(name) = func_decl.name { @@ -239,6 +245,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { let name = self.resolver.get_or_add(param.name); unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; self.resolver.register(param.name, value); + if func_decl.name.is_kernel() { + let attr_kind = unsafe { + LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len()) + }; + let attr = unsafe { + LLVMCreateTypeAttribute( + self.context, + attr_kind, + get_type(self.context, ¶m.v_type)?, + ) + }; + unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; + } } let call_conv = if func_decl.name.is_kernel() { Self::kernel_call_convention() @@ -264,12 +283,26 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { } } +fn get_input_argument_type( + context: LLVMContextRef, + v_type: &ptx_parser::Type, + state_space: ptx_parser::StateSpace, +) -> Result { + match state_space { + ptx_parser::StateSpace::ParamEntry => { + Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) }) + } + ptx_parser::StateSpace::Reg => get_type(context, v_type), + _ => return Err(error_unreachable()), + } +} + struct MethodEmitContext<'a, 'input> { context: LLVMContextRef, module: LLVMModuleRef, method: LLVMValueRef, builder: LLVMBuilderRef, - id_defs: &'a GlobalStringIdResolver<'input>, + id_defs: &'a GlobalStringIdentResolver2<'input>, variables_builder: Builder, resolver: &'a mut ResolveIdent, } @@ -533,7 +566,9 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { let type_ = get_function_type( self.context, data.return_arguments.iter().map(|(type_, space)| type_), - data.input_arguments.iter().map(|(type_, space)| type_), + data.input_arguments + .iter() + .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)), )?; let mut input_arguments = arguments .input_arguments @@ -633,11 +668,10 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR fn get_function_type<'a>( context: LLVMContextRef, mut return_args: impl ExactSizeIterator, - input_args: impl ExactSizeIterator, + input_args: impl ExactSizeIterator>, ) -> Result { - let mut input_args: Vec<*mut llvm_zluda::LLVMType> = input_args - .map(|type_| get_type(context, type_)) - .collect::, _>>()?; + let mut input_args: Vec<*mut llvm_zluda::LLVMType> = + input_args.collect::, _>>()?; let return_type = match return_args.len() { 0 => unsafe { LLVMVoidTypeInContext(context) }, 1 => get_type(context, return_args.next().unwrap())?, @@ -658,7 +692,7 @@ fn get_state_space(space: ast::StateSpace) -> Result { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), ast::StateSpace::Param => Err(TranslateError::Todo), - ast::StateSpace::ParamEntry => Err(TranslateError::Todo), + ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::ParamFunc => Err(TranslateError::Todo), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), @@ -675,7 +709,7 @@ struct ResolveIdent { } impl ResolveIdent { - fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self { + fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self { ResolveIdent { words: HashMap::new(), values: HashMap::new(), diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs new file mode 100644 index 0000000..753172a --- /dev/null +++ b/ptx/src/pass/hoist_globals.rs @@ -0,0 +1,45 @@ +use super::*; + +pub(super) fn run<'input>( + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let mut result = Vec::with_capacity(directives.len()); + for mut directive in directives.into_iter() { + run_directive(&mut result, &mut directive); + result.push(directive); + } + Ok(result) +} + +fn run_directive<'input>( + result: &mut Vec, SpirvWord>>, + directive: &mut Directive2<'input, ptx_parser::Instruction, SpirvWord>, +) -> Result<(), TranslateError> { + match directive { + Directive2::Variable(..) => {} + Directive2::Method(function2) => run_function(result, function2), + } + Ok(()) +} + +fn run_function<'input>( + result: &mut Vec, SpirvWord>>, + function: &mut Function2<'input, ptx_parser::Instruction, SpirvWord>, +) { + function.body = function.body.take().map(|statements| { + statements + .into_iter() + .filter_map(|statement| match statement { + Statement::Variable(var @ ast::Variable { + state_space: + ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared, + .. + }) => { + result.push(Directive2::Variable(ast::LinkingDirective::NONE, var)); + None + } + s => Some(s), + }) + .collect() + }); +} diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs index e8f01cd..ec6498c 100644 --- a/ptx/src/pass/insert_explicit_load_store.rs +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -41,10 +41,9 @@ fn run_method<'a, 'input>( ) -> Result, SpirvWord>, TranslateError> { let mut func_decl = method.func_decl; for arg in func_decl.return_arguments.iter_mut() { - visitor.visit_variable(arg); + visitor.visit_variable(arg)?; } let is_kernel = func_decl.name.is_kernel(); - // let mut prelude = Vec::with_capacity(method.body.as_ref().map(Vec::len).unwrap_or(0)); if is_kernel { for arg in func_decl.input_arguments.iter_mut() { let old_name = arg.name; @@ -85,23 +84,29 @@ fn run_statement<'a, 'input>( ) -> Result<(), TranslateError> { match statement { Statement::Variable(mut var) => { - visitor.visit_variable(&mut var); + visitor.visit_variable(&mut var)?; result.push(Statement::Variable(var)); } Statement::Instruction(ast::Instruction::Ld { data, arguments }) => { let instruction = visitor.visit_ld(data, arguments)?; let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); } - Statement::Instruction(ast::Instruction::St { - data, - mut arguments, - }) => { + Statement::Instruction(ast::Instruction::St { data, arguments }) => { let instruction = visitor.visit_st(data, arguments)?; let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + s => { + let new_statement = s.visit_map(visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(new_statement); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); } - s => result.push(s.visit_map(visitor)?), } Ok(()) } @@ -109,6 +114,8 @@ fn run_statement<'a, 'input>( struct InsertMemSSAVisitor<'a, 'input> { resolver: &'a mut GlobalStringIdentResolver2<'input>, variables: FxHashMap, + pre: Vec>, + post: Vec>, } impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { @@ -116,6 +123,8 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { Self { resolver, variables: FxHashMap::default(), + pre: Vec::new(), + post: Vec::new(), } } @@ -141,14 +150,20 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn variable( &mut self, + type_: &ast::Type, old_name: SpirvWord, new_name: SpirvWord, old_space: ast::StateSpace, ) -> Result<(), TranslateError> { Ok(match old_space { ast::StateSpace::Reg => { - self.variables - .insert(old_name, RemapAction::PreLdPostSt(new_name)); + self.variables.insert( + old_name, + RemapAction::PreLdPostSt { + name: new_name, + type_: type_.clone(), + }, + ); } ast::StateSpace::Param => { self.variables.insert( @@ -182,7 +197,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { ) -> Result, TranslateError> { if let Some(remap) = self.variables.get(&arguments.src1) { match remap { - RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()), + RemapAction::PreLdPostSt { .. } => {} RemapAction::LDStSpaceChange { old_space, new_space, @@ -206,7 +221,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { ) -> Result, TranslateError> { if let Some(remap) = self.variables.get(&arguments.src) { match remap { - RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()), + RemapAction::PreLdPostSt { .. } => {} RemapAction::LDStSpaceChange { old_space, new_space, @@ -223,7 +238,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { Ok(ast::Instruction::Ld { data, arguments }) } - fn visit_variable(&mut self, var: &mut ast::Variable) { + fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { if var.state_space != ast::StateSpace::Local { let old_name = var.name; let old_space = var.state_space; @@ -231,10 +246,11 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { let new_name = self .resolver .register_unnamed(Some((var.v_type.clone(), new_space))); - self.variable(old_name, new_name, old_space); + self.variable(&var.v_type, old_name, new_name, old_space)?; var.name = new_name; var.state_space = new_space; } + Ok(()) } } @@ -243,12 +259,58 @@ impl<'a, 'input> ast::VisitorMap { fn visit( &mut self, - args: SpirvWord, + ident: SpirvWord, type_space: Option<(&ast::Type, ast::StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result { - todo!() + if let Some(remap) = self.variables.get(&ident) { + match remap { + RemapAction::PreLdPostSt { name, type_ } => { + if is_dst { + let temp = self + .resolver + .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); + self.post.push(ast::Instruction::St { + data: ast::StData { + state_space: ast::StateSpace::Local, + qualifier: ast::LdStQualifier::Weak, + caching: ast::StCacheOperator::Writethrough, + typ: type_.clone(), + }, + arguments: ast::StArgs { + src1: *name, + src2: temp, + }, + }); + Ok(temp) + } else { + let temp = self + .resolver + .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); + self.pre.push(ast::Instruction::Ld { + data: ast::LdDetails { + state_space: ast::StateSpace::Local, + qualifier: ast::LdStQualifier::Weak, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: temp, + src: *name, + }, + }); + Ok(temp) + } + } + RemapAction::LDStSpaceChange { .. } => { + return Err(error_mismatched_type()); + } + } + } else { + Ok(ident) + } } fn visit_ident( @@ -262,9 +324,12 @@ impl<'a, 'input> ast::VisitorMap } } -#[derive(Clone, Copy)] +#[derive(Clone)] enum RemapAction { - PreLdPostSt(SpirvWord), + PreLdPostSt { + name: SpirvWord, + type_: ast::Type, + }, LDStSpaceChange { old_space: ast::StateSpace, new_space: ast::StateSpace, diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs new file mode 100644 index 0000000..4f738f5 --- /dev/null +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -0,0 +1,426 @@ +use std::mem; + +use super::*; +use ptx_parser as ast; + +/* + There are several kinds of implicit conversions in PTX: + * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands + * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size + - ld.param: not documented, but for instruction `ld.param. x, [y]`, + semantics are to first zext/chop/bitcast `y` as needed and then do + documented special ld/st/cvt conversion rules for destination operands + - st.param [x] y (used as function return arguments) same rule as above applies + - generic/global ld: for instruction `ld x, [y]`, y must be of type + b64/u64/s64, which is bitcast to a pointer, dereferenced and then + documented special ld/st/cvt conversion rules are applied to dst + - generic/global st: for instruction `st [x], y`, x must be of type + b64/u64/s64, which is bitcast to a pointer +*/ +pub(super) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(mut method) => { + method.body = method + .body + .map(|statements| run_statements(resolver, statements)) + .transpose()?; + Directive2::Method(method) + } + }) +} + +fn run_statements<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + func: Vec, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func.into_iter() { + insert_implicit_conversions_impl(resolver, &mut result, s)?; + } + Ok(result) +} + +fn insert_implicit_conversions_impl<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + func: &mut Vec, + stmt: ExpandedStatement, +) -> Result<(), TranslateError> { + let mut post_conv = Vec::new(); + let statement = stmt.visit_map::( + &mut |operand, + type_state: Option<(&ast::Type, ast::StateSpace)>, + is_dst, + relaxed_type_check| { + let (instr_type, instruction_space) = match type_state { + None => return Ok(operand), + Some(t) => t, + }; + let (operand_type, operand_space) = resolver.get_typed(operand)?; + let conversion_fn = if relaxed_type_check { + if is_dst { + should_convert_relaxed_dst_wrapper + } else { + should_convert_relaxed_src_wrapper + } + } else { + default_implicit_conversion + }; + match conversion_fn( + (*operand_space, &operand_type), + (instruction_space, instr_type), + )? { + Some(conv_kind) => { + let conv_output = if is_dst { &mut post_conv } else { &mut *func }; + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type.clone(); + let mut to_space = *operand_space; + let mut src = + resolver.register_unnamed(Some((instr_type.clone(), instruction_space))); + let mut dst = operand; + let result = Ok::<_, TranslateError>(src); + if !is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + from_space, + to_type, + to_space, + kind: conv_kind, + })); + result + } + None => Ok(operand), + } + }, + )?; + func.push(statement); + func.append(&mut post_conv); + Ok(()) +} + +pub(crate) fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if instruction_space == ast::StateSpace::Reg { + if operand_space == ast::StateSpace::Reg { + if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) + { + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } + } + } else if is_addressable(operand_space) { + return Ok(Some(ConversionKind::AddressOf)); + } + } + if instruction_space != operand_space { + default_implicit_conversion_space( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) + } +} + +fn is_addressable(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg => false, + ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => todo!(), + } +} + +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) + || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if operand_space == ast::StateSpace::Reg { + match operand_type { + ast::Type::Pointer(operand_ptr_type, operand_ptr_space) + if *operand_ptr_space == instruction_space => + { + if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(error_mismatched_type()), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } + _ => Err(error_mismatched_type()), + }, + _ => Err(error_mismatched_type()), + } + } else if instruction_space == ast::StateSpace::Reg { + match instruction_type { + ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) + if operand_space == *instruction_ptr_space => + { + if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + _ => Err(error_mismatched_type()), + } + } else { + Err(error_mismatched_type()) + } +} + +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, +) -> Result, TranslateError> { + if space == ast::StateSpace::Reg { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) + } + } else { + Ok(Some(ConversionKind::PtrToPtr)) + } +} + +fn coerces_to_generic(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCluster + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::Generic => false, + } +} + +fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { + match (instr, operand) { + (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { + if inst.size_of() != operand.size_of() { + return false; + } + match inst.kind() { + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned + } + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed + } + ast::ScalarKind::Pred => false, + } + } + (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) + | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { + should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) + } + _ => false, + } +} + +pub(crate) fn should_convert_relaxed_dst_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if operand_space != instruction_space { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_dst(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands +fn should_convert_relaxed_dst( + dst_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if dst_type == instr_type { + return None; + } + match (dst_type, instr_type) { + (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= dst_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { + if instr_type.size_of() == dst_type.size_of() { + Some(ConversionKind::Default) + } else if instr_type.size_of() < dst_type.size_of() { + Some(ConversionKind::SignExtend) + } else { + None + } + } else { + None + } + } + ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { + should_convert_relaxed_dst( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} + +pub(crate) fn should_convert_relaxed_src_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if operand_space != instruction_space { + return Err(error_mismatched_type()); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_src(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(error_mismatched_type()), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands +fn should_convert_relaxed_src( + src_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if src_type == instr_type { + return None; + } + match (src_type, instr_type) { + (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= src_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { + should_convert_relaxed_src( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index b82d3c5..0e233ed 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -27,8 +27,10 @@ mod expand_operands; mod extract_globals; mod fix_special_registers; mod fix_special_registers2; +mod hoist_globals; mod insert_explicit_load_store; mod insert_implicit_conversions; +mod insert_implicit_conversions2; mod insert_mem_ssa_statements; mod normalize_identifiers; mod normalize_identifiers2; @@ -67,11 +69,13 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result(ast: ast::Module<'input>) -> Result { @@ -82,10 +86,17 @@ pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result, SpirvWord>> = + expand_operands::run(&mut flat_resolver, directives)?; let directives = deparamize_functions::run(&mut flat_resolver, directives)?; let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; - todo!() + let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; + let directives = hoist_globals::run(directives)?; + let llvm_ir = emit_llvm::run(flat_resolver, directives)?; + Ok(Module { + llvm_ir, + kernel_info: HashMap::new(), + }) } fn translate_directive<'input, 'a>( diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 69dd206..e15d6ea 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -236,7 +236,7 @@ fn test_hip_assert< output: &mut [Output], ) -> Result<(), Box> { let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); - let llvm_ir = pass::to_llvm_module(ast).unwrap(); + let llvm_ir = pass::to_llvm_module2(ast).unwrap(); let name = CString::new(name)?; let result = run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?; -- cgit v1.2.3