diff --git a/passes/opt/opt_clean.cc b/passes/opt/opt_clean.cc index c7874eeb3..e822d13b1 100644 --- a/passes/opt/opt_clean.cc +++ b/passes/opt/opt_clean.cc @@ -180,16 +180,15 @@ struct RmStats { } }; -void rmunused_module_cells(Module *module, bool verbose, RmStats &stats, keep_cache_t &keep_cache) +unsigned int hash_bit(const SigBit &bit) { + return static_cast(hash_ops::hash(bit).yield()); +} + +void rmunused_module_cells(Module *module, ParallelDispatchThreadPool::Subpool &subpool, bool verbose, RmStats &stats, keep_cache_t &keep_cache) { SigMap sigmap(module); - dict> mem2cells; - pool mem_unused; - pool queue, unused; - pool used_raw_bits; - dict> wire2driver; - dict> driver_driver_logs; - FfInitVals ffinit(&sigmap, module); + FfInitVals ffinit; + ffinit.set_parallel(&sigmap, subpool.thread_pool(), module); SigMap raw_sigmap; for (auto &it : module->connections_) { @@ -199,86 +198,209 @@ void rmunused_module_cells(Module *module, bool verbose, RmStats &stats, keep_ca } } - for (auto &it : module->memories) { - mem_unused.insert(it.first); - } + struct WireDrivers; + struct WireDriver { + using Accumulated = WireDrivers; + SigBit bit; + int driver_cell; + }; + struct WireDrivers { + WireDrivers() : driver_cell(0) {} + WireDrivers(WireDriver driver) : bit(driver.bit), driver_cell(driver.driver_cell) {} + WireDrivers(SigBit bit) : bit(bit), driver_cell(0) {} + WireDrivers(WireDrivers &&other) = default; - for (Cell *cell : module->cells()) { - if (cell->type.in(ID($memwr), ID($memwr_v2), ID($meminit), ID($meminit_v2))) { - IdString mem_id = cell->getParam(ID::MEMID).decode_string(); - mem2cells[mem_id].insert(cell); - } - } - - for (auto &it : module->cells_) { - Cell *cell = it.second; - for (auto &it2 : cell->connections()) { - if (ct_all.cell_known(cell->type) && !ct_all.cell_output(cell->type, it2.first)) - continue; - for (auto raw_bit : it2.second) { - if (raw_bit.wire == nullptr) - continue; - auto bit = sigmap(raw_bit); - if (bit.wire == nullptr && ct_all.cell_known(cell->type)) - driver_driver_logs[raw_sigmap(raw_bit)].push_back(stringf("Driver-driver conflict " - "for %s between cell %s.%s and constant %s in %s: Resolved using constant.", - log_signal(raw_bit), log_id(cell), log_id(it2.first), log_signal(bit), log_id(module))); - if (bit.wire != nullptr) - wire2driver[bit].insert(cell); - } - } - if (keep_cache.query(cell)) - queue.insert(cell); - else - unused.insert(cell); - } - - for (auto &it : module->wires_) { - Wire *wire = it.second; - if (wire->port_output || wire->get_bool_attribute(ID::keep)) { - for (auto bit : sigmap(wire)) - for (auto c : wire2driver[bit]) - queue.insert(c), unused.erase(c); - for (auto raw_bit : SigSpec(wire)) - used_raw_bits.insert(raw_sigmap(raw_bit)); - } - } - - while (!queue.empty()) - { - pool bits; - pool mems; - for (auto cell : queue) { - for (auto &it : cell->connections()) - if (!ct_all.cell_known(cell->type) || ct_all.cell_input(cell->type, it.first)) - for (auto bit : sigmap(it.second)) - bits.insert(bit); - - if (cell->type.in(ID($memrd), ID($memrd_v2))) { - IdString mem_id = cell->getParam(ID::MEMID).decode_string(); - if (mem_unused.count(mem_id)) { - mem_unused.erase(mem_id); - mems.insert(mem_id); + class const_iterator { + public: + const_iterator(const WireDrivers &drivers, bool end) + : driver_cell(drivers.driver_cell), in_extra_cells(end) { + if (drivers.extra_driver_cells) { + if (end) { + extra_it = drivers.extra_driver_cells->end(); + } else { + extra_it = drivers.extra_driver_cells->begin(); + } } } + int operator*() const { + if (in_extra_cells) + return **extra_it; + return driver_cell; + } + const_iterator& operator++() { + if (in_extra_cells) + ++*extra_it; + else + in_extra_cells = true; + return *this; + } + bool operator!=(const const_iterator &other) const { + return !(*this == other); + } + bool operator==(const const_iterator &other) const { + return in_extra_cells == other.in_extra_cells && + extra_it == other.extra_it; + } + private: + std::optional::iterator> extra_it; + int driver_cell; + bool in_extra_cells; + }; + + const_iterator begin() const { return const_iterator(*this, false); } + const_iterator end() const { return const_iterator(*this, true); } + + SigBit bit; + int driver_cell; + std::unique_ptr> extra_driver_cells; + }; + struct WireDriversKeyEquality { + bool operator()(const WireDrivers &a, const WireDrivers &b) const { + return a.bit == b.bit; } + }; + struct WireDriversCollisionHandler { + void operator()(WireDrivers &incumbent, WireDrivers &new_value) const { + log_assert(new_value.extra_driver_cells == nullptr); + if (!incumbent.extra_driver_cells) + incumbent.extra_driver_cells.reset(new pool()); + incumbent.extra_driver_cells->insert(new_value.driver_cell); + } + }; + using Wire2Drivers = ShardedHashSet; - queue.clear(); + Wire2Drivers::Builder wire2driver_builder(subpool); + ShardedVector> mem2cells_vector(subpool); + ShardedVector> driver_driver_logs(subpool); + ShardedVector keep_wires(subpool); + const RTLIL::Module *const_module = module; + int num_threads = subpool.num_threads(); + ConcurrentWorkQueue cell_queue(num_threads); + std::vector> unused(const_module->cells_size()); + subpool.run([&sigmap, &raw_sigmap, &keep_cache, const_module, &mem2cells_vector, &driver_driver_logs, &keep_wires, &cell_queue, &wire2driver_builder, &unused](const ParallelDispatchThreadPool::RunCtx &ctx) { + for (int i : ctx.item_range(const_module->cells_size())) { + Cell *cell = const_module->cell_at(i); + if (cell->type.in(ID($memwr), ID($memwr_v2), ID($meminit), ID($meminit_v2))) + mem2cells_vector.insert(ctx, {cell->getParam(ID::MEMID).decode_string(), i}); - for (auto bit : bits) - for (auto c : wire2driver[bit]) - if (unused.count(c)) - queue.insert(c), unused.erase(c); + for (auto &it2 : cell->connections()) { + if (ct_all.cell_known(cell->type) && !ct_all.cell_output(cell->type, it2.first)) + continue; + for (auto raw_bit : it2.second) { + if (raw_bit.wire == nullptr) + continue; + auto bit = sigmap(raw_bit); + if (bit.wire == nullptr && ct_all.cell_known(cell->type)) { + std::string msg = stringf("Driver-driver conflict " + "for %s between cell %s.%s and constant %s in %s: Resolved using constant.", + log_signal(raw_bit), cell->name.unescape(), it2.first.unescape(), log_signal(bit), const_module->name.unescape()); + driver_driver_logs.insert(ctx, {raw_sigmap(raw_bit), msg}); + } + if (bit.wire != nullptr) + wire2driver_builder.insert(ctx, {{bit, i}, hash_bit(bit)}); + } + } + bool keep = keep_cache.query(cell); + unused[i].store(!keep, std::memory_order_relaxed); + if (keep) + cell_queue.push(ctx, i); + } + for (int i : ctx.item_range(const_module->wires_size())) { + Wire *wire = const_module->wire_at(i); + if (wire->port_output || wire->get_bool_attribute(ID::keep)) + keep_wires.insert(ctx, wire); + } + }); + subpool.run([&wire2driver_builder](const ParallelDispatchThreadPool::RunCtx &ctx) { + wire2driver_builder.process(ctx); + }); + Wire2Drivers wire2driver(wire2driver_builder); - for (auto mem : mems) - for (auto c : mem2cells[mem]) - if (unused.count(c)) - queue.insert(c), unused.erase(c); + dict> mem2cells; + for (std::pair &mem2cell : mem2cells_vector) + mem2cells[mem2cell.first].insert(mem2cell.second); + + pool used_raw_bits; + int i = 0; + for (Wire *wire : keep_wires) { + for (auto bit : sigmap(wire)) { + const WireDrivers *drivers = wire2driver.find({{bit}, hash_bit(bit)}); + if (drivers != nullptr) + for (int cell_index : *drivers) + if (unused[cell_index].exchange(false, std::memory_order_relaxed)) { + ThreadIndex fake_thread_index = {i++ % num_threads}; + cell_queue.push(fake_thread_index, cell_index); + } + } + for (auto raw_bit : SigSpec(wire)) + used_raw_bits.insert(raw_sigmap(raw_bit)); } - unused.sort(RTLIL::sort_by_name_id()); + std::vector> mem_unused(module->memories.size()); + dict mem_indices; + for (int i = 0; i < GetSize(module->memories); ++i) { + mem_indices[module->memories.element(i)->first.str()] = i; + mem_unused[i].store(true, std::memory_order_relaxed); + } - for (auto cell : unused) { + subpool.run([const_module, &sigmap, &wire2driver, &mem2cells, &unused, &cell_queue, &mem_indices, &mem_unused](const ParallelDispatchThreadPool::RunCtx &ctx) { + pool bits; + pool mems; + while (true) { + std::vector cell_indices = cell_queue.pop_batch(ctx); + if (cell_indices.empty()) + return; + for (auto cell_index : cell_indices) { + Cell *cell = const_module->cell_at(cell_index); + for (auto &it : cell->connections()) + if (!ct_all.cell_known(cell->type) || ct_all.cell_input(cell->type, it.first)) + for (auto bit : sigmap(it.second)) + bits.insert(bit); + + if (cell->type.in(ID($memrd), ID($memrd_v2))) { + std::string mem_id = cell->getParam(ID::MEMID).decode_string(); + if (mem_indices.count(mem_id)) { + int mem_index = mem_indices[mem_id]; + if (mem_unused[mem_index].exchange(false, std::memory_order_relaxed)) + mems.insert(mem_id); + } + } + } + + for (auto bit : bits) { + const WireDrivers *drivers = wire2driver.find({{bit}, hash_bit(bit)}); + if (drivers != nullptr) + for (int cell_index : *drivers) + if (unused[cell_index].exchange(false, std::memory_order_relaxed)) + cell_queue.push(ctx, cell_index); + } + bits.clear(); + + for (auto mem : mems) { + if (mem2cells.count(mem) == 0) + continue; + for (int cell_index : mem2cells.at(mem)) + if (unused[cell_index].exchange(false, std::memory_order_relaxed)) + cell_queue.push(ctx, cell_index); + } + mems.clear(); + } + }); + + ShardedVector sharded_unused_cells(subpool); + subpool.run([const_module, &unused, &sharded_unused_cells, &wire2driver](const ParallelDispatchThreadPool::RunCtx &ctx) { + // Parallel destruction of `wire2driver` + wire2driver.clear(ctx); + for (int i : ctx.item_range(const_module->cells_size())) + if (unused[i].load(std::memory_order_relaxed)) + sharded_unused_cells.insert(ctx, i); + }); + pool unused_cells; + for (int cell_index : sharded_unused_cells) + unused_cells.insert(const_module->cell_at(cell_index)); + unused_cells.sort(RTLIL::sort_by_name_id()); + + for (auto cell : unused_cells) { if (verbose) log_debug(" removing unused `%s' cell `%s'.\n", cell->type, cell->name); module->design->scratchpad_set_bool("opt.did_something", true); @@ -288,28 +410,31 @@ void rmunused_module_cells(Module *module, bool verbose, RmStats &stats, keep_ca stats.count_rm_cells++; } - for (auto it : mem_unused) - { + for (const auto &it : mem_indices) { + if (!mem_unused[it.second].load(std::memory_order_relaxed)) + continue; + RTLIL::IdString id(it.first); if (verbose) - log_debug(" removing unused memory `%s'.\n", it); - delete module->memories.at(it); - module->memories.erase(it); + log_debug(" removing unused memory `%s'.\n", id.unescape()); + delete module->memories.at(id); + module->memories.erase(id); } - for (auto &it : module->cells_) { - Cell *cell = it.second; - for (auto &it2 : cell->connections()) { - if (ct_all.cell_known(cell->type) && !ct_all.cell_input(cell->type, it2.first)) - continue; - for (auto raw_bit : raw_sigmap(it2.second)) - used_raw_bits.insert(raw_bit); + if (!driver_driver_logs.empty()) { + // We could do this in parallel but hopefully this is rare. + for (auto &it : module->cells_) { + Cell *cell = it.second; + for (auto &it2 : cell->connections()) { + if (ct_all.cell_known(cell->type) && !ct_all.cell_input(cell->type, it2.first)) + continue; + for (auto raw_bit : raw_sigmap(it2.second)) + used_raw_bits.insert(raw_bit); + } + } + for (std::pair &it : driver_driver_logs) { + if (used_raw_bits.count(it.first)) + log_warning("%s\n", it.second); } - } - - for (auto it : driver_driver_logs) { - if (used_raw_bits.count(it.first)) - for (auto msg : it.second) - log_warning("%s\n", msg); } } @@ -760,7 +885,7 @@ void rmunused_module(RTLIL::Module *module, ParallelDispatchThreadPool &thread_p int num_worker_threads = ThreadPool::work_pool_size(0, module->cells_size(), 1000); ParallelDispatchThreadPool::Subpool subpool(thread_pool, num_worker_threads); remove_temporary_cells(module, subpool, verbose); - rmunused_module_cells(module, verbose, stats, keep_cache); + rmunused_module_cells(module, subpool, verbose, stats, keep_cache); while (rmunused_module_signals(module, purge_mode, verbose, stats)) { } if (rminit && rmunused_module_init(module, subpool, verbose))