From 02cf83ebb9254aed8184aeb77d3f66197d665570 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Fri, 13 Sep 2024 19:40:58 +0200 Subject: Add mov and call support --- ptx/src/pass/emit_llvm.rs | 158 +++++++++++++++++++++++++++------------------- 1 file changed, 93 insertions(+), 65 deletions(-) diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 44debba..6df162f 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -225,11 +225,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id], }); let name = CString::new(name).map_err(|_| error_unreachable())?; - let fn_type = self.function_type( + let fn_type = get_function_type( + self.context, func_decl.return_arguments.iter().map(|v| &v.v_type), func_decl.input_arguments.iter().map(|v| &v.v_type), - ); + )?; let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + if let ast::MethodName::Func(name) = func_decl.name { + self.resolver.register(name, fn_); + } for (i, param) in func_decl.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); @@ -258,67 +262,6 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { } Ok(()) } - - fn function_type( - &self, - return_args: impl ExactSizeIterator, - input_args: impl ExactSizeIterator, - ) -> LLVMTypeRef { - if return_args.len() == 0 { - let mut input_args = input_args - .map(|type_| match type_ { - ast::Type::Scalar(scalar) => match scalar { - ast::ScalarType::Pred => { - unsafe { LLVMInt1TypeInContext(self.context) } - } - ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => { - unsafe { LLVMInt8TypeInContext(self.context) } - } - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { - unsafe { LLVMInt16TypeInContext(self.context) } - } - ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => { - unsafe { LLVMInt32TypeInContext(self.context) } - } - ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => { - unsafe { LLVMInt64TypeInContext(self.context) } - } - ast::ScalarType::B128 => { - unsafe { LLVMInt128TypeInContext(self.context) } - } - ast::ScalarType::F16 => { - unsafe { LLVMHalfTypeInContext(self.context) } - } - ast::ScalarType::F32 => { - unsafe { LLVMFloatTypeInContext(self.context) } - } - ast::ScalarType::F64 => { - unsafe { LLVMDoubleTypeInContext(self.context) } - } - ast::ScalarType::BF16 => { - unsafe { LLVMBFloatTypeInContext(self.context) } - } - ast::ScalarType::U16x2 => todo!(), - ast::ScalarType::S16x2 => todo!(), - ast::ScalarType::F16x2 => todo!(), - ast::ScalarType::BF16x2 => todo!(), - }, - ast::Type::Vector(_, _) => todo!(), - ast::Type::Array(_, _, _) => todo!(), - ast::Type::Pointer(_, _) => todo!(), - }) - .collect::>(); - return unsafe { - LLVMFunctionType( - LLVMVoidTypeInContext(self.context), - input_args.as_mut_ptr(), - input_args.len() as u32, - 0, - ) - }; - } - todo!() - } } struct MethodEmitContext<'a, 'input> { @@ -414,7 +357,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { inst: ast::Instruction, ) -> Result<(), TranslateError> { match inst { - ast::Instruction::Mov { data, arguments } => todo!(), + ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments), ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), @@ -425,7 +368,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Or { data, arguments } => todo!(), ast::Instruction::And { data, arguments } => todo!(), ast::Instruction::Bra { arguments } => todo!(), - ast::Instruction::Call { data, arguments } => todo!(), + ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), ast::Instruction::Cvt { data, arguments } => todo!(), ast::Instruction::Shr { data, arguments } => todo!(), ast::Instruction::Shl { data, arguments } => todo!(), @@ -563,6 +506,68 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_ret(&self, _data: ptx_parser::RetData) { unsafe { LLVMBuildRetVoid(self.builder) }; } + + fn emit_call( + &mut self, + data: ptx_parser::CallDetails, + arguments: ptx_parser::CallArgs, + ) -> Result<(), TranslateError> { + if cfg!(debug_assertions) { + for (_, space) in data.return_arguments.iter() { + if *space != ast::StateSpace::Reg { + panic!() + } + } + for (_, space) in data.input_arguments.iter() { + if *space != ast::StateSpace::Reg { + panic!() + } + } + } + let name = match (&*data.return_arguments, &*arguments.return_arguments) { + ([], []) => LLVM_UNNAMED.as_ptr(), + ([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst), + _ => todo!(), + }; + let type_ = get_function_type( + self.context, + data.return_arguments.iter().map(|(type_, space)| type_), + data.input_arguments.iter().map(|(type_, space)| type_), + )?; + let mut input_arguments = arguments + .input_arguments + .iter() + .map(|arg| self.resolver.value(*arg)) + .collect::, _>>()?; + let llvm_fn = unsafe { + LLVMBuildCall2( + self.builder, + type_, + self.resolver.value(arguments.func)?, + input_arguments.as_mut_ptr(), + input_arguments.len() as u32, + name, + ) + }; + match &*arguments.return_arguments { + [] => {} + [name] => { + self.resolver.register(*name, llvm_fn); + } + _ => todo!(), + } + Ok(()) + } + + fn emit_mov( + &mut self, + _data: ptx_parser::MovDetails, + arguments: ptx_parser::MovArgs, + ) -> Result<(), TranslateError> { + self.resolver + .register(arguments.dst, self.resolver.value(arguments.src)?); + Ok(()) + } } fn get_pointer_type<'ctx>( @@ -624,6 +629,29 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR } } +fn get_function_type<'a>( + context: LLVMContextRef, + mut return_args: impl ExactSizeIterator, + input_args: impl ExactSizeIterator, +) -> Result { + let mut input_args: Vec<*mut llvm_zluda::LLVMType> = input_args + .map(|type_| get_type(context, type_)) + .collect::, _>>()?; + let return_type = match return_args.len() { + 0 => unsafe { LLVMVoidTypeInContext(context) }, + 1 => get_type(context, return_args.next().unwrap())?, + _ => todo!(), + }; + Ok(unsafe { + LLVMFunctionType( + return_type, + input_args.as_mut_ptr(), + input_args.len() as u32, + 0, + ) + }) +} + fn get_state_space(space: ast::StateSpace) -> Result { match space { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), -- cgit v1.2.3