From a7fcfc18fa664abafe3cb2aa26c463510a846365 Mon Sep 17 00:00:00 2001 From: nella Date: Fri, 13 Mar 2026 12:54:58 +0100 Subject: [PATCH] Add sub chain support for csa trees. --- passes/opt/csa_tree.cc | 157 ++++++++++++++++++++++++++++++++--------- 1 file changed, 123 insertions(+), 34 deletions(-) diff --git a/passes/opt/csa_tree.cc b/passes/opt/csa_tree.cc index 92f744a9f..e1a4a9464 100644 --- a/passes/opt/csa_tree.cc +++ b/passes/opt/csa_tree.cc @@ -13,7 +13,7 @@ struct CsaTreeWorker dict> bit_consumers; dict fanout; - pool all_adds; + pool all_addsubs; CsaTreeWorker(Module *module) : module(module), sigmap(module) {} @@ -22,11 +22,11 @@ struct CsaTreeWorker int depth; }; - void find_adds() + void find_addsubs() { for (auto cell : module->cells()) - if (cell->type == ID($add)) - all_adds.insert(cell); + if (cell->type == ID($add) || cell->type == ID($sub)) + all_addsubs.insert(cell); } void build_fanout_map() @@ -46,9 +46,9 @@ struct CsaTreeWorker fanout[bit]++; } - Cell*single_add_consumer(SigSpec sig) + Cell* single_addsub_consumer(SigSpec sig) { - Cell*consumer = nullptr; + Cell* consumer = nullptr; for (auto bit : sig) { if (!fanout.count(bit) || fanout[bit] != 1) @@ -57,7 +57,7 @@ struct CsaTreeWorker return nullptr; Cell* c = *bit_consumers[bit].begin(); - if (!all_adds.count(c)) + if (!all_addsubs.count(c)) return nullptr; if (consumer == nullptr) @@ -69,13 +69,13 @@ struct CsaTreeWorker return consumer; } - dict find_add_parents() + dict find_addsub_parents() { dict parent_of; - for (auto cell : all_adds) { + for (auto cell : all_addsubs) { SigSpec y = sigmap(cell->getPort(ID::Y)); - Cell* consumer = single_add_consumer(y); + Cell* consumer = single_addsub_consumer(y); if (consumer != nullptr && consumer != cell) parent_of[cell] = consumer; } @@ -125,22 +125,78 @@ struct CsaTreeWorker struct Operand { SigSpec sig; bool is_signed; + bool negate; }; - std::vector collect_leaf_operands(const pool &chain, const pool &chain_y_bits) + bool is_subtracted_input(Cell* child, Cell* parent) { + if (parent->type != ID($sub)) + return false; + + SigSpec child_y = sigmap(child->getPort(ID::Y)); + SigSpec parent_b = sigmap(parent->getPort(ID::B)); + + for (auto bit : child_y) + for (auto pbit : parent_b) + if (bit == pbit) + return true; + + return false; + } + + std::vector collect_leaf_operands( + const pool &chain, + const pool &chain_y_bits, + Cell* root, + const dict &parent_of, + int &correction + ) { + dict negated; + negated[root] = false; + std::queue worklist; + worklist.push(root); + + while (!worklist.empty()) { + Cell* cur = worklist.front(); + worklist.pop(); + for (auto cell : chain) { + if (!parent_of.count(cell)) + continue; + if (parent_of.at(cell) != cur) + continue; + if (negated.count(cell)) + continue; + + bool sub_b = is_subtracted_input(cell, cur); + negated[cell] = negated[cur] ^ sub_b; + worklist.push(cell); + } + } + std::vector operands; + correction = 0; for (auto cell : chain) { + bool cell_neg = negated.count(cell) ? negated[cell] : false; SigSpec a = sigmap(cell->getPort(ID::A)); SigSpec b = sigmap(cell->getPort(ID::B)); bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); + bool b_subtracted = (cell->type == ID($sub)); - if (!is_chain_internal(a, chain_y_bits)) - operands.push_back({a, a_signed}); - if (!is_chain_internal(b, chain_y_bits)) - operands.push_back({b, b_signed}); + if (!is_chain_internal(a, chain_y_bits)) { + bool neg_a = cell_neg; + operands.push_back({a, a_signed, neg_a}); + if (neg_a) + correction++; + } + + if (!is_chain_internal(b, chain_y_bits)) { + bool neg_b = cell_neg ^ b_subtracted; + operands.push_back({b, b_signed, neg_b}); + if (neg_b) + correction++; + } } return operands; @@ -157,6 +213,23 @@ struct CsaTreeWorker return sig; } + SigSpec emit_not(SigSpec sig, int width) + { + SigSpec out = module->addWire(NEW_ID, width); + Cell* inv = module->addCell(NEW_ID, ID($not)); + inv->setParam(ID::A_SIGNED, false); + inv->setParam(ID::A_WIDTH, width); + inv->setParam(ID::Y_WIDTH, width); + inv->setPort(ID::A, sig); + inv->setPort(ID::Y, out); + return out; + } + + SigSpec make_constant(int value, int width) + { + return SigSpec(value, width); + } + std::pair emit_fa(SigSpec a, SigSpec b, SigSpec c, int width) { SigSpec sum = module->addWire(NEW_ID, width); @@ -247,13 +320,13 @@ struct CsaTreeWorker void run() { - find_adds(); - if (all_adds.empty()) + find_addsubs(); + if (all_addsubs.empty()) return; build_fanout_map(); - auto parent_of = find_add_parents(); + auto parent_of = find_addsub_parents(); pool has_parent; dict> children_of; @@ -264,7 +337,7 @@ struct CsaTreeWorker pool processed; - for (auto root : all_adds) + for (auto root : all_addsubs) { if (has_parent.count(root)) continue; @@ -279,23 +352,40 @@ struct CsaTreeWorker processed.insert(c); pool chain_y_bits = collect_chain_outputs(chain); - auto operands = collect_leaf_operands(chain, chain_y_bits); + int correction = 0; + auto operands = collect_leaf_operands(chain, chain_y_bits, root, parent_of, correction); if (operands.size() < 3) continue; SigSpec root_y = root->getPort(ID::Y); int width = GetSize(root_y); - std::vector extended; - for (auto &op : operands) - extended.push_back(extend_to(op.sig, op.is_signed, width)); + + for (auto &op : operands) { + SigSpec s = extend_to(op.sig, op.is_signed, width); + if (op.negate) + s = emit_not(s, width); + extended.push_back(s); + } + + if (correction > 0) + extended.push_back(make_constant(correction, width)); int fa_count; auto [final_a, final_b] = build_wallace_tree(extended, width, fa_count); + int num_subs = 0; - log(" Replaced chain of %d $add cells with %d $fa + 1 $add (%d operands, module %s)\n", - (int)chain.size(), fa_count, (int)operands.size(), log_id(module)); + for (auto cell : chain) + if (cell->type == ID($sub)) + num_subs++; + + if (num_subs > 0) + log(" Replaced chain of %d $add/%d $sub cells with %d $fa + 1 $add (%d operands, module %s)\n", + (int)chain.size() - num_subs, num_subs, fa_count, (int)operands.size(), log_id(module)); + else + log(" Replaced chain of %d $add cells with %d $fa + 1 $add (%d operands, module %s)\n", + (int)chain.size(), fa_count, (int)operands.size(), log_id(module)); emit_final_add(final_a, final_b, root_y, width); @@ -306,22 +396,21 @@ struct CsaTreeWorker }; struct CsaTreePass : public Pass { - CsaTreePass() : Pass("csa_tree", "convert $add chains to carry-save adder trees") {} + CsaTreePass() : Pass("csa_tree", "convert $add/$sub chains to carry-save adder trees") {} void help() override { - // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| + // |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---| log("\n"); log(" csa_tree [selection]\n"); log("\n"); - log("This pass finds chains of $add cells and replaces them with\n"); - log("carry-save adder (CSA) trees built from $fa cells, followed by\n"); - log("a single final $add for the carry-propagate step.\n"); + log("This pass finds chains of $add and $sub cells and replaces them with carry-save\n"); + log("adder trees built from $fa cells, followed by a single final $add for the\n"); + log("carry-propagate step.\n"); log("\n"); - log("The tree uses Wallace-tree scheduling for optimal depth:\n"); - log("at each level, all ready operands are grouped into triplets\n"); - log("and compressed via full adders. This gives ceil(log_1.5(N))\n"); - log("FA levels for N input operands.\n"); + log("The tree uses Wallace-tree scheduling for optimal depth: at each level, all ready\n"); + log("operands are grouped into triplets and compressed via full adders. This\n"); + log("gives ceil(log_1.5(N)) FA levels for N input operands.\n"); log("\n"); }