aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/pass/normalize_identifiers2.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/pass/normalize_identifiers2.rs')
-rw-r--r--ptx/src/pass/normalize_identifiers2.rs111
1 files changed, 12 insertions, 99 deletions
diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs
index e3fb88d..beaf08b 100644
--- a/ptx/src/pass/normalize_identifiers2.rs
+++ b/ptx/src/pass/normalize_identifiers2.rs
@@ -2,21 +2,21 @@ use super::*;
use ptx_parser as ast;
use rustc_hash::FxHashMap;
-pub(crate) fn run<'input>(
- fn_defs: &mut GlobalStringIdentResolver2<'input>,
+pub(crate) fn run<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
- let mut resolver = NameResolver::new(fn_defs);
+ resolver.start_scope();
let result = directives
.into_iter()
- .map(|directive| run_directive(&mut resolver, directive))
+ .map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
fn run_directive<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> {
Ok(match directive {
@@ -30,7 +30,7 @@ fn run_directive<'input, 'b>(
}
fn run_method<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> {
@@ -41,11 +41,7 @@ fn run_method<'input, 'b>(
}
};
resolver.start_scope();
- let func_decl = Rc::new(RefCell::new(run_function_decl(
- resolver,
- method.func_directive,
- name,
- )?));
+ let func_decl = run_function_decl(resolver, method.func_directive, name)?;
let body = method
.body
.map(|statements| {
@@ -66,7 +62,7 @@ fn run_method<'input, 'b>(
}
fn run_function_decl<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>,
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
@@ -90,7 +86,7 @@ fn run_function_decl<'input, 'b>(
}
fn run_variable<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable {
@@ -106,7 +102,7 @@ fn run_variable<'input, 'b>(
}
fn run_statements<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<(), TranslateError> {
@@ -148,7 +144,7 @@ fn run_statements<'input, 'b>(
}
fn run_instruction<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(instruction, &mut |name: &'input str,
@@ -163,7 +159,7 @@ fn run_instruction<'input, 'b>(
}
fn run_multivariable<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> {
@@ -201,86 +197,3 @@ fn run_multivariable<'input, 'b>(
}
Ok(())
}
-
-struct NameResolver<'input, 'b> {
- flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
- scopes: Vec<ScopeStringIdentResolver<'input>>,
-}
-
-impl<'input, 'b> NameResolver<'input, 'b> {
- fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
- Self {
- flat_resolver,
- scopes: vec![ScopeStringIdentResolver::new()],
- }
- }
-
- fn start_scope(&mut self) {
- self.scopes.push(ScopeStringIdentResolver::new());
- }
-
- fn end_scope(&mut self) {
- let scope = self.scopes.pop().unwrap();
- scope.flush(self.flat_resolver);
- }
-
- fn add(
- &mut self,
- name: Cow<'input, str>,
- type_space: Option<(ast::Type, ast::StateSpace)>,
- ) -> Result<SpirvWord, TranslateError> {
- let result = self.flat_resolver.current_id;
- self.flat_resolver.current_id.0 += 1;
- let current_scope = self.scopes.last_mut().unwrap();
- if current_scope
- .name_to_ident
- .insert(name.clone(), result)
- .is_some()
- {
- return Err(error_unknown_symbol());
- }
- current_scope.ident_map.insert(
- result,
- IdentEntry {
- name: Some(name),
- type_space,
- },
- );
- Ok(result)
- }
-
- fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
- self.scopes
- .iter()
- .rev()
- .find_map(|resolver| resolver.name_to_ident.get(name).copied())
- .ok_or_else(|| error_unreachable())
- }
-
- fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
- let current_scope = self.scopes.last().unwrap();
- current_scope
- .name_to_ident
- .get(label)
- .copied()
- .ok_or_else(|| error_unreachable())
- }
-}
-
-struct ScopeStringIdentResolver<'input> {
- ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
- name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
-}
-
-impl<'input> ScopeStringIdentResolver<'input> {
- fn new() -> Self {
- Self {
- ident_map: FxHashMap::default(),
- name_to_ident: FxHashMap::default(),
- }
- }
-
- fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
- resolver.ident_map.extend(self.ident_map);
- }
-}