diff --git a/passes/opt/muxpack.cc b/passes/opt/muxpack.cc
index cb13a45b0..ae4b67db2 100644
--- a/passes/opt/muxpack.cc
+++ b/passes/opt/muxpack.cc
@@ -37,6 +37,7 @@ struct MuxpackWorker
 	dict<SigSpec, Cell*> sig_chain_prev;
 	pool<SigBit> sigbit_with_non_chain_users;
 	pool<Cell*> chain_start_cells;
+	pool<Cell*> candidate_cells;
 
 	void make_sig_chain_next_prev()
 	{
@@ -59,14 +60,18 @@ struct MuxpackWorker
 				if (sig_chain_next.count(a_sig))
 					for (auto a_bit : a_sig.bits())
 						sigbit_with_non_chain_users.insert(a_bit);
-				else
+				else {
 					sig_chain_next[a_sig] = cell;
+					candidate_cells.insert(cell);
+				}
 
 				if (sig_chain_next.count(b_sig))
 					for (auto b_bit : b_sig.bits())
 						sigbit_with_non_chain_users.insert(b_bit);
-				else
+				else {
 					sig_chain_next[b_sig] = cell;
+					candidate_cells.insert(cell);
+				}
 
 				sig_chain_prev[y_sig] = cell;
 				continue;
@@ -81,35 +86,34 @@ struct MuxpackWorker
 
 	void find_chain_start_cells()
 	{
-		for (auto it : sig_chain_next)
+		for (auto cell : candidate_cells)
 		{
-			SigSpec next_sig = it.second->getPort("\\A");
+			SigSpec next_sig = cell->getPort("\\A");
 			if (sig_chain_prev.count(next_sig) == 0) {
-				next_sig = it.second->getPort("\\B");
+				next_sig = cell->getPort("\\B");
 				if (sig_chain_prev.count(next_sig) == 0)
-					next_sig = SigSpec();
+					goto start_cell;
 			}
 
-			for (auto bit : next_sig.bits())
-				if (sigbit_with_non_chain_users.count(bit))
-					goto start_cell;
-
-			if (!next_sig.empty())
 			{
+				for (auto bit : next_sig.bits())
+					if (sigbit_with_non_chain_users.count(bit))
+						goto start_cell;
+
 				Cell *c1 = sig_chain_prev.at(next_sig);
-				Cell *c2 = it.second;
+				Cell *c2 = cell;
 
 				if (c1->type != c2->type)
 					goto start_cell;
 
 				if (c1->parameters != c2->parameters)
 					goto start_cell;
-
-				continue;
 			}
 
+			continue;
+
 		start_cell:
-			chain_start_cells.insert(it.second);
+			chain_start_cells.insert(cell);
 		}
 	}
 
@@ -197,6 +201,7 @@ struct MuxpackWorker
 		sig_chain_next.clear();
 		sig_chain_prev.clear();
 		chain_start_cells.clear();
+		candidate_cells.clear();
 	}
 
 	MuxpackWorker(Module *module) :