From 749d6cd8f7d5a749afc86d0f4014501363245a1a Mon Sep 17 00:00:00 2001 From: nella Date: Mon, 18 May 2026 17:21:26 +0200 Subject: [PATCH] Collapse signed*signed or combined nodes via BW. --- kernel/compressor_tree.h | 68 ++++++++--- passes/techmap/arith_tree.cc | 35 +++--- tests/arith_tree/arith_tree_fma.ys | 7 +- tests/arith_tree/arith_tree_signed_fma.ys | 135 ++++++++++++++++++++++ 4 files changed, 210 insertions(+), 35 deletions(-) create mode 100644 tests/arith_tree/arith_tree_signed_fma.ys diff --git a/kernel/compressor_tree.h b/kernel/compressor_tree.h index 1b631eb36..7785aab4d 100644 --- a/kernel/compressor_tree.h +++ b/kernel/compressor_tree.h @@ -113,10 +113,6 @@ inline SigSpec normalize_to_width(SigSpec sig, bool is_signed, int width) return sig; } -inline bool supports_signedness(bool a_signed, bool b_signed) { - return !(a_signed || b_signed); -} - /** * generate_partial_products() - Generate partial products for FMA concat * @module:The Yosys module to which the compressors will be added @@ -126,15 +122,13 @@ inline bool supports_signedness(bool a_signed, bool b_signed) { * @b_signed: Whether signal B is signed * @width: Target width * - * Return: Radix-2 partial product matrix as a set of depth-0 vectors + * Return: Partial-product matrix as a set of depth-0 vectors */ inline std::vector generate_partial_products(Module *module, SigSpec a, SigSpec b, bool a_signed, bool b_signed, int width) { - // TODO: Baugh-Wooley sign extension for mixed sign and sign*sign cases, don't bail out to non-FMA - log_assert(supports_signedness(a_signed, b_signed) && "CompressorTree::generate_partial_products: signed inputs unsupported"); - int width_a = GetSize(a); + int width_b = GetSize(b); std::vector products; - products.reserve(width_a); + products.reserve(width_a + 3); for (int i = 0; i < width_a; i++) { SigBit ai = a[i]; @@ -144,14 +138,62 @@ inline std::vector generate_partial_products(Module *module, SigSpec a b_shifted.append(b); b_shifted = normalize_to_width(b_shifted, false, width); - // product = b_shifted & replicate(a[i], width) + // row = b_shifted & replicate(a[i], width) SigSpec ai_rep = SigSpec(ai, width); - SigSpec product = module->addWire(NEW_ID, width); - module->addAnd(NEW_ID, b_shifted, ai_rep, product); + SigSpec row = module->addWire(NEW_ID, width); + module->addAnd(NEW_ID, b_shifted, ai_rep, row); - products.push_back({product, 0}); + // Apply Modified Baugh-Wooley inversions for this row + bool row_is_bottom = (i == width_a - 1); + bool any_inversion = (row_is_bottom && b_signed) || a_signed; + + if (any_inversion) { + std::vector mask(width, RTLIL::State::S0); + + for (int j = 0; j < width_b; j++) { + int col = i + j; + if (col < 0 || col >= width) + continue; + bool col_is_right = (j == width_b - 1); + // Flip masks + bool invert = (row_is_bottom && b_signed) ^ (col_is_right && a_signed); + if (invert) + mask[col] = RTLIL::State::S1; + } + + // Skip the xor entirely if the mask is all zeroes + bool nonzero = false; + for (auto s : mask) + if (s == RTLIL::State::S1) { + nonzero = true; + break; + } + if (nonzero) { + SigSpec inverted = module->addWire(NEW_ID, width); + module->addXor(NEW_ID, row, SigSpec(RTLIL::Const(mask)), inverted); + row = inverted; + } + } + + products.push_back({row, 0}); } + // Correction constants + auto push_one_at = [&](int col) { + if (col < 0 || col >= width) + return; + std::vector v(width, RTLIL::State::S0); + v[col] = RTLIL::State::S1; + products.push_back({SigSpec(RTLIL::Const(v)), 0}); + }; + + if (b_signed) + push_one_at(width_a - 1); + if (a_signed) + push_one_at(width_b - 1); + if (a_signed || b_signed) + push_one_at(width_a + width_b - 1); + return products; } diff --git a/passes/techmap/arith_tree.cc b/passes/techmap/arith_tree.cc index eef0b1c2c..aa91dab51 100644 --- a/passes/techmap/arith_tree.cc +++ b/passes/techmap/arith_tree.cc @@ -276,19 +276,16 @@ struct ArithTreeWorker { for (auto &term : macc.terms) { if (GetSize(term.in_b) != 0) { - // TODO: Baugh-Wooley sign extension for mixed sign and sign*sign cases, don't bail out to non-FMA if (!opt.fma_fusion) return false; - if (term.is_signed || !CompressorTree::supports_signedness(term.is_signed, term.is_signed)) - return false; // Preserve term as a multiplicative operand which is expanded into partial products Operand op; op.sig = term.in_a; - op.is_signed = false; + op.is_signed = term.is_signed; op.negate = term.do_subtract; op.factor_b = term.in_b; - op.factor_b_signed = false; + op.factor_b_signed = term.is_signed; operands.push_back(op); continue; } @@ -313,22 +310,22 @@ struct ArithTreeWorker { s = module->Not(NEW_ID, s); pool.push_back({s, 0}); } else { - // Multiplicative operand - // TODO: Negate product instead of factor - auto pps = - CompressorTree::generate_partial_products(module, op.sig, op.factor_b, op.is_signed, op.factor_b_signed, width); + // Multiplicative operand. + auto pps = CompressorTree::generate_partial_products(module, op.sig, op.factor_b, op.is_signed, op.factor_b_signed, width); - if (op.negate) { - for (auto &pp : pps) { - SigSpec inv = module->addWire(NEW_ID, width); - module->addNot(NEW_ID, pp.sig, inv); - pp.sig = inv; - neg_compensation++; - } + if (!op.negate) { + for (auto &pp : pps) + pool.push_back(pp); + continue; } - for (auto &pp : pps) - pool.push_back(pp); + auto [a_red, b_red] = CompressorTree::reduce_scheduled(module, pps, width, opt.strategy); + SigSpec product = module->addWire(NEW_ID, width); + module->addAdd(NEW_ID, a_red, b_red, product, false); + SigSpec neg = module->addWire(NEW_ID, width); + module->addNot(NEW_ID, product, neg); + pool.push_back({neg, 0}); + neg_compensation++; } } @@ -380,7 +377,7 @@ struct ArithTreeWorker { bool any_operand_signed(const std::vector &operands) { for (auto &op : operands) - if (op.is_signed) + if (op.is_signed || op.factor_b_signed) return true; return false; } diff --git a/tests/arith_tree/arith_tree_fma.ys b/tests/arith_tree/arith_tree_fma.ys index 16d90c528..95d9b566f 100644 --- a/tests/arith_tree/arith_tree_fma.ys +++ b/tests/arith_tree/arith_tree_fma.ys @@ -100,7 +100,7 @@ select -assert-min 1 t:$macc t:$macc_v2 %u design -reset read_verilog <