diff options
Diffstat (limited to 'ptx/src/pass/emit_llvm.rs')
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 158 |
1 files changed, 93 insertions, 65 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 44debba..6df162f 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -225,11 +225,15 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
});
let name = CString::new(name).map_err(|_| error_unreachable())?;
- let fn_type = self.function_type(
+ let fn_type = get_function_type(
+ self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
func_decl.input_arguments.iter().map(|v| &v.v_type),
- );
+ )?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ if let ast::MethodName::Func(name) = func_decl.name {
+ self.resolver.register(name, fn_);
+ }
for (i, param) in func_decl.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name);
@@ -258,67 +262,6 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { }
Ok(())
}
-
- fn function_type(
- &self,
- return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
- input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
- ) -> LLVMTypeRef {
- if return_args.len() == 0 {
- let mut input_args = input_args
- .map(|type_| match type_ {
- ast::Type::Scalar(scalar) => match scalar {
- ast::ScalarType::Pred => {
- unsafe { LLVMInt1TypeInContext(self.context) }
- }
- ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
- unsafe { LLVMInt8TypeInContext(self.context) }
- }
- ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
- unsafe { LLVMInt16TypeInContext(self.context) }
- }
- ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
- unsafe { LLVMInt32TypeInContext(self.context) }
- }
- ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
- unsafe { LLVMInt64TypeInContext(self.context) }
- }
- ast::ScalarType::B128 => {
- unsafe { LLVMInt128TypeInContext(self.context) }
- }
- ast::ScalarType::F16 => {
- unsafe { LLVMHalfTypeInContext(self.context) }
- }
- ast::ScalarType::F32 => {
- unsafe { LLVMFloatTypeInContext(self.context) }
- }
- ast::ScalarType::F64 => {
- unsafe { LLVMDoubleTypeInContext(self.context) }
- }
- ast::ScalarType::BF16 => {
- unsafe { LLVMBFloatTypeInContext(self.context) }
- }
- ast::ScalarType::U16x2 => todo!(),
- ast::ScalarType::S16x2 => todo!(),
- ast::ScalarType::F16x2 => todo!(),
- ast::ScalarType::BF16x2 => todo!(),
- },
- ast::Type::Vector(_, _) => todo!(),
- ast::Type::Array(_, _, _) => todo!(),
- ast::Type::Pointer(_, _) => todo!(),
- })
- .collect::<Vec<_>>();
- return unsafe {
- LLVMFunctionType(
- LLVMVoidTypeInContext(self.context),
- input_args.as_mut_ptr(),
- input_args.len() as u32,
- 0,
- )
- };
- }
- todo!()
- }
}
struct MethodEmitContext<'a, 'input> {
@@ -414,7 +357,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { inst: ast::Instruction<SpirvWord>,
) -> Result<(), TranslateError> {
match inst {
- ast::Instruction::Mov { data, arguments } => todo!(),
+ ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
@@ -425,7 +368,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Or { data, arguments } => todo!(),
ast::Instruction::And { data, arguments } => todo!(),
ast::Instruction::Bra { arguments } => todo!(),
- ast::Instruction::Call { data, arguments } => todo!(),
+ ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
ast::Instruction::Cvt { data, arguments } => todo!(),
ast::Instruction::Shr { data, arguments } => todo!(),
ast::Instruction::Shl { data, arguments } => todo!(),
@@ -563,6 +506,68 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_ret(&self, _data: ptx_parser::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
+
+ fn emit_call(
+ &mut self,
+ data: ptx_parser::CallDetails,
+ arguments: ptx_parser::CallArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if cfg!(debug_assertions) {
+ for (_, space) in data.return_arguments.iter() {
+ if *space != ast::StateSpace::Reg {
+ panic!()
+ }
+ }
+ for (_, space) in data.input_arguments.iter() {
+ if *space != ast::StateSpace::Reg {
+ panic!()
+ }
+ }
+ }
+ let name = match (&*data.return_arguments, &*arguments.return_arguments) {
+ ([], []) => LLVM_UNNAMED.as_ptr(),
+ ([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
+ _ => todo!(),
+ };
+ let type_ = get_function_type(
+ self.context,
+ data.return_arguments.iter().map(|(type_, space)| type_),
+ data.input_arguments.iter().map(|(type_, space)| type_),
+ )?;
+ let mut input_arguments = arguments
+ .input_arguments
+ .iter()
+ .map(|arg| self.resolver.value(*arg))
+ .collect::<Result<Vec<_>, _>>()?;
+ let llvm_fn = unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ type_,
+ self.resolver.value(arguments.func)?,
+ input_arguments.as_mut_ptr(),
+ input_arguments.len() as u32,
+ name,
+ )
+ };
+ match &*arguments.return_arguments {
+ [] => {}
+ [name] => {
+ self.resolver.register(*name, llvm_fn);
+ }
+ _ => todo!(),
+ }
+ Ok(())
+ }
+
+ fn emit_mov(
+ &mut self,
+ _data: ptx_parser::MovDetails,
+ arguments: ptx_parser::MovArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.resolver
+ .register(arguments.dst, self.resolver.value(arguments.src)?);
+ Ok(())
+ }
}
fn get_pointer_type<'ctx>(
@@ -624,6 +629,29 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR }
}
+fn get_function_type<'a>(
+ context: LLVMContextRef,
+ mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
+ input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
+) -> Result<LLVMTypeRef, TranslateError> {
+ let mut input_args: Vec<*mut llvm_zluda::LLVMType> = input_args
+ .map(|type_| get_type(context, type_))
+ .collect::<Result<Vec<_>, _>>()?;
+ let return_type = match return_args.len() {
+ 0 => unsafe { LLVMVoidTypeInContext(context) },
+ 1 => get_type(context, return_args.next().unwrap())?,
+ _ => todo!(),
+ };
+ Ok(unsafe {
+ LLVMFunctionType(
+ return_type,
+ input_args.as_mut_ptr(),
+ input_args.len() as u32,
+ 0,
+ )
+ })
+}
+
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
|