diff --git a/passes/opt/opt_balance_tree.cc b/passes/opt/opt_balance_tree.cc index 6ea43ec30..e723587c5 100644 --- a/passes/opt/opt_balance_tree.cc +++ b/passes/opt/opt_balance_tree.cc @@ -19,30 +19,35 @@ * */ -#include "kernel/yosys.h" #include "kernel/sigtools.h" +#include "kernel/yosys.h" USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN - struct OptBalanceTreeWorker { // Module and signal map Design *design; Module *module; SigMap sigmap; bool allow_off_chain; + int limit = -1; // Counts of each cell type that are getting balanced dict cell_count; + // Driver data + dict> bit_drivers_db; + // Load data + dict>> bit_users_db; // Signal chain data structures - dict sig_chain_next; - dict sig_chain_prev; + dict sig_chain_next; + dict sig_chain_prev; pool sigbit_with_non_chain_users; - pool chain_start_cells; - pool candidate_cells; + pool chain_start_cells; + pool candidate_cells; - void make_sig_chain_next_prev(IdString cell_type) { + void make_sig_chain_next_prev(IdString cell_type) + { // Mark all wires with keep attribute as having non-chain users for (auto wire : module->wires()) { if (wire->get_bool_attribute(ID::keep)) { @@ -59,14 +64,16 @@ struct OptBalanceTreeWorker { SigSpec a_sig = sigmap(cell->getPort(ID::A)); SigSpec b_sig = sigmap(cell->getPort(ID::B)); SigSpec y_sig = sigmap(cell->getPort(ID::Y)); - - // If a_sig already has a chain user, mark its bits as having non-chain users + + // If a_sig already has a chain user, mark its bits as having non-chain users if (sig_chain_next.count(a_sig)) for (auto a_bit : a_sig.bits()) sigbit_with_non_chain_users.insert(a_bit); // Otherwise, mark cell as the next in the chain relative to a_sig else { - sig_chain_next[a_sig] = cell; + if (fanout_in_range(y_sig)) { + sig_chain_next[a_sig] = cell; + } } if (!b_sig.empty()) { @@ -76,15 +83,19 @@ struct OptBalanceTreeWorker { sigbit_with_non_chain_users.insert(b_bit); // Otherwise, mark cell as the next in the chain relative to b_sig else { - sig_chain_next[b_sig] = cell; + if (fanout_in_range(y_sig)) { + sig_chain_next[b_sig] = cell; + } } } - - // Add cell as candidate - candidate_cells.insert(cell); - // Mark cell as the previous in the chain relative to y_sig - sig_chain_prev[y_sig] = cell; + if (fanout_in_range(y_sig)) { + // Add cell as candidate + candidate_cells.insert(cell); + + // Mark cell as the previous in the chain relative to y_sig + sig_chain_prev[y_sig] = cell; + } } // If cell is not matching type, mark all cell input signals as being non-chain users else { @@ -96,7 +107,8 @@ struct OptBalanceTreeWorker { } } - void find_chain_start_cells() { + void find_chain_start_cells() + { for (auto cell : candidate_cells) { // Log candidate cell log_debug("Considering %s (%s)\n", log_id(cell), log_id(cell->type)); @@ -105,7 +117,7 @@ struct OptBalanceTreeWorker { SigSpec a_sig = sigmap(cell->getPort(ID::A)); SigSpec b_sig = sigmap(cell->getPort(ID::B)); SigSpec prev_sig = sig_chain_prev.count(a_sig) ? a_sig : b_sig; - + // This is a start cell if there was no previous cell in the chain for a_sig or b_sig if (sig_chain_prev.count(a_sig) + sig_chain_prev.count(b_sig) != 1) { chain_start_cells.insert(cell); @@ -121,9 +133,10 @@ struct OptBalanceTreeWorker { } } - vector create_chain(Cell *start_cell) { + vector create_chain(Cell *start_cell) + { // Chain of cells - vector chain; + vector chain; // Current cell Cell *c = start_cell; @@ -146,7 +159,8 @@ struct OptBalanceTreeWorker { return chain; } - void wreduce(Cell *cell) { + void wreduce(Cell *cell) + { // If cell is arithmetic, remove leading zeros from inputs, then clean up outputs if (cell->type.in(ID($add), ID($mul))) { // Remove leading zeros from inputs @@ -158,13 +172,14 @@ struct OptBalanceTreeWorker { SigSpec inport_sig = sigmap(cell->getPort(inport)); cell->unsetPort(inport); if (cell->getParam((inport == ID::A) ? ID::A_SIGNED : ID::B_SIGNED).as_bool()) { - while (GetSize(inport_sig) > 1 && inport_sig[GetSize(inport_sig)-1] == State::S0 && inport_sig[GetSize(inport_sig)-2] == State::S0) { - inport_sig.remove(GetSize(inport_sig)-1, 1); + while (GetSize(inport_sig) > 1 && inport_sig[GetSize(inport_sig) - 1] == State::S0 && + inport_sig[GetSize(inport_sig) - 2] == State::S0) { + inport_sig.remove(GetSize(inport_sig) - 1, 1); bits_removed++; } } else { - while (GetSize(inport_sig) > 0 && inport_sig[GetSize(inport_sig)-1] == State::S0) { - inport_sig.remove(GetSize(inport_sig)-1, 1); + while (GetSize(inport_sig) > 0 && inport_sig[GetSize(inport_sig) - 1] == State::S0) { + inport_sig.remove(GetSize(inport_sig) - 1, 1); bits_removed++; } } @@ -184,7 +199,8 @@ struct OptBalanceTreeWorker { width = std::max(cell->getParam(ID::A_WIDTH).as_int(), cell->getParam(ID::B_WIDTH).as_int()) + 1; else if (cell->type == ID($mul)) width = cell->getParam(ID::A_WIDTH).as_int() + cell->getParam(ID::B_WIDTH).as_int(); - else log_abort(); + else + log_abort(); for (int i = GetSize(y_sig) - 1; i >= width; i--) { module->connect(y_sig[i], State::S0); y_sig.remove(i, 1); @@ -198,7 +214,8 @@ struct OptBalanceTreeWorker { cell->fixup_parameters(); } - void process_chain(vector &chain) { + void process_chain(vector &chain) + { // If chain size is less than 3, no balancing needed if (GetSize(chain) < 3) return; @@ -208,8 +225,8 @@ struct OptBalanceTreeWorker { Cell *cell = mid_cell; // SILIMATE: Set cell to mid_cell for better naming Cell *midnext_cell = chain[GetSize(chain) / 2 + 1]; Cell *end_cell = chain.back(); - log_debug("Balancing chain of %d cells: mid=%s, midnext=%s, endcell=%s\n", - GetSize(chain), log_id(mid_cell), log_id(midnext_cell), log_id(end_cell)); + log_debug("Balancing chain of %d cells: mid=%s, midnext=%s, endcell=%s\n", GetSize(chain), log_id(mid_cell), log_id(midnext_cell), + log_id(end_cell)); // Get mid signals SigSpec mid_a_sig = sigmap(mid_cell->getPort(ID::A)); @@ -238,17 +255,17 @@ struct OptBalanceTreeWorker { sigmap.set(module); // Get subtrees - vector left_chain(chain.begin(), chain.begin() + GetSize(chain) / 2); - vector right_chain(chain.begin() + GetSize(chain) / 2 + 1, chain.end()); + vector left_chain(chain.begin(), chain.begin() + GetSize(chain) / 2); + vector right_chain(chain.begin() + GetSize(chain) / 2 + 1, chain.end()); // Recurse on subtrees process_chain(left_chain); process_chain(right_chain); - + // Width reduce left subtree for (auto c : left_chain) wreduce(c); - + // Width reduce right subtree for (auto c : right_chain) wreduce(c); @@ -260,7 +277,8 @@ struct OptBalanceTreeWorker { wreduce(mid_cell); } - void cleanup() { + void cleanup() + { // Fix ports module->fixup_ports(); @@ -272,10 +290,74 @@ struct OptBalanceTreeWorker { candidate_cells.clear(); } - OptBalanceTreeWorker(Design* design, Module *module, const vector cell_types, bool allow_off_chain) : - design(design), module(module), sigmap(module), allow_off_chain(allow_off_chain) { + bool fanout_in_range(SigSpec outsig) + { + // Check if output signal is "bit-split", skip if so + // This is a lookahead for the splitfanout pass that has this limitation + auto bit_users = bit_users_db[outsig[0]]; + for (int i = 0; i < GetSize(outsig); i++) { + if (bit_users_db[outsig[i]] != bit_users) { + return false; + } + } + + // Skip if fanout is above limit + if (limit != -1 && GetSize(bit_users) > limit) { + return false; + } + return true; + } + + OptBalanceTreeWorker(Design *design, Module *module, const vector cell_types, bool allow_off_chain, int limit) + : design(design), module(module), sigmap(module), allow_off_chain(allow_off_chain), limit(limit) + { if (allow_off_chain) { + + // Build bit_drivers_db + log("Building bit_drivers_db...\n"); + for (auto cell : module->cells()) { + for (auto conn : cell->connections()) { + if (!cell->output(conn.first)) + continue; + for (int i = 0; i < GetSize(conn.second); i++) { + SigBit bit(sigmap(conn.second[i])); + bit_drivers_db[bit] = tuple(cell->name, conn.first, i); + } + } + } + + // Build bit_users_db + log("Building bit_users_db...\n"); + for (auto cell : module->cells()) { + for (auto conn : cell->connections()) { + if (!cell->input(conn.first)) + continue; + for (int i = 0; i < GetSize(conn.second); i++) { + SigBit bit(sigmap(conn.second[i])); + if (!bit_drivers_db.count(bit)) + continue; + bit_users_db[bit].insert( + tuple(cell->name, conn.first, i - std::get<2>(bit_drivers_db[bit]))); + } + } + } + + // Build bit_users_db for output ports + log("Building bit_users_db for output ports...\n"); + for (auto wire : module->wires()) { + if (!wire->port_output) + continue; + SigSpec sig(sigmap(wire)); + for (int i = 0; i < GetSize(sig); i++) { + SigBit bit(sig[i]); + if (!bit_drivers_db.count(bit)) + continue; + bit_users_db[bit].insert( + tuple(wire->name, IdString(), i - std::get<2>(bit_drivers_db[bit]))); + } + } + // Deselect all cells Pass::call(design, "select -none"); // Do for each cell type @@ -308,7 +390,7 @@ struct OptBalanceTreeWorker { sigmap.set(module); } - // Do for each cell type + // Do for each cell type for (auto cell_type : cell_types) { // Find chains of ops make_sig_chain_next_prev(cell_type); @@ -316,7 +398,7 @@ struct OptBalanceTreeWorker { // For each chain, if len >= 3, convert to tree via "rotation" and recurse on subtrees for (auto c : chain_start_cells) { - vector chain = create_chain(c); + vector chain = create_chain(c); process_chain(chain); cell_count[cell_type] += GetSize(chain); } @@ -328,8 +410,9 @@ struct OptBalanceTreeWorker { }; struct OptBalanceTreePass : public Pass { - OptBalanceTreePass() : Pass("opt_balance_tree", "$and/$or/$xor/$xnor/$add/$mul cascades to trees") { } - void help() override { + OptBalanceTreePass() : Pass("opt_balance_tree", "$and/$or/$xor/$xnor/$add/$mul cascades to trees") {} + void help() override + { // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| log("\n"); log(" opt_balance_tree [options] [selection]\n"); @@ -341,20 +424,26 @@ struct OptBalanceTreePass : public Pass { log(" Allows matching of cells that have loads outside the chain. These cells\n"); log(" will be replicated and balanced into a tree, but the original\n"); log(" cell will remain, driving its original loads.\n"); + log(" -fanout_limit n\n"); + log(" max fanout to split.\n"); log("\n"); } - void execute(std::vector args, RTLIL::Design *design) override { + void execute(std::vector args, RTLIL::Design *design) override + { log_header(design, "Executing OPT_BALANCE_TREE pass (cell cascades to trees).\n"); bool allow_off_chain = false; size_t argidx; - for (argidx = 1; argidx < args.size(); argidx++) - { - if (args[argidx] == "-allow-off-chain") - { + int limit = -1; + for (argidx = 1; argidx < args.size(); argidx++) { + if (args[argidx] == "-allow-off-chain") { allow_off_chain = true; continue; } + if (args[argidx] == "-fanout_limit" && argidx + 1 < args.size()) { + limit = std::stoi(args[++argidx]); + continue; + } break; } extra_args(args, argidx, design); @@ -363,7 +452,7 @@ struct OptBalanceTreePass : public Pass { dict cell_count; const vector cell_types = {ID($and), ID($or), ID($xor), ID($xnor), ID($add), ID($mul)}; for (auto module : design->selected_modules()) { - OptBalanceTreeWorker worker(design, module, cell_types, allow_off_chain); + OptBalanceTreeWorker worker(design, module, cell_types, allow_off_chain, limit); for (auto cell : worker.cell_count) { cell_count[cell.first] += cell.second; }