diff --git a/scripts/mk_util.py b/scripts/mk_util.py index 411fdd1e1..4641a7b1b 100644 --- a/scripts/mk_util.py +++ b/scripts/mk_util.py @@ -1801,7 +1801,7 @@ def def_module_params(module_name, export, params, class_name=None, description= out.write(' {}\n') out.write(' static void collect_param_descrs(param_descrs & d) {\n') for param in params: - out.write(' d.insert("%s", %s, "%s", "%s");\n' % (param[0], TYPE2CPK[param[1]], param[3], pyg_default(param))) + out.write(' d.insert("%s", %s, "%s", "%s","%s");\n' % (param[0], TYPE2CPK[param[1]], param[3], pyg_default(param), module_name)) out.write(' }\n') if export: out.write(' /*\n') diff --git a/src/muz/rel/udoc_relation.cpp b/src/muz/rel/udoc_relation.cpp index 07ca511a6..7f0a61da9 100644 --- a/src/muz/rel/udoc_relation.cpp +++ b/src/muz/rel/udoc_relation.cpp @@ -1071,6 +1071,9 @@ namespace datalog { // TBD: replace this by "join" given below. virtual relation_base* operator()(relation_base const& t1, relation_base const& t2) { +#if 0 + return join(get(t1), get(t2)); +#else udoc_relation *joined = get(m_joiner(t1, t2)); relation_base* result = 0; if (joined->fast_empty()) { @@ -1082,6 +1085,7 @@ namespace datalog { } joined->deallocate(); return result; +#endif } private: diff --git a/src/test/udoc_relation.cpp b/src/test/udoc_relation.cpp index d7f83f0aa..af80a6d5c 100644 --- a/src/test/udoc_relation.cpp +++ b/src/test/udoc_relation.cpp @@ -144,6 +144,8 @@ public: udoc_relation* t1, *t2, *t3; expr_ref fml(m); + test_join_project(); + test_filter_neg4(false); test_filter_neg4(true); test_filter_neg5(false); @@ -406,6 +408,37 @@ public: } + void test_join_project() + { + datalog::relation_signature sig; + sig.push_back(bv.mk_sort(3)); + sig.push_back(bv.mk_sort(3)); + sig.push_back(bv.mk_sort(3)); + + unsigned_vector jc1, jc2, pc; + jc1.push_back(0); + jc2.push_back(0); + pc.push_back(1); + pc.push_back(3); + pc.push_back(4); + udoc_relation* t1, *t2; + relation_base* t; + + scoped_ptr join_project_fn; + + for (unsigned i = 0; i < 20; ++i) { + t1 = mk_rand(sig); + t2 = mk_rand(sig); + join_project_fn = p.mk_join_project_fn(*t1, *t2, jc1.size(), jc1.c_ptr(), jc2.c_ptr(), pc.size(), pc.c_ptr()); + t = (*join_project_fn)(*t1, *t2); + t->display(std::cout); + cr.verify_join_project(*t1, *t2, *t, jc1, jc2, pc); + t->deallocate(); + t1->deallocate(); + t2->deallocate(); + } + } + void test_rename() { udoc_relation* t1; // rename diff --git a/src/util/params.cpp b/src/util/params.cpp index a2609f840..cbb2b2acc 100644 --- a/src/util/params.cpp +++ b/src/util/params.cpp @@ -51,31 +51,34 @@ struct param_descrs::imp { param_kind m_kind; char const * m_descr; char const * m_default; + char const * m_module; - info(param_kind k, char const * descr, char const * def): + info(param_kind k, char const * descr, char const * def, char const* module): m_kind(k), m_descr(descr), - m_default(def) { + m_default(def), + m_module(module) { } info(): m_kind(CPK_INVALID), m_descr(0), - m_default(0) { + m_default(0), + m_module(0) { } }; dictionary m_info; svector m_names; - void insert(symbol const & name, param_kind k, char const * descr, char const * def) { + void insert(symbol const & name, param_kind k, char const * descr, char const * def, char const* module) { SASSERT(!name.is_numerical()); info i; if (m_info.find(name, i)) { SASSERT(i.m_kind == k); return; } - m_info.insert(name, info(k, descr, def)); + m_info.insert(name, info(k, descr, def, module)); m_names.push_back(name); } @@ -94,6 +97,13 @@ struct param_descrs::imp { return CPK_INVALID; } + char const* get_module(symbol const& name) const { + info i; + if (m_info.find(name, i)) + return i.m_module; + return 0; + } + char const * get_descr(symbol const & name) const { info i; if (m_info.find(name, i)) @@ -162,7 +172,7 @@ struct param_descrs::imp { dictionary::iterator it = other.m_imp->m_info.begin(); dictionary::iterator end = other.m_imp->m_info.end(); for (; it != end; ++it) { - insert(it->m_key, it->m_value.m_kind, it->m_value.m_descr, it->m_value.m_default); + insert(it->m_key, it->m_value.m_kind, it->m_value.m_descr, it->m_value.m_default, it->m_value.m_module); } } @@ -180,12 +190,12 @@ void param_descrs::copy(param_descrs & other) { m_imp->copy(other); } -void param_descrs::insert(symbol const & name, param_kind k, char const * descr, char const * def) { - m_imp->insert(name, k, descr, def); +void param_descrs::insert(symbol const & name, param_kind k, char const * descr, char const * def, char const* module) { + m_imp->insert(name, k, descr, def, module); } -void param_descrs::insert(char const * name, param_kind k, char const * descr, char const * def) { - insert(symbol(name), k, descr, def); +void param_descrs::insert(char const * name, param_kind k, char const * descr, char const * def, char const* module) { + insert(symbol(name), k, descr, def, module); } bool param_descrs::contains(char const * name) const { @@ -236,6 +246,10 @@ symbol param_descrs::get_param_name(unsigned i) const { return m_imp->get_param_name(i); } +char const* param_descrs::get_module(symbol const& name) const { + return m_imp->get_module(name); +} + void param_descrs::display(std::ostream & out, unsigned indent, bool smt2_style, bool include_descr) const { return m_imp->display(out, indent, smt2_style, include_descr); } @@ -297,11 +311,35 @@ public: void reset(symbol const & k); void reset(char const * k); - void validate(param_descrs const & p) const { - svector::const_iterator it = m_entries.begin(); - svector::const_iterator end = m_entries.end(); + bool split_name(symbol const& name, symbol & prefix, symbol & suffix) { + if (name.is_numerical()) return false; + char const* str = name.bare_str(); + char const* period = strchr(str,'.'); + if (!period) return false; + svector prefix_((unsigned)(period-str), str); + prefix_.push_back(0); + prefix = symbol(prefix_.c_ptr()); + suffix = symbol(period + 1); + return true; + } + + void validate(param_descrs const & p) { + svector::iterator it = m_entries.begin(); + svector::iterator end = m_entries.end(); + symbol suffix, prefix; for (; it != end; ++it) { param_kind expected = p.get_kind(it->first); + if (expected == CPK_INVALID && split_name(it->first, prefix, suffix)) { + expected = p.get_kind(suffix); + if (expected != CPK_INVALID) { + if (symbol(p.get_module(suffix)) == prefix) { + it->first = suffix; + } + else { + expected = CPK_INVALID; + } + } + } if (expected == CPK_INVALID) { std::stringstream strm; strm << "unknown parameter '" << it->first.str() << "'\n"; @@ -490,7 +528,7 @@ void params_ref::display(std::ostream & out, symbol const & k) const { out << "default"; } -void params_ref::validate(param_descrs const & p) const { +void params_ref::validate(param_descrs const & p) { if (m_params) m_params->validate(p); } diff --git a/src/util/params.h b/src/util/params.h index 06be486bb..f11374775 100644 --- a/src/util/params.h +++ b/src/util/params.h @@ -92,7 +92,7 @@ public: void display(std::ostream & out) const; void display_smt2(std::ostream& out, char const* module, param_descrs& module_desc) const; - void validate(param_descrs const & p) const; + void validate(param_descrs const & p); /* \brief Display the value of the given parameter. @@ -115,8 +115,8 @@ public: param_descrs(); ~param_descrs(); void copy(param_descrs & other); - void insert(char const * name, param_kind k, char const * descr, char const * def = 0); - void insert(symbol const & name, param_kind k, char const * descr, char const * def = 0); + void insert(char const * name, param_kind k, char const * descr, char const * def = 0, char const* module = 0); + void insert(symbol const & name, param_kind k, char const * descr, char const * def = 0, char const* module = 0); bool contains(char const * name) const; bool contains(symbol const & name) const; void erase(char const * name); @@ -130,6 +130,7 @@ public: void display(std::ostream & out, unsigned indent = 0, bool smt2_style=false, bool include_descr=true) const; unsigned size() const; symbol get_param_name(unsigned idx) const; + char const * get_module(symbol const& name) const; }; void insert_max_memory(param_descrs & r);