From 2853b322ca79044d4cb09c38ea7102ef08488c29 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 5 Nov 2013 01:30:34 -0800 Subject: [PATCH] sketch cardinality plugin module Signed-off-by: Nikolaj Bjorner --- src/ast/card_decl_plugin.cpp | 72 ++++++++++ src/ast/card_decl_plugin.h | 83 +++++++++++ src/opt/theory_card.cpp | 272 +++++++++++++++++++++++++++++++++++ src/opt/theory_card.h | 78 ++++++++++ src/opt/weighted_maxsat.cpp | 2 +- 5 files changed, 506 insertions(+), 1 deletion(-) create mode 100644 src/ast/card_decl_plugin.cpp create mode 100644 src/ast/card_decl_plugin.h create mode 100644 src/opt/theory_card.cpp create mode 100644 src/opt/theory_card.h diff --git a/src/ast/card_decl_plugin.cpp b/src/ast/card_decl_plugin.cpp new file mode 100644 index 000000000..bbf4b28bd --- /dev/null +++ b/src/ast/card_decl_plugin.cpp @@ -0,0 +1,72 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + card_decl_plugin.cpp + +Abstract: + + Cardinality Constraints plugin + +Author: + + Nikolaj Bjorner (nbjorner) 2013-05-11 + +Revision History: + +--*/ + +#include "card_decl_plugin.h" + +card_decl_plugin::card_decl_plugin(): + m_at_most_sym("at_most") +{} + +func_decl * card_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range) { + SASSERT(m_manager); + ast_manager& m = *m_manager; + for (unsigned i = 0; i < arity; ++i) { + if (!m.is_bool(domain[i])) { + m.raise_exception("invalid non-Boolean sort applied to 'at_most_k'"); + } + } + if (num_parameters != 1 || !parameters[0].is_int() || parameters[0].get_int() < 0) { + m.raise_exception("function 'at_most_k' expects one non-negative integer parameter"); + } + func_decl_info info(m_family_id, OP_AT_MOST_K, 1, parameters); + return m.mk_func_decl(m_at_most_sym, arity, domain, m.mk_bool_sort(), info); +} + +void card_decl_plugin::get_op_names(svector & op_names, symbol const & logic) { + if (logic == symbol::null) { + op_names.push_back(builtin_name("at-most-k", OP_AT_MOST_K)); + } +} + + +app * card_util::mk_at_most_k(unsigned num_args, expr * const * args, unsigned k) { + parameter param(1); + return m.mk_app(m_fid, OP_AT_MOST_K, 1, ¶m, num_args, args, m.mk_bool_sort()); +} + +bool card_util::is_at_most_k(app *a) const { + return is_app_of(a, m_fid, OP_AT_MOST_K); +} + +bool card_util::is_at_most_k(app *a, unsigned& k) const { + if (is_at_most_k(a)) { + k = get_k(a); + return true; + } + else { + return false; + } +} + +unsigned card_util::get_k(app *a) const { + SASSERT(is_at_most_k(a)); + return static_cast(a->get_decl()->get_parameter(0).get_int()); +} + diff --git a/src/ast/card_decl_plugin.h b/src/ast/card_decl_plugin.h new file mode 100644 index 000000000..9d4e95b13 --- /dev/null +++ b/src/ast/card_decl_plugin.h @@ -0,0 +1,83 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + card_decl_plugin.h + +Abstract: + + Cardinality Constraints plugin + +Author: + + Nikolaj Bjorner (nbjorner) 2013-05-11 + +Notes: + + + (at-most-k x1 .... x_n) means x1 + ... + x_n <= k + +hence: + + (not (at-most-k x1 .... x_n)) means x1 + ... + x_n >= k + 1 + + +--*/ +#ifndef _CARD_DECL_PLUGIN_H_ +#define _CARD_DECL_PLUGIN_H_ + +#include"ast.h" + + + +enum card_op_kind { + OP_AT_MOST_K, + LAST_CARD_OP +}; + + +class card_decl_plugin : public decl_plugin { + symbol m_at_most_sym; + func_decl * mk_at_most(unsigned arity, unsigned k); +public: + card_decl_plugin(); + virtual ~card_decl_plugin() {} + + virtual sort * mk_sort(decl_kind k, unsigned num_parameters, parameter const * parameters) { + UNREACHABLE(); + return 0; + } + + virtual decl_plugin * mk_fresh() { + return alloc(card_decl_plugin); + } + + // + // Contract for func_decl: + // parameters[0] - integer (at most k elements) + // all sorts are Booleans + virtual func_decl * mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, + unsigned arity, sort * const * domain, sort * range); + virtual void get_op_names(svector & op_names, symbol const & logic); + virtual void get_sort_names(svector & sort_names, symbol const & logic); + virtual expr * get_some_value(sort * s); + virtual bool is_fully_interp(sort const * s) const; +}; + + +class card_util { + ast_manager & m; + family_id m_fid; +public: + card_util(ast_manager& m):m(m), m_fid(m.mk_family_id("card")) {} + ast_manager & get_manager() const { return m; } + app * mk_at_most_k(unsigned num_args, expr * const * args, unsigned k); + bool is_at_most_k(app *a) const; + bool is_at_most_k(app *a, unsigned& k) const; + unsigned get_k(app *a) const; +}; + + +#endif /* _CARD_DECL_PLUGIN_H_ */ + diff --git a/src/opt/theory_card.cpp b/src/opt/theory_card.cpp new file mode 100644 index 000000000..f303306b3 --- /dev/null +++ b/src/opt/theory_card.cpp @@ -0,0 +1,272 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + theory_card.cpp + +Abstract: + + Cardinality theory plugin. + +Author: + + Nikolaj Bjorner (nbjorner) 2013-11-05 + +Notes: + + - count number of clauses per cardinality constraint. + - when number of conflicts exceeds n^2 or n*log(n), then create a sorting circuit. + where n is the arity of the cardinality constraint. + - extra: do clauses get re-created? keep track of gc status of created clauses. + +--*/ + +#include "theory_card.h" +#include "smt_context.h" + +namespace smt { + + theory_card::theory_card(ast_manager& m): + theory(m.mk_family_id("card")), + m_util(m) + {} + + theory_card::~theory_card() { + reset_eh(); + } + + theory * theory_card::mk_fresh(context * new_ctx) { + return alloc(theory_card, new_ctx->get_manager()); + } + + bool theory_card::internalize_atom(app * atom, bool gate_ctx) { + context& ctx = get_context(); + ast_manager& m = get_manager(); + unsigned num_args = atom->get_num_args(); + SASSERT(m_util.is_at_most_k(atom)); + unsigned k = m_util.get_k(atom); + bool_var bv; + if (ctx.b_internalized(atom)) { + return false; + } + SASSERT(!ctx.b_internalized(atom)); + bv = ctx.mk_bool_var(atom); + card* c = alloc(card, atom, bv, k); + add_card(c); + // + // TBD take repeated bv into account. + // base case: throw exception. + // refinement: adjust argument list and k for non-repeated values. + // + for (unsigned i = 0; i < num_args; ++i) { + expr* arg = atom->get_arg(i); + if (!ctx.b_internalized(arg)) { + bv = ctx.mk_bool_var(arg); + } + else { + bv = ctx.get_bool_var(arg); + } + ctx.set_var_theory(bv, get_id()); + add_watch(bv, c); + } + return true; + } + + void theory_card::add_watch(bool_var bv, card* c) { + ptr_vector* cards; + if (!m_watch.find(bv, cards)) { + cards = alloc(ptr_vector); + m_watch.insert(bv, cards); + } + cards->push_back(c); + m_watch_trail.push_back(bv); + } + + + void theory_card::reset_eh() { + + // m_watch; + u_map*>::iterator it = m_watch.begin(), end = m_watch.end(); + for (; it != end; ++it) { + dealloc(it->m_value); + } + u_map::iterator itc = m_cards.begin(), endc = m_cards.end(); + for (; itc != endc; ++itc) { + dealloc(itc->m_value); + } + m_watch.reset(); + m_cards.reset(); + m_cards_trail.reset(); + m_cards_lim.reset(); + m_watch_trail.reset(); + m_watch_lim.reset(); + } + + void theory_card::assign_eh(bool_var v, bool is_true) { + context& ctx = get_context(); + ptr_vector* cards = 0; + card* c = 0; + if (m_watch.find(v, cards)) { + for (unsigned i = 0; i < cards->size(); ++i) { + c = (*cards)[i]; + app* atm = c->m_atom; + // + // is_true && m_t + 1 > k -> force false + // !is_true && m_f + 1 >= arity - k -> force true + // + if (is_true && c->m_t >= c->m_k) { + unsigned k = c->m_k; + // force false + switch (ctx.get_assignment(c->m_bv)) { + case l_true: + case l_undef: { + literal_vector& lits = get_lits(); + lits.push_back(literal(v)); + for (unsigned i = 0; i < atm->get_num_args() && lits.size() <= k + 1; ++i) { + expr* arg = atm->get_arg(i); + if (ctx.get_assignment(arg) == l_true) { + lits.push_back(literal(ctx.get_bool_var(arg))); + } + } + SASSERT(lits.size() == k + 1); + add_clause(lits); + break; + } + default: + break; + } + } + else if (!is_true && c->m_k >= atm->get_num_args() - c->m_f) { + // forced true + switch (ctx.get_assignment(c->m_bv)) { + case l_false: + case l_undef: { + literal_vector& lits = get_lits(); + lits.push_back(~literal(v)); + for (unsigned i = 0; i < atm->get_num_args(); ++i) { + expr* arg = atm->get_arg(i); + if (ctx.get_assignment(arg) == l_false) { + lits.push_back(~literal(ctx.get_bool_var(arg))); + } + } + add_clause(lits); + break; + } + default: + break; + } + } + else if (is_true) { + ctx.push_trail(value_trail(c->m_t)); + c->m_t++; + } + else { + ctx.push_trail(value_trail(c->m_f)); + c->m_f++; + } + } + } + if (m_cards.find(v, c)) { + app* atm = to_app(ctx.bool_var2expr(v)); + SASSERT(atm->get_num_args() >= c->m_f + c->m_t); + bool_var bv; + + // at most k + // propagate false to children that are not yet assigned. + // v & t1 & ... & tk => ~l_j + if (is_true && c->m_k <= c->m_t) { + + literal_vector& lits = get_lits(); + lits.push_back(literal(v)); + bool done = false; + for (unsigned i = 0; !done && i < atm->get_num_args(); ++i) { + bv = ctx.get_bool_var(atm->get_arg(i)); + if (ctx.get_assignment(bv) == l_true) { + lits.push_back(literal(bv)); + } + if (lits.size() > c->m_k + 1) { + add_clause(lits); + done = true; + } + } + SASSERT(done || lits.size() == c->m_k + 1); + for (unsigned i = 0; !done && i < atm->get_num_args(); ++i) { + bv = ctx.get_bool_var(atm->get_arg(i)); + if (ctx.get_assignment(bv) == l_undef) { + lits.push_back(literal(bv)); + add_clause(lits); + lits.pop_back(); + } + } + } + // at least k+1: + // !v & !f1 & .. & !f_m => l_j + // for m + k + 1 = arity() + if (!is_true && atm->get_num_args() == 1 + c->m_f + c->m_k) { + literal_vector& lits = get_lits(); + lits.push_back(~literal(v)); + bool done = false; + for (unsigned i = 0; !done && i < atm->get_num_args(); ++i) { + bv = ctx.get_bool_var(atm->get_arg(i)); + if (ctx.get_assignment(bv) == l_false) { + lits.push_back(~literal(bv)); + } + if (lits.size() > c->m_k + 1) { + add_clause(lits); + done = true; + } + } + SASSERT(done || lits.size() == c->m_k + 1); + for (unsigned i = 0; !done && i < atm->get_num_args(); ++i) { + bv = ctx.get_bool_var(atm->get_arg(i)); + if (ctx.get_assignment(bv) != l_false) { + lits.push_back(~literal(bv)); + add_clause(lits); + lits.pop_back(); + } + } + } + } + } + + void theory_card::init_search_eh() { + + } + + void theory_card::push_scope_eh() { + m_watch_lim.push_back(m_watch_trail.size()); + m_cards_lim.push_back(m_cards_trail.size()); + } + + void theory_card::pop_scope_eh(unsigned num_scopes) { + unsigned sz = m_watch_lim[m_watch_lim.size()-num_scopes]; + for (unsigned i = m_watch_trail.size(); i > sz; ) { + --i; + ptr_vector* cards = 0; + VERIFY(m_watch.find(m_watch_trail[i], cards)); + SASSERT(cards && !cards->empty()); + cards->pop_back(); + } + m_watch_lim.resize(m_watch_lim.size()-num_scopes); + sz = m_cards_lim[m_cards_lim.size()-num_scopes]; + for (unsigned i = m_cards_trail.size(); i > sz; ) { + --i; + SASSERT(m_cards.contains(m_cards_trail[i])); + m_cards.remove(m_cards_trail[i]); + } + m_cards_lim.resize(m_cards_lim.size()-num_scopes); + } + + + literal_vector& theory_card::get_lits() { + m_literals.reset(); + return m_literals; + } + + void theory_card::add_clause(literal_vector const& lits) { + context& ctx = get_context(); + ctx.mk_th_axiom(get_id(), lits.size(), lits.c_ptr()); + } + +} diff --git a/src/opt/theory_card.h b/src/opt/theory_card.h new file mode 100644 index 000000000..f38084292 --- /dev/null +++ b/src/opt/theory_card.h @@ -0,0 +1,78 @@ +/*++ +Copyright (c) 2013 Microsoft Corporation + +Module Name: + + theory_card.h + +Abstract: + + Cardinality theory plugin. + +Author: + + Nikolaj Bjorner (nbjorner) 2013-11-05 + +Notes: + + This custom theory handles cardinality constraints + It performs unit propagation and switches to creating + sorting circuits if it keeps having to propagate (create new clauses). +--*/ + +#include "smt_theory.h" +#include "card_decl_plugin.h" + +namespace smt { + class theory_card : public theory { + struct card { + unsigned m_k; + bool_var m_bv; + unsigned m_t; + unsigned m_f; + app* m_atom; + card(app* a, bool_var bv, unsigned k): + m_k(k), m_bv(bv), m_atom(a), m_t(0), m_f(0) + {} + }; + + u_map*> m_watch; // use-list of literals. + u_map m_cards; // bool_var |-> card + unsigned_vector m_cards_trail; + unsigned_vector m_cards_lim; + unsigned_vector m_watch_trail; + unsigned_vector m_watch_lim; + literal_vector m_literals; + card_util m_util; + + void add_watch(bool_var bv, card* c); + + void add_card(card* c) { + m_cards.insert(c->m_bv, c); + m_cards_trail.push_back(c->m_bv); + } + void add_clause(literal_vector const& lits); + literal_vector& get_lits(); + + public: + theory_card(ast_manager& m); + + virtual ~theory_card(); + + virtual theory * mk_fresh(context * new_ctx); + virtual bool internalize_atom(app * atom, bool gate_ctx); + virtual bool internalize_term(app * term) { UNREACHABLE(); return false; } + virtual void new_eq_eh(theory_var v1, theory_var v2) { } + virtual void new_diseq_eh(theory_var v1, theory_var v2) { } + virtual bool use_diseqs() const { return false; } + virtual bool build_models() const { return false; } + virtual final_check_status final_check_eh() { return FC_DONE; } + + virtual void reset_eh(); + virtual void assign_eh(bool_var v, bool is_true); + virtual void init_search_eh(); + virtual void push_scope_eh(); + virtual void pop_scope_eh(unsigned num_scopes); + + }; +}; diff --git a/src/opt/weighted_maxsat.cpp b/src/opt/weighted_maxsat.cpp index 81aa42586..204e6d656 100644 --- a/src/opt/weighted_maxsat.cpp +++ b/src/opt/weighted_maxsat.cpp @@ -3,7 +3,7 @@ Copyright (c) 2013 Microsoft Corporation Module Name: - weighted_maxsat.h + weighted_maxsat.cpp Abstract: Weighted MAXSAT module