3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2025-06-06 14:13:23 +00:00

functional backend: add different types of input/output/state variables

This commit is contained in:
Emily Schmidt 2024-08-06 09:56:28 +01:00
parent 79a1b691ea
commit 50047d25b3
5 changed files with 237 additions and 120 deletions

View file

@ -152,11 +152,60 @@ namespace Functional {
bool operator==(Sort const& other) const { return _v == other._v; }
unsigned int hash() const { return mkhash(_v); }
};
class IR;
class Factory;
class Node;
class IRInput {
friend class Factory;
public:
IdString name;
IdString type;
Sort sort;
private:
IRInput(IR &, IdString name, IdString type, Sort sort)
: name(name), type(type), sort(std::move(sort)) {}
};
class IROutput {
friend class Factory;
IR &_ir;
public:
IdString name;
IdString type;
Sort sort;
private:
IROutput(IR &ir, IdString name, IdString type, Sort sort)
: _ir(ir), name(name), type(type), sort(std::move(sort)) {}
public:
Node value() const;
bool has_value() const;
void set_value(Node value);
};
class IRState {
friend class Factory;
IR &_ir;
public:
IdString name;
IdString type;
Sort sort;
private:
std::variant<RTLIL::Const, MemContents> _initial;
IRState(IR &ir, IdString name, IdString type, Sort sort)
: _ir(ir), name(name), type(type), sort(std::move(sort)) {}
public:
Node next_value() const;
bool has_next_value() const;
RTLIL::Const const& initial_value_signal() const { return std::get<RTLIL::Const>(_initial); }
MemContents const& initial_value_memory() const { return std::get<MemContents>(_initial); }
void set_next_value(Node value);
void set_initial_value(RTLIL::Const value) { value.extu(sort.width()); _initial = std::move(value); }
void set_initial_value(MemContents value) { log_assert(Sort(value.addr_width(), value.data_width()) == sort); _initial = std::move(value); }
};
class IR {
friend class Factory;
friend class Node;
friend class IRInput;
friend class IROutput;
friend class IRState;
// one NodeData is stored per Node, containing the function and non-node arguments
// note that NodeData is deduplicated by ComputeGraph
class NodeData {
@ -164,7 +213,7 @@ namespace Functional {
std::variant<
std::monostate,
RTLIL::Const,
IdString,
std::pair<IdString, IdString>,
int
> _extra;
public:
@ -173,7 +222,7 @@ namespace Functional {
template<class T> NodeData(Fn fn, T &&extra) : _fn(fn), _extra(std::forward<T>(extra)) {}
Fn fn() const { return _fn; }
const RTLIL::Const &as_const() const { return std::get<RTLIL::Const>(_extra); }
IdString as_idstring() const { return std::get<IdString>(_extra); }
std::pair<IdString, IdString> as_idstring_pair() const { return std::get<std::pair<IdString, IdString>>(_extra); }
int as_int() const { return std::get<int>(_extra); }
int hash() const {
return mkhash((unsigned int) _fn, mkhash(_extra));
@ -190,13 +239,12 @@ namespace Functional {
// the sparse_attr IdString stores a naming suggestion, retrieved with name()
// the key is currently used to identify the nodes that represent output and next state values
// the bool is true for next state values
using Graph = ComputeGraph<NodeData, Attr, IdString, std::pair<IdString, bool>>;
using Graph = ComputeGraph<NodeData, Attr, IdString, std::tuple<IdString, IdString, bool>>;
Graph _graph;
dict<IdString, Sort> _input_sorts;
dict<IdString, Sort> _output_sorts;
dict<IdString, Sort> _state_sorts;
dict<IdString, RTLIL::Const> _initial_state_signal;
dict<IdString, MemContents> _initial_state_memory;
dict<std::pair<IdString, IdString>, IRInput> _inputs;
dict<std::pair<IdString, IdString>, IROutput> _outputs;
dict<std::pair<IdString, IdString>, IRState> _states;
IR::Graph::Ref mutate(Node n);
public:
static IR from_module(Module *module);
Factory factory();
@ -204,13 +252,24 @@ namespace Functional {
Node operator[](int i);
void topological_sort();
void forward_buf();
dict<IdString, Sort> inputs() const { return _input_sorts; }
dict<IdString, Sort> outputs() const { return _output_sorts; }
dict<IdString, Sort> state() const { return _state_sorts; }
RTLIL::Const const &get_initial_state_signal(IdString name) { return _initial_state_signal.at(name); }
MemContents const &get_initial_state_memory(IdString name) { return _initial_state_memory.at(name); }
Node get_output_node(IdString name);
Node get_state_next_node(IdString name);
IRInput const& input(IdString name, IdString type) const { return _inputs.at({name, type}); }
IRInput const& input(IdString name) const { return input(name, ID($input)); }
IROutput const& output(IdString name, IdString type) const { return _outputs.at({name, type}); }
IROutput const& output(IdString name) const { return output(name, ID($output)); }
IRState const& state(IdString name, IdString type) const { return _states.at({name, type}); }
IRState const& state(IdString name) const { return state(name, ID($state)); }
bool has_input(IdString name, IdString type) const { return _inputs.count({name, type}); }
bool has_output(IdString name, IdString type) const { return _outputs.count({name, type}); }
bool has_state(IdString name, IdString type) const { return _states.count({name, type}); }
vector<IRInput const*> inputs(IdString type) const;
vector<IRInput const*> inputs() const { return inputs(ID($input)); }
vector<IROutput const*> outputs(IdString type) const;
vector<IROutput const*> outputs() const { return outputs(ID($output)); }
vector<IRState const*> states(IdString type) const;
vector<IRState const*> states() const { return states(ID($state)); }
vector<IRInput const*> all_inputs() const;
vector<IROutput const*> all_outputs() const;
vector<IRState const*> all_states() const;
class iterator {
friend class IR;
IR *_ir;
@ -236,6 +295,9 @@ namespace Functional {
class Node {
friend class Factory;
friend class IR;
friend class IRInput;
friend class IROutput;
friend class IRState;
IR::Graph::ConstRef _ref;
explicit Node(IR::Graph::ConstRef ref) : _ref(ref) { }
explicit operator IR::Graph::ConstRef() { return _ref; }
@ -290,8 +352,8 @@ namespace Functional {
case Fn::arithmetic_shift_right: return v.arithmetic_shift_right(*this, arg(0), arg(1)); break;
case Fn::mux: return v.mux(*this, arg(0), arg(1), arg(2)); break;
case Fn::constant: return v.constant(*this, _ref.function().as_const()); break;
case Fn::input: return v.input(*this, _ref.function().as_idstring()); break;
case Fn::state: return v.state(*this, _ref.function().as_idstring()); break;
case Fn::input: return v.input(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break;
case Fn::state: return v.state(*this, _ref.function().as_idstring_pair().first, _ref.function().as_idstring_pair().second); break;
case Fn::memory_read: return v.memory_read(*this, arg(0), arg(1)); break;
case Fn::memory_write: return v.memory_write(*this, arg(0), arg(1), arg(2)); break;
}
@ -300,9 +362,14 @@ namespace Functional {
std::string to_string();
std::string to_string(std::function<std::string(Node)>);
};
inline IR::Graph::Ref IR::mutate(Node n) { return _graph[n._ref.index()]; }
inline Node IR::operator[](int i) { return Node(_graph[i]); }
inline Node IR::get_output_node(IdString name) { return Node(_graph({name, false})); }
inline Node IR::get_state_next_node(IdString name) { return Node(_graph({name, true})); }
inline Node IROutput::value() const { return Node(_ir._graph({name, type, false})); }
inline bool IROutput::has_value() const { return _ir._graph.has_key({name, type, false}); }
inline void IROutput::set_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, type, false}); }
inline Node IRState::next_value() const { return Node(_ir._graph({name, type, true})); }
inline bool IRState::has_next_value() const { return _ir._graph.has_key({name, type, true}); }
inline void IRState::set_next_value(Node value) { log_assert(sort == value.sort()); _ir.mutate(value).assign_key({name, type, true}); }
inline Node IR::iterator::operator*() { return Node(_ir->_graph[_index]); }
inline arrow_proxy<Node> IR::iterator::operator->() { return arrow_proxy<Node>(**this); }
// AbstractVisitor provides an abstract base class for visitors
@ -336,8 +403,8 @@ namespace Functional {
virtual T arithmetic_shift_right(Node self, Node a, Node b) = 0;
virtual T mux(Node self, Node a, Node b, Node s) = 0;
virtual T constant(Node self, RTLIL::Const const & value) = 0;
virtual T input(Node self, IdString name) = 0;
virtual T state(Node self, IdString name) = 0;
virtual T input(Node self, IdString name, IdString type) = 0;
virtual T state(Node self, IdString name, IdString type) = 0;
virtual T memory_read(Node self, Node mem, Node addr) = 0;
virtual T memory_write(Node self, Node mem, Node addr, Node data) = 0;
};
@ -373,8 +440,8 @@ namespace Functional {
T arithmetic_shift_right(Node self, Node, Node) override { return default_handler(self); }
T mux(Node self, Node, Node, Node) override { return default_handler(self); }
T constant(Node self, RTLIL::Const const &) override { return default_handler(self); }
T input(Node self, IdString) override { return default_handler(self); }
T state(Node self, IdString) override { return default_handler(self); }
T input(Node self, IdString, IdString) override { return default_handler(self); }
T state(Node self, IdString, IdString) override { return default_handler(self); }
T memory_read(Node self, Node, Node) override { return default_handler(self); }
T memory_write(Node self, Node, Node, Node) override { return default_handler(self); }
};
@ -383,7 +450,7 @@ namespace Functional {
friend class IR;
IR &_ir;
explicit Factory(IR &ir) : _ir(ir) {}
Node add(IR::NodeData &&fn, Sort &&sort, std::initializer_list<Node> args) {
Node add(IR::NodeData &&fn, Sort const &sort, std::initializer_list<Node> args) {
log_assert(!sort.is_signal() || sort.width() > 0);
log_assert(!sort.is_memory() || (sort.addr_width() > 0 && sort.data_width() > 0));
IR::Graph::Ref ref = _ir._graph.add(std::move(fn), {std::move(sort)});
@ -391,13 +458,11 @@ namespace Functional {
ref.append_arg(IR::Graph::ConstRef(arg));
return Node(ref);
}
IR::Graph::Ref mutate(Node n) {
return _ir._graph[n._ref.index()];
}
void check_basic_binary(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && a.sort() == b.sort()); }
void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal() && b.width() == ceil_log2(a.width())); }
void check_unary(Node const &a) { log_assert(a.sort().is_signal()); }
public:
IR &ir() { return _ir; }
Node slice(Node a, int offset, int out_width) {
log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width());
if(offset == 0 && out_width == a.width())
@ -480,45 +545,31 @@ namespace Functional {
void update_pending(Node node, Node value) {
log_assert(node._ref.function() == Fn::buf && node._ref.size() == 0);
log_assert(node.sort() == value.sort());
mutate(node).append_arg(value._ref);
_ir.mutate(node).append_arg(value._ref);
}
void add_input(IdString name, int width) {
auto [it, inserted] = _ir._input_sorts.emplace(name, Sort(width));
IRInput &add_input(IdString name, IdString type, Sort sort) {
auto [it, inserted] = _ir._inputs.emplace({name, type}, IRInput(_ir, name, type, std::move(sort)));
if (!inserted) log_error("input `%s` was re-defined", name.c_str());
return it->second;
}
void add_output(IdString name, int width) {
auto [it, inserted] = _ir._output_sorts.emplace(name, Sort(width));
IROutput &add_output(IdString name, IdString type, Sort sort) {
auto [it, inserted] = _ir._outputs.emplace({name, type}, IROutput(_ir, name, type, std::move(sort)));
if (!inserted) log_error("output `%s` was re-defined", name.c_str());
return it->second;
}
void add_state(IdString name, Sort sort) {
auto [it, inserted] = _ir._state_sorts.emplace(name, sort);
IRState &add_state(IdString name, IdString type, Sort sort) {
auto [it, inserted] = _ir._states.emplace({name, type}, IRState(_ir, name, type, std::move(sort)));
if (!inserted) log_error("state `%s` was re-defined", name.c_str());
return it->second;
}
Node input(IdString name) {
return add(IR::NodeData(Fn::input, name), Sort(_ir._input_sorts.at(name)), {});
Node value(IRInput const& input) {
return add(IR::NodeData(Fn::input, std::pair(input.name, input.type)), input.sort, {});
}
Node current_state(IdString name) {
return add(IR::NodeData(Fn::state, name), Sort(_ir._state_sorts.at(name)), {});
}
void set_output(IdString output, Node value) {
log_assert(_ir._output_sorts.at(output) == value.sort());
mutate(value).assign_key({output, false});
}
void set_initial_state(IdString state, RTLIL::Const value) {
Sort &sort = _ir._state_sorts.at(state);
value.extu(sort.width());
_ir._initial_state_signal.emplace(state, std::move(value));
}
void set_initial_state(IdString state, MemContents value) {
log_assert(Sort(value.addr_width(), value.data_width()) == _ir._state_sorts.at(state));
_ir._initial_state_memory.emplace(state, std::move(value));
}
void set_next_state(IdString state, Node value) {
log_assert(_ir._state_sorts.at(state) == value.sort());
mutate(value).assign_key({state, true});
Node value(IRState const& state) {
return add(IR::NodeData(Fn::state, std::pair(state.name, state.type)), state.sort, {});
}
void suggest_name(Node node, IdString name) {
mutate(node).sparse_attr() = name;
_ir.mutate(node).sparse_attr() = name;
}
};
inline Factory IR::factory() { return Factory(*this); }