diff --git a/passes/opt/opt_balance_tree.cc b/passes/opt/opt_balance_tree.cc index 90f574f87..fc48974df 100644 --- a/passes/opt/opt_balance_tree.cc +++ b/passes/opt/opt_balance_tree.cc @@ -147,6 +147,56 @@ struct OptBalanceTreeWorker { return chain; } + void wreduce(Cell *cell) { + // If cell is arithmetic, remove leading zeros from inputs, then clean up outputs + if (cell->type.in(ID($add), ID($mul))) { + // Remove leading zeros from inputs + for (auto inport : {ID::A, ID::B}) { + // Record number of bits removed + int bits_removed = 0; + IdString inport_signed = (inport == ID::A) ? ID::A_SIGNED : ID::B_SIGNED; + IdString inport_width = (inport == ID::A) ? ID::A_WIDTH : ID::B_WIDTH; + SigSpec inport_sig = sigmap(cell->getPort(inport)); + cell->unsetPort(inport); + if (cell->getParam((inport == ID::A) ? ID::A_SIGNED : ID::B_SIGNED).as_bool()) { + while (GetSize(inport_sig) > 1 && inport_sig[GetSize(inport_sig)-1] == State::S0 && inport_sig[GetSize(inport_sig)-2] == State::S0) { + inport_sig.remove(GetSize(inport_sig)-1, 1); + bits_removed++; + } + } else { + while (GetSize(inport_sig) > 0 && inport_sig[GetSize(inport_sig)-1] == State::S0) { + inport_sig.remove(GetSize(inport_sig)-1, 1); + bits_removed++; + } + } + cell->setPort(inport, inport_sig); + cell->setParam(inport_width, GetSize(inport_sig)); + log("Width reduced %s/%s by %d bits\n", log_id(cell), log_id(inport), bits_removed); + } + + // Record number of bits removed from output + int bits_removed = 0; + + // Remove unnecessary bits from output + SigSpec y_sig = sigmap(cell->getPort(ID::Y)); + cell->unsetPort(ID::Y); + int width; + if (cell->type == ID($add)) + width = std::max(cell->getParam(ID::A_WIDTH).as_int(), cell->getParam(ID::B_WIDTH).as_int()) + 1; + else if (cell->type == ID($mul)) + width = cell->getParam(ID::A_WIDTH).as_int() + cell->getParam(ID::B_WIDTH).as_int(); + else log_abort(); + for (int i = GetSize(y_sig) - 1; i >= width; i--) { + module->connect(y_sig[i], State::S0); + y_sig.remove(i, 1); + bits_removed++; + } + cell->setPort(ID::Y, y_sig); + cell->setParam(ID::Y_WIDTH, GetSize(y_sig)); + log("Width reduced %s/Y by %d bits\n", log_id(cell), bits_removed); + } + } + void process_chain(vector &chain) { // If chain size is less than 3, no balancing needed if (GetSize(chain) < 3) @@ -180,7 +230,7 @@ struct OptBalanceTreeWorker { end_cell->unsetPort(ID::Y); // Create new mid wire - Wire *mid_wire = module->addWire(NEW_ID, GetSize(mid_non_chain_sig)); + Wire *mid_wire = module->addWire(NEW_ID, GetSize(end_y_sig)); // Perform rotation mid_cell->setPort(mid_non_chain_port, mid_wire); @@ -188,11 +238,30 @@ struct OptBalanceTreeWorker { midnext_cell->setPort(midnext_chain_port, mid_non_chain_sig); end_cell->setPort(ID::Y, mid_wire); - // Recurse on subtrees + // Get subtrees vector left_chain(chain.begin(), chain.begin() + GetSize(chain) / 2); vector right_chain(chain.begin() + GetSize(chain) / 2 + 1, chain.end()); + + // Recurse on subtrees process_chain(left_chain); process_chain(right_chain); + + // Recreate sigmap + sigmap.set(module); + + // Width reduce left subtree + for (auto c : left_chain) + wreduce(c); + + // Width reduce right subtree + for (auto c : right_chain) + wreduce(c); + + // Recreate sigmap + sigmap.set(module); + + // Width reduce mid cell + wreduce(mid_cell); } void cleanup() {