diff options
Diffstat (limited to 'ptx/src/pass/emit_llvm.rs')
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 224 |
1 files changed, 143 insertions, 81 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 44debba..235ad7d 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -164,17 +164,16 @@ impl Deref for MemoryBuffer { }
pub(super) fn run<'input>(
- id_defs: &GlobalStringIdResolver<'input>,
- call_map: MethodsCallMap<'input>,
- directives: Vec<Directive<'input>>,
+ id_defs: GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
- let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
+ let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
- Directive::Variable(..) => todo!(),
- Directive::Method(method) => emit_ctx.emit_method(method)?,
+ Directive2::Variable(..) => todo!(),
+ Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
module.write_to_stderr();
@@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> { context: LLVMContextRef,
module: LLVMModuleRef,
builder: Builder,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
resolver: ResolveIdent,
}
@@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn new(
context: &Context,
module: &Module,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
) -> Self {
ModuleEmitContext {
context: context.get(),
@@ -215,26 +214,50 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { LLVMCallConv::LLVMCCallConv as u32
}
- fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
- let func_decl = method.func_decl.borrow();
+ fn emit_method(
+ &mut self,
+ method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let func_decl = method.func_decl;
let name = method
.import_as
.as_deref()
- .unwrap_or_else(|| match func_decl.name {
- ast::MethodName::Kernel(name) => name,
- ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
- });
+ .or_else(|| match func_decl.name {
+ ast::MethodName::Kernel(name) => Some(name),
+ ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
+ })
+ .ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
- let fn_type = self.function_type(
+ let fn_type = get_function_type(
+ self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
- func_decl.input_arguments.iter().map(|v| &v.v_type),
- );
+ func_decl
+ .input_arguments
+ .iter()
+ .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
+ )?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
+ if let ast::MethodName::Func(name) = func_decl.name {
+ self.resolver.register(name, fn_);
+ }
for (i, param) in func_decl.input_arguments.iter().enumerate() {
let value = unsafe { LLVMGetParam(fn_, i as u32) };
let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
+ if func_decl.name.is_kernel() {
+ let attr_kind = unsafe {
+ LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
+ };
+ let attr = unsafe {
+ LLVMCreateTypeAttribute(
+ self.context,
+ attr_kind,
+ get_type(self.context, ¶m.v_type)?,
+ )
+ };
+ unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
+ }
}
let call_conv = if func_decl.name.is_kernel() {
Self::kernel_call_convention()
@@ -258,66 +281,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { }
Ok(())
}
+}
- fn function_type(
- &self,
- return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
- input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
- ) -> LLVMTypeRef {
- if return_args.len() == 0 {
- let mut input_args = input_args
- .map(|type_| match type_ {
- ast::Type::Scalar(scalar) => match scalar {
- ast::ScalarType::Pred => {
- unsafe { LLVMInt1TypeInContext(self.context) }
- }
- ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => {
- unsafe { LLVMInt8TypeInContext(self.context) }
- }
- ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
- unsafe { LLVMInt16TypeInContext(self.context) }
- }
- ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => {
- unsafe { LLVMInt32TypeInContext(self.context) }
- }
- ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => {
- unsafe { LLVMInt64TypeInContext(self.context) }
- }
- ast::ScalarType::B128 => {
- unsafe { LLVMInt128TypeInContext(self.context) }
- }
- ast::ScalarType::F16 => {
- unsafe { LLVMHalfTypeInContext(self.context) }
- }
- ast::ScalarType::F32 => {
- unsafe { LLVMFloatTypeInContext(self.context) }
- }
- ast::ScalarType::F64 => {
- unsafe { LLVMDoubleTypeInContext(self.context) }
- }
- ast::ScalarType::BF16 => {
- unsafe { LLVMBFloatTypeInContext(self.context) }
- }
- ast::ScalarType::U16x2 => todo!(),
- ast::ScalarType::S16x2 => todo!(),
- ast::ScalarType::F16x2 => todo!(),
- ast::ScalarType::BF16x2 => todo!(),
- },
- ast::Type::Vector(_, _) => todo!(),
- ast::Type::Array(_, _, _) => todo!(),
- ast::Type::Pointer(_, _) => todo!(),
- })
- .collect::<Vec<_>>();
- return unsafe {
- LLVMFunctionType(
- LLVMVoidTypeInContext(self.context),
- input_args.as_mut_ptr(),
- input_args.len() as u32,
- 0,
- )
- };
+fn get_input_argument_type(
+ context: LLVMContextRef,
+ v_type: &ptx_parser::Type,
+ state_space: ptx_parser::StateSpace,
+) -> Result<LLVMTypeRef, TranslateError> {
+ match state_space {
+ ptx_parser::StateSpace::ParamEntry => {
+ Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
}
- todo!()
+ ptx_parser::StateSpace::Reg => get_type(context, v_type),
+ _ => return Err(error_unreachable()),
}
}
@@ -326,7 +302,7 @@ struct MethodEmitContext<'a, 'input> { module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
@@ -365,6 +341,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Statement::PtrAccess(_) => todo!(),
Statement::RepackVector(_) => todo!(),
Statement::FunctionPointer(_) => todo!(),
+ Statement::VectorAccess(_) => todo!(),
})
}
@@ -414,7 +391,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { inst: ast::Instruction<SpirvWord>,
) -> Result<(), TranslateError> {
match inst {
- ast::Instruction::Mov { data, arguments } => todo!(),
+ ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments),
ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments),
ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments),
ast::Instruction::St { data, arguments } => self.emit_st(data, arguments),
@@ -425,7 +402,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Or { data, arguments } => todo!(),
ast::Instruction::And { data, arguments } => todo!(),
ast::Instruction::Bra { arguments } => todo!(),
- ast::Instruction::Call { data, arguments } => todo!(),
+ ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments),
ast::Instruction::Cvt { data, arguments } => todo!(),
ast::Instruction::Shr { data, arguments } => todo!(),
ast::Instruction::Shl { data, arguments } => todo!(),
@@ -563,6 +540,70 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_ret(&self, _data: ptx_parser::RetData) {
unsafe { LLVMBuildRetVoid(self.builder) };
}
+
+ fn emit_call(
+ &mut self,
+ data: ptx_parser::CallDetails,
+ arguments: ptx_parser::CallArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ if cfg!(debug_assertions) {
+ for (_, space) in data.return_arguments.iter() {
+ if *space != ast::StateSpace::Reg {
+ panic!()
+ }
+ }
+ for (_, space) in data.input_arguments.iter() {
+ if *space != ast::StateSpace::Reg {
+ panic!()
+ }
+ }
+ }
+ let name = match (&*data.return_arguments, &*arguments.return_arguments) {
+ ([], []) => LLVM_UNNAMED.as_ptr(),
+ ([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst),
+ _ => todo!(),
+ };
+ let type_ = get_function_type(
+ self.context,
+ data.return_arguments.iter().map(|(type_, space)| type_),
+ data.input_arguments
+ .iter()
+ .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
+ )?;
+ let mut input_arguments = arguments
+ .input_arguments
+ .iter()
+ .map(|arg| self.resolver.value(*arg))
+ .collect::<Result<Vec<_>, _>>()?;
+ let llvm_fn = unsafe {
+ LLVMBuildCall2(
+ self.builder,
+ type_,
+ self.resolver.value(arguments.func)?,
+ input_arguments.as_mut_ptr(),
+ input_arguments.len() as u32,
+ name,
+ )
+ };
+ match &*arguments.return_arguments {
+ [] => {}
+ [name] => {
+ self.resolver.register(*name, llvm_fn);
+ }
+ _ => todo!(),
+ }
+ Ok(())
+ }
+
+ fn emit_mov(
+ &mut self,
+ _data: ptx_parser::MovDetails,
+ arguments: ptx_parser::MovArgs<SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ self.resolver
+ .register(arguments.dst, self.resolver.value(arguments.src)?);
+ Ok(())
+ }
}
fn get_pointer_type<'ctx>(
@@ -624,13 +665,34 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR }
}
+fn get_function_type<'a>(
+ context: LLVMContextRef,
+ mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
+ input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
+) -> Result<LLVMTypeRef, TranslateError> {
+ let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
+ input_args.collect::<Result<Vec<_>, _>>()?;
+ let return_type = match return_args.len() {
+ 0 => unsafe { LLVMVoidTypeInContext(context) },
+ 1 => get_type(context, return_args.next().unwrap())?,
+ _ => todo!(),
+ };
+ Ok(unsafe {
+ LLVMFunctionType(
+ return_type,
+ input_args.as_mut_ptr(),
+ input_args.len() as u32,
+ 0,
+ )
+ })
+}
+
fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
match space {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
- ast::StateSpace::Sreg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo),
- ast::StateSpace::ParamEntry => Err(TranslateError::Todo),
+ ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
@@ -647,7 +709,7 @@ struct ResolveIdent { }
impl ResolveIdent {
- fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
+ fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
ResolveIdent {
words: HashMap::new(),
values: HashMap::new(),
|