From 850b3a6c29b8b1ffd76ed24493d85ffa721f8109 Mon Sep 17 00:00:00 2001
From: Emily Schmidt <emily@yosyshq.com>
Date: Thu, 25 Jul 2024 12:10:59 +0100
Subject: [PATCH] convert class FunctionalIR to a namespace Functional, rename
 functionalir.h to functional.h, rename functional.h to compute_graph.h

---
 Makefile                                  |   2 +-
 backends/functional/cxx.cc                |  20 +-
 backends/functional/smtlib.cc             |  22 +-
 backends/functional/test_generic.cc       |   4 +-
 kernel/compute_graph.h                    | 403 ++++++++++
 kernel/{functionalir.cc => functional.cc} | 111 ++-
 kernel/functional.h                       | 935 +++++++++++++---------
 kernel/functionalir.h                     | 575 -------------
 kernel/mem.h                              |  11 +-
 kernel/utils.h                            |   9 +
 passes/cmds/example_dt.cc                 |   2 +-
 11 files changed, 1055 insertions(+), 1039 deletions(-)
 create mode 100644 kernel/compute_graph.h
 rename kernel/{functionalir.cc => functional.cc} (90%)
 delete mode 100644 kernel/functionalir.h

diff --git a/Makefile b/Makefile
index 68e6fda4a..e61948cb4 100644
--- a/Makefile
+++ b/Makefile
@@ -640,7 +640,7 @@ $(eval $(call add_include_file,backends/rtlil/rtlil_backend.h))
 OBJS += kernel/driver.o kernel/register.o kernel/rtlil.o kernel/log.o kernel/calc.o kernel/yosys.o
 OBJS += kernel/binding.o
 OBJS += kernel/cellaigs.o kernel/celledges.o kernel/cost.o kernel/satgen.o kernel/scopeinfo.o kernel/qcsat.o kernel/mem.o kernel/ffmerge.o kernel/ff.o kernel/yw.o kernel/json.o kernel/fmt.o kernel/sexpr.o
-OBJS += kernel/drivertools.o kernel/functionalir.o
+OBJS += kernel/drivertools.o kernel/functional.o
 ifeq ($(ENABLE_ZLIB),1)
 OBJS += kernel/fstdata.o
 endif
diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc
index 8d53c9e03..a4755e144 100644
--- a/backends/functional/cxx.cc
+++ b/backends/functional/cxx.cc
@@ -18,7 +18,7 @@
  */
 
 #include "kernel/yosys.h"
-#include "kernel/functionalir.h"
+#include "kernel/functional.h"
 #include <ctype.h>
 
 USING_YOSYS_NAMESPACE
@@ -42,7 +42,7 @@ const char *reserved_keywords[] = {
 	nullptr
 };
 
-template<typename Id> struct CxxScope : public FunctionalTools::Scope<Id> {
+template<typename Id> struct CxxScope : public Functional::Scope<Id> {
 	CxxScope() {
 		for(const char **p = reserved_keywords; *p != nullptr; p++)
 			this->reserve(*p);
@@ -53,8 +53,8 @@ template<typename Id> struct CxxScope : public FunctionalTools::Scope<Id> {
 };
 
 struct CxxType {
-	FunctionalIR::Sort sort;
-	CxxType(FunctionalIR::Sort sort) : sort(sort) {}
+	Functional::Sort sort;
+	CxxType(Functional::Sort sort) : sort(sort) {}
 	std::string to_string() const {
 		if(sort.is_memory()) {
 			return stringf("Memory<%d, %d>", sort.addr_width(), sort.data_width());
@@ -66,7 +66,7 @@ struct CxxType {
 	}
 };
 
-using CxxWriter = FunctionalTools::Writer;
+using CxxWriter = Functional::Writer;
 
 struct CxxStruct {
 	std::string name;
@@ -111,8 +111,8 @@ std::string cxx_const(RTLIL::Const const &value) {
 	return ss.str();
 }
 
-template<class NodePrinter> struct CxxPrintVisitor : public FunctionalIR::AbstractVisitor<void> {
-	using Node = FunctionalIR::Node;
+template<class NodePrinter> struct CxxPrintVisitor : public Functional::AbstractVisitor<void> {
+	using Node = Functional::Node;
 	CxxWriter &f;
 	NodePrinter np;
 	CxxStruct &input_struct;
@@ -165,12 +165,12 @@ bool equal_def(RTLIL::Const const &a, RTLIL::Const const &b) {
 }
 
 struct CxxModule {
-	FunctionalIR ir;
+	Functional::IR ir;
 	CxxStruct input_struct, output_struct, state_struct;
 	std::string module_name;
 
 	explicit CxxModule(Module *module) :
-		ir(FunctionalIR::from_module(module)),
+		ir(Functional::IR::from_module(module)),
 		input_struct("Inputs"),
 		output_struct("Outputs"),
 		state_struct("State")
@@ -222,7 +222,7 @@ struct CxxModule {
 		locals.reserve("output");
 		locals.reserve("current_state");
 		locals.reserve("next_state");
-		auto node_name = [&](FunctionalIR::Node n) { return locals(n.id(), n.name()); };
+		auto node_name = [&](Functional::Node n) { return locals(n.id(), n.name()); };
 		CxxPrintVisitor printVisitor(f, node_name, input_struct, state_struct);
 		for (auto node : ir) {
 			f.print("\t{} {} = ", CxxType(node.sort()).to_string(), node_name(node));
diff --git a/backends/functional/smtlib.cc b/backends/functional/smtlib.cc
index 0d2763d32..7fd6fe564 100644
--- a/backends/functional/smtlib.cc
+++ b/backends/functional/smtlib.cc
@@ -17,7 +17,7 @@
  *
  */
 
-#include "kernel/functionalir.h"
+#include "kernel/functional.h"
 #include "kernel/yosys.h"
 #include "kernel/sexpr.h"
 #include <ctype.h>
@@ -42,7 +42,7 @@ const char *reserved_keywords[] = {
 	nullptr
 };
 
-struct SmtScope : public FunctionalTools::Scope<int> {
+struct SmtScope : public Functional::Scope<int> {
 	SmtScope() {
 		for(const char **p = reserved_keywords; *p != nullptr; p++)
 			reserve(*p);
@@ -53,8 +53,8 @@ struct SmtScope : public FunctionalTools::Scope<int> {
 };
 
 struct SmtSort {
-	FunctionalIR::Sort sort;
-	SmtSort(FunctionalIR::Sort sort) : sort(sort) {}
+	Functional::Sort sort;
+	SmtSort(Functional::Sort sort) : sort(sort) {}
 	SExpr to_sexpr() const {
 		if(sort.is_memory()) {
 			return list("Array", list("_", "BitVec", sort.addr_width()), list("_", "BitVec", sort.data_width()));
@@ -116,8 +116,8 @@ std::string smt_const(RTLIL::Const const &c) {
 	return s;
 }
 
-struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor<SExpr> {
-	using Node = FunctionalIR::Node;
+struct SmtPrintVisitor : public Functional::AbstractVisitor<SExpr> {
+	using Node = Functional::Node;
 	std::function<SExpr(Node)> n;
 	SmtStruct &input_struct;
 	SmtStruct &state_struct;
@@ -183,7 +183,7 @@ struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor<SExpr> {
 };
 
 struct SmtModule {
-	FunctionalIR ir;
+	Functional::IR ir;
 	SmtScope scope;
 	std::string name;
 	
@@ -192,7 +192,7 @@ struct SmtModule {
 	SmtStruct state_struct;
 
 	SmtModule(Module *module)
-		: ir(FunctionalIR::from_module(module))
+		: ir(Functional::IR::from_module(module))
 		, scope()
 		, name(scope.unique_name(module->name))
 		, input_struct(scope.unique_name(module->name.str() + "_Inputs"), scope)
@@ -215,11 +215,11 @@ struct SmtModule {
 			list(list("inputs", input_struct.name),
 			     list("state", state_struct.name)),
 			list("Pair", output_struct.name, state_struct.name)));
-		auto inlined = [&](FunctionalIR::Node n) {
-			return n.fn() == FunctionalIR::Fn::constant;
+		auto inlined = [&](Functional::Node n) {
+			return n.fn() == Functional::Fn::constant;
 		};
 		SmtPrintVisitor visitor(input_struct, state_struct);
-		auto node_to_sexpr = [&](FunctionalIR::Node n) -> SExpr {
+		auto node_to_sexpr = [&](Functional::Node n) -> SExpr {
 			if(inlined(n))
 				return n.visit(visitor);
 			else
diff --git a/backends/functional/test_generic.cc b/backends/functional/test_generic.cc
index 5d9349276..83ea09d8d 100644
--- a/backends/functional/test_generic.cc
+++ b/backends/functional/test_generic.cc
@@ -18,7 +18,7 @@
  */
 
 #include "kernel/yosys.h"
-#include "kernel/functionalir.h"
+#include "kernel/functional.h"
 #include <random>
 
 USING_YOSYS_NAMESPACE
@@ -139,7 +139,7 @@ struct FunctionalTestGeneric : public Pass
 
 		for (auto module : design->selected_modules()) {
             log("Dumping module `%s'.\n", module->name.c_str());
-			auto fir = FunctionalIR::from_module(module);
+			auto fir = Functional::IR::from_module(module);
 			for(auto node : fir)
 				std::cout << RTLIL::unescape_id(node.name()) << " = " << node.to_string([](auto n) { return RTLIL::unescape_id(n.name()); }) << "\n";
 			for(auto [name, sort] : fir.outputs())
diff --git a/kernel/compute_graph.h b/kernel/compute_graph.h
new file mode 100644
index 000000000..aeba17f8c
--- /dev/null
+++ b/kernel/compute_graph.h
@@ -0,0 +1,403 @@
+/*
+ *  yosys -- Yosys Open SYnthesis Suite
+ *
+ *  Copyright (C) 2024  Jannis Harder <jix@yosyshq.com> <me@jix.one>
+ *
+ *  Permission to use, copy, modify, and/or distribute this software for any
+ *  purpose with or without fee is hereby granted, provided that the above
+ *  copyright notice and this permission notice appear in all copies.
+ *
+ *  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ *  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ *  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ *  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ *  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ *  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ *  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ *
+ */
+
+#ifndef COMPUTE_GRAPH_H
+#define COMPUTE_GRAPH_H
+
+#include <tuple>
+#include "kernel/yosys.h"
+
+YOSYS_NAMESPACE_BEGIN
+
+template<
+    typename Fn, // Function type (deduplicated across whole graph)
+    typename Attr = std::tuple<>, // Call attributes (present in every node)
+    typename SparseAttr = std::tuple<>, // Sparse call attributes (optional per node)
+    typename Key = std::tuple<> // Stable keys to refer to nodes
+>
+struct ComputeGraph
+{
+    struct Ref;
+private:
+
+    // Functions are deduplicated by assigning unique ids
+    idict<Fn> functions;
+
+    struct Node {
+        int fn_index;
+        int arg_offset;
+        int arg_count;
+        Attr attr;
+
+        Node(int fn_index, Attr &&attr, int arg_offset, int arg_count = 0)
+            : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(std::move(attr)) {}
+
+        Node(int fn_index, Attr const &attr, int arg_offset, int arg_count = 0)
+            : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(attr) {}
+    };
+
+
+    std::vector<Node> nodes;
+    std::vector<int> args;
+    dict<Key, int> keys_;
+    dict<int, SparseAttr> sparse_attrs;
+
+public:
+    template<typename Graph>
+    struct BaseRef
+    {
+    protected:
+        friend struct ComputeGraph;
+        Graph *graph_;
+        int index_;
+        BaseRef(Graph *graph, int index) : graph_(graph), index_(index) {
+            log_assert(index_ >= 0);
+            check();
+        }
+
+        void check() const { log_assert(index_ < graph_->size()); }
+
+        Node const &deref() const { check(); return graph_->nodes[index_]; }
+
+    public:
+        ComputeGraph const &graph() const { return graph_; }
+        int index() const { return index_; }
+
+        int size() const { return deref().arg_count; }
+
+        BaseRef arg(int n) const
+        {
+            Node const &node = deref();
+            log_assert(n >= 0 && n < node.arg_count);
+            return BaseRef(graph_, graph_->args[node.arg_offset + n]);
+        }
+
+        std::vector<int>::const_iterator arg_indices_cbegin() const
+        {
+            Node const &node = deref();
+            return graph_->args.cbegin() + node.arg_offset;
+        }
+
+        std::vector<int>::const_iterator arg_indices_cend() const
+        {
+            Node const &node = deref();
+            return graph_->args.cbegin() + node.arg_offset + node.arg_count;
+        }
+
+        Fn const &function() const { return graph_->functions[deref().fn_index]; }
+        Attr const &attr() const { return deref().attr; }
+
+        bool has_sparse_attr() const { return graph_->sparse_attrs.count(index_); }
+
+        SparseAttr const &sparse_attr() const
+        {
+            auto found = graph_->sparse_attrs.find(index_);
+            log_assert(found != graph_->sparse_attrs.end());
+            return found->second;
+        }
+    };
+
+    using ConstRef = BaseRef<ComputeGraph const>;
+
+    struct Ref : public BaseRef<ComputeGraph>
+    {
+    private:
+        friend struct ComputeGraph;
+        Ref(ComputeGraph *graph, int index) : BaseRef<ComputeGraph>(graph, index) {}
+        Node &deref() const { this->check(); return this->graph_->nodes[this->index_]; }
+
+    public:
+        Ref(BaseRef<ComputeGraph> ref) : Ref(ref.graph_, ref.index_) {}
+
+        void set_function(Fn const &function) const
+        {
+            deref().fn_index = this->graph_->functions(function);
+        }
+
+        Attr &attr() const { return deref().attr; }
+
+        void append_arg(ConstRef arg) const
+        {
+            log_assert(arg.graph_ == this->graph_);
+            append_arg(arg.index());
+        }
+
+        void append_arg(int arg) const
+        {
+            log_assert(arg >= 0 && arg < this->graph_->size());
+            Node &node = deref();
+            if (node.arg_offset + node.arg_count != GetSize(this->graph_->args))
+                move_args(node);
+            this->graph_->args.push_back(arg);
+            node.arg_count++;
+        }
+
+        operator ConstRef() const
+        {
+            return ConstRef(this->graph_, this->index_);
+        }
+
+        SparseAttr &sparse_attr() const
+        {
+            return this->graph_->sparse_attrs[this->index_];
+        }
+
+        void clear_sparse_attr() const
+        {
+            this->graph_->sparse_attrs.erase(this->index_);
+        }
+
+        void assign_key(Key const &key) const
+        {
+            this->graph_->keys_.emplace(key, this->index_);
+        }
+
+    private:
+        void move_args(Node &node) const
+        {
+            auto &args = this->graph_->args;
+            int old_offset = node.arg_offset;
+            node.arg_offset = GetSize(args);
+            for (int i = 0; i != node.arg_count; ++i)
+                args.push_back(args[old_offset + i]);
+        }
+
+    };
+
+    bool has_key(Key const &key) const
+    {
+        return keys_.count(key);
+    }
+
+    dict<Key, int> const &keys() const
+    {
+        return keys_;
+    }
+
+    ConstRef operator()(Key const &key) const
+    {
+        auto it = keys_.find(key);
+        log_assert(it != keys_.end());
+        return (*this)[it->second];
+    }
+
+    Ref operator()(Key const &key)
+    {
+        auto it = keys_.find(key);
+        log_assert(it != keys_.end());
+        return (*this)[it->second];
+    }
+
+    int size() const { return GetSize(nodes); }
+
+    ConstRef operator[](int index) const { return ConstRef(this, index); }
+    Ref operator[](int index) { return Ref(this, index); }
+
+    Ref add(Fn const &function, Attr &&attr)
+    {
+        int index = GetSize(nodes);
+        int fn_index = functions(function);
+        nodes.emplace_back(fn_index, std::move(attr), GetSize(args));
+        return Ref(this, index);
+    }
+
+    Ref add(Fn const &function, Attr const &attr)
+    {
+        int index = GetSize(nodes);
+        int fn_index = functions(function);
+        nodes.emplace_back(fn_index, attr,  GetSize(args));
+        return Ref(this, index);
+    }
+
+    template<typename T>
+    Ref add(Fn const &function, Attr const &attr, T &&args)
+    {
+        Ref added = add(function, attr);
+        for (auto arg : args)
+            added.append_arg(arg);
+        return added;
+    }
+
+    template<typename T>
+    Ref add(Fn const &function, Attr &&attr, T &&args)
+    {
+        Ref added = add(function, std::move(attr));
+        for (auto arg : args)
+            added.append_arg(arg);
+        return added;
+    }
+
+    Ref add(Fn const &function, Attr const &attr, std::initializer_list<Ref> args)
+    {
+        Ref added = add(function, attr);
+        for (auto arg : args)
+            added.append_arg(arg);
+        return added;
+    }
+
+    Ref add(Fn const &function, Attr &&attr, std::initializer_list<Ref> args)
+    {
+        Ref added = add(function, std::move(attr));
+        for (auto arg : args)
+            added.append_arg(arg);
+        return added;
+    }
+
+    template<typename T>
+    Ref add(Fn const &function, Attr const &attr, T begin, T end)
+    {
+        Ref added = add(function, attr);
+        for (; begin != end; ++begin)
+            added.append_arg(*begin);
+        return added;
+    }
+
+    void compact_args()
+    {
+        std::vector<int> new_args;
+        for (auto &node : nodes)
+        {
+            int new_offset = GetSize(new_args);
+            for (int i = 0; i < node.arg_count; i++)
+                new_args.push_back(args[node.arg_offset + i]);
+            node.arg_offset = new_offset;
+        }
+        std::swap(args, new_args);
+    }
+
+    void permute(std::vector<int> const &perm)
+    {
+        log_assert(perm.size() <= nodes.size());
+        std::vector<int> inv_perm;
+        inv_perm.resize(nodes.size(), -1);
+        for (int i = 0; i < GetSize(perm); ++i)
+        {
+            int j = perm[i];
+            log_assert(j >= 0 && j < GetSize(nodes));
+            log_assert(inv_perm[j] == -1);
+            inv_perm[j] = i;
+        }
+        permute(perm, inv_perm);
+    }
+
+    void permute(std::vector<int> const &perm, std::vector<int> const &inv_perm)
+    {
+        log_assert(inv_perm.size() == nodes.size());
+        std::vector<Node> new_nodes;
+        new_nodes.reserve(perm.size());
+        dict<int, SparseAttr> new_sparse_attrs;
+        for (int i : perm)
+        {
+            int j = GetSize(new_nodes);
+            new_nodes.emplace_back(std::move(nodes[i]));
+            auto found = sparse_attrs.find(i);
+            if (found != sparse_attrs.end())
+                new_sparse_attrs.emplace(j, std::move(found->second));
+        }
+
+        std::swap(nodes, new_nodes);
+        std::swap(sparse_attrs, new_sparse_attrs);
+
+        compact_args();
+        for (int &arg : args)
+        {
+            log_assert(arg < GetSize(inv_perm));
+            log_assert(inv_perm[arg] >= 0);
+            arg = inv_perm[arg];
+        }
+
+        for (auto &key : keys_)
+        {
+            log_assert(key.second < GetSize(inv_perm));
+            log_assert(inv_perm[key.second] >= 0);
+            key.second = inv_perm[key.second];
+        }
+    }
+
+    struct SccAdaptor
+    {
+    private:
+        ComputeGraph const &graph_;
+        std::vector<int> indices_;
+    public:
+        SccAdaptor(ComputeGraph const &graph) : graph_(graph)
+        {
+            indices_.resize(graph.size(), -1);
+        }
+
+
+        typedef int node_type;
+
+        struct node_enumerator {
+        private:
+            friend struct SccAdaptor;
+            int current, end;
+            node_enumerator(int current, int end) : current(current), end(end) {}
+
+        public:
+
+            bool finished() const { return current == end; }
+            node_type next() {
+                log_assert(!finished());
+                node_type result = current;
+                ++current;
+                return result;
+            }
+        };
+
+        node_enumerator enumerate_nodes() {
+            return node_enumerator(0, GetSize(indices_));
+        }
+
+
+        struct successor_enumerator {
+        private:
+            friend struct SccAdaptor;
+            std::vector<int>::const_iterator current, end;
+            successor_enumerator(std::vector<int>::const_iterator current, std::vector<int>::const_iterator end) :
+                current(current), end(end) {}
+
+        public:
+            bool finished() const { return current == end; }
+            node_type next() {
+                log_assert(!finished());
+                node_type result = *current;
+                ++current;
+                return result;
+            }
+        };
+
+        successor_enumerator enumerate_successors(int index) const {
+            auto const &ref = graph_[index];
+            return successor_enumerator(ref.arg_indices_cbegin(), ref.arg_indices_cend());
+        }
+
+        int &dfs_index(node_type const &node) { return indices_[node]; }
+
+        std::vector<int> const &dfs_indices() { return indices_; }
+    };
+
+};
+
+
+
+YOSYS_NAMESPACE_END
+
+
+#endif
diff --git a/kernel/functionalir.cc b/kernel/functional.cc
similarity index 90%
rename from kernel/functionalir.cc
rename to kernel/functional.cc
index 223fdaa91..ad507187d 100644
--- a/kernel/functionalir.cc
+++ b/kernel/functional.cc
@@ -17,55 +17,55 @@
  *
  */
 
-#include "kernel/functionalir.h"
-#include <optional>
+#include "kernel/functional.h"
+#include "kernel/topo_scc.h"
 #include "ff.h"
 #include "ffinit.h"
 
 YOSYS_NAMESPACE_BEGIN
+namespace Functional {
 
-const char *FunctionalIR::fn_to_string(FunctionalIR::Fn fn) {
+const char *fn_to_string(Fn fn) {
 	switch(fn) {
-	case FunctionalIR::Fn::invalid: return "invalid";
-	case FunctionalIR::Fn::buf: return "buf";
-	case FunctionalIR::Fn::slice: return "slice";
-	case FunctionalIR::Fn::zero_extend: return "zero_extend";
-	case FunctionalIR::Fn::sign_extend: return "sign_extend";
-	case FunctionalIR::Fn::concat: return "concat";
-	case FunctionalIR::Fn::add: return "add";
-	case FunctionalIR::Fn::sub: return "sub";
-	case FunctionalIR::Fn::mul: return "mul";
-	case FunctionalIR::Fn::unsigned_div: return "unsigned_div";
-	case FunctionalIR::Fn::unsigned_mod: return "unsigned_mod";
-	case FunctionalIR::Fn::bitwise_and: return "bitwise_and";
-	case FunctionalIR::Fn::bitwise_or: return "bitwise_or";
-	case FunctionalIR::Fn::bitwise_xor: return "bitwise_xor";
-	case FunctionalIR::Fn::bitwise_not: return "bitwise_not";
-	case FunctionalIR::Fn::reduce_and: return "reduce_and";
-	case FunctionalIR::Fn::reduce_or: return "reduce_or";
-	case FunctionalIR::Fn::reduce_xor: return "reduce_xor";
-	case FunctionalIR::Fn::unary_minus: return "unary_minus";
-	case FunctionalIR::Fn::equal: return "equal";
-	case FunctionalIR::Fn::not_equal: return "not_equal";
-	case FunctionalIR::Fn::signed_greater_than: return "signed_greater_than";
-	case FunctionalIR::Fn::signed_greater_equal: return "signed_greater_equal";
-	case FunctionalIR::Fn::unsigned_greater_than: return "unsigned_greater_than";
-	case FunctionalIR::Fn::unsigned_greater_equal: return "unsigned_greater_equal";
-	case FunctionalIR::Fn::logical_shift_left: return "logical_shift_left";
-	case FunctionalIR::Fn::logical_shift_right: return "logical_shift_right";
-	case FunctionalIR::Fn::arithmetic_shift_right: return "arithmetic_shift_right";
-	case FunctionalIR::Fn::mux: return "mux";
-	case FunctionalIR::Fn::constant: return "constant";
-	case FunctionalIR::Fn::input: return "input";
-	case FunctionalIR::Fn::state: return "state";
-	case FunctionalIR::Fn::memory_read: return "memory_read";
-	case FunctionalIR::Fn::memory_write: return "memory_write";
+	case Fn::invalid: return "invalid";
+	case Fn::buf: return "buf";
+	case Fn::slice: return "slice";
+	case Fn::zero_extend: return "zero_extend";
+	case Fn::sign_extend: return "sign_extend";
+	case Fn::concat: return "concat";
+	case Fn::add: return "add";
+	case Fn::sub: return "sub";
+	case Fn::mul: return "mul";
+	case Fn::unsigned_div: return "unsigned_div";
+	case Fn::unsigned_mod: return "unsigned_mod";
+	case Fn::bitwise_and: return "bitwise_and";
+	case Fn::bitwise_or: return "bitwise_or";
+	case Fn::bitwise_xor: return "bitwise_xor";
+	case Fn::bitwise_not: return "bitwise_not";
+	case Fn::reduce_and: return "reduce_and";
+	case Fn::reduce_or: return "reduce_or";
+	case Fn::reduce_xor: return "reduce_xor";
+	case Fn::unary_minus: return "unary_minus";
+	case Fn::equal: return "equal";
+	case Fn::not_equal: return "not_equal";
+	case Fn::signed_greater_than: return "signed_greater_than";
+	case Fn::signed_greater_equal: return "signed_greater_equal";
+	case Fn::unsigned_greater_than: return "unsigned_greater_than";
+	case Fn::unsigned_greater_equal: return "unsigned_greater_equal";
+	case Fn::logical_shift_left: return "logical_shift_left";
+	case Fn::logical_shift_right: return "logical_shift_right";
+	case Fn::arithmetic_shift_right: return "arithmetic_shift_right";
+	case Fn::mux: return "mux";
+	case Fn::constant: return "constant";
+	case Fn::input: return "input";
+	case Fn::state: return "state";
+	case Fn::memory_read: return "memory_read";
+	case Fn::memory_write: return "memory_write";
 	}
-	log_error("fn_to_string: unknown FunctionalIR::Fn value %d", (int)fn);
+	log_error("fn_to_string: unknown Functional::Fn value %d", (int)fn);
 }
 
-struct PrintVisitor : FunctionalIR::DefaultVisitor<std::string> {
-	using Node = FunctionalIR::Node;
+struct PrintVisitor : DefaultVisitor<std::string> {
 	std::function<std::string(Node)> np;
 	PrintVisitor(std::function<std::string(Node)> np) : np(np) { }
 	// as a general rule the default handler is good enough iff the only arguments are of type Node
@@ -76,7 +76,7 @@ struct PrintVisitor : FunctionalIR::DefaultVisitor<std::string> {
 	std::string input(Node, IdString name) override { return "input(" + name.str() + ")"; }
 	std::string state(Node, IdString name) override { return "state(" + name.str() + ")"; }
 	std::string default_handler(Node self) override {
-		std::string ret = FunctionalIR::fn_to_string(self.fn());
+		std::string ret = fn_to_string(self.fn());
 		ret += "(";
 		for(size_t i = 0; i < self.arg_count(); i++) {
 			if(i > 0) ret += ", ";
@@ -87,19 +87,18 @@ struct PrintVisitor : FunctionalIR::DefaultVisitor<std::string> {
 	}
 };
 
-std::string FunctionalIR::Node::to_string()
+std::string Node::to_string()
 {
 	return to_string([](Node n) { return RTLIL::unescape_id(n.name()); });
 }
 
-std::string FunctionalIR::Node::to_string(std::function<std::string(Node)> np)
+std::string Node::to_string(std::function<std::string(Node)> np)
 {
 	return visit(PrintVisitor(np));
 }
 
 class CellSimplifier {
-	using Node = FunctionalIR::Node;
-	FunctionalIR::Factory &factory;
+	Factory &factory;
 	Node sign(Node a) {
 		return factory.slice(a, a.width() - 1, 1);
 	}
@@ -138,7 +137,7 @@ public:
 		Node bb = factory.bitwise_and(b, s);
 		return factory.bitwise_or(aa, bb);
 	}
-	CellSimplifier(FunctionalIR::Factory &f) : factory(f) {}
+	CellSimplifier(Factory &f) : factory(f) {}
 private:
 	Node handle_pow(Node a0, Node b, int y_width, bool is_signed) {
 		Node a = factory.extend(a0, y_width, is_signed);
@@ -400,12 +399,11 @@ public:
 };
 
 class FunctionalIRConstruction {
-	using Node = FunctionalIR::Node;
 	std::deque<std::variant<DriveSpec, Cell *>> queue;
 	dict<DriveSpec, Node> graph_nodes;
 	dict<std::pair<Cell *, IdString>, Node> cell_outputs;
 	DriverMap driver_map;
-	FunctionalIR::Factory& factory;
+	Factory& factory;
 	CellSimplifier simplifier;
 	vector<Mem> memories_vector;
 	dict<Cell*, Mem*> memories;
@@ -442,7 +440,7 @@ class FunctionalIRConstruction {
 			return it->second;
 	}
 public:
-	FunctionalIRConstruction(Module *module, FunctionalIR::Factory &f)
+	FunctionalIRConstruction(Module *module, Factory &f)
 		: factory(f)
 		, simplifier(f)
 		, sig_map(module)
@@ -497,7 +495,7 @@ public:
 		// - Since wr port j can only have priority over wr port i if j > i, if we do writes in
 		//   ascending index order the result will obey the priorty relation.
 		vector<Node> read_results;
-		factory.add_state(mem->cell->name, FunctionalIR::Sort(ceil_log2(mem->size), mem->width));
+		factory.add_state(mem->cell->name, Sort(ceil_log2(mem->size), mem->width));
 		factory.set_initial_state(mem->cell->name, MemContents(mem));
 		Node node = factory.current_state(mem->cell->name);
 		for (size_t i = 0; i < mem->wr_ports.size(); i++) {
@@ -542,7 +540,7 @@ public:
 			if (!ff.has_gclk)
 				log_error("The design contains a %s flip-flop at %s. This is not supported by the functional backend. "
 					"Call async2sync or clk2fflogic to avoid this error.\n", log_id(cell->type), log_id(cell));
-			factory.add_state(ff.name, FunctionalIR::Sort(ff.width));
+			factory.add_state(ff.name, Sort(ff.width));
 			Node q_value = factory.current_state(ff.name);
 			factory.suggest_name(q_value, ff.name);
 			factory.update_pending(cell_outputs.at({cell, ID(Q)}), q_value);
@@ -643,8 +641,8 @@ public:
 	}
 };
 
-FunctionalIR FunctionalIR::from_module(Module *module) {
-    FunctionalIR ir;
+IR IR::from_module(Module *module) {
+	IR ir;
     auto factory = ir.factory();
     FunctionalIRConstruction ctor(module, factory);
     ctor.process_queue();
@@ -653,7 +651,7 @@ FunctionalIR FunctionalIR::from_module(Module *module) {
     return ir;
 }
 
-void FunctionalIR::topological_sort() {
+void IR::topological_sort() {
     Graph::SccAdaptor compute_graph_scc(_graph);
     bool scc = false;
     std::vector<int> perm;
@@ -687,7 +685,7 @@ static IdString merge_name(IdString a, IdString b) {
 		return a;
 }
 
-void FunctionalIR::forward_buf() {
+void IR::forward_buf() {
     std::vector<int> perm, alias;
     perm.clear();
 
@@ -734,7 +732,7 @@ static std::string quote_fmt(const char *fmt)
 	return r;
 }
 
-void FunctionalTools::Writer::print_impl(const char *fmt, vector<std::function<void()>> &fns)
+void Writer::print_impl(const char *fmt, vector<std::function<void()>> &fns)
 {
 	size_t next_index = 0;
 	for(const char *p = fmt; *p != 0; p++)
@@ -770,4 +768,5 @@ void FunctionalTools::Writer::print_impl(const char *fmt, vector<std::function<v
 		}
 }
 
+}
 YOSYS_NAMESPACE_END
diff --git a/kernel/functional.h b/kernel/functional.h
index 77e965f64..e0592259c 100644
--- a/kernel/functional.h
+++ b/kernel/functional.h
@@ -1,7 +1,7 @@
 /*
  *  yosys -- Yosys Open SYnthesis Suite
  *
- *  Copyright (C) 2024  Jannis Harder <jix@yosyshq.com> <me@jix.one>
+ *  Copyright (C) 2024  Emily Schmidt <emily@yosyshq.com>
  *
  *  Permission to use, copy, modify, and/or distribute this software for any
  *  purpose with or without fee is hereby granted, provided that the above
@@ -20,384 +20,571 @@
 #ifndef FUNCTIONAL_H
 #define FUNCTIONAL_H
 
-#include <tuple>
 #include "kernel/yosys.h"
+#include "kernel/compute_graph.h"
+#include "kernel/drivertools.h"
+#include "kernel/mem.h"
+#include "kernel/utils.h"
 
+USING_YOSYS_NAMESPACE
 YOSYS_NAMESPACE_BEGIN
 
-template<
-    typename Fn, // Function type (deduplicated across whole graph)
-    typename Attr = std::tuple<>, // Call attributes (present in every node)
-    typename SparseAttr = std::tuple<>, // Sparse call attributes (optional per node)
-    typename Key = std::tuple<> // Stable keys to refer to nodes
->
-struct ComputeGraph
-{
-    struct Ref;
-private:
-
-    // Functions are deduplicated by assigning unique ids
-    idict<Fn> functions;
-
-    struct Node {
-        int fn_index;
-        int arg_offset;
-        int arg_count;
-        Attr attr;
-
-        Node(int fn_index, Attr &&attr, int arg_offset, int arg_count = 0)
-            : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(std::move(attr)) {}
-
-        Node(int fn_index, Attr const &attr, int arg_offset, int arg_count = 0)
-            : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(attr) {}
-    };
-
-
-    std::vector<Node> nodes;
-    std::vector<int> args;
-    dict<Key, int> keys_;
-    dict<int, SparseAttr> sparse_attrs;
-
-public:
-    template<typename Graph>
-    struct BaseRef
-    {
-    protected:
-        friend struct ComputeGraph;
-        Graph *graph_;
-        int index_;
-        BaseRef(Graph *graph, int index) : graph_(graph), index_(index) {
-            log_assert(index_ >= 0);
-            check();
-        }
-
-        void check() const { log_assert(index_ < graph_->size()); }
-
-        Node const &deref() const { check(); return graph_->nodes[index_]; }
-
-    public:
-        ComputeGraph const &graph() const { return graph_; }
-        int index() const { return index_; }
-
-        int size() const { return deref().arg_count; }
-
-        BaseRef arg(int n) const
-        {
-            Node const &node = deref();
-            log_assert(n >= 0 && n < node.arg_count);
-            return BaseRef(graph_, graph_->args[node.arg_offset + n]);
-        }
-
-        std::vector<int>::const_iterator arg_indices_cbegin() const
-        {
-            Node const &node = deref();
-            return graph_->args.cbegin() + node.arg_offset;
-        }
-
-        std::vector<int>::const_iterator arg_indices_cend() const
-        {
-            Node const &node = deref();
-            return graph_->args.cbegin() + node.arg_offset + node.arg_count;
-        }
-
-        Fn const &function() const { return graph_->functions[deref().fn_index]; }
-        Attr const &attr() const { return deref().attr; }
-
-        bool has_sparse_attr() const { return graph_->sparse_attrs.count(index_); }
-
-        SparseAttr const &sparse_attr() const
-        {
-            auto found = graph_->sparse_attrs.find(index_);
-            log_assert(found != graph_->sparse_attrs.end());
-            return found->second;
-        }
-    };
-
-    using ConstRef = BaseRef<ComputeGraph const>;
-
-    struct Ref : public BaseRef<ComputeGraph>
-    {
-    private:
-        friend struct ComputeGraph;
-        Ref(ComputeGraph *graph, int index) : BaseRef<ComputeGraph>(graph, index) {}
-        Node &deref() const { this->check(); return this->graph_->nodes[this->index_]; }
-
-    public:
-        Ref(BaseRef<ComputeGraph> ref) : Ref(ref.graph_, ref.index_) {}
-
-        void set_function(Fn const &function) const
-        {
-            deref().fn_index = this->graph_->functions(function);
-        }
-
-        Attr &attr() const { return deref().attr; }
-
-        void append_arg(ConstRef arg) const
-        {
-            log_assert(arg.graph_ == this->graph_);
-            append_arg(arg.index());
-        }
-
-        void append_arg(int arg) const
-        {
-            log_assert(arg >= 0 && arg < this->graph_->size());
-            Node &node = deref();
-            if (node.arg_offset + node.arg_count != GetSize(this->graph_->args))
-                move_args(node);
-            this->graph_->args.push_back(arg);
-            node.arg_count++;
-        }
-
-        operator ConstRef() const
-        {
-            return ConstRef(this->graph_, this->index_);
-        }
-
-        SparseAttr &sparse_attr() const
-        {
-            return this->graph_->sparse_attrs[this->index_];
-        }
-
-        void clear_sparse_attr() const
-        {
-            this->graph_->sparse_attrs.erase(this->index_);
-        }
-
-        void assign_key(Key const &key) const
-        {
-            this->graph_->keys_.emplace(key, this->index_);
-        }
-
-    private:
-        void move_args(Node &node) const
-        {
-            auto &args = this->graph_->args;
-            int old_offset = node.arg_offset;
-            node.arg_offset = GetSize(args);
-            for (int i = 0; i != node.arg_count; ++i)
-                args.push_back(args[old_offset + i]);
-        }
-
-    };
-
-    bool has_key(Key const &key) const
-    {
-        return keys_.count(key);
-    }
-
-    dict<Key, int> const &keys() const
-    {
-        return keys_;
-    }
-
-    ConstRef operator()(Key const &key) const
-    {
-        auto it = keys_.find(key);
-        log_assert(it != keys_.end());
-        return (*this)[it->second];
-    }
-
-    Ref operator()(Key const &key)
-    {
-        auto it = keys_.find(key);
-        log_assert(it != keys_.end());
-        return (*this)[it->second];
-    }
-
-    int size() const { return GetSize(nodes); }
-
-    ConstRef operator[](int index) const { return ConstRef(this, index); }
-    Ref operator[](int index) { return Ref(this, index); }
-
-    Ref add(Fn const &function, Attr &&attr)
-    {
-        int index = GetSize(nodes);
-        int fn_index = functions(function);
-        nodes.emplace_back(fn_index, std::move(attr), GetSize(args));
-        return Ref(this, index);
-    }
-
-    Ref add(Fn const &function, Attr const &attr)
-    {
-        int index = GetSize(nodes);
-        int fn_index = functions(function);
-        nodes.emplace_back(fn_index, attr,  GetSize(args));
-        return Ref(this, index);
-    }
-
-    template<typename T>
-    Ref add(Fn const &function, Attr const &attr, T &&args)
-    {
-        Ref added = add(function, attr);
-        for (auto arg : args)
-            added.append_arg(arg);
-        return added;
-    }
-
-    template<typename T>
-    Ref add(Fn const &function, Attr &&attr, T &&args)
-    {
-        Ref added = add(function, std::move(attr));
-        for (auto arg : args)
-            added.append_arg(arg);
-        return added;
-    }
-
-    Ref add(Fn const &function, Attr const &attr, std::initializer_list<Ref> args)
-    {
-        Ref added = add(function, attr);
-        for (auto arg : args)
-            added.append_arg(arg);
-        return added;
-    }
-
-    Ref add(Fn const &function, Attr &&attr, std::initializer_list<Ref> args)
-    {
-        Ref added = add(function, std::move(attr));
-        for (auto arg : args)
-            added.append_arg(arg);
-        return added;
-    }
-
-    template<typename T>
-    Ref add(Fn const &function, Attr const &attr, T begin, T end)
-    {
-        Ref added = add(function, attr);
-        for (; begin != end; ++begin)
-            added.append_arg(*begin);
-        return added;
-    }
-
-    void compact_args()
-    {
-        std::vector<int> new_args;
-        for (auto &node : nodes)
-        {
-            int new_offset = GetSize(new_args);
-            for (int i = 0; i < node.arg_count; i++)
-                new_args.push_back(args[node.arg_offset + i]);
-            node.arg_offset = new_offset;
-        }
-        std::swap(args, new_args);
-    }
-
-    void permute(std::vector<int> const &perm)
-    {
-        log_assert(perm.size() <= nodes.size());
-        std::vector<int> inv_perm;
-        inv_perm.resize(nodes.size(), -1);
-        for (int i = 0; i < GetSize(perm); ++i)
-        {
-            int j = perm[i];
-            log_assert(j >= 0 && j < GetSize(nodes));
-            log_assert(inv_perm[j] == -1);
-            inv_perm[j] = i;
-        }
-        permute(perm, inv_perm);
-    }
-
-    void permute(std::vector<int> const &perm, std::vector<int> const &inv_perm)
-    {
-        log_assert(inv_perm.size() == nodes.size());
-        std::vector<Node> new_nodes;
-        new_nodes.reserve(perm.size());
-        dict<int, SparseAttr> new_sparse_attrs;
-        for (int i : perm)
-        {
-            int j = GetSize(new_nodes);
-            new_nodes.emplace_back(std::move(nodes[i]));
-            auto found = sparse_attrs.find(i);
-            if (found != sparse_attrs.end())
-                new_sparse_attrs.emplace(j, std::move(found->second));
-        }
-
-        std::swap(nodes, new_nodes);
-        std::swap(sparse_attrs, new_sparse_attrs);
-
-        compact_args();
-        for (int &arg : args)
-        {
-            log_assert(arg < GetSize(inv_perm));
-            log_assert(inv_perm[arg] >= 0);
-            arg = inv_perm[arg];
-        }
-
-        for (auto &key : keys_)
-        {
-            log_assert(key.second < GetSize(inv_perm));
-            log_assert(inv_perm[key.second] >= 0);
-            key.second = inv_perm[key.second];
-        }
-    }
-
-    struct SccAdaptor
-    {
-    private:
-        ComputeGraph const &graph_;
-        std::vector<int> indices_;
-    public:
-        SccAdaptor(ComputeGraph const &graph) : graph_(graph)
-        {
-            indices_.resize(graph.size(), -1);
-        }
-
-
-        typedef int node_type;
-
-        struct node_enumerator {
-        private:
-            friend struct SccAdaptor;
-            int current, end;
-            node_enumerator(int current, int end) : current(current), end(end) {}
-
-        public:
-
-            bool finished() const { return current == end; }
-            node_type next() {
-                log_assert(!finished());
-                node_type result = current;
-                ++current;
-                return result;
-            }
-        };
-
-        node_enumerator enumerate_nodes() {
-            return node_enumerator(0, GetSize(indices_));
-        }
-
-
-        struct successor_enumerator {
-        private:
-            friend struct SccAdaptor;
-            std::vector<int>::const_iterator current, end;
-            successor_enumerator(std::vector<int>::const_iterator current, std::vector<int>::const_iterator end) :
-                current(current), end(end) {}
-
-        public:
-            bool finished() const { return current == end; }
-            node_type next() {
-                log_assert(!finished());
-                node_type result = *current;
-                ++current;
-                return result;
-            }
-        };
-
-        successor_enumerator enumerate_successors(int index) const {
-            auto const &ref = graph_[index];
-            return successor_enumerator(ref.arg_indices_cbegin(), ref.arg_indices_cend());
-        }
-
-        int &dfs_index(node_type const &node) { return indices_[node]; }
-
-        std::vector<int> const &dfs_indices() { return indices_; }
-    };
-
-};
-
-
+namespace Functional {
+	// each function is documented with a short pseudocode declaration or definition
+	// standard C/Verilog operators are used to describe the result
+	// 
+	// the types used in this are:
+	// - bit[N]: a bitvector of N bits
+	//   bit[N] can be indicated as signed or unsigned. this is not tracked by the functional backend
+	//   but is meant to indicate how the value is interpreted
+	//   if a bit[N] is marked as neither signed nor unsigned, this means the result should be valid with *either* interpretation
+	// - memory[N, M]: a memory with N address and M data bits
+	// - int: C++ int
+	// - Const[N]: yosys RTLIL::Const (with size() == N)
+	// - IdString: yosys IdString
+	// - any: used in documentation to indicate that the type is unconstrained
+	//
+	// nodes in the functional backend are either of type bit[N] or memory[N,M] (for some N, M: int)
+	// additionally, they can carry a constant of type int, Const[N] or IdString
+	// each node has a 'sort' field that stores the type of the node
+	// slice, zero_extend, sign_extend use the type field to store out_width
+	enum class Fn {
+		// invalid() = known-invalid/shouldn't happen value
+		// TODO: maybe remove this and use e.g. std::optional instead?
+		invalid,
+		// buf(a: any): any = a
+		// no-op operation
+		// when constructing the compute graph we generate invalid buf() nodes as a placeholder
+		// and later insert the argument
+		buf,
+		// slice(a: bit[in_width], offset: int, out_width: int): bit[out_width] = a[offset +: out_width]
+		// required: offset + out_width <= in_width
+		slice,
+		// zero_extend(a: unsigned bit[in_width], out_width: int): unsigned bit[out_width] = a (zero extended)
+		// required: out_width > in_width
+		zero_extend,
+		// sign_extend(a: signed bit[in_width], out_width: int): signed bit[out_width] = a (sign extended)
+		// required: out_width > in_width
+		sign_extend,
+		// concat(a: bit[N], b: bit[M]): bit[N+M] = {b, a} (verilog syntax)
+		// concatenates two bitvectors, with a in the least significant position and b in the more significant position
+		concat,
+		// add(a: bit[N], b: bit[N]): bit[N] = a + b
+		add,
+		// sub(a: bit[N], b: bit[N]): bit[N] = a - b
+		sub,
+		// mul(a: bit[N], b: bit[N]): bit[N] = a * b
+		mul,
+		// unsigned_div(a: unsigned bit[N], b: unsigned bit[N]): bit[N] = a / b
+		unsigned_div,
+		// unsigned_mod(a: signed bit[N], b: signed bit[N]): bit[N] = a % b
+		unsigned_mod,
+		// bitwise_and(a: bit[N], b: bit[N]): bit[N] = a & b
+		bitwise_and,
+		// bitwise_or(a: bit[N], b: bit[N]): bit[N] = a | b
+		bitwise_or,
+		// bitwise_xor(a: bit[N], b: bit[N]): bit[N] = a ^ b
+		bitwise_xor,
+		// bitwise_not(a: bit[N]): bit[N] = ~a
+		bitwise_not,
+		// reduce_and(a: bit[N]): bit[1] = &a
+		reduce_and,
+		// reduce_or(a: bit[N]): bit[1] = |a
+		reduce_or,
+		// reduce_xor(a: bit[N]): bit[1] = ^a
+		reduce_xor,
+		// unary_minus(a: bit[N]): bit[N] = -a
+		unary_minus,
+		// equal(a: bit[N], b: bit[N]): bit[1] = (a == b)
+		equal,
+		// not_equal(a: bit[N], b: bit[N]): bit[1] = (a != b)
+		not_equal,
+		// signed_greater_than(a: signed bit[N], b: signed bit[N]): bit[1] = (a > b)
+		signed_greater_than,
+		// signed_greater_equal(a: signed bit[N], b: signed bit[N]): bit[1] = (a >= b)
+		signed_greater_equal,
+		// unsigned_greater_than(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a > b)
+		unsigned_greater_than,
+		// unsigned_greater_equal(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a >= b)
+		unsigned_greater_equal,
+		// logical_shift_left(a: bit[N], b: unsigned bit[M]): bit[N] = a << b
+		// required: M == clog2(N)
+		logical_shift_left,
+		// logical_shift_right(a: unsigned bit[N], b: unsigned bit[M]): unsigned bit[N] = a >> b
+		// required: M == clog2(N)
+		logical_shift_right,
+		// arithmetic_shift_right(a: signed bit[N], b: unsigned bit[M]): signed bit[N] = a >> b
+		// required: M == clog2(N)
+		arithmetic_shift_right,
+		// mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a
+		mux,
+		// constant(a: Const[N]): bit[N] = a
+		constant,
+		// input(a: IdString): any
+		// returns the current value of the input with the specified name
+		input,
+		// state(a: IdString): any
+		// returns the current value of the state variable with the specified name
+		state,
+		// memory_read(memory: memory[addr_width, data_width], addr: bit[addr_width]): bit[data_width] = memory[addr]
+		memory_read,
+		// memory_write(memory: memory[addr_width, data_width], addr: bit[addr_width], data: bit[data_width]): memory[addr_width, data_width]
+		// returns a copy of `memory` but with the value at `addr` changed to `data`
+		memory_write
+	};
+	// returns the name of a Fn value, as a string literal
+	const char *fn_to_string(Fn);
+	// Sort represents the sort or type of a node
+	// currently the only two types are signal/bit and memory
+	class Sort {
+		std::variant<int, std::pair<int, int>> _v;
+	public:
+		explicit Sort(int width) : _v(width) { }
+		Sort(int addr_width, int data_width) : _v(std::make_pair(addr_width, data_width)) { }
+		bool is_signal() const { return _v.index() == 0; }
+		bool is_memory() const { return _v.index() == 1; }
+		// returns the width of a bitvector type, errors out for other types
+		int width() const { return std::get<0>(_v); }
+		// returns the address width of a bitvector type, errors out for other types
+		int addr_width() const { return std::get<1>(_v).first; }
+		// returns the data width of a bitvector type, errors out for other types
+		int data_width() const { return std::get<1>(_v).second; }
+		bool operator==(Sort const& other) const { return _v == other._v; }
+		unsigned int hash() const { return mkhash(_v); }
+	};
+	class Factory;
+	class Node;
+	class IR {
+		friend class Factory;
+		friend class Node;
+		// one NodeData is stored per Node, containing the function and non-node arguments
+		// note that NodeData is deduplicated by ComputeGraph
+		class NodeData {
+			Fn _fn;
+			std::variant<
+				std::monostate,
+				RTLIL::Const,
+				IdString,
+				int
+			> _extra;
+		public:
+			NodeData() : _fn(Fn::invalid) {}
+			NodeData(Fn fn) : _fn(fn) {}
+			template<class T> NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward<T>(extra)) {}
+			Fn fn() const { return _fn; }
+			const RTLIL::Const &as_const() const { return std::get<RTLIL::Const>(_extra); }
+			IdString as_idstring() const { return std::get<IdString>(_extra); }
+			int as_int() const { return std::get<int>(_extra); }
+			int hash() const {
+				return mkhash((unsigned int) _fn, mkhash(_extra));
+			}
+			bool operator==(NodeData const &other) const {
+				return _fn == other._fn && _extra == other._extra;
+			}
+		};
+		// Attr contains all the information about a note that should not be deduplicated
+		struct Attr {
+			Sort sort;
+		};
+		// our specialised version of ComputeGraph
+		// the sparse_attr IdString stores a naming suggestion, retrieved with name()
+		// the key is currently used to identify the nodes that represent output and next state values
+		// the bool is true for next state values
+		using Graph = ComputeGraph<NodeData, Attr, IdString, std::pair<IdString, bool>>;
+		Graph _graph;
+		dict<IdString, Sort> _input_sorts;
+		dict<IdString, Sort> _output_sorts;
+		dict<IdString, Sort> _state_sorts;
+		dict<IdString, RTLIL::Const> _initial_state_signal;
+		dict<IdString, MemContents> _initial_state_memory;
+	public:
+		static IR from_module(Module *module);
+		Factory factory();
+		int size() const { return _graph.size(); }
+		Node operator[](int i);
+		void topological_sort();
+		void forward_buf();
+		dict<IdString, Sort> inputs() const { return _input_sorts; }
+		dict<IdString, Sort> outputs() const { return _output_sorts; }
+		dict<IdString, Sort> state() const { return _state_sorts; }
+		RTLIL::Const  const &get_initial_state_signal(IdString name) { return _initial_state_signal.at(name); }
+		MemContents const &get_initial_state_memory(IdString name) { return _initial_state_memory.at(name); }
+		Node get_output_node(IdString name);
+		Node get_state_next_node(IdString name);
+		class iterator {
+			friend class IR;
+			IR *_ir;
+			int _index;
+			iterator(IR *ir, int index) : _ir(ir), _index(index) {}
+		public:
+			using iterator_category = std::input_iterator_tag;
+			using value_type = Node;
+			using pointer = arrow_proxy<Node>;
+			using reference = Node;
+			using difference_type = ptrdiff_t;
+			Node operator*();
+			iterator &operator++() { _index++; return *this; }
+			bool operator!=(iterator const &other) const { return _ir != other._ir || _index != other._index; }
+			bool operator==(iterator const &other) const { return !(*this != other); }
+			pointer operator->();
+			// TODO: implement operator-> using the arrow_proxy class currently in mem.h
+		};
+		iterator begin() { return iterator(this, 0); }
+		iterator end() { return iterator(this, _graph.size()); }
+	};
+	// Node is an immutable reference to a FunctionalIR node
+	class Node {
+		friend class Factory;
+		friend class IR;
+		IR::Graph::ConstRef _ref;
+		explicit Node(IR::Graph::ConstRef ref) : _ref(ref) { }
+		explicit operator IR::Graph::ConstRef() { return _ref; }
+	public:
+		// the node's index. may change if nodes are added or removed
+		int id() const { return _ref.index(); }
+		// a name suggestion for the node, which need not be unique
+		IdString name() const {
+			if(_ref.has_sparse_attr())
+				return _ref.sparse_attr();
+			else
+				return std::string("\\n") + std::to_string(id());
+		}
+		Fn fn() const { return _ref.function().fn(); }
+		Sort sort() const { return _ref.attr().sort; }
+		// returns the width of a bitvector node, errors out for other nodes
+		int width() const { return sort().width(); }
+		size_t arg_count() const { return _ref.size(); }
+		Node arg(int n) const { return Node(_ref.arg(n)); }
+		// visit calls the appropriate visitor method depending on the type of the node
+		template<class Visitor> auto visit(Visitor v) const
+		{
+			// currently templated but could be switched to AbstractVisitor &
+			switch(_ref.function().fn()) {
+			case Fn::invalid: log_error("invalid node in visit"); break;
+			case Fn::buf: return v.buf(*this, arg(0)); break;
+			case Fn::slice: return v.slice(*this, arg(0), _ref.function().as_int(), sort().width()); break;
+			case Fn::zero_extend: return v.zero_extend(*this, arg(0), width()); break;
+			case Fn::sign_extend: return v.sign_extend(*this, arg(0), width()); break;
+			case Fn::concat: return v.concat(*this, arg(0), arg(1)); break;
+			case Fn::add: return v.add(*this, arg(0), arg(1)); break;
+			case Fn::sub: return v.sub(*this, arg(0), arg(1)); break;
+			case Fn::mul: return v.mul(*this, arg(0), arg(1)); break;
+			case Fn::unsigned_div: return v.unsigned_div(*this, arg(0), arg(1)); break;
+			case Fn::unsigned_mod: return v.unsigned_mod(*this, arg(0), arg(1)); break;
+			case Fn::bitwise_and: return v.bitwise_and(*this, arg(0), arg(1)); break;
+			case Fn::bitwise_or: return v.bitwise_or(*this, arg(0), arg(1)); break;
+			case Fn::bitwise_xor: return v.bitwise_xor(*this, arg(0), arg(1)); break;
+			case Fn::bitwise_not: return v.bitwise_not(*this, arg(0)); break;
+			case Fn::unary_minus: return v.unary_minus(*this, arg(0)); break;
+			case Fn::reduce_and: return v.reduce_and(*this, arg(0)); break;
+			case Fn::reduce_or: return v.reduce_or(*this, arg(0)); break;
+			case Fn::reduce_xor: return v.reduce_xor(*this, arg(0)); break;
+			case Fn::equal: return v.equal(*this, arg(0), arg(1)); break;
+			case Fn::not_equal: return v.not_equal(*this, arg(0), arg(1)); break;
+			case Fn::signed_greater_than: return v.signed_greater_than(*this, arg(0), arg(1)); break; 
+			case Fn::signed_greater_equal: return v.signed_greater_equal(*this, arg(0), arg(1)); break;
+			case Fn::unsigned_greater_than: return v.unsigned_greater_than(*this, arg(0), arg(1)); break; 
+			case Fn::unsigned_greater_equal: return v.unsigned_greater_equal(*this, arg(0), arg(1)); break;
+			case Fn::logical_shift_left: return v.logical_shift_left(*this, arg(0), arg(1)); break;
+			case Fn::logical_shift_right: return v.logical_shift_right(*this, arg(0), arg(1)); break;
+			case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break;
+			case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break;
+			case Fn::constant: return v.constant(*this, _ref.function().as_const()); break;
+			case Fn::input: return v.input(*this, _ref.function().as_idstring()); break;
+			case Fn::state: return v.state(*this, _ref.function().as_idstring()); break;
+			case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break;
+			case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break;
+			}
+		}
+		std::string to_string();
+		std::string to_string(std::function<std::string(Node)>);
+	};
+	inline Node IR::operator[](int i) { return Node(_graph[i]); }
+	inline Node IR::get_output_node(IdString name) { return Node(_graph({name, false})); }
+	inline Node IR::get_state_next_node(IdString name) { return Node(_graph({name, true})); }
+	inline Node IR::iterator::operator*() { return Node(_ir->_graph[_index]); }
+	inline arrow_proxy<Node> IR::iterator::operator->() { return arrow_proxy<Node>(**this); }
+	// AbstractVisitor provides an abstract base class for visitors
+	template<class T> struct AbstractVisitor {
+		virtual T buf(Node self, Node n) = 0;
+		virtual T slice(Node self, Node a, int offset, int out_width) = 0;
+		virtual T zero_extend(Node self, Node a, int out_width) = 0;
+		virtual T sign_extend(Node self, Node a, int out_width) = 0;
+		virtual T concat(Node self, Node a, Node b) = 0;
+		virtual T add(Node self, Node a, Node b) = 0;
+		virtual T sub(Node self, Node a, Node b) = 0;
+		virtual T mul(Node self, Node a, Node b) = 0;
+		virtual T unsigned_div(Node self, Node a, Node b) = 0;
+		virtual T unsigned_mod(Node self, Node a, Node b) = 0;
+		virtual T bitwise_and(Node self, Node a, Node b) = 0;
+		virtual T bitwise_or(Node self, Node a, Node b) = 0;
+		virtual T bitwise_xor(Node self, Node a, Node b) = 0;
+		virtual T bitwise_not(Node self, Node a) = 0;
+		virtual T unary_minus(Node self, Node a) = 0;
+		virtual T reduce_and(Node self, Node a) = 0;
+		virtual T reduce_or(Node self, Node a) = 0;
+		virtual T reduce_xor(Node self, Node a) = 0;
+		virtual T equal(Node self, Node a, Node b) = 0;
+		virtual T not_equal(Node self, Node a, Node b) = 0;
+		virtual T signed_greater_than(Node self, Node a, Node b) = 0;
+		virtual T signed_greater_equal(Node self, Node a, Node b) = 0;
+		virtual T unsigned_greater_than(Node self, Node a, Node b) = 0;
+		virtual T unsigned_greater_equal(Node self, Node a, Node b) = 0;
+		virtual T logical_shift_left(Node self, Node a, Node b) = 0;
+		virtual T logical_shift_right(Node self, Node a, Node b) = 0;
+		virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0;
+		virtual T mux(Node self, Node a, Node b, Node s) = 0;
+		virtual T constant(Node self, RTLIL::Const const & value) = 0;
+		virtual T input(Node self, IdString name) = 0;
+		virtual T state(Node self, IdString name) = 0;
+		virtual T memory_read(Node self, Node mem, Node addr) = 0;
+		virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0;
+	};
+	// DefaultVisitor provides defaults for all visitor methods which just calls default_handler
+	template<class T> struct DefaultVisitor : public AbstractVisitor<T> {
+		virtual T default_handler(Node self) = 0;
+		T buf(Node self, Node) override { return default_handler(self); }
+		T slice(Node self, Node, int, int) override { return default_handler(self); }
+		T zero_extend(Node self, Node, int) override { return default_handler(self); }
+		T sign_extend(Node self, Node, int) override { return default_handler(self); }
+		T concat(Node self, Node, Node) override { return default_handler(self); }
+		T add(Node self, Node, Node) override { return default_handler(self); }
+		T sub(Node self, Node, Node) override { return default_handler(self); }
+		T mul(Node self, Node, Node) override { return default_handler(self); }
+		T unsigned_div(Node self, Node, Node) override { return default_handler(self); }
+		T unsigned_mod(Node self, Node, Node) override { return default_handler(self); }
+		T bitwise_and(Node self, Node, Node) override { return default_handler(self); }
+		T bitwise_or(Node self, Node, Node) override { return default_handler(self); }
+		T bitwise_xor(Node self, Node, Node) override { return default_handler(self); }
+		T bitwise_not(Node self, Node) override { return default_handler(self); }
+		T unary_minus(Node self, Node) override { return default_handler(self); }
+		T reduce_and(Node self, Node) override { return default_handler(self); }
+		T reduce_or(Node self, Node) override { return default_handler(self); }
+		T reduce_xor(Node self, Node) override { return default_handler(self); }
+		T equal(Node self, Node, Node) override { return default_handler(self); }
+		T not_equal(Node self, Node, Node) override { return default_handler(self); }
+		T signed_greater_than(Node self, Node, Node) override { return default_handler(self); }
+		T signed_greater_equal(Node self, Node, Node) override { return default_handler(self); }
+		T unsigned_greater_than(Node self, Node, Node) override { return default_handler(self); }
+		T unsigned_greater_equal(Node self, Node, Node) override { return default_handler(self); }
+		T logical_shift_left(Node self, Node, Node) override { return default_handler(self); }
+		T logical_shift_right(Node self, Node, Node) override { return default_handler(self); }
+		T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); }
+		T mux(Node self, Node, Node, Node) override { return default_handler(self); }
+		T constant(Node self, RTLIL::Const const &) override { return default_handler(self); }
+		T input(Node self, IdString) override { return default_handler(self); }
+		T state(Node self, IdString) override { return default_handler(self); }
+		T memory_read(Node self, Node, Node) override { return default_handler(self); }
+		T memory_write(Node self, Node, Node, Node) override { return default_handler(self); }
+	};
+	// a factory is used to modify a FunctionalIR. it creates new nodes and allows for some modification of existing nodes.
+	class Factory {
+		friend class IR;
+		IR &_ir;
+		explicit Factory(IR &ir) : _ir(ir) {}
+		Node add(IR::NodeData &&fn, Sort &&sort, std::initializer_list<Node> args) {
+			log_assert(!sort.is_signal() || sort.width() > 0);
+			log_assert(!sort.is_memory() || sort.addr_width() > 0 && sort.data_width() > 0);
+			IR::Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)});
+			for (auto arg : args)
+				ref.append_arg(IR::Graph::ConstRef(arg));
+			return Node(ref);
+		}
+		IR::Graph::Ref mutate(Node n) {
+			return _ir._graph[n._ref.index()];
+		}
+		void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); }
+		void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); }
+		void check_unary(Node const &a) { log_assert(a.sort().is_signal()); }
+	public:
+		Node slice(Node a, int offset, int out_width) {
+			log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width());
+			if(offset == 0 && out_width == a.width())
+				return a;
+			return add(IR::NodeData(Fn::slice, offset), Sort(out_width), {a});
+		}
+		// extend will either extend or truncate the provided value to reach the desired width
+		Node extend(Node a, int out_width, bool is_signed) {
+			int in_width = a.sort().width();
+			log_assert(a.sort().is_signal());
+			if(in_width == out_width)
+				return a;
+			if(in_width > out_width)
+				return slice(a, 0, out_width);
+			if(is_signed)
+				return add(Fn::sign_extend, Sort(out_width), {a});
+			else
+				return add(Fn::zero_extend, Sort(out_width), {a});
+		}
+		Node concat(Node a, Node b) {
+			log_assert(a.sort().is_signal() && b.sort().is_signal());
+			return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b});
+		}
+		Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); }
+		Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); }
+		Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); }
+		Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); }
+		Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); }
+		Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); }
+		Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); }
+		Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); }
+		Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); }
+		Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); }
+		Node reduce_and(Node a) {
+			check_unary(a);
+			if(a.width() == 1)
+				return a;
+			return add(Fn::reduce_and, Sort(1), {a});
+		}
+		Node reduce_or(Node a) {
+			check_unary(a);
+			if(a.width() == 1)
+				return a;
+			return add(Fn::reduce_or, Sort(1), {a});
+		}
+		Node reduce_xor(Node a) { 
+			check_unary(a);
+			if(a.width() == 1)
+				return a;
+			return add(Fn::reduce_xor, Sort(1), {a});
+		}
+		Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); }
+		Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); }
+		Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); }
+		Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); }
+		Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); }
+		Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); }
+		Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); }
+		Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); }
+		Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); }
+		Node mux(Node a, Node b, Node s) {
+			log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1));
+			return add(Fn::mux, a.sort(), {a, b, s});
+		}
+		Node memory_read(Node mem, Node addr) {
+			log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width());
+			return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr});
+		}
+		Node memory_write(Node mem, Node addr, Node data) {
+			log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() &&
+				mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width());
+			return add(Fn::memory_write, mem.sort(), {mem, addr, data});
+		}
+		Node constant(RTLIL::Const value) {
+			return add(IR::NodeData(Fn::constant, std::move(value)), Sort(value.size()), {});
+		}
+		Node create_pending(int width) {
+			return add(Fn::buf, Sort(width), {});
+		}
+		void update_pending(Node node, Node value) {
+			log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0);
+			log_assert(node.sort() == value.sort());
+			mutate(node).append_arg(value._ref);
+		}
+		void add_input(IdString name, int width) {
+			auto [it, inserted] = _ir._input_sorts.emplace(name, Sort(width));
+			if (!inserted) log_error("input `%s` was re-defined", name.c_str());
+		}
+		void add_output(IdString name, int width) {
+			auto [it, inserted] = _ir._output_sorts.emplace(name, Sort(width));
+			if (!inserted) log_error("output `%s` was re-defined", name.c_str());
+		}
+		void add_state(IdString name, Sort sort) {
+			auto [it, inserted] = _ir._state_sorts.emplace(name, sort);
+			if (!inserted) log_error("state `%s` was re-defined", name.c_str());
+		}
+		Node input(IdString name) {
+			return add(IR::NodeData(Fn::input, name), Sort(_ir._input_sorts.at(name)), {});
+		}
+		Node current_state(IdString name) {
+			return add(IR::NodeData(Fn::state, name), Sort(_ir._state_sorts.at(name)), {});
+		}
+		void set_output(IdString output, Node value) {
+			log_assert(_ir._output_sorts.at(output) == value.sort());
+			mutate(value).assign_key({output, false});
+		}
+		void set_initial_state(IdString state, RTLIL::Const value) {
+			Sort &sort = _ir._state_sorts.at(state);
+			value.extu(sort.width());
+			_ir._initial_state_signal.emplace(state, std::move(value));
+		}
+		void set_initial_state(IdString state, MemContents value) {
+			log_assert(Sort(value.addr_width(), value.data_width()) == _ir._state_sorts.at(state));
+			_ir._initial_state_memory.emplace(state, std::move(value));
+		}
+		void set_next_state(IdString state, Node value) {
+			log_assert(_ir._state_sorts.at(state) == value.sort());
+			mutate(value).assign_key({state, true});
+		}
+		void suggest_name(Node node, IdString name) {
+			mutate(node).sparse_attr() = name;
+		}
+	};
+	inline Factory IR::factory() { return Factory(*this); }
+	template<class Id> class Scope {
+	protected:
+		char substitution_character = '_';
+		virtual bool is_character_legal(char) = 0;
+	private:
+		pool<std::string> _used_names;
+		dict<Id, std::string> _by_id;
+	public:
+		void reserve(std::string name) {
+			_used_names.insert(std::move(name));
+		}
+		std::string unique_name(IdString suggestion) {
+			std::string str = RTLIL::unescape_id(suggestion);
+			for(size_t i = 0; i < str.size(); i++)
+				if(!is_character_legal(str[i]))
+					str[i] = substitution_character;
+			if(_used_names.count(str) == 0) {
+				_used_names.insert(str);
+				return str;
+			}
+			for (int idx = 0 ; ; idx++){
+				std::string suffixed = str + "_" + std::to_string(idx);
+				if(_used_names.count(suffixed) == 0) {
+					_used_names.insert(suffixed);
+					return suffixed;
+				}
+			}
+		}
+		std::string operator()(Id id, IdString suggestion) {
+			auto it = _by_id.find(id);
+			if(it != _by_id.end())
+				return it->second;
+			std::string str = unique_name(suggestion);
+			_by_id.insert({id, str});
+			return str;
+		}
+	};
+	class Writer {
+		std::ostream *os;
+		void print_impl(const char *fmt, vector<std::function<void()>>& fns);
+	public:
+		Writer(std::ostream &os) : os(&os) {}
+		template<class T> Writer& operator <<(T&& arg) { *os << std::forward<T>(arg); return *this; }
+		template<typename... Args>
+		void print(const char *fmt, Args&&... args)
+		{
+			vector<std::function<void()>> fns { [&]() { *this << args; }... };
+			print_impl(fmt, fns);
+		}
+		template<typename Fn, typename... Args>
+		void print_with(Fn fn, const char *fmt, Args&&... args)
+		{
+			vector<std::function<void()>> fns { [&]() {
+				if constexpr (std::is_invocable_v<Fn, Args>)
+					*this << fn(args);
+				else
+					*this << args; }...
+			};
+			print_impl(fmt, fns);
+		}
+	};
+
+}
 
 YOSYS_NAMESPACE_END
 
-
 #endif
diff --git a/kernel/functionalir.h b/kernel/functionalir.h
deleted file mode 100644
index fdbdcbde3..000000000
--- a/kernel/functionalir.h
+++ /dev/null
@@ -1,575 +0,0 @@
-/*
- *  yosys -- Yosys Open SYnthesis Suite
- *
- *  Copyright (C) 2024  Emily Schmidt <emily@yosyshq.com>
- *
- *  Permission to use, copy, modify, and/or distribute this software for any
- *  purpose with or without fee is hereby granted, provided that the above
- *  copyright notice and this permission notice appear in all copies.
- *
- *  THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
- *  WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
- *  MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
- *  ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
- *  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
- *  ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
- *  OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
- *
- */
-
-#ifndef FUNCTIONALIR_H
-#define FUNCTIONALIR_H
-
-#include "kernel/yosys.h"
-#include "kernel/functional.h"
-#include "kernel/drivertools.h"
-#include "kernel/mem.h"
-#include "kernel/topo_scc.h"
-
-USING_YOSYS_NAMESPACE
-YOSYS_NAMESPACE_BEGIN
-
-class FunctionalIR {
-public:
-	// each function is documented with a short pseudocode declaration or definition
-	// standard C/Verilog operators are used to describe the result
-	// 
-	// the types used in this are:
-	// - bit[N]: a bitvector of N bits
-	//   bit[N] can be indicated as signed or unsigned. this is not tracked by the functional backend
-	//   but is meant to indicate how the value is interpreted
-	//   if a bit[N] is marked as neither signed nor unsigned, this means the result should be valid with *either* interpretation
-	// - memory[N, M]: a memory with N address and M data bits
-	// - int: C++ int
-	// - Const[N]: yosys RTLIL::Const (with size() == N)
-	// - IdString: yosys IdString
-	// - any: used in documentation to indicate that the type is unconstrained
-	//
-	// nodes in the functional backend are either of type bit[N] or memory[N,M] (for some N, M: int)
-	// additionally, they can carry a constant of type int, Const[N] or IdString
-	// each node has a 'sort' field that stores the type of the node
-	// slice, zero_extend, sign_extend use the type field to store out_width
-	enum class Fn {
-		// invalid() = known-invalid/shouldn't happen value
-		// TODO: maybe remove this and use e.g. std::optional instead?
-		invalid,
-		// buf(a: any): any = a
-		// no-op operation
-		// when constructing the compute graph we generate invalid buf() nodes as a placeholder
-		// and later insert the argument
-		buf,
-		// slice(a: bit[in_width], offset: int, out_width: int): bit[out_width] = a[offset +: out_width]
-		// required: offset + out_width <= in_width
-		slice,
-		// zero_extend(a: unsigned bit[in_width], out_width: int): unsigned bit[out_width] = a (zero extended)
-		// required: out_width > in_width
-		zero_extend,
-		// sign_extend(a: signed bit[in_width], out_width: int): signed bit[out_width] = a (sign extended)
-		// required: out_width > in_width
-		sign_extend,
-		// concat(a: bit[N], b: bit[M]): bit[N+M] = {b, a} (verilog syntax)
-		// concatenates two bitvectors, with a in the least significant position and b in the more significant position
-		concat,
-		// add(a: bit[N], b: bit[N]): bit[N] = a + b
-		add,
-		// sub(a: bit[N], b: bit[N]): bit[N] = a - b
-		sub,
-		// mul(a: bit[N], b: bit[N]): bit[N] = a * b
-		mul,
-		// unsigned_div(a: unsigned bit[N], b: unsigned bit[N]): bit[N] = a / b
-		unsigned_div,
-		// unsigned_mod(a: signed bit[N], b: signed bit[N]): bit[N] = a % b
-		unsigned_mod,
-		// bitwise_and(a: bit[N], b: bit[N]): bit[N] = a & b
-		bitwise_and,
-		// bitwise_or(a: bit[N], b: bit[N]): bit[N] = a | b
-		bitwise_or,
-		// bitwise_xor(a: bit[N], b: bit[N]): bit[N] = a ^ b
-		bitwise_xor,
-		// bitwise_not(a: bit[N]): bit[N] = ~a
-		bitwise_not,
-		// reduce_and(a: bit[N]): bit[1] = &a
-		reduce_and,
-		// reduce_or(a: bit[N]): bit[1] = |a
-		reduce_or,
-		// reduce_xor(a: bit[N]): bit[1] = ^a
-		reduce_xor,
-		// unary_minus(a: bit[N]): bit[N] = -a
-		unary_minus,
-		// equal(a: bit[N], b: bit[N]): bit[1] = (a == b)
-		equal,
-		// not_equal(a: bit[N], b: bit[N]): bit[1] = (a != b)
-		not_equal,
-		// signed_greater_than(a: signed bit[N], b: signed bit[N]): bit[1] = (a > b)
-		signed_greater_than,
-		// signed_greater_equal(a: signed bit[N], b: signed bit[N]): bit[1] = (a >= b)
-		signed_greater_equal,
-		// unsigned_greater_than(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a > b)
-		unsigned_greater_than,
-		// unsigned_greater_equal(a: unsigned bit[N], b: unsigned bit[N]): bit[1] = (a >= b)
-		unsigned_greater_equal,
-		// logical_shift_left(a: bit[N], b: unsigned bit[M]): bit[N] = a << b
-		// required: M == clog2(N)
-		logical_shift_left,
-		// logical_shift_right(a: unsigned bit[N], b: unsigned bit[M]): unsigned bit[N] = a >> b
-		// required: M == clog2(N)
-		logical_shift_right,
-		// arithmetic_shift_right(a: signed bit[N], b: unsigned bit[M]): signed bit[N] = a >> b
-		// required: M == clog2(N)
-		arithmetic_shift_right,
-		// mux(a: bit[N], b: bit[N], s: bit[1]): bit[N] = s ? b : a
-		mux,
-		// constant(a: Const[N]): bit[N] = a
-		constant,
-		// input(a: IdString): any
-		// returns the current value of the input with the specified name
-		input,
-		// state(a: IdString): any
-		// returns the current value of the state variable with the specified name
-		state,
-		// memory_read(memory: memory[addr_width, data_width], addr: bit[addr_width]): bit[data_width] = memory[addr]
-		memory_read,
-		// memory_write(memory: memory[addr_width, data_width], addr: bit[addr_width], data: bit[data_width]): memory[addr_width, data_width]
-		// returns a copy of `memory` but with the value at `addr` changed to `data`
-		memory_write
-	};
-	// returns the name of a FunctionalIR::Fn value, as a string literal
-	static const char *fn_to_string(Fn);
-	// FunctionalIR::Sort represents the sort or type of a node
-	// currently the only two types are signal/bit and memory
-	class Sort {
-		std::variant<int, std::pair<int, int>> _v;
-	public:
-		explicit Sort(int width) : _v(width) { }
-		Sort(int addr_width, int data_width) : _v(std::make_pair(addr_width, data_width)) { }
-		bool is_signal() const { return _v.index() == 0; }
-		bool is_memory() const { return _v.index() == 1; }
-		// returns the width of a bitvector type, errors out for other types
-		int width() const { return std::get<0>(_v); }
-		// returns the address width of a bitvector type, errors out for other types
-		int addr_width() const { return std::get<1>(_v).first; }
-		// returns the data width of a bitvector type, errors out for other types
-		int data_width() const { return std::get<1>(_v).second; }
-		bool operator==(Sort const& other) const { return _v == other._v; }
-		unsigned int hash() const { return mkhash(_v); }
-	};
-private:
-	// one NodeData is stored per Node, containing the function and non-node arguments
-	// note that NodeData is deduplicated by ComputeGraph
-	class NodeData {
-		Fn _fn;
-		std::variant<
-			std::monostate,
-			RTLIL::Const,
-			IdString,
-			int
-		> _extra;
-	public:
-		NodeData() : _fn(Fn::invalid) {}
-		NodeData(Fn fn) : _fn(fn) {}
-		template<class T> NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward<T>(extra)) {}
-		Fn fn() const { return _fn; }
-		const RTLIL::Const &as_const() const { return std::get<RTLIL::Const>(_extra); }
-		IdString as_idstring() const { return std::get<IdString>(_extra); }
-		int as_int() const { return std::get<int>(_extra); }
-		int hash() const {
-			return mkhash((unsigned int) _fn, mkhash(_extra));
-		}
-		bool operator==(NodeData const &other) const {
-			return _fn == other._fn && _extra == other._extra;
-		}
-	};
-	// Attr contains all the information about a note that should not be deduplicated
-	struct Attr {
-		Sort sort;
-	};
-	// our specialised version of ComputeGraph
-	// the sparse_attr IdString stores a naming suggestion, retrieved with name()
-	// the key is currently used to identify the nodes that represent output and next state values
-	// the bool is true for next state values
-	using Graph = ComputeGraph<NodeData, Attr, IdString, std::pair<IdString, bool>>;
-	Graph _graph;
-	dict<IdString, Sort> _input_sorts;
-	dict<IdString, Sort> _output_sorts;
-	dict<IdString, Sort> _state_sorts;
-	dict<IdString, RTLIL::Const> _initial_state_signal;
-	dict<IdString, MemContents> _initial_state_memory;
-public:
-	class Factory;
-	// Node is an immutable reference to a FunctionalIR node
-	class Node {
-		friend class Factory;
-		friend class FunctionalIR;
-		Graph::ConstRef _ref;
-		explicit Node(Graph::ConstRef ref) : _ref(ref) { }
-		explicit operator Graph::ConstRef() { return _ref; }
-	public:
-		// the node's index. may change if nodes are added or removed
-		int id() const { return _ref.index(); }
-		// a name suggestion for the node, which need not be unique
-		IdString name() const {
-			if(_ref.has_sparse_attr())
-				return _ref.sparse_attr();
-			else
-				return std::string("\\n") + std::to_string(id());
-		}
-		Fn fn() const { return _ref.function().fn(); }
-		Sort sort() const { return _ref.attr().sort; }
-		// returns the width of a bitvector node, errors out for other nodes
-		int width() const { return sort().width(); }
-		size_t arg_count() const { return _ref.size(); }
-		Node arg(int n) const { return Node(_ref.arg(n)); }
-		// visit calls the appropriate visitor method depending on the type of the node
-		template<class Visitor> auto visit(Visitor v) const
-		{
-			// currently templated but could be switched to AbstractVisitor &
-			switch(_ref.function().fn()) {
-			case Fn::invalid: log_error("invalid node in visit"); break;
-			case Fn::buf: return v.buf(*this, arg(0)); break;
-			case Fn::slice: return v.slice(*this, arg(0), _ref.function().as_int(), sort().width()); break;
-			case Fn::zero_extend: return v.zero_extend(*this, arg(0), width()); break;
-			case Fn::sign_extend: return v.sign_extend(*this, arg(0), width()); break;
-			case Fn::concat: return v.concat(*this, arg(0), arg(1)); break;
-			case Fn::add: return v.add(*this, arg(0), arg(1)); break;
-			case Fn::sub: return v.sub(*this, arg(0), arg(1)); break;
-			case Fn::mul: return v.mul(*this, arg(0), arg(1)); break;
-			case Fn::unsigned_div: return v.unsigned_div(*this, arg(0), arg(1)); break;
-			case Fn::unsigned_mod: return v.unsigned_mod(*this, arg(0), arg(1)); break;
-			case Fn::bitwise_and: return v.bitwise_and(*this, arg(0), arg(1)); break;
-			case Fn::bitwise_or: return v.bitwise_or(*this, arg(0), arg(1)); break;
-			case Fn::bitwise_xor: return v.bitwise_xor(*this, arg(0), arg(1)); break;
-			case Fn::bitwise_not: return v.bitwise_not(*this, arg(0)); break;
-			case Fn::unary_minus: return v.unary_minus(*this, arg(0)); break;
-			case Fn::reduce_and: return v.reduce_and(*this, arg(0)); break;
-			case Fn::reduce_or: return v.reduce_or(*this, arg(0)); break;
-			case Fn::reduce_xor: return v.reduce_xor(*this, arg(0)); break;
-			case Fn::equal: return v.equal(*this, arg(0), arg(1)); break;
-			case Fn::not_equal: return v.not_equal(*this, arg(0), arg(1)); break;
-			case Fn::signed_greater_than: return v.signed_greater_than(*this, arg(0), arg(1)); break; 
-			case Fn::signed_greater_equal: return v.signed_greater_equal(*this, arg(0), arg(1)); break;
-			case Fn::unsigned_greater_than: return v.unsigned_greater_than(*this, arg(0), arg(1)); break; 
-			case Fn::unsigned_greater_equal: return v.unsigned_greater_equal(*this, arg(0), arg(1)); break;
-			case Fn::logical_shift_left: return v.logical_shift_left(*this, arg(0), arg(1)); break;
-			case Fn::logical_shift_right: return v.logical_shift_right(*this, arg(0), arg(1)); break;
-			case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break;
-			case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break;
-			case Fn::constant: return v.constant(*this, _ref.function().as_const()); break;
-			case Fn::input: return v.input(*this, _ref.function().as_idstring()); break;
-			case Fn::state: return v.state(*this, _ref.function().as_idstring()); break;
-			case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break;
-			case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break;
-			}
-		}
-		std::string to_string();
-		std::string to_string(std::function<std::string(Node)>);
-	};
-	// AbstractVisitor provides an abstract base class for visitors
-	template<class T> struct AbstractVisitor {
-		virtual T buf(Node self, Node n) = 0;
-		virtual T slice(Node self, Node a, int offset, int out_width) = 0;
-		virtual T zero_extend(Node self, Node a, int out_width) = 0;
-		virtual T sign_extend(Node self, Node a, int out_width) = 0;
-		virtual T concat(Node self, Node a, Node b) = 0;
-		virtual T add(Node self, Node a, Node b) = 0;
-		virtual T sub(Node self, Node a, Node b) = 0;
-		virtual T mul(Node self, Node a, Node b) = 0;
-		virtual T unsigned_div(Node self, Node a, Node b) = 0;
-		virtual T unsigned_mod(Node self, Node a, Node b) = 0;
-		virtual T bitwise_and(Node self, Node a, Node b) = 0;
-		virtual T bitwise_or(Node self, Node a, Node b) = 0;
-		virtual T bitwise_xor(Node self, Node a, Node b) = 0;
-		virtual T bitwise_not(Node self, Node a) = 0;
-		virtual T unary_minus(Node self, Node a) = 0;
-		virtual T reduce_and(Node self, Node a) = 0;
-		virtual T reduce_or(Node self, Node a) = 0;
-		virtual T reduce_xor(Node self, Node a) = 0;
-		virtual T equal(Node self, Node a, Node b) = 0;
-		virtual T not_equal(Node self, Node a, Node b) = 0;
-		virtual T signed_greater_than(Node self, Node a, Node b) = 0;
-		virtual T signed_greater_equal(Node self, Node a, Node b) = 0;
-		virtual T unsigned_greater_than(Node self, Node a, Node b) = 0;
-		virtual T unsigned_greater_equal(Node self, Node a, Node b) = 0;
-		virtual T logical_shift_left(Node self, Node a, Node b) = 0;
-		virtual T logical_shift_right(Node self, Node a, Node b) = 0;
-		virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0;
-		virtual T mux(Node self, Node a, Node b, Node s) = 0;
-		virtual T constant(Node self, RTLIL::Const const & value) = 0;
-		virtual T input(Node self, IdString name) = 0;
-		virtual T state(Node self, IdString name) = 0;
-		virtual T memory_read(Node self, Node mem, Node addr) = 0;
-		virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0;
-	};
-	// DefaultVisitor provides defaults for all visitor methods which just calls default_handler
-	template<class T> struct DefaultVisitor : public AbstractVisitor<T> {
-		virtual T default_handler(Node self) = 0;
-		T buf(Node self, Node) override { return default_handler(self); }
-		T slice(Node self, Node, int, int) override { return default_handler(self); }
-		T zero_extend(Node self, Node, int) override { return default_handler(self); }
-		T sign_extend(Node self, Node, int) override { return default_handler(self); }
-		T concat(Node self, Node, Node) override { return default_handler(self); }
-		T add(Node self, Node, Node) override { return default_handler(self); }
-		T sub(Node self, Node, Node) override { return default_handler(self); }
-		T mul(Node self, Node, Node) override { return default_handler(self); }
-		T unsigned_div(Node self, Node, Node) override { return default_handler(self); }
-		T unsigned_mod(Node self, Node, Node) override { return default_handler(self); }
-		T bitwise_and(Node self, Node, Node) override { return default_handler(self); }
-		T bitwise_or(Node self, Node, Node) override { return default_handler(self); }
-		T bitwise_xor(Node self, Node, Node) override { return default_handler(self); }
-		T bitwise_not(Node self, Node) override { return default_handler(self); }
-		T unary_minus(Node self, Node) override { return default_handler(self); }
-		T reduce_and(Node self, Node) override { return default_handler(self); }
-		T reduce_or(Node self, Node) override { return default_handler(self); }
-		T reduce_xor(Node self, Node) override { return default_handler(self); }
-		T equal(Node self, Node, Node) override { return default_handler(self); }
-		T not_equal(Node self, Node, Node) override { return default_handler(self); }
-		T signed_greater_than(Node self, Node, Node) override { return default_handler(self); }
-		T signed_greater_equal(Node self, Node, Node) override { return default_handler(self); }
-		T unsigned_greater_than(Node self, Node, Node) override { return default_handler(self); }
-		T unsigned_greater_equal(Node self, Node, Node) override { return default_handler(self); }
-		T logical_shift_left(Node self, Node, Node) override { return default_handler(self); }
-		T logical_shift_right(Node self, Node, Node) override { return default_handler(self); }
-		T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); }
-		T mux(Node self, Node, Node, Node) override { return default_handler(self); }
-		T constant(Node self, RTLIL::Const const &) override { return default_handler(self); }
-		T input(Node self, IdString) override { return default_handler(self); }
-		T state(Node self, IdString) override { return default_handler(self); }
-		T memory_read(Node self, Node, Node) override { return default_handler(self); }
-		T memory_write(Node self, Node, Node, Node) override { return default_handler(self); }
-	};
-	// a factory is used to modify a FunctionalIR. it creates new nodes and allows for some modification of existing nodes.
-	class Factory {
-		FunctionalIR &_ir;
-		friend class FunctionalIR;
-		explicit Factory(FunctionalIR &ir) : _ir(ir) {}
-		Node add(NodeData &&fn, Sort &&sort, std::initializer_list<Node> args) {
-			log_assert(!sort.is_signal() || sort.width() > 0);
-			log_assert(!sort.is_memory() || sort.addr_width() > 0 && sort.data_width() > 0);
-			Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)});
-			for (auto arg : args)
-				ref.append_arg(Graph::ConstRef(arg));
-			return Node(ref);
-		}
-		Graph::Ref mutate(Node n) {
-			return _ir._graph[n._ref.index()];
-		}
-		void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); }
-		void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); }
-		void check_unary(Node const &a) { log_assert(a.sort().is_signal()); }
-	public:
-		Node slice(Node a, int offset, int out_width) {
-			log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width());
-			if(offset == 0 && out_width == a.width())
-				return a;
-			return add(NodeData(Fn::slice, offset), Sort(out_width), {a});
-		}
-		// extend will either extend or truncate the provided value to reach the desired width
-		Node extend(Node a, int out_width, bool is_signed) {
-			int in_width = a.sort().width();
-			log_assert(a.sort().is_signal());
-			if(in_width == out_width)
-				return a;
-			if(in_width > out_width)
-				return slice(a, 0, out_width);
-			if(is_signed)
-				return add(Fn::sign_extend, Sort(out_width), {a});
-			else
-				return add(Fn::zero_extend, Sort(out_width), {a});
-		}
-		Node concat(Node a, Node b) {
-			log_assert(a.sort().is_signal() && b.sort().is_signal());
-			return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b});
-		}
-		Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); }
-		Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); }
-		Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); }
-		Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); }
-		Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); }
-		Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); }
-		Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); }
-		Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); }
-		Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); }
-		Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); }
-		Node reduce_and(Node a) {
-			check_unary(a);
-			if(a.width() == 1)
-				return a;
-			return add(Fn::reduce_and, Sort(1), {a});
-		}
-		Node reduce_or(Node a) {
-			check_unary(a);
-			if(a.width() == 1)
-				return a;
-			return add(Fn::reduce_or, Sort(1), {a});
-		}
-		Node reduce_xor(Node a) { 
-			check_unary(a);
-			if(a.width() == 1)
-				return a;
-			return add(Fn::reduce_xor, Sort(1), {a});
-		}
-		Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); }
-		Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); }
-		Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); }
-		Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); }
-		Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); }
-		Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); }
-		Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); }
-		Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); }
-		Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); }
-		Node mux(Node a, Node b, Node s) {
-			log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1));
-			return add(Fn::mux, a.sort(), {a, b, s});
-		}
-		Node memory_read(Node mem, Node addr) {
-			log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width());
-			return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr});
-		}
-		Node memory_write(Node mem, Node addr, Node data) {
-			log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() &&
-				mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width());
-			return add(Fn::memory_write, mem.sort(), {mem, addr, data});
-		}
-		Node constant(RTLIL::Const value) {
-			return add(NodeData(Fn::constant, std::move(value)), Sort(value.size()), {});
-		}
-		Node create_pending(int width) {
-			return add(Fn::buf, Sort(width), {});
-		}
-		void update_pending(Node node, Node value) {
-			log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0);
-			log_assert(node.sort() == value.sort());
-			mutate(node).append_arg(value._ref);
-		}
-		void add_input(IdString name, int width) {
-			auto [it, inserted] = _ir._input_sorts.emplace(name, Sort(width));
-			if (!inserted) log_error("input `%s` was re-defined", name.c_str());
-		}
-		void add_output(IdString name, int width) {
-			auto [it, inserted] = _ir._output_sorts.emplace(name, Sort(width));
-			if (!inserted) log_error("output `%s` was re-defined", name.c_str());
-		}
-		void add_state(IdString name, Sort sort) {
-			auto [it, inserted] = _ir._state_sorts.emplace(name, sort);
-			if (!inserted) log_error("state `%s` was re-defined", name.c_str());
-		}
-		Node input(IdString name) {
-			return add(NodeData(Fn::input, name), Sort(_ir._input_sorts.at(name)), {});
-		}
-		Node current_state(IdString name) {
-			return add(NodeData(Fn::state, name), Sort(_ir._state_sorts.at(name)), {});
-		}
-		void set_output(IdString output, Node value) {
-			log_assert(_ir._output_sorts.at(output) == value.sort());
-			mutate(value).assign_key({output, false});
-		}
-		void set_initial_state(IdString state, RTLIL::Const value) {
-			Sort &sort = _ir._state_sorts.at(state);
-			value.extu(sort.width());
-			_ir._initial_state_signal.emplace(state, std::move(value));
-		}
-		void set_initial_state(IdString state, MemContents value) {
-			log_assert(Sort(value.addr_width(), value.data_width()) == _ir._state_sorts.at(state));
-			_ir._initial_state_memory.emplace(state, std::move(value));
-		}
-		void set_next_state(IdString state, Node value) {
-			log_assert(_ir._state_sorts.at(state) == value.sort());
-			mutate(value).assign_key({state, true});
-		}
-		void suggest_name(Node node, IdString name) {
-			mutate(node).sparse_attr() = name;
-		}
-	};
-	static FunctionalIR from_module(Module *module);
-	Factory factory() { return Factory(*this); }
-	int size() const { return _graph.size(); }
-	Node operator[](int i) { return Node(_graph[i]); }
-	void topological_sort();
-	void forward_buf();
-	dict<IdString, Sort> inputs() const { return _input_sorts; }
-	dict<IdString, Sort> outputs() const { return _output_sorts; }
-	dict<IdString, Sort> state() const { return _state_sorts; }
-	RTLIL::Const  const &get_initial_state_signal(IdString name) { return _initial_state_signal.at(name); }
-	MemContents const &get_initial_state_memory(IdString name) { return _initial_state_memory.at(name); }
-	Node get_output_node(IdString name) { return Node(_graph({name, false})); }
-	Node get_state_next_node(IdString name) { return Node(_graph({name, true})); }
-	class Iterator {
-		friend class FunctionalIR;
-		FunctionalIR *_ir;
-		int _index;
-		Iterator(FunctionalIR *ir, int index) : _ir(ir), _index(index) {}
-	public:
-		Node operator*() { return Node(_ir->_graph[_index]); }
-		Iterator &operator++() { _index++; return *this; }
-		bool operator!=(Iterator const &other) const { return _index != other._index; }
-	};
-	Iterator begin() { return Iterator(this, 0); }
-	Iterator end() { return Iterator(this, _graph.size()); }
-};
-
-namespace FunctionalTools {
-	template<class Id> class Scope {
-	protected:
-		char substitution_character = '_';
-		virtual bool is_character_legal(char) = 0;
-	private:
-		pool<std::string> _used_names;
-		dict<Id, std::string> _by_id;
-	public:
-		void reserve(std::string name) {
-			_used_names.insert(std::move(name));
-		}
-		std::string unique_name(IdString suggestion) {
-			std::string str = RTLIL::unescape_id(suggestion);
-			for(size_t i = 0; i < str.size(); i++)
-				if(!is_character_legal(str[i]))
-					str[i] = substitution_character;
-			if(_used_names.count(str) == 0) {
-				_used_names.insert(str);
-				return str;
-			}
-			for (int idx = 0 ; ; idx++){
-				std::string suffixed = str + "_" + std::to_string(idx);
-				if(_used_names.count(suffixed) == 0) {
-					_used_names.insert(suffixed);
-					return suffixed;
-				}
-			}
-		}
-		std::string operator()(Id id, IdString suggestion) {
-			auto it = _by_id.find(id);
-			if(it != _by_id.end())
-				return it->second;
-			std::string str = unique_name(suggestion);
-			_by_id.insert({id, str});
-			return str;
-		}
-	};
-	class Writer {
-		std::ostream *os;
-		void print_impl(const char *fmt, vector<std::function<void()>>& fns);
-	public:
-		Writer(std::ostream &os) : os(&os) {}
-		template<class T> Writer& operator <<(T&& arg) { *os << std::forward<T>(arg); return *this; }
-		template<typename... Args>
-		void print(const char *fmt, Args&&... args)
-		{
-			vector<std::function<void()>> fns { [&]() { *this << args; }... };
-			print_impl(fmt, fns);
-		}
-		template<typename Fn, typename... Args>
-		void print_with(Fn fn, const char *fmt, Args&&... args)
-		{
-			vector<std::function<void()>> fns { [&]() {
-				if constexpr (std::is_invocable_v<Fn, Args>)
-					*this << fn(args);
-				else
-					*this << args; }...
-			};
-			print_impl(fmt, fns);
-		}
-	};
-}
-
-YOSYS_NAMESPACE_END
-
-#endif
diff --git a/kernel/mem.h b/kernel/mem.h
index 4be4b6864..8c935adc1 100644
--- a/kernel/mem.h
+++ b/kernel/mem.h
@@ -22,6 +22,7 @@
 
 #include "kernel/yosys.h"
 #include "kernel/ffinit.h"
+#include "kernel/utils.h"
 
 YOSYS_NAMESPACE_BEGIN
 
@@ -224,15 +225,6 @@ struct Mem : RTLIL::AttrObject {
 	Mem(Module *module, IdString memid, int width, int start_offset, int size) : module(module), memid(memid), packed(false), mem(nullptr), cell(nullptr), width(width), start_offset(start_offset), size(size) {}
 };
 
-// this class is used for implementing operator-> on iterators that return values rather than references
-// it's necessary because in C++ operator-> is called recursively until a raw pointer is obtained
-template<class T>
-struct arrow_proxy {
-	T v;
-	explicit arrow_proxy(T const & v) : v(v) {}
-	T* operator->() { return &v; }
-};
-
 // MemContents efficiently represents the contents of a potentially sparse memory by storing only those segments that are actually defined
 class MemContents {
 public:
@@ -303,6 +295,7 @@ public:
 		reference operator *() const { return range(_memory->_data_width, _addr, _memory->_values.at(_addr)); }
 		pointer operator->() const { return arrow_proxy<range>(**this); }
 		bool operator !=(iterator const &other) const { return _memory != other._memory || _addr != other._addr; }
+		bool operator ==(iterator const &other) const { return !(*this != other); }
 		iterator &operator++();
 	};
 	MemContents(int addr_width, int data_width, RTLIL::Const default_value)
diff --git a/kernel/utils.h b/kernel/utils.h
index 3216c5eb5..99f327db4 100644
--- a/kernel/utils.h
+++ b/kernel/utils.h
@@ -253,6 +253,15 @@ template <typename T, typename C = std::less<T>, typename OPS = hash_ops<T>> cla
 	}
 };
 
+// this class is used for implementing operator-> on iterators that return values rather than references
+// it's necessary because in C++ operator-> is called recursively until a raw pointer is obtained
+template<class T>
+struct arrow_proxy {
+	T v;
+	explicit arrow_proxy(T const & v) : v(v) {}
+	T* operator->() { return &v; }
+};
+
 YOSYS_NAMESPACE_END
 
 #endif
diff --git a/passes/cmds/example_dt.cc b/passes/cmds/example_dt.cc
index 4b836d75b..aaf07dadd 100644
--- a/passes/cmds/example_dt.cc
+++ b/passes/cmds/example_dt.cc
@@ -1,7 +1,7 @@
 #include "kernel/yosys.h"
 #include "kernel/drivertools.h"
 #include "kernel/topo_scc.h"
-#include "kernel/functional.h"
+#include "kernel/compute_graph.h"
 
 USING_YOSYS_NAMESPACE
 PRIVATE_NAMESPACE_BEGIN