3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-06-25 23:33:41 +00:00

overhaul urle_set

Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
Nikolaj Bjorner 2013-04-09 10:15:20 -07:00
commit 9456f16a4c
15 changed files with 233 additions and 253 deletions

View file

@ -707,6 +707,7 @@ namespace datalog {
//update the head relation //update the head relation
make_union(new_head_reg, head_reg, delta_reg, use_widening, acc); make_union(new_head_reg, head_reg, delta_reg, use_widening, acc);
make_dealloc_non_void(new_head_reg, acc);
} }
finish: finish:

View file

@ -756,6 +756,10 @@ namespace datalog {
add_fact(head->get_decl(), fact); add_fact(head->get_decl(), fact);
} }
bool context::has_facts(func_decl * pred) const {
return m_rel && m_rel->has_facts(pred);
}
void context::add_table_fact(func_decl * pred, const table_fact & fact) { void context::add_table_fact(func_decl * pred, const table_fact & fact) {
if (get_engine() == DATALOG_ENGINE) { if (get_engine() == DATALOG_ENGINE) {
ensure_rel(); ensure_rel();

View file

@ -257,6 +257,7 @@ namespace datalog {
void add_fact(app * head); void add_fact(app * head);
void add_fact(func_decl * pred, const relation_fact & fact); void add_fact(func_decl * pred, const relation_fact & fact);
bool has_facts(func_decl * pred) const;
void add_rule(rule_ref& r); void add_rule(rule_ref& r);

View file

@ -309,6 +309,8 @@ namespace datalog {
if (!m_context.magic_sets_for_queries()) { if (!m_context.magic_sets_for_queries()) {
return 0; return 0;
} }
SASSERT(source.contains(m_goal));
SASSERT(source.get_predicate_rules(m_goal).size() == 1);
app * goal_head = source.get_predicate_rules(m_goal)[0]->get_head(); app * goal_head = source.get_predicate_rules(m_goal)[0]->get_head();

View file

@ -42,8 +42,8 @@ namespace datalog {
/** /**
Allows to traverse head and positive tails in a single for loop starting from -1 Allows to traverse head and positive tails in a single for loop starting from -1
*/ */
app * get_by_tail_index(rule * r, int idx) { static app * get_by_tail_index(rule * r, int idx) {
if (idx == -1) { if (idx < 0) {
return r->get_head(); return r->get_head();
} }
SASSERT(idx < static_cast<int>(r->get_positive_tail_size())); SASSERT(idx < static_cast<int>(r->get_positive_tail_size()));
@ -51,11 +51,11 @@ namespace datalog {
} }
template<typename T> template<typename T>
int aux_compare(T a, T b) { static int aux_compare(T a, T b) {
return (a>b) ? 1 : ( (a==b) ? 0 : -1); return (a>b) ? 1 : ( (a==b) ? 0 : -1);
} }
int compare_var_args(app* t1, app* t2) { static int compare_var_args(app* t1, app* t2) {
SASSERT(t1->get_num_args()==t2->get_num_args()); SASSERT(t1->get_num_args()==t2->get_num_args());
int res; int res;
unsigned n = t1->get_num_args(); unsigned n = t1->get_num_args();
@ -76,7 +76,7 @@ namespace datalog {
return 0; return 0;
} }
int compare_args(app* t1, app* t2, int & skip_countdown) { static int compare_args(app* t1, app* t2, int & skip_countdown) {
SASSERT(t1->get_num_args()==t2->get_num_args()); SASSERT(t1->get_num_args()==t2->get_num_args());
int res; int res;
unsigned n = t1->get_num_args(); unsigned n = t1->get_num_args();
@ -101,7 +101,7 @@ namespace datalog {
Two rules are in the same rough similarity class if they differ only in constant arguments Two rules are in the same rough similarity class if they differ only in constant arguments
of positive uninterpreted predicates. of positive uninterpreted predicates.
*/ */
int rough_compare(rule * r1, rule * r2) { static int rough_compare(rule * r1, rule * r2) {
int res = aux_compare(r1->get_tail_size(), r2->get_tail_size()); int res = aux_compare(r1->get_tail_size(), r2->get_tail_size());
if (res!=0) { return res; } if (res!=0) { return res; }
res = aux_compare(r1->get_uninterpreted_tail_size(), r2->get_uninterpreted_tail_size()); res = aux_compare(r1->get_uninterpreted_tail_size(), r2->get_uninterpreted_tail_size());
@ -132,7 +132,7 @@ namespace datalog {
\c r1 and \c r2 must be equal according to the \c rough_compare function for this function \c r1 and \c r2 must be equal according to the \c rough_compare function for this function
to be called. to be called.
*/ */
int total_compare(rule * r1, rule * r2, int skipped_arg_index = INT_MAX) { static int total_compare(rule * r1, rule * r2, int skipped_arg_index = INT_MAX) {
SASSERT(rough_compare(r1, r2)==0); SASSERT(rough_compare(r1, r2)==0);
int pos_tail_sz = r1->get_positive_tail_size(); int pos_tail_sz = r1->get_positive_tail_size();
for (int i=-1; i<pos_tail_sz; i++) { for (int i=-1; i<pos_tail_sz; i++) {
@ -168,7 +168,7 @@ namespace datalog {
typedef svector<const_info> info_vector; typedef svector<const_info> info_vector;
void collect_const_indexes(app * t, int tail_index, info_vector & res) { static void collect_const_indexes(app * t, int tail_index, info_vector & res) {
unsigned n = t->get_num_args(); unsigned n = t->get_num_args();
for (unsigned i=0; i<n; i++) { for (unsigned i=0; i<n; i++) {
if (is_var(t->get_arg(i))) { if (is_var(t->get_arg(i))) {
@ -178,7 +178,7 @@ namespace datalog {
} }
} }
void collect_const_indexes(rule * r, info_vector & res) { static void collect_const_indexes(rule * r, info_vector & res) {
collect_const_indexes(r->get_head(), -1, res); collect_const_indexes(r->get_head(), -1, res);
unsigned pos_tail_sz = r->get_positive_tail_size(); unsigned pos_tail_sz = r->get_positive_tail_size();
for (unsigned i=0; i<pos_tail_sz; i++) { for (unsigned i=0; i<pos_tail_sz; i++) {
@ -187,7 +187,7 @@ namespace datalog {
} }
template<class T> template<class T>
void collect_orphan_consts(rule * r, const info_vector & const_infos, T & tgt) { static void collect_orphan_consts(rule * r, const info_vector & const_infos, T & tgt) {
unsigned const_cnt = const_infos.size(); unsigned const_cnt = const_infos.size();
tgt.reset(); tgt.reset();
for (unsigned i=0; i<const_cnt; i++) { for (unsigned i=0; i<const_cnt; i++) {
@ -201,7 +201,7 @@ namespace datalog {
} }
} }
template<class T> template<class T>
void collect_orphan_sorts(rule * r, const info_vector & const_infos, T & tgt) { static void collect_orphan_sorts(rule * r, const info_vector & const_infos, T & tgt) {
unsigned const_cnt = const_infos.size(); unsigned const_cnt = const_infos.size();
tgt.reset(); tgt.reset();
for (unsigned i=0; i<const_cnt; i++) { for (unsigned i=0; i<const_cnt; i++) {
@ -218,7 +218,7 @@ namespace datalog {
\brief From the \c tail_indexes and \c arg_indexes remove elements corresponding to constants \brief From the \c tail_indexes and \c arg_indexes remove elements corresponding to constants
that are the same in rules \c *first ... \c *(after_last-1). that are the same in rules \c *first ... \c *(after_last-1).
*/ */
void remove_stable_constants(rule_vector::iterator first, rule_vector::iterator after_last, static void remove_stable_constants(rule_vector::iterator first, rule_vector::iterator after_last,
info_vector & const_infos) { info_vector & const_infos) {
SASSERT(after_last-first>1); SASSERT(after_last-first>1);
unsigned const_cnt = const_infos.size(); unsigned const_cnt = const_infos.size();
@ -255,7 +255,7 @@ namespace datalog {
first constant that is equal to it in all the rules. If there is no such, it will contain first constant that is equal to it in all the rules. If there is no such, it will contain
its own index. its own index.
*/ */
void detect_equal_constants(rule_vector::iterator first, rule_vector::iterator after_last, static void detect_equal_constants(rule_vector::iterator first, rule_vector::iterator after_last,
info_vector & const_infos) { info_vector & const_infos) {
SASSERT(first!=after_last); SASSERT(first!=after_last);
unsigned const_cnt = const_infos.size(); unsigned const_cnt = const_infos.size();
@ -305,7 +305,7 @@ namespace datalog {
} }
} }
unsigned get_constant_count(rule * r) { static unsigned get_constant_count(rule * r) {
unsigned res = r->get_head()->get_num_args() - count_variable_arguments(r->get_head()); unsigned res = r->get_head()->get_num_args() - count_variable_arguments(r->get_head());
unsigned pos_tail_sz = r->get_positive_tail_size(); unsigned pos_tail_sz = r->get_positive_tail_size();
for (unsigned i=0; i<pos_tail_sz; i++) { for (unsigned i=0; i<pos_tail_sz; i++) {
@ -314,7 +314,7 @@ namespace datalog {
return res; return res;
} }
bool initial_comparator(rule * r1, rule * r2) { static bool initial_comparator(rule * r1, rule * r2) {
int res = rough_compare(r1, r2); int res = rough_compare(r1, r2);
if (res!=0) { return res>0; } if (res!=0) { return res>0; }
return total_compare(r1, r2)>0; return total_compare(r1, r2)>0;

View file

@ -499,10 +499,7 @@ namespace datalog {
m_data.ensure_reserve(); m_data.ensure_reserve();
char * reserve = m_data.get_reserve_ptr(); char * reserve = m_data.get_reserve_ptr();
unsigned col_cnt = m_column_layout.size(); unsigned col_cnt = m_column_layout.size();
for(unsigned i=0; i<col_cnt; i++) { for (unsigned i = 0; i < col_cnt; ++i) {
if (f[i] >= get_signature()[i]) {
std::cout << f[i] << " " << get_signature()[i] << "\n";
}
SASSERT(f[i] < get_signature()[i]); //the value fits into the table signature SASSERT(f[i] < get_signature()[i]); //the value fits into the table signature
m_column_layout.set(reserve, i, f[i]); m_column_layout.set(reserve, i, f[i]);
} }

View file

@ -107,15 +107,16 @@ namespace datalog {
scoped_query scoped_query(m_context); scoped_query scoped_query(m_context);
m_code.reset();
instruction_block termination_code; instruction_block termination_code;
m_ectx.reset();
lbool result; lbool result;
TRACE("dl", m_context.display(tout);); TRACE("dl", m_context.display(tout););
while (true) { while (true) {
m_ectx.reset();
m_code.reset();
termination_code.reset();
m_context.ensure_closed(); m_context.ensure_closed();
m_context.transform_rules(); m_context.transform_rules();
if (m_context.canceled()) { if (m_context.canceled()) {
@ -174,8 +175,6 @@ namespace datalog {
else { else {
restart_time = static_cast<unsigned>(new_restart_time); restart_time = static_cast<unsigned>(new_restart_time);
} }
termination_code.reset();
scoped_query.reset(); scoped_query.reset();
} }
m_context.record_transformed_rules(); m_context.record_transformed_rules();
@ -452,6 +451,11 @@ namespace datalog {
} }
} }
bool rel_context::has_facts(func_decl * pred) const {
relation_base* r = try_get_relation(pred);
return r && !r->empty();
}
void rel_context::store_relation(func_decl * pred, relation_base * rel) { void rel_context::store_relation(func_decl * pred, relation_base * rel) {
get_rmanager().store_relation(pred, rel); get_rmanager().store_relation(pred, rel);
} }

View file

@ -87,10 +87,15 @@ namespace datalog {
*/ */
bool result_contains_fact(relation_fact const& f); bool result_contains_fact(relation_fact const& f);
/** \brief add facts to relation
*/
void add_fact(func_decl* pred, relation_fact const& fact); void add_fact(func_decl* pred, relation_fact const& fact);
void add_fact(func_decl* pred, table_fact const& fact); void add_fact(func_decl* pred, table_fact const& fact);
/** \brief check if facts were added to relation
*/
bool has_facts(func_decl * pred) const;
/** /**
\brief Store the relation \c rel under the predicate \c pred. The \c context object \brief Store the relation \c rel under the predicate \c pred. The \c context object
takes over the ownership of the relation object. takes over the ownership of the relation object.

View file

@ -29,6 +29,7 @@ Revision History:
#include"th_rewriter.h" #include"th_rewriter.h"
#include"filter_model_converter.h" #include"filter_model_converter.h"
#include"ast_smt2_pp.h" #include"ast_smt2_pp.h"
#include"expr_replacer.h"
/* /*
---- ----
@ -131,18 +132,16 @@ struct purify_arith_proc {
proof_ref_vector m_new_cnstr_prs; proof_ref_vector m_new_cnstr_prs;
expr_ref m_subst; expr_ref m_subst;
proof_ref m_subst_pr; proof_ref m_subst_pr;
bool m_in_q; expr_ref_vector m_new_vars;
unsigned m_var_idx;
rw_cfg(purify_arith_proc & o, bool in_q): rw_cfg(purify_arith_proc & o):
m_owner(o), m_owner(o),
m_pinned(o.m()), m_pinned(o.m()),
m_new_cnstrs(o.m()), m_new_cnstrs(o.m()),
m_new_cnstr_prs(o.m()), m_new_cnstr_prs(o.m()),
m_subst(o.m()), m_subst(o.m()),
m_subst_pr(o.m()), m_subst_pr(o.m()),
m_in_q(in_q), m_new_vars(o.m()) {
m_var_idx(0) {
} }
ast_manager & m() { return m_owner.m(); } ast_manager & m() { return m_owner.m(); }
@ -155,14 +154,9 @@ struct purify_arith_proc {
bool elim_inverses() const { return m_owner.m_elim_inverses; } bool elim_inverses() const { return m_owner.m_elim_inverses; }
expr * mk_fresh_var(bool is_int) { expr * mk_fresh_var(bool is_int) {
if (m_in_q) { expr * r = m().mk_fresh_const(0, is_int ? u().mk_int() : u().mk_real());
unsigned idx = m_var_idx; m_new_vars.push_back(r);
m_var_idx++; return r;
return m().mk_var(idx, is_int ? u().mk_int() : u().mk_real());
}
else {
return m().mk_fresh_const(0, is_int ? u().mk_int() : u().mk_real());
}
} }
expr * mk_fresh_real_var() { return mk_fresh_var(false); } expr * mk_fresh_real_var() { return mk_fresh_var(false); }
@ -596,105 +590,51 @@ struct purify_arith_proc {
struct rw : public rewriter_tpl<rw_cfg> { struct rw : public rewriter_tpl<rw_cfg> {
rw_cfg m_cfg; rw_cfg m_cfg;
rw(purify_arith_proc & o, bool in_q): rw(purify_arith_proc & o):
rewriter_tpl<rw_cfg>(o.m(), o.m_produce_proofs, m_cfg), rewriter_tpl<rw_cfg>(o.m(), o.m_produce_proofs, m_cfg),
m_cfg(o, in_q) { m_cfg(o) {
}
};
/**
\brief Return the number of (auxiliary) variables needed for converting an expression.
*/
struct num_vars_proc {
arith_util & m_util;
expr_fast_mark1 m_visited;
ptr_vector<expr> m_todo;
unsigned m_num_vars;
bool m_elim_root_objs;
num_vars_proc(arith_util & u, bool elim_root_objs):
m_util(u),
m_elim_root_objs(elim_root_objs) {
}
void visit(expr * t) {
if (m_visited.is_marked(t))
return;
m_visited.mark(t);
m_todo.push_back(t);
}
void process(app * t) {
if (t->get_family_id() == m_util.get_family_id()) {
if (m_util.is_power(t)) {
rational k;
if (m_util.is_numeral(t->get_arg(1), k) && (k.is_zero() || !k.is_int())) {
m_num_vars++;
}
}
else if (m_util.is_div(t) ||
m_util.is_idiv(t) ||
m_util.is_mod(t) ||
m_util.is_to_int(t) ||
(m_util.is_irrational_algebraic_numeral(t) && m_elim_root_objs)) {
m_num_vars++;
}
}
unsigned num_args = t->get_num_args();
for (unsigned i = 0; i < num_args; i++)
visit(t->get_arg(i));
}
unsigned operator()(expr * t) {
m_num_vars = 0;
visit(t);
while (!m_todo.empty()) {
expr * t = m_todo.back();
m_todo.pop_back();
if (is_app(t))
process(to_app(t));
}
m_visited.reset();
return m_num_vars;
} }
}; };
void process_quantifier(quantifier * q, expr_ref & result, proof_ref & result_pr) { void process_quantifier(quantifier * q, expr_ref & result, proof_ref & result_pr) {
result_pr = 0; result_pr = 0;
num_vars_proc p(u(), m_elim_root_objs); rw r(*this);
expr_ref body(m());
unsigned num_vars = p(q->get_expr());
if (num_vars > 0) {
// open space for aux vars
var_shifter shifter(m());
shifter(q->get_expr(), num_vars, body);
}
else {
body = q->get_expr();
}
rw r(*this, true);
expr_ref new_body(m()); expr_ref new_body(m());
proof_ref new_body_pr(m()); proof_ref new_body_pr(m());
r(body, new_body, new_body_pr); r(q->get_expr(), new_body, new_body_pr);
unsigned num_vars = r.cfg().m_new_vars.size();
TRACE("purify_arith", TRACE("purify_arith",
tout << "num_vars: " << num_vars << "\n"; tout << "num_vars: " << num_vars << "\n";
tout << "body: " << mk_ismt2_pp(body, m()) << "\nnew_body: " << mk_ismt2_pp(new_body, m()) << "\n";); tout << "body: " << mk_ismt2_pp(q->get_expr(), m()) << "\nnew_body: " << mk_ismt2_pp(new_body, m()) << "\n";);
if (num_vars == 0) { if (num_vars == 0) {
SASSERT(r.cfg().m_new_cnstrs.empty());
result = m().update_quantifier(q, new_body); result = m().update_quantifier(q, new_body);
if (m_produce_proofs) if (m_produce_proofs)
result_pr = m().mk_quant_intro(q, to_quantifier(result.get()), result_pr); result_pr = m().mk_quant_intro(q, to_quantifier(result.get()), result_pr);
} }
else { else {
// Add new constraints
expr_ref_vector & cnstrs = r.cfg().m_new_cnstrs; expr_ref_vector & cnstrs = r.cfg().m_new_cnstrs;
cnstrs.push_back(new_body); cnstrs.push_back(new_body);
new_body = m().mk_and(cnstrs.size(), cnstrs.c_ptr()); new_body = m().mk_and(cnstrs.size(), cnstrs.c_ptr());
// Open space for new variables
var_shifter shifter(m());
shifter(new_body, num_vars, new_body);
// Rename fresh constants in r.cfg().m_new_vars to variables
ptr_buffer<sort> sorts; ptr_buffer<sort> sorts;
buffer<symbol> names; buffer<symbol> names;
expr_substitution subst(m(), false, false);
for (unsigned i = 0; i < num_vars; i++) { for (unsigned i = 0; i < num_vars; i++) {
sorts.push_back(u().mk_real()); expr * c = r.cfg().m_new_vars.get(i);
sort * s = get_sort(c);
sorts.push_back(s);
names.push_back(m().mk_fresh_var_name("x")); names.push_back(m().mk_fresh_var_name("x"));
unsigned idx = num_vars - i - 1;
subst.insert(c, m().mk_var(idx, s));
} }
scoped_ptr<expr_replacer> replacer = mk_default_expr_replacer(m());
replacer->set_substitution(&subst);
(*replacer)(new_body, new_body);
new_body = m().mk_exists(num_vars, sorts.c_ptr(), names.c_ptr(), new_body); new_body = m().mk_exists(num_vars, sorts.c_ptr(), names.c_ptr(), new_body);
result = m().update_quantifier(q, new_body); result = m().update_quantifier(q, new_body);
if (m_produce_proofs) { if (m_produce_proofs) {
@ -708,7 +648,7 @@ struct purify_arith_proc {
} }
void operator()(goal & g, model_converter_ref & mc, bool produce_models) { void operator()(goal & g, model_converter_ref & mc, bool produce_models) {
rw r(*this, false); rw r(*this);
// purify // purify
expr_ref new_curr(m()); expr_ref new_curr(m());
proof_ref new_pr(m()); proof_ref new_pr(m());

View file

@ -166,7 +166,7 @@ static void tst3() {
void tst_diff_logic() { void tst_diff_logic() {
tst1(); tst1();
tst2(); tst2();
tst3(); // tst3();
} }
#else #else
void tst_diff_logic() { void tst_diff_logic() {

View file

@ -48,6 +48,7 @@ void dl_query_test(ast_manager & m, smt_params & fparams, params_ref& params,
bool use_magic_sets) { bool use_magic_sets) {
dl_decl_util decl_util(m); dl_decl_util decl_util(m);
random_gen ran(0);
context ctx_q(m, fparams); context ctx_q(m, fparams);
params.set_bool("magic_sets_for_queries", use_magic_sets); params.set_bool("magic_sets_for_queries", use_magic_sets);
@ -86,7 +87,7 @@ void dl_query_test(ast_manager & m, smt_params & fparams, params_ref& params,
warning_msg("cannot get sort size"); warning_msg("cannot get sort size");
return; return;
} }
uint64 num = rand()%sort_sz; uint64 num = ran()%sort_sz;
app * el_b = decl_util.mk_numeral(num, sig_b[col]); app * el_b = decl_util.mk_numeral(num, sig_b[col]);
f_b.push_back(el_b); f_b.push_back(el_b);
app * el_q = decl_util.mk_numeral(num, sig_q[col]); app * el_q = decl_util.mk_numeral(num, sig_q[col]);
@ -112,7 +113,7 @@ void dl_query_test(ast_manager & m, smt_params & fparams, params_ref& params,
table_base::iterator fit = table_b.begin(); table_base::iterator fit = table_b.begin();
table_base::iterator fend = table_b.end(); table_base::iterator fend = table_b.end();
for(; fit!=fend; ++fit) { for(; fit!=fend; ++fit) {
if(rand()%std::max(1u,table_sz/test_count)!=0) { if(ran()%std::max(1u,table_sz/test_count)!=0) {
continue; continue;
} }
fit->get_fact(tf); fit->get_fact(tf);
@ -131,6 +132,7 @@ void dl_query_test_wpa(smt_params & fparams, params_ref& params) {
arith_util arith(m); arith_util arith(m);
const char * problem_dir = "C:\\tvm\\src\\z3_2\\debug\\test\\w0.datalog"; const char * problem_dir = "C:\\tvm\\src\\z3_2\\debug\\test\\w0.datalog";
dl_decl_util dl_util(m); dl_decl_util dl_util(m);
random_gen ran(0);
std::cerr << "Testing queries on " << problem_dir <<"\n"; std::cerr << "Testing queries on " << problem_dir <<"\n";
context ctx(m, fparams); context ctx(m, fparams);
@ -151,8 +153,8 @@ void dl_query_test_wpa(smt_params & fparams, params_ref& params) {
TRUSTME( ctx.try_get_sort_constant_count(var_sort, var_sz) ); TRUSTME( ctx.try_get_sort_constant_count(var_sort, var_sz) );
for(unsigned attempt=0; attempt<attempts; attempt++) { for(unsigned attempt=0; attempt<attempts; attempt++) {
unsigned el1 = rand()%var_sz; unsigned el1 = ran()%var_sz;
unsigned el2 = rand()%var_sz; unsigned el2 = ran()%var_sz;
expr_ref_vector q_args(m); expr_ref_vector q_args(m);
q_args.push_back(dl_util.mk_numeral(el1, var_sort)); q_args.push_back(dl_util.mk_numeral(el1, var_sort));
@ -217,6 +219,9 @@ void tst_dl_query() {
params.set_uint("similarity_compressor", use_similar != 0); params.set_uint("similarity_compressor", use_similar != 0);
for(unsigned use_magic_sets=0; use_magic_sets<=1; use_magic_sets++) { for(unsigned use_magic_sets=0; use_magic_sets<=1; use_magic_sets++) {
if (!(use_restarts == 1 && use_similar == 0 && use_magic_sets == 1)) {
continue;
}
stopwatch watch; stopwatch watch;
watch.start(); watch.start();
std::cerr << "------- " << (use_restarts ? "With" : "Without") << " restarts -------\n"; std::cerr << "------- " << (use_restarts ? "With" : "Without") << " restarts -------\n";

View file

@ -171,7 +171,7 @@ static void tst2() {
rational int64_max("9223372036854775807"); rational int64_max("9223372036854775807");
rational int64_min(-int64_max - rational(1)); rational int64_min((-int64_max) - rational(1));
// is_int64 // is_int64
SASSERT(int64_max.is_int64()); SASSERT(int64_max.is_int64());
SASSERT(int64_min.is_int64()); SASSERT(int64_min.is_int64());

View file

@ -120,6 +120,7 @@ mpz_manager<SYNCH>::mpz_manager():
mpz_set_ui(m_tmp, max_l); mpz_set_ui(m_tmp, max_l);
mpz_add(m_uint64_max, m_uint64_max, m_tmp); mpz_add(m_uint64_max, m_uint64_max, m_tmp);
mpz_init(m_int64_max); mpz_init(m_int64_max);
mpz_init(m_int64_min);
max_l = static_cast<unsigned>(INT64_MAX % static_cast<int64>(UINT_MAX)); max_l = static_cast<unsigned>(INT64_MAX % static_cast<int64>(UINT_MAX));
max_h = static_cast<unsigned>(INT64_MAX / static_cast<int64>(UINT_MAX)); max_h = static_cast<unsigned>(INT64_MAX / static_cast<int64>(UINT_MAX));
@ -128,6 +129,8 @@ mpz_manager<SYNCH>::mpz_manager():
mpz_mul(m_int64_max, m_tmp, m_int64_max); mpz_mul(m_int64_max, m_tmp, m_int64_max);
mpz_set_ui(m_tmp, max_l); mpz_set_ui(m_tmp, max_l);
mpz_add(m_int64_max, m_tmp, m_int64_max); mpz_add(m_int64_max, m_tmp, m_int64_max);
mpz_neg(m_int64_min, m_int64_max);
mpz_sub_ui(m_int64_min, m_int64_min, 1);
#endif #endif
mpz one(1); mpz one(1);
@ -152,6 +155,7 @@ mpz_manager<SYNCH>::~mpz_manager() {
deallocate(m_arg[1]); deallocate(m_arg[1]);
mpz_clear(m_uint64_max); mpz_clear(m_uint64_max);
mpz_clear(m_int64_max); mpz_clear(m_int64_max);
mpz_clear(m_int64_min);
#endif #endif
if (SYNCH) if (SYNCH)
omp_destroy_nest_lock(&m_lock); omp_destroy_nest_lock(&m_lock);
@ -1299,10 +1303,9 @@ bool mpz_manager<SYNCH>::is_int64(mpz const & a) const {
if (is_small(a)) if (is_small(a))
return true; return true;
#ifndef _MP_GMP #ifndef _MP_GMP
if (!is_uint64(a)) { if (!is_abs_uint64(a))
return false; return false;
} uint64 num = big_abs_to_uint64(a);
uint64 num get_uint64(a);
uint64 msb = static_cast<uint64>(1) << 63; uint64 msb = static_cast<uint64>(1) << 63;
uint64 msb_val = msb & num; uint64 msb_val = msb & num;
if (a.m_val >= 0) { if (a.m_val >= 0) {
@ -1318,7 +1321,7 @@ bool mpz_manager<SYNCH>::is_int64(mpz const & a) const {
} }
#else #else
// GMP version // GMP version
return mpz_cmp(*a.m_ptr, m_int64_max) <= 0; return mpz_cmp(m_int64_min, *a.m_ptr) <= 0 && mpz_cmp(*a.m_ptr, m_int64_max) <= 0;
#endif #endif
} }
@ -1328,14 +1331,7 @@ uint64 mpz_manager<SYNCH>::get_uint64(mpz const & a) const {
return static_cast<uint64>(a.m_val); return static_cast<uint64>(a.m_val);
#ifndef _MP_GMP #ifndef _MP_GMP
SASSERT(a.m_ptr->m_size > 0); SASSERT(a.m_ptr->m_size > 0);
if (a.m_ptr->m_size == 1) return big_abs_to_uint64(a);
return digits(a)[0];
if (sizeof(digit_t) == sizeof(uint64))
// 64-bit machine
return digits(a)[0];
else
// 32-bit machine
return ((static_cast<uint64>(digits(a)[1]) << 32) | (static_cast<uint64>(digits(a)[0])));
#else #else
// GMP version // GMP version
if (sizeof(uint64) == sizeof(unsigned long)) { if (sizeof(uint64) == sizeof(unsigned long)) {
@ -1360,7 +1356,7 @@ int64 mpz_manager<SYNCH>::get_int64(mpz const & a) const {
return static_cast<int64>(a.m_val); return static_cast<int64>(a.m_val);
#ifndef _MP_GMP #ifndef _MP_GMP
SASSERT(is_int64(a)); SASSERT(is_int64(a));
uint64 num = get_uint64(a); uint64 num = big_abs_to_uint64(a);
if (a.m_val < 0) { if (a.m_val < 0) {
if (num != 0 && (num << 1) == 0) if (num != 0 && (num << 1) == 0)
return INT64_MIN; return INT64_MIN;

View file

@ -168,6 +168,7 @@ class mpz_manager {
mpz_t * m_arg[2]; mpz_t * m_arg[2];
mpz_t m_uint64_max; mpz_t m_uint64_max;
mpz_t m_int64_max; mpz_t m_int64_max;
mpz_t m_int64_min;
mpz_t * allocate() { mpz_t * allocate() {
mpz_t * cell = reinterpret_cast<mpz_t*>(m_allocator.allocate(sizeof(mpz_t))); mpz_t * cell = reinterpret_cast<mpz_t*>(m_allocator.allocate(sizeof(mpz_t)));
@ -211,6 +212,30 @@ class mpz_manager {
static digit_t * digits(mpz const & c) { return c.m_ptr->m_digits; } static digit_t * digits(mpz const & c) { return c.m_ptr->m_digits; }
// Return true if the absolute value fits in a UINT64
static bool is_abs_uint64(mpz const & a) {
if (is_small(a))
return true;
if (sizeof(digit_t) == sizeof(uint64))
return size(a) <= 1;
else
return size(a) <= 2;
}
// CAST the absolute value into a UINT64
static uint64 big_abs_to_uint64(mpz const & a) {
SASSERT(is_abs_uint64(a));
SASSERT(!is_small(a));
if (a.m_ptr->m_size == 1)
return digits(a)[0];
if (sizeof(digit_t) == sizeof(uint64))
// 64-bit machine
return digits(a)[0];
else
// 32-bit machine
return ((static_cast<uint64>(digits(a)[1]) << 32) | (static_cast<uint64>(digits(a)[0])));
}
template<int IDX> template<int IDX>
void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell) { void get_sign_cell(mpz const & a, int & sign, mpz_cell * & cell) {
if (is_small(a)) { if (is_small(a)) {

View file

@ -45,7 +45,7 @@ COMPILE_TIME_ASSERT(sizeof(int64) == 8);
#define INT64_MIN static_cast<int64>(0x8000000000000000ull) #define INT64_MIN static_cast<int64>(0x8000000000000000ull)
#endif #endif
#ifndef INT64_MAX #ifndef INT64_MAX
#define INT64_MAX static_cast<int64>(0x0fffffffffffffffull) #define INT64_MAX static_cast<int64>(0x7fffffffffffffffull)
#endif #endif
#ifndef UINT64_MAX #ifndef UINT64_MAX
#define UINT64_MAX 0xffffffffffffffffull #define UINT64_MAX 0xffffffffffffffffull