diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 9bdc242b2..509400a50 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -304,7 +304,7 @@ namespace arith { return lp::EQ; } - void solver::mk_eq_axiom(bool is_eq, euf::th_eq const& e) { + void solver::new_eq_eh(euf::th_eq const& e) { theory_var v1 = e.v1(); theory_var v2 = e.v2(); if (is_bool(v1)) @@ -316,23 +316,35 @@ namespace arith { if (e1->get_id() > e2->get_id()) std::swap(e1, e2); - if (is_eq && m.are_equal(e1, e2)) + if (m.are_equal(e1, e2)) return; - if (!is_eq && m.are_distinct(e1, e2)) - return; - if (is_eq) { - ++m_stats.m_assert_eq; - m_new_eq = true; - euf::enode* n1 = var2enode(v1); - euf::enode* n2 = var2enode(v2); - lpvar w1 = register_theory_var_in_lar_solver(v1); - lpvar w2 = register_theory_var_in_lar_solver(v2); - auto cs = lp().add_equality(w1, w2); - add_eq_constraint(cs.first, n1, n2); - add_eq_constraint(cs.second, n1, n2); + ++m_stats.m_assert_eq; + m_new_eq = true; + euf::enode* n1 = var2enode(v1); + euf::enode* n2 = var2enode(v2); + lpvar w1 = register_theory_var_in_lar_solver(v1); + lpvar w2 = register_theory_var_in_lar_solver(v2); + auto cs = lp().add_equality(w1, w2); + add_eq_constraint(cs.first, n1, n2); + add_eq_constraint(cs.second, n1, n2); + } + + void solver::new_diseq_eh(euf::th_eq const& e) { + m_delayed_eqs.push_back(std::make_pair(e, false)); + ctx.push(push_back_vector>>(m_delayed_eqs)); + } + + void solver::mk_diseq_axiom(euf::th_eq const& e) { + if (is_bool(e.v1())) return; - } + force_push(); + expr* e1 = var2expr(e.v1()); + expr* e2 = var2expr(e.v2()); + if (e1->get_id() > e2->get_id()) + std::swap(e1, e2); + if (m.are_distinct(e1, e2)) + return; literal le, ge; if (a.is_numeral(e1)) std::swap(e1, e2); @@ -347,9 +359,7 @@ namespace arith { expr_ref zero(a.mk_numeral(rational(0), a.is_int(e1)), m); rewrite(diff); if (a.is_numeral(diff)) { - if (is_eq && a.is_zero(diff)) - return; - if (!is_eq && !a.is_zero(diff)) + if (!a.is_zero(diff)) return; if (a.is_zero(diff)) add_unit(eq); diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 3469bd846..cfced8a6b 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -969,6 +969,9 @@ namespace arith { TRACE("arith", ctx.display(tout);); + if (!check_delayed_eqs()) + return sat::check_result::CR_CONTINUE; + switch (check_lia()) { case l_true: break; @@ -1064,6 +1067,19 @@ namespace arith { } } + bool solver::check_delayed_eqs() { + for (auto p : m_delayed_eqs) { + auto const& e = p.first; + if (p.second) + new_eq_eh(e); + else if (is_eq(e.v1(), e.v2())) { + mk_diseq_axiom(e); + return false; + } + } + return true; + } + lbool solver::check_lia() { TRACE("arith", ); if (!m.inc()) diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 8659221c9..48a44f0f9 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -164,6 +164,7 @@ namespace arith { svector m_inequalities; // asserted rows corresponding to inequality literals. svector m_equalities; // asserted rows corresponding to equalities. svector m_definitions; // asserted rows corresponding to definitions + svector> m_delayed_eqs; literal_vector m_asserted; expr* m_not_handled{ nullptr }; @@ -305,6 +306,7 @@ namespace arith { literal is_bound_implied(lp::lconstraint_kind k, rational const& value, api_bound const& b) const; void assert_bound(bool is_true, api_bound& b); void mk_eq_axiom(bool is_eq, euf::th_eq const& eq); + void mk_diseq_axiom(euf::th_eq const& eq); void assert_idiv_mod_axioms(theory_var u, theory_var v, theory_var w, rational const& r); api_bound* mk_var_bound(sat::literal lit, theory_var v, lp_api::bound_kind bk, rational const& bound); lp::lconstraint_kind bound2constraint_kind(bool is_int, lp_api::bound_kind bk, bool is_true); @@ -348,6 +350,7 @@ namespace arith { bool use_nra_model(); lbool make_feasible(); + bool check_delayed_eqs(); lbool check_lia(); lbool check_nla(); bool is_infeasible() const; @@ -423,8 +426,8 @@ namespace arith { void collect_statistics(statistics& st) const override; euf::th_solver* clone(euf::solver& ctx) override; bool use_diseqs() const override { return true; } - void new_eq_eh(euf::th_eq const& eq) override { mk_eq_axiom(true, eq); } - void new_diseq_eh(euf::th_eq const& de) override { mk_eq_axiom(false, de); } + void new_eq_eh(euf::th_eq const& eq) override; + void new_diseq_eh(euf::th_eq const& de) override; bool unit_propagate() override; void init_model() override; void finalize_model(model& mdl) override { DEBUG_CODE(dbg_finalize_model(mdl);); }