From d4e3daa9d0037e2397dc74e3862b5b41c67716e5 Mon Sep 17 00:00:00 2001
From: Jannis Harder <me@jix.one>
Date: Thu, 11 Apr 2024 13:48:25 +0200
Subject: [PATCH] ComputeGraph datatype for the upcoming functional backend

---
 kernel/functional.h       | 369 ++++++++++++++++++++++++++++++++++++++
 passes/cmds/example_dt.cc | 178 ++++++++++++++----
 2 files changed, 515 insertions(+), 32 deletions(-)
 create mode 100644 kernel/functional.h

diff --git a/kernel/functional.h b/kernel/functional.h
new file mode 100644
index 000000000..e5ee88240
--- /dev/null
+++ b/kernel/functional.h
@@ -0,0 +1,369 @@
+/*
+ *  yosys -- Yosys Open SYnthesis Suite
+ *
+ *  Copyright (C) 2024  Jannis Harder <jix@yosyshq.com> <me@jix.one>
+ *
+ *  Permission to use, copy, modify, and/or distribute this software for any
+ *  purpose with or without fee is hereby granted, provided that the above
+ *  copyright notice and this permission notice appear in all copies.
+ *
+ *  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ *  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ *  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ *  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ *  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ *  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ *  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ *
+ */
+
+#ifndef FUNCTIONAL_H
+#define FUNCTIONAL_H
+
+#include <tuple>
+#include "kernel/yosys.h"
+
+YOSYS_NAMESPACE_BEGIN
+
+template<
+    typename Fn, // Function type (deduplicated across whole graph)
+    typename Attr = std::tuple<>, // Call attributes (present in every node)
+    typename SparseAttr = std::tuple<>, // Sparse call attributes (optional per node)
+    typename Key = std::tuple<> // Stable keys to refer to nodes
+>
+struct ComputeGraph
+{
+    struct Ref;
+private:
+
+    // Functions are deduplicated by assigning unique ids
+    idict<Fn> functions;
+
+    struct Node {
+        int fn_index;
+        int arg_offset;
+        int arg_count;
+        Attr attr;
+
+        Node(int fn_index, Attr &&attr, int arg_offset, int arg_count = 0)
+            : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(std::move(attr)) {}
+
+        Node(int fn_index, Attr const &attr, int arg_offset, int arg_count = 0)
+            : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(attr) {}
+    };
+
+
+    std::vector<Node> nodes;
+    std::vector<int> args;
+    dict<Key, int> keys_;
+    dict<int, SparseAttr> sparse_attrs;
+
+public:
+    template<typename Graph>
+    struct BaseRef
+    {
+    protected:
+        friend struct ComputeGraph;
+        Graph *graph_;
+        int index_;
+        BaseRef(Graph *graph, int index) : graph_(graph), index_(index) {
+            log_assert(index_ >= 0);
+            check();
+        }
+
+        void check() const { log_assert(index_ < graph_->size()); }
+
+        Node const &deref() const { check(); return graph_->nodes[index_]; }
+
+    public:
+        ComputeGraph const &graph() const { return graph_; }
+        int index() const { return index_; }
+
+        int size() const { return deref().arg_count; }
+
+        BaseRef arg(int n) const
+        {
+            Node const &node = deref();
+            log_assert(n >= 0 && n < node.arg_count);
+            return BaseRef(graph_, graph_->args[node.arg_offset + n]);
+        }
+
+        std::vector<int>::const_iterator arg_indices_cbegin() const
+        {
+            Node const &node = deref();
+            return graph_->args.cbegin() + node.arg_offset;
+        }
+
+        std::vector<int>::const_iterator arg_indices_cend() const
+        {
+            Node const &node = deref();
+            return graph_->args.cbegin() + node.arg_offset + node.arg_count;
+        }
+
+        Fn const &function() const { return graph_->functions[deref().fn_index]; }
+        Attr const &attr() const { return deref().attr; }
+
+        bool has_sparse_attr() const { return graph_->sparse_attrs.count(index_); }
+
+        SparseAttr const &sparse_attr() const
+        {
+            auto found = graph_->sparse_attrs.find(index_);
+            log_assert(found != graph_->sparse_attrs.end());
+            return *found;
+        }
+    };
+
+    using ConstRef = BaseRef<ComputeGraph const>;
+
+    struct Ref : public BaseRef<ComputeGraph>
+    {
+    private:
+        friend struct ComputeGraph;
+        Ref(ComputeGraph *graph, int index) : BaseRef<ComputeGraph>(graph, index) {}
+        Node &deref() const { this->check(); return this->graph_->nodes[this->index_]; }
+
+    public:
+        void set_function(Fn const &function) const
+        {
+            deref().fn_index = this->graph_->functions(function);
+        }
+
+        Attr &attr() const { return deref().attr; }
+
+        void append_arg(ConstRef arg) const
+        {
+            log_assert(arg.graph_ == this->graph_);
+            append_arg(arg.index());
+        }
+
+        void append_arg(int arg) const
+        {
+            log_assert(arg >= 0 && arg < this->graph_->size());
+            Node &node = deref();
+            if (node.arg_offset + node.arg_count != GetSize(this->graph_->args))
+                move_args(node);
+            this->graph_->args.push_back(arg);
+            node.arg_count++;
+        }
+
+        operator ConstRef() const
+        {
+            return ConstRef(this->graph_, this->index_);
+        }
+
+        SparseAttr &sparse_attr() const
+        {
+            return this->graph_->sparse_attrs[this->index_];
+        }
+
+        void clear_sparse_attr() const
+        {
+            this->graph_->sparse_attrs.erase(this->index_);
+        }
+
+        void assign_key(Key const &key) const
+        {
+            this->graph_->keys_.emplace(key, this->index_);
+        }
+
+    private:
+        void move_args(Node &node) const
+        {
+            auto &args = this->graph_->args;
+            int old_offset = node.arg_offset;
+            node.arg_offset = GetSize(args);
+            for (int i = 0; i != node.arg_count; ++i)
+                args.push_back(args[old_offset + i]);
+        }
+
+    };
+
+    bool has_key(Key const &key) const
+    {
+        return keys_.count(key);
+    }
+
+    dict<Key, int> const &keys() const
+    {
+        return keys_;
+    }
+
+    ConstRef operator()(Key const &key) const
+    {
+        auto it = keys_.find(key);
+        log_assert(it != keys_.end());
+        return (*this)[it->second];
+    }
+
+    Ref operator()(Key const &key)
+    {
+        auto it = keys_.find(key);
+        log_assert(it != keys_.end());
+        return (*this)[it->second];
+    }
+
+    int size() const { return GetSize(nodes); }
+
+    ConstRef operator[](int index) const { return ConstRef(this, index); }
+    Ref operator[](int index) { return Ref(this, index); }
+
+    Ref add(Fn const &function, Attr &&attr)
+    {
+        int index = GetSize(nodes);
+        int fn_index = functions(function);
+        nodes.emplace_back(fn_index, std::move(attr), GetSize(args));
+        return Ref(this, index);
+    }
+
+    Ref add(Fn const &function, Attr const &attr)
+    {
+        int index = GetSize(nodes);
+        int fn_index = functions(function);
+        nodes.emplace_back(fn_index, attr,  GetSize(args));
+        return Ref(this, index);
+    }
+
+    template<typename T>
+    Ref add(Fn const &function, Attr const &attr, T const &args)
+    {
+        Ref added = add(function, attr);
+        for (auto arg : args)
+            added.append_arg(arg);
+        return added;
+    }
+
+    template<typename T>
+    Ref add(Fn const &function, Attr &&attr, T const &args)
+    {
+        Ref added = add(function, std::move(attr));
+        for (auto arg : args)
+            added.append_arg(arg);
+        return added;
+    }
+
+    template<typename T>
+    Ref add(Fn const &function, Attr const &attr, T begin, T end)
+    {
+        Ref added = add(function, attr);
+        for (; begin != end; ++begin)
+            added.append_arg(*begin);
+        return added;
+    }
+
+    void permute(std::vector<int> const &perm)
+    {
+        log_assert(perm.size() <= nodes.size());
+        std::vector<int> inv_perm;
+        inv_perm.resize(nodes.size(), -1);
+        for (int i = 0; i < GetSize(perm); ++i)
+        {
+            int j = perm[i];
+            log_assert(j >= 0 && j < GetSize(perm));
+            log_assert(inv_perm[j] == -1);
+            inv_perm[j] = i;
+        }
+        permute(perm, inv_perm);
+    }
+
+    void permute(std::vector<int> const &perm, std::vector<int> const &inv_perm)
+    {
+        log_assert(inv_perm.size() == nodes.size());
+        std::vector<Node> new_nodes;
+        new_nodes.reserve(perm.size());
+        dict<int, SparseAttr> new_sparse_attrs;
+        for (int i : perm)
+        {
+            int j = GetSize(new_nodes);
+            new_nodes.emplace_back(std::move(nodes[i]));
+            auto found = sparse_attrs.find(i);
+            if (found != sparse_attrs.end())
+                new_sparse_attrs.emplace(j, std::move(found->second));
+        }
+
+        std::swap(nodes, new_nodes);
+        std::swap(sparse_attrs, new_sparse_attrs);
+
+        for (int &arg : args)
+        {
+            log_assert(arg < GetSize(inv_perm));
+            arg = inv_perm[arg];
+        }
+
+        for (auto &key : keys_)
+        {
+            log_assert(key.second < GetSize(inv_perm));
+            key.second = inv_perm[key.second];
+        }
+    }
+
+    struct SccAdaptor
+    {
+    private:
+        ComputeGraph const &graph_;
+        std::vector<int> indices_;
+    public:
+        SccAdaptor(ComputeGraph const &graph) : graph_(graph)
+        {
+            indices_.resize(graph.size(), -1);
+        }
+
+
+        typedef int node_type;
+
+        struct node_enumerator {
+        private:
+            friend struct SccAdaptor;
+            int current, end;
+            node_enumerator(int current, int end) : current(current), end(end) {}
+
+        public:
+
+            bool finished() const { return current == end; }
+            node_type next() {
+                log_assert(!finished());
+                node_type result = current;
+                ++current;
+                return result;
+            }
+        };
+
+        node_enumerator enumerate_nodes() {
+            return node_enumerator(0, GetSize(indices_));
+        }
+
+
+        struct successor_enumerator {
+        private:
+            friend struct SccAdaptor;
+            std::vector<int>::const_iterator current, end;
+            successor_enumerator(std::vector<int>::const_iterator current, std::vector<int>::const_iterator end) :
+                current(current), end(end) {}
+
+        public:
+            bool finished() const { return current == end; }
+            node_type next() {
+                log_assert(!finished());
+                node_type result = *current;
+                ++current;
+                return result;
+            }
+        };
+
+        successor_enumerator enumerate_successors(int index) const {
+            auto const &ref = graph_[index];
+            return successor_enumerator(ref.arg_indices_cbegin(), ref.arg_indices_cend());
+        }
+
+        int &dfs_index(node_type const &node) { return indices_[node]; }
+
+        std::vector<int> const &dfs_indices() { return indices_; }
+    };
+
+};
+
+
+
+YOSYS_NAMESPACE_END
+
+
+#endif
diff --git a/passes/cmds/example_dt.cc b/passes/cmds/example_dt.cc
index de84fa3cd..dec554d6c 100644
--- a/passes/cmds/example_dt.cc
+++ b/passes/cmds/example_dt.cc
@@ -1,6 +1,7 @@
 #include "kernel/yosys.h"
 #include "kernel/drivertools.h"
 #include "kernel/topo_scc.h"
+#include "kernel/functional.h"
 
 USING_YOSYS_NAMESPACE
 PRIVATE_NAMESPACE_BEGIN
@@ -38,86 +39,137 @@ struct ExampleDtPass : public Pass
 			ExampleWorker worker(module);
 			DriverMap dm;
 
+			struct ExampleFn {
+				IdString name;
+				dict<IdString, Const> parameters;
+
+				ExampleFn(IdString name) : name(name) {}
+				ExampleFn(IdString name, dict<IdString, Const> parameters) : name(name), parameters(parameters) {}
+
+				bool operator==(ExampleFn const &other) const {
+					return name == other.name && parameters == other.parameters;
+				}
+
+				unsigned int hash() const {
+					return mkhash(name.hash(), parameters.hash());
+				}
+			};
+
+			typedef ComputeGraph<ExampleFn, int, IdString, IdString> ExampleGraph;
+
+			ExampleGraph compute_graph;
+
+
 			dm.add(module);
 
 			idict<DriveSpec> queue;
 			idict<Cell *> cells;
 
 			IntGraph edges;
+			std::vector<int> graph_nodes;
 
+			auto enqueue = [&](DriveSpec const &spec) {
+				int index = queue(spec);
+				if (index == GetSize(graph_nodes))
+					graph_nodes.emplace_back(compute_graph.add(ID($pending), index).index());
+				//if (index >= GetSize(graph_nodes))
+				return compute_graph[graph_nodes[index]];
+			};
 
 			for (auto cell : module->cells()) {
 				if (cell->type.in(ID($assert), ID($assume), ID($cover), ID($check)))
-					queue(DriveBitMarker(cells(cell), 0));
+					enqueue(DriveBitMarker(cells(cell), 0));
 			}
 
 			for (auto wire : module->wires()) {
 				if (!wire->port_output)
 					continue;
-				queue(DriveChunk(DriveChunkWire(wire, 0, wire->width)));
+				enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width))).assign_key(wire->name);
 			}
 
-#define emit log
-// #define emit(X...) do {} while (false)
-
 			for (int i = 0; i != GetSize(queue); ++i)
 			{
-				emit("n%d: ", i);
 				DriveSpec spec = queue[i];
+				ExampleGraph::Ref node = compute_graph[i];
+
 				if (spec.chunks().size() > 1) {
-					emit("concat %s <-\n", log_signal(spec));
+					node.set_function(ID($$concat));
+
 					for (auto const &chunk : spec.chunks()) {
-						emit("  * %s\n", log_signal(chunk));
-						edges.add_edge(i, queue(chunk));
+						node.append_arg(enqueue(chunk));
 					}
 				} else if (spec.chunks().size() == 1) {
 					DriveChunk chunk = spec.chunks()[0];
 					if (chunk.is_wire()) {
 						DriveChunkWire wire_chunk = chunk.wire();
 						if (wire_chunk.is_whole()) {
+							node.sparse_attr() = wire_chunk.wire->name;
 							if (wire_chunk.wire->port_input) {
-								emit("input %s\n", log_signal(spec));
+								node.set_function(ExampleFn(ID($$input), {{wire_chunk.wire->name, {}}}));
 							} else {
 								DriveSpec driver = dm(DriveSpec(wire_chunk));
-								edges.add_edge(i, queue(driver));
-								emit("wire driver %s <- %s\n", log_signal(spec), log_signal(driver));
+								node.set_function(ID($$buf));
+
+								node.append_arg(enqueue(driver));
 							}
 						} else {
 							DriveChunkWire whole_wire(wire_chunk.wire, 0, wire_chunk.width);
-							edges.add_edge(i, queue(whole_wire));
-							emit("wire slice %s <- %s\n", log_signal(spec), log_signal(DriveSpec(whole_wire)));
+							node.set_function(ExampleFn(ID($$slice), {{ID(offset), wire_chunk.offset}, {ID(width), wire_chunk.width}}));
+							node.append_arg(enqueue(whole_wire));
 						}
 					} else if (chunk.is_port()) {
 						DriveChunkPort port_chunk = chunk.port();
 						if (port_chunk.is_whole()) {
 							if (dm.celltypes.cell_output(port_chunk.cell->type, port_chunk.port)) {
-								int cell_marker = queue(DriveBitMarker(cells(port_chunk.cell), 0));
-								if (!port_chunk.cell->type.in(ID($dff), ID($ff)))
-									edges.add_edge(i, cell_marker);
-								emit("cell output %s %s\n", log_id(port_chunk.cell), log_id(port_chunk.port));
+								if (port_chunk.cell->type.in(ID($dff), ID($ff)))
+								{
+									Cell *cell = port_chunk.cell;
+									node.set_function(ExampleFn(ID($$state), {{cell->name, {}}}));
+									for (auto const &conn : cell->connections()) {
+										if (!dm.celltypes.cell_input(cell->type, conn.first))
+											continue;
+										enqueue(DriveChunkPort(cell, conn)).assign_key(cell->name);
+									}
+								}
+								else
+								{
+									node.set_function(ExampleFn(ID($$cell_output), {{port_chunk.port, {}}}));
+									node.append_arg(enqueue(DriveBitMarker(cells(port_chunk.cell), 0)));
+								}
 							} else {
+								node.set_function(ID($$buf));
+
 								DriveSpec driver = dm(DriveSpec(port_chunk));
-								edges.add_edge(i, queue(driver));
-								emit("cell port driver %s <- %s\n", log_signal(spec), log_signal(driver));
+								node.append_arg(enqueue(driver));
 							}
 
 						} else {
 							DriveChunkPort whole_port(port_chunk.cell, port_chunk.port, 0, GetSize(port_chunk.cell->connections().at(port_chunk.port)));
-							edges.add_edge(i, queue(whole_port));
-							emit("port slice %s <- %s\n", log_signal(spec), log_signal(DriveSpec(whole_port)));
+							node.set_function(ID($$buf));
+							node.append_arg(enqueue(whole_port));
 						}
 					} else if (chunk.is_constant()) {
-						emit("constant %s <- %s\n", log_signal(spec), log_const(chunk.constant()));
+						node.set_function(ExampleFn(ID($$const), {{ID(value), chunk.constant()}}));
+
+					} else if (chunk.is_multiple()) {
+						node.set_function(ID($$multi));
+						for (auto const &driver : chunk.multiple().multiple())
+							node.append_arg(enqueue(driver));
 					} else if (chunk.is_marker()) {
 						Cell *cell = cells[chunk.marker().marker];
-						emit("cell %s %s\n", log_id(cell->type), log_id(cell));
+
+						node.set_function(ExampleFn(cell->type, cell->parameters));
 						for (auto const &conn : cell->connections()) {
 							if (!dm.celltypes.cell_input(cell->type, conn.first))
 								continue;
-							emit("  * %s <- %s\n", log_id(conn.first), log_signal(conn.second));
-							edges.add_edge(i, queue(DriveChunkPort(cell, conn)));
+
+							node.append_arg(enqueue(DriveChunkPort(cell, conn)));
 						}
+					} else if (chunk.is_none()) {
+						node.set_function(ID($$undriven));
+
 					} else {
+						log_error("unhandled drivespec: %s\n", log_signal(chunk));
 						log_abort();
 					}
 				} else {
@@ -125,13 +177,75 @@ struct ExampleDtPass : public Pass
 				}
 			}
 
-			topo_sorted_sccs(edges, [&](int *begin, int *end) {
-				emit("scc:");
-				for (int *i = begin; i != end; ++i)
-					emit(" n%d", *i);
-				emit("\n");
-			});
 
+			// Perform topo sort and detect SCCs
+			ExampleGraph::SccAdaptor compute_graph_scc(compute_graph);
+
+
+			std::vector<int> perm;
+			topo_sorted_sccs(compute_graph_scc, [&](int *begin, int *end) {
+				perm.insert(perm.end(), begin, end);
+				if (end > begin + 1)
+				{
+					log_warning("SCC:");
+					for (int *i = begin; i != end; ++i)
+						log(" %d", *i);
+					log("\n");
+				}
+			}, /* sources_first */ true);
+			compute_graph.permute(perm);
+
+
+			// Forward $$buf unless we have a name in the sparse attribute
+			std::vector<int> alias;
+			perm.clear();
+
+			for (int i = 0; i < compute_graph.size(); ++i)
+			{
+				if (compute_graph[i].function().name == ID($$buf) && !compute_graph[i].has_sparse_attr() && compute_graph[i].arg(0).index() < i)
+				{
+
+					alias.push_back(alias[compute_graph[i].arg(0).index()]);
+				}
+				else
+				{
+					alias.push_back(GetSize(perm));
+					perm.push_back(i);
+				}
+			}
+			compute_graph.permute(perm, alias);
+
+			// Dump the compute graph
+			for (int i = 0; i < compute_graph.size(); ++i)
+			{
+				auto ref = compute_graph[i];
+				log("n%d ", i);
+				log("%s", log_id(ref.function().name));
+				for (auto const &param : ref.function().parameters)
+				{
+					if (param.second.empty())
+						log("[%s]", log_id(param.first));
+					else
+						log("[%s=%s]", log_id(param.first), log_const(param.second));
+				}
+				log("(");
+
+				for (int i = 0, end = ref.size(); i != end; ++i)
+				{
+					if (i > 0)
+						log(", ");
+					log("n%d", ref.arg(i).index());
+				}
+				log(")\n");
+				if (ref.has_sparse_attr())
+					log("// wire %s\n", log_id(ref.sparse_attr()));
+				log("// was #%d %s\n", ref.attr(), log_signal(queue[ref.attr()]));
+			}
+
+			for (auto const &key : compute_graph.keys())
+			{
+				log("return %d as %s \n", key.second, log_id(key.first));
+			}
 		}
 		log("Plugin test passed!\n");
 	}