diff --git a/src/smt/params/smt_params.cpp b/src/smt/params/smt_params.cpp index a5b3e4867..f295e260b 100644 --- a/src/smt/params/smt_params.cpp +++ b/src/smt/params/smt_params.cpp @@ -32,6 +32,7 @@ void smt_params::updt_local_params(params_ref const & _p) { m_restart_factor = p.restart_factor(); m_case_split_strategy = static_cast(p.case_split()); m_theory_case_split = p.theory_case_split(); + m_theory_aware_branching = p.theory_aware_branching(); m_delay_units = p.delay_units(); m_delay_units_threshold = p.delay_units_threshold(); m_preprocess = _p.get_bool("preprocess", true); // hidden parameter diff --git a/src/smt/params/smt_params.h b/src/smt/params/smt_params.h index 55346d34f..a0c90a525 100644 --- a/src/smt/params/smt_params.h +++ b/src/smt/params/smt_params.h @@ -112,6 +112,7 @@ struct smt_params : public preprocessor_params, unsigned m_rel_case_split_order; bool m_lookahead_diseq; bool m_theory_case_split; + bool m_theory_aware_branching; // ----------------------------------- // @@ -243,6 +244,7 @@ struct smt_params : public preprocessor_params, m_rel_case_split_order(0), m_lookahead_diseq(false), m_theory_case_split(false), + m_theory_aware_branching(false), m_delay_units(false), m_delay_units_threshold(32), m_theory_resolve(false), diff --git a/src/smt/params/smt_params_helper.pyg b/src/smt/params/smt_params_helper.pyg index 4e3bec57d..8e8e52987 100644 --- a/src/smt/params/smt_params_helper.pyg +++ b/src/smt/params/smt_params_helper.pyg @@ -71,5 +71,6 @@ def_module_params(module_name='smt', ('str.string_constant_cache', BOOL, True, 'cache all generated string constants generated from anywhere in theory_str'), ('str.use_binary_search', BOOL, False, 'use a binary search heuristic for finding concrete length values for free variables in theory_str (set to False to use linear search)'), ('str.binary_search_start', UINT, 64, 'initial upper bound for theory_str binary search'), - ('theory_case_split', BOOL, False, 'Allow the context to use heuristics involving theory case splits, which are a set of literals of which exactly one can be assigned True. If this option is false, the context will generate extra axioms to enforce this instead.') + ('theory_case_split', BOOL, False, 'Allow the context to use heuristics involving theory case splits, which are a set of literals of which exactly one can be assigned True. If this option is false, the context will generate extra axioms to enforce this instead.'), + ('theory_aware_branching', BOOL, False, 'Allow the context to use extra information from theory solvers regarding literal branching prioritization.') )) diff --git a/src/smt/smt_case_split_queue.cpp b/src/smt/smt_case_split_queue.cpp index 06004e3b8..8b02dd6a9 100644 --- a/src/smt/smt_case_split_queue.cpp +++ b/src/smt/smt_case_split_queue.cpp @@ -22,9 +22,13 @@ Revision History: #include"stopwatch.h" #include"for_each_expr.h" #include"ast_pp.h" +#include"map.h" +#include"hashtable.h" namespace smt { + typedef map > theory_var_priority_map; + struct bool_var_act_lt { svector const & m_activity; bool_var_act_lt(svector const & a):m_activity(a) {} @@ -35,6 +39,25 @@ namespace smt { typedef heap bool_var_act_queue; + struct theory_aware_act_lt { + // only take into account theory var priority for now + theory_var_priority_map const & m_theory_var_priority; + theory_aware_act_lt(theory_var_priority_map const & a):m_theory_var_priority(a) {} + bool operator()(bool_var v1, bool_var v2) const { + double p_v1, p_v2; + // safety -- use a large negative number if some var isn't in the map + if (!m_theory_var_priority.find(v1, p_v1)) { + p_v1 = -1000.0; + } + if (!m_theory_var_priority.find(v2, p_v2)) { + p_v2 = -1000.0; + } + return p_v1 > p_v2; + } + }; + + typedef heap theory_aware_act_queue; + /** \brief Case split queue based on activity and random splits. */ @@ -1087,6 +1110,118 @@ namespace smt { } }; + class theory_aware_branching_queue : public case_split_queue { + protected: + context & m_context; + smt_params & m_params; + + theory_var_priority_map m_theory_var_priority; + theory_aware_act_queue m_theory_queue; + case_split_queue * m_base_queue; + int_hashtable > m_theory_vars; + map > m_theory_var_phase; + public: + theory_aware_branching_queue(context & ctx, smt_params & p, case_split_queue * base_queue) : + m_context(ctx), + m_params(p), + m_theory_var_priority(), + m_theory_queue(1024, theory_aware_act_lt(m_theory_var_priority)), + m_base_queue(base_queue) { + } + + virtual void activity_increased_eh(bool_var v) { + if (m_theory_queue.contains(v)) { + m_theory_queue.decreased(v); + } + m_base_queue->activity_increased_eh(v); + } + + virtual void mk_var_eh(bool_var v) { + // do nothing. we only "react" if/when we learn this is an important theory literal + m_base_queue->mk_var_eh(v); + } + + virtual void del_var_eh(bool_var v) { + if (m_theory_queue.contains(v)) { + m_theory_queue.erase(v); + } + m_base_queue->del_var_eh(v); + } + + virtual void assign_lit_eh(literal l) { + m_base_queue->assign_lit_eh(l); + } + + virtual void unassign_var_eh(bool_var v) { + if (m_theory_vars.contains(v) && !m_theory_queue.contains(v)) { + m_theory_queue.insert(v); + } + m_base_queue->unassign_var_eh(v); + } + + virtual void relevant_eh(expr * n) { + m_base_queue->relevant_eh(n); + } + + virtual void init_search_eh() { + m_base_queue->init_search_eh(); + } + + virtual void end_search_eh() { + m_base_queue->end_search_eh(); + } + + virtual void internalize_instance_eh(expr * e, unsigned gen) { + m_base_queue->internalize_instance_eh(e, gen); + } + + virtual void reset() { + m_theory_queue.reset(); + m_theory_vars.reset(); + m_theory_var_phase.reset(); + m_theory_var_priority.reset(); + m_base_queue->reset(); + } + + virtual void push_scope() { + m_base_queue->push_scope(); + } + + virtual void pop_scope(unsigned num_scopes) { + m_base_queue->pop_scope(num_scopes); + } + + virtual void next_case_split(bool_var & next, lbool & phase) { + while (!m_theory_queue.empty()) { + next = m_theory_queue.erase_min(); + // if this literal is unassigned, it is the theory literal with the highest priority, + // so case split on this + if (m_context.get_assignment(next) == l_undef) { + TRACE("theory_aware_branching", tout << "Theory-aware branch on l#" << next << std::endl;); + if (!m_theory_var_phase.find(next, phase)) { + phase = l_undef; + } + return; + } + } + // if we reach this point, the theory literal queue is empty, + // so fall back to the base queue + m_base_queue->next_case_split(next, phase); + } + + virtual void add_theory_aware_branching_info(bool_var v, double priority, lbool phase) { + TRACE("theory_aware_branching", tout << "Add theory-aware branching information for l#" << v << ": priority=" << priority << std::endl;); + m_theory_vars.insert(v); + m_theory_var_phase.insert(v, phase); + m_theory_var_priority.insert(v, priority); + m_theory_queue.insert(v); + } + + virtual void display(std::ostream & out) { + // TODO + m_base_queue->display(out); + } + }; case_split_queue * mk_case_split_queue(context & ctx, smt_params & p) { if (p.m_relevancy_lvl < 2 && (p.m_case_split_strategy == CS_RELEVANCY || p.m_case_split_strategy == CS_RELEVANCY_ACTIVITY || @@ -1099,19 +1234,36 @@ namespace smt { warning_msg("auto configuration (option AUTO_CONFIG) must be disabled to use option CASE_SPLIT=3, 4 or 5"); p.m_case_split_strategy = CS_ACTIVITY; } + + case_split_queue * baseQueue; + switch (p.m_case_split_strategy) { case CS_ACTIVITY_DELAY_NEW: - return alloc(dact_case_split_queue, ctx, p); + baseQueue = alloc(dact_case_split_queue, ctx, p); + break; case CS_ACTIVITY_WITH_CACHE: - return alloc(cact_case_split_queue, ctx, p); + baseQueue = alloc(cact_case_split_queue, ctx, p); + break; case CS_RELEVANCY: - return alloc(rel_case_split_queue, ctx, p); + baseQueue = alloc(rel_case_split_queue, ctx, p); + break; case CS_RELEVANCY_ACTIVITY: - return alloc(rel_act_case_split_queue, ctx, p); + baseQueue = alloc(rel_act_case_split_queue, ctx, p); + break; case CS_RELEVANCY_GOAL: - return alloc(rel_goal_case_split_queue, ctx, p); + baseQueue = alloc(rel_goal_case_split_queue, ctx, p); + break; default: - return alloc(act_case_split_queue, ctx, p); + baseQueue = alloc(act_case_split_queue, ctx, p); + break; + } + + if (p.m_theory_aware_branching) { + TRACE("theory_aware_branching", tout << "Allocating and returning theory-aware branching queue." << std::endl;); + case_split_queue * theory_aware_queue = alloc(theory_aware_branching_queue, ctx, p, baseQueue); + return theory_aware_queue; + } else { + return baseQueue; } } diff --git a/src/smt/smt_case_split_queue.h b/src/smt/smt_case_split_queue.h index e6b217a22..9a3a93cc6 100644 --- a/src/smt/smt_case_split_queue.h +++ b/src/smt/smt_case_split_queue.h @@ -46,6 +46,9 @@ namespace smt { virtual void next_case_split(bool_var & next, lbool & phase) = 0; virtual void display(std::ostream & out) = 0; virtual ~case_split_queue() {} + + // theory-aware branching hint + virtual void add_theory_aware_branching_info(bool_var v, double priority, lbool phase) {} }; case_split_queue * mk_case_split_queue(context & ctx, smt_params & p);