aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs88
1 files changed, 54 insertions, 34 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 297588a..13c578b 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -543,10 +543,10 @@ fn emit_directives<'input>(
let f_body = match &f.body {
Some(f) => f,
None => {
- if f.import_as.is_some() {
- &empty_body
- } else {
+ if f.linkage == ast::LinkingDirective::NONE {
continue;
+ } else {
+ &empty_body
}
}
};
@@ -607,33 +607,38 @@ fn emit_directives<'input>(
}
}
emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?;
+ emit_function_linkage(builder, id_defs, f, fn_id);
builder.select_block(None)?;
builder.end_function()?;
- if let (
- ast::MethodDeclaration {
- name: ast::MethodName::Func(fn_id),
- ..
- },
- Some(name),
- ) = (&*func_decl, &f.import_as)
- {
- builder.decorate(
- *fn_id,
- spirv::Decoration::LinkageAttributes,
- [
- dr::Operand::LiteralString(name.clone()),
- dr::Operand::LinkageType(spirv::LinkageType::Import),
- ]
- .iter()
- .cloned(),
- );
- }
}
}
}
Ok(())
}
+fn emit_function_linkage<'input>(
+ builder: &mut dr::Builder,
+ id_defs: &GlobalStringIdResolver<'input>,
+ f: &Function,
+ fn_name: spirv::Word,
+) -> Result<(), TranslateError> {
+ if f.linkage == ast::LinkingDirective::NONE {
+ return Ok(());
+ };
+ let linking_name = f.import_as.as_deref().map_or_else(
+ || match f.func_decl.borrow().name {
+ ast::MethodName::Kernel(kernel_name) => Ok(kernel_name),
+ ast::MethodName::Func(fn_id) => match id_defs.reverse_variables.get(&fn_id) {
+ Some(fn_name) => Ok(fn_name),
+ None => Err(error_unknown_symbol()),
+ },
+ },
+ Result::Ok,
+ )?;
+ emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage);
+ Ok(())
+}
+
fn get_kernels_call_map<'input>(
module: &[Directive<'input>],
) -> HashMap<&'input str, HashSet<spirv::Word>> {
@@ -763,6 +768,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
+ linkage,
}) => {
let call_key = (*func_decl).borrow().name;
let statements = statements
@@ -786,6 +792,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
+ linkage,
})
}
directive => directive,
@@ -804,6 +811,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
+ linkage,
}) => {
if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
@@ -812,6 +820,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
+ linkage,
});
}
let shared_id_param = new_id();
@@ -832,6 +841,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
body: Some(statements),
import_as,
tuning,
+ linkage,
})
}
directive => directive,
@@ -1150,8 +1160,8 @@ fn translate_directive<'input, 'a>(
array_init: var.array_init,
},
)),
- ast::Directive::Method(_, f) => {
- translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method)
+ ast::Directive::Method(linkage, f) => {
+ translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method)
}
})
}
@@ -1159,6 +1169,7 @@ fn translate_directive<'input, 'a>(
fn translate_function<'input, 'a>(
id_defs: &'a mut GlobalStringIdResolver<'input>,
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
+ linkage: ast::LinkingDirective,
f: ast::ParsedFunction<'input>,
) -> Result<Option<Function<'input>>, TranslateError> {
let import_as = match &f.func_directive {
@@ -1178,6 +1189,7 @@ fn translate_function<'input, 'a>(
fn_decl,
f.body,
f.tuning,
+ linkage,
)?;
func.import_as = import_as;
if func.import_as.is_some() {
@@ -1213,6 +1225,7 @@ fn to_ssa<'input, 'b>(
func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
) -> Result<Function<'input>, TranslateError> {
//deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
@@ -1224,6 +1237,7 @@ fn to_ssa<'input, 'b>(
globals: Vec::new(),
import_as: None,
tuning,
+ linkage,
})
}
};
@@ -1255,6 +1269,7 @@ fn to_ssa<'input, 'b>(
body: Some(f_body),
import_as: None,
tuning,
+ linkage,
})
}
@@ -1832,6 +1847,7 @@ fn register_external_fn_call<'a>(
body: None,
import_as: Some(entry.key().clone()),
tuning: Vec::new(),
+ linkage: ast::LinkingDirective::EXTERN,
};
entry.insert(Directive::Method(func));
Ok(fn_id)
@@ -3464,37 +3480,40 @@ fn emit_variable<'input>(
[dr::Operand::LiteralInt32(align)].iter().cloned(),
);
}
- emit_linking_decoration(builder, id_defs, var.name, linking);
+ emit_linking_decoration(builder, id_defs, None, var.name, linking);
Ok(())
}
fn emit_linking_decoration<'input>(
builder: &mut dr::Builder,
id_defs: &GlobalStringIdResolver<'input>,
+ name_override: Option<&str>,
name: spirv::Word,
linking: ast::LinkingDirective,
) {
- if linking.contains(ast::LinkingDirective::EXTERN) {
- let external_name = id_defs.reverse_variables.get(&name).unwrap();
+ if linking == ast::LinkingDirective::NONE {
+ return;
+ }
+ let string_name =
+ name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap());
+ if linking.contains(ast::LinkingDirective::VISIBLE) {
builder.decorate(
name,
spirv::Decoration::LinkageAttributes,
[
- dr::Operand::LiteralString(external_name.to_string()),
- dr::Operand::LinkageType(spirv::LinkageType::Import),
+ dr::Operand::LiteralString(string_name.to_string()),
+ dr::Operand::LinkageType(spirv::LinkageType::Export),
]
.iter()
.cloned(),
);
- }
- if linking.contains(ast::LinkingDirective::VISIBLE) {
- let external_name = id_defs.reverse_variables.get(&name).unwrap();
+ } else if linking.contains(ast::LinkingDirective::EXTERN) {
builder.decorate(
name,
spirv::Decoration::LinkageAttributes,
[
- dr::Operand::LiteralString(external_name.to_string()),
- dr::Operand::LinkageType(spirv::LinkageType::Export),
+ dr::Operand::LiteralString(string_name.to_string()),
+ dr::Operand::LinkageType(spirv::LinkageType::Import),
]
.iter()
.cloned(),
@@ -5774,6 +5793,7 @@ struct Function<'input> {
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
}
pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {