diff --git a/passes/techmap/extract_fa.cc b/passes/techmap/extract_fa.cc index 1984f82f5..964ec99bc 100644 --- a/passes/techmap/extract_fa.cc +++ b/passes/techmap/extract_fa.cc @@ -19,7 +19,11 @@ #include "kernel/yosys.h" #include "kernel/sigtools.h" +#include "backends/rtlil/rtlil_backend.h" #include "kernel/consteval.h" +#include +#include +#include USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN @@ -66,9 +70,6 @@ struct ExtractFaWorker dict, dict>> func2; dict, dict>> func3; - int count_func2; - int count_func3; - struct func2_and_info_t { bool inv_a, inv_b, inv_y; }; @@ -77,6 +78,26 @@ struct ExtractFaWorker bool inv_a, inv_b, inv_c, inv_y; }; + struct Counters { + int count_func2; + int count_func3; + }; + + struct ThreadData { + size_t start; + size_t end; + Counters counters; + ExtractFaWorker* instance; + std::stringstream log_buffer; + std::vector ports = {ID::A, ID::B, ID::C, ID::D}; + pool> tl_xorxnor2; + pool> tl_xorxnor3; + pool, int, SigBit>> tl_func_2; + pool,int, SigBit>> tl_func_3; + }; + + std::mutex consteval_mtx; + dict func2_and_info; dict func3_maj_info; @@ -153,7 +174,7 @@ struct ExtractFaWorker } } - void check_partition(SigBit root, pool &leaves) + void check_partition(SigBit root, pool &leaves, ThreadData& data) { if (config.enable_ha && GetSize(leaves) == 2) { @@ -163,35 +184,38 @@ struct ExtractFaWorker SigBit B = SigSpec(leaves)[1]; int func = 0; - for (int i = 0; i < 4; i++) { - bool a_value = (i & 1) != 0; - bool b_value = (i & 2) != 0; + std::lock_guard lock(consteval_mtx); + for (int i = 0; i < 4; i++) + { + bool a_value = (i & 1) != 0; + bool b_value = (i & 2) != 0; - ce.push(); - ce.set(A, a_value ? State::S1 : State::S0); - ce.set(B, b_value ? State::S1 : State::S0); + ce.push(); + ce.set(A, a_value ? State::S1 : State::S0); + ce.set(B, b_value ? State::S1 : State::S0); + SigSpec sig = root; - SigSpec sig = root; + if (!ce.eval(sig)) { + ce.pop(); + return; + } + + if (sig == State::S1) + func |= 1 << i; - if (!ce.eval(sig)) { ce.pop(); - return; } - - if (sig == State::S1) - func |= 1 << i; - - ce.pop(); } - // log("%04d %s %s -> %s\n", bindec(func), log_signal(A), log_signal(B), log_signal(root)); if (func == xor2_func || func == xnor2_func) - xorxnor2.insert(tuple(A, B)); + data.tl_xorxnor2.insert(tuple(A, B)); - count_func2++; - func2[tuple(A, B)][func].insert(root); + data.counters.count_func2++; + data.tl_func_2.insert( + tuple, int, SigBit>(tuple(A, B), func, root) + ); } if (config.enable_fa && GetSize(leaves) == 3) @@ -203,52 +227,70 @@ struct ExtractFaWorker SigBit C = SigSpec(leaves)[2]; int func = 0; - for (int i = 0; i < 8; i++) { - bool a_value = (i & 1) != 0; - bool b_value = (i & 2) != 0; - bool c_value = (i & 4) != 0; + std::lock_guard lock(consteval_mtx); + for (int i = 0; i < 8; i++) + { + bool a_value = (i & 1) != 0; + bool b_value = (i & 2) != 0; + bool c_value = (i & 4) != 0; - ce.push(); - ce.set(A, a_value ? State::S1 : State::S0); - ce.set(B, b_value ? State::S1 : State::S0); - ce.set(C, c_value ? State::S1 : State::S0); + ce.push(); + ce.set(A, a_value ? State::S1 : State::S0); + ce.set(B, b_value ? State::S1 : State::S0); + ce.set(C, c_value ? State::S1 : State::S0); + SigSpec sig = root; - SigSpec sig = root; + if (!ce.eval(sig)) { + ce.pop(); + return; + } + if (sig == State::S1) + func |= 1 << i; - if (!ce.eval(sig)) { ce.pop(); - return; } - - if (sig == State::S1) - func |= 1 << i; - - ce.pop(); } // log("%08d %s %s %s -> %s\n", bindec(func), log_signal(A), log_signal(B), log_signal(C), log_signal(root)); if (func == xor3_func || func == xnor3_func) - xorxnor3.insert(tuple(A, B, C)); + data.tl_xorxnor3.insert(tuple(A, B, C)); - count_func3++; - func3[tuple(A, B, C)][func].insert(root); + data.counters.count_func3++; + data.tl_func_3.insert( + tuple, int, SigBit>(tuple(A, B, C), func, root) + ); } } + void partition_log_cache(std::stringstream& stream, int depth, SigBit signal, bool format_depth=false) { + std::stringstream buf; + RTLIL_BACKEND::dump_sigspec(buf, signal, true); + if(format_depth) { + // at most, this is going to take in maxdepth spaces + 2 brackets + 2 numbers + space + \0 + std::vector spacer_buffer; + spacer_buffer.resize(config.maxdepth + 6); + snprintf(spacer_buffer.data(), config.maxdepth + 6, "%*s[%d] ", config.maxdepth-depth, "", depth); + stream << spacer_buffer.data(); + } - void find_partitions(SigBit root, pool &leaves, pool> &cache, int maxdepth, int maxbreadth) + stream << " " << buf.str(); + if(format_depth) + stream << ":"; + } + + void find_partitions(SigBit root, pool &leaves, pool> &cache, int maxdepth, int maxbreadth, ThreadData& data) { if (cache.count(leaves)) return; - // log("%*s[%d] %s:", 20-maxdepth, "", maxdepth, log_signal(root)); - // for (auto bit : leaves) - // log(" %s", log_signal(bit)); - // log("\n"); + partition_log_cache(data.log_buffer, maxdepth, root, true); + for (auto bit : leaves) + partition_log_cache(data.log_buffer, maxdepth, bit); + data.log_buffer << "\n"; cache.insert(leaves); - check_partition(root, leaves); + check_partition(root, leaves, data); if (maxdepth == 0) return; @@ -262,7 +304,8 @@ struct ExtractFaWorker pool new_leaves = leaves; new_leaves.erase(bit); - for (auto port : {ID::A, ID::B, ID::C, ID::D}) { + + for (auto port : data.ports) { if (!cell->hasPort(port)) continue; auto bit = sigmap(SigBit(cell->getPort(port))); @@ -274,7 +317,7 @@ struct ExtractFaWorker if (GetSize(new_leaves) > maxbreadth) continue; - find_partitions(root, new_leaves, cache, maxdepth-1, maxbreadth); + find_partitions(root, new_leaves, cache, maxdepth-1, maxbreadth, data); } } @@ -290,29 +333,55 @@ struct ExtractFaWorker void run() { log("Extracting full/half adders from %s:\n", log_id(module)); + const size_t num_threads = std::thread::hardware_concurrency(); + std::vector threads; + std::vector thread_data(num_threads); - for (auto it : driver) - { - if (it.second->type.in(ID($_BUF_), ID($_NOT_))) - continue; + size_t total_elements = driver.size(); + size_t thread_elements = total_elements / num_threads; + for (size_t i = 0; i < num_threads; ++i) { + thread_data[i].start = i * thread_elements; + thread_data[i].end = (i == num_threads - 1) ? total_elements : (i + 1) * thread_elements; + thread_data[i].instance = this; - SigBit root = it.first; - pool leaves = { root }; - pool> cache; + threads.emplace_back([&data = thread_data[i]]() { + auto& driver = data.instance->driver; + auto& config = data.instance->config; - if (config.verbose) - log(" checking %s\n", log_signal(it.first)); + for (size_t i = data.start; i < data.end; ++i) { + const auto& it = *driver.element(i); + if (it.second->type.in(ID($_BUF_), ID($_NOT_))) + continue; - count_func2 = 0; - count_func3 = 0; + SigBit root = it.first; + pool leaves = { root }; + pool> cache; - find_partitions(root, leaves, cache, config.maxdepth, config.maxbreadth); + if (config.verbose) + log(" checking %s\n", log_signal(it.first)); - if (config.verbose && count_func2 > 0) - log(" extracted %d two-input functions\n", count_func2); + data.instance->find_partitions(root, leaves, cache, config.maxdepth, config.maxbreadth, data); + // log("%s", log_buffer.str().c_str()); - if (config.verbose && count_func3 > 0) - log(" extracted %d three-input functions\n", count_func3); + if (config.verbose && data.counters.count_func2 > 0) + log(" extracted %d two-input functions\n", data.counters.count_func2); + + if (config.verbose && data.counters.count_func3 > 0) + log(" extracted %d three-input functions\n", data.counters.count_func3); + } + }); + } + + for (size_t i = 0; i < num_threads; ++i) { + threads[i].join(); + for(auto& x3 : thread_data[i].tl_xorxnor3) + xorxnor3.insert(x3); + for(auto& x2 : thread_data[i].tl_xorxnor2) + xorxnor2.insert(x2); + for(auto& f3 : thread_data[i].tl_func_3) + func3[get<0>(f3)][get<1>(f3)].insert(get<2>(f3)); + for(auto& f2 : thread_data[i].tl_func_2) + func2[get<0>(f2)][get<1>(f2)].insert(get<2>(f2)); } for (auto &key : xorxnor3) @@ -341,10 +410,13 @@ struct ExtractFaWorker int func = it.first; auto f3i = it.second; + int xor_cnt, xnor_cnt; + xor_cnt = func3.at(key).count(xor3_func); + xnor_cnt = func3.at(key).count(xnor3_func); if (func3.at(key).count(func) == 0) continue; - if (func3.at(key).count(xor3_func) == 0 && func3.at(key).count(xnor3_func) != 0) { + if (xor_cnt == 0 && xnor_cnt != 0) { f3i.inv_a = !f3i.inv_a; f3i.inv_b = !f3i.inv_b; f3i.inv_c = !f3i.inv_c; @@ -413,13 +485,13 @@ struct ExtractFaWorker } bool invert_y = f3i.inv_a ^ f3i.inv_b ^ f3i.inv_c; - if (func3.at(key).count(xor3_func)) { + if (xor_cnt) { SigBit YY = invert_xy ^ invert_y ? module->NotGate(NEW_ID, Y) : Y; for (auto bit : func3.at(key).at(xor3_func)) assign_new_driver(bit, YY); } - if (func3.at(key).count(xnor3_func)) { + if (xnor_cnt) { SigBit YY = invert_xy ^ invert_y ? Y : module->NotGate(NEW_ID, Y); for (auto bit : func3.at(key).at(xnor3_func)) assign_new_driver(bit, YY);