diff --git a/passes/opt/csa_tree.cc b/passes/opt/csa_tree.cc index 3be9fdc91..b58d65104 100644 --- a/passes/opt/csa_tree.cc +++ b/passes/opt/csa_tree.cc @@ -13,19 +13,32 @@ struct Operand { bool negate; }; -struct CsaTreeWorker +struct Traversal { - Module* module; SigMap sigmap; - dict> bit_consumers; dict fanout; + Traversal(Module* module) : sigmap(module) { + for (auto cell : module->cells()) + for (auto& conn : cell->connections()) + if (cell->input(conn.first)) + for (auto bit : sigmap(conn.second)) + bit_consumers[bit].insert(cell); - pool addsub_cells; - pool alu_cells; - pool macc_cells; + for (auto& pair : bit_consumers) + fanout[pair.first] = pair.second.size(); - CsaTreeWorker(Module* module) : module(module), sigmap(module) {} + for (auto wire : module->wires()) + if (wire->port_output) + for (auto bit : sigmap(SigSpec(wire))) + fanout[bit]++; + } +}; + +struct Cells { + pool addsub; + pool alu; + pool macc; static bool is_addsub(Cell* cell) { @@ -42,79 +55,74 @@ struct CsaTreeWorker return cell->type == ID($macc) || cell->type == ID($macc_v2); } - bool alu_is_subtract(Cell* cell) + bool empty() { + return addsub.empty() && alu.empty() && macc.empty(); + } + + Cells(Module* module) { + for (auto cell : module->cells()) { + if (is_addsub(cell)) + addsub.insert(cell); + else if (is_alu(cell)) + alu.insert(cell); + else if (is_macc(cell)) + macc.insert(cell); + } + } +}; + +struct AluInfo { + Cells& cells; + Traversal& traversal; + bool is_subtract(Cell* cell) { - SigSpec bi = sigmap(cell->getPort(ID::BI)); - SigSpec ci = sigmap(cell->getPort(ID::CI)); + SigSpec bi = traversal.sigmap(cell->getPort(ID::BI)); + SigSpec ci = traversal.sigmap(cell->getPort(ID::CI)); return GetSize(bi) == 1 && bi[0] == State::S1 && GetSize(ci) == 1 && ci[0] == State::S1; } - bool alu_is_add(Cell* cell) + bool is_add(Cell* cell) { - SigSpec bi = sigmap(cell->getPort(ID::BI)); - SigSpec ci = sigmap(cell->getPort(ID::CI)); + SigSpec bi = traversal.sigmap(cell->getPort(ID::BI)); + SigSpec ci = traversal.sigmap(cell->getPort(ID::CI)); return GetSize(bi) == 1 && bi[0] == State::S0 && GetSize(ci) == 1 && ci[0] == State::S0; } - bool alu_is_chainable(Cell* cell) - { - if (!(alu_is_add(cell) || alu_is_subtract(cell))) - return false; - - for (auto bit : sigmap(cell->getPort(ID::X))) - if (fanout.count(bit) && fanout[bit] > 0) - return false; - for (auto bit : sigmap(cell->getPort(ID::CO))) - if (fanout.count(bit) && fanout[bit] > 0) - return false; - - return true; - } - bool is_chainable(Cell* cell) { - return is_addsub(cell) || (is_alu(cell) && alu_is_chainable(cell)); + if (!(is_add(cell) || is_subtract(cell))) + return false; + + for (auto bit : traversal.sigmap(cell->getPort(ID::X))) + if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0) + return false; + for (auto bit : traversal.sigmap(cell->getPort(ID::CO))) + if (traversal.fanout.count(bit) && traversal.fanout[bit] > 0) + return false; + + return true; } +}; - void classify_cells() - { - for (auto cell : module->cells()) { - if (is_addsub(cell)) - addsub_cells.insert(cell); - else if (is_alu(cell)) - alu_cells.insert(cell); - else if (is_macc(cell)) - macc_cells.insert(cell); - } - } +struct Rewriter +{ + Module* module; + Cells& cells; + Traversal traversal; + AluInfo alu_info; - void build_fanout_map() - { - for (auto cell : module->cells()) - for (auto& conn : cell->connections()) - if (cell->input(conn.first)) - for (auto bit : sigmap(conn.second)) - bit_consumers[bit].insert(cell); - - for (auto& pair : bit_consumers) - fanout[pair.first] = pair.second.size(); - - for (auto wire : module->wires()) - if (wire->port_output) - for (auto bit : sigmap(SigSpec(wire))) - fanout[bit]++; - } + Rewriter(Module* module, Cells& cells) : module(module), cells(cells), traversal(module), alu_info{cells, traversal} {} Cell* sole_chainable_consumer(SigSpec sig, const pool& candidates) { Cell* consumer = nullptr; for (auto bit : sig) { - if (!fanout.count(bit) || fanout[bit] != 1) + if (!traversal.fanout.count(bit) || traversal.fanout[bit] != 1) return nullptr; - if (!bit_consumers.count(bit) || bit_consumers[bit].size() != 1) + if (!traversal.bit_consumers.count(bit) || traversal.bit_consumers[bit].size() != 1) return nullptr; - Cell* c = *bit_consumers[bit].begin(); + Cell* c = *traversal.bit_consumers[bit].begin(); if (!candidates.count(c)) return nullptr; @@ -131,7 +139,7 @@ struct CsaTreeWorker dict parent_of; for (auto cell : candidates) { Cell* consumer = sole_chainable_consumer( - sigmap(cell->getPort(ID::Y)), candidates); + traversal.sigmap(cell->getPort(ID::Y)), candidates); if (consumer && consumer != cell) parent_of[cell] = consumer; } @@ -144,7 +152,8 @@ struct CsaTreeWorker std::queue q; q.push(root); while (!q.empty()) { - Cell* cur = q.front(); q.pop(); + Cell* cur = q.front(); + q.pop(); if (!chain.insert(cur).second) continue; auto it = children_of.find(cur); @@ -159,7 +168,7 @@ struct CsaTreeWorker { pool bits; for (auto cell : chain) - for (auto bit : sigmap(cell->getPort(ID::Y))) + for (auto bit : traversal.sigmap(cell->getPort(ID::Y))) bits.insert(bit); return bits; } @@ -177,16 +186,16 @@ struct CsaTreeWorker bool parent_subtracts; if (parent->type == ID($sub)) parent_subtracts = true; - else if (is_alu(parent)) - parent_subtracts = alu_is_subtract(parent); + else if (cells.is_alu(parent)) + parent_subtracts = alu_info.is_subtract(parent); else return false; if (!parent_subtracts) return false; - SigSpec child_y = sigmap(child->getPort(ID::Y)); - SigSpec parent_b = sigmap(parent->getPort(ID::B)); + SigSpec child_y = traversal.sigmap(child->getPort(ID::Y)); + SigSpec parent_b = traversal.sigmap(parent->getPort(ID::B)); for (auto bit : child_y) for (auto pbit : parent_b) if (bit == pbit) @@ -229,11 +238,11 @@ struct CsaTreeWorker else cell_neg = false; - SigSpec a = sigmap(cell->getPort(ID::A)); - SigSpec b = sigmap(cell->getPort(ID::B)); + SigSpec a = traversal.sigmap(cell->getPort(ID::A)); + SigSpec b = traversal.sigmap(cell->getPort(ID::B)); bool a_signed = cell->getParam(ID::A_SIGNED).as_bool(); bool b_signed = cell->getParam(ID::B_SIGNED).as_bool(); - bool b_sub = (cell->type == ID($sub)) || (is_alu(cell) && alu_is_subtract(cell)); + bool b_sub = (cell->type == ID($sub)) || (cells.is_alu(cell) && alu_info.is_subtract(cell)); if (!overlaps(a, chain_bits)) { bool neg = cell_neg; @@ -255,6 +264,7 @@ struct CsaTreeWorker correction = 0; for (auto& term : macc.terms) { + // Bail on multiplication if (GetSize(term.in_b) != 0) return false; operands.push_back({term.in_a, term.is_signed, term.do_subtract}); @@ -279,30 +289,12 @@ struct CsaTreeWorker return sig; } - SigSpec emit_not(SigSpec sig, int width) - { - SigSpec out = module->addWire(NEW_ID, width); - Cell* inv = module->addCell(NEW_ID, ID($not)); - inv->setParam(ID::A_SIGNED, false); - inv->setParam(ID::A_WIDTH, width); - inv->setParam(ID::Y_WIDTH, width); - inv->setPort(ID::A, sig); - inv->setPort(ID::Y, out); - return out; - } - std::pair emit_fa(SigSpec a, SigSpec b, SigSpec c, int width) { SigSpec sum = module->addWire(NEW_ID, width); SigSpec cout = module->addWire(NEW_ID, width); - Cell* fa = module->addCell(NEW_ID, ID($fa)); - fa->setParam(ID::WIDTH, width); - fa->setPort(ID::A, a); - fa->setPort(ID::B, b); - fa->setPort(ID::C, c); - fa->setPort(ID::X, cout); - fa->setPort(ID::Y, sum); + module->addFa(NEW_ID, a, b, c, cout, sum); SigSpec carry; carry.append(State::S0); @@ -310,19 +302,6 @@ struct CsaTreeWorker return {sum, carry}; } - void emit_final_add(SigSpec a, SigSpec b, SigSpec y, int width) - { - Cell* add = module->addCell(NEW_ID, ID($add)); - add->setParam(ID::A_SIGNED, false); - add->setParam(ID::B_SIGNED, false); - add->setParam(ID::A_WIDTH, width); - add->setParam(ID::B_WIDTH, width); - add->setParam(ID::Y_WIDTH, width); - add->setPort(ID::A, a); - add->setPort(ID::B, b); - add->setPort(ID::Y, y); - } - struct DepthSig { SigSpec sig; int depth; @@ -387,7 +366,7 @@ struct CsaTreeWorker for (auto& op : operands) { SigSpec s = extend_operand(op.sig, op.is_signed, width); if (op.negate) - s = emit_not(s, width); + s = module->Not(NEW_ID, s); extended.push_back(s); } @@ -400,16 +379,17 @@ struct CsaTreeWorker log(" %s -> %d $fa + 1 $add (%d operands, module %s)\n", desc, fa_count, (int)operands.size(), log_id(module)); - emit_final_add(a, b, result_y, width); + // Emit final add + module->addAdd(NEW_ID, a, b, result_y, false); } void process_chains() { pool candidates; - for (auto cell : addsub_cells) + for (auto cell : cells.addsub) candidates.insert(cell); - for (auto cell : alu_cells) - if (alu_is_chainable(cell)) + for (auto cell : cells.alu) + if (alu_info.is_chainable(cell)) candidates.insert(cell); if (candidates.empty()) @@ -427,7 +407,7 @@ struct CsaTreeWorker pool processed; for (auto root : candidates) { if (has_parent.count(root) || processed.count(root)) - continue; + continue; // Not a tree root pool chain = collect_chain(root, children_of); if (chain.size() < 2) @@ -451,7 +431,7 @@ struct CsaTreeWorker void process_maccs() { - for (auto cell : macc_cells) { + for (auto cell : cells.macc) { std::vector operands; int correction; if (!extract_macc_operands(cell, operands, correction)) @@ -464,20 +444,19 @@ struct CsaTreeWorker module->remove(cell); } } - - void run() - { - classify_cells(); - - if (addsub_cells.empty() && alu_cells.empty() && macc_cells.empty()) - return; - - build_fanout_map(); - process_chains(); - process_maccs(); - } }; +void run(Module* module) { + Cells cells(module); + + if (cells.empty()) + return; + + Rewriter rewriter {module, cells}; + rewriter.process_chains(); + rewriter.process_maccs(); +} + struct CsaTreePass : public Pass { CsaTreePass() : Pass("csa_tree", "convert add/sub/macc chains to carry-save adder trees") {} @@ -508,8 +487,7 @@ struct CsaTreePass : public Pass { extra_args(args, argidx, design); for (auto module : design->selected_modules()) { - CsaTreeWorker worker(module); - worker.run(); + run(module); } } } CsaTreePass;