3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2026-03-23 04:49:15 +00:00

better balancing.

This commit is contained in:
nella 2026-03-13 11:06:48 +01:00
parent a180a0003f
commit 728403d1eb

View file

@ -1,355 +1,342 @@
#include "kernel/yosys.h"
#include "kernel/sigtools.h"
#include <queue>
USING_YOSYS_NAMESPACE
PRIVATE_NAMESPACE_BEGIN
struct CsaTreeWorker
{
RTLIL::Module *module;
Module *module;
SigMap sigmap;
int min_operands;
dict<RTLIL::SigBit, RTLIL::Cell*> sig_to_driver;
dict<RTLIL::Cell*, int> cell_fanout;
pool<RTLIL::Cell*> consumed;
dict<SigBit, pool<Cell*>> bit_consumers;
dict<SigBit, int> fanout;
pool<Cell*> all_adds;
int stat_trees = 0;
int stat_fa_cells = 0;
int stat_removed_cells = 0;
CsaTreeWorker(Module *module) : module(module), sigmap(module) {}
CsaTreeWorker(RTLIL::Module *module, int min_operands) :
module(module), sigmap(module), min_operands(min_operands) {}
struct DepthSig {
SigSpec sig;
int depth;
};
void build_maps()
void find_adds()
{
dict<RTLIL::SigBit, int> sig_consumers;
for (auto cell : module->cells())
{
if (cell->type.in(ID($add), ID($sub)))
{
RTLIL::SigSpec y = sigmap(cell->getPort(ID::Y));
for (auto bit : y)
if (bit.wire != nullptr)
sig_to_driver[bit] = cell;
}
if (cell->type == ID($add))
all_adds.insert(cell);
}
void build_fanout_map()
{
for (auto cell : module->cells())
for (auto &conn : cell->connections())
{
if (cell->input(conn.first))
{
for (auto bit : sigmap(conn.second))
if (bit.wire != nullptr)
sig_consumers[bit]++;
}
}
}
bit_consumers[bit].insert(cell);
for (auto &pair : bit_consumers)
fanout[pair.first] = pair.second.size();
for (auto wire : module->wires())
if (wire->port_output)
for (auto bit : sigmap(wire))
if (bit.wire != nullptr)
sig_consumers[bit]++;
for (auto bit : sigmap(SigSpec(wire)))
fanout[bit]++;
}
for (auto cell : module->cells())
{
if (!cell->type.in(ID($add), ID($sub)))
continue;
int fo = 0;
for (auto bit : sigmap(cell->getPort(ID::Y)))
if (bit.wire != nullptr)
fo = std::max(fo, sig_consumers.count(bit) ? sig_consumers.at(bit) : 0);
cell_fanout[cell] = fo;
Cell*single_add_consumer(SigSpec sig)
{
Cell*consumer = nullptr;
for (auto bit : sig) {
if (!fanout.count(bit) || fanout[bit] != 1)
return nullptr;
if (!bit_consumers.count(bit) || bit_consumers[bit].size() != 1)
return nullptr;
Cell* c = *bit_consumers[bit].begin();
if (!all_adds.count(c))
return nullptr;
if (consumer == nullptr)
consumer = c;
else if (consumer != c)
return nullptr;
}
return consumer;
}
dict<Cell*, Cell*> find_add_parents()
{
dict<Cell*, Cell*> parent_of;
for (auto cell : all_adds) {
SigSpec y = sigmap(cell->getPort(ID::Y));
Cell* consumer = single_add_consumer(y);
if (consumer != nullptr && consumer != cell)
parent_of[cell] = consumer;
}
return parent_of;
}
pool<Cell*> collect_chain(Cell* root, const dict<Cell*, pool<Cell*>> &children_of)
{
pool<Cell*> chain;
std::queue<Cell*> worklist;
worklist.push(root);
while (!worklist.empty()) {
Cell* cur = worklist.front();
worklist.pop();
if (chain.count(cur))
continue;
chain.insert(cur);
if (children_of.count(cur))
for (auto child : children_of.at(cur))
worklist.push(child);
}
return chain;
}
bool is_chain_internal(SigSpec sig, const pool<SigBit> &chain_y_bits)
{
for (auto bit : sig)
if (chain_y_bits.count(bit))
return true;
return false;
}
pool<SigBit> collect_chain_outputs(const pool<Cell*> &chain)
{
pool<SigBit> bits;
for (auto cell : chain)
for (auto bit : sigmap(cell->getPort(ID::Y)))
bits.insert(bit);
return bits;
}
struct Operand {
RTLIL::SigSpec sig;
SigSpec sig;
bool is_signed;
bool do_subtract;
};
bool can_absorb(RTLIL::Cell *cell)
std::vector<Operand> collect_leaf_operands(const pool<Cell*> &chain, const pool<SigBit> &chain_y_bits)
{
if (cell == nullptr)
return false;
if (!cell->type.in(ID($add), ID($sub)))
return false;
if (consumed.count(cell))
return false;
if (cell_fanout.count(cell) ? cell_fanout.at(cell) != 1 : true)
return false;
return true;
std::vector<Operand> operands;
for (auto cell : chain) {
SigSpec a = sigmap(cell->getPort(ID::A));
SigSpec b = sigmap(cell->getPort(ID::B));
bool a_signed = cell->getParam(ID::A_SIGNED).as_bool();
bool b_signed = cell->getParam(ID::B_SIGNED).as_bool();
if (!is_chain_internal(a, chain_y_bits))
operands.push_back({a, a_signed});
if (!is_chain_internal(b, chain_y_bits))
operands.push_back({b, b_signed});
}
return operands;
}
RTLIL::Cell *get_driver(RTLIL::SigSpec sig)
SigSpec extend_to(SigSpec sig, bool is_signed, int width)
{
sig = sigmap(sig);
if (sig.empty())
return nullptr;
RTLIL::Cell *driver = nullptr;
for (auto bit : sig)
{
if (bit.wire == nullptr)
continue;
auto it = sig_to_driver.find(bit);
if (it == sig_to_driver.end())
return nullptr;
if (driver == nullptr)
driver = it->second;
else if (driver != it->second)
return nullptr; // mixed
if (GetSize(sig) < width) {
SigBit pad = (is_signed && GetSize(sig) > 0) ? sig[GetSize(sig) - 1] : State::S0;
sig.append(SigSpec(pad, width - GetSize(sig)));
}
return driver;
if (GetSize(sig) > width)
sig = sig.extract(0, width);
return sig;
}
void collect_operands(
RTLIL::Cell *cell,
bool negate,
std::vector<Operand> &operands,
std::vector<RTLIL::Cell*> &tree_cells
) {
tree_cells.push_back(cell);
consumed.insert(cell);
std::pair<SigSpec, SigSpec> emit_fa(SigSpec a, SigSpec b, SigSpec c, int width)
{
SigSpec sum = module->addWire(NEW_ID, width);
SigSpec cout = module->addWire(NEW_ID, width);
bool a_signed = cell->getParam(ID::A_SIGNED).as_bool();
bool b_signed = cell->getParam(ID::B_SIGNED).as_bool();
bool is_sub = (cell->type == ID($sub));
RTLIL::SigSpec sig_a = cell->getPort(ID::A);
RTLIL::SigSpec sig_b = cell->getPort(ID::B);
RTLIL::Cell *driver_a = get_driver(sig_a);
if (can_absorb(driver_a)) {
collect_operands(driver_a, negate, operands, tree_cells);
} else {
operands.push_back({sig_a, a_signed, negate});
}
bool b_negate = negate ^ is_sub;
RTLIL::Cell *driver_b = get_driver(sig_b);
if (can_absorb(driver_b)) {
collect_operands(driver_b, b_negate, operands, tree_cells);
} else {
operands.push_back({sig_b, b_signed, b_negate});
}
}
void create_fa(
RTLIL::SigSpec a,
RTLIL::SigSpec b,
RTLIL::SigSpec c,
int width,
RTLIL::SigSpec &sum_out,
RTLIL::SigSpec &carry_out
) {
RTLIL::Wire *w_sum = module->addWire(NEW_ID, width);
RTLIL::Wire *w_carry = module->addWire(NEW_ID, width);
RTLIL::Cell *fa = module->addCell(NEW_ID, ID($fa));
Cell* fa = module->addCell(NEW_ID, ID($fa));
fa->setParam(ID::WIDTH, width);
fa->setPort(ID::A, a);
fa->setPort(ID::B, b);
fa->setPort(ID::C, c);
fa->setPort(ID::Y, w_sum);
fa->setPort(ID::X, w_carry);
fa->setPort(ID::X, cout);
fa->setPort(ID::Y, sum);
sum_out = w_sum;
carry_out = w_carry;
stat_fa_cells++;
SigSpec carry_shifted;
carry_shifted.append(State::S0);
carry_shifted.append(cout.extract(0, width - 1));
return {sum, carry_shifted};
}
RTLIL::SigSpec extend_to(RTLIL::SigSpec sig, bool is_signed, int target_width)
std::pair<SigSpec, SigSpec> build_wallace_tree(std::vector<SigSpec> &operands, int width, int &fa_count)
{
if (GetSize(sig) >= target_width)
return sig.extract(0, target_width);
std::vector<DepthSig> ops;
for (auto &s : operands)
ops.push_back({s, 0});
RTLIL::SigSpec result = sig;
RTLIL::SigBit pad = is_signed ? sig[GetSize(sig) - 1] : RTLIL::S0;
while (GetSize(result) < target_width)
result.append(pad);
return result;
}
fa_count = 0;
int level = 0;
RTLIL::SigSpec build_csa_tree(std::vector<Operand> &operands, int output_width)
{
int width = output_width;
std::vector<RTLIL::SigSpec> summands;
int sub_count = 0;
for (auto &op : operands)
while (ops.size() > 2)
{
RTLIL::SigSpec sig = extend_to(op.sig, op.is_signed, width);
if (op.do_subtract) {
sig = module->Not(NEW_ID, sig);
sub_count++;
std::vector<DepthSig> ready, waiting;
for (auto &op : ops) {
if (op.depth <= level)
ready.push_back(op);
else
waiting.push_back(op);
}
summands.push_back(sig);
}
if (ready.size() < 3) {
level++;
log_assert(level <= 100);
continue;
}
if (sub_count > 0) {
RTLIL::Const correction(sub_count, width);
summands.push_back(RTLIL::SigSpec(correction));
}
if (summands.empty())
return RTLIL::SigSpec(0, width);
if (summands.size() == 1)
return summands[0];
if (summands.size() == 2) {
RTLIL::Wire *result = module->addWire(NEW_ID, width);
module->addAdd(NEW_ID, summands[0], summands[1], result);
return result;
}
while (summands.size() > 2)
{
std::vector<RTLIL::SigSpec> next;
int i = 0;
while (i + 2 < (int)summands.size())
{
RTLIL::SigSpec a = summands[i];
RTLIL::SigSpec b = summands[i + 1];
RTLIL::SigSpec c = summands[i + 2];
RTLIL::SigSpec sum, carry;
create_fa(a, b, c, width, sum, carry);
RTLIL::SigSpec carry_shifted;
carry_shifted.append(RTLIL::S0);
carry_shifted.append(carry.extract(0, width - 1));
next.push_back(sum);
next.push_back(carry_shifted);
std::vector<DepthSig> next;
size_t i = 0;
while (i + 2 < ready.size()) {
auto [sum, carry] = emit_fa(ready[i].sig, ready[i+1].sig, ready[i+2].sig, width);
int d = std::max({ready[i].depth, ready[i+1].depth, ready[i+2].depth}) + 1;
next.push_back({sum, d});
next.push_back({carry, d});
fa_count++;
i += 3;
}
while (i < (int)summands.size())
next.push_back(summands[i++]);
for (; i < ready.size(); i++)
next.push_back(ready[i]);
summands.swap(next);
for (auto &op : waiting)
next.push_back(op);
ops = std::move(next);
level++;
log_assert(level <= 100);
}
RTLIL::Wire *result = module->addWire(NEW_ID, width);
module->addAdd(NEW_ID, summands[0], summands[1], result);
return result;
log_assert(ops.size() == 2);
int max_depth = std::max(ops[0].depth, ops[1].depth);
log(" Tree depth: %d FA levels + 1 final add\n", max_depth);
return {ops[0].sig, ops[1].sig};
}
void emit_final_add(SigSpec a, SigSpec b, SigSpec y, int width)
{
Cell* add = module->addCell(NEW_ID, ID($add));
add->setParam(ID::A_SIGNED, false);
add->setParam(ID::B_SIGNED, false);
add->setParam(ID::A_WIDTH, width);
add->setParam(ID::B_WIDTH, width);
add->setParam(ID::Y_WIDTH, width);
add->setPort(ID::A, a);
add->setPort(ID::B, b);
add->setPort(ID::Y, y);
}
void run()
{
build_maps();
find_adds();
if (all_adds.empty())
return;
std::vector<RTLIL::Cell*> roots;
for (auto cell : module->selected_cells())
if (cell->type.in(ID($add), ID($sub)))
roots.push_back(cell);
build_fanout_map();
std::sort(roots.begin(), roots.end(),
[](RTLIL::Cell *a, RTLIL::Cell *b) {
return a->name < b->name;
});
auto parent_of = find_add_parents();
std::sort(roots.begin(), roots.end(),
[&](RTLIL::Cell *a, RTLIL::Cell *b) {
return (cell_fanout.count(a) ? cell_fanout.at(a) : 0) >
(cell_fanout.count(b) ? cell_fanout.at(b) : 0);
});
pool<Cell*> has_parent;
dict<Cell*, pool<Cell*>> children_of;
for (auto &pair : parent_of) {
has_parent.insert(pair.first);
children_of[pair.second].insert(pair.first);
}
for (auto root : roots)
pool<Cell*> processed;
for (auto root : all_adds)
{
if (consumed.count(root))
if (has_parent.count(root))
continue;
if (processed.count(root))
continue;
std::vector<Operand> operands;
std::vector<RTLIL::Cell*> tree_cells;
collect_operands(root, false, operands, tree_cells);
if ((int)operands.size() < min_operands) {
for (auto c : tree_cells)
consumed.erase(c);
pool<Cell*> chain = collect_chain(root, children_of);
if (chain.size() < 2)
continue;
}
int output_width = root->getParam(ID::Y_WIDTH).as_int();
for (auto c : chain)
processed.insert(c);
log(" Found adder tree rooted at %s with %d operands (depth %d cells)\n",
log_id(root), (int)operands.size(), (int)tree_cells.size());
pool<SigBit> chain_y_bits = collect_chain_outputs(chain);
auto operands = collect_leaf_operands(chain, chain_y_bits);
RTLIL::SigSpec new_output = build_csa_tree(operands, output_width);
RTLIL::SigSpec old_output = root->getPort(ID::Y);
module->connect(old_output, new_output);
if (operands.size() < 3)
continue;
for (auto c : tree_cells) {
module->remove(c);
stat_removed_cells++;
}
SigSpec root_y = root->getPort(ID::Y);
int width = GetSize(root_y);
stat_trees++;
std::vector<SigSpec> extended;
for (auto &op : operands)
extended.push_back(extend_to(op.sig, op.is_signed, width));
int fa_count;
auto [final_a, final_b] = build_wallace_tree(extended, width, fa_count);
log(" Replaced chain of %d $add cells with %d $fa + 1 $add (%d operands, module %s)\n",
(int)chain.size(), fa_count, (int)operands.size(), log_id(module));
emit_final_add(final_a, final_b, root_y, width);
for (auto cell : chain)
module->remove(cell);
}
}
};
struct CsaTreePass : public Pass
{
CsaTreePass() : Pass("csa_tree",
"convert adder chains to carry-save adder trees") {}
struct CsaTreePass : public Pass {
CsaTreePass() : Pass("csa_tree", "convert $add chains to carry-save adder trees") {}
void help() override
{
// |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|
// |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|
log("\n");
log(" csa_tree [options] [selection]\n");
log(" csa_tree [selection]\n");
log("\n");
log("This pass converts chains of $add/$sub cells into carry-save adder trees using\n");
log("$fa (full adder / 3:2 compressor) cells to reduce the critical path depth of\n");
log("multi-operand addition.\n");
log("This pass finds chains of $add cells and replaces them with\n");
log("carry-save adder (CSA) trees built from $fa cells, followed by\n");
log("a single final $add for the carry-propagate step.\n");
log("\n");
log("For N operands of width W, the critical path is reduced from\n");
log("O(N * log W) to O(log_1.5(N) + log W).\n");
log("\n");
log(" -min_operands N\n");
log(" Minimum number of operands to trigger CSA tree construction.\n");
log(" Default: 3. Values below 3 are clamped to 3.\n");
log("The tree uses Wallace-tree scheduling for optimal depth:\n");
log("at each level, all ready operands are grouped into triplets\n");
log("and compressed via full adders. This gives ceil(log_1.5(N))\n");
log("FA levels for N input operands.\n");
log("\n");
}
void execute(std::vector<std::string> args, RTLIL::Design *design) override
{
int min_operands = 3;
log_header(design, "Executing CSA_TREE pass (carry-save adder tree optimization).\n");
log_header(design, "Executing CSA_TREE pass.\n");
size_t argidx;
for (argidx = 1; argidx < args.size(); argidx++)
{
if (args[argidx] == "-min_operands" && argidx + 1 < args.size()) {
min_operands = std::max(3, atoi(args[++argidx].c_str()));
continue;
}
break;
}
extra_args(args, argidx, design);
for (auto module : design->selected_modules())
{
log("Processing module %s...\n", log_id(module));
CsaTreeWorker worker(module, min_operands);
for (auto module : design->selected_modules()) {
CsaTreeWorker worker(module);
worker.run();
if (worker.stat_trees > 0)
log(" Converted %d adder tree(s): created %d $fa cells, "
"removed %d $add/$sub cells.\n",
worker.stat_trees, worker.stat_fa_cells,
worker.stat_removed_cells);
}
}
} CsaTreePass;