diff --git a/passes/opt/Makefile.inc b/passes/opt/Makefile.inc index 5dee824ff..42c17d6c0 100644 --- a/passes/opt/Makefile.inc +++ b/passes/opt/Makefile.inc @@ -24,6 +24,7 @@ OBJS += passes/opt/opt_ffinv.o OBJS += passes/opt/pmux2shiftx.o OBJS += passes/opt/muxpack.o OBJS += passes/opt/opt_balance_tree.o +OBJS += passes/opt/csa_tree.o OBJS += passes/opt/peepopt.o GENFILES += passes/opt/peepopt_pm.h diff --git a/passes/opt/csa_tree.cc b/passes/opt/csa_tree.cc new file mode 100644 index 000000000..2328af3d8 --- /dev/null +++ b/passes/opt/csa_tree.cc @@ -0,0 +1,357 @@ +#include "kernel/yosys.h" +#include "kernel/sigtools.h" + +USING_YOSYS_NAMESPACE +PRIVATE_NAMESPACE_BEGIN + +struct CsaTreeWorker +{ + RTLIL::Module *module; + SigMap sigmap; + int min_operands; + + dict sig_to_driver; + dict cell_fanout; + pool consumed; + + int stat_trees = 0; + int stat_fa_cells = 0; + int stat_removed_cells = 0; + + CsaTreeWorker(RTLIL::Module *module, int min_operands) : + module(module), sigmap(module), min_operands(min_operands) {} + + void build_maps() + { + dict sig_consumers; + + for (auto cell : module->cells()) + { + if (cell->type.in(ID($add), ID($sub))) + { + RTLIL::SigSpec y = sigmap(cell->getPort(ID::Y)); + for (auto bit : y) + if (bit.wire != nullptr) + sig_to_driver[bit] = cell; + } + + for (auto &conn : cell->connections()) + { + if (cell->input(conn.first)) + { + for (auto bit : sigmap(conn.second)) + if (bit.wire != nullptr) + sig_consumers[bit]++; + } + } + } + + for (auto wire : module->wires()) + if (wire->port_output) + for (auto bit : sigmap(wire)) + if (bit.wire != nullptr) + sig_consumers[bit]++; + + for (auto cell : module->cells()) + { + if (!cell->type.in(ID($add), ID($sub))) + continue; + int fo = 0; + for (auto bit : sigmap(cell->getPort(ID::Y))) + if (bit.wire != nullptr) + fo = std::max(fo, sig_consumers.count(bit) ? sig_consumers.at(bit) : 0); + cell_fanout[cell] = fo; + } + } + + struct Operand { + RTLIL::SigSpec sig; + bool is_signed; + bool do_subtract; + }; + + bool can_absorb(RTLIL::Cell *cell) + { + if (cell == nullptr) + return false; + if (!cell->type.in(ID($add), ID($sub))) + return false; + if (consumed.count(cell)) + return false; + if (cell_fanout.count(cell) ? cell_fanout.at(cell) != 1 : true) + return false; + return true; + } + + RTLIL::Cell *get_driver(RTLIL::SigSpec sig) + { + sig = sigmap(sig); + if (sig.empty()) + return nullptr; + + RTLIL::Cell *driver = nullptr; + for (auto bit : sig) + { + if (bit.wire == nullptr) + continue; + auto it = sig_to_driver.find(bit); + if (it == sig_to_driver.end()) + return nullptr; + if (driver == nullptr) + driver = it->second; + else if (driver != it->second) + return nullptr; // mixed + } + return driver; + } + + void collect_operands( + RTLIL::Cell *cell, + bool negate, + std::vector &operands, + std::vector &tree_cells + ) { + tree_cells.push_back(cell); + consumed.insert(cell); + + bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); + bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); + bool is_sub = (cell->type == ID($sub)); + + RTLIL::SigSpec sig_a = cell->getPort(ID::A); + RTLIL::SigSpec sig_b = cell->getPort(ID::B); + + RTLIL::Cell *driver_a = get_driver(sig_a); + if (can_absorb(driver_a)) { + collect_operands(driver_a, negate, operands, tree_cells); + } else { + operands.push_back({sig_a, a_signed, negate}); + } + + bool b_negate = negate ^ is_sub; + RTLIL::Cell *driver_b = get_driver(sig_b); + if (can_absorb(driver_b)) { + collect_operands(driver_b, b_negate, operands, tree_cells); + } else { + operands.push_back({sig_b, b_signed, b_negate}); + } + } + + void create_fa( + RTLIL::SigSpec a, + RTLIL::SigSpec b, + RTLIL::SigSpec c, + int width, + RTLIL::SigSpec &sum_out, + RTLIL::SigSpec &carry_out + ) { + RTLIL::Wire *w_sum = module->addWire(NEW_ID, width); + RTLIL::Wire *w_carry = module->addWire(NEW_ID, width); + + RTLIL::Cell *fa = module->addCell(NEW_ID, ID($fa)); + fa->setParam(ID::WIDTH, width); + fa->setPort(ID::A, a); + fa->setPort(ID::B, b); + fa->setPort(ID::C, c); + fa->setPort(ID::Y, w_sum); + fa->setPort(ID::X, w_carry); + + sum_out = w_sum; + carry_out = w_carry; + stat_fa_cells++; + } + + RTLIL::SigSpec extend_to(RTLIL::SigSpec sig, bool is_signed, int target_width) + { + if (GetSize(sig) >= target_width) + return sig.extract(0, target_width); + + RTLIL::SigSpec result = sig; + RTLIL::SigBit pad = is_signed ? sig[GetSize(sig) - 1] : RTLIL::S0; + while (GetSize(result) < target_width) + result.append(pad); + return result; + } + + RTLIL::SigSpec build_csa_tree(std::vector &operands, int output_width) + { + int width = output_width; + std::vector summands; + int sub_count = 0; + + for (auto &op : operands) + { + RTLIL::SigSpec sig = extend_to(op.sig, op.is_signed, width); + + if (op.do_subtract) { + sig = module->Not(NEW_ID, sig); + sub_count++; + } + + summands.push_back(sig); + } + + if (sub_count > 0) { + RTLIL::Const correction(sub_count, width); + summands.push_back(RTLIL::SigSpec(correction)); + } + + if (summands.empty()) + return RTLIL::SigSpec(0, width); + + if (summands.size() == 1) + return summands[0]; + + if (summands.size() == 2) { + RTLIL::Wire *result = module->addWire(NEW_ID, width); + module->addAdd(NEW_ID, summands[0], summands[1], result); + return result; + } + + while (summands.size() > 2) + { + std::vector next; + int i = 0; + + while (i + 2 < (int)summands.size()) + { + RTLIL::SigSpec a = summands[i]; + RTLIL::SigSpec b = summands[i + 1]; + RTLIL::SigSpec c = summands[i + 2]; + + RTLIL::SigSpec sum, carry; + create_fa(a, b, c, width, sum, carry); + + RTLIL::SigSpec carry_shifted; + carry_shifted.append(RTLIL::S0); + carry_shifted.append(carry.extract(0, width - 1)); + + next.push_back(sum); + next.push_back(carry_shifted); + i += 3; + } + + while (i < (int)summands.size()) + next.push_back(summands[i++]); + + summands.swap(next); + } + + RTLIL::Wire *result = module->addWire(NEW_ID, width); + module->addAdd(NEW_ID, summands[0], summands[1], result); + return result; + } + + void run() + { + build_maps(); + + std::vector roots; + for (auto cell : module->selected_cells()) + if (cell->type.in(ID($add), ID($sub))) + roots.push_back(cell); + + std::sort(roots.begin(), roots.end(), + [](RTLIL::Cell *a, RTLIL::Cell *b) { + return a->name < b->name; + }); + + std::sort(roots.begin(), roots.end(), + [&](RTLIL::Cell *a, RTLIL::Cell *b) { + return (cell_fanout.count(a) ? cell_fanout.at(a) : 0) > + (cell_fanout.count(b) ? cell_fanout.at(b) : 0); + }); + + for (auto root : roots) + { + if (consumed.count(root)) + continue; + + std::vector operands; + std::vector tree_cells; + + collect_operands(root, false, operands, tree_cells); + + if ((int)operands.size() < min_operands) { + for (auto c : tree_cells) + consumed.erase(c); + continue; + } + + int output_width = root->getParam(ID::Y_WIDTH).as_int(); + + log(" Found adder tree rooted at %s with %d operands (depth %d cells)\n", + log_id(root), (int)operands.size(), (int)tree_cells.size()); + + RTLIL::SigSpec new_output = build_csa_tree(operands, output_width); + RTLIL::SigSpec old_output = root->getPort(ID::Y); + module->connect(old_output, new_output); + + for (auto c : tree_cells) { + module->remove(c); + stat_removed_cells++; + } + + stat_trees++; + } + } +}; + +struct CsaTreePass : public Pass +{ + CsaTreePass() : Pass("csa_tree", + "convert adder chains to carry-save adder trees") {} + + void help() override + { + // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| + log("\n"); + log(" csa_tree [options] [selection]\n"); + log("\n"); + log("This pass converts chains of $add/$sub cells into carry-save adder trees using\n"); + log("$fa (full adder / 3:2 compressor) cells to reduce the critical path depth of\n"); + log("multi-operand addition.\n"); + log("\n"); + log("For N operands of width W, the critical path is reduced from\n"); + log("O(N * log W) to O(log_1.5(N) + log W).\n"); + log("\n"); + log(" -min_operands N\n"); + log(" Minimum number of operands to trigger CSA tree construction.\n"); + log(" Default: 3. Values below 3 are clamped to 3.\n"); + log("\n"); + } + + void execute(std::vector args, RTLIL::Design *design) override + { + int min_operands = 3; + + log_header(design, "Executing CSA_TREE pass (carry-save adder tree optimization).\n"); + + size_t argidx; + for (argidx = 1; argidx < args.size(); argidx++) + { + if (args[argidx] == "-min_operands" && argidx + 1 < args.size()) { + min_operands = std::max(3, atoi(args[++argidx].c_str())); + continue; + } + break; + } + extra_args(args, argidx, design); + + for (auto module : design->selected_modules()) + { + log("Processing module %s...\n", log_id(module)); + + CsaTreeWorker worker(module, min_operands); + worker.run(); + + if (worker.stat_trees > 0) + log(" Converted %d adder tree(s): created %d $fa cells, " + "removed %d $add/$sub cells.\n", + worker.stat_trees, worker.stat_fa_cells, + worker.stat_removed_cells); + } + } +} CsaTreePass; + +PRIVATE_NAMESPACE_END