aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-04-22 00:55:49 +0200
committerAndrzej Janik <[email protected]>2020-04-22 00:55:49 +0200
commit7b2bc69330f2043791db01f96a4daf8198116503 (patch)
treefe5c104c285bbf7f87af13c35eb1b6a0039080b3
parent0c71826bc773612f08a8787241a7b564a2b0cfd2 (diff)
downloadZLUDA-7b2bc69330f2043791db01f96a4daf8198116503.tar.gz
ZLUDA-7b2bc69330f2043791db01f96a4daf8198116503.zip
Start doing SSA conversion
-rw-r--r--ptx/src/ast.rs3
-rw-r--r--ptx/src/lib.rs1
-rw-r--r--ptx/src/test/mod.rs3
-rw-r--r--ptx/src/translate.rs186
4 files changed, 155 insertions, 38 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index a7bbe1f..9089c01 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,7 +1,4 @@
use std::convert::From;
-use std::convert::Into;
-use std::error::Error;
-use std::mem;
use std::num::ParseIntError;
quick_error! {
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 61c3444..f8bb7fd 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -9,6 +9,7 @@ extern crate spirv_headers as spirv;
lalrpop_mod!(ptx);
+#[cfg(test)]
mod test;
mod translate;
pub mod ast;
diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs
index e12097a..15876ad 100644
--- a/ptx/src/test/mod.rs
+++ b/ptx/src/test/mod.rs
@@ -2,7 +2,7 @@ use super::ptx;
fn parse_and_assert(s: &str) {
let mut errors = Vec::new();
- let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
+ ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
assert!(errors.len() == 0);
}
@@ -12,6 +12,7 @@ fn empty() {
}
#[test]
+#[allow(non_snake_case)]
fn vectorAdd_kernel64_ptx() {
let vector_add = include_str!("vectorAdd_kernel64.ptx");
parse_and_assert(vector_add);
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 259bcd2..5584af5 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,8 +1,8 @@
use crate::ast;
use bit_vec::BitVec;
use rspirv::dr;
+use std::cell::RefCell;
use std::collections::{BTreeMap, HashMap, HashSet};
-use std::{cell::RefCell, ptr};
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
enum SpirvType {
@@ -57,23 +57,8 @@ impl TypeWordMap {
}
}
-struct IdWordMap<'a>(HashMap<&'a str, spirv::Word>);
-
-impl<'a> IdWordMap<'a> {
- fn new() -> Self {
- IdWordMap(HashMap::new())
- }
-}
-
-impl<'a> IdWordMap<'a> {
- fn get_or_add(&mut self, b: &mut dr::Builder, id: &'a str) -> spirv::Word {
- *self.0.entry(id).or_insert_with(|| b.id())
- }
-}
-
pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
let mut builder = dr::Builder::new();
- let mut ids = IdWordMap::new();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
builder.set_version(1, 0);
emit_capabilities(&mut builder);
@@ -82,7 +67,7 @@ pub fn to_spirv(ast: ast::Module) -> Result<Vec<u32>, rspirv::dr::Error> {
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
for f in ast.functions {
- emit_function(&mut builder, &mut map, &mut ids, f)?;
+ emit_function(&mut builder, &mut map, f)?;
}
Ok(vec![])
}
@@ -111,9 +96,8 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn emit_function<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- ids: &mut IdWordMap<'a>,
f: ast::Function<'a>,
-) -> Result<(), rspirv::dr::Error> {
+) -> Result<spirv::Word, rspirv::dr::Error> {
let func_id = builder.begin_function(
map.void(),
None,
@@ -128,15 +112,19 @@ fn emit_function<'a>(
let bbs = get_basic_blocks(&normalized_ids);
let rpostorder = to_reverse_postorder(&bbs);
let dom_fronts = dominance_frontiers(&bbs, &rpostorder);
- let ssa = ssa_legalize(normalized_ids, dom_fronts);
- emit_function_body_ops(ssa, builder);
+ let (ops, phis) = ssa_legalize(normalized_ids, bbs, &dom_fronts);
+ emit_function_body_ops(builder, ops, phis);
builder.ret()?;
builder.end_function()?;
- Ok(())
+ Ok(func_id)
}
-fn emit_function_body_ops(ssa: Vec<Statement>, builder: &mut dr::Builder) {
- unimplemented!()
+fn emit_function_body_ops(
+ builder: &mut dr::Builder,
+ ops: Vec<Statement>,
+ phis: Vec<RefCell<PhiBasicBlock>>,
+) {
+ todo!()
}
// TODO: support scopes
@@ -158,8 +146,47 @@ fn normalize_identifiers<'a>(func: Vec<ast::Statement<&'a str>>) -> Vec<Statemen
result
}
-fn ssa_legalize(func: Vec<Statement>, dom_fronts: Vec<HashSet<BBIndex>>) -> Vec<Statement> {
- unimplemented!()
+fn ssa_legalize(
+ func: Vec<Statement>,
+ bbs: Vec<BasicBlock>,
+ dom_fronts: &Vec<HashSet<BBIndex>>,
+) -> (Vec<Statement>, Vec<RefCell<PhiBasicBlock>>) {
+ let mut phis = gather_phi_sets(&func, &bbs, dom_fronts);
+ trim_singleton_phi_sets(&mut phis);
+ todo!()
+}
+
+fn gather_phi_sets(
+ func: &Vec<Statement>,
+ bbs: &Vec<BasicBlock>,
+ dom_fronts: &Vec<HashSet<BBIndex>>,
+) -> Vec<HashMap<spirv::Word, HashSet<BBIndex>>> {
+ let mut phis = vec![HashMap::new(); bbs.len()];
+ for (bb_idx, bb) in bbs.iter().enumerate() {
+ let StmtIndex(start) = bb.start;
+ let end = if bb_idx == bbs.len() - 1 {
+ bbs.len()
+ } else {
+ bbs[bb_idx + 1].start.0
+ };
+ for s in func[start..end].iter() {
+ s.for_dst_id(&mut |id| {
+ for BBIndex(phi_target) in dom_fronts[bb_idx].iter() {
+ phis[*phi_target]
+ .entry(id)
+ .or_insert_with(|| HashSet::new())
+ .insert(BBIndex(bb_idx));
+ }
+ });
+ }
+ }
+ phis
+}
+
+fn trim_singleton_phi_sets(phis: &mut Vec<HashMap<spirv::Word, HashSet<BBIndex>>>) {
+ for phi_map in phis.iter_mut() {
+ phi_map.retain(|_, set| set.len() > 1);
+ }
}
fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
@@ -179,7 +206,6 @@ fn get_basic_blocks(fun: &Vec<Statement>) -> Vec<BasicBlock> {
Statement::Label(id) => {
labels.insert(id, StmtIndex(idx));
}
- Statement::Phi(_) => (),
};
}
let mut bbs_map = BTreeMap::new();
@@ -322,10 +348,10 @@ fn to_reverse_postorder(input: &Vec<BasicBlock>) -> Vec<BBIndex> {
result
}
-#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
-struct StmtIndex(pub usize);
-#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
-struct BBIndex(pub usize);
+struct PhiBasicBlock {
+ bb: BasicBlock,
+ phi: Vec<(spirv::Word, Vec<(spirv::Word, BBIndex)>)>,
+}
#[derive(Eq, PartialEq, Debug, Clone)]
struct BasicBlock {
@@ -334,10 +360,17 @@ struct BasicBlock {
succ: Vec<BBIndex>,
}
+#[derive(Eq, PartialEq, Debug, Copy, Clone, Ord, PartialOrd)]
+struct StmtIndex(pub usize);
+#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Hash)]
+struct BBIndex(pub usize);
+
enum Statement {
Label(u32),
- Instruction(Option<ast::PredAt<u32>>, ast::Instruction<u32>),
- Phi(Vec<spirv::Word>),
+ Instruction(
+ Option<ast::PredAt<spirv::Word>>,
+ ast::Instruction<spirv::Word>,
+ ),
}
impl Statement {
@@ -353,6 +386,16 @@ impl Statement {
ast::Statement::Variable(_) => None,
}
}
+
+ fn for_dst_id<F: FnMut(spirv::Word)>(&self, f: &mut F) {
+ match self {
+ Statement::Label(id) => f(*id),
+ Statement::Instruction(pred, inst) => {
+ pred.as_ref().map(|p| p.for_dst_id(f));
+ inst.for_dst_id(f);
+ }
+ }
+ }
}
impl<T> ast::PredAt<T> {
@@ -364,6 +407,10 @@ impl<T> ast::PredAt<T> {
}
}
+impl<T: Copy> ast::PredAt<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {}
+}
+
impl<T> ast::Instruction<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
match self {
@@ -387,7 +434,7 @@ impl<T> ast::Instruction<T> {
impl<T: Copy> ast::Instruction<T> {
fn jump_target(&self) -> Option<T> {
match self {
- ast::Instruction::Bra(d, a) => Some(a.dst),
+ ast::Instruction::Bra(_, a) => Some(a.dst),
ast::Instruction::Ld(_, _)
| ast::Instruction::Mov(_, _)
| ast::Instruction::Mul(_, _)
@@ -402,6 +449,24 @@ impl<T: Copy> ast::Instruction<T> {
| ast::Instruction::Ret(_) => None,
}
}
+
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ match self {
+ ast::Instruction::Bra(_, a) => a.for_dst_id(f),
+ ast::Instruction::Ld(_, a) => a.for_dst_id(f),
+ ast::Instruction::Mov(_, a) => a.for_dst_id(f),
+ ast::Instruction::Mul(_, a) => a.for_dst_id(f),
+ ast::Instruction::Add(_, a) => a.for_dst_id(f),
+ ast::Instruction::Setp(_, a) => a.for_dst_id(f),
+ ast::Instruction::SetpBool(_, a) => a.for_dst_id(f),
+ ast::Instruction::Not(_, a) => a.for_dst_id(f),
+ ast::Instruction::Cvt(_, a) => a.for_dst_id(f),
+ ast::Instruction::Shl(_, a) => a.for_dst_id(f),
+ ast::Instruction::St(_, a) => a.for_dst_id(f),
+ ast::Instruction::At(_, a) => a.for_dst_id(f),
+ ast::Instruction::Ret(_) => (),
+ }
+ }
}
impl<T> ast::Arg1<T> {
@@ -410,6 +475,12 @@ impl<T> ast::Arg1<T> {
}
}
+impl<T: Copy> ast::Arg1<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ f(self.dst)
+ }
+}
+
impl<T> ast::Arg2<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2<U> {
ast::Arg2 {
@@ -419,6 +490,12 @@ impl<T> ast::Arg2<T> {
}
}
+impl<T: Copy> ast::Arg2<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ f(self.dst);
+ }
+}
+
impl<T> ast::Arg2Mov<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg2Mov<U> {
ast::Arg2Mov {
@@ -428,6 +505,12 @@ impl<T> ast::Arg2Mov<T> {
}
}
+impl<T: Copy> ast::Arg2Mov<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ f(self.dst);
+ }
+}
+
impl<T> ast::Arg3<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg3<U> {
ast::Arg3 {
@@ -438,6 +521,12 @@ impl<T> ast::Arg3<T> {
}
}
+impl<T: Copy> ast::Arg3<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ f(self.dst);
+ }
+}
+
impl<T> ast::Arg4<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg4<U> {
ast::Arg4 {
@@ -449,6 +538,13 @@ impl<T> ast::Arg4<T> {
}
}
+impl<T: Copy> ast::Arg4<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ f(self.dst1);
+ self.dst2.map(|t| f(t));
+ }
+}
+
impl<T> ast::Arg5<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg5<U> {
ast::Arg5 {
@@ -461,6 +557,13 @@ impl<T> ast::Arg5<T> {
}
}
+impl<T: Copy> ast::Arg5<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ f(self.dst1);
+ self.dst2.map(|t| f(t));
+ }
+}
+
impl<T> ast::Operand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Operand<U> {
match self {
@@ -471,6 +574,12 @@ impl<T> ast::Operand<T> {
}
}
+impl<T: Copy> ast::Operand<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ unreachable!()
+ }
+}
+
impl<T> ast::MovOperand<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::MovOperand<U> {
match self {
@@ -480,6 +589,15 @@ impl<T> ast::MovOperand<T> {
}
}
+impl<T: Copy> ast::MovOperand<T> {
+ fn for_dst_id<F: FnMut(T)>(&self, f: &mut F) {
+ match self {
+ ast::MovOperand::Op(o) => o.for_dst_id(f),
+ ast::MovOperand::Vec(_, _) => (),
+ }
+ }
+}
+
// CFGs below taken from "Modern Compiler Implementation in Java"
#[cfg(test)]
mod tests {