diff --git a/pyosys/hashlib.h b/pyosys/hashlib.h index c6ca8c096..02dd75566 100644 --- a/pyosys/hashlib.h +++ b/pyosys/hashlib.h @@ -51,10 +51,11 @@ namespace pybind11 { namespace hashlib { -template -struct is_pointer { static const bool value = false; }; -template -struct is_pointer { static const bool value = true; }; +// "traits" +template struct is_pointer: std::false_type {}; +template struct is_pointer: std::true_type {}; +template struct is_optional: std::false_type {}; +template struct is_optional< std::optional >: std::true_type {}; bool is_mapping(object obj) { object mapping = module_::import("collections.abc").attr("Mapping"); @@ -270,6 +271,12 @@ void bind_set(module &m, const char *name_cstr) { .def("__iter__", [](const C &s){ return make_iterator(s.begin(), s.end()); }, keep_alive<0,1>()) + .def("__eq__", [](const C &s, const C &other) { return s == other; }) + .def("__eq__", [](const C &s, const iterable &other) { + C other_cast; + unionize(other_cast, other); + return s == other_cast; + }) .def("__repr__", [name_cstr](const iterable &s){ // repr(set(s)) where s is iterable would be more terse/robust // but are there concerns with copying? @@ -292,17 +299,17 @@ void bind_pool(module &m, const char *name_cstr) { template -void update_dict(C *target, const iterable &iterable_or_mapping) { +void update_dict(C &target, const iterable &iterable_or_mapping) { if (is_mapping(iterable_or_mapping)) { for (const auto &key: iterable_or_mapping) { - (*target)[cast(key)] = cast(iterable_or_mapping[key]); + target[cast(key)] = cast(iterable_or_mapping[key]); } } else { for (const auto &pair: iterable_or_mapping) { if (len(pair) != 2) { throw value_error(str("iterable element %s has more than two elements").format(str(pair))); } - (*target)[cast(pair[cast(0)])] = cast(pair[cast(1)]); + target[cast(pair[cast(0)])] = cast(pair[cast(1)]); } } } @@ -314,7 +321,7 @@ void bind_dict(module &m, const char *name_cstr) { .def(init()) // copy constructor .def(init([](const iterable &other){ // copy instructor from arbitrary iterables and mappings auto s = new C(); - update_dict(s, other); + update_dict(*s, other); return s; })) .def("__len__", [](const C &s){ return (size_t)s.size(); }) @@ -352,7 +359,15 @@ void bind_dict(module &m, const char *name_cstr) { return s.at(k); } }, arg("key"), arg("default") = std::nullopt) - .def("popitem", [name_cstr](args _) { throw std::runtime_error(std::string(name_cstr) + " is not an ordered dictionary"); }) + .def("popitem", [](C &s) { + auto it = s.begin(); + if (it == s.end()) { + throw key_error("dict is empty"); + } + auto copy = *it; + s.erase(it); + return copy; + }) .def("setdefault", [name_cstr](C &s, const K& k, std::optional &default_) { auto it = s.find(k); if (it != s.end()) { @@ -367,22 +382,25 @@ void bind_dict(module &m, const char *name_cstr) { s[k] = nullptr; return (V)nullptr; } - // TODO: std::optional? do we care? + if constexpr (is_optional::value) { + s[k] = std::nullopt; + return std::nullopt; + } throw type_error(std::string("the value type of ") + name_cstr + " is not nullable"); }, arg("key"), arg("default") = std::nullopt) .def("update", [](C &s, iterable iterable_or_mapping) { - update_dict(&s, iterable_or_mapping); + update_dict(s, iterable_or_mapping); }, arg("iterable_or_mapping")) .def("values", [](const C &s){ return make_value_iterator(s.begin(), s.end()); }, keep_alive<0,1>()) .def("__or__", [](const C &s, iterable iterable_or_mapping) { auto result = new C(s); - update_dict(result, iterable_or_mapping); + update_dict(*result, iterable_or_mapping); return result; }) .def("__ior__", [](C &s, iterable iterable_or_mapping) { - update_dict(&s, iterable_or_mapping); + update_dict(s, iterable_or_mapping); return s; }) .def("__bool__", [](const C &s) { return s.size() != 0; }) @@ -402,6 +420,17 @@ void bind_dict(module &m, const char *name_cstr) { return representation; }); + // K is always comparable + // Python implements `is` as a fallback to check if it's the same object + if constexpr (detail::is_comparable::value) { + cls.def("__eq__", [](const C &s, const C &other) { return s == other; }); + cls.def("__eq__", [](const C &s, const iterable &other) { + C other_cast; + update_dict(other_cast, other); + return s == other_cast; + }); + } + // Inherit from collections.abc.Mapping so update operators (and a bunch // of other things) work. auto collections_abc = module_::import("collections.abc"); diff --git a/pyosys/wrappers_tpl.cc b/pyosys/wrappers_tpl.cc index 9dee4f710..0c17be4ed 100644 --- a/pyosys/wrappers_tpl.cc +++ b/pyosys/wrappers_tpl.cc @@ -36,7 +36,7 @@ using namespace RTLIL; #include "wrappers.inc.cc" namespace YOSYS_PYTHON { - struct YosysStatics{}; + struct Globals {}; // Trampolines for Classes with Python-Overridable Virtual Methods // https://pybind11.readthedocs.io/en/stable/advanced/classes.html#overriding-virtual-functions-in-python @@ -192,7 +192,7 @@ namespace YOSYS_PYTHON { m.def("log_file_error", [](std::string_view file, int line, std::string s) { log_formatted_file_error(file, line, s); }); // Namespace to host global objects - auto global_variables = py::class_(m, "Yosys"); + auto global_variables = py::class_(m, "Globals"); // Trampoline Classes py::class_>(m, "Pass") diff --git a/tests/pyosys/test_dict.py b/tests/pyosys/test_dict.py index 916d69b92..717fed8ea 100644 --- a/tests/pyosys/test_dict.py +++ b/tests/pyosys/test_dict.py @@ -26,4 +26,19 @@ the_great_or = constructor_test_1 | constructor_test_2 | constructor_test_3 assert set(the_great_or) == {"first", "key", "tomato", "im running"} repr_test = eval(repr(the_great_or)) -assert repr_test == the_great_or + +assert repr_test == the_great_or # compare dicts +assert repr_test == {'tomato': 'tomato', 'first': 'second', 'key': 'value', 'im running': 'out of string ideas', } # compare dict with mapping + +before = len(repr_test) +print(repr_test.popitem()) +assert before - 1 == len(repr_test) + +# test noncomparable +## if ys.CellType ever gets an == operator just disable this section +uncomparable_value = ys.Globals.yosys_celltypes.cell_types[ys.IdString("$not")] + +x = ys.IdstringToCelltypeDict({ ys.IdString("\\a"): uncomparable_value}) +y = ys.IdstringToCelltypeDict({ ys.IdString("\\a"): uncomparable_value}) + +assert x != y # not comparable diff --git a/tests/pyosys/test_set.py b/tests/pyosys/test_set.py index d89c5243e..5698cffbe 100644 --- a/tests/pyosys/test_set.py +++ b/tests/pyosys/test_set.py @@ -33,10 +33,10 @@ for cls in [StringSet, StringPool]: C |= {"A", "B", "C"} D |= {"C", "D", "E"} c_symdiff_d = (C ^ D) - assert (c_symdiff_d) == {"A", "B", "D", "E"} + assert c_symdiff_d == {"A", "B", "D", "E"} # compare against iterable repr_test = eval(repr(c_symdiff_d)) - c_symdiff_d == repr_test + assert c_symdiff_d == repr_test # compare against self print("Done.")