From e61c69b50cf0f4d2720b9560e6054e17c85ea517 Mon Sep 17 00:00:00 2001 From: tnv5 Date: Tue, 5 Sep 2023 19:44:44 -0400 Subject: [PATCH 01/10] added tree flag --- src/lib.rs | 4 +-- src/proof_checker.rs | 12 ++++---- src/serialize.rs | 2 +- src/termdag.rs | 73 +++++++++++++++++++++++--------------------- src/typecheck.rs | 7 +++-- 5 files changed, 53 insertions(+), 45 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index d266ddc47..d34de242e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -993,7 +993,7 @@ impl EGraph { for expr in exprs { use std::io::Write; let res = self.extract_expr(expr, 1)?; - writeln!(f, "{}", res.termdag.to_string(&res.expr)) + writeln!(f, "{}", res.termdag.to_string(&res.expr, true)) .map_err(|e| Error::IoError(filename.clone(), e))?; } @@ -1013,7 +1013,7 @@ impl EGraph { pub fn term_to_string(&mut self, term: Value) -> String { let mut termdag = TermDag::default(); let (_cost, expr) = self.print(term, &mut termdag, &self.get_sort(&term).unwrap()); - termdag.to_string(&expr) + termdag.to_string(&expr, true) } // Extract an expression from the current state, returning the cost, the extracted expression and some number diff --git a/src/proof_checker.rs b/src/proof_checker.rs index 406ae6bfd..0c609144d 100644 --- a/src/proof_checker.rs +++ b/src/proof_checker.rs @@ -140,7 +140,7 @@ impl<'a> ProofChecker<'a> { } fn string_from_term(&self, term: Term) -> String { - let with_quotes = self.termdag.to_string(&term); + let with_quotes = self.termdag.to_string(&term, true); assert!(with_quotes.len() >= 2); assert!(with_quotes.starts_with('"')); assert!(with_quotes.ends_with('"')); @@ -188,7 +188,7 @@ impl<'a> ProofChecker<'a> { assert!(args.len() == 1); let unwrapped = self.termdag.get_term(args[0]); let Term::App(data_head, args) = unwrapped.clone() else { - panic!("Expected a datatype wrapper. Got: {}", self.termdag.to_string(&unwrapped)) + panic!("Expected a datatype wrapper. Got: {}", self.termdag.to_string(&unwrapped, true)) }; assert!(data_head.as_str() == stripped); ( @@ -252,7 +252,7 @@ impl<'a> ProofChecker<'a> { name, "Expected operators to match: {} != {}", &NormExpr::Call(*op, body.clone()), - self.termdag.to_string(¤t_term) + self.termdag.to_string(¤t_term, true) ); assert_eq!(body.len(), inputs.len()); for (arg, targ) in body.iter().zip(inputs) { @@ -284,8 +284,8 @@ impl<'a> ProofChecker<'a> { self.get_term(&rule_ctx, *var_a), self.get_term(&rule_ctx, *var_b), "Expected terms to be equal in rule proof: {} != {} with variables {} and {}", - self.termdag.to_string(&self.get_term(&rule_ctx, *var_a)), - self.termdag.to_string(&self.get_term(&rule_ctx, *var_b)), + self.termdag.to_string(&self.get_term(&rule_ctx, *var_a), true), + self.termdag.to_string(&self.get_term(&rule_ctx, *var_b), true), var_a, var_b ); @@ -406,7 +406,7 @@ impl<'a> ProofChecker<'a> { let output = primitive.apply(&body_vals, self.egraph).unwrap_or_else(|| { panic!( "Proof checking failed- primitive did not return a value. Primitive term: {}", - self.termdag.to_string(&term) + self.termdag.to_string(&term, true) ) }); diff --git a/src/serialize.rs b/src/serialize.rs index e825c50e9..50ace2cb9 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -159,7 +159,7 @@ impl EGraph { } else { let mut termdag = TermDag::default(); let term = sort.make_expr(self, *value, &mut termdag); - termdag.to_string(&term) + termdag.to_string(&term, true) }; egraph.nodes.insert( node_id.clone(), diff --git a/src/termdag.rs b/src/termdag.rs index bfdff9783..87ef6b60e 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -174,8 +174,8 @@ impl TermDag { New: {}\n", old.clone().unwrap(), node, - self.to_string(&old.unwrap()), - self.to_string(&node), + self.to_string(&old.unwrap(), true), + self.to_string(&node, true), ); new_id } @@ -197,40 +197,45 @@ impl TermDag { } } - pub fn to_string(&self, term: &Term) -> String { - let mut stored = HashMap::::default(); - let mut seen = HashSet::default(); - let id = self.get_id(term); - // use a stack to avoid stack overflow - let mut stack = vec![id]; - while !stack.is_empty() { - let next = stack.pop().unwrap(); - - match self.nodes.get(&next).unwrap().clone() { - Term::App(name, children) => { - if seen.contains(&next) { - let mut str = String::new(); - str.push_str(&format!("({}", name)); - for c in children.iter() { - str.push_str(&format!(" {}", stored[c])); - } - str.push(')'); - stored.insert(next, str); - } else { - seen.insert(next); - stack.push(next); - for c in children.iter().rev() { - stack.push(*c); + pub fn to_string(&self, term: &Term, tree: bool) -> String { + if tree { + let mut stored = HashMap::::default(); + let mut seen = HashSet::default(); + let id = self.get_id(term); + // use a stack to avoid stack overflow + let mut stack = vec![id]; + while !stack.is_empty() { + let next = stack.pop().unwrap(); + + match self.nodes.get(&next).unwrap().clone() { + Term::App(name, children) => { + if seen.contains(&next) { + let mut str = String::new(); + str.push_str(&format!("({}", name)); + for c in children.iter() { + str.push_str(&format!(" {}", stored[c])); + } + str.push(')'); + stored.insert(next, str); + } else { + seen.insert(next); + stack.push(next); + for c in children.iter().rev() { + stack.push(*c); + } } } - } - Term::Lit(lit) => { - stored.insert(next, format!("{}", lit)); + Term::Lit(lit) => { + stored.insert(next, format!("{}", lit)); + } } } - } - stored.get(&id).unwrap().clone() + stored.get(&id).unwrap().clone() + } else { + // TODO + String::new() + } } pub fn display_entry(&self, entry: &FunctionEntry) -> String { @@ -238,14 +243,14 @@ impl TermDag { format!( "({} {})", entry.name, - ListDisplay(entry.inputs.iter().map(|t| self.to_string(t)), " "), + ListDisplay(entry.inputs.iter().map(|t| self.to_string(t, true)), " "), ) } else { format!( "({} {}) -> {}", entry.name, - ListDisplay(entry.inputs.iter().map(|t| self.to_string(t)), " "), - self.to_string(&entry.output) + ListDisplay(entry.inputs.iter().map(|t| self.to_string(t, true)), " "), + self.to_string(&entry.output, true) ) } } diff --git a/src/typecheck.rs b/src/typecheck.rs index 6365c6e32..8088cc842 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -906,7 +906,10 @@ impl EGraph { .name_to_sort(&values[0].tag) .unwrap(), ); - log::info!("extracted with cost {cost}: {}", termdag.to_string(&expr)); + log::info!( + "extracted with cost {cost}: {}", + termdag.to_string(&expr, true) + ); } else { if variants < 0 { panic!("Cannot extract negative number of variants"); @@ -915,7 +918,7 @@ impl EGraph { self.extract_variants(values[0], variants as usize, &mut termdag); log::info!("extracted variants:"); for expr in extracted { - log::info!(" {}", termdag.to_string(&expr)); + log::info!(" {}", termdag.to_string(&expr, true)); } } From ba6253a7deb9d88e12a9d9c14b99148aa16b1a45 Mon Sep 17 00:00:00 2001 From: tnv5 Date: Wed, 6 Sep 2023 00:14:59 -0400 Subject: [PATCH 02/10] added termdag tree flag --- src/ast/desugar.rs | 4 ++-- src/ast/mod.rs | 22 ++++++++++++---------- src/ast/parse.lalrpop | 2 +- src/lib.rs | 6 +++--- src/proof_checker.rs | 12 ++++++------ src/serialize.rs | 2 +- src/termdag.rs | 16 +++++++++------- src/typecheck.rs | 18 +++++++++--------- src/typechecking.rs | 4 ++-- 9 files changed, 45 insertions(+), 41 deletions(-) diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index 6877193da..b67e53a3e 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -272,9 +272,9 @@ fn flatten_actions(actions: &Vec, desugar: &mut Desugar) -> Vec { + Action::Print(expr, print_tree) => { let added = add_expr(expr.clone(), &mut res); - res.push(NormAction::Print(added)); + res.push(NormAction::Print(added, *print_tree)); } Action::Delete(symbol, exprs) => { let del = NormAction::Delete(NormExpr::Call( diff --git a/src/ast/mod.rs b/src/ast/mod.rs index d2b561a75..2a76d42db 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -762,7 +762,7 @@ pub enum Action { Union(Expr, Expr), // What to extract and how many variants Extract(Expr, Expr), - Print(Expr), + Print(Expr, bool), Panic(String), Expr(Expr), // If(Expr, Action, Action), @@ -774,7 +774,7 @@ pub enum NormAction { LetVar(Symbol, Symbol), LetLit(Symbol, Literal), Extract(Symbol, Symbol), - Print(Symbol), + Print(Symbol, bool), Set(NormExpr, Symbol), Delete(NormExpr), Union(Symbol, Symbol), @@ -795,7 +795,7 @@ impl NormAction { NormAction::Extract(symbol, variants) => { Action::Extract(Expr::Var(*symbol), Expr::Var(*variants)) } - NormAction::Print(symbol) => Action::Print(Expr::Var(*symbol)), + NormAction::Print(symbol, print_tree) => Action::Print(Expr::Var(*symbol), *print_tree), NormAction::Delete(NormExpr::Call(symbol, args)) => { Action::Delete(*symbol, args.iter().map(|s| Expr::Var(*s)).collect()) } @@ -811,7 +811,7 @@ impl NormAction { NormAction::LetLit(symbol, lit) => NormAction::LetLit(*symbol, lit.clone()), NormAction::Set(expr, other) => NormAction::Set(f(expr), *other), NormAction::Extract(var, variants) => NormAction::Extract(*var, *variants), - NormAction::Print(var) => NormAction::Print(*var), + NormAction::Print(var, print_tree) => NormAction::Print(*var, *print_tree), NormAction::Delete(expr) => NormAction::Delete(f(expr)), NormAction::Union(lhs, rhs) => NormAction::Union(*lhs, *rhs), NormAction::Panic(msg) => NormAction::Panic(msg.clone()), @@ -834,7 +834,7 @@ impl NormAction { NormAction::Extract(var, variants) => { NormAction::Extract(fvar(*var, false), fvar(*variants, false)) } - NormAction::Print(var) => NormAction::Print(fvar(*var, false)), + NormAction::Print(var, print_tree) => NormAction::Print(fvar(*var, false), *print_tree), NormAction::Delete(expr) => NormAction::Delete(expr.map_def_use(fvar, false)), NormAction::Union(lhs, rhs) => NormAction::Union(fvar(*lhs, false), fvar(*rhs, false)), NormAction::Panic(msg) => NormAction::Panic(msg.clone()), @@ -850,7 +850,9 @@ impl ToSexp for Action { Action::Union(lhs, rhs) => list!("union", lhs, rhs), Action::Delete(lhs, args) => list!("delete", list!(lhs, ++ args)), Action::Extract(expr, variants) => list!("extract", expr, variants), - Action::Print(expr) => list!("print", expr), + Action::Print(expr, print_tree) => { + list!("print", expr, if *print_tree { "" } else { ":dag" }) + } Action::Panic(msg) => list!("panic", format!("\"{}\"", msg.clone())), Action::Expr(e) => e.to_sexp(), } @@ -868,7 +870,7 @@ impl Action { Action::Delete(lhs, args) => Action::Delete(*lhs, args.iter().map(f).collect()), Action::Union(lhs, rhs) => Action::Union(f(lhs), f(rhs)), Action::Extract(expr, variants) => Action::Extract(f(expr), f(variants)), - Action::Print(expr) => Action::Print(f(expr)), + Action::Print(expr, print_tree) => Action::Print(f(expr), *print_tree), Action::Panic(msg) => Action::Panic(msg.clone()), Action::Expr(e) => Action::Expr(f(e)), } @@ -889,7 +891,7 @@ impl Action { Action::Extract(expr, variants) => { Action::Extract(expr.subst(canon), variants.subst(canon)) } - Action::Print(expr) => Action::Print(expr.subst(canon)), + Action::Print(expr, print_tree) => Action::Print(expr.subst(canon), *print_tree), Action::Panic(msg) => Action::Panic(msg.clone()), Action::Expr(e) => Action::Expr(e.subst(canon)), } @@ -1044,10 +1046,10 @@ impl NormRule { used.insert(*variants); head.push(Action::Extract(new_expr, new_expr2)); } - NormAction::Print(symbol) => { + NormAction::Print(symbol, print_tree) => { let new_expr = subst.get(symbol).cloned().unwrap_or(Expr::Var(*symbol)); used.insert(*symbol); - head.push(Action::Print(new_expr)); + head.push(Action::Print(new_expr, *print_tree)); } NormAction::LetLit(symbol, lit) => { subst.insert(*symbol, Expr::Lit(lit.clone())); diff --git a/src/ast/parse.lalrpop b/src/ast/parse.lalrpop index 490adde03..c0dd13387 100644 --- a/src/ast/parse.lalrpop +++ b/src/ast/parse.lalrpop @@ -113,7 +113,7 @@ NonLetAction: Action = { LParen "panic" RParen => Action::Panic(msg), LParen "extract" RParen => Action::Extract(expr, Expr::Lit(Literal::Int(0))), LParen "extract" RParen => Action::Extract(expr, variants), - LParen "print" RParen => Action::Print(expr), + LParen "print" RParen => Action::Print(expr, !print_tree.is_some()), => Action::Expr(e), } diff --git a/src/lib.rs b/src/lib.rs index d34de242e..167dc5f22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -993,7 +993,7 @@ impl EGraph { for expr in exprs { use std::io::Write; let res = self.extract_expr(expr, 1)?; - writeln!(f, "{}", res.termdag.to_string(&res.expr, true)) + writeln!(f, "{}", res.termdag.to_string(&res.expr, &true)) .map_err(|e| Error::IoError(filename.clone(), e))?; } @@ -1010,10 +1010,10 @@ impl EGraph { } } - pub fn term_to_string(&mut self, term: Value) -> String { + pub fn term_to_string(&mut self, term: Value, print_tree:&bool) -> String { let mut termdag = TermDag::default(); let (_cost, expr) = self.print(term, &mut termdag, &self.get_sort(&term).unwrap()); - termdag.to_string(&expr, true) + termdag.to_string(&expr, print_tree) } // Extract an expression from the current state, returning the cost, the extracted expression and some number diff --git a/src/proof_checker.rs b/src/proof_checker.rs index 0c609144d..ccd31e90f 100644 --- a/src/proof_checker.rs +++ b/src/proof_checker.rs @@ -140,7 +140,7 @@ impl<'a> ProofChecker<'a> { } fn string_from_term(&self, term: Term) -> String { - let with_quotes = self.termdag.to_string(&term, true); + let with_quotes = self.termdag.to_string(&term, &true); assert!(with_quotes.len() >= 2); assert!(with_quotes.starts_with('"')); assert!(with_quotes.ends_with('"')); @@ -188,7 +188,7 @@ impl<'a> ProofChecker<'a> { assert!(args.len() == 1); let unwrapped = self.termdag.get_term(args[0]); let Term::App(data_head, args) = unwrapped.clone() else { - panic!("Expected a datatype wrapper. Got: {}", self.termdag.to_string(&unwrapped, true)) + panic!("Expected a datatype wrapper. Got: {}", self.termdag.to_string(&unwrapped, &true)) }; assert!(data_head.as_str() == stripped); ( @@ -252,7 +252,7 @@ impl<'a> ProofChecker<'a> { name, "Expected operators to match: {} != {}", &NormExpr::Call(*op, body.clone()), - self.termdag.to_string(¤t_term, true) + self.termdag.to_string(¤t_term, &true) ); assert_eq!(body.len(), inputs.len()); for (arg, targ) in body.iter().zip(inputs) { @@ -284,8 +284,8 @@ impl<'a> ProofChecker<'a> { self.get_term(&rule_ctx, *var_a), self.get_term(&rule_ctx, *var_b), "Expected terms to be equal in rule proof: {} != {} with variables {} and {}", - self.termdag.to_string(&self.get_term(&rule_ctx, *var_a), true), - self.termdag.to_string(&self.get_term(&rule_ctx, *var_b), true), + self.termdag.to_string(&self.get_term(&rule_ctx, *var_a), &true), + self.termdag.to_string(&self.get_term(&rule_ctx, *var_b), &true), var_a, var_b ); @@ -406,7 +406,7 @@ impl<'a> ProofChecker<'a> { let output = primitive.apply(&body_vals, self.egraph).unwrap_or_else(|| { panic!( "Proof checking failed- primitive did not return a value. Primitive term: {}", - self.termdag.to_string(&term, true) + self.termdag.to_string(&term, &true) ) }); diff --git a/src/serialize.rs b/src/serialize.rs index 50ace2cb9..a5888e559 100644 --- a/src/serialize.rs +++ b/src/serialize.rs @@ -159,7 +159,7 @@ impl EGraph { } else { let mut termdag = TermDag::default(); let term = sort.make_expr(self, *value, &mut termdag); - termdag.to_string(&term, true) + termdag.to_string(&term, &true) }; egraph.nodes.insert( node_id.clone(), diff --git a/src/termdag.rs b/src/termdag.rs index 87ef6b60e..759dc622c 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -174,8 +174,8 @@ impl TermDag { New: {}\n", old.clone().unwrap(), node, - self.to_string(&old.unwrap(), true), - self.to_string(&node, true), + self.to_string(&old.unwrap(), &true), + self.to_string(&node, &true), ); new_id } @@ -197,8 +197,8 @@ impl TermDag { } } - pub fn to_string(&self, term: &Term, tree: bool) -> String { - if tree { + pub fn to_string(&self, term: &Term, print_tree: &bool) -> String { + if *print_tree { let mut stored = HashMap::::default(); let mut seen = HashSet::default(); let id = self.get_id(term); @@ -234,6 +234,8 @@ impl TermDag { stored.get(&id).unwrap().clone() } else { // TODO + let mut stored = HashMap::::default(); + String::new() } } @@ -243,14 +245,14 @@ impl TermDag { format!( "({} {})", entry.name, - ListDisplay(entry.inputs.iter().map(|t| self.to_string(t, true)), " "), + ListDisplay(entry.inputs.iter().map(|t| self.to_string(t, &true)), " "), ) } else { format!( "({} {}) -> {}", entry.name, - ListDisplay(entry.inputs.iter().map(|t| self.to_string(t, true)), " "), - self.to_string(&entry.output, true) + ListDisplay(entry.inputs.iter().map(|t| self.to_string(t, &true)), " "), + self.to_string(&entry.output, &true) ) } } diff --git a/src/typecheck.rs b/src/typecheck.rs index 8088cc842..17d001827 100644 --- a/src/typecheck.rs +++ b/src/typecheck.rs @@ -483,9 +483,9 @@ impl<'a> ActionChecker<'a> { self.instructions.push(Instruction::Extract(2)); Ok(()) } - Action::Print(expr) => { + Action::Print(expr, print_tree) => { let (_, _ty) = self.infer_expr(expr)?; - self.instructions.push(Instruction::Print); + self.instructions.push(Instruction::Print(*print_tree)); Ok(()) } Action::Delete(f, args) => { @@ -670,7 +670,7 @@ enum Instruction { Set(Symbol), Union(usize), Extract(usize), - Print, + Print(bool), Panic(String), Pop, } @@ -785,7 +785,7 @@ impl EGraph { _ => { let terms = values .iter() - .map(|v| self.term_to_string(*v)) + .map(|v| self.term_to_string(*v, &true)) .collect::>(); return Err(Error::NotFoundError(NotFoundError(Expr::Var( format!( @@ -800,7 +800,7 @@ impl EGraph { } else { let terms = values .iter() - .map(|v| self.term_to_string(*v)) + .map(|v| self.term_to_string(*v, &true)) .collect::>(); return Err(Error::NotFoundError(NotFoundError(Expr::Var( format!( @@ -882,9 +882,9 @@ impl EGraph { Instruction::Union(_arity) => { panic!("term encoding gets rid of union"); } - Instruction::Print => { + Instruction::Print(print_tree) => { let to_print = stack.pop().unwrap(); - let extracted = self.term_to_string(to_print); + let extracted = self.term_to_string(to_print, print_tree); log::info!("printing: {}", extracted); } Instruction::Extract(arity) => { @@ -908,7 +908,7 @@ impl EGraph { ); log::info!( "extracted with cost {cost}: {}", - termdag.to_string(&expr, true) + termdag.to_string(&expr, &true) ); } else { if variants < 0 { @@ -918,7 +918,7 @@ impl EGraph { self.extract_variants(values[0], variants as usize, &mut termdag); log::info!("extracted variants:"); for expr in extracted { - log::info!(" {}", termdag.to_string(&expr, true)); + log::info!(" {}", termdag.to_string(&expr, &true)); } } diff --git a/src/typechecking.rs b/src/typechecking.rs index 9f9e1c4dc..bf29b94ed 100644 --- a/src/typechecking.rs +++ b/src/typechecking.rs @@ -397,7 +397,7 @@ impl TypeInfo { assert_bound(var, let_bound); assert_bound(variants, let_bound); } - NormAction::Print(var) => { + NormAction::Print(var, print_tree) => { assert_bound(var, let_bound); } NormAction::Union(v1, v2) => { @@ -466,7 +466,7 @@ impl TypeInfo { } } NormAction::Extract(_var, _variants) => {} - NormAction::Print(_var) => {} + NormAction::Print(_var, print_tree) => {} NormAction::LetVar(var1, var2) => { let var2_type = self.lookup(ctx, *var2)?; self.introduce_binding(ctx, *var1, var2_type, is_global)?; From 5a9952c1eac6c1924feba126b465f2a93b4420a2 Mon Sep 17 00:00:00 2001 From: tnv5 Date: Wed, 6 Sep 2023 00:35:53 -0400 Subject: [PATCH 03/10] added flag to get-proof --- src/ast/desugar.rs | 2 +- src/ast/mod.rs | 25 ++++++++++++++++--------- src/ast/parse.lalrpop | 2 +- src/terms.rs | 2 +- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index b67e53a3e..afc9b3221 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -636,7 +636,7 @@ pub(crate) fn desugar_command( vec![NCommand::Check(flatten_facts(&facts, desugar))] } Command::CheckProof => vec![NCommand::CheckProof], - Command::GetProof(query) => desugar.desugar_get_proof(&query)?, + Command::GetProof(query, print_tree) => desugar.desugar_get_proof(&query)?, Command::LookupProof(expr) => match &expr { Expr::Call(f, args) => { if !args.is_empty() { diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 2a76d42db..5ad5b6335 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -116,7 +116,7 @@ pub enum NCommand { name: Symbol, file: String, }, - GetProof(Vec), + GetProof(Vec, bool), LookupProof(NormExpr), } @@ -151,9 +151,10 @@ impl NCommand { Command::Check(facts.iter().map(|fact| fact.to_fact()).collect()) } NCommand::CheckProof => Command::CheckProof, - NCommand::GetProof(query) => { - Command::GetProof(query.iter().map(|fact| fact.to_fact()).collect::>()) - } + NCommand::GetProof(query, print_tree) => Command::GetProof( + query.iter().map(|fact| fact.to_fact()).collect::>(), + *print_tree, + ), NCommand::LookupProof(expr) => Command::LookupProof(expr.to_expr()), NCommand::PrintTable(name, n) => Command::PrintTable(*name, *n), NCommand::PrintSize(name) => Command::PrintSize(*name), @@ -196,9 +197,10 @@ impl NCommand { NCommand::Check(facts.iter().map(|fact| fact.map_exprs(f)).collect()) } NCommand::CheckProof => NCommand::CheckProof, - NCommand::GetProof(query) => { - NCommand::GetProof(query.iter().map(|fact| fact.map_exprs(f)).collect()) - } + NCommand::GetProof(query, print_tree) => NCommand::GetProof( + query.iter().map(|fact| fact.map_exprs(f)).collect(), + *print_tree, + ), NCommand::LookupProof(expr) => NCommand::LookupProof(f(expr)), NCommand::PrintTable(name, n) => NCommand::PrintTable(*name, *n), NCommand::PrintSize(name) => NCommand::PrintSize(*name), @@ -368,7 +370,7 @@ pub enum Command { // TODO: this could just become an empty query Check(Vec), CheckProof, - GetProof(Vec), + GetProof(Vec, bool), LookupProof(Expr), PrintTable(Symbol, usize), PrintSize(Symbol), @@ -410,7 +412,12 @@ impl ToSexp for Command { Command::Extract { variants, fact } => list!("extract", ":variants", variants, fact), Command::Check(facts) => list!("check", ++ facts), Command::CheckProof => list!("check-proof"), - Command::GetProof(query) => list!("get-proof", ++ query), + Command::GetProof(query, print_tree) => { + list!( + list!("get-proof", ++ query), + if *print_tree { "" } else { ":dag" } + ) + } Command::LookupProof(expr) => list!("lookup-proof", expr), Command::Push(n) => list!("push", n), Command::Pop(n) => list!("pop", n), diff --git a/src/ast/parse.lalrpop b/src/ast/parse.lalrpop index c0dd13387..50353d4e3 100644 --- a/src/ast/parse.lalrpop +++ b/src/ast/parse.lalrpop @@ -78,7 +78,7 @@ Command: Command = { LParen "query-extract" )?> RParen => Command::Extract { fact, variants: variants.unwrap_or(0) }, LParen "check" <(Fact)*> RParen => Command::Check(<>), LParen "check-proof" RParen => Command::CheckProof, - LParen "get-proof" <(Fact)*> RParen => Command::GetProof(<>), + LParen "get-proof" RParen => Command::GetProof(fact, !print_tree.is_some()), LParen "lookup-proof" RParen => Command::LookupProof(<>), LParen "run-schedule" RParen => Command::RunSchedule(Schedule::Sequence(<>)), LParen "push" RParen => Command::Push(<>.unwrap_or(1)), diff --git a/src/terms.rs b/src/terms.rs index c9c41eb76..a7b9a8687 100644 --- a/src/terms.rs +++ b/src/terms.rs @@ -358,7 +358,7 @@ impl ProofState { res.extend(with_term_encoding); res.push(Command::Fail(Box::new(last))); } - NCommand::GetProof(_query) => { + NCommand::GetProof(..) => { panic!("GetProof should be desugared"); } NCommand::LookupProof(..) From 5cdebcb8aa0a2973467a382ecccfaea5f75cc851 Mon Sep 17 00:00:00 2001 From: tnv5 Date: Wed, 6 Sep 2023 02:27:46 -0400 Subject: [PATCH 04/10] added flag to lookup-proof and fixed desugaring --- src/ast/desugar.rs | 11 ++++++----- src/ast/mod.rs | 10 +++++----- src/ast/parse.lalrpop | 4 ++-- src/proofs.rs | 2 +- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index afc9b3221..0d0541452 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -636,13 +636,13 @@ pub(crate) fn desugar_command( vec![NCommand::Check(flatten_facts(&facts, desugar))] } Command::CheckProof => vec![NCommand::CheckProof], - Command::GetProof(query, print_tree) => desugar.desugar_get_proof(&query)?, - Command::LookupProof(expr) => match &expr { + Command::GetProof(query, print_tree) => desugar.desugar_get_proof(&query, print_tree)?, + Command::LookupProof(expr, print_tree) => match &expr { Expr::Call(f, args) => { if !args.is_empty() { return Err(Error::LookupProofRequiresExpr(expr.to_string())); } - vec![NCommand::LookupProof(NormExpr::Call(*f, vec![]))] + vec![NCommand::LookupProof(NormExpr::Call(*f, vec![]), print_tree)] } _ => { return Err(Error::LookupProofRequiresExpr(expr.to_string())); @@ -851,11 +851,12 @@ impl Desugar { res } - fn desugar_get_proof(&mut self, query: &Vec) -> Result, Error> { + fn desugar_get_proof(&mut self, query: &Vec, print_tree: bool) -> Result, Error> { let proof_ruleset = self.fresh().as_str(); let result_sort = self.fresh().as_str(); let result_func = self.fresh().as_str(); let query_str = ListDisplay(query, " "); + let tree_or_dag = if print_tree { "" } else { ":dag" }; desugar_commands( self.parse_program(&format!( " @@ -867,7 +868,7 @@ impl Desugar { (({result_func})) :ruleset {proof_ruleset}) (run {proof_ruleset} 1) - (lookup-proof ({result_func})) + (lookup-proof ({result_func} {tree_or_dag})) ", )) .unwrap(), diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 5ad5b6335..747b8bed5 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -117,7 +117,7 @@ pub enum NCommand { file: String, }, GetProof(Vec, bool), - LookupProof(NormExpr), + LookupProof(NormExpr, bool), } impl NormCommand { @@ -155,7 +155,7 @@ impl NCommand { query.iter().map(|fact| fact.to_fact()).collect::>(), *print_tree, ), - NCommand::LookupProof(expr) => Command::LookupProof(expr.to_expr()), + NCommand::LookupProof(expr, print_tree) => Command::LookupProof(expr.to_expr(), *print_tree), NCommand::PrintTable(name, n) => Command::PrintTable(*name, *n), NCommand::PrintSize(name) => Command::PrintSize(*name), NCommand::Output { file, exprs } => Command::Output { @@ -201,7 +201,7 @@ impl NCommand { query.iter().map(|fact| fact.map_exprs(f)).collect(), *print_tree, ), - NCommand::LookupProof(expr) => NCommand::LookupProof(f(expr)), + NCommand::LookupProof(expr, print_tree) => NCommand::LookupProof(f(expr), *print_tree), NCommand::PrintTable(name, n) => NCommand::PrintTable(*name, *n), NCommand::PrintSize(name) => NCommand::PrintSize(*name), NCommand::Output { file, exprs } => NCommand::Output { @@ -371,7 +371,7 @@ pub enum Command { Check(Vec), CheckProof, GetProof(Vec, bool), - LookupProof(Expr), + LookupProof(Expr, bool), PrintTable(Symbol, usize), PrintSize(Symbol), Input { @@ -418,7 +418,7 @@ impl ToSexp for Command { if *print_tree { "" } else { ":dag" } ) } - Command::LookupProof(expr) => list!("lookup-proof", expr), + Command::LookupProof(expr, print_tree) => list!("lookup-proof", expr, if *print_tree { "" } else { ":dag" }), Command::Push(n) => list!("push", n), Command::Pop(n) => list!("pop", n), Command::PrintTable(name, n) => list!("print-table", name, n), diff --git a/src/ast/parse.lalrpop b/src/ast/parse.lalrpop index 50353d4e3..5bdb501ff 100644 --- a/src/ast/parse.lalrpop +++ b/src/ast/parse.lalrpop @@ -78,8 +78,8 @@ Command: Command = { LParen "query-extract" )?> RParen => Command::Extract { fact, variants: variants.unwrap_or(0) }, LParen "check" <(Fact)*> RParen => Command::Check(<>), LParen "check-proof" RParen => Command::CheckProof, - LParen "get-proof" RParen => Command::GetProof(fact, !print_tree.is_some()), - LParen "lookup-proof" RParen => Command::LookupProof(<>), + LParen "get-proof" RParen => Command::GetProof(facts, !print_tree.is_some()), + LParen "lookup-proof" RParen => Command::LookupProof(expr, !print_tree.is_some()), LParen "run-schedule" RParen => Command::RunSchedule(Schedule::Sequence(<>)), LParen "push" RParen => Command::Push(<>.unwrap_or(1)), LParen "pop" RParen => Command::Pop(<>.unwrap_or(1)), diff --git a/src/proofs.rs b/src/proofs.rs index 68d243f8b..1466c0376 100644 --- a/src/proofs.rs +++ b/src/proofs.rs @@ -364,7 +364,7 @@ impl ProofState { }] } NCommand::GetProof(..) => panic!("GetProof should have been desugared"), - NCommand::LookupProof(expr) => self + NCommand::LookupProof(expr, print_tree) => self .parse_actions(vec![format!("(print {})", self.get_proof(expr, None))]) .into_iter() .map(Command::Action) From e82410c72396e145b1d39f2884051c567d448e9a Mon Sep 17 00:00:00 2001 From: tnv5 Date: Thu, 7 Sep 2023 15:37:40 -0400 Subject: [PATCH 05/10] finished dag printing --- src/termdag.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/src/termdag.rs b/src/termdag.rs index 759dc622c..f3be2c6fc 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -197,8 +197,16 @@ impl TermDag { } } + fn term_id_to_string(&self, val: &TermId) -> String { + match val { + TermId::Value(v) => format!("Value tag: {} Value bits: {}", v.tag, v.bits), + TermId::Num(n) => format!("Term Num: {}", n), + } + } + pub fn to_string(&self, term: &Term, print_tree: &bool) -> String { if *print_tree { + // Tree output let mut stored = HashMap::::default(); let mut seen = HashSet::default(); let id = self.get_id(term); @@ -233,10 +241,78 @@ impl TermDag { stored.get(&id).unwrap().clone() } else { - // TODO - let mut stored = HashMap::::default(); + // TODO refactor some adjacency list stuff + // DAG output + let mut stored: HashMap = HashMap::default(); + let mut adj_list: HashMap> = HashMap::default(); + let mut seen = HashSet::default(); + let mut term_id_map: HashMap = HashMap::default(); + let mut term_id_insertion_order: Vec = Vec::default(); + let mut counter = 0; + let id = self.get_id(term); + let mut stack = vec![id]; - String::new() + // Construct the new IDs and construct the adjacency lis + while !stack.is_empty() { + let next: TermId = stack.pop().unwrap(); + if !seen.contains(&next) { + // Give a homemade id number to it + term_id_insertion_order.push(next); + term_id_map.insert(next, counter); + counter = counter + 1; + } + match self.nodes.get(&next).unwrap().clone() { + Term::App(name, children) => { + if !seen.contains(&next) { + // Add the children to get numbered and then revisit this node + seen.insert(next); + stack.push(next); + for c in children.iter().rev() { + stack.push(*c); + } + } else { + // Construct the string for this node + let mut str = String::new(); + let mut edges: Vec = Vec::default(); + str.push_str(&format!( + "(Term: {}, Value: ({}", + term_id_map.get(&next).unwrap().to_string().as_str(), + name + )); + for c in children.iter() { + str.push_str(&format!( + " (Term: {})", + term_id_map.get(c).unwrap().to_string().as_str() + )); + edges.push(*term_id_map.get(c).unwrap()); + } + str.push_str("))"); + adj_list.insert(next, edges); + stored.insert(next, str); + } + } + Term::Lit(lit) => { + adj_list.insert(next, Vec::default()); + stored.insert( + next, + format!( + "(Term: {}, Value: {})", + term_id_map.get(&next).unwrap().to_string().as_str(), + lit + ), + ); + } + } + } + + // Construct the string + let mut str = String::new(); + str.push('\n'); + for id in term_id_insertion_order.iter() { + str.push_str(stored.get(id).unwrap()); + str.push('\n'); + } + str } } From 941a59cc76d702f102956c4ccaec015e101a9a76 Mon Sep 17 00:00:00 2001 From: tnv5 Date: Thu, 7 Sep 2023 21:21:18 -0400 Subject: [PATCH 06/10] fixed lookup-proof desugaring with dag flag --- src/ast/desugar.rs | 13 ++++++++++--- src/proofs.rs | 6 +++++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/ast/desugar.rs b/src/ast/desugar.rs index 0d0541452..bf37d9e18 100644 --- a/src/ast/desugar.rs +++ b/src/ast/desugar.rs @@ -642,7 +642,10 @@ pub(crate) fn desugar_command( if !args.is_empty() { return Err(Error::LookupProofRequiresExpr(expr.to_string())); } - vec![NCommand::LookupProof(NormExpr::Call(*f, vec![]), print_tree)] + vec![NCommand::LookupProof( + NormExpr::Call(*f, vec![]), + print_tree, + )] } _ => { return Err(Error::LookupProofRequiresExpr(expr.to_string())); @@ -851,7 +854,11 @@ impl Desugar { res } - fn desugar_get_proof(&mut self, query: &Vec, print_tree: bool) -> Result, Error> { + fn desugar_get_proof( + &mut self, + query: &Vec, + print_tree: bool, + ) -> Result, Error> { let proof_ruleset = self.fresh().as_str(); let result_sort = self.fresh().as_str(); let result_func = self.fresh().as_str(); @@ -868,7 +875,7 @@ impl Desugar { (({result_func})) :ruleset {proof_ruleset}) (run {proof_ruleset} 1) - (lookup-proof ({result_func} {tree_or_dag})) + (lookup-proof ({result_func}) {tree_or_dag}) ", )) .unwrap(), diff --git a/src/proofs.rs b/src/proofs.rs index 1466c0376..f3cced1f9 100644 --- a/src/proofs.rs +++ b/src/proofs.rs @@ -365,7 +365,11 @@ impl ProofState { } NCommand::GetProof(..) => panic!("GetProof should have been desugared"), NCommand::LookupProof(expr, print_tree) => self - .parse_actions(vec![format!("(print {})", self.get_proof(expr, None))]) + .parse_actions(vec![format!( + "(print {} {})", + self.get_proof(expr, None), + if *print_tree { "" } else { ":dag" } + )]) .into_iter() .map(Command::Action) .collect(), From 1a46df259b587f37b947598c635ca5b9d22964ab Mon Sep 17 00:00:00 2001 From: tnv5 Date: Wed, 20 Sep 2023 16:43:08 -0400 Subject: [PATCH 07/10] correct ordering of value terms --- src/termdag.rs | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/src/termdag.rs b/src/termdag.rs index f3be2c6fc..967759adc 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -199,8 +199,8 @@ impl TermDag { fn term_id_to_string(&self, val: &TermId) -> String { match val { - TermId::Value(v) => format!("Value tag: {} Value bits: {}", v.tag, v.bits), - TermId::Num(n) => format!("Term Num: {}", n), + TermId::Value(v) => format!("Tag: {} Bits: {}", v.tag, v.bits), + TermId::Num(n) => format!("Num: {}", n), } } @@ -241,18 +241,17 @@ impl TermDag { stored.get(&id).unwrap().clone() } else { - // TODO refactor some adjacency list stuff // DAG output let mut stored: HashMap = HashMap::default(); let mut adj_list: HashMap> = HashMap::default(); let mut seen = HashSet::default(); let mut term_id_map: HashMap = HashMap::default(); let mut term_id_insertion_order: Vec = Vec::default(); + let mut value_map = HashMap::::default(); let mut counter = 0; let id = self.get_id(term); let mut stack = vec![id]; - // Construct the new IDs and construct the adjacency lis while !stack.is_empty() { let next: TermId = stack.pop().unwrap(); if !seen.contains(&next) { @@ -260,7 +259,40 @@ impl TermDag { term_id_insertion_order.push(next); term_id_map.insert(next, counter); counter = counter + 1; + if let TermId::Value(v) = next { + value_map.insert(v.bits as i32, next); + } + } + if let Term::App(_, children) = self.nodes.get(&next).unwrap().clone() { + if !seen.contains(&next) { + // Add the children to get numbered and then revisit this node + seen.insert(next); + for c in children.iter().rev() { + stack.push(*c); + } + } } + } + + let mut values = value_map.keys().collect::>(); + let mut homemade_value_ids = value_map + .iter() + .map(|(k, v)| *term_id_map.get(v).unwrap()) + .collect::>(); + + values.sort(); + homemade_value_ids.sort(); + + for (i, v) in values.iter().enumerate() { + term_id_map.insert(value_map[*v], homemade_value_ids[i]); + } + + seen.clear(); + stack = vec![id]; + + // Construct the new IDs and construct the adjacency lis + while !stack.is_empty() { + let next: TermId = stack.pop().unwrap(); match self.nodes.get(&next).unwrap().clone() { Term::App(name, children) => { if !seen.contains(&next) { From f7c9837c47a16a3b3f7b101c8acdf02274e96522 Mon Sep 17 00:00:00 2001 From: tnv5 Date: Tue, 17 Oct 2023 16:08:12 -0400 Subject: [PATCH 08/10] Added comments --- src/termdag.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/termdag.rs b/src/termdag.rs index 967759adc..20b67ec7d 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -245,13 +245,16 @@ impl TermDag { let mut stored: HashMap = HashMap::default(); let mut adj_list: HashMap> = HashMap::default(); let mut seen = HashSet::default(); + // Maps term IDs to their new homemade IDs let mut term_id_map: HashMap = HashMap::default(); let mut term_id_insertion_order: Vec = Vec::default(); + // Maps homemade IDs to their original term IDs let mut value_map = HashMap::::default(); let mut counter = 0; let id = self.get_id(term); let mut stack = vec![id]; + //Initial numbering while !stack.is_empty() { let next: TermId = stack.pop().unwrap(); if !seen.contains(&next) { @@ -265,7 +268,7 @@ impl TermDag { } if let Term::App(_, children) = self.nodes.get(&next).unwrap().clone() { if !seen.contains(&next) { - // Add the children to get numbered and then revisit this node + // Add the children to get numbered seen.insert(next); for c in children.iter().rev() { stack.push(*c); @@ -274,6 +277,7 @@ impl TermDag { } } + // Renumber homemade IDs for values to have the same ordering as the original IDs let mut values = value_map.keys().collect::>(); let mut homemade_value_ids = value_map .iter() @@ -290,7 +294,7 @@ impl TermDag { seen.clear(); stack = vec![id]; - // Construct the new IDs and construct the adjacency lis + // Construct the adjacency list for the Terms with the new IDs while !stack.is_empty() { let next: TermId = stack.pop().unwrap(); match self.nodes.get(&next).unwrap().clone() { @@ -307,7 +311,7 @@ impl TermDag { let mut str = String::new(); let mut edges: Vec = Vec::default(); str.push_str(&format!( - "(Term: {}, Value: ({}", + "(Term: {}, ({}", term_id_map.get(&next).unwrap().to_string().as_str(), name )); @@ -328,7 +332,7 @@ impl TermDag { stored.insert( next, format!( - "(Term: {}, Value: {})", + "(Term: {}, {})", term_id_map.get(&next).unwrap().to_string().as_str(), lit ), From faa7b07e59e8e3366983f5d432fd85bdf321c63f Mon Sep 17 00:00:00 2001 From: tnv5 Date: Tue, 17 Oct 2023 16:40:11 -0400 Subject: [PATCH 09/10] fixed value_map comment --- src/termdag.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/termdag.rs b/src/termdag.rs index 20b67ec7d..2631484be 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -248,7 +248,7 @@ impl TermDag { // Maps term IDs to their new homemade IDs let mut term_id_map: HashMap = HashMap::default(); let mut term_id_insertion_order: Vec = Vec::default(); - // Maps homemade IDs to their original term IDs + // Maps values to their term IDs let mut value_map = HashMap::::default(); let mut counter = 0; let id = self.get_id(term); From a70095c5b8eade50d150b04cd1a6dc6a5665ffc2 Mon Sep 17 00:00:00 2001 From: tnv5 Date: Tue, 17 Oct 2023 16:51:00 -0400 Subject: [PATCH 10/10] removed redundant control flow --- src/termdag.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/termdag.rs b/src/termdag.rs index 2631484be..3d729297a 100644 --- a/src/termdag.rs +++ b/src/termdag.rs @@ -300,13 +300,13 @@ impl TermDag { match self.nodes.get(&next).unwrap().clone() { Term::App(name, children) => { if !seen.contains(&next) { - // Add the children to get numbered and then revisit this node + // Construct the string for the children seen.insert(next); - stack.push(next); for c in children.iter().rev() { stack.push(*c); } - } else { + + seen.insert(next); // Construct the string for this node let mut str = String::new(); let mut edges: Vec = Vec::default(); @@ -341,7 +341,7 @@ impl TermDag { } } - // Construct the string + // Construct DAG output let mut str = String::new(); str.push('\n'); for id in term_id_insertion_order.iter() {