3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2026-05-30 05:46:32 +00:00

Improve arith_tree: FMA add, elarith WIP.

This commit is contained in:
nella 2026-05-18 13:39:04 +02:00
parent e87a9bd9a7
commit d6a01d9200
4 changed files with 599 additions and 255 deletions

View file

@ -1,5 +1,5 @@
/**
* Replaces chains of $add/$sub and $macc cells with carry-save adder trees
* Replaces chains of $add/$sub/$alu and $macc cells with carry-save compression trees
*
* Terminology:
* - parent: Cells that consume another cell's output
@ -7,9 +7,9 @@
* - chain: Connected path of chainable cells
*/
#include "kernel/compressor_tree.h"
#include "kernel/macc.h"
#include "kernel/sigtools.h"
#include "kernel/wallace_tree.h"
#include "kernel/yosys.h"
#include <queue>
@ -17,49 +17,58 @@
USING_YOSYS_NAMESPACE
PRIVATE_NAMESPACE_BEGIN
struct Operand {
SigSpec sig;
bool is_signed;
bool negate;
struct ArithTreeOptions {
CompressorTree::Strategy strategy = CompressorTree::Strategy::PREFER_42;
CompressorTree::FinalMode final_mode = CompressorTree::FinalMode::AUTO;
bool fma_fusion = true;
bool elarith_macro = false;
};
struct Traversal {
struct ArithTreeWorker {
const ArithTreeOptions &opt;
Module *module;
SigMap sigmap;
dict<SigBit, pool<Cell *>> bit_consumers;
dict<SigBit, int> fanout;
Traversal(Module *module) : sigmap(module)
{
for (auto cell : module->cells())
for (auto &conn : cell->connections())
if (cell->input(conn.first))
for (auto bit : sigmap(conn.second))
bit_consumers[bit].insert(cell);
for (auto &pair : bit_consumers)
fanout[pair.first] = pair.second.size();
pool<Cell *> addsub;
pool<Cell *> alu;
pool<Cell *> macc;
struct Operand {
SigSpec sig;
bool is_signed;
bool negate;
// With FMA, when both factors are set, the operand represents a product to
// be expanded into partial products at extraction time, is_signed then
// applies to factor_a, and factor_b carries its own signedness
SigSpec factor_b; // empty for regular operands
bool factor_b_signed = false;
};
ArithTreeWorker(const ArithTreeOptions &opt, Module *module) : opt(opt), module(module), sigmap(module)
{
// Build traversal data
for (auto cell : module->cells()) {
for (auto &[name, sig] : cell->connections()) {
if (cell->input(name)) {
for (auto bit : sigmap(sig)) {
bit_consumers[bit].insert(cell);
}
}
}
}
for (auto &[sig, consumers] : bit_consumers)
fanout[sig] = consumers.size();
for (auto wire : module->wires())
if (wire->port_output)
for (auto bit : sigmap(SigSpec(wire)))
fanout[bit]++;
}
};
struct Cells {
pool<Cell *> addsub;
pool<Cell *> alu;
pool<Cell *> macc;
static bool is_addsub(Cell *cell) { return cell->type == ID($add) || cell->type == ID($sub); }
static bool is_alu(Cell *cell) { return cell->type == ID($alu); }
static bool is_macc(Cell *cell) { return cell->type == ID($macc) || cell->type == ID($macc_v2); }
bool empty() { return addsub.empty() && alu.empty() && macc.empty(); }
Cells(Module *module)
{
// Collect cell data
for (auto cell : module->cells()) {
if (is_addsub(cell))
addsub.insert(cell);
@ -69,59 +78,55 @@ struct Cells {
macc.insert(cell);
}
}
};
struct AluInfo {
Cells &cells;
Traversal &traversal;
bool is_subtract(Cell *cell)
{
SigSpec bi = traversal.sigmap(cell->getPort(ID::BI));
SigSpec ci = traversal.sigmap(cell->getPort(ID::CI));
bool is_addsub(Cell *cell) {
return cell->type == ID($add) || cell->type == ID($sub);
}
bool is_alu(Cell *cell) {
return cell->type == ID($alu);
}
bool is_macc(Cell *cell) {
return cell->type == ID($macc) || cell->type == ID($macc_v2);
}
bool is_sub(Cell *cell) {
SigSpec bi = sigmap(cell->getPort(ID::BI));
SigSpec ci = sigmap(cell->getPort(ID::CI));
return GetSize(bi) == 1 && bi[0] == State::S1 && GetSize(ci) == 1 && ci[0] == State::S1;
}
bool is_add(Cell *cell)
{
SigSpec bi = traversal.sigmap(cell->getPort(ID::BI));
SigSpec ci = traversal.sigmap(cell->getPort(ID::CI));
SigSpec bi = sigmap(cell->getPort(ID::BI));
SigSpec ci = sigmap(cell->getPort(ID::CI));
return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0;
}
bool is_chainable(Cell *cell)
{
if (!(is_add(cell) || is_subtract(cell)))
if (!(is_add(cell) || is_sub(cell)))
return false;
for (auto bit : traversal.sigmap(cell->getPort(ID::X)))
if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0)
for (auto bit : sigmap(cell->getPort(ID::X)))
if (fanout.count(bit) && fanout[bit] > 0)
return false;
for (auto bit : traversal.sigmap(cell->getPort(ID::CO)))
if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0)
for (auto bit : sigmap(cell->getPort(ID::CO)))
if (fanout.count(bit) && fanout[bit] > 0)
return false;
return true;
}
};
struct Rewriter {
Module *module;
Cells &cells;
Traversal traversal;
AluInfo alu_info;
Rewriter(Module *module, Cells &cells) : module(module), cells(cells), traversal(module), alu_info{cells, traversal} {}
Cell *sole_chainable_consumer(SigSpec sig, const pool<Cell *> &candidates)
{
Cell *consumer = nullptr;
for (auto bit : sig) {
if (!traversal.fanout.count(bit) || traversal.fanout[bit] != 1)
if (!fanout.count(bit) || fanout[bit] != 1)
return nullptr;
if (!traversal.bit_consumers.count(bit) || traversal.bit_consumers[bit].size() != 1)
if (!bit_consumers.count(bit) || bit_consumers[bit].size() != 1)
return nullptr;
Cell *c = *traversal.bit_consumers[bit].begin();
Cell *c = *bit_consumers[bit].begin();
if (!candidates.count(c))
return nullptr;
@ -137,7 +142,7 @@ struct Rewriter {
{
dict<Cell *, Cell *> parent_of;
for (auto cell : candidates) {
Cell *consumer = sole_chainable_consumer(traversal.sigmap(cell->getPort(ID::Y)), candidates);
Cell *consumer = sole_chainable_consumer(sigmap(cell->getPort(ID::Y)), candidates);
if (consumer && consumer != cell)
parent_of[cell] = consumer;
}
@ -177,12 +182,12 @@ struct Rewriter {
{
pool<SigBit> bits;
for (auto cell : chain)
for (auto bit : traversal.sigmap(cell->getPort(ID::Y)))
for (auto bit : sigmap(cell->getPort(ID::Y)))
bits.insert(bit);
return bits;
}
static bool overlaps(SigSpec sig, const pool<SigBit> &bits)
bool overlaps(SigSpec sig, const pool<SigBit> &bits)
{
for (auto bit : sig)
if (bits.count(bit))
@ -195,17 +200,16 @@ struct Rewriter {
bool parent_subtracts;
if (parent->type == ID($sub))
parent_subtracts = true;
else if (cells.is_alu(parent))
parent_subtracts = alu_info.is_subtract(parent);
else if (is_alu(parent))
parent_subtracts = is_sub(parent);
else
return false;
if (!parent_subtracts)
return false;
// Check if any bit of child's Y connects to parent's B
SigSpec child_y = traversal.sigmap(child->getPort(ID::Y));
SigSpec parent_b = traversal.sigmap(parent->getPort(ID::B));
SigSpec child_y = sigmap(child->getPort(ID::Y));
SigSpec parent_b = sigmap(parent->getPort(ID::B));
for (auto bit : child_y)
for (auto pbit : parent_b)
if (bit == pbit)
@ -244,21 +248,20 @@ struct Rewriter {
for (auto cell : chain) {
bool cell_neg = negated.count(cell) ? negated[cell] : false;
SigSpec a = traversal.sigmap(cell->getPort(ID::A));
SigSpec b = traversal.sigmap(cell->getPort(ID::B));
SigSpec a = sigmap(cell->getPort(ID::A));
SigSpec b = sigmap(cell->getPort(ID::B));
bool a_signed = cell->getParam(ID::A_SIGNED).as_bool();
bool b_signed = cell->getParam(ID::B_SIGNED).as_bool();
bool b_sub = (cell->type == ID($sub)) || (cells.is_alu(cell) && alu_info.is_subtract(cell));
bool b_sub = (cell->type == ID($sub)) || (is_alu(cell) && is_sub(cell));
// Only add operands not produced by other chain cells
if (!overlaps(a, chain_bits)) {
operands.push_back({a, a_signed, cell_neg});
operands.push_back({a, a_signed, cell_neg, SigSpec(), false});
if (cell_neg)
neg_compensation++;
}
if (!overlaps(b, chain_bits)) {
bool neg = cell_neg ^ b_sub;
operands.push_back({b, b_signed, neg});
operands.push_back({b, b_signed, neg, SigSpec(), false});
if (neg)
neg_compensation++;
}
@ -272,63 +275,123 @@ struct Rewriter {
neg_compensation = 0;
for (auto &term : macc.terms) {
// Bail on multiplication
if (GetSize(term.in_b) != 0)
return false;
operands.push_back({term.in_a, term.is_signed, term.do_subtract});
if (GetSize(term.in_b) != 0) {
// TODO: Baugh-Wooley sign extension for mixed sign and sign*sign cases, don't bail out to non-FMA
if (!opt.fma_fusion)
return false;
if (term.is_signed || !CompressorTree::supports_signedness(term.is_signed, term.is_signed))
return false;
// Preserve term as a multiplicative operand which is expanded into partial products
Operand op;
op.sig = term.in_a;
op.is_signed = false;
op.negate = term.do_subtract;
op.factor_b = term.in_b;
op.factor_b_signed = false;
operands.push_back(op);
continue;
}
operands.push_back({term.in_a, term.is_signed, term.do_subtract, SigSpec(), false});
if (term.do_subtract)
neg_compensation++;
}
return true;
}
SigSpec extend_operand(SigSpec sig, bool is_signed, int width)
std::vector<CompressorTree::DepthSig> build_operand_pool(std::vector<Operand> &operands, int width, int &neg_compensation)
{
if (GetSize(sig) < width) {
SigBit pad;
if (is_signed && GetSize(sig) > 0)
pad = sig[GetSize(sig) - 1];
else
pad = State::S0;
sig.append(SigSpec(pad, width - GetSize(sig)));
}
if (GetSize(sig) > width)
sig = sig.extract(0, width);
return sig;
}
void replace_with_carry_save_tree(std::vector<Operand> &operands, SigSpec result_y, int neg_compensation, const char *desc)
{
int width = GetSize(result_y);
std::vector<SigSpec> extended;
extended.reserve(operands.size() + 1);
// Expand operands into a flat list of signals for reduction
std::vector<CompressorTree::DepthSig> pool;
pool.reserve(operands.size() * 2);
for (auto &op : operands) {
SigSpec s = extend_operand(op.sig, op.is_signed, width);
if (op.negate)
s = module->Not(NEW_ID, s);
extended.push_back(s);
if (GetSize(op.factor_b) == 0) {
// Additive operand
SigSpec s = CompressorTree::normalize_to_width(op.sig, op.is_signed, width);
if (op.negate)
s = module->Not(NEW_ID, s);
pool.push_back({s, 0});
} else {
// Multiplicative operand
// TODO: Negate product instead of factor
auto pps =
CompressorTree::generate_partial_products(module, op.sig, op.factor_b, op.is_signed, op.factor_b_signed, width);
if (op.negate) {
for (auto &pp : pps) {
SigSpec inv = module->addWire(NEW_ID, width);
module->addNot(NEW_ID, pp.sig, inv);
pp.sig = inv;
neg_compensation++;
}
}
for (auto &pp : pps)
pool.push_back(pp);
}
}
// Add correction for negated operands (-x = ~x + 1 so 1 per negation)
if (neg_compensation > 0)
extended.push_back(SigSpec(neg_compensation, width));
pool.push_back({SigSpec(neg_compensation, width), 0});
int compressor_count;
auto [a, b] = wallace_reduce_scheduled(module, extended, width, &compressor_count);
log(" %s -> %d $fa + 1 $add (%d operands, module %s)\n", desc, compressor_count, (int)operands.size(), module);
return pool;
}
// Emit final add
module->addAdd(NEW_ID, a, b, result_y, false);
void emit_tree(std::vector<Operand> &operands, SigSpec result_y, int neg_compensation, bool any_signed, const char *desc)
{
int width = GetSize(result_y);
if (opt.elarith_macro) {
// Bypass the compressor
emit_elarith_macro(operands, result_y, neg_compensation, any_signed, desc);
return;
}
auto pool = build_operand_pool(operands, width, neg_compensation);
auto [a, b] = CompressorTree::reduce_scheduled(module, std::move(pool), width, opt.strategy);
auto final_choice = CompressorTree::pick_final_adder(width, opt.final_mode);
CompressorTree::emit_final_adder(module, a, b, result_y, final_choice, any_signed);
}
void emit_elarith_macro(std::vector<Operand> &operands, SigSpec result_y, int neg_compensation, bool any_signed, const char *desc)
{
int width = GetSize(result_y);
auto pool = build_operand_pool(operands, width, neg_compensation);
log(" arith_tree::elarith: %s -> \\AddMopCsv macro, %d operands, width %d (module %s)\n", desc, (int)pool.size(), width, log_id(module));
// Pack all operands
SigSpec flat;
for (auto &dp : pool) {
SigSpec ext = CompressorTree::normalize_to_width(dp.sig, false, width);
flat.append(ext);
}
Cell *c = module->addCell(NEW_ID, IdString("\\AddMopCsv"));
c->setParam(IdString("\\WIDTH"), width);
c->setParam(IdString("\\NUM_OPERANDS"), (int)pool.size());
c->setParam(IdString("\\SIGNED"), any_signed ? 1 : 0);
c->setParam(IdString("\\SPEED"), Const("fast"));
c->setPort(IdString("\\Operands"), flat);
c->setPort(IdString("\\Sum"), result_y);
}
bool any_operand_signed(const std::vector<Operand> &operands)
{
for (auto &op : operands)
if (op.is_signed)
return true;
return false;
}
void process_chains()
{
pool<Cell *> candidates;
for (auto cell : cells.addsub)
for (auto cell : addsub)
candidates.insert(cell);
for (auto cell : cells.alu)
if (alu_info.is_chainable(cell))
for (auto cell : alu)
if (is_chainable(cell))
candidates.insert(cell);
if (candidates.empty())
@ -354,7 +417,7 @@ struct Rewriter {
for (auto c : chain)
to_remove.insert(c);
replace_with_carry_save_tree(operands, root->getPort(ID::Y), neg_compensation, "Replaced add/sub chain");
emit_tree(operands, root->getPort(ID::Y), neg_compensation, any_operand_signed(operands), "Replaced $add/$sub chain");
}
for (auto cell : to_remove)
@ -363,48 +426,76 @@ struct Rewriter {
void process_maccs()
{
for (auto cell : cells.macc) {
pool<Cell *> to_remove;
for (auto cell : macc) {
std::vector<Operand> operands;
int neg_compensation;
if (!extract_macc_operands(cell, operands, neg_compensation))
continue;
if (operands.size() < 3)
if (operands.size() < 1)
continue;
bool has_mul = false;
for (auto &op : operands)
if (GetSize(op.factor_b) > 0) {
has_mul = true;
break;
}
if (!has_mul && operands.size() < 3)
continue;
replace_with_carry_save_tree(operands, cell->getPort(ID::Y), neg_compensation, "Replaced $macc");
module->remove(cell);
emit_tree(operands, cell->getPort(ID::Y), neg_compensation, any_operand_signed(operands), has_mul ? "Replaced $macc (FMA)" : "Replaced $macc");
to_remove.insert(cell);
}
for (auto cell : to_remove)
module->remove(cell);
}
void run()
{
if (addsub.empty() && alu.empty() && macc.empty())
return;
process_chains();
process_maccs();
}
};
void run(Module *module)
{
Cells cells(module);
if (cells.empty())
return;
Rewriter rewriter{module, cells};
rewriter.process_chains();
rewriter.process_maccs();
}
struct ArithTreePass : public Pass {
ArithTreePass() : Pass("arith_tree", "convert add/sub/macc chains to carry-save adder trees") {}
ArithTreePass() : Pass("arith_tree", "convert add/sub/macc/alu chains to carry-save adder trees") {}
void help() override
{
// |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|
log("\n");
log(" arith_tree [selection]\n");
log(" arith_tree [options] [selection]\n");
log("\n");
log("This pass replaces chains of $add/$sub cells, $alu cells (with constant\n");
log("BI/CI), and $macc/$macc_v2 cells (without multiplications) with carry-save\n");
log("adder trees using $fa cells and a single final $add.\n");
log("BI/CI), and $macc/$macc_v2 cells with carry-save adder trees \n");
log("using $fa cells and a single final adder.\n");
log("\n");
log("The tree uses Wallace-tree scheduling: at each level, ready operands are\n");
log("grouped into triplets and compressed via full adders, giving\n");
log("O(log_{1.5} N) depth for N input operands.\n");
log(" -strategy <fa|42>\n");
log(" Compressor strategy. 'fa' uses only 3:2 full-adder groupings\n");
log(" '42' (the default) prefers 4:2 compressor groupings, with\n");
log(" fallback to 3:2 compressors for residuals\n");
log("\n");
log(" -final <auto|ripple|prefix|elarith>\n");
log(" Selects the architecture used for the final two-vector add.\n");
log(" 'auto' (default) emits a ripple-style $add for narrow widths\n");
log(" (< 16 bits) and a parallel prefix hinted $add for wider ones.\n");
log(" 'elarith' emits an \\AddCfast black-box from the ELArith\n");
log(" library; the surrounding flow must provide that module.\n");
log("\n");
log(" -no-fma\n");
log(" Disable fused multiply-add expansion in $macc cells\n");
log("\n");
log(" -elarith-macro\n");
log(" Replace each detected chain with a single \\AddMopCsv black-box\n");
log(" instance instead of expanding it into $fa cells. The downstream\n");
log(" flow must provide an \\AddMopCsv implementation\n");
log("\n");
log("The default behaviour delivers 4:2 compression, FMA fusion, and a\n");
log("width-adaptive final adder\n");
log("\n");
}
@ -412,15 +503,44 @@ struct ArithTreePass : public Pass {
{
log_header(design, "Executing ARITH_TREE pass.\n");
ArithTreeOptions opt;
size_t argidx;
for (argidx = 1; argidx < args.size(); argidx++)
for (argidx = 1; argidx < args.size(); argidx++) {
const std::string &arg = args[argidx];
if (arg == "-strategy" && argidx + 1 < args.size()) {
const std::string &v = args[++argidx];
if (v == "fa") { opt.strategy = CompressorTree::Strategy::FA_ONLY; }
else if (v == "42") { opt.strategy = CompressorTree::Strategy::PREFER_42; }
else { log_cmd_error("arith_tree: unknown -strategy '%s'\n", v.c_str()); }
continue;
}
if (arg == "-final" && argidx + 1 < args.size()) {
const std::string &v = args[++argidx];
if (v == "auto") { opt.final_mode = CompressorTree::FinalMode::AUTO; }
else if (v == "ripple") { opt.final_mode = CompressorTree::FinalMode::RIPPLE; }
else if (v == "prefix") { opt.final_mode = CompressorTree::FinalMode::PREFIX; }
else if (v == "elarith") { opt.final_mode = CompressorTree::FinalMode::ELARITH; }
else { log_cmd_error("arith_tree: unknown -final '%s'\n", v.c_str()); }
continue;
}
if (arg == "-no-fma") {
opt.fma_fusion = false;
continue;
}
if (arg == "-elarith-macro") {
opt.elarith_macro = true;
continue;
}
break;
}
extra_args(args, argidx, design);
for (auto module : design->selected_modules()) {
run(module);
for (auto mod : design->selected_modules()) {
ArithTreeWorker worker(opt, mod);
worker.run();
}
}
} ArithTreePass;
PRIVATE_NAMESPACE_END
PRIVATE_NAMESPACE_END

View file

@ -58,7 +58,7 @@ synth -top my_design -booth
#include "kernel/sigtools.h"
#include "kernel/yosys.h"
#include "kernel/macc.h"
#include "kernel/wallace_tree.h"
#include "kernel/compressor_tree.h"
USING_YOSYS_NAMESPACE
PRIVATE_NAMESPACE_BEGIN
@ -386,7 +386,11 @@ struct BoothPassWorker {
// Later on yosys will clean up unused constants
// DebugDumpAlignPP(aligned_pp);
auto [wtree_a, wtree_b] = wallace_reduce_scheduled(module, aligned_pp, z_sz);
std::vector<CompressorTree::DepthSig> operands;
operands.reserve(aligned_pp.size());
for (auto &s : aligned_pp)
operands.push_back({s, 0});
auto [wtree_a, wtree_b] = CompressorTree::reduce_scheduled(module, std::move(operands), z_sz, CompressorTree::Strategy::FA_ONLY);
// Debug code: Dump out the csa trees
// DumpCSATrees(debug_csa_trees);