diff --git a/passes/opt/Makefile.inc b/passes/opt/Makefile.inc index 426d9a79a..5dee824ff 100644 --- a/passes/opt/Makefile.inc +++ b/passes/opt/Makefile.inc @@ -23,6 +23,7 @@ OBJS += passes/opt/opt_lut_ins.o OBJS += passes/opt/opt_ffinv.o OBJS += passes/opt/pmux2shiftx.o OBJS += passes/opt/muxpack.o +OBJS += passes/opt/opt_balance_tree.o OBJS += passes/opt/peepopt.o GENFILES += passes/opt/peepopt_pm.h diff --git a/passes/opt/opt_balance_tree.cc b/passes/opt/opt_balance_tree.cc new file mode 100644 index 000000000..5c19e3975 --- /dev/null +++ b/passes/opt/opt_balance_tree.cc @@ -0,0 +1,378 @@ +/* + * yosys -- Yosys Open SYnthesis Suite + * + * Copyright (C) 2012 Claire Xenia Wolf + * 2019 Eddie Hung + * 2024 Akash Levy + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + */ + +#include "kernel/yosys.h" +#include "kernel/sigtools.h" +#include + +USING_YOSYS_NAMESPACE +PRIVATE_NAMESPACE_BEGIN + + +struct OptBalanceTreeWorker { + // Module and signal map + Module *module; + SigMap sigmap; + + // Counts of each cell type that are getting balanced + dict cell_count; + + // Check if cell is of the right type and has matching input/output widths + // Only allow cells with "natural" output widths (no truncation) to prevent + // equivalence issues when rebalancing (see YosysHQ/yosys#5605) + bool is_right_type(Cell* cell, IdString cell_type) { + if (cell->type != cell_type) + return false; + + int y_width = cell->getParam(ID::Y_WIDTH).as_int(); + int a_width = cell->getParam(ID::A_WIDTH).as_int(); + int b_width = cell->getParam(ID::B_WIDTH).as_int(); + + // Calculate the "natural" output width for this operation + int natural_width; + if (cell_type == ID($add)) { + // Addition produces max(A_WIDTH, B_WIDTH) + 1 (for carry bit) + natural_width = std::max(a_width, b_width) + 1; + } else if (cell_type == ID($mul)) { + // Multiplication produces A_WIDTH + B_WIDTH + natural_width = a_width + b_width; + } else { + // Logic operations ($and/$or/$xor) produce max(A_WIDTH, B_WIDTH) + natural_width = std::max(a_width, b_width); + } + + // Only allow cells where Y_WIDTH >= natural width (no truncation) + // This prevents rebalancing chains where truncation semantics matter + return y_width >= natural_width; + } + + // Create a balanced binary tree from a vector of source signals + SigSpec create_balanced_tree(vector &sources, IdString cell_type, Cell* cell) { + // Base case: if we have no sources, return an empty signal + if (sources.size() == 0) + return SigSpec(); + + // Base case: if we have only one source, return it + if (sources.size() == 1) + return sources[0]; + + // Base case: if we have two sources, create a single cell + if (sources.size() == 2) { + // Create a new cell of the same type + Cell* new_cell = module->addCell(NEW_ID, cell_type); + + // Copy attributes from reference cell + new_cell->attributes = cell->attributes; + + // Create output wire + int out_width = cell->getParam(ID::Y_WIDTH).as_int(); + if (cell_type == ID($add)) + out_width = max(sources[0].size(), sources[1].size()) + 1; + else if (cell_type == ID($mul)) + out_width = sources[0].size() + sources[1].size(); + Wire* out_wire = module->addWire(NEW_ID, out_width); + + // Connect ports and fix up parameters + new_cell->setPort(ID::A, sources[0]); + new_cell->setPort(ID::B, sources[1]); + new_cell->setPort(ID::Y, out_wire); + new_cell->fixup_parameters(); + new_cell->setParam(ID::A_SIGNED, cell->getParam(ID::A_SIGNED)); + new_cell->setParam(ID::B_SIGNED, cell->getParam(ID::B_SIGNED)); + + // Update count and return output wire + cell_count[cell_type]++; + return out_wire; + } + + // Recursive case: split sources into two groups and create subtrees + int mid = (sources.size() + 1) / 2; + vector left_sources(sources.begin(), sources.begin() + mid); + vector right_sources(sources.begin() + mid, sources.end()); + + SigSpec left_tree = create_balanced_tree(left_sources, cell_type, cell); + SigSpec right_tree = create_balanced_tree(right_sources, cell_type, cell); + + // Create a cell to combine the two subtrees + Cell* new_cell = module->addCell(NEW_ID, cell_type); + + // Copy attributes from reference cell + new_cell->attributes = cell->attributes; + + // Create output wire + int out_width = cell->getParam(ID::Y_WIDTH).as_int(); + if (cell_type == ID($add)) + out_width = max(left_tree.size(), right_tree.size()) + 1; + else if (cell_type == ID($mul)) + out_width = left_tree.size() + right_tree.size(); + Wire* out_wire = module->addWire(NEW_ID, out_width); + + // Connect ports and fix up parameters + new_cell->setPort(ID::A, left_tree); + new_cell->setPort(ID::B, right_tree); + new_cell->setPort(ID::Y, out_wire); + new_cell->fixup_parameters(); + new_cell->setParam(ID::A_SIGNED, cell->getParam(ID::A_SIGNED)); + new_cell->setParam(ID::B_SIGNED, cell->getParam(ID::B_SIGNED)); + + // Update count and return output wire + cell_count[cell_type]++; + return out_wire; + } + + OptBalanceTreeWorker(Module *module, const vector cell_types) : module(module), sigmap(module) { + // Do for each cell type + for (auto cell_type : cell_types) { + // Index all of the nets in the module + dict sig_to_driver; + dict> sig_to_sink; + for (auto cell : module->selected_cells()) + { + for (auto &conn : cell->connections()) + { + if (cell->output(conn.first)) + sig_to_driver[sigmap(conn.second)] = cell; + + if (cell->input(conn.first)) + { + SigSpec sig = sigmap(conn.second); + if (sig_to_sink.count(sig) == 0) + sig_to_sink[sig] = pool(); + sig_to_sink[sig].insert(cell); + } + } + } + + // Need to check if any wires connect to module ports + pool input_port_sigs; + pool output_port_sigs; + for (auto wire : module->selected_wires()) + if (wire->port_input || wire->port_output) { + SigSpec sig = sigmap(wire); + for (auto bit : sig) { + if (wire->port_input) + input_port_sigs.insert(bit); + if (wire->port_output) + output_port_sigs.insert(bit); + } + } + + // Actual logic starts here + pool consumed_cells; + for (auto cell : module->selected_cells()) + { + // If consumed or not the correct type, skip + if (consumed_cells.count(cell) || !is_right_type(cell, cell_type)) + continue; + + // BFS, following all chains until they hit a cell of a different type + // Pick the longest one + auto y = sigmap(cell->getPort(ID::Y)); + pool sinks; + pool current_loads = sig_to_sink[y]; + pool next_loads; + while (!current_loads.empty()) + { + // Find each sink and see what they are + for (auto x : current_loads) + { + // If not the correct type, don't follow any further + // (but add the originating cell to the list of sinks) + if (!is_right_type(x, cell_type)) + { + sinks.insert(cell); + continue; + } + + auto xy = sigmap(x->getPort(ID::Y)); + + // If this signal drives a port, add it to the sinks + // (even though it may not be the end of a chain) + for (auto bit : xy) { + if (output_port_sigs.count(bit) && !consumed_cells.count(x)) { + sinks.insert(x); + break; + } + } + + // Search signal's fanout + auto& next = sig_to_sink[xy]; + for (auto z : next) + next_loads.insert(z); + } + + // If we couldn't find any downstream loads, stop. + // Create a reduction for each of the max-length chains we found + if (next_loads.empty()) + { + for (auto s : current_loads) + { + // Not one of our gates? Don't follow any further + if (!is_right_type(s, cell_type)) + continue; + + sinks.insert(s); + } + break; + } + + // Otherwise, continue down the chain + current_loads = next_loads; + next_loads.clear(); + } + + // We have our list of sinks, now go tree balance the chains + for (auto head_cell : sinks) + { + // Avoid duplication if we already were covered + if (consumed_cells.count(head_cell)) + continue; + + // Get sources of the chain + dict sources; + dict signeds; + int inner_cells = 0; + std::deque bfs_queue = {head_cell}; + while (bfs_queue.size()) + { + Cell* x = bfs_queue.front(); + bfs_queue.pop_front(); + + for (IdString port: {ID::A, ID::B}) { + auto sig = sigmap(x->getPort(port)); + Cell* drv = sig_to_driver[sig]; + bool drv_ok = drv && is_right_type(drv, cell_type); + for (auto bit : sig) { + if (input_port_sigs.count(bit) && !consumed_cells.count(drv)) { + drv_ok = false; + break; + } + } + if (drv_ok) { + inner_cells++; + bfs_queue.push_back(drv); + } else { + sources[sig]++; + signeds[sig] = x->getParam(port == ID::A ? ID::A_SIGNED : ID::B_SIGNED).as_bool(); + } + } + } + + if (inner_cells) + { + // Create a tree + log_debug(" Creating tree for %s with %d sources and %d inner cells...\n", log_id(head_cell), GetSize(sources), inner_cells); + + // Build a vector of all source signals + vector source_signals; + vector signed_flags; + for (auto &source : sources) { + for (int i = 0; i < source.second; i++) { + source_signals.push_back(source.first); + signed_flags.push_back(signeds[source.first]); + } + } + + // If not all signed flags are the same, do not balance + if (!std::all_of(signed_flags.begin(), signed_flags.end(), [&](bool flag) { return flag == signed_flags[0]; })) { + continue; + } + + // Create the balanced tree + SigSpec tree_output = create_balanced_tree(source_signals, cell_type, head_cell); + + // Connect the tree output to the head cell's output + SigSpec head_output = sigmap(head_cell->getPort(ID::Y)); + int connect_width = std::min(head_output.size(), tree_output.size()); + module->connect(head_output.extract(0, connect_width), tree_output.extract(0, connect_width)); + if (head_output.size() > tree_output.size()) { + SigBit sext_bit = head_cell->getParam(ID::A_SIGNED).as_bool() ? head_output[connect_width - 1] : State::S0; + module->connect(head_output.extract(connect_width, head_output.size() - connect_width), SigSpec(sext_bit, head_output.size() - connect_width)); + } + + // Mark consumed cell for removal + consumed_cells.insert(head_cell); + } + } + } + + // Remove all consumed cells, which now have been replaced by trees + for (auto cell : consumed_cells) + module->remove(cell); + } + } +}; + +struct OptBalanceTreePass : public Pass { + OptBalanceTreePass() : Pass("opt_balance_tree", "$and/$or/$xor/$add/$mul cascades to trees") { } + void help() override { + // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| + log("\n"); + log(" opt_balance_tree [options] [selection]\n"); + log("\n"); + log("This pass converts cascaded chains of $and/$or/$xor/$add/$mul cells into\n"); + log("trees of cells to improve timing.\n"); + log("\n"); + log(" -arith\n"); + log(" only convert arithmetic cells.\n"); + log("\n"); + log(" -logic\n"); + log(" only convert logic cells.\n"); + log("\n"); + } + void execute(std::vector args, RTLIL::Design *design) override { + log_header(design, "Executing OPT_BALANCE_TREE pass (cell cascades to trees).\n"); + + // Handle arguments + size_t argidx; + vector cell_types = {ID($and), ID($or), ID($xor), ID($add), ID($mul)}; + for (argidx = 1; argidx < args.size(); argidx++) { + if (args[argidx] == "-arith") { + cell_types = {ID($add), ID($mul)}; + continue; + } + if (args[argidx] == "-logic") { + cell_types = {ID($and), ID($or), ID($xor)}; + continue; + } + break; + } + extra_args(args, argidx, design); + + // Count of all cells that were packed + dict cell_count; + for (auto module : design->selected_modules()) { + OptBalanceTreeWorker worker(module, cell_types); + for (auto cell : worker.cell_count) { + cell_count[cell.first] += cell.second; + } + } + + // Log stats + for (auto cell_type : cell_types) + log("Converted %d %s cells into trees.\n", cell_count[cell_type], log_id(cell_type)); + + // Clean up + Yosys::run_pass("clean -purge"); + } +} OptBalanceTreePass; + +PRIVATE_NAMESPACE_END diff --git a/tests/opt/opt_balance_tree.ys b/tests/opt/opt_balance_tree.ys new file mode 100644 index 000000000..6f8b8b711 --- /dev/null +++ b/tests/opt/opt_balance_tree.ys @@ -0,0 +1,1246 @@ +# Test 1 +log -header "Simple AND chain" +log -push +design -reset +read_verilog <signed + + // Combine with unsigned offset + assign result = sum_full + $signed({6'b0, unsigned_offset}); +endmodule +EOF +check -assert + +# Check equivalence after opt_balance_tree +equiv_opt -assert opt_balance_tree +design -load postopt + +# Width reduction +equiv_opt -assert wreduce +design -load postopt + +design -reset +log -pop + + + +# Test 30 +log -header "Complex signedness with conditional sign extension" +log -push +design -reset +read_verilog <