From fb22d7cdc411ec52672cb7f13364651c564872db Mon Sep 17 00:00:00 2001
From: Miodrag Milanovic <mmicko@gmail.com>
Date: Tue, 15 Feb 2022 09:30:42 +0100
Subject: [PATCH] Add support for various ff/latch cells simulation

---
 kernel/fstdata.cc | 145 ++++++++-----------------
 kernel/fstdata.h  |  22 ++--
 passes/sat/sim.cc | 264 +++++++++++++++++++++++++++++++++++-----------
 3 files changed, 260 insertions(+), 171 deletions(-)

diff --git a/kernel/fstdata.cc b/kernel/fstdata.cc
index 330c4d189..1386a3300 100644
--- a/kernel/fstdata.cc
+++ b/kernel/fstdata.cc
@@ -109,8 +109,7 @@ void FstData::extractVarNames()
 				}
 				if (clean_name[0]=='\\')
 					clean_name = clean_name.substr(1);
-				//log("adding %s.%s\n",var.scope.c_str(), clean_name.c_str());
-				
+
 				name_to_handle[var.scope+"."+clean_name] = h->u.var.handle;
 				break;
 			}
@@ -118,48 +117,6 @@ void FstData::extractVarNames()
 	}
 }
 
-static void reconstruct_edges_varlen(void *user_data, uint64_t pnt_time, fstHandle pnt_facidx, const unsigned char *pnt_value, uint32_t plen)
-{
-	FstData *ptr = (FstData*)user_data;
-	ptr->reconstruct_edges_callback(pnt_time, pnt_facidx, pnt_value, plen);
-}
-
-static void reconstruct_edges(void *user_data, uint64_t pnt_time, fstHandle pnt_facidx, const unsigned char *pnt_value)
-{
-	FstData *ptr = (FstData*)user_data;
-	uint32_t plen = (pnt_value) ?  strlen((const char *)pnt_value) : 0;
-	ptr->reconstruct_edges_callback(pnt_time, pnt_facidx, pnt_value, plen);
-}
-
-void FstData::reconstruct_edges_callback(uint64_t pnt_time, fstHandle pnt_facidx, const unsigned char *pnt_value, uint32_t /* plen */)
-{
-	std::string val = std::string((const char *)pnt_value);
-	std::string prev = last_data[pnt_facidx];
-	if (pnt_time>=start_time) {
-		if (prev!="1" && val=="1")
-			edges.push_back(pnt_time);
-		if (prev!="0" && val=="0")
-			edges.push_back(pnt_time);
-	}
-	last_data[pnt_facidx] = val;
-}
-
-std::vector<uint64_t> FstData::getAllEdges(std::vector<fstHandle> &signal, uint64_t start, uint64_t end)
-{
-	start_time = start;
-	end_time = end;
-	last_data.clear();
-	for(auto &s : signal) {
-		last_data[s] = "x";
-	}
-	edges.clear();
-	fstReaderSetLimitTimeRange(ctx, start_time, end_time);
-	fstReaderClrFacProcessMaskAll(ctx);
-	for(const auto sig : signal)
-		fstReaderSetFacProcessMask(ctx,sig);
-	fstReaderIterBlocks2(ctx, reconstruct_edges, reconstruct_edges_varlen, this, nullptr);
-	return edges;
-}
 
 static void reconstruct_clb_varlen_attimes(void *user_data, uint64_t pnt_time, fstHandle pnt_facidx, const unsigned char *pnt_value, uint32_t plen)
 {
@@ -176,77 +133,65 @@ static void reconstruct_clb_attimes(void *user_data, uint64_t pnt_time, fstHandl
 
 void FstData::reconstruct_callback_attimes(uint64_t pnt_time, fstHandle pnt_facidx, const unsigned char *pnt_value, uint32_t /* plen */)
 {
-	if (sample_times_ndx >= sample_times.size()) return;
-
-	uint64_t time = sample_times[sample_times_ndx];
+	if (pnt_time > end_time) return;
 	// if we are past the timestamp
-	if (pnt_time > time) {
-		for (auto const& c : last_data)
-		{
-			handle_to_data[c.first].push_back(std::make_pair(time,c.second));
-			size_t index = handle_to_data[c.first].size() - 1;
-			time_to_index[c.first][time] = index;
+	bool is_clock = false;
+	if (!all_samples) {
+		for(auto &s : clk_signals) {
+			if (s==pnt_facidx)  { 
+				is_clock=true;
+				break;
+			}
+		}
+	}
+
+	if (pnt_time > past_time) {
+		past_data = last_data;
+		past_time = pnt_time;
+	}
+
+	if (pnt_time > last_time) {
+		if (all_samples) {
+			callback(last_time);
+			last_time = pnt_time;
+		} else {
+			if (is_clock) {
+				std::string val = std::string((const char *)pnt_value);
+				std::string prev = past_data[pnt_facidx];
+				if ((prev!="1" && val=="1") || (prev!="0" && val=="0")) {
+					callback(last_time);
+					last_time = pnt_time;
+				}
+			}
 		}
-		sample_times_ndx++;
 	}
 	// always update last_data
 	last_data[pnt_facidx] =  std::string((const char *)pnt_value);
 }
 
-void FstData::reconstructAtTimes(std::vector<fstHandle> &signal, std::vector<uint64_t> time)
+void FstData::reconstructAllAtTimes(std::vector<fstHandle> &signal, uint64_t start, uint64_t end, CallbackFunction cb)
 {
-	handle_to_data.clear();
-	time_to_index.clear();
+	clk_signals = signal;
+	callback = cb;
+	start_time = start;
+	end_time = end;
 	last_data.clear();
-	sample_times_ndx = 0;
-	sample_times = time;
-	fstReaderSetUnlimitedTimeRange(ctx);
-	fstReaderClrFacProcessMaskAll(ctx);
-	for(const auto sig : signal)
-		fstReaderSetFacProcessMask(ctx,sig);
-	fstReaderIterBlocks2(ctx, reconstruct_clb_attimes, reconstruct_clb_varlen_attimes, this, nullptr);
-
-	if (time_to_index[signal.back()].count(time.back())==0) {
-		for (auto const& c : last_data)
-		{
-			handle_to_data[c.first].push_back(std::make_pair(time.back(),c.second));
-			size_t index = handle_to_data[c.first].size() - 1;
-			time_to_index[c.first][time.back()] = index;
-		}
-	}
-}
-
-void FstData::reconstructAllAtTimes(std::vector<uint64_t> time)
-{
-	handle_to_data.clear();
-	time_to_index.clear();
-	last_data.clear();
-	sample_times_ndx = 0;
-	sample_times = time;
+	last_time = start_time;
+	past_data.clear();
+	past_time = start_time;
+	all_samples = clk_signals.empty();
 
 	fstReaderSetUnlimitedTimeRange(ctx);
 	fstReaderSetFacProcessMaskAll(ctx);
 	fstReaderIterBlocks2(ctx, reconstruct_clb_attimes, reconstruct_clb_varlen_attimes, this, nullptr);
-
-	if (time_to_index[1].count(time.back())==0) {
-		for (auto const& c : last_data)
-		{
-			handle_to_data[c.first].push_back(std::make_pair(time.back(),c.second));
-			size_t index = handle_to_data[c.first].size() - 1;
-			time_to_index[c.first][time.back()] = index;
-		}
-	}
+	callback(last_time);
+	if (last_time!=end_time)
+		callback(end_time);
 }
 
-std::string FstData::valueAt(fstHandle signal, uint64_t time)
+std::string FstData::valueOf(fstHandle signal)
 {
-	if (handle_to_data.find(signal) == handle_to_data.end())
+	if (past_data.find(signal) == past_data.end())
 		log_error("Signal id %d not found\n", (int)signal);
-	auto &data = handle_to_data[signal];
-	if (time_to_index[signal].count(time)!=0) {
-		size_t index = time_to_index[signal][time];
-		return data.at(index).second;
-	} else {
-		log_error("No data for signal %d at time %d\n", (int)signal, (int)time);
-	}
+	return past_data[signal];
 }
diff --git a/kernel/fstdata.h b/kernel/fstdata.h
index c069ff5e5..707d1b64e 100644
--- a/kernel/fstdata.h
+++ b/kernel/fstdata.h
@@ -25,6 +25,9 @@
 
 YOSYS_NAMESPACE_BEGIN
 
+typedef std::function<void(uint64_t)> CallbackFunction;
+struct fst_end_of_data_exception { };
+
 struct FstVar
 {
 	fstHandle id;
@@ -45,14 +48,10 @@ class FstData
 
 	std::vector<FstVar>& getVars() { return vars; };
 
-	void reconstruct_edges_callback(uint64_t pnt_time, fstHandle pnt_facidx, const unsigned char *pnt_value, uint32_t plen);
-	std::vector<uint64_t> getAllEdges(std::vector<fstHandle> &signal, uint64_t start_time, uint64_t end_time);
-
 	void reconstruct_callback_attimes(uint64_t pnt_time, fstHandle pnt_facidx, const unsigned char *pnt_value, uint32_t plen);
-	void reconstructAtTimes(std::vector<fstHandle> &signal,std::vector<uint64_t> time);
-	void reconstructAllAtTimes(std::vector<uint64_t> time);
+	void reconstructAllAtTimes(std::vector<fstHandle> &signal, uint64_t start_time, uint64_t end_time, CallbackFunction cb);
 
-	std::string valueAt(fstHandle signal, uint64_t time);
+	std::string valueOf(fstHandle signal);
 	fstHandle getHandle(std::string name);
 	double getTimescale() { return timescale; }
 	const char *getTimescaleString() { return timescale_str.c_str(); }
@@ -64,16 +63,17 @@ private:
 	std::vector<FstVar> vars;
 	std::map<fstHandle, FstVar> handle_to_var;
 	std::map<std::string, fstHandle> name_to_handle;
-	std::map<fstHandle, std::vector<std::pair<uint64_t, std::string>>> handle_to_data;
 	std::map<fstHandle, std::string> last_data;
-	std::map<fstHandle, std::map<uint64_t, size_t>> time_to_index;
-	std::vector<uint64_t> sample_times;
-	size_t sample_times_ndx;
+	uint64_t last_time;
+	std::map<fstHandle, std::string> past_data;
+	uint64_t past_time;
 	double timescale;
 	std::string timescale_str;
 	uint64_t start_time;
 	uint64_t end_time;
-	std::vector<uint64_t> edges;
+	CallbackFunction callback;
+	std::vector<fstHandle> clk_signals;
+	bool all_samples;
 };
 
 YOSYS_NAMESPACE_END
diff --git a/passes/sat/sim.cc b/passes/sat/sim.cc
index a7c109374..47f48e67d 100644
--- a/passes/sat/sim.cc
+++ b/passes/sat/sim.cc
@@ -22,6 +22,7 @@
 #include "kernel/celltypes.h"
 #include "kernel/mem.h"
 #include "kernel/fstdata.h"
+#include "kernel/ff.h"
 
 #include <ctime>
 
@@ -76,6 +77,7 @@ struct SimShared
 	double stop_time = -1;
 	SimulationMode sim_mode = SimulationMode::sim;
 	bool cycles_set = false;
+	const pool<IdString> ff_types = RTLIL::builtin_ff_cell_types();
 };
 
 void zinit(State &v)
@@ -113,8 +115,13 @@ struct SimInstance
 
 	struct ff_state_t
 	{
-		State past_clock;
 		Const past_d;
+		Const past_ad;
+		SigSpec past_clk;
+		SigSpec past_ce;
+		SigSpec past_srst;
+		
+		FfData data;
 	};
 
 	struct mem_state_t
@@ -209,10 +216,15 @@ struct SimInstance
 					}
 			}
 
-			if (cell->type.in(ID($dff))) {
+			if (shared->ff_types.count(cell->type)) {
+				FfData ff_data(nullptr, cell);
 				ff_state_t ff;
-				ff.past_clock = State::Sx;
-				ff.past_d = Const(State::Sx, cell->getParam(ID::WIDTH).as_int());
+				ff.past_d = Const(State::Sx, ff_data.width);
+				ff.past_ad = Const(State::Sx, ff_data.width);
+				ff.past_clk = State::Sx;
+				ff.past_ce = State::Sx;
+				ff.past_srst = State::Sx;
+				ff.data = ff_data;
 				ff_database[cell] = ff;
 			}
 
@@ -229,11 +241,10 @@ struct SimInstance
 		{
 			for (auto &it : ff_database)
 			{
-				Cell *cell = it.first;
 				ff_state_t &ff = it.second;
 				zinit(ff.past_d);
 
-				SigSpec qsig = cell->getPort(ID::Q);
+				SigSpec qsig = it.second.data.sig_q;
 				Const qdata = get_state(qsig);
 				zinit(qdata);
 				set_state(qsig, qdata);
@@ -466,20 +477,138 @@ struct SimInstance
 
 		for (auto &it : ff_database)
 		{
-			Cell *cell = it.first;
 			ff_state_t &ff = it.second;
+			FfData ff_data = ff.data;
 
-			if (cell->type.in(ID($dff)))
-			{
-				bool clkpol = cell->getParam(ID::CLK_POLARITY).as_bool();
-				State current_clock = get_state(cell->getPort(ID::CLK))[0];
+			if (ff_data.has_clk) {
+				// flip-flops
+				State current_clk = get_state(ff_data.sig_clk)[0];
 
-				if (clkpol ? (ff.past_clock == State::S1 || current_clock != State::S1) :
-						(ff.past_clock == State::S0 || current_clock != State::S0))
-					continue;
+				// handle set/reset
+				if (ff.data.has_sr) {
+					Const current_q = get_state(ff.data.sig_q);
+					Const current_clr = get_state(ff.data.sig_clr);
+					Const current_set = get_state(ff.data.sig_set);
 
-				if (set_state(cell->getPort(ID::Q), ff.past_d))
-					did_something = true;
+					for(int i=0;i<ff.past_d.size();i++) {
+						
+						if (current_clr[i] == (ff_data.pol_clr ? State::S1 : State::S0)) {
+							current_q[i] = State::S0;
+						}
+						else if (current_set[i] == (ff_data.pol_set ? State::S1 : State::S0)) {
+							current_q[i] = State::S1;
+						} else {
+							// all below is in sync with clk
+							if (ff_data.pol_clk ? (ff.past_clk == State::S1 || current_clk != State::S1) :
+									(ff.past_clk == State::S0 || current_clk != State::S0))
+								continue;
+
+							if (ff_data.has_ce) {
+								if (ff.past_ce == (ff_data.pol_ce ? State::S1 : State::S0))
+									current_q[i] = ff.past_d[i];
+							} else {
+								current_q[i] = ff.past_d[i];
+							}
+						}
+					}
+					if (set_state(ff_data.sig_q, current_q))
+						did_something = true;
+				} else {
+					// async reset
+					if (ff_data.has_arst) {
+						State current_arst = get_state(ff_data.sig_arst)[0];
+						if (current_arst == (ff_data.pol_arst ? State::S1 : State::S0)) {
+							if (set_state(ff_data.sig_q, ff_data.val_arst))
+								did_something = true;
+							continue;
+						}
+					}
+					// async load
+					if (ff_data.has_aload) {
+						State current_aload = get_state(ff_data.sig_aload)[0];
+						if (current_aload == (ff_data.pol_aload ? State::S1 : State::S0)) {
+							if (set_state(ff_data.sig_q, ff.past_ad))
+								did_something = true;
+							continue;
+						}
+					}
+
+					// all below is in sync with clk
+					if (ff_data.pol_clk ? (ff.past_clk == State::S1 || current_clk != State::S1) :
+							(ff.past_clk == State::S0 || current_clk != State::S0))
+						continue;
+
+					// chip enable priority over reset
+					if (ff_data.ce_over_srst && ff_data.has_ce) {
+						if (ff.past_ce != (ff_data.pol_ce ? State::S1 : State::S0))
+							continue;
+					}
+
+					// handle sync reset
+					if (ff_data.has_srst) {
+						if (ff.past_srst == (ff_data.pol_srst ? State::S1 : State::S0)) {
+							if (set_state(ff_data.sig_q, ff_data.val_srst))
+								did_something = true;
+							continue;
+						}
+					}
+
+					// reset had priority over chip enable
+					if (!ff_data.ce_over_srst && ff_data.has_ce) {
+						if (ff.past_ce != (ff_data.pol_ce ? State::S1 : State::S0))
+							continue;
+					}
+					if (set_state(ff_data.sig_q, ff.past_d))
+						did_something = true;
+				}
+			} else {
+				// handle set/reset
+				if (ff.data.has_sr) {
+					Const current_q = get_state(ff.data.sig_q);
+					Const current_clr = get_state(ff.data.sig_clr);
+					Const current_set = get_state(ff.data.sig_set);
+
+					for(int i=0;i<current_q.size();i++) {
+						if (current_clr[i] == (ff_data.pol_clr ? State::S1 : State::S0)) {
+							current_q[i] = State::S0;
+						}
+						else if (current_set[i] == (ff_data.pol_set ? State::S1 : State::S0)) {
+							current_q[i] = State::S1;
+						} else {
+							if (ff_data.has_aload) {
+								Const current_ad = get_state(ff.data.sig_ad);
+								State current_aload = get_state(ff_data.sig_aload)[0];
+								if (current_aload == (ff_data.pol_aload ? State::S1 : State::S0)) {
+									current_q[i] = current_ad[i];
+								}
+							}
+						}
+					}
+					if (set_state(ff_data.sig_q, current_q))
+						did_something = true;
+				}
+				// async load is true for all latches
+				else if (ff_data.has_aload) {
+					// async reset
+					if (ff_data.has_arst) {
+						State current_arst = get_state(ff_data.sig_arst)[0];
+						if (current_arst == (ff_data.pol_arst ? State::S1 : State::S0)) {
+							if (set_state(ff_data.sig_q, ff_data.val_arst))
+								did_something = true;
+							continue;
+						}
+					}
+
+					State current_aload = get_state(ff_data.sig_aload)[0];
+					if (current_aload == (ff_data.pol_aload ? State::S1 : State::S0)) {
+						if (set_state(ff_data.sig_q, get_state(ff.data.sig_ad)))
+							did_something = true;
+					}
+				} else if (ff_data.has_gclk) {
+					// $ff
+					if (set_state(ff_data.sig_q, ff.past_d))
+						did_something = true;
+				}
 			}
 		}
 
@@ -538,13 +667,22 @@ struct SimInstance
 	{
 		for (auto &it : ff_database)
 		{
-			Cell *cell = it.first;
 			ff_state_t &ff = it.second;
 
-			if (cell->type.in(ID($dff))) {
-				ff.past_clock = get_state(cell->getPort(ID::CLK))[0];
-				ff.past_d = get_state(cell->getPort(ID::D));
-			}
+			if (ff.data.has_aload)
+				ff.past_ad = get_state(ff.data.sig_ad);
+
+			if (ff.data.has_clk || ff.data.has_gclk)
+				ff.past_d = get_state(ff.data.sig_d);
+
+			if (ff.data.has_clk)
+				ff.past_clk = get_state(ff.data.sig_clk)[0];
+
+			if (ff.data.has_ce)
+				ff.past_ce = get_state(ff.data.sig_ce)[0];
+
+			if (ff.data.has_srst)
+				ff.past_srst = get_state(ff.data.sig_srst)[0];
 		}
 
 		for (auto &it : mem_database)
@@ -595,8 +733,7 @@ struct SimInstance
 
 		for (auto &it : ff_database)
 		{
-			Cell *cell = it.first;
-			SigSpec sig_q = cell->getPort(ID::Q);
+			SigSpec sig_q = it.second.data.sig_q;
 			Const initval = get_state(sig_q);
 
 			for (int i = 0; i < GetSize(sig_q); i++)
@@ -722,34 +859,32 @@ struct SimInstance
 			child.second->write_fst_step(f);
 	}
 
-	void setInitState(uint64_t time)
+	void setInitState()
 	{
 		for (auto &it : ff_database)
 		{
-			Cell *cell = it.first;
-			
-			SigSpec qsig = cell->getPort(ID::Q);
+			SigSpec qsig = it.second.data.sig_q;
 			if (qsig.is_wire()) {
 				IdString name = qsig.as_wire()->name;
 				fstHandle id = shared->fst->getHandle(scope + "." + RTLIL::unescape_id(name));
 				if (id==0 && name.isPublic())
 					log_warning("Unable to found wire %s in input file.\n", (scope + "." + RTLIL::unescape_id(name)).c_str());
 				if (id!=0) {
-					Const fst_val = Const::from_string(shared->fst->valueAt(id, time));
+					Const fst_val = Const::from_string(shared->fst->valueOf(id));
 					set_state(qsig, fst_val);
 				}
 			}
 		}
 		for (auto child : children)
-			child.second->setInitState(time);
+			child.second->setInitState();
 	}
 
-	bool checkSignals(uint64_t time)
+	bool checkSignals()
 	{
 		bool retVal = false;
 		for(auto &item : fst_handles) {
 			if (item.second==0) continue; // Ignore signals not found
-			Const fst_val = Const::from_string(shared->fst->valueAt(item.second, time));
+			Const fst_val = Const::from_string(shared->fst->valueOf(item.second));
 			Const sim_val = get_state(item.first);
 			if (sim_val.size()!=fst_val.size())
 				log_error("Signal '%s' size is different in gold and gate.\n", log_id(item.first));
@@ -779,7 +914,7 @@ struct SimInstance
 			}
 		}
 		for (auto child : children)
-			retVal |= child.second->checkSignals(time);
+			retVal |= child.second->checkSignals();
 		return retVal;
 	}
 };
@@ -998,8 +1133,6 @@ struct SimWorker : SimShared
 				log_error("Can't find port %s.%s in FST.\n", scope.c_str(), log_id(portname));
 			fst_clock.push_back(id);
 		}
-		if (fst_clock.size()==0)
-			log_error("No clock signals defined for input file\n");
 
 		SigMap sigmap(topmod);
 		std::map<Wire*,fstHandle> inputs;
@@ -1044,37 +1177,48 @@ struct SimWorker : SimShared
 		if (stopCount<startCount) {
 			log_error("Stop time is before start time\n");
 		}
-		auto samples = fst->getAllEdges(fst_clock, startCount, stopCount);
 
-		// Limit to number of cycles if provided
-		if (cycles_set && ((size_t)(numcycles *2) < samples.size()))
-			samples.erase(samples.begin() + (numcycles*2), samples.end());
-
-		// Add setup time (start time)
-		if (samples.empty() || samples.front()!=startCount)
-			samples.insert(samples.begin(), startCount);
-
-		fst->reconstructAllAtTimes(samples);
 		bool initial = true;
 		int cycle = 0;
-		log("Co-simulation from %lu%s to %lu%s\n", (unsigned long)startCount, fst->getTimescaleString(), (unsigned long)stopCount, fst->getTimescaleString());
-		for(auto &time : samples) {
-			log("Co-simulating cycle %d [%lu%s].\n", cycle, (unsigned long)time, fst->getTimescaleString());
-			for(auto &item : inputs) {
-				std::string v = fst->valueAt(item.second, time);
-				top->set_state(item.first, Const::from_string(v));
-			}
-			if (initial) {
-				top->setInitState(time);
-				initial = false;
-			}
-			update();
+		log("Co-simulation from %lu%s to %lu%s", (unsigned long)startCount, fst->getTimescaleString(), (unsigned long)stopCount, fst->getTimescaleString());
+		if (cycles_set) 
+			log(" for %d clock cycle(s)",numcycles);
+		log("\n");
+		bool all_samples = fst_clock.empty();
 
-			bool status = top->checkSignals(time);
-			if (status)
-				log_error("Signal difference\n");
-			cycle++;
+		try {
+			fst->reconstructAllAtTimes(fst_clock, startCount, stopCount, [&](uint64_t time) {
+				log("Co-simulating %s %d [%lu%s].\n", (all_samples ? "sample" : "cycle"), cycle, (unsigned long)time, fst->getTimescaleString());
+				for(auto &item : inputs) {
+					std::string v = fst->valueOf(item.second);
+					top->set_state(item.first, Const::from_string(v));
+				}
+
+				if (initial) {
+					top->setInitState();
+					write_output_header();
+					initial = false;
+				}
+				update();
+				write_output_step(5*cycle);
+
+				bool status = top->checkSignals();
+				if (status)
+					log_error("Signal difference\n");
+				cycle++;
+
+				// Limit to number of cycles if provided
+				if (cycles_set && cycle > numcycles *2)
+					throw fst_end_of_data_exception();
+				if (time==stopCount)
+					throw fst_end_of_data_exception();
+			});
+		} catch(fst_end_of_data_exception) {
+			// end of data detected
 		}
+		write_output_step(5*(cycle-1)+2);
+		write_output_end();
+
 		if (writeback) {
 			pool<Module*> wbmods;
 			top->writeback(wbmods);