diff options
author | Andrzej Janik <[email protected]> | 2020-04-28 00:02:34 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-04-28 00:02:34 +0200 |
commit | 92b5dbd6a8a3995bf12b2b606c034f9f05cbeca1 (patch) | |
tree | 40bb97dd00b4465feb03c5553fdd6d4df75c9b0e | |
parent | bce5f2784382e81cda66773d4cb0727d18b8b7ac (diff) | |
download | ZLUDA-92b5dbd6a8a3995bf12b2b606c034f9f05cbeca1.tar.gz ZLUDA-92b5dbd6a8a3995bf12b2b606c034f9f05cbeca1.zip |
Fix bugs in basic block resolution
-rw-r--r-- | ptx/src/ptx.lalrpop | 2 | ||||
-rw-r--r-- | ptx/src/translate.rs | 174 |
2 files changed, 148 insertions, 28 deletions
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index f40846d..83a0fe2 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -86,7 +86,7 @@ FunctionInput: ast::Argument<'input> = { } }; -FunctionBody: Vec<ast::Statement<&'input str>> = { +pub(crate) FunctionBody: Vec<ast::Statement<&'input str>> = { "{" <s:Statement*> "}" => { without_none(s) } }; diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 651f996..1206e22 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -251,7 +251,11 @@ fn rename_succesor_phi_src( }
}
-fn pop_stacks(ssa_state: &mut SSARewriteState, old_phi: &HashSet<spirv::Word>, old_ids: &[spirv::Word]) {
+fn pop_stacks(
+ ssa_state: &mut SSARewriteState,
+ old_phi: &HashSet<spirv::Word>,
+ old_ids: &[spirv::Word],
+) {
for id in old_phi.iter().chain(old_ids) {
ssa_state.pop(*id);
}
@@ -335,7 +339,7 @@ fn gather_phi_sets( });
}
}
- for (id, to_work ) in def_sites.iter_mut().enumerate() {
+ for (id, to_work) in def_sites.iter_mut().enumerate() {
let id = id as spirv::Word;
let (ref mut set, ref mut stack) = to_work;
loop {
@@ -358,18 +362,26 @@ fn gather_phi_sets( result
}
-fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
- let mut direct_bb_start = Vec::new();
- let mut indirect_bb_start = Vec::new();
+fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
+ // edge signify pred/succ relationship between bbs
+ let mut bb_edge = HashSet::new();
+ let mut unresolved_bb_edge = Vec::new();
+ // bb start means that a bb is starting at this statement, but there's no predecessor
+ let mut bb_start = Vec::new();
let mut labels = HashMap::new();
for (idx, s) in fun.iter().enumerate() {
match s {
- Statement::Instruction(_, i) => {
+ Statement::Instruction(pred, i) => {
if let Some(id) = i.jump_target() {
- indirect_bb_start.push((StmtIndex(idx), id));
+ unresolved_bb_edge.push((StmtIndex(idx), id));
if idx + 1 < fun.len() {
- direct_bb_start.push((StmtIndex(idx), StmtIndex(idx + 1)));
+ if pred.is_some() {
+ bb_edge.insert((StmtIndex(idx), StmtIndex(idx + 1)));
+ }
+ bb_start.push(StmtIndex(idx + 1));
}
+ } else if i.is_terminal() && idx + 1 < fun.len() {
+ bb_start.push(StmtIndex(idx + 1));
}
}
Statement::Label(id) => {
@@ -377,6 +389,25 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> { }
};
}
+ // Resolve every <jump into label> into <jump into statement index>
+ // TODO: handle jumps into nowhere
+ for (idx, id) in unresolved_bb_edge {
+ let target = labels[&id];
+ bb_edge.insert((idx, target));
+ bb_start.push(target);
+ // now check if the preceding statement forms an edge
+ if target != StmtIndex(0) {
+ match &fun[target.0 - 1] {
+ Statement::Instruction(pred, i) => {
+ if !((pred.is_none() && i.jump_target().is_some()) || i.is_terminal()) {
+ bb_edge.insert((StmtIndex(target.0 - 1), target));
+ }
+ }
+ Statement::Label(_) => (),
+ }
+ }
+ }
+ // Create list of bbs without succ/pred
let mut bbs_map = BTreeMap::new();
bbs_map.insert(
StmtIndex(0),
@@ -386,32 +417,22 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> { succ: Vec::new(),
},
);
- // TODO: handle jumps into nowhere
- let resolved_indirect_bb_start = indirect_bb_start
- .into_iter()
- .map(|(idx, id)| (idx, labels[&id]))
- .collect::<Vec<_>>();
- for (_, to) in direct_bb_start
- .iter()
- .chain(resolved_indirect_bb_start.iter())
- {
- bbs_map.entry(*to).or_insert_with(|| BasicBlock {
- start: *to,
+ for bb_first_stmt in bb_start {
+ bbs_map.entry(bb_first_stmt).or_insert_with(|| BasicBlock {
+ start: bb_first_stmt,
pred: Vec::new(),
succ: Vec::new(),
});
}
+ // Populate succ/pred
let indexed_bbs_map = bbs_map
.into_iter()
.enumerate()
.map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val))))
.collect::<BTreeMap<_, _>>();
- for (from, to) in direct_bb_start
- .iter()
- .chain(resolved_indirect_bb_start.iter())
- {
- let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=*from).next_back().unwrap();
- let (to_idx, to_ref) = indexed_bbs_map.get(to).unwrap();
+ for (from, to) in bb_edge {
+ let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=from).next_back().unwrap();
+ let (to_idx, to_ref) = indexed_bbs_map.get(&to).unwrap();
{
from_ref.borrow_mut().succ.push(*to_idx);
}
@@ -527,9 +548,9 @@ struct BasicBlock { succ: Vec<BBIndex>,
}
-#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
+#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
struct StmtIndex(pub usize);
-#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
+#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
struct BBIndex(pub usize);
enum Statement {
@@ -646,6 +667,23 @@ impl<T: Copy> ast::Instruction<T> { }
}
+ fn is_terminal(&self) -> bool {
+ match self {
+ ast::Instruction::Ret(_) => true,
+ ast::Instruction::Ld(_, _)
+ | ast::Instruction::Mov(_, _)
+ | ast::Instruction::Mul(_, _)
+ | ast::Instruction::Add(_, _)
+ | ast::Instruction::Setp(_, _)
+ | ast::Instruction::SetpBool(_, _)
+ | ast::Instruction::Not(_, _)
+ | ast::Instruction::Cvt(_, _)
+ | ast::Instruction::Shl(_, _)
+ | ast::Instruction::St(_, _)
+ | ast::Instruction::Bra(_, _) => false,
+ }
+ }
+
fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
match self {
ast::Instruction::Ld(_, a) => a.for_dst_id(f),
@@ -826,6 +864,8 @@ impl<T> ast::MovOperand<T> { #[cfg(test)]
mod tests {
use super::*;
+ use crate::ast;
+ use crate::ptx;
// page 411
#[test]
@@ -1140,4 +1180,84 @@ mod tests { ]
);
}
+
+ fn sort_pred_succ(bb: &mut BasicBlock) {
+ bb.pred.sort();
+ bb.succ.sort();
+ }
+
+ // page 403
+ #[test]
+ fn gather_phi_sets_19_4() {
+ let func = "{
+ mov.u32 i, 1;
+ mov.u32 j, 1;
+ mov.u32 k, 0;
+ block_2:
+ setp.ge.u32 p, k, 100;
+ @p bra block_4;
+ block_3:
+ setp.ge.u32 q, j, 20;
+ @q bra block_6;
+ block_5:
+ mov.u32 j, i;
+ add.u32 k, k, 1;
+ bra block_7;
+ block_6:
+ mov.u32 j, k;
+ add.u32 k, k, 2;
+ block_7:
+ bra block_2;
+ block_4:
+ ret;
+ }";
+ let mut errors = Vec::new();
+ let ast = ptx::FunctionBodyParser::new()
+ .parse(&mut errors, func)
+ .unwrap();
+ assert_eq!(errors.len(), 0);
+ let (normalized_ids, _) = normalize_identifiers(ast);
+ let mut bbs = get_basic_blocks(&normalized_ids);
+ bbs.iter_mut().for_each(sort_pred_succ);
+ assert_eq!(
+ bbs,
+ vec![
+ BasicBlock {
+ start: StmtIndex(0),
+ pred: vec![],
+ succ: vec![BBIndex(1)]
+ },
+ BasicBlock {
+ start: StmtIndex(3),
+ pred: vec![BBIndex(0), BBIndex(5)],
+ succ: vec![BBIndex(2), BBIndex(6)]
+ },
+ BasicBlock {
+ start: StmtIndex(6),
+ pred: vec![BBIndex(1)],
+ succ: vec![BBIndex(3), BBIndex(4)]
+ },
+ BasicBlock {
+ start: StmtIndex(9),
+ pred: vec![BBIndex(2)],
+ succ: vec![BBIndex(5)]
+ },
+ BasicBlock {
+ start: StmtIndex(13),
+ pred: vec![BBIndex(2)],
+ succ: vec![BBIndex(5)]
+ },
+ BasicBlock {
+ start: StmtIndex(16),
+ pred: vec![BBIndex(3), BBIndex(4)],
+ succ: vec![BBIndex(1)]
+ },
+ BasicBlock {
+ start: StmtIndex(18),
+ pred: vec![BBIndex(1)],
+ succ: vec![]
+ },
+ ]
+ );
+ }
}
|