diff options
author | Andrzej Janik <[email protected]> | 2024-09-16 16:42:34 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-09-16 16:42:34 +0200 |
commit | 3b5efbf88b99392f12e95e084fabb5a8960ae04c (patch) | |
tree | 0c1c5fdce66d9242b75020939a0b3641a2fe701d | |
parent | 2cd7910d465914b868396d0192e8ddc31086d2d8 (diff) | |
download | ZLUDA-3b5efbf88b99392f12e95e084fabb5a8960ae04c.tar.gz ZLUDA-3b5efbf88b99392f12e95e084fabb5a8960ae04c.zip |
Refactor normalize_identifiers
-rw-r--r-- | ptx/Cargo.toml | 1 | ||||
-rw-r--r-- | ptx/src/pass/mod.rs | 33 | ||||
-rw-r--r-- | ptx/src/pass/normalize_identifiers2.rs | 292 | ||||
-rw-r--r-- | ptx_parser/src/ast.rs | 9 |
4 files changed, 335 insertions, 0 deletions
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2e2995f..fd86f15 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -17,6 +17,7 @@ thiserror = "1.0" bit-vec = "0.6" half ="1.6" bitflags = "1.2" +rustc-hash = "2.0.0" [dependencies.lalrpop-util] version = "0.19.12" diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 3dcbf84..409425f 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -24,6 +24,7 @@ mod fix_special_registers; mod insert_implicit_conversions;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
+mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
@@ -1657,3 +1658,35 @@ fn denorm_count_map_update_impl<T: Eq + Hash>( }
}
}
+
+pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
+ Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
+ Method(Function2<'input, Instruction, Operand>),
+}
+
+pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
+ pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ pub globals: Vec<ast::Variable<SpirvWord>>,
+ pub body: Option<Vec<Statement<Instruction, Operand>>>,
+ import_as: Option<String>,
+ tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
+}
+
+type NormalizedDirective2<'input> = Directive2<
+ 'input,
+ (
+ Option<ast::PredAt<SpirvWord>>,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ),
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type NormalizedFunction2<'input> = Function2<
+ 'input,
+ (
+ Option<ast::PredAt<SpirvWord>>,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ),
+ ast::ParsedOperand<SpirvWord>,
+>;
diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs new file mode 100644 index 0000000..925feb7 --- /dev/null +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -0,0 +1,292 @@ +use super::*;
+use ptx_parser as ast;
+use rustc_hash::FxHashMap;
+
+pub(crate) fn run<'input>(
+ fn_defs: &mut GlobalStringIdentResolver<'input>,
+ directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
+) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
+ let mut resolver = NameResolver::new(fn_defs);
+ let result = directives
+ .into_iter()
+ .map(|directive| remap_directive(&mut resolver, directive))
+ .collect::<Result<Vec<_>, _>>()?;
+ resolver.end_scope();
+ Ok(result)
+}
+
+fn remap_directive<'input, 'b>(
+ resolver: &mut NameResolver<'input, 'b>,
+ directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
+) -> Result<NormalizedDirective2<'input>, TranslateError> {
+ Ok(match directive {
+ ast::Directive::Variable(linking, var) => {
+ NormalizedDirective2::Variable(linking, remap_variable(resolver, var)?)
+ }
+ ast::Directive::Method(linking, directive) => {
+ NormalizedDirective2::Method(remap_method(resolver, linking, directive)?)
+ }
+ })
+}
+
+fn remap_method<'input, 'b>(
+ resolver: &mut NameResolver<'input, 'b>,
+ linkage: ast::LinkingDirective,
+ method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
+) -> Result<NormalizedFunction2<'input>, TranslateError> {
+ let name = match method.func_directive.name {
+ ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
+ ast::MethodName::Func(text) => ast::MethodName::Func(
+ resolver.add(Cow::Borrowed(method.func_directive.name.text()), None)?,
+ ),
+ };
+ resolver.start_scope();
+ let func_decl = Rc::new(RefCell::new(remap_function_decl(
+ resolver,
+ method.func_directive,
+ name,
+ )?));
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ remap_statements(resolver, &mut result, statements)?;
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ resolver.end_scope();
+ Ok(Function2 {
+ func_decl,
+ globals: Vec::new(),
+ body,
+ import_as: None,
+ tuning: method.tuning,
+ linkage,
+ })
+}
+
+fn remap_function_decl<'input, 'b>(
+ resolver: &mut NameResolver<'input, 'b>,
+ func_directive: ast::MethodDeclaration<'input, &'input str>,
+ name: ast::MethodName<'input, SpirvWord>,
+) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
+ assert!(func_directive.shared_mem.is_none());
+ let return_arguments = func_directive
+ .return_arguments
+ .into_iter()
+ .map(|var| remap_variable(resolver, var))
+ .collect::<Result<Vec<_>, _>>()?;
+ let input_arguments = func_directive
+ .input_arguments
+ .into_iter()
+ .map(|var| remap_variable(resolver, var))
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(ast::MethodDeclaration {
+ return_arguments,
+ name,
+ input_arguments,
+ shared_mem: None,
+ })
+}
+
+fn remap_variable<'input, 'b>(
+ resolver: &mut NameResolver<'input, 'b>,
+ variable: ast::Variable<&'input str>,
+) -> Result<ast::Variable<SpirvWord>, TranslateError> {
+ Ok(ast::Variable {
+ name: resolver.add(
+ Cow::Borrowed(variable.name),
+ Some((variable.v_type.clone(), variable.state_space)),
+ )?,
+ align: variable.align,
+ v_type: variable.v_type,
+ state_space: variable.state_space,
+ array_init: variable.array_init,
+ })
+}
+
+fn remap_statements<'input, 'b>(
+ resolver: &mut NameResolver<'input, 'b>,
+ result: &mut Vec<NormalizedStatement>,
+ statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
+) -> Result<(), TranslateError> {
+ for statement in statements.iter() {
+ match statement {
+ ast::Statement::Label(label) => {
+ resolver.add(Cow::Borrowed(*label), None)?;
+ }
+ _ => {}
+ }
+ }
+ for statement in statements {
+ match statement {
+ ast::Statement::Label(label) => {
+ result.push(Statement::Label(resolver.get_in_current_scope(label)?))
+ }
+ ast::Statement::Variable(variable) => remap_multivariable(resolver, result, variable)?,
+ ast::Statement::Instruction(predicate, instruction) => {
+ result.push(Statement::Instruction((
+ predicate
+ .map(|pred| {
+ Ok::<_, TranslateError>(ast::PredAt {
+ not: pred.not,
+ label: resolver.get(pred.label)?,
+ })
+ })
+ .transpose()?,
+ remap_instruction(resolver, instruction)?,
+ )))
+ }
+ ast::Statement::Block(block) => {
+ resolver.start_scope();
+ remap_statements(resolver, result, block)?;
+ resolver.end_scope();
+ }
+ }
+ }
+ Ok(())
+}
+
+fn remap_instruction<'input, 'b>(
+ resolver: &mut NameResolver<'input, 'b>,
+ instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
+) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
+ ast::visit_map(instruction, &mut |name: &'input str,
+ _: Option<(
+ &ast::Type,
+ ast::StateSpace,
+ )>,
+ _,
+ _| {
+ resolver.get(&name)
+ })
+}
+
+fn remap_multivariable<'input, 'b>(
+ resolver: &mut NameResolver<'input, 'b>,
+ result: &mut Vec<NormalizedStatement>,
+ variable: ast::MultiVariable<&'input str>,
+) -> Result<(), TranslateError> {
+ match variable.count {
+ Some(count) => {
+ for i in 0..count {
+ let name = Cow::Owned(format!("{}{}", variable.var.name, i));
+ let ident = resolver.add(
+ name,
+ Some((variable.var.v_type.clone(), variable.var.state_space)),
+ )?;
+ result.push(Statement::Variable(ast::Variable {
+ align: variable.var.align,
+ v_type: variable.var.v_type.clone(),
+ state_space: variable.var.state_space,
+ name: ident,
+ array_init: variable.var.array_init.clone(),
+ }));
+ }
+ }
+ None => {
+ let name = Cow::Borrowed(variable.var.name);
+ let ident = resolver.add(
+ name,
+ Some((variable.var.v_type.clone(), variable.var.state_space)),
+ )?;
+ result.push(Statement::Variable(ast::Variable {
+ align: variable.var.align,
+ v_type: variable.var.v_type.clone(),
+ state_space: variable.var.state_space,
+ name: ident,
+ array_init: variable.var.array_init.clone(),
+ }));
+ }
+ }
+ Ok(())
+}
+
+struct NameResolver<'input, 'b> {
+ flat_resolver: &'b mut GlobalStringIdentResolver<'input>,
+ scopes: Vec<ScopeStringIdentResolver<'input>>,
+}
+
+impl<'input, 'b> NameResolver<'input, 'b> {
+ fn new(flat_resolver: &'b mut GlobalStringIdentResolver<'input>) -> Self {
+ Self {
+ flat_resolver,
+ scopes: vec![ScopeStringIdentResolver::new()],
+ }
+ }
+
+ fn start_scope(&mut self) {
+ self.scopes.push(ScopeStringIdentResolver::new());
+ }
+
+ fn end_scope(&mut self) {
+ let scope = self.scopes.pop().unwrap();
+ scope.flush(self.flat_resolver);
+ }
+
+ fn add(
+ &mut self,
+ name: Cow<'input, str>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let result = self.flat_resolver.current_id;
+ self.flat_resolver.current_id.0 += 1;
+ let current_scope = self.scopes.last_mut().unwrap();
+ if current_scope
+ .name_to_ident
+ .insert(name.clone(), result)
+ .is_some()
+ {
+ return Err(error_unknown_symbol());
+ }
+ current_scope
+ .ident_map
+ .insert(result, IdentEntry { name, type_space });
+ Ok(result)
+ }
+
+ fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
+ self.scopes
+ .iter()
+ .rev()
+ .find_map(|resolver| resolver.name_to_ident.get(name).copied())
+ .ok_or_else(|| error_unreachable())
+ }
+
+ fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
+ let current_scope = self.scopes.last().unwrap();
+ current_scope
+ .name_to_ident
+ .get(label)
+ .copied()
+ .ok_or_else(|| error_unreachable())
+ }
+}
+
+struct ScopeStringIdentResolver<'input> {
+ ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+ name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
+}
+
+impl<'input> ScopeStringIdentResolver<'input> {
+ fn new() -> Self {
+ Self {
+ ident_map: FxHashMap::default(),
+ name_to_ident: FxHashMap::default(),
+ }
+ }
+
+ fn flush(self, resolver: &mut GlobalStringIdentResolver<'input>) {
+ resolver.ident_map.extend(self.ident_map);
+ }
+}
+
+struct GlobalStringIdentResolver<'input> {
+ pub(crate) current_id: SpirvWord,
+ pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+}
+
+struct IdentEntry<'input> {
+ name: Cow<'input, str>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+}
diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index cc5a1d0..65c624e 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1049,6 +1049,15 @@ impl<'input, ID> MethodName<'input, ID> { }
}
+impl<'input> MethodName<'input, &'input str> {
+ pub fn text(&self) -> &'input str {
+ match self {
+ MethodName::Kernel(name) => *name,
+ MethodName::Func(name) => *name,
+ }
+ }
+}
+
bitflags! {
pub struct LinkingDirective: u8 {
const NONE = 0b000;
|