Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/ast/desugar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ fn flatten_actions(actions: &Vec<Action>, desugar: &mut Desugar) -> Vec<NormActi
let added_variants = add_expr(variants.clone(), &mut res);
res.push(NormAction::Extract(added, added_variants));
}
Action::Print(expr) => {
Action::Print(expr, print_tree) => {
let added = add_expr(expr.clone(), &mut res);
res.push(NormAction::Print(added));
res.push(NormAction::Print(added, *print_tree));
}
Action::Delete(symbol, exprs) => {
let del = NormAction::Delete(NormExpr::Call(
Expand Down Expand Up @@ -636,13 +636,16 @@ pub(crate) fn desugar_command(
vec![NCommand::Check(flatten_facts(&facts, desugar))]
}
Command::CheckProof => vec![NCommand::CheckProof],
Command::GetProof(query) => desugar.desugar_get_proof(&query)?,
Command::LookupProof(expr) => match &expr {
Command::GetProof(query, print_tree) => desugar.desugar_get_proof(&query, print_tree)?,
Command::LookupProof(expr, print_tree) => match &expr {
Expr::Call(f, args) => {
if !args.is_empty() {
return Err(Error::LookupProofRequiresExpr(expr.to_string()));
}
vec![NCommand::LookupProof(NormExpr::Call(*f, vec![]))]
vec![NCommand::LookupProof(
NormExpr::Call(*f, vec![]),
print_tree,
)]
}
_ => {
return Err(Error::LookupProofRequiresExpr(expr.to_string()));
Expand Down Expand Up @@ -851,11 +854,16 @@ impl Desugar {
res
}

fn desugar_get_proof(&mut self, query: &Vec<Fact>) -> Result<Vec<NCommand>, Error> {
fn desugar_get_proof(
&mut self,
query: &Vec<Fact>,
print_tree: bool,
) -> Result<Vec<NCommand>, Error> {
let proof_ruleset = self.fresh().as_str();
let result_sort = self.fresh().as_str();
let result_func = self.fresh().as_str();
let query_str = ListDisplay(query, " ");
let tree_or_dag = if print_tree { "" } else { ":dag" };
desugar_commands(
self.parse_program(&format!(
"
Expand All @@ -867,7 +875,7 @@ impl Desugar {
(({result_func}))
:ruleset {proof_ruleset})
(run {proof_ruleset} 1)
(lookup-proof ({result_func}))
(lookup-proof ({result_func}) {tree_or_dag})
",
))
.unwrap(),
Expand Down
57 changes: 33 additions & 24 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ pub enum NCommand {
name: Symbol,
file: String,
},
GetProof(Vec<NormFact>),
LookupProof(NormExpr),
GetProof(Vec<NormFact>, bool),
LookupProof(NormExpr, bool),
}

impl NormCommand {
Expand Down Expand Up @@ -151,10 +151,11 @@ impl NCommand {
Command::Check(facts.iter().map(|fact| fact.to_fact()).collect())
}
NCommand::CheckProof => Command::CheckProof,
NCommand::GetProof(query) => {
Command::GetProof(query.iter().map(|fact| fact.to_fact()).collect::<Vec<_>>())
}
NCommand::LookupProof(expr) => Command::LookupProof(expr.to_expr()),
NCommand::GetProof(query, print_tree) => Command::GetProof(
query.iter().map(|fact| fact.to_fact()).collect::<Vec<_>>(),
*print_tree,
),
NCommand::LookupProof(expr, print_tree) => Command::LookupProof(expr.to_expr(), *print_tree),
NCommand::PrintTable(name, n) => Command::PrintTable(*name, *n),
NCommand::PrintSize(name) => Command::PrintSize(*name),
NCommand::Output { file, exprs } => Command::Output {
Expand Down Expand Up @@ -196,10 +197,11 @@ impl NCommand {
NCommand::Check(facts.iter().map(|fact| fact.map_exprs(f)).collect())
}
NCommand::CheckProof => NCommand::CheckProof,
NCommand::GetProof(query) => {
NCommand::GetProof(query.iter().map(|fact| fact.map_exprs(f)).collect())
}
NCommand::LookupProof(expr) => NCommand::LookupProof(f(expr)),
NCommand::GetProof(query, print_tree) => NCommand::GetProof(
query.iter().map(|fact| fact.map_exprs(f)).collect(),
*print_tree,
),
NCommand::LookupProof(expr, print_tree) => NCommand::LookupProof(f(expr), *print_tree),
NCommand::PrintTable(name, n) => NCommand::PrintTable(*name, *n),
NCommand::PrintSize(name) => NCommand::PrintSize(*name),
NCommand::Output { file, exprs } => NCommand::Output {
Expand Down Expand Up @@ -368,8 +370,8 @@ pub enum Command {
// TODO: this could just become an empty query
Check(Vec<Fact>),
CheckProof,
GetProof(Vec<Fact>),
LookupProof(Expr),
GetProof(Vec<Fact>, bool),
LookupProof(Expr, bool),
PrintTable(Symbol, usize),
PrintSize(Symbol),
Input {
Expand Down Expand Up @@ -410,8 +412,13 @@ impl ToSexp for Command {
Command::Extract { variants, fact } => list!("extract", ":variants", variants, fact),
Command::Check(facts) => list!("check", ++ facts),
Command::CheckProof => list!("check-proof"),
Command::GetProof(query) => list!("get-proof", ++ query),
Command::LookupProof(expr) => list!("lookup-proof", expr),
Command::GetProof(query, print_tree) => {
list!(
list!("get-proof", ++ query),
if *print_tree { "" } else { ":dag" }
)
}
Command::LookupProof(expr, print_tree) => list!("lookup-proof", expr, if *print_tree { "" } else { ":dag" }),
Command::Push(n) => list!("push", n),
Command::Pop(n) => list!("pop", n),
Command::PrintTable(name, n) => list!("print-table", name, n),
Expand Down Expand Up @@ -762,7 +769,7 @@ pub enum Action {
Union(Expr, Expr),
// What to extract and how many variants
Extract(Expr, Expr),
Print(Expr),
Print(Expr, bool),
Panic(String),
Expr(Expr),
// If(Expr, Action, Action),
Expand All @@ -774,7 +781,7 @@ pub enum NormAction {
LetVar(Symbol, Symbol),
LetLit(Symbol, Literal),
Extract(Symbol, Symbol),
Print(Symbol),
Print(Symbol, bool),
Set(NormExpr, Symbol),
Delete(NormExpr),
Union(Symbol, Symbol),
Expand All @@ -795,7 +802,7 @@ impl NormAction {
NormAction::Extract(symbol, variants) => {
Action::Extract(Expr::Var(*symbol), Expr::Var(*variants))
}
NormAction::Print(symbol) => Action::Print(Expr::Var(*symbol)),
NormAction::Print(symbol, print_tree) => Action::Print(Expr::Var(*symbol), *print_tree),
NormAction::Delete(NormExpr::Call(symbol, args)) => {
Action::Delete(*symbol, args.iter().map(|s| Expr::Var(*s)).collect())
}
Expand All @@ -811,7 +818,7 @@ impl NormAction {
NormAction::LetLit(symbol, lit) => NormAction::LetLit(*symbol, lit.clone()),
NormAction::Set(expr, other) => NormAction::Set(f(expr), *other),
NormAction::Extract(var, variants) => NormAction::Extract(*var, *variants),
NormAction::Print(var) => NormAction::Print(*var),
NormAction::Print(var, print_tree) => NormAction::Print(*var, *print_tree),
NormAction::Delete(expr) => NormAction::Delete(f(expr)),
NormAction::Union(lhs, rhs) => NormAction::Union(*lhs, *rhs),
NormAction::Panic(msg) => NormAction::Panic(msg.clone()),
Expand All @@ -834,7 +841,7 @@ impl NormAction {
NormAction::Extract(var, variants) => {
NormAction::Extract(fvar(*var, false), fvar(*variants, false))
}
NormAction::Print(var) => NormAction::Print(fvar(*var, false)),
NormAction::Print(var, print_tree) => NormAction::Print(fvar(*var, false), *print_tree),
NormAction::Delete(expr) => NormAction::Delete(expr.map_def_use(fvar, false)),
NormAction::Union(lhs, rhs) => NormAction::Union(fvar(*lhs, false), fvar(*rhs, false)),
NormAction::Panic(msg) => NormAction::Panic(msg.clone()),
Expand All @@ -850,7 +857,9 @@ impl ToSexp for Action {
Action::Union(lhs, rhs) => list!("union", lhs, rhs),
Action::Delete(lhs, args) => list!("delete", list!(lhs, ++ args)),
Action::Extract(expr, variants) => list!("extract", expr, variants),
Action::Print(expr) => list!("print", expr),
Action::Print(expr, print_tree) => {
list!("print", expr, if *print_tree { "" } else { ":dag" })
}
Action::Panic(msg) => list!("panic", format!("\"{}\"", msg.clone())),
Action::Expr(e) => e.to_sexp(),
}
Expand All @@ -868,7 +877,7 @@ impl Action {
Action::Delete(lhs, args) => Action::Delete(*lhs, args.iter().map(f).collect()),
Action::Union(lhs, rhs) => Action::Union(f(lhs), f(rhs)),
Action::Extract(expr, variants) => Action::Extract(f(expr), f(variants)),
Action::Print(expr) => Action::Print(f(expr)),
Action::Print(expr, print_tree) => Action::Print(f(expr), *print_tree),
Action::Panic(msg) => Action::Panic(msg.clone()),
Action::Expr(e) => Action::Expr(f(e)),
}
Expand All @@ -889,7 +898,7 @@ impl Action {
Action::Extract(expr, variants) => {
Action::Extract(expr.subst(canon), variants.subst(canon))
}
Action::Print(expr) => Action::Print(expr.subst(canon)),
Action::Print(expr, print_tree) => Action::Print(expr.subst(canon), *print_tree),
Action::Panic(msg) => Action::Panic(msg.clone()),
Action::Expr(e) => Action::Expr(e.subst(canon)),
}
Expand Down Expand Up @@ -1044,10 +1053,10 @@ impl NormRule {
used.insert(*variants);
head.push(Action::Extract(new_expr, new_expr2));
}
NormAction::Print(symbol) => {
NormAction::Print(symbol, print_tree) => {
let new_expr = subst.get(symbol).cloned().unwrap_or(Expr::Var(*symbol));
used.insert(*symbol);
head.push(Action::Print(new_expr));
head.push(Action::Print(new_expr, *print_tree));
}
NormAction::LetLit(symbol, lit) => {
subst.insert(*symbol, Expr::Lit(lit.clone()));
Expand Down
6 changes: 3 additions & 3 deletions src/ast/parse.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ Command: Command = {
LParen "query-extract" <variants:(":variants" <UNum>)?> <fact:Fact> RParen => Command::Extract { fact, variants: variants.unwrap_or(0) },
LParen "check" <(Fact)*> RParen => Command::Check(<>),
LParen "check-proof" RParen => Command::CheckProof,
LParen "get-proof" <(Fact)*> RParen => Command::GetProof(<>),
LParen "lookup-proof" <Expr> RParen => Command::LookupProof(<>),
LParen "get-proof" <facts:(Fact)*> <print_tree:":dag"?> RParen => Command::GetProof(facts, !print_tree.is_some()),
LParen "lookup-proof" <expr:Expr> <print_tree:":dag"?> RParen => Command::LookupProof(expr, !print_tree.is_some()),
LParen "run-schedule" <Schedule*> RParen => Command::RunSchedule(Schedule::Sequence(<>)),
LParen "push" <UNum?> RParen => Command::Push(<>.unwrap_or(1)),
LParen "pop" <UNum?> RParen => Command::Pop(<>.unwrap_or(1)),
Expand Down Expand Up @@ -113,7 +113,7 @@ NonLetAction: Action = {
LParen "panic" <msg:String> RParen => Action::Panic(msg),
LParen "extract" <expr:Expr> RParen => Action::Extract(expr, Expr::Lit(Literal::Int(0))),
LParen "extract" <expr:Expr> <variants:Expr> RParen => Action::Extract(expr, variants),
LParen "print" <expr:Expr> RParen => Action::Print(expr),
LParen "print" <expr:Expr> <print_tree:":dag"?> RParen => Action::Print(expr, !print_tree.is_some()),
<e:CallExpr> => Action::Expr(e),
}

Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ impl EGraph {
for expr in exprs {
use std::io::Write;
let res = self.extract_expr(expr, 1)?;
writeln!(f, "{}", res.termdag.to_string(&res.expr))
writeln!(f, "{}", res.termdag.to_string(&res.expr, &true))
.map_err(|e| Error::IoError(filename.clone(), e))?;
}

Expand All @@ -1010,10 +1010,10 @@ impl EGraph {
}
}

pub fn term_to_string(&mut self, term: Value) -> String {
pub fn term_to_string(&mut self, term: Value, print_tree:&bool) -> String {
let mut termdag = TermDag::default();
let (_cost, expr) = self.print(term, &mut termdag, &self.get_sort(&term).unwrap());
termdag.to_string(&expr)
termdag.to_string(&expr, print_tree)
}

// Extract an expression from the current state, returning the cost, the extracted expression and some number
Expand Down
12 changes: 6 additions & 6 deletions src/proof_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ impl<'a> ProofChecker<'a> {
}

fn string_from_term(&self, term: Term) -> String {
let with_quotes = self.termdag.to_string(&term);
let with_quotes = self.termdag.to_string(&term, &true);
assert!(with_quotes.len() >= 2);
assert!(with_quotes.starts_with('"'));
assert!(with_quotes.ends_with('"'));
Expand Down Expand Up @@ -188,7 +188,7 @@ impl<'a> ProofChecker<'a> {
assert!(args.len() == 1);
let unwrapped = self.termdag.get_term(args[0]);
let Term::App(data_head, args) = unwrapped.clone() else {
panic!("Expected a datatype wrapper. Got: {}", self.termdag.to_string(&unwrapped))
panic!("Expected a datatype wrapper. Got: {}", self.termdag.to_string(&unwrapped, &true))
};
assert!(data_head.as_str() == stripped);
(
Expand Down Expand Up @@ -252,7 +252,7 @@ impl<'a> ProofChecker<'a> {
name,
"Expected operators to match: {} != {}",
&NormExpr::Call(*op, body.clone()),
self.termdag.to_string(&current_term)
self.termdag.to_string(&current_term, &true)
);
assert_eq!(body.len(), inputs.len());
for (arg, targ) in body.iter().zip(inputs) {
Expand Down Expand Up @@ -284,8 +284,8 @@ impl<'a> ProofChecker<'a> {
self.get_term(&rule_ctx, *var_a),
self.get_term(&rule_ctx, *var_b),
"Expected terms to be equal in rule proof: {} != {} with variables {} and {}",
self.termdag.to_string(&self.get_term(&rule_ctx, *var_a)),
self.termdag.to_string(&self.get_term(&rule_ctx, *var_b)),
self.termdag.to_string(&self.get_term(&rule_ctx, *var_a), &true),
self.termdag.to_string(&self.get_term(&rule_ctx, *var_b), &true),
var_a,
var_b
);
Expand Down Expand Up @@ -406,7 +406,7 @@ impl<'a> ProofChecker<'a> {
let output = primitive.apply(&body_vals, self.egraph).unwrap_or_else(|| {
panic!(
"Proof checking failed- primitive did not return a value. Primitive term: {}",
self.termdag.to_string(&term)
self.termdag.to_string(&term, &true)
)
});

Expand Down
8 changes: 6 additions & 2 deletions src/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,12 @@ impl ProofState {
}]
}
NCommand::GetProof(..) => panic!("GetProof should have been desugared"),
NCommand::LookupProof(expr) => self
.parse_actions(vec![format!("(print {})", self.get_proof(expr, None))])
NCommand::LookupProof(expr, print_tree) => self
.parse_actions(vec![format!(
"(print {} {})",
self.get_proof(expr, None),
if *print_tree { "" } else { ":dag" }
)])
.into_iter()
.map(Command::Action)
.collect(),
Expand Down
2 changes: 1 addition & 1 deletion src/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl EGraph {
} else {
let mut termdag = TermDag::default();
let term = sort.make_expr(self, *value, &mut termdag);
termdag.to_string(&term)
termdag.to_string(&term, &true)
};
egraph.nodes.insert(
node_id.clone(),
Expand Down
Loading