3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-07 01:54:08 +00:00
z3/src/ast/euf/euf_egraph.h
Nikolaj Bjorner ecddaeae66 na
Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
2020-09-01 07:15:13 -07:00

186 lines
6.8 KiB
C++

/*++
Copyright (c) 2020 Microsoft Corporation
Module Name:
euf_egraph.h
Abstract:
E-graph layer
Author:
Nikolaj Bjorner (nbjorner) 2020-08-23
Notes:
It relies on
- data structures form the (legacy) SMT solver.
- it still uses eager path compression.
- delayed congruence table reconstruction from egg.
- it does not deduplicate parents.
--*/
#pragma once
#include "util/statistics.h"
#include "util/trail.h"
#include "ast/euf/euf_enode.h"
#include "ast/euf/euf_etable.h"
namespace euf {
struct add_eq_record {
enode* r1;
enode* n1;
unsigned r2_num_parents;
add_eq_record(enode* r1, enode* n1, unsigned r2_num_parents):
r1(r1), n1(n1), r2_num_parents(r2_num_parents) {}
};
/***
\brief store derived theory equalities.
Theory 'id' is notified with the equality of theory variables v1, v2
that are merged into the common root of child and root (their roots may
have been updated since the equality was derived, but the explanation for
v1 == v2 is provided by explaining the equality child == root.
*/
struct th_eq {
theory_id m_id;
theory_var m_v1;
theory_var m_v2;
enode* m_child;
enode* m_root;
th_eq(theory_id id, theory_var v1, theory_var v2, enode* c, enode* r) :
m_id(id), m_v1(v1), m_v2(v2), m_child(c), m_root(r) {}
};
class egraph {
typedef ptr_vector<trail<egraph> > trail_stack;
struct scope {
bool m_inconsistent;
unsigned m_num_eqs;
unsigned m_num_nodes;
unsigned m_trail_sz;
};
struct stats {
unsigned m_num_merge;
unsigned m_num_lits;
unsigned m_num_eqs;
unsigned m_num_conflicts;
stats() { reset(); }
void reset() { memset(this, 0, sizeof(*this)); }
};
ast_manager& m;
trail_stack m_trail;
region m_region;
enode_vector m_worklist;
etable m_table;
svector<add_eq_record> m_eqs;
svector<scope> m_scopes;
enode_vector m_expr2enode;
enode_vector m_nodes;
expr_ref_vector m_exprs;
unsigned m_num_scopes { 0 };
bool m_inconsistent { false };
enode *m_n1 { nullptr };
enode *m_n2 { nullptr };
justification m_justification;
enode_vector m_new_eqs;
enode_vector m_new_lits;
svector<th_eq> m_new_th_eqs;
enode_vector m_todo;
stats m_stats;
std::function<void(expr*,expr*,expr*)> m_used_eq;
std::function<void(app*,app*)> m_used_cc;
void push_eq(enode* r1, enode* n1, unsigned r2_num_parents) {
m_eqs.push_back(add_eq_record(r1, n1, r2_num_parents));
}
void undo_eq(enode* r1, enode* n1, unsigned r2_num_parents);
enode* mk_enode(expr* f, unsigned num_args, enode * const* args);
void reinsert(enode* n);
void force_push();
void set_conflict(enode* n1, enode* n2, justification j);
void merge(enode* n1, enode* n2, justification j);
void merge_th_eq(enode* n, enode* root);
void merge_justification(enode* n1, enode* n2, justification j);
void unmerge_justification(enode* n1);
void dedup_equalities();
void reinsert_equality(enode* p);
void update_children(enode* n);
void push_lca(enode* a, enode* b);
enode* find_lca(enode* a, enode* b);
void push_to_lca(enode* a, enode* lca);
void push_congruence(enode* n1, enode* n2, bool commutative);
void push_todo(enode* n);
template <typename T>
void explain_eq(ptr_vector<T>& justifications, enode* a, enode* b, justification const& j) {
if (j.is_external())
justifications.push_back(j.ext<T>());
else if (j.is_congruence())
push_congruence(a, b, j.is_commutative());
}
template <typename T>
void explain_todo(ptr_vector<T>& justifications);
std::ostream& display(std::ostream& out, unsigned max_args, enode* n) const;
public:
egraph(ast_manager& m): m(m), m_table(m), m_exprs(m) {}
~egraph();
enode* find(expr* f) { return m_expr2enode.get(f->get_id(), nullptr); }
enode* mk(expr* f, unsigned n, enode *const* args);
void push() { ++m_num_scopes; }
void pop(unsigned num_scopes);
bool is_equality(enode* n) const;
/**
\brief merge nodes, all effects are deferred to the propagation step.
*/
void merge(enode* n1, enode* n2, void* reason) { merge(n1, n2, justification::external(reason)); }
/**
\brief propagate set of merges.
This call may detect an inconsistency. Then inconsistent() is true.
Use then explain() to extract an explanation for the conflict.
It may also infer new implied equalities, when the roots of the
equated nodes are merged. Use then new_eqs() to extract the vector
of new equalities.
*/
void propagate();
bool inconsistent() const { return m_inconsistent; }
enode_vector const& new_eqs() const { return m_new_eqs; }
enode_vector const& new_lits() const { return m_new_lits; }
svector<th_eq> const& new_th_eqs() const { return m_new_th_eqs; }
void add_th_var(enode* n, theory_var v, theory_id id);
void set_used_eq(std::function<void(expr*,expr*,expr*)>& used_eq) { m_used_eq = used_eq; }
void set_used_cc(std::function<void(app*,app*)>& used_cc) { m_used_cc = used_cc; }
template <typename T>
void explain(ptr_vector<T>& justifications);
template <typename T>
void explain_eq(ptr_vector<T>& justifications, enode* a, enode* b);
enode_vector const& nodes() const { return m_nodes; }
void invariant();
void copy_from(egraph const& src, std::function<void*(void*)>& copy_justification);
struct e_pp {
egraph const& g;
enode* n;
e_pp(egraph const& g, enode* n) : g(g), n(n) {}
std::ostream& display(std::ostream& out) const { return g.display(out, 0, n); }
};
e_pp pp(enode* n) const { return e_pp(*this, n); }
std::ostream& display(std::ostream& out) const;
void collect_statistics(statistics& st) const;
};
inline std::ostream& operator<<(std::ostream& out, egraph const& g) { return g.display(out); }
inline std::ostream& operator<<(std::ostream& out, egraph::e_pp const& p) { return p.display(out); }
}