3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-05 17:14:07 +00:00

a round of cleanup

This commit is contained in:
Nikolaj Bjorner 2022-12-04 06:07:45 -08:00
parent d218083145
commit 3d7bd40a87
3 changed files with 150 additions and 189 deletions

View file

@ -3,7 +3,7 @@ Copyright (c) 2006 Microsoft Corporation
Module Name:
demodulator.cpp
demodulator_rewriter.cpp
Abstract:
@ -17,6 +17,7 @@ Revision History:
Christoph M. Wintersteiger (cwinter) 2010-04-21: Implementation
Christoph M. Wintersteiger (cwinter) 2012-10-24: Moved from demodulator.h to ufbv_rewriter.h
Nikolaj Bjorner (nbjorner) 2022-12-4: Moved to demodulator_rewriter.h
--*/
@ -186,48 +187,45 @@ public:
unsigned get_max() { return m_max_var_id; }
};
unsigned demodulator_rewriter::max_var_id(expr * e)
{
unsigned demodulator_rewriter::max_var_id(expr_ref_vector const& es) {
max_var_id_proc proc;
for_each_expr(proc, e);
for (expr* e : es)
for_each_expr(proc, e);
return proc.get_max();
}
void demodulator_rewriter::insert_fwd_idx(expr * large, expr * small, quantifier * demodulator) {
SASSERT(large->get_kind() == AST_APP);
void demodulator_rewriter::insert_fwd_idx(app * large, expr * small, quantifier * demodulator) {
SASSERT(demodulator);
SASSERT(large && small);
TRACE("demodulator_fwd", tout << "INSERT: " << mk_pp(demodulator, m) << std::endl; );
func_decl * fd = to_app(large)->get_decl();
fwd_idx_map::iterator it = m_fwd_idx.find_iterator(fd);
if (it == m_fwd_idx.end()) {
quantifier_set * qs = alloc(quantifier_set, 1);
quantifier_set * qs;
if (!m_fwd_idx.find(fd, qs)) {
qs = alloc(quantifier_set, 1);
m_fwd_idx.insert(fd, qs);
it = m_fwd_idx.find_iterator(fd);
}
SASSERT(it->m_value);
it->m_value->insert(demodulator);
SASSERT(qs);
qs->insert(demodulator);
m.inc_ref(demodulator);
m.inc_ref(large);
m.inc_ref(small);
m_demodulator2lhs_rhs.insert(demodulator, expr_pair(large, small));
m_demodulator2lhs_rhs.insert(demodulator, app_expr_pair(large, small));
}
void demodulator_rewriter::remove_fwd_idx(func_decl * f, quantifier * demodulator) {
TRACE("demodulator_fwd", tout << "REMOVE: " << std::hex << (size_t)demodulator << std::endl; );
fwd_idx_map::iterator it = m_fwd_idx.find_iterator(f);
if (it != m_fwd_idx.end()) {
demodulator2lhs_rhs::iterator fit = m_demodulator2lhs_rhs.find_iterator(demodulator);
expr_pair p = fit->m_value;
quantifier_set* qs;
if (m_fwd_idx.find(f, qs)) {
auto [lhs, rhs] = m_demodulator2lhs_rhs[demodulator];
m_demodulator2lhs_rhs.erase(demodulator);
it->m_value->erase(demodulator);
m.dec_ref(p.first);
m.dec_ref(p.second);
qs->erase(demodulator);
m.dec_ref(lhs);
m.dec_ref(rhs);
m.dec_ref(demodulator);
} else {
SASSERT(m_demodulator2lhs_rhs.contains(demodulator));
@ -235,59 +233,50 @@ void demodulator_rewriter::remove_fwd_idx(func_decl * f, quantifier * demodulato
}
bool demodulator_rewriter::check_fwd_idx_consistency() {
for (auto & kv : m_fwd_idx) {
quantifier_set * set = kv.m_value;
for (auto & [k, set] : m_fwd_idx) {
SASSERT(set);
for (auto e : *set) {
for (auto e : *set)
if (!m_demodulator2lhs_rhs.contains(e))
return false;
}
}
return true;
}
void demodulator_rewriter::show_fwd_idx(std::ostream & out) {
for (auto & kv : m_fwd_idx) {
quantifier_set * set = kv.m_value;
SASSERT(!set);
out << kv.m_key->get_name() << ": " << std::endl;
for (auto e : *set) {
out << std::hex << (size_t)e << std::endl;
}
for (auto & [k, set] : m_fwd_idx) {
out << k->get_name() << ": " << std::endl;
if (set)
for (auto e : *set)
out << std::hex << (size_t)e << std::endl;
}
out << "D2LR: " << std::endl;
for (auto & kv : m_demodulator2lhs_rhs) {
out << (size_t) kv.m_key << std::endl;
for (auto & [k, v] : m_demodulator2lhs_rhs) {
out << (size_t) k << std::endl;
}
}
bool demodulator_rewriter::rewrite1(func_decl * f, expr_ref_vector & m_new_args, expr_ref & np) {
fwd_idx_map::iterator it = m_fwd_idx.find_iterator(f);
if (it != m_fwd_idx.end()) {
TRACE("demodulator_bug", tout << "trying to rewrite: " << f->get_name() << " args:\n";
tout << m_new_args << "\n";);
for (quantifier* d : *it->m_value) {
bool demodulator_rewriter::rewrite1(func_decl * f, expr_ref_vector const & args, expr_ref & np) {
quantifier_set* set;
if (!m_fwd_idx.find(f, set))
return false;
TRACE("demodulator_bug", tout << "trying to rewrite: " << f->get_name() << " args:\n";
tout << m_new_args << "\n";);
SASSERT(m_demodulator2lhs_rhs.contains(d));
expr_pair l_s;
m_demodulator2lhs_rhs.find(d, l_s);
app * large = to_app(l_s.first);
for (quantifier* d : *set) {
if (large->get_num_args() != m_new_args.size())
continue;
TRACE("demodulator_bug", tout << "Matching with demodulator: " << mk_pp(d, m) << std::endl; );
SASSERT(large->get_decl() == f);
if (m_match_subst(large, l_s.second, m_new_args.data(), np)) {
TRACE("demodulator_bug", tout << "succeeded...\n" << mk_pp(l_s.second, m) << "\n===>\n" << mk_pp(np, m) << "\n";);
return true;
}
auto const& [large, rhs] = m_demodulator2lhs_rhs[d];
if (large->get_num_args() != args.size())
continue;
TRACE("demodulator_bug", tout << "Matching with demodulator: " << mk_pp(d, m) << std::endl; );
SASSERT(large->get_decl() == f);
if (m_match_subst(large, rhs, args.data(), np)) {
TRACE("demodulator_bug", tout << "succeeded...\n" << mk_pp(rhs, m) << "\n===>\n" << mk_pp(np, m) << "\n";);
return true;
}
}
@ -449,21 +438,19 @@ public:
void operator()(var * n) {}
void operator()(quantifier * n) {}
void operator()(app * n) {
// We track only uninterpreted and constant functions.
if (n->get_num_args()==0) return;
// We track only uninterpreted functions.
if (n->get_num_args() == 0)
return;
SASSERT(m_expr && m_expr != (expr*) 0x00000003);
func_decl * d=n->get_decl();
if (d->get_family_id() == null_family_id) {
back_idx_map::iterator it = m_back_idx.find_iterator(d);
if (it != m_back_idx.end()) {
SASSERT(it->m_value);
it->m_value->insert(m_expr);
} else {
expr_set * e = alloc(expr_set);
e->insert(m_expr);
m_back_idx.insert(d, e);
}
func_decl * d = n->get_decl();
if (d->get_family_id() != null_family_id)
return;
expr_set* set = nullptr;
if (!m_back_idx.find(d, set)) {
set = alloc(expr_set);
m_back_idx.insert(d, set);
}
set->insert(m_expr);
}
};
@ -475,39 +462,48 @@ public:
void operator()(var * n) {}
void operator()(quantifier * n) {}
void operator()(app * n) {
// We track only uninterpreted and constant functions.
if (n->get_num_args()==0) return;
func_decl * d=n->get_decl();
if (d->get_family_id() == null_family_id) {
back_idx_map::iterator it = m_back_idx.find_iterator(d);
if (it != m_back_idx.end()) {
SASSERT(it->m_value);
it->m_value->remove(m_expr);
}
}
// We track only uninterpreted functions.
if (n->get_num_args() == 0)
return;
func_decl * d = n->get_decl();
if (d->get_family_id() != null_family_id)
return;
expr_set* set = nullptr;
if (m_back_idx.find(d, set))
set->remove(m_expr);
}
};
void demodulator_rewriter::insert_bwd_idx(expr* e) {
add_back_idx_proc proc(m_back_idx, e);
for_each_expr(proc, e);
}
void demodulator_rewriter::remove_bwd_idx(expr* e) {
remove_back_idx_proc proc(m_back_idx, e);
for_each_expr(proc, e);
}
void demodulator_rewriter::reschedule_processed(func_decl * f) {
//use m_back_idx to find all formulas p in m_processed that contains f {
back_idx_map::iterator it = m_back_idx.find_iterator(f);
if (it != m_back_idx.end()) {
SASSERT(it->m_value);
expr_set temp;
expr_set* set = nullptr;
if (!m_back_idx.find(f, set))
return;
SASSERT(set);
expr_set temp;
for (expr* p : *it->m_value) {
if (m_processed.contains(p))
for (expr* p : *set)
if (m_processed.contains(p))
temp.insert(p);
}
for (expr * p : temp) {
// remove p from m_processed and m_back_idx
m_processed.remove(p);
remove_back_idx_proc proc(m_back_idx, p); // this could change it->m_value, thus we need the `temp' set.
for_each_expr(proc, p);
// insert p into m_todo
m_todo.push_back(p);
}
for (expr * p : temp) {
// remove p from m_processed and m_back_idx
m_processed.remove(p);
// this could change `set', thus we need the `temp' set.
remove_bwd_idx(p);
// insert p into m_todo
m_todo.push_back(p);
}
}
@ -545,20 +541,10 @@ bool demodulator_rewriter::can_rewrite(expr * n, expr * lhs) {
break;
case AST_QUANTIFIER:
if (!for_each_expr_args(stack, visited, to_quantifier(curr)->get_num_patterns(),
to_quantifier(curr)->get_patterns())) {
break;
}
if (!for_each_expr_args(stack, visited, to_quantifier(curr)->get_num_no_patterns(),
to_quantifier(curr)->get_no_patterns())) {
break;
}
if (!visited.is_marked(to_quantifier(curr)->get_expr())) {
if (visited.is_marked(to_quantifier(curr)->get_expr()))
stack.pop_back();
else
stack.push_back(to_quantifier(curr)->get_expr());
break;
}
stack.pop_back();
break;
default:
UNREACHABLE();
@ -571,66 +557,55 @@ bool demodulator_rewriter::can_rewrite(expr * n, expr * lhs) {
void demodulator_rewriter::reschedule_demodulators(func_decl * f, expr * lhs) {
// use m_back_idx to find all demodulators d in m_fwd_idx that contains f {
//ptr_vector<expr> to_remove;
back_idx_map::iterator it = m_back_idx.find_iterator(f);
if (it != m_back_idx.end()) {
SASSERT(it->m_value);
expr_set all_occurrences;
expr_ref l(m);
expr_set* set = nullptr;
if (!m_back_idx.find(f, set))
return;
SASSERT(set);
expr_set all_occurrences;
app_ref l(m);
for (auto s : *it->m_value)
all_occurrences.insert(s);
for (auto s : *set)
all_occurrences.insert(s);
// Run over all f-demodulators
for (expr* occ : all_occurrences) {
if (!is_quantifier(occ))
continue;
quantifier* qe = to_quantifier(occ);
// Use the fwd idx to find out whether this is a demodulator.
app_expr_pair p;
if (!m_demodulator2lhs_rhs.find(qe, p))
continue;
// Run over all f-demodulators
for (expr* occ : all_occurrences) {
l = p.first;
quantifier_ref d(qe, m);
func_decl_ref df(l->get_decl(), m);
// Now we know there is an occurrence of f in d
if (!can_rewrite(d, lhs))
continue;
if (!is_quantifier(occ))
continue;
TRACE("demodulator", tout << "Rescheduling: " << std::endl << mk_pp(d, m) << std::endl);
// Use the fwd idx to find out whether this is a demodulator.
demodulator2lhs_rhs::iterator d2lr_it = m_demodulator2lhs_rhs.find_iterator(to_quantifier(occ));
if (d2lr_it != m_demodulator2lhs_rhs.end()) {
l = d2lr_it->m_value.first;
quantifier_ref d(m);
func_decl_ref df(m);
d = to_quantifier(occ);
df = to_app(l)->get_decl();
// Now we know there is an occurrence of f in d
// if n' can rewrite d {
if (can_rewrite(d, lhs)) {
TRACE("demodulator", tout << "Rescheduling: " << std::endl << mk_pp(d, m) << std::endl; );
// remove d from m_fwd_idx
remove_fwd_idx(df, d);
// remove d from m_back_idx
// just remember it here, because otherwise it and/or esit might become invalid?
// to_remove.insert(d);
remove_back_idx_proc proc(m_back_idx, d);
for_each_expr(proc, d);
// insert d into m_todo
m_todo.push_back(d);
}
}
}
remove_fwd_idx(df, d);
remove_bwd_idx(d);
m_todo.push_back(d);
}
}
void demodulator_rewriter::operator()(unsigned n, expr * const * exprs,
expr_ref_vector & new_exprs) {
void demodulator_rewriter::operator()(expr_ref_vector const& exprs,
expr_ref_vector & new_exprs) {
TRACE("demodulator", tout << "before demodulator:\n";
for ( unsigned i = 0 ; i < n ; i++ )
tout << mk_pp(exprs[i], m) << std::endl; );
TRACE("demodulator", tout << "before demodulator:\n" << exprs);
// Initially, m_todo contains all formulas. That is, it contains the argument exprs. m_fwd_idx, m_processed, m_back_idx are empty.
unsigned max_vid = 0;
for ( unsigned i = 0 ; i < n ; i++ ) {
m_todo.push_back(exprs[i]);
max_vid = std::max(max_vid, max_var_id(exprs[i]));
}
for (expr* e : exprs)
m_todo.push_back(e);
m_match_subst.reserve(max_vid);
m_match_subst.reserve(max_var_id(exprs));
while (!m_todo.empty()) {
// let n be the next formula in m_todo.
@ -644,7 +619,6 @@ void demodulator_rewriter::operator()(unsigned n, expr * const * exprs,
// unless there is a demodulator cycle
// SASSERT(rewrite(np)==np);
// if (n' is not a demodulator) {
app_ref large(m);
expr_ref small(m);
if (!is_demodulator(np, large, small)) {
@ -652,22 +626,11 @@ void demodulator_rewriter::operator()(unsigned n, expr * const * exprs,
m_processed.insert(np);
m_in_processed.push_back(np);
// update m_back_idx (traverse n' and for each uninterpreted function declaration f in n' add the entry f->n' to m_back_idx)
add_back_idx_proc proc(m_back_idx, np);
for_each_expr(proc, np);
} else {
insert_bwd_idx(np);
}
else {
// np is a demodulator that allows us to replace 'large' with 'small'.
TRACE("demodulator", tout << "Found demodulator: " << std::endl;
tout << mk_pp(large.get(), m) << std::endl << " ---> " <<
std::endl << mk_pp(small.get(), m) << std::endl; );
TRACE("demodulator_s", tout << "Found demodulator: " << std::endl;
tout << to_app(large)->get_decl()->get_name() <<
"[" << to_app(large)->get_depth() << "]" << " ---> ";
if (is_app(small))
tout << to_app(small)->get_decl()->get_name() <<
"[" << to_app(small)->get_depth() << "]" << std::endl;
else
tout << mk_pp(small.get(), m) << std::endl; );
TRACE("demodulator", tout << "Found demodulator:\n" << large << "\n ---> " << small << "\n");
// let f be the top symbol of n'
func_decl * f = large->get_decl();
@ -679,8 +642,7 @@ void demodulator_rewriter::operator()(unsigned n, expr * const * exprs,
insert_fwd_idx(large, small, to_quantifier(np));
// update m_back_idx
add_back_idx_proc proc(m_back_idx, np);
for_each_expr(proc, np);
insert_bwd_idx(np);
}
}
@ -690,12 +652,11 @@ void demodulator_rewriter::operator()(unsigned n, expr * const * exprs,
TRACE("demodulator", tout << mk_pp(e, m) << std::endl; );
}
for (auto const& kv : m_fwd_idx) {
if (kv.m_value) {
for (expr* e : *kv.m_value) {
for (auto const& [k, set] : m_fwd_idx) {
if (set) {
for (expr* e : *set)
new_exprs.push_back(e);
TRACE("demodulator", tout << mk_pp(e, m) << std::endl; );
}
TRACE("demodulator", for (expr* e : *set) tout << mk_pp(e, m) << std::endl; );
}
}

View file

@ -111,11 +111,12 @@ class demodulator_rewriter final {
typedef array_map<expr*, expr_bool_pair, plugin> expr_map;
typedef std::pair<expr *, expr *> expr_pair;
typedef std::pair<app *, expr* > app_expr_pair;
typedef obj_hashtable<expr> expr_set;
typedef obj_map<func_decl, expr_set *> back_idx_map;
typedef obj_hashtable<quantifier> quantifier_set;
typedef obj_map<func_decl, quantifier_set *> fwd_idx_map;
typedef obj_map<quantifier, expr_pair> demodulator2lhs_rhs;
typedef obj_map<quantifier, app_expr_pair> demodulator2lhs_rhs;
typedef expr_map rewrite_cache_map;
/**
@ -172,20 +173,22 @@ class demodulator_rewriter final {
rewrite_cache_map m_rewrite_cache;
expr_ref_buffer m_new_exprs;
void insert_fwd_idx(expr * large, expr * small, quantifier * demodulator);
void insert_fwd_idx(app * large, expr * small, quantifier * demodulator);
void remove_fwd_idx(func_decl * f, quantifier * demodulator);
void insert_bwd_idx(expr* q);
void remove_bwd_idx(expr* q);
bool check_fwd_idx_consistency();
void show_fwd_idx(std::ostream & out);
bool is_demodulator(expr * e, app_ref & large, expr_ref & small) const;
bool can_rewrite(expr * n, expr * lhs);
expr * rewrite(expr * n);
bool rewrite1(func_decl * f, expr_ref_vector & m_new_args, expr_ref & np);
bool rewrite1(func_decl * f, expr_ref_vector const & args, expr_ref & np);
bool rewrite_visit_children(app * a);
void rewrite_cache(expr * e, expr * new_e, bool done);
void reschedule_processed(func_decl * f);
void reschedule_demodulators(func_decl * f, expr * np);
unsigned max_var_id(expr * e);
unsigned max_var_id(expr_ref_vector const& es);
// is_smaller returns -1 for e1<e2, 0 for e1==e2 and +1 for e1>e2.
int is_smaller(expr * e1, expr * e2) const;
@ -197,7 +200,7 @@ public:
demodulator_rewriter(ast_manager & m);
~demodulator_rewriter();
void operator()(unsigned n, expr * const * exprs, expr_ref_vector & new_exprs);
void operator()(expr_ref_vector const& exprs, expr_ref_vector & new_exprs);
/**
Given a demodulator (aka rewrite rule) of the form

View file

@ -55,15 +55,12 @@ public:
demodulator_rewriter dem(m_manager);
expr_ref_vector forms(m_manager), new_forms(m_manager);
proof_ref_vector proofs(m_manager), new_proofs(m_manager);
unsigned size = g->size();
for (unsigned i = 0; i < size; i++) {
for (unsigned i = 0; i < size; i++)
forms.push_back(g->form(i));
proofs.push_back(g->pr(i));
}
dem(forms.size(), forms.data(), new_forms);
dem(forms, new_forms);
g->reset();
for (unsigned i = 0; i < new_forms.size(); i++)