aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/emit_llvm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/pass/emit_llvm.rs')
-rw-r--r--ptx/src/pass/emit_llvm.rs80
1 files changed, 57 insertions, 23 deletions
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs
index 3060335..235ad7d 100644
--- a/ptx/src/pass/emit_llvm.rs
+++ b/ptx/src/pass/emit_llvm.rs
@@ -164,17 +164,16 @@ impl Deref for MemoryBuffer {
}
pub(super) fn run<'input>(
- id_defs: &GlobalStringIdResolver<'input>,
- call_map: MethodsCallMap<'input>,
- directives: Vec<Directive<'input>>,
+ id_defs: GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
) -> Result<MemoryBuffer, TranslateError> {
let context = Context::new();
let module = Module::new(&context, LLVM_UNNAMED);
- let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs);
+ let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs);
for directive in directives {
match directive {
- Directive::Variable(..) => todo!(),
- Directive::Method(method) => emit_ctx.emit_method(method)?,
+ Directive2::Variable(..) => todo!(),
+ Directive2::Method(method) => emit_ctx.emit_method(method)?,
}
}
module.write_to_stderr();
@@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
builder: Builder,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
resolver: ResolveIdent,
}
@@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
fn new(
context: &Context,
module: &Module,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
) -> Self {
ModuleEmitContext {
context: context.get(),
@@ -215,20 +214,27 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
LLVMCallConv::LLVMCCallConv as u32
}
- fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> {
- let func_decl = method.func_decl.borrow();
+ fn emit_method(
+ &mut self,
+ method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+ ) -> Result<(), TranslateError> {
+ let func_decl = method.func_decl;
let name = method
.import_as
.as_deref()
- .unwrap_or_else(|| match func_decl.name {
- ast::MethodName::Kernel(name) => name,
- ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id],
- });
+ .or_else(|| match func_decl.name {
+ ast::MethodName::Kernel(name) => Some(name),
+ ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(),
+ })
+ .ok_or_else(|| error_unreachable())?;
let name = CString::new(name).map_err(|_| error_unreachable())?;
let fn_type = get_function_type(
self.context,
func_decl.return_arguments.iter().map(|v| &v.v_type),
- func_decl.input_arguments.iter().map(|v| &v.v_type),
+ func_decl
+ .input_arguments
+ .iter()
+ .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)),
)?;
let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) };
if let ast::MethodName::Func(name) = func_decl.name {
@@ -239,6 +245,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
let name = self.resolver.get_or_add(param.name);
unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) };
self.resolver.register(param.name, value);
+ if func_decl.name.is_kernel() {
+ let attr_kind = unsafe {
+ LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len())
+ };
+ let attr = unsafe {
+ LLVMCreateTypeAttribute(
+ self.context,
+ attr_kind,
+ get_type(self.context, &param.v_type)?,
+ )
+ };
+ unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) };
+ }
}
let call_conv = if func_decl.name.is_kernel() {
Self::kernel_call_convention()
@@ -264,12 +283,26 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> {
}
}
+fn get_input_argument_type(
+ context: LLVMContextRef,
+ v_type: &ptx_parser::Type,
+ state_space: ptx_parser::StateSpace,
+) -> Result<LLVMTypeRef, TranslateError> {
+ match state_space {
+ ptx_parser::StateSpace::ParamEntry => {
+ Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) })
+ }
+ ptx_parser::StateSpace::Reg => get_type(context, v_type),
+ _ => return Err(error_unreachable()),
+ }
+}
+
struct MethodEmitContext<'a, 'input> {
context: LLVMContextRef,
module: LLVMModuleRef,
method: LLVMValueRef,
builder: LLVMBuilderRef,
- id_defs: &'a GlobalStringIdResolver<'input>,
+ id_defs: &'a GlobalStringIdentResolver2<'input>,
variables_builder: Builder,
resolver: &'a mut ResolveIdent,
}
@@ -533,7 +566,9 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> {
let type_ = get_function_type(
self.context,
data.return_arguments.iter().map(|(type_, space)| type_),
- data.input_arguments.iter().map(|(type_, space)| type_),
+ data.input_arguments
+ .iter()
+ .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)),
)?;
let mut input_arguments = arguments
.input_arguments
@@ -633,11 +668,10 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR
fn get_function_type<'a>(
context: LLVMContextRef,
mut return_args: impl ExactSizeIterator<Item = &'a ast::Type>,
- input_args: impl ExactSizeIterator<Item = &'a ast::Type>,
+ input_args: impl ExactSizeIterator<Item = Result<LLVMTypeRef, TranslateError>>,
) -> Result<LLVMTypeRef, TranslateError> {
- let mut input_args: Vec<*mut llvm_zluda::LLVMType> = input_args
- .map(|type_| get_type(context, type_))
- .collect::<Result<Vec<_>, _>>()?;
+ let mut input_args: Vec<*mut llvm_zluda::LLVMType> =
+ input_args.collect::<Result<Vec<_>, _>>()?;
let return_type = match return_args.len() {
0 => unsafe { LLVMVoidTypeInContext(context) },
1 => get_type(context, return_args.next().unwrap())?,
@@ -658,7 +692,7 @@ fn get_state_space(space: ast::StateSpace) -> Result<u32, TranslateError> {
ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE),
ast::StateSpace::Param => Err(TranslateError::Todo),
- ast::StateSpace::ParamEntry => Err(TranslateError::Todo),
+ ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE),
ast::StateSpace::ParamFunc => Err(TranslateError::Todo),
ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE),
ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE),
@@ -675,7 +709,7 @@ struct ResolveIdent {
}
impl ResolveIdent {
- fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self {
+ fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self {
ResolveIdent {
words: HashMap::new(),
values: HashMap::new(),