3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2026-05-31 06:07:47 +00:00

csa_tree: refactor

This commit is contained in:
Emil J. Tywoniak 2026-03-31 11:56:12 +02:00 committed by nella
parent c3bc2d88da
commit eb477b2d56

View file

@ -13,19 +13,32 @@ struct Operand {
bool negate; bool negate;
}; };
struct CsaTreeWorker struct Traversal
{ {
Module* module;
SigMap sigmap; SigMap sigmap;
dict<SigBit, pool<Cell*>> bit_consumers; dict<SigBit, pool<Cell*>> bit_consumers;
dict<SigBit, int> fanout; dict<SigBit, int> fanout;
Traversal(Module* module) : sigmap(module) {
for (auto cell : module->cells())
for (auto& conn : cell->connections())
if (cell->input(conn.first))
for (auto bit : sigmap(conn.second))
bit_consumers[bit].insert(cell);
pool<Cell*> addsub_cells; for (auto& pair : bit_consumers)
pool<Cell*> alu_cells; fanout[pair.first] = pair.second.size();
pool<Cell*> macc_cells;
CsaTreeWorker(Module* module) : module(module), sigmap(module) {} for (auto wire : module->wires())
if (wire->port_output)
for (auto bit : sigmap(SigSpec(wire)))
fanout[bit]++;
}
};
struct Cells {
pool<Cell*> addsub;
pool<Cell*> alu;
pool<Cell*> macc;
static bool is_addsub(Cell* cell) static bool is_addsub(Cell* cell)
{ {
@ -42,79 +55,74 @@ struct CsaTreeWorker
return cell->type == ID($macc) || cell->type == ID($macc_v2); return cell->type == ID($macc) || cell->type == ID($macc_v2);
} }
bool alu_is_subtract(Cell* cell) bool empty() {
return addsub.empty() && alu.empty() && macc.empty();
}
Cells(Module* module) {
for (auto cell : module->cells()) {
if (is_addsub(cell))
addsub.insert(cell);
else if (is_alu(cell))
alu.insert(cell);
else if (is_macc(cell))
macc.insert(cell);
}
}
};
struct AluInfo {
Cells& cells;
Traversal& traversal;
bool is_subtract(Cell* cell)
{ {
SigSpec bi = sigmap(cell->getPort(ID::BI)); SigSpec bi = traversal.sigmap(cell->getPort(ID::BI));
SigSpec ci = sigmap(cell->getPort(ID::CI)); SigSpec ci = traversal.sigmap(cell->getPort(ID::CI));
return GetSize(bi) == 1 && bi[0] == State::S1 && GetSize(ci) == 1 && ci[0] == State::S1; return GetSize(bi) == 1 && bi[0] == State::S1 && GetSize(ci) == 1 && ci[0] == State::S1;
} }
bool alu_is_add(Cell* cell) bool is_add(Cell* cell)
{ {
SigSpec bi = sigmap(cell->getPort(ID::BI)); SigSpec bi = traversal.sigmap(cell->getPort(ID::BI));
SigSpec ci = sigmap(cell->getPort(ID::CI)); SigSpec ci = traversal.sigmap(cell->getPort(ID::CI));
return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0; return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0;
} }
bool alu_is_chainable(Cell* cell)
{
if (!(alu_is_add(cell) || alu_is_subtract(cell)))
return false;
for (auto bit : sigmap(cell->getPort(ID::X)))
if (fanout.count(bit) && fanout[bit] > 0)
return false;
for (auto bit : sigmap(cell->getPort(ID::CO)))
if (fanout.count(bit) && fanout[bit] > 0)
return false;
return true;
}
bool is_chainable(Cell* cell) bool is_chainable(Cell* cell)
{ {
return is_addsub(cell) || (is_alu(cell) && alu_is_chainable(cell)); if (!(is_add(cell) || is_subtract(cell)))
return false;
for (auto bit : traversal.sigmap(cell->getPort(ID::X)))
if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0)
return false;
for (auto bit : traversal.sigmap(cell->getPort(ID::CO)))
if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0)
return false;
return true;
} }
};
void classify_cells() struct Rewriter
{ {
for (auto cell : module->cells()) { Module* module;
if (is_addsub(cell)) Cells& cells;
addsub_cells.insert(cell); Traversal traversal;
else if (is_alu(cell)) AluInfo alu_info;
alu_cells.insert(cell);
else if (is_macc(cell))
macc_cells.insert(cell);
}
}
void build_fanout_map() Rewriter(Module* module, Cells& cells) : module(module), cells(cells), traversal(module), alu_info{cells, traversal} {}
{
for (auto cell : module->cells())
for (auto& conn : cell->connections())
if (cell->input(conn.first))
for (auto bit : sigmap(conn.second))
bit_consumers[bit].insert(cell);
for (auto& pair : bit_consumers)
fanout[pair.first] = pair.second.size();
for (auto wire : module->wires())
if (wire->port_output)
for (auto bit : sigmap(SigSpec(wire)))
fanout[bit]++;
}
Cell* sole_chainable_consumer(SigSpec sig, const pool<Cell*>& candidates) Cell* sole_chainable_consumer(SigSpec sig, const pool<Cell*>& candidates)
{ {
Cell* consumer = nullptr; Cell* consumer = nullptr;
for (auto bit : sig) { for (auto bit : sig) {
if (!fanout.count(bit) || fanout[bit] != 1) if (!traversal.fanout.count(bit) || traversal.fanout[bit] != 1)
return nullptr; return nullptr;
if (!bit_consumers.count(bit) || bit_consumers[bit].size() != 1) if (!traversal.bit_consumers.count(bit) || traversal.bit_consumers[bit].size() != 1)
return nullptr; return nullptr;
Cell* c = *bit_consumers[bit].begin(); Cell* c = *traversal.bit_consumers[bit].begin();
if (!candidates.count(c)) if (!candidates.count(c))
return nullptr; return nullptr;
@ -131,7 +139,7 @@ struct CsaTreeWorker
dict<Cell*, Cell*> parent_of; dict<Cell*, Cell*> parent_of;
for (auto cell : candidates) { for (auto cell : candidates) {
Cell* consumer = sole_chainable_consumer( Cell* consumer = sole_chainable_consumer(
sigmap(cell->getPort(ID::Y)), candidates); traversal.sigmap(cell->getPort(ID::Y)), candidates);
if (consumer && consumer != cell) if (consumer && consumer != cell)
parent_of[cell] = consumer; parent_of[cell] = consumer;
} }
@ -144,7 +152,8 @@ struct CsaTreeWorker
std::queue<Cell*> q; std::queue<Cell*> q;
q.push(root); q.push(root);
while (!q.empty()) { while (!q.empty()) {
Cell* cur = q.front(); q.pop(); Cell* cur = q.front();
q.pop();
if (!chain.insert(cur).second) if (!chain.insert(cur).second)
continue; continue;
auto it = children_of.find(cur); auto it = children_of.find(cur);
@ -159,7 +168,7 @@ struct CsaTreeWorker
{ {
pool<SigBit> bits; pool<SigBit> bits;
for (auto cell : chain) for (auto cell : chain)
for (auto bit : sigmap(cell->getPort(ID::Y))) for (auto bit : traversal.sigmap(cell->getPort(ID::Y)))
bits.insert(bit); bits.insert(bit);
return bits; return bits;
} }
@ -177,16 +186,16 @@ struct CsaTreeWorker
bool parent_subtracts; bool parent_subtracts;
if (parent->type == ID($sub)) if (parent->type == ID($sub))
parent_subtracts = true; parent_subtracts = true;
else if (is_alu(parent)) else if (cells.is_alu(parent))
parent_subtracts = alu_is_subtract(parent); parent_subtracts = alu_info.is_subtract(parent);
else else
return false; return false;
if (!parent_subtracts) if (!parent_subtracts)
return false; return false;
SigSpec child_y = sigmap(child->getPort(ID::Y)); SigSpec child_y = traversal.sigmap(child->getPort(ID::Y));
SigSpec parent_b = sigmap(parent->getPort(ID::B)); SigSpec parent_b = traversal.sigmap(parent->getPort(ID::B));
for (auto bit : child_y) for (auto bit : child_y)
for (auto pbit : parent_b) for (auto pbit : parent_b)
if (bit == pbit) if (bit == pbit)
@ -229,11 +238,11 @@ struct CsaTreeWorker
else else
cell_neg = false; cell_neg = false;
SigSpec a = sigmap(cell->getPort(ID::A)); SigSpec a = traversal.sigmap(cell->getPort(ID::A));
SigSpec b = sigmap(cell->getPort(ID::B)); SigSpec b = traversal.sigmap(cell->getPort(ID::B));
bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); bool a_signed = cell->getParam(ID::A_SIGNED).as_bool();
bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); bool b_signed = cell->getParam(ID::B_SIGNED).as_bool();
bool b_sub = (cell->type == ID($sub)) || (is_alu(cell) && alu_is_subtract(cell)); bool b_sub = (cell->type == ID($sub)) || (cells.is_alu(cell) && alu_info.is_subtract(cell));
if (!overlaps(a, chain_bits)) { if (!overlaps(a, chain_bits)) {
bool neg = cell_neg; bool neg = cell_neg;
@ -255,6 +264,7 @@ struct CsaTreeWorker
correction = 0; correction = 0;
for (auto& term : macc.terms) { for (auto& term : macc.terms) {
// Bail on multiplication
if (GetSize(term.in_b) != 0) if (GetSize(term.in_b) != 0)
return false; return false;
operands.push_back({term.in_a, term.is_signed, term.do_subtract}); operands.push_back({term.in_a, term.is_signed, term.do_subtract});
@ -279,30 +289,12 @@ struct CsaTreeWorker
return sig; return sig;
} }
SigSpec emit_not(SigSpec sig, int width)
{
SigSpec out = module->addWire(NEW_ID, width);
Cell* inv = module->addCell(NEW_ID, ID($not));
inv->setParam(ID::A_SIGNED, false);
inv->setParam(ID::A_WIDTH, width);
inv->setParam(ID::Y_WIDTH, width);
inv->setPort(ID::A, sig);
inv->setPort(ID::Y, out);
return out;
}
std::pair<SigSpec, SigSpec> emit_fa(SigSpec a, SigSpec b, SigSpec c, int width) std::pair<SigSpec, SigSpec> emit_fa(SigSpec a, SigSpec b, SigSpec c, int width)
{ {
SigSpec sum = module->addWire(NEW_ID, width); SigSpec sum = module->addWire(NEW_ID, width);
SigSpec cout = module->addWire(NEW_ID, width); SigSpec cout = module->addWire(NEW_ID, width);
Cell* fa = module->addCell(NEW_ID, ID($fa)); module->addFa(NEW_ID, a, b, c, cout, sum);
fa->setParam(ID::WIDTH, width);
fa->setPort(ID::A, a);
fa->setPort(ID::B, b);
fa->setPort(ID::C, c);
fa->setPort(ID::X, cout);
fa->setPort(ID::Y, sum);
SigSpec carry; SigSpec carry;
carry.append(State::S0); carry.append(State::S0);
@ -310,19 +302,6 @@ struct CsaTreeWorker
return {sum, carry}; return {sum, carry};
} }
void emit_final_add(SigSpec a, SigSpec b, SigSpec y, int width)
{
Cell* add = module->addCell(NEW_ID, ID($add));
add->setParam(ID::A_SIGNED, false);
add->setParam(ID::B_SIGNED, false);
add->setParam(ID::A_WIDTH, width);
add->setParam(ID::B_WIDTH, width);
add->setParam(ID::Y_WIDTH, width);
add->setPort(ID::A, a);
add->setPort(ID::B, b);
add->setPort(ID::Y, y);
}
struct DepthSig { struct DepthSig {
SigSpec sig; SigSpec sig;
int depth; int depth;
@ -387,7 +366,7 @@ struct CsaTreeWorker
for (auto& op : operands) { for (auto& op : operands) {
SigSpec s = extend_operand(op.sig, op.is_signed, width); SigSpec s = extend_operand(op.sig, op.is_signed, width);
if (op.negate) if (op.negate)
s = emit_not(s, width); s = module->Not(NEW_ID, s);
extended.push_back(s); extended.push_back(s);
} }
@ -400,16 +379,17 @@ struct CsaTreeWorker
log(" %s -> %d $fa + 1 $add (%d operands, module %s)\n", log(" %s -> %d $fa + 1 $add (%d operands, module %s)\n",
desc, fa_count, (int)operands.size(), log_id(module)); desc, fa_count, (int)operands.size(), log_id(module));
emit_final_add(a, b, result_y, width); // Emit final add
module->addAdd(NEW_ID, a, b, result_y, false);
} }
void process_chains() void process_chains()
{ {
pool<Cell*> candidates; pool<Cell*> candidates;
for (auto cell : addsub_cells) for (auto cell : cells.addsub)
candidates.insert(cell); candidates.insert(cell);
for (auto cell : alu_cells) for (auto cell : cells.alu)
if (alu_is_chainable(cell)) if (alu_info.is_chainable(cell))
candidates.insert(cell); candidates.insert(cell);
if (candidates.empty()) if (candidates.empty())
@ -427,7 +407,7 @@ struct CsaTreeWorker
pool<Cell*> processed; pool<Cell*> processed;
for (auto root : candidates) { for (auto root : candidates) {
if (has_parent.count(root) || processed.count(root)) if (has_parent.count(root) || processed.count(root))
continue; continue; // Not a tree root
pool<Cell*> chain = collect_chain(root, children_of); pool<Cell*> chain = collect_chain(root, children_of);
if (chain.size() < 2) if (chain.size() < 2)
@ -451,7 +431,7 @@ struct CsaTreeWorker
void process_maccs() void process_maccs()
{ {
for (auto cell : macc_cells) { for (auto cell : cells.macc) {
std::vector<Operand> operands; std::vector<Operand> operands;
int correction; int correction;
if (!extract_macc_operands(cell, operands, correction)) if (!extract_macc_operands(cell, operands, correction))
@ -464,20 +444,19 @@ struct CsaTreeWorker
module->remove(cell); module->remove(cell);
} }
} }
void run()
{
classify_cells();
if (addsub_cells.empty() && alu_cells.empty() && macc_cells.empty())
return;
build_fanout_map();
process_chains();
process_maccs();
}
}; };
void run(Module* module) {
Cells cells(module);
if (cells.empty())
return;
Rewriter rewriter {module, cells};
rewriter.process_chains();
rewriter.process_maccs();
}
struct CsaTreePass : public Pass { struct CsaTreePass : public Pass {
CsaTreePass() : Pass("csa_tree", CsaTreePass() : Pass("csa_tree",
"convert add/sub/macc chains to carry-save adder trees") {} "convert add/sub/macc chains to carry-save adder trees") {}
@ -508,8 +487,7 @@ struct CsaTreePass : public Pass {
extra_args(args, argidx, design); extra_args(args, argidx, design);
for (auto module : design->selected_modules()) { for (auto module : design->selected_modules()) {
CsaTreeWorker worker(module); run(module);
worker.run();
} }
} }
} CsaTreePass; } CsaTreePass;