diff --git a/passes/opt/csa_tree.cc b/passes/opt/csa_tree.cc index 2328af3d8..92f744a9f 100644 --- a/passes/opt/csa_tree.cc +++ b/passes/opt/csa_tree.cc @@ -1,355 +1,342 @@ #include "kernel/yosys.h" #include "kernel/sigtools.h" +#include + USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN struct CsaTreeWorker { - RTLIL::Module *module; + Module *module; SigMap sigmap; - int min_operands; - dict sig_to_driver; - dict cell_fanout; - pool consumed; + dict> bit_consumers; + dict fanout; + pool all_adds; - int stat_trees = 0; - int stat_fa_cells = 0; - int stat_removed_cells = 0; + CsaTreeWorker(Module *module) : module(module), sigmap(module) {} - CsaTreeWorker(RTLIL::Module *module, int min_operands) : - module(module), sigmap(module), min_operands(min_operands) {} + struct DepthSig { + SigSpec sig; + int depth; + }; - void build_maps() + void find_adds() { - dict sig_consumers; - for (auto cell : module->cells()) - { - if (cell->type.in(ID($add), ID($sub))) - { - RTLIL::SigSpec y = sigmap(cell->getPort(ID::Y)); - for (auto bit : y) - if (bit.wire != nullptr) - sig_to_driver[bit] = cell; - } + if (cell->type == ID($add)) + all_adds.insert(cell); + } + void build_fanout_map() + { + for (auto cell : module->cells()) for (auto &conn : cell->connections()) - { if (cell->input(conn.first)) - { for (auto bit : sigmap(conn.second)) - if (bit.wire != nullptr) - sig_consumers[bit]++; - } - } - } + bit_consumers[bit].insert(cell); + + for (auto &pair : bit_consumers) + fanout[pair.first] = pair.second.size(); for (auto wire : module->wires()) if (wire->port_output) - for (auto bit : sigmap(wire)) - if (bit.wire != nullptr) - sig_consumers[bit]++; + for (auto bit : sigmap(SigSpec(wire))) + fanout[bit]++; + } - for (auto cell : module->cells()) - { - if (!cell->type.in(ID($add), ID($sub))) - continue; - int fo = 0; - for (auto bit : sigmap(cell->getPort(ID::Y))) - if (bit.wire != nullptr) - fo = std::max(fo, sig_consumers.count(bit) ? sig_consumers.at(bit) : 0); - cell_fanout[cell] = fo; + Cell*single_add_consumer(SigSpec sig) + { + Cell*consumer = nullptr; + + for (auto bit : sig) { + if (!fanout.count(bit) || fanout[bit] != 1) + return nullptr; + if (!bit_consumers.count(bit) || bit_consumers[bit].size() != 1) + return nullptr; + + Cell* c = *bit_consumers[bit].begin(); + if (!all_adds.count(c)) + return nullptr; + + if (consumer == nullptr) + consumer = c; + else if (consumer != c) + return nullptr; } + + return consumer; + } + + dict find_add_parents() + { + dict parent_of; + + for (auto cell : all_adds) { + SigSpec y = sigmap(cell->getPort(ID::Y)); + Cell* consumer = single_add_consumer(y); + if (consumer != nullptr && consumer != cell) + parent_of[cell] = consumer; + } + + return parent_of; + } + + pool collect_chain(Cell* root, const dict> &children_of) + { + pool chain; + std::queue worklist; + worklist.push(root); + + while (!worklist.empty()) { + Cell* cur = worklist.front(); + worklist.pop(); + + if (chain.count(cur)) + continue; + chain.insert(cur); + + if (children_of.count(cur)) + for (auto child : children_of.at(cur)) + worklist.push(child); + } + + return chain; + } + + bool is_chain_internal(SigSpec sig, const pool &chain_y_bits) + { + for (auto bit : sig) + if (chain_y_bits.count(bit)) + return true; + return false; + } + + pool collect_chain_outputs(const pool &chain) + { + pool bits; + for (auto cell : chain) + for (auto bit : sigmap(cell->getPort(ID::Y))) + bits.insert(bit); + return bits; } struct Operand { - RTLIL::SigSpec sig; + SigSpec sig; bool is_signed; - bool do_subtract; }; - bool can_absorb(RTLIL::Cell *cell) + std::vector collect_leaf_operands(const pool &chain, const pool &chain_y_bits) { - if (cell == nullptr) - return false; - if (!cell->type.in(ID($add), ID($sub))) - return false; - if (consumed.count(cell)) - return false; - if (cell_fanout.count(cell) ? cell_fanout.at(cell) != 1 : true) - return false; - return true; + std::vector operands; + + for (auto cell : chain) { + SigSpec a = sigmap(cell->getPort(ID::A)); + SigSpec b = sigmap(cell->getPort(ID::B)); + bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); + bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); + + if (!is_chain_internal(a, chain_y_bits)) + operands.push_back({a, a_signed}); + if (!is_chain_internal(b, chain_y_bits)) + operands.push_back({b, b_signed}); + } + + return operands; } - RTLIL::Cell *get_driver(RTLIL::SigSpec sig) + SigSpec extend_to(SigSpec sig, bool is_signed, int width) { - sig = sigmap(sig); - if (sig.empty()) - return nullptr; - - RTLIL::Cell *driver = nullptr; - for (auto bit : sig) - { - if (bit.wire == nullptr) - continue; - auto it = sig_to_driver.find(bit); - if (it == sig_to_driver.end()) - return nullptr; - if (driver == nullptr) - driver = it->second; - else if (driver != it->second) - return nullptr; // mixed + if (GetSize(sig) < width) { + SigBit pad = (is_signed && GetSize(sig) > 0) ? sig[GetSize(sig) - 1] : State::S0; + sig.append(SigSpec(pad, width - GetSize(sig))); } - return driver; + if (GetSize(sig) > width) + sig = sig.extract(0, width); + return sig; } - void collect_operands( - RTLIL::Cell *cell, - bool negate, - std::vector &operands, - std::vector &tree_cells - ) { - tree_cells.push_back(cell); - consumed.insert(cell); + std::pair emit_fa(SigSpec a, SigSpec b, SigSpec c, int width) + { + SigSpec sum = module->addWire(NEW_ID, width); + SigSpec cout = module->addWire(NEW_ID, width); - bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); - bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); - bool is_sub = (cell->type == ID($sub)); - - RTLIL::SigSpec sig_a = cell->getPort(ID::A); - RTLIL::SigSpec sig_b = cell->getPort(ID::B); - - RTLIL::Cell *driver_a = get_driver(sig_a); - if (can_absorb(driver_a)) { - collect_operands(driver_a, negate, operands, tree_cells); - } else { - operands.push_back({sig_a, a_signed, negate}); - } - - bool b_negate = negate ^ is_sub; - RTLIL::Cell *driver_b = get_driver(sig_b); - if (can_absorb(driver_b)) { - collect_operands(driver_b, b_negate, operands, tree_cells); - } else { - operands.push_back({sig_b, b_signed, b_negate}); - } - } - - void create_fa( - RTLIL::SigSpec a, - RTLIL::SigSpec b, - RTLIL::SigSpec c, - int width, - RTLIL::SigSpec &sum_out, - RTLIL::SigSpec &carry_out - ) { - RTLIL::Wire *w_sum = module->addWire(NEW_ID, width); - RTLIL::Wire *w_carry = module->addWire(NEW_ID, width); - - RTLIL::Cell *fa = module->addCell(NEW_ID, ID($fa)); + Cell* fa = module->addCell(NEW_ID, ID($fa)); fa->setParam(ID::WIDTH, width); fa->setPort(ID::A, a); fa->setPort(ID::B, b); fa->setPort(ID::C, c); - fa->setPort(ID::Y, w_sum); - fa->setPort(ID::X, w_carry); + fa->setPort(ID::X, cout); + fa->setPort(ID::Y, sum); - sum_out = w_sum; - carry_out = w_carry; - stat_fa_cells++; + SigSpec carry_shifted; + carry_shifted.append(State::S0); + carry_shifted.append(cout.extract(0, width - 1)); + + return {sum, carry_shifted}; } - RTLIL::SigSpec extend_to(RTLIL::SigSpec sig, bool is_signed, int target_width) + std::pair build_wallace_tree(std::vector &operands, int width, int &fa_count) { - if (GetSize(sig) >= target_width) - return sig.extract(0, target_width); + std::vector ops; + for (auto &s : operands) + ops.push_back({s, 0}); - RTLIL::SigSpec result = sig; - RTLIL::SigBit pad = is_signed ? sig[GetSize(sig) - 1] : RTLIL::S0; - while (GetSize(result) < target_width) - result.append(pad); - return result; - } + fa_count = 0; + int level = 0; - RTLIL::SigSpec build_csa_tree(std::vector &operands, int output_width) - { - int width = output_width; - std::vector summands; - int sub_count = 0; - - for (auto &op : operands) + while (ops.size() > 2) { - RTLIL::SigSpec sig = extend_to(op.sig, op.is_signed, width); - - if (op.do_subtract) { - sig = module->Not(NEW_ID, sig); - sub_count++; + std::vector ready, waiting; + for (auto &op : ops) { + if (op.depth <= level) + ready.push_back(op); + else + waiting.push_back(op); } - summands.push_back(sig); - } + if (ready.size() < 3) { + level++; + log_assert(level <= 100); + continue; + } - if (sub_count > 0) { - RTLIL::Const correction(sub_count, width); - summands.push_back(RTLIL::SigSpec(correction)); - } - - if (summands.empty()) - return RTLIL::SigSpec(0, width); - - if (summands.size() == 1) - return summands[0]; - - if (summands.size() == 2) { - RTLIL::Wire *result = module->addWire(NEW_ID, width); - module->addAdd(NEW_ID, summands[0], summands[1], result); - return result; - } - - while (summands.size() > 2) - { - std::vector next; - int i = 0; - - while (i + 2 < (int)summands.size()) - { - RTLIL::SigSpec a = summands[i]; - RTLIL::SigSpec b = summands[i + 1]; - RTLIL::SigSpec c = summands[i + 2]; - - RTLIL::SigSpec sum, carry; - create_fa(a, b, c, width, sum, carry); - - RTLIL::SigSpec carry_shifted; - carry_shifted.append(RTLIL::S0); - carry_shifted.append(carry.extract(0, width - 1)); - - next.push_back(sum); - next.push_back(carry_shifted); + std::vector next; + size_t i = 0; + while (i + 2 < ready.size()) { + auto [sum, carry] = emit_fa(ready[i].sig, ready[i+1].sig, ready[i+2].sig, width); + int d = std::max({ready[i].depth, ready[i+1].depth, ready[i+2].depth}) + 1; + next.push_back({sum, d}); + next.push_back({carry, d}); + fa_count++; i += 3; } - while (i < (int)summands.size()) - next.push_back(summands[i++]); + for (; i < ready.size(); i++) + next.push_back(ready[i]); - summands.swap(next); + for (auto &op : waiting) + next.push_back(op); + + ops = std::move(next); + level++; + log_assert(level <= 100); } - RTLIL::Wire *result = module->addWire(NEW_ID, width); - module->addAdd(NEW_ID, summands[0], summands[1], result); - return result; + log_assert(ops.size() == 2); + + int max_depth = std::max(ops[0].depth, ops[1].depth); + log(" Tree depth: %d FA levels + 1 final add\n", max_depth); + + return {ops[0].sig, ops[1].sig}; + } + + void emit_final_add(SigSpec a, SigSpec b, SigSpec y, int width) + { + Cell* add = module->addCell(NEW_ID, ID($add)); + add->setParam(ID::A_SIGNED, false); + add->setParam(ID::B_SIGNED, false); + add->setParam(ID::A_WIDTH, width); + add->setParam(ID::B_WIDTH, width); + add->setParam(ID::Y_WIDTH, width); + add->setPort(ID::A, a); + add->setPort(ID::B, b); + add->setPort(ID::Y, y); } void run() { - build_maps(); + find_adds(); + if (all_adds.empty()) + return; - std::vector roots; - for (auto cell : module->selected_cells()) - if (cell->type.in(ID($add), ID($sub))) - roots.push_back(cell); + build_fanout_map(); - std::sort(roots.begin(), roots.end(), - [](RTLIL::Cell *a, RTLIL::Cell *b) { - return a->name < b->name; - }); + auto parent_of = find_add_parents(); - std::sort(roots.begin(), roots.end(), - [&](RTLIL::Cell *a, RTLIL::Cell *b) { - return (cell_fanout.count(a) ? cell_fanout.at(a) : 0) > - (cell_fanout.count(b) ? cell_fanout.at(b) : 0); - }); + pool has_parent; + dict> children_of; + for (auto &pair : parent_of) { + has_parent.insert(pair.first); + children_of[pair.second].insert(pair.first); + } - for (auto root : roots) + pool processed; + + for (auto root : all_adds) { - if (consumed.count(root)) + if (has_parent.count(root)) + continue; + if (processed.count(root)) continue; - std::vector operands; - std::vector tree_cells; - - collect_operands(root, false, operands, tree_cells); - - if ((int)operands.size() < min_operands) { - for (auto c : tree_cells) - consumed.erase(c); + pool chain = collect_chain(root, children_of); + if (chain.size() < 2) continue; - } - int output_width = root->getParam(ID::Y_WIDTH).as_int(); + for (auto c : chain) + processed.insert(c); - log(" Found adder tree rooted at %s with %d operands (depth %d cells)\n", - log_id(root), (int)operands.size(), (int)tree_cells.size()); + pool chain_y_bits = collect_chain_outputs(chain); + auto operands = collect_leaf_operands(chain, chain_y_bits); - RTLIL::SigSpec new_output = build_csa_tree(operands, output_width); - RTLIL::SigSpec old_output = root->getPort(ID::Y); - module->connect(old_output, new_output); + if (operands.size() < 3) + continue; - for (auto c : tree_cells) { - module->remove(c); - stat_removed_cells++; - } + SigSpec root_y = root->getPort(ID::Y); + int width = GetSize(root_y); - stat_trees++; + std::vector extended; + for (auto &op : operands) + extended.push_back(extend_to(op.sig, op.is_signed, width)); + + int fa_count; + auto [final_a, final_b] = build_wallace_tree(extended, width, fa_count); + + log(" Replaced chain of %d $add cells with %d $fa + 1 $add (%d operands, module %s)\n", + (int)chain.size(), fa_count, (int)operands.size(), log_id(module)); + + emit_final_add(final_a, final_b, root_y, width); + + for (auto cell : chain) + module->remove(cell); } } }; -struct CsaTreePass : public Pass -{ - CsaTreePass() : Pass("csa_tree", - "convert adder chains to carry-save adder trees") {} +struct CsaTreePass : public Pass { + CsaTreePass() : Pass("csa_tree", "convert $add chains to carry-save adder trees") {} void help() override { - // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| + // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| log("\n"); - log(" csa_tree [options] [selection]\n"); + log(" csa_tree [selection]\n"); log("\n"); - log("This pass converts chains of $add/$sub cells into carry-save adder trees using\n"); - log("$fa (full adder / 3:2 compressor) cells to reduce the critical path depth of\n"); - log("multi-operand addition.\n"); + log("This pass finds chains of $add cells and replaces them with\n"); + log("carry-save adder (CSA) trees built from $fa cells, followed by\n"); + log("a single final $add for the carry-propagate step.\n"); log("\n"); - log("For N operands of width W, the critical path is reduced from\n"); - log("O(N * log W) to O(log_1.5(N) + log W).\n"); - log("\n"); - log(" -min_operands N\n"); - log(" Minimum number of operands to trigger CSA tree construction.\n"); - log(" Default: 3. Values below 3 are clamped to 3.\n"); + log("The tree uses Wallace-tree scheduling for optimal depth:\n"); + log("at each level, all ready operands are grouped into triplets\n"); + log("and compressed via full adders. This gives ceil(log_1.5(N))\n"); + log("FA levels for N input operands.\n"); log("\n"); } void execute(std::vector args, RTLIL::Design *design) override { - int min_operands = 3; - - log_header(design, "Executing CSA_TREE pass (carry-save adder tree optimization).\n"); + log_header(design, "Executing CSA_TREE pass.\n"); size_t argidx; for (argidx = 1; argidx < args.size(); argidx++) - { - if (args[argidx] == "-min_operands" && argidx + 1 < args.size()) { - min_operands = std::max(3, atoi(args[++argidx].c_str())); - continue; - } break; - } extra_args(args, argidx, design); - for (auto module : design->selected_modules()) - { - log("Processing module %s...\n", log_id(module)); - - CsaTreeWorker worker(module, min_operands); + for (auto module : design->selected_modules()) { + CsaTreeWorker worker(module); worker.run(); - - if (worker.stat_trees > 0) - log(" Converted %d adder tree(s): created %d $fa cells, " - "removed %d $add/$sub cells.\n", - worker.stat_trees, worker.stat_fa_cells, - worker.stat_removed_cells); } } } CsaTreePass;