diff --git a/src/opt/pb_sls.cpp b/src/opt/pb_sls.cpp index 3ac91b33a..06d137c09 100644 --- a/src/opt/pb_sls.cpp +++ b/src/opt/pb_sls.cpp @@ -45,8 +45,8 @@ namespace smt { rational m_penalty; // current penalty of soft constraints vector m_hard_occ, m_soft_occ; // variable occurrence svector m_assignment; // current assignment. - obj_map m_expr2var; // map expressions to Boolean variables. - ptr_vector m_var2expr; // reverse map + obj_map m_decl2var; // map declarations to Boolean variables. + ptr_vector m_var2decl; // reverse map uint_set m_hard_false; // list of hard clauses that are false. uint_set m_soft_false; // list of soft clauses that are false. unsigned m_max_flips; @@ -75,10 +75,11 @@ namespace smt { clause cls(mgr); if (compile_clause(f, cls)) { m_clauses.push_back(cls); + m_weights.push_back(w); } } - void init_value(expr* f, bool b) { + void set_value(func_decl* f, bool b) { literal lit = mk_literal(f); SASSERT(!lit.sign()); m_assignment[lit.var()] = b; @@ -95,14 +96,6 @@ namespace smt { return l_undef; } - bool get_value(expr* f) { - unsigned var; - if (m_expr2var.find(f, var)) { - return m_assignment[var]; - } - UNREACHABLE(); - return true; - } bool get_value(literal l) { return l.sign() ^ m_assignment[l.var()]; } @@ -267,6 +260,21 @@ namespace smt { return break_count; } + literal mk_literal(func_decl* f) { + SASSERT(f->get_family_id() == null_family_id); + unsigned var; + if (!m_expr2var.find(f, var)) { + var = m_hard_occ.size(); + SASSERT(m_expr2var.size() == var); + m_hard_occ.push_back(unsigned_vector()); + m_soft_occ.push_back(unsigned_vector()); + m_assignment.push_back(false); + m_expr2var.insert(f, var); + m_var2decl.push_back(f); + } + return literal(var); + } + literal mk_literal(expr* f) { literal result; bool sign = false; @@ -279,18 +287,12 @@ namespace smt { else if (m.is_false(f)) { result = false_literal; } + else if (is_uninterp_const(f)) { + result = mk_literal(to_app(f)->get_decl()); + } else { - unsigned var; - if (!m_expr2var.find(f, var)) { - var = m_hard_occ.size(); - SASSERT(m_expr2var.size() == var); - m_hard_occ.push_back(unsigned_vector()); - m_soft_occ.push_back(unsigned_vector()); - m_assignment.push_back(false); - m_expr2var.insert(f, var); - m_var2expr.push_back(f); - } - result = literal(var); + IF_VERBOSE(0, verbose_stream() << "not handled: " << mk_pp(f, m) << "\n";); + result = null_literal; } if (sign) { result.neg(); @@ -363,15 +365,12 @@ namespace smt { void pb_sls::add(expr* f, rational const& w) { m_imp->add(f, w); } - void pb_sls::init_value(expr* f, bool b) { - m_imp->init_value(f, b); + void pb_sls::set_value(func_decl* f, bool b) { + m_imp->set_value(f, b); } lbool pb_sls::operator()() { return (*m_imp)(); } - bool pb_sls::get_value(expr* f) { - return m_imp->get_value(f); - } void pb_sls::set_cancel(bool f) { m_imp->set_cancel(f); } diff --git a/src/opt/pb_sls.h b/src/opt/pb_sls.h index 01362fe31..05c7becaa 100644 --- a/src/opt/pb_sls.h +++ b/src/opt/pb_sls.h @@ -35,9 +35,8 @@ namespace smt { ~pb_sls(); void add(expr* f); void add(expr* f, rational const& w); - void init_value(expr* f, bool b); + void set_value(func_decl* f, bool b); lbool operator()(); - bool get_value(expr* f); void set_cancel(bool f); void collect_statistics(statistics& st) const; void get_model(model_ref& mdl);