From 97e9caa4fa1f874b693a9d948f48418f22babb6c Mon Sep 17 00:00:00 2001
From: Clifford Wolf <clifford@clifford.at>
Date: Sat, 20 Apr 2019 17:52:16 +0200
Subject: [PATCH] Add "onehot" pass, improve "pmux2shiftx" onehot handling

Signed-off-by: Clifford Wolf <clifford@clifford.at>
---
 passes/opt/pmux2shiftx.cc | 417 ++++++++++++++++++++++++++++++++++++--
 1 file changed, 404 insertions(+), 13 deletions(-)

diff --git a/passes/opt/pmux2shiftx.cc b/passes/opt/pmux2shiftx.cc
index 4cd061c68..5fd49a571 100644
--- a/passes/opt/pmux2shiftx.cc
+++ b/passes/opt/pmux2shiftx.cc
@@ -23,6 +23,172 @@
 USING_YOSYS_NAMESPACE
 PRIVATE_NAMESPACE_BEGIN
 
+struct OnehotDatabase
+{
+	Module *module;
+	const SigMap &sigmap;
+	bool verbose = false;
+
+	pool<SigBit> init_ones;
+	dict<SigSpec, pool<SigSpec>> sig_sources_db;
+	dict<SigSpec, bool> sig_onehot_cache;
+	pool<SigSpec> recursion_guard;
+
+	OnehotDatabase(Module *module, const SigMap &sigmap) : module(module), sigmap(sigmap)
+	{
+	}
+
+	void initialize()
+	{
+		for (auto wire : module->wires())
+		{
+			auto it = wire->attributes.find("\\init");
+			if (it == wire->attributes.end())
+				continue;
+
+			auto &val = it->second;
+			int width = std::max(GetSize(wire), GetSize(val));
+
+			for (int i = 0; i < width; i++)
+				if (val[i] == State::S1)
+					init_ones.insert(sigmap(SigBit(wire, i)));
+		}
+
+		for (auto cell : module->cells())
+		{
+			vector<SigSpec> inputs;
+			SigSpec output;
+
+			if (cell->type.in("$adff", "$dff", "$dffe", "$dlatch", "$ff"))
+			{
+				output = cell->getPort("\\Q");
+				if (cell->type == "$adff")
+					inputs.push_back(cell->getParam("\\ARST_VALUE"));
+				inputs.push_back(cell->getPort("\\D"));
+			}
+
+			if (cell->type.in("$mux", "$pmux"))
+			{
+				output = cell->getPort("\\Y");
+				inputs.push_back(cell->getPort("\\A"));
+				SigSpec B = cell->getPort("\\B");
+				for (int i = 0; i < GetSize(B); i += GetSize(output))
+					inputs.push_back(B.extract(i, GetSize(output)));
+			}
+
+			if (!output.empty())
+			{
+				output = sigmap(output);
+				auto &srcs = sig_sources_db[output];
+				for (auto src : inputs) {
+					while (!src.empty() && src[GetSize(src)-1] == State::S0)
+						src.remove(GetSize(src)-1);
+					srcs.insert(sigmap(src));
+				}
+			}
+		}
+	}
+
+	void query_worker(const SigSpec &sig, bool &retval, bool &cache, int indent)
+	{
+		if (verbose)
+			log("%*s %s\n", indent, "", log_signal(sig));
+		log_assert(retval);
+
+		if (recursion_guard.count(sig)) {
+			if (verbose)
+				log("%*s   - recursion\n", indent, "");
+			cache = false;
+			return;
+		}
+
+		auto it = sig_onehot_cache.find(sig);
+		if (it != sig_onehot_cache.end()) {
+			if (verbose)
+				log("%*s   - cached (%s)\n", indent, "", it->second ? "true" : "false");
+			if (!it->second)
+				retval = false;
+			return;
+		}
+
+		bool found_init_ones = false;
+		for (auto bit : sig) {
+			if (init_ones.count(bit)) {
+				if (found_init_ones) {
+					if (verbose)
+						log("%*s   - non-onehot init value\n", indent, "");
+					retval = false;
+					break;
+				}
+				found_init_ones = true;
+			}
+		}
+
+		if (retval)
+		{
+			if (sig.is_fully_const())
+			{
+				bool found_ones = false;
+				for (auto bit : sig) {
+					if (bit == State::S1) {
+						if (found_ones) {
+							if (verbose)
+								log("%*s   - non-onehot constant\n", indent, "");
+							retval = false;
+							break;
+						}
+						found_ones = true;
+					}
+				}
+			}
+			else
+			{
+				auto srcs = sig_sources_db.find(sig);
+				if (srcs == sig_sources_db.end()) {
+					if (verbose)
+						log("%*s   - no sources for non-const signal\n", indent, "");
+					retval = false;
+				} else {
+					for (auto &src : srcs->second) {
+						bool child_cache = true;
+						recursion_guard.insert(sig);
+						query_worker(src, retval, child_cache, indent+4);
+						recursion_guard.erase(sig);
+						if (!child_cache)
+							cache = false;
+						if (!retval)
+							break;
+					}
+				}
+			}
+		}
+
+		// it is always safe to cache a negative result
+		if (cache || !retval)
+			sig_onehot_cache[sig] = retval;
+	}
+
+	bool query(const SigSpec &sig)
+	{
+		bool retval = true;
+		bool cache = true;
+
+		if (verbose)
+			log("** ONEHOT QUERY START (%s)\n", log_signal(sig));
+
+		query_worker(sig, retval, cache, 3);
+
+		if (verbose)
+			log("** ONEHOT QUERY RESULT = %s\n", retval ? "true" : "false");
+
+		// it is always safe to cache the root result of a query
+		if (!cache)
+			sig_onehot_cache[sig] = retval;
+
+		return retval;
+	}
+};
+
 struct Pmux2ShiftxPass : public Pass {
 	Pmux2ShiftxPass() : Pass("pmux2shiftx", "transform $pmux cells to $shiftx cells") { }
 	void help() YS_OVERRIDE
@@ -33,6 +199,9 @@ struct Pmux2ShiftxPass : public Pass {
 		log("\n");
 		log("This pass transforms $pmux cells to $shiftx cells.\n");
 		log("\n");
+		log("    -v, -vv\n");
+		log("        verbose output\n");
+		log("\n");
 		log("    -min_density <percentage>\n");
 		log("        specifies the minimum density for the shifter\n");
 		log("        default: 50\n");
@@ -41,9 +210,9 @@ struct Pmux2ShiftxPass : public Pass {
 		log("        specified the minimum number of choices for a control signal\n");
 		log("        default: 3\n");
 		log("\n");
-		log("    -allow_onehot\n");
-		log("        by default, pmuxes with one-hot encoded control signals are not\n");
-		log("        converted. this option disables that check.\n");
+		log("    -onehot ignore|pmux|shiftx\n");
+		log("        select strategy for one-hot encoded control signals\n");
+		log("        default: pmux\n");
 		log("\n");
 	}
 	void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
@@ -51,6 +220,9 @@ struct Pmux2ShiftxPass : public Pass {
 		int min_density = 50;
 		int min_choices = 3;
 		bool allow_onehot = false;
+		bool optimize_onehot = true;
+		bool verbose = false;
+		bool verbose_onehot = false;
 
 		log_header(design, "Executing PMUX2SHIFTX pass.\n");
 
@@ -64,8 +236,31 @@ struct Pmux2ShiftxPass : public Pass {
 				min_choices = atoi(args[++argidx].c_str());
 				continue;
 			}
-			if (args[argidx] == "-allow_onehot") {
+			if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "ignore") {
+				argidx++;
+				allow_onehot = false;
+				optimize_onehot = false;
+				continue;
+			}
+			if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "pmux") {
+				argidx++;
+				allow_onehot = false;
+				optimize_onehot = true;
+				continue;
+			}
+			if (args[argidx] == "-onehot" && argidx+1 < args.size() && args[argidx+1] == "shiftx") {
+				argidx++;
 				allow_onehot = true;
+				optimize_onehot = false;
+				continue;
+			}
+			if (args[argidx] == "-v") {
+				verbose = true;
+				continue;
+			}
+			if (args[argidx] == "-vv") {
+				verbose = true;
+				verbose_onehot = true;
 				continue;
 			}
 			break;
@@ -75,10 +270,15 @@ struct Pmux2ShiftxPass : public Pass {
 		for (auto module : design->selected_modules())
 		{
 			SigMap sigmap(module);
+			OnehotDatabase onehot_db(module, sigmap);
+			onehot_db.verbose = verbose_onehot;
+
+			if (optimize_onehot)
+				onehot_db.initialize();
 
 			dict<SigBit, pair<SigSpec, Const>> eqdb;
 
-			for (auto cell : module->selected_cells())
+			for (auto cell : module->cells())
 			{
 				if (cell->type == "$eq")
 				{
@@ -181,6 +381,12 @@ struct Pmux2ShiftxPass : public Pass {
 
 				bool printed_pmux_header = false;
 
+				if (verbose) {
+					printed_pmux_header = true;
+					log("Inspecting $pmux cell %s/%s.\n", log_id(module), log_id(cell));
+					log("  data width: %d (next power-of-2 = %d, log2 = %d)\n", width, extwidth, width_bits);
+				}
+
 				SigSpec updated_S = cell->getPort("\\S");
 				SigSpec updated_B = cell->getPort("\\B");
 
@@ -196,7 +402,7 @@ struct Pmux2ShiftxPass : public Pass {
 					}
 
 					// find the relevant choices
-					bool is_onehot = true;
+					bool is_onehot = GetSize(sig) > 2;
 					dict<Const, int> choices;
 					for (int i : seldb.at(sig)) {
 						Const val = eqdb.at(S[i]).second;
@@ -211,14 +417,17 @@ struct Pmux2ShiftxPass : public Pass {
 
 					// TBD: also find choices that are using signals that are subsets of the bits in "sig"
 
-					if (is_onehot && !allow_onehot) {
-						seldb.erase(sig);
-						continue;
-					}
+					if (!verbose)
+					{
+						if (is_onehot && !allow_onehot && !optimize_onehot) {
+							seldb.erase(sig);
+							continue;
+						}
 
-					if (GetSize(choices) < min_choices) {
-						seldb.erase(sig);
-						continue;
+						if (GetSize(choices) < min_choices) {
+							seldb.erase(sig);
+							continue;
+						}
 					}
 
 					if (!printed_pmux_header) {
@@ -229,6 +438,65 @@ struct Pmux2ShiftxPass : public Pass {
 
 					log("  checking ctrl signal %s\n", log_signal(sig));
 
+					auto print_choices = [&]() {
+						log("    table of choices:\n");
+						for (auto &it : choices)
+							log("    %3d: %s: %s\n", it.second, log_signal(it.first),
+									log_signal(B.extract(it.second*width, width)));
+					};
+
+					if (verbose)
+					{
+						if (is_onehot && !allow_onehot && !optimize_onehot) {
+							print_choices();
+							log("    ignoring one-hot encoding.\n");
+							seldb.erase(sig);
+							continue;
+						}
+
+						if (GetSize(choices) < min_choices) {
+							print_choices();
+							log("    insufficient choices.\n");
+							seldb.erase(sig);
+							continue;
+						}
+					}
+
+					if (is_onehot && optimize_onehot)
+					{
+						print_choices();
+						if (!onehot_db.query(sig))
+						{
+							log("    failed to detect onehot driver. do not optimize.\n");
+						}
+						else
+						{
+							log("    optimizing one-hot encoding.\n");
+							for (auto &it : choices)
+							{
+								const Const &val = it.first;
+								int index = -1;
+
+								for (int i = 0; i < GetSize(val); i++)
+									if (val[i] == State::S1) {
+										log_assert(index < 0);
+										index = i;
+									}
+
+								if (index < 0) {
+									log("    %3d: zero encoding.\n", it.second);
+									continue;
+								}
+
+								SigBit new_ctrl = sig[index];
+								log("    %3d: new crtl signal is %s.\n", it.second, log_signal(new_ctrl));
+								updated_S[it.second] = new_ctrl;
+							}
+						}
+						seldb.erase(sig);
+						continue;
+					}
+
 					// find the best permutation
 					vector<int> perm_new_from_old(GetSize(sig));
 					Const perm_xormask(State::S0, GetSize(sig));
@@ -434,4 +702,127 @@ struct Pmux2ShiftxPass : public Pass {
 	}
 } Pmux2ShiftxPass;
 
+struct OnehotPass : public Pass {
+	OnehotPass() : Pass("onehot", "optimize $eq cells for onehot signals") { }
+	void help() YS_OVERRIDE
+	{
+		//   |---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|---v---|
+		log("\n");
+		log("    onehot [options] [selection]\n");
+		log("\n");
+		log("This pass optimizes $eq cells that compare one-hot signals against constants\n");
+		log("\n");
+		log("    -v, -vv\n");
+		log("        verbose output\n");
+		log("\n");
+	}
+	void execute(std::vector<std::string> args, RTLIL::Design *design) YS_OVERRIDE
+	{
+		bool verbose = false;
+		bool verbose_onehot = false;
+
+		log_header(design, "Executing ONEHOT pass.\n");
+
+		size_t argidx;
+		for (argidx = 1; argidx < args.size(); argidx++) {
+			if (args[argidx] == "-v") {
+				verbose = true;
+				continue;
+			}
+			if (args[argidx] == "-vv") {
+				verbose = true;
+				verbose_onehot = true;
+				continue;
+			}
+			break;
+		}
+		extra_args(args, argidx, design);
+
+		for (auto module : design->selected_modules())
+		{
+			SigMap sigmap(module);
+			OnehotDatabase onehot_db(module, sigmap);
+			onehot_db.verbose = verbose_onehot;
+			onehot_db.initialize();
+
+			for (auto cell : module->selected_cells())
+			{
+				if (cell->type != "$eq")
+					continue;
+
+				SigSpec A = sigmap(cell->getPort("\\A"));
+				SigSpec B = sigmap(cell->getPort("\\B"));
+
+				int a_width = cell->getParam("\\A_WIDTH").as_int();
+				int b_width = cell->getParam("\\B_WIDTH").as_int();
+
+				if (a_width < b_width) {
+					bool a_signed = cell->getParam("\\A_SIGNED").as_int();
+					A.extend_u0(b_width, a_signed);
+				}
+
+				if (b_width < a_width) {
+					bool b_signed = cell->getParam("\\B_SIGNED").as_int();
+					B.extend_u0(a_width, b_signed);
+				}
+
+				if (A.is_fully_const())
+					std::swap(A, B);
+
+				if (!B.is_fully_const())
+					continue;
+
+				if (verbose)
+					log("Checking $eq(%s, %s) cell %s/%s.\n", log_signal(A), log_signal(B), log_id(module), log_id(cell));
+
+				if (!onehot_db.query(A)) {
+					if (verbose)
+						log("  onehot driver test on %s failed.\n", log_signal(A));
+					continue;
+				}
+
+				int index = -1;
+				bool not_onehot = false;
+
+				for (int i = 0; i < GetSize(B); i++) {
+					if (B[i] != State::S1)
+						continue;
+					if (index >= 0)
+						not_onehot = true;
+					index = i;
+				}
+
+				if (index < 0) {
+					if (verbose)
+						log("  not optimizing the zero pattern.\n");
+					continue;
+				}
+
+				SigSpec Y = cell->getPort("\\Y");
+
+				if (not_onehot)
+				{
+					if (verbose)
+						log("  replacing with constant 0 driver.\n");
+					else
+						log("Replacing one-hot $eq(%s, %s) cell %s/%s with constant 0 driver.\n", log_signal(A), log_signal(B), log_id(module), log_id(cell));
+					module->connect(Y, SigSpec(1, GetSize(Y)));
+				}
+				else
+				{
+					SigSpec sig = A[index];
+					if (verbose)
+						log("  replacing with signal %s.\n", log_signal(sig));
+					else
+						log("Replacing one-hot $eq(%s, %s) cell %s/%s with signal %s.\n",log_signal(A), log_signal(B), log_id(module), log_id(cell), log_signal(sig));
+					sig.extend_u0(GetSize(Y));
+					module->connect(Y, sig);
+				}
+
+				module->remove(cell);
+			}
+		}
+	}
+} OnehotPass;
+
 PRIVATE_NAMESPACE_END