From 4b0b1d35b5602402b63fd88e6b5e0b848b85e118 Mon Sep 17 00:00:00 2001 From: nella Date: Thu, 19 Mar 2026 17:44:56 +0100 Subject: [PATCH] CSA add support for macc and alu cells. --- passes/opt/csa_tree.cc | 528 ++++++++++++++++++++++++----------------- 1 file changed, 306 insertions(+), 222 deletions(-) diff --git a/passes/opt/csa_tree.cc b/passes/opt/csa_tree.cc index e1a4a9464..de2e28901 100644 --- a/passes/opt/csa_tree.cc +++ b/passes/opt/csa_tree.cc @@ -1,43 +1,102 @@ #include "kernel/yosys.h" #include "kernel/sigtools.h" +#include "kernel/macc.h" #include USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN +struct Operand { + SigSpec sig; + bool is_signed; + bool negate; +}; + struct CsaTreeWorker { - Module *module; + Module* module; SigMap sigmap; dict> bit_consumers; dict fanout; - pool all_addsubs; - CsaTreeWorker(Module *module) : module(module), sigmap(module) {} + pool addsub_cells; + pool alu_cells; + pool macc_cells; - struct DepthSig { - SigSpec sig; - int depth; - }; + CsaTreeWorker(Module* module) : module(module), sigmap(module) {} - void find_addsubs() + static bool is_addsub(Cell* cell) { - for (auto cell : module->cells()) - if (cell->type == ID($add) || cell->type == ID($sub)) - all_addsubs.insert(cell); + return cell->type == ID($add) || cell->type == ID($sub); + } + + static bool is_alu(Cell* cell) + { + return cell->type == ID($alu); + } + + static bool is_macc(Cell* cell) + { + return cell->type == ID($macc) || cell->type == ID($macc_v2); + } + + bool alu_is_subtract(Cell* cell) + { + SigSpec bi = sigmap(cell->getPort(ID::BI)); + SigSpec ci = sigmap(cell->getPort(ID::CI)); + return GetSize(bi) == 1 && bi[0] == State::S1 && GetSize(ci) == 1 && ci[0] == State::S1; + } + + bool alu_is_add(Cell* cell) + { + SigSpec bi = sigmap(cell->getPort(ID::BI)); + SigSpec ci = sigmap(cell->getPort(ID::CI)); + return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0; + } + + bool alu_is_chainable(Cell* cell) + { + if (!(alu_is_add(cell) || alu_is_subtract(cell))) + return false; + + for (auto bit : sigmap(cell->getPort(ID::X))) + if (fanout.count(bit) && fanout[bit] > 0) + return false; + for (auto bit : sigmap(cell->getPort(ID::CO))) + if (fanout.count(bit) && fanout[bit] > 0) + return false; + + return true; + } + + bool is_chainable(Cell* cell) + { + return is_addsub(cell) || (is_alu(cell) && alu_is_chainable(cell)); + } + + void classify_cells() + { + for (auto cell : module->cells()) { + if (is_addsub(cell)) + addsub_cells.insert(cell); + else if (is_alu(cell)) + alu_cells.insert(cell); + else if (is_macc(cell)) + macc_cells.insert(cell); + } } void build_fanout_map() { for (auto cell : module->cells()) - for (auto &conn : cell->connections()) + for (auto& conn : cell->connections()) if (cell->input(conn.first)) for (auto bit : sigmap(conn.second)) bit_consumers[bit].insert(cell); - for (auto &pair : bit_consumers) + for (auto& pair : bit_consumers) fanout[pair.first] = pair.second.size(); for (auto wire : module->wires()) @@ -46,10 +105,9 @@ struct CsaTreeWorker fanout[bit]++; } - Cell* single_addsub_consumer(SigSpec sig) + Cell* sole_chainable_consumer(SigSpec sig, const pool& candidates) { Cell* consumer = nullptr; - for (auto bit : sig) { if (!fanout.count(bit) || fanout[bit] != 1) return nullptr; @@ -57,7 +115,7 @@ struct CsaTreeWorker return nullptr; Cell* c = *bit_consumers[bit].begin(); - if (!all_addsubs.count(c)) + if (!candidates.count(c)) return nullptr; if (consumer == nullptr) @@ -65,55 +123,39 @@ struct CsaTreeWorker else if (consumer != c) return nullptr; } - return consumer; } - dict find_addsub_parents() + dict find_parents(const pool& candidates) { dict parent_of; - - for (auto cell : all_addsubs) { - SigSpec y = sigmap(cell->getPort(ID::Y)); - Cell* consumer = single_addsub_consumer(y); - if (consumer != nullptr && consumer != cell) + for (auto cell : candidates) { + Cell* consumer = sole_chainable_consumer( + sigmap(cell->getPort(ID::Y)), candidates); + if (consumer && consumer != cell) parent_of[cell] = consumer; } - return parent_of; } - pool collect_chain(Cell* root, const dict> &children_of) + pool collect_chain(Cell* root, const dict>& children_of) { pool chain; - std::queue worklist; - worklist.push(root); - - while (!worklist.empty()) { - Cell* cur = worklist.front(); - worklist.pop(); - - if (chain.count(cur)) + std::queue q; + q.push(root); + while (!q.empty()) { + Cell* cur = q.front(); q.pop(); + if (!chain.insert(cur).second) continue; - chain.insert(cur); - - if (children_of.count(cur)) - for (auto child : children_of.at(cur)) - worklist.push(child); + auto it = children_of.find(cur); + if (it != children_of.end()) + for (auto child : it->second) + q.push(child); } - return chain; } - bool is_chain_internal(SigSpec sig, const pool &chain_y_bits) - { - for (auto bit : sig) - if (chain_y_bits.count(bit)) - return true; - return false; - } - - pool collect_chain_outputs(const pool &chain) + pool internal_bits(const pool& chain) { pool bits; for (auto cell : chain) @@ -122,54 +164,58 @@ struct CsaTreeWorker return bits; } - struct Operand { - SigSpec sig; - bool is_signed; - bool negate; - }; - - bool is_subtracted_input(Cell* child, Cell* parent) + static bool overlaps(SigSpec sig, const pool& bits) { - if (parent->type != ID($sub)) + for (auto bit : sig) + if (bits.count(bit)) + return true; + return false; + } + + bool feeds_subtracted_port(Cell* child, Cell* parent) + { + bool parent_subtracts; + if (parent->type == ID($sub)) + parent_subtracts = true; + else if (is_alu(parent)) + parent_subtracts = alu_is_subtract(parent); + else return false; - SigSpec child_y = sigmap(child->getPort(ID::Y)); - SigSpec parent_b = sigmap(parent->getPort(ID::B)); + if (!parent_subtracts) + return false; + SigSpec child_y = sigmap(child->getPort(ID::Y)); + SigSpec parent_b = sigmap(parent->getPort(ID::B)); for (auto bit : child_y) for (auto pbit : parent_b) if (bit == pbit) return true; - return false; } - std::vector collect_leaf_operands( - const pool &chain, - const pool &chain_y_bits, + std::vector extract_chain_operands( + const pool& chain, Cell* root, - const dict &parent_of, - int &correction + const dict& parent_of, + int& correction ) { + pool chain_bits = internal_bits(chain); dict negated; negated[root] = false; - std::queue worklist; - worklist.push(root); - - while (!worklist.empty()) { - Cell* cur = worklist.front(); - worklist.pop(); - for (auto cell : chain) { - if (!parent_of.count(cell)) - continue; - if (parent_of.at(cell) != cur) - continue; - if (negated.count(cell)) - continue; - - bool sub_b = is_subtracted_input(cell, cur); - negated[cell] = negated[cur] ^ sub_b; - worklist.push(cell); + { + std::queue q; + q.push(root); + while (!q.empty()) { + Cell* cur = q.front(); q.pop(); + for (auto cell : chain) { + if (!parent_of.count(cell) || parent_of.at(cell) != cur) + continue; + if (negated.count(cell)) + continue; + negated[cell] = negated[cur] ^ feeds_subtracted_port(cell, cur); + q.push(cell); + } } } @@ -177,35 +223,55 @@ struct CsaTreeWorker correction = 0; for (auto cell : chain) { - bool cell_neg = negated.count(cell) ? negated[cell] : false; + bool cell_neg; + if (negated.count(cell)) + cell_neg = negated[cell]; + else + cell_neg = false; + SigSpec a = sigmap(cell->getPort(ID::A)); SigSpec b = sigmap(cell->getPort(ID::B)); bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); - bool b_subtracted = (cell->type == ID($sub)); + bool b_sub = (cell->type == ID($sub)) || (is_alu(cell) && alu_is_subtract(cell)); - if (!is_chain_internal(a, chain_y_bits)) { - bool neg_a = cell_neg; - operands.push_back({a, a_signed, neg_a}); - if (neg_a) - correction++; + if (!overlaps(a, chain_bits)) { + bool neg = cell_neg; + operands.push_back({a, a_signed, neg}); + if (neg) correction++; } - - if (!is_chain_internal(b, chain_y_bits)) { - bool neg_b = cell_neg ^ b_subtracted; - operands.push_back({b, b_signed, neg_b}); - if (neg_b) - correction++; + if (!overlaps(b, chain_bits)) { + bool neg = cell_neg ^ b_sub; + operands.push_back({b, b_signed, neg}); + if (neg) correction++; } } - return operands; } - SigSpec extend_to(SigSpec sig, bool is_signed, int width) + bool extract_macc_operands(Cell* cell, std::vector& operands, int& correction) + { + Macc macc(cell); + correction = 0; + + for (auto& term : macc.terms) { + if (GetSize(term.in_b) != 0) + return false; + operands.push_back({term.in_a, term.is_signed, term.do_subtract}); + if (term.do_subtract) + correction++; + } + return true; + } + + SigSpec extend_operand(SigSpec sig, bool is_signed, int width) { if (GetSize(sig) < width) { - SigBit pad = (is_signed && GetSize(sig) > 0) ? sig[GetSize(sig) - 1] : State::S0; + SigBit pad; + if (is_signed && GetSize(sig) > 0) + pad = sig[GetSize(sig) - 1]; + else + pad = State::S0; sig.append(SigSpec(pad, width - GetSize(sig))); } if (GetSize(sig) > width) @@ -225,14 +291,9 @@ struct CsaTreeWorker return out; } - SigSpec make_constant(int value, int width) - { - return SigSpec(value, width); - } - std::pair emit_fa(SigSpec a, SigSpec b, SigSpec c, int width) { - SigSpec sum = module->addWire(NEW_ID, width); + SigSpec sum = module->addWire(NEW_ID, width); SigSpec cout = module->addWire(NEW_ID, width); Cell* fa = module->addCell(NEW_ID, ID($fa)); @@ -243,66 +304,10 @@ struct CsaTreeWorker fa->setPort(ID::X, cout); fa->setPort(ID::Y, sum); - SigSpec carry_shifted; - carry_shifted.append(State::S0); - carry_shifted.append(cout.extract(0, width - 1)); - - return {sum, carry_shifted}; - } - - std::pair build_wallace_tree(std::vector &operands, int width, int &fa_count) - { - std::vector ops; - for (auto &s : operands) - ops.push_back({s, 0}); - - fa_count = 0; - int level = 0; - - while (ops.size() > 2) - { - std::vector ready, waiting; - for (auto &op : ops) { - if (op.depth <= level) - ready.push_back(op); - else - waiting.push_back(op); - } - - if (ready.size() < 3) { - level++; - log_assert(level <= 100); - continue; - } - - std::vector next; - size_t i = 0; - while (i + 2 < ready.size()) { - auto [sum, carry] = emit_fa(ready[i].sig, ready[i+1].sig, ready[i+2].sig, width); - int d = std::max({ready[i].depth, ready[i+1].depth, ready[i+2].depth}) + 1; - next.push_back({sum, d}); - next.push_back({carry, d}); - fa_count++; - i += 3; - } - - for (; i < ready.size(); i++) - next.push_back(ready[i]); - - for (auto &op : waiting) - next.push_back(op); - - ops = std::move(next); - level++; - log_assert(level <= 100); - } - - log_assert(ops.size() == 2); - - int max_depth = std::max(ops[0].depth, ops[1].depth); - log(" Tree depth: %d FA levels + 1 final add\n", max_depth); - - return {ops[0].sig, ops[1].sig}; + SigSpec carry; + carry.append(State::S0); + carry.append(cout.extract(0, width - 1)); + return {sum, carry}; } void emit_final_add(SigSpec a, SigSpec b, SigSpec y, int width) @@ -318,30 +323,110 @@ struct CsaTreeWorker add->setPort(ID::Y, y); } - void run() + struct DepthSig { + SigSpec sig; + int depth; + }; + + std::pair reduce_wallace(std::vector& sigs, int width, int& fa_count) { - find_addsubs(); - if (all_addsubs.empty()) + std::vector ops; + ops.reserve(sigs.size()); + for (auto& s : sigs) + ops.push_back({s, 0}); + + fa_count = 0; + + for (int level = 0; ops.size() > 2; level++) { + log_assert(level <= 100); + + std::vector ready, waiting; + for (auto& op : ops) { + if (op.depth <= level) + ready.push_back(op); + else + waiting.push_back(op); + } + + if (ready.size() < 3) continue; + + std::vector next; + size_t i = 0; + while (i + 2 < ready.size()) { + auto [sum, carry] = emit_fa(ready[i].sig, ready[i + 1].sig, ready[i + 2].sig, width); + int d = std::max({ready[i].depth, ready[i + 1].depth,ready[i + 2].depth}) + 1; + next.push_back({sum, d}); + next.push_back({carry, d}); + fa_count++; + i += 3; + } + for (; i < ready.size(); i++) + next.push_back(ready[i]); + for (auto& op : waiting) + next.push_back(op); + + ops = std::move(next); + } + + log_assert(ops.size() == 2); + log(" Tree depth: %d FA levels + 1 final add\n", + std::max(ops[0].depth, ops[1].depth)); + return {ops[0].sig, ops[1].sig}; + } + + void replace_with_csa_tree( + std::vector& operands, + SigSpec result_y, + int correction, + const char* desc + ) { + int width = GetSize(result_y); + std::vector extended; + extended.reserve(operands.size() + 1); + + for (auto& op : operands) { + SigSpec s = extend_operand(op.sig, op.is_signed, width); + if (op.negate) + s = emit_not(s, width); + extended.push_back(s); + } + + if (correction > 0) + extended.push_back(SigSpec(correction, width)); + + int fa_count; + auto [a, b] = reduce_wallace(extended, width, fa_count); + + log(" %s → %d $fa + 1 $add (%d operands, module %s)\n", + desc, fa_count, (int)operands.size(), log_id(module)); + + emit_final_add(a, b, result_y, width); + } + + void process_chains() + { + pool candidates; + for (auto cell : addsub_cells) + candidates.insert(cell); + for (auto cell : alu_cells) + if (alu_is_chainable(cell)) + candidates.insert(cell); + + if (candidates.empty()) return; - build_fanout_map(); + auto parent_of = find_parents(candidates); - auto parent_of = find_addsub_parents(); - - pool has_parent; dict> children_of; - for (auto &pair : parent_of) { - has_parent.insert(pair.first); - children_of[pair.second].insert(pair.first); + pool has_parent; + for (auto& [child, parent] : parent_of) { + children_of[parent].insert(child); + has_parent.insert(child); } pool processed; - - for (auto root : all_addsubs) - { - if (has_parent.count(root)) - continue; - if (processed.count(root)) + for (auto root : candidates) { + if (has_parent.count(root) || processed.count(root)) continue; pool chain = collect_chain(root, children_of); @@ -351,52 +436,51 @@ struct CsaTreeWorker for (auto c : chain) processed.insert(c); - pool chain_y_bits = collect_chain_outputs(chain); - int correction = 0; - auto operands = collect_leaf_operands(chain, chain_y_bits, root, parent_of, correction); - + int correction; + auto operands = extract_chain_operands( + chain, root, parent_of, correction); if (operands.size() < 3) continue; - SigSpec root_y = root->getPort(ID::Y); - int width = GetSize(root_y); - std::vector extended; - - for (auto &op : operands) { - SigSpec s = extend_to(op.sig, op.is_signed, width); - if (op.negate) - s = emit_not(s, width); - extended.push_back(s); - } - - if (correction > 0) - extended.push_back(make_constant(correction, width)); - - int fa_count; - auto [final_a, final_b] = build_wallace_tree(extended, width, fa_count); - int num_subs = 0; - - for (auto cell : chain) - if (cell->type == ID($sub)) - num_subs++; - - if (num_subs > 0) - log(" Replaced chain of %d $add/%d $sub cells with %d $fa + 1 $add (%d operands, module %s)\n", - (int)chain.size() - num_subs, num_subs, fa_count, (int)operands.size(), log_id(module)); - else - log(" Replaced chain of %d $add cells with %d $fa + 1 $add (%d operands, module %s)\n", - (int)chain.size(), fa_count, (int)operands.size(), log_id(module)); - - emit_final_add(final_a, final_b, root_y, width); - + replace_with_csa_tree(operands, root->getPort(ID::Y), + correction, "Replaced add/sub chain"); for (auto cell : chain) module->remove(cell); } } + + void process_maccs() + { + for (auto cell : macc_cells) { + std::vector operands; + int correction; + if (!extract_macc_operands(cell, operands, correction)) + continue; + if (operands.size() < 3) + continue; + + replace_with_csa_tree(operands, cell->getPort(ID::Y), + correction, "Replaced $macc"); + module->remove(cell); + } + } + + void run() + { + classify_cells(); + + if (addsub_cells.empty() && alu_cells.empty() && macc_cells.empty()) + return; + + build_fanout_map(); + process_chains(); + process_maccs(); + } }; struct CsaTreePass : public Pass { - CsaTreePass() : Pass("csa_tree", "convert $add/$sub chains to carry-save adder trees") {} + CsaTreePass() : Pass("csa_tree", + "convert add/sub/macc chains to carry-save adder trees") {} void help() override { @@ -404,17 +488,17 @@ struct CsaTreePass : public Pass { log("\n"); log(" csa_tree [selection]\n"); log("\n"); - log("This pass finds chains of $add and $sub cells and replaces them with carry-save\n"); - log("adder trees built from $fa cells, followed by a single final $add for the\n"); - log("carry-propagate step.\n"); + log("This pass replaces chains of $add/$sub cells, $alu cells (with constant\n"); + log("BI/CI), and $macc/$macc_v2 cells (without multiplications) with carry-save\n"); + log("adder trees using $fa cells and a single final $add.\n"); log("\n"); - log("The tree uses Wallace-tree scheduling for optimal depth: at each level, all ready\n"); - log("operands are grouped into triplets and compressed via full adders. This\n"); - log("gives ceil(log_1.5(N)) FA levels for N input operands.\n"); + log("The tree uses Wallace-tree scheduling: at each level, ready operands are\n"); + log("grouped into triplets and compressed via full adders, giving\n"); + log("O(log_{1.5} N) depth for N input operands.\n"); log("\n"); } - void execute(std::vector args, RTLIL::Design *design) override + void execute(std::vector args, RTLIL::Design* design) override { log_header(design, "Executing CSA_TREE pass.\n");