diff --git a/src/smt/smt_case_split_queue.cpp b/src/smt/smt_case_split_queue.cpp index 06004e3b8..35cdcb6fe 100644 --- a/src/smt/smt_case_split_queue.cpp +++ b/src/smt/smt_case_split_queue.cpp @@ -22,9 +22,13 @@ Revision History: #include"stopwatch.h" #include"for_each_expr.h" #include"ast_pp.h" +#include"map.h" +#include"hashtable.h" namespace smt { + typedef map > theory_var_priority_map; + struct bool_var_act_lt { svector const & m_activity; bool_var_act_lt(svector const & a):m_activity(a) {} @@ -35,6 +39,27 @@ namespace smt { typedef heap bool_var_act_queue; + struct theory_aware_act_lt { + svector const & m_activity; + theory_var_priority_map const & m_theory_var_priority; + theory_aware_act_lt(svector const & act, theory_var_priority_map const & a):m_activity(act),m_theory_var_priority(a) {} + bool operator()(bool_var v1, bool_var v2) const { + double p_v1, p_v2; + if (!m_theory_var_priority.find(v1, p_v1)) { + p_v1 = 0.0; + } + if (!m_theory_var_priority.find(v2, p_v2)) { + p_v2 = 0.0; + } + // add clause activity + p_v1 += m_activity[v1]; + p_v2 += m_activity[v2]; + return p_v1 > p_v2; + } + }; + + typedef heap theory_aware_act_queue; + /** \brief Case split queue based on activity and random splits. */ @@ -1086,32 +1111,151 @@ namespace smt { m_params.m_qi_eager_threshold += start_gen; } }; - + + class theory_aware_branching_queue : public case_split_queue { + protected: + context & m_context; + smt_params & m_params; + theory_var_priority_map m_theory_var_priority; + theory_aware_act_queue m_queue; + + int_hashtable > m_theory_vars; + map > m_theory_var_phase; + public: + theory_aware_branching_queue(context & ctx, smt_params & p): + m_context(ctx), + m_params(p), + m_theory_var_priority(), + m_queue(1024, theory_aware_act_lt(ctx.get_activity_vector(), m_theory_var_priority)) { + } + + virtual void activity_increased_eh(bool_var v) { + if (m_queue.contains(v)) + m_queue.decreased(v); + } + + virtual void mk_var_eh(bool_var v) { + m_queue.reserve(v+1); + m_queue.insert(v); + } + + virtual void del_var_eh(bool_var v) { + if (m_queue.contains(v)) + m_queue.erase(v); + } + + virtual void unassign_var_eh(bool_var v) { + if (!m_queue.contains(v)) + m_queue.insert(v); + } + + virtual void relevant_eh(expr * n) {} + + virtual void init_search_eh() {} + + virtual void end_search_eh() {} + + virtual void reset() { + m_queue.reset(); + } + + virtual void push_scope() {} + + virtual void pop_scope(unsigned num_scopes) {} + + virtual void next_case_split(bool_var & next, lbool & phase) { + phase = l_undef; + + if (m_context.get_random_value() < static_cast(m_params.m_random_var_freq * random_gen::max_value())) { + next = m_context.get_random_value() % m_context.get_num_b_internalized(); + TRACE("random_split", tout << "next: " << next << " get_assignment(next): " << m_context.get_assignment(next) << "\n";); + if (m_context.get_assignment(next) == l_undef) + return; + } + + while (!m_queue.empty()) { + next = m_queue.erase_min(); + if (m_context.get_assignment(next) == l_undef) + return; + } + + next = null_bool_var; + if (m_theory_vars.contains(next)) { + if (!m_theory_var_phase.find(next, phase)) { + phase = l_undef; + } + } + } + + virtual void add_theory_aware_branching_info(bool_var v, double priority, lbool phase) { + TRACE("theory_aware_branching", tout << "Add theory-aware branching information for l#" << v << ": priority=" << priority << std::endl;); + m_theory_vars.insert(v); + m_theory_var_phase.insert(v, phase); + m_theory_var_priority.insert(v, priority); + if (m_queue.contains(v)) { + if (priority > 0.0) { + m_queue.decreased(v); + } else { + m_queue.increased(v); + } + } + // m_theory_queue.reserve(v+1); + // m_theory_queue.insert(v); + } + + virtual void display(std::ostream & out) { + bool first = true; + bool_var_act_queue::const_iterator it = m_queue.begin(); + bool_var_act_queue::const_iterator end = m_queue.end(); + for (; it != end ; ++it) { + unsigned v = *it; + if (m_context.get_assignment(v) == l_undef) { + if (first) { + out << "remaining case-splits:\n"; + first = false; + } + out << "#" << m_context.bool_var2expr(v)->get_id() << " "; + } + } + if (!first) + out << "\n"; + + } + + virtual ~theory_aware_branching_queue() {}; + }; + case_split_queue * mk_case_split_queue(context & ctx, smt_params & p) { if (p.m_relevancy_lvl < 2 && (p.m_case_split_strategy == CS_RELEVANCY || p.m_case_split_strategy == CS_RELEVANCY_ACTIVITY || - p.m_case_split_strategy == CS_RELEVANCY_GOAL)) { + p.m_case_split_strategy == CS_RELEVANCY_GOAL)) { warning_msg("relevancy must be enabled to use option CASE_SPLIT=3, 4 or 5"); p.m_case_split_strategy = CS_ACTIVITY; } if (p.m_auto_config && (p.m_case_split_strategy == CS_RELEVANCY || p.m_case_split_strategy == CS_RELEVANCY_ACTIVITY || - p.m_case_split_strategy == CS_RELEVANCY_GOAL)) { + p.m_case_split_strategy == CS_RELEVANCY_GOAL)) { warning_msg("auto configuration (option AUTO_CONFIG) must be disabled to use option CASE_SPLIT=3, 4 or 5"); p.m_case_split_strategy = CS_ACTIVITY; } - switch (p.m_case_split_strategy) { - case CS_ACTIVITY_DELAY_NEW: - return alloc(dact_case_split_queue, ctx, p); - case CS_ACTIVITY_WITH_CACHE: - return alloc(cact_case_split_queue, ctx, p); - case CS_RELEVANCY: - return alloc(rel_case_split_queue, ctx, p); - case CS_RELEVANCY_ACTIVITY: - return alloc(rel_act_case_split_queue, ctx, p); - case CS_RELEVANCY_GOAL: - return alloc(rel_goal_case_split_queue, ctx, p); - default: - return alloc(act_case_split_queue, ctx, p); + + if (p.m_theory_aware_branching) { + // override + return alloc(theory_aware_branching_queue, ctx, p); + } else { + switch (p.m_case_split_strategy) { + case CS_ACTIVITY_DELAY_NEW: + return alloc(dact_case_split_queue, ctx, p); + case CS_ACTIVITY_WITH_CACHE: + return alloc(cact_case_split_queue, ctx, p); + case CS_RELEVANCY: + return alloc(rel_case_split_queue, ctx, p); + case CS_RELEVANCY_ACTIVITY: + return alloc(rel_act_case_split_queue, ctx, p); + case CS_RELEVANCY_GOAL: + return alloc(rel_goal_case_split_queue, ctx, p); + default: + return alloc(act_case_split_queue, ctx, p); + } } } diff --git a/src/smt/smt_case_split_queue.h b/src/smt/smt_case_split_queue.h index e6b217a22..9a3a93cc6 100644 --- a/src/smt/smt_case_split_queue.h +++ b/src/smt/smt_case_split_queue.h @@ -46,6 +46,9 @@ namespace smt { virtual void next_case_split(bool_var & next, lbool & phase) = 0; virtual void display(std::ostream & out) = 0; virtual ~case_split_queue() {} + + // theory-aware branching hint + virtual void add_theory_aware_branching_info(bool_var v, double priority, lbool phase) {} }; case_split_queue * mk_case_split_queue(context & ctx, smt_params & p); diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index abac9ef82..a6cef548f 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -3022,6 +3022,10 @@ namespace smt { } } + void context::add_theory_aware_branching_info(bool_var v, double priority, lbool phase) { + m_case_split_queue->add_theory_aware_branching_info(v, priority, phase); + } + bool context::propagate_th_case_split(unsigned qhead) { if (m_all_th_case_split_literals.empty()) return true; diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index ca2429be8..9c70f5999 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -838,12 +838,20 @@ namespace smt { */ void mk_th_case_split(unsigned num_lits, literal * lits); + + /* + * Provide a hint to the branching heuristic about the priority of a "theory-aware literal". + * Literals marked in this way will always be branched on before unmarked literals, + * starting with the literal having the highest priority. + */ + void add_theory_aware_branching_info(bool_var v, double priority, lbool phase); + public: // helper function for trail void undo_th_case_split(literal l); - bool propagate_th_case_split(unsigned qhead); + bool propagate_th_case_split(unsigned qhead); bool_var mk_bool_var(expr * n);