3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2025-04-15 13:28:59 +00:00

remove widths parameters from FunctionalIR factory methods and from functionalir.cc

This commit is contained in:
Emily Schmidt 2024-07-17 12:42:24 +01:00
parent 55c2c17853
commit 7f8f21b980
2 changed files with 216 additions and 241 deletions

View file

@ -98,87 +98,75 @@ std::string FunctionalIR::Node::to_string(std::function<std::string(Node)> np)
return visit(PrintVisitor(np)); return visit(PrintVisitor(np));
} }
template <class T, class Factory>
class CellSimplifier { class CellSimplifier {
Factory &factory; using Node = FunctionalIR::Node;
T reduce_shift_width(T b, int b_width, int y_width, int &reduced_b_width) { FunctionalIR::Factory &factory;
Node reduce_shift_width(Node b, int y_width) {
log_assert(y_width > 0); log_assert(y_width > 0);
int new_width = ceil_log2(y_width + 1); int new_width = ceil_log2(y_width + 1);
if (b_width <= new_width) { if (b.width() <= new_width) {
reduced_b_width = b_width;
return b; return b;
} else { } else {
reduced_b_width = new_width; Node lower_b = factory.slice(b, 0, new_width);
T lower_b = factory.slice(b, b_width, 0, new_width); Node overflow = factory.unsigned_greater_than(b, factory.constant(RTLIL::Const(y_width, b.width())));
T overflow = factory.unsigned_greater_than(b, factory.constant(RTLIL::Const(y_width, b_width)), b_width); return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow);
return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow, new_width);
} }
} }
T sign(T a, int a_width) { Node sign(Node a) {
return factory.slice(a, a_width, a_width - 1, 1); return factory.slice(a, a.width() - 1, 1);
} }
T neg_if(T a, int a_width, T s) { Node neg_if(Node a, Node s) {
return factory.mux(a, factory.unary_minus(a, a_width), s, a_width); return factory.mux(a, factory.unary_minus(a), s);
} }
T abs(T a, int a_width) { Node abs(Node a) {
return neg_if(a, a_width, sign(a, a_width)); return neg_if(a, sign(a));
} }
public: public:
T extend(T a, int in_width, int out_width, bool is_signed) { Node logical_shift_left(Node a, Node b) {
if(in_width == out_width) Node reduced_b = reduce_shift_width(b, a.width());
return a; return factory.logical_shift_left(a, reduced_b);
if(in_width > out_width)
return factory.slice(a, in_width, 0, out_width);
return factory.extend(a, in_width, out_width, is_signed);
} }
T logical_shift_left(T a, T b, int y_width, int b_width) { Node logical_shift_right(Node a, Node b) {
int reduced_b_width; Node reduced_b = reduce_shift_width(b, a.width());
T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width); return factory.logical_shift_right(a, reduced_b);
return factory.logical_shift_left(a, reduced_b, y_width, reduced_b_width);
} }
T logical_shift_right(T a, T b, int y_width, int b_width) { Node arithmetic_shift_right(Node a, Node b) {
int reduced_b_width; Node reduced_b = reduce_shift_width(b, a.width());
T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width); return factory.arithmetic_shift_right(a, reduced_b);
return factory.logical_shift_right(a, reduced_b, y_width, reduced_b_width);
} }
T arithmetic_shift_right(T a, T b, int y_width, int b_width) { Node bitwise_mux(Node a, Node b, Node s) {
int reduced_b_width; Node aa = factory.bitwise_and(a, factory.bitwise_not(s));
T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width); Node bb = factory.bitwise_and(b, s);
return factory.arithmetic_shift_right(a, reduced_b, y_width, reduced_b_width); return factory.bitwise_or(aa, bb);
} }
T bitwise_mux(T a, T b, T s, int width) { CellSimplifier(FunctionalIR::Factory &f) : factory(f) {}
T aa = factory.bitwise_and(a, factory.bitwise_not(s, width), width);
T bb = factory.bitwise_and(b, s, width);
return factory.bitwise_or(aa, bb, width);
}
CellSimplifier(Factory &f) : factory(f) {}
private: private:
T handle_pow(T a0, int a_width, T b, int b_width, int y_width, bool is_signed) { Node handle_pow(Node a0, Node b, int y_width, bool is_signed) {
T a = extend(a0, a_width, y_width, is_signed); Node a = factory.extend(a0, y_width, is_signed);
T r = factory.constant(Const(1, y_width)); Node r = factory.constant(Const(1, y_width));
for(int i = 0; i < b_width; i++) { for(int i = 0; i < b.width(); i++) {
T b_bit = factory.slice(b, b_width, i, 1); Node b_bit = factory.slice(b, i, 1);
r = factory.mux(r, factory.mul(r, a, y_width), b_bit, y_width); r = factory.mux(r, factory.mul(r, a), b_bit);
a = factory.mul(a, a, y_width); a = factory.mul(a, a);
} }
if (is_signed) { if (is_signed) {
T a_ge_1 = factory.unsigned_greater_than(abs(a0, a_width), factory.constant(Const(1, a_width)), a_width); Node a_ge_1 = factory.unsigned_greater_than(abs(a0), factory.constant(Const(1, a0.width())));
T zero_result = factory.bitwise_and(a_ge_1, sign(b, b_width), 1); Node zero_result = factory.bitwise_and(a_ge_1, sign(b));
r = factory.mux(r, factory.constant(Const(0, y_width)), zero_result, y_width); r = factory.mux(r, factory.constant(Const(0, y_width)), zero_result);
} }
return r; return r;
} }
T handle_bmux(T a, T s, int a_width, int a_offset, int width, int s_width, int sn) { Node handle_bmux(Node a, Node s, int a_offset, int width, int sn) {
if(sn < 1) if(sn < 1)
return factory.slice(a, a_width, a_offset, width); return factory.slice(a, a_offset, width);
else { else {
T y0 = handle_bmux(a, s, a_width, a_offset, width, s_width, sn - 1); Node y0 = handle_bmux(a, s, a_offset, width, sn - 1);
T y1 = handle_bmux(a, s, a_width, a_offset + (width << (sn - 1)), width, s_width, sn - 1); Node y1 = handle_bmux(a, s, a_offset + (width << (sn - 1)), width, sn - 1);
return factory.mux(y0, y1, factory.slice(s, s_width, sn - 1, 1), width); return factory.mux(y0, y1, factory.slice(s, sn - 1, 1));
} }
} }
public: public:
T handle(IdString cellType, dict<IdString, Const> parameters, dict<IdString, T> inputs) Node handle(IdString cellType, dict<IdString, Const> parameters, dict<IdString, Node> inputs)
{ {
int a_width = parameters.at(ID(A_WIDTH), Const(-1)).as_int(); int a_width = parameters.at(ID(A_WIDTH), Const(-1)).as_int();
int b_width = parameters.at(ID(B_WIDTH), Const(-1)).as_int(); int b_width = parameters.at(ID(B_WIDTH), Const(-1)).as_int();
@ -187,208 +175,202 @@ public:
bool b_signed = parameters.at(ID(B_SIGNED), Const(0)).as_bool(); bool b_signed = parameters.at(ID(B_SIGNED), Const(0)).as_bool();
if(cellType.in({ID($add), ID($sub), ID($and), ID($or), ID($xor), ID($xnor), ID($mul)})){ if(cellType.in({ID($add), ID($sub), ID($and), ID($or), ID($xor), ID($xnor), ID($mul)})){
bool is_signed = a_signed && b_signed; bool is_signed = a_signed && b_signed;
T a = extend(inputs.at(ID(A)), a_width, y_width, is_signed); Node a = factory.extend(inputs.at(ID(A)), y_width, is_signed);
T b = extend(inputs.at(ID(B)), b_width, y_width, is_signed); Node b = factory.extend(inputs.at(ID(B)), y_width, is_signed);
if(cellType == ID($add)) if(cellType == ID($add))
return factory.add(a, b, y_width); return factory.add(a, b);
else if(cellType == ID($sub)) else if(cellType == ID($sub))
return factory.sub(a, b, y_width); return factory.sub(a, b);
else if(cellType == ID($mul)) else if(cellType == ID($mul))
return factory.mul(a, b, y_width); return factory.mul(a, b);
else if(cellType == ID($and)) else if(cellType == ID($and))
return factory.bitwise_and(a, b, y_width); return factory.bitwise_and(a, b);
else if(cellType == ID($or)) else if(cellType == ID($or))
return factory.bitwise_or(a, b, y_width); return factory.bitwise_or(a, b);
else if(cellType == ID($xor)) else if(cellType == ID($xor))
return factory.bitwise_xor(a, b, y_width); return factory.bitwise_xor(a, b);
else if(cellType == ID($xnor)) else if(cellType == ID($xnor))
return factory.bitwise_not(factory.bitwise_xor(a, b, y_width), y_width); return factory.bitwise_not(factory.bitwise_xor(a, b));
else else
log_abort(); log_abort();
}else if(cellType.in({ID($eq), ID($ne), ID($eqx), ID($nex), ID($le), ID($lt), ID($ge), ID($gt)})){ }else if(cellType.in({ID($eq), ID($ne), ID($eqx), ID($nex), ID($le), ID($lt), ID($ge), ID($gt)})){
bool is_signed = a_signed && b_signed; bool is_signed = a_signed && b_signed;
int width = max(a_width, b_width); int width = max(a_width, b_width);
T a = extend(inputs.at(ID(A)), a_width, width, is_signed); Node a = factory.extend(inputs.at(ID(A)), width, is_signed);
T b = extend(inputs.at(ID(B)), b_width, width, is_signed); Node b = factory.extend(inputs.at(ID(B)), width, is_signed);
if(cellType.in({ID($eq), ID($eqx)})) if(cellType.in({ID($eq), ID($eqx)}))
return extend(factory.equal(a, b, width), 1, y_width, false); return factory.extend(factory.equal(a, b), y_width, false);
else if(cellType.in({ID($ne), ID($nex)})) else if(cellType.in({ID($ne), ID($nex)}))
return extend(factory.not_equal(a, b, width), 1, y_width, false); return factory.extend(factory.not_equal(a, b), y_width, false);
else if(cellType == ID($lt)) else if(cellType == ID($lt))
return extend(is_signed ? factory.signed_greater_than(b, a, width) : factory.unsigned_greater_than(b, a, width), 1, y_width, false); return factory.extend(is_signed ? factory.signed_greater_than(b, a) : factory.unsigned_greater_than(b, a), y_width, false);
else if(cellType == ID($le)) else if(cellType == ID($le))
return extend(is_signed ? factory.signed_greater_equal(b, a, width) : factory.unsigned_greater_equal(b, a, width), 1, y_width, false); return factory.extend(is_signed ? factory.signed_greater_equal(b, a) : factory.unsigned_greater_equal(b, a), y_width, false);
else if(cellType == ID($gt)) else if(cellType == ID($gt))
return extend(is_signed ? factory.signed_greater_than(a, b, width) : factory.unsigned_greater_than(a, b, width), 1, y_width, false); return factory.extend(is_signed ? factory.signed_greater_than(a, b) : factory.unsigned_greater_than(a, b), y_width, false);
else if(cellType == ID($ge)) else if(cellType == ID($ge))
return extend(is_signed ? factory.signed_greater_equal(a, b, width) : factory.unsigned_greater_equal(a, b, width), 1, y_width, false); return factory.extend(is_signed ? factory.signed_greater_equal(a, b) : factory.unsigned_greater_equal(a, b), y_width, false);
else else
log_abort(); log_abort();
}else if(cellType.in({ID($logic_or), ID($logic_and)})){ }else if(cellType.in({ID($logic_or), ID($logic_and)})){
T a = factory.reduce_or(inputs.at(ID(A)), a_width); Node a = factory.reduce_or(inputs.at(ID(A)));
T b = factory.reduce_or(inputs.at(ID(B)), b_width); Node b = factory.reduce_or(inputs.at(ID(B)));
T y = cellType == ID($logic_and) ? factory.bitwise_and(a, b, 1) : factory.bitwise_or(a, b, 1); Node y = cellType == ID($logic_and) ? factory.bitwise_and(a, b) : factory.bitwise_or(a, b);
return extend(y, 1, y_width, false); return factory.extend(y, y_width, false);
}else if(cellType == ID($not)){ }else if(cellType == ID($not)){
T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed); Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed);
return factory.bitwise_not(a, y_width); return factory.bitwise_not(a);
}else if(cellType == ID($pos)){ }else if(cellType == ID($pos)){
return extend(inputs.at(ID(A)), a_width, y_width, a_signed); return factory.extend(inputs.at(ID(A)), y_width, a_signed);
}else if(cellType == ID($neg)){ }else if(cellType == ID($neg)){
T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed); Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed);
return factory.unary_minus(a, y_width); return factory.unary_minus(a);
}else if(cellType == ID($logic_not)){ }else if(cellType == ID($logic_not)){
T a = factory.reduce_or(inputs.at(ID(A)), a_width); Node a = factory.reduce_or(inputs.at(ID(A)));
T y = factory.bitwise_not(a, 1); Node y = factory.bitwise_not(a);
return extend(y, 1, y_width, false); return factory.extend(y, y_width, false);
}else if(cellType.in({ID($reduce_or), ID($reduce_bool)})){ }else if(cellType.in({ID($reduce_or), ID($reduce_bool)})){
T a = factory.reduce_or(inputs.at(ID(A)), a_width); Node a = factory.reduce_or(inputs.at(ID(A)));
return extend(a, 1, y_width, false); return factory.extend(a, y_width, false);
}else if(cellType == ID($reduce_and)){ }else if(cellType == ID($reduce_and)){
T a = factory.reduce_and(inputs.at(ID(A)), a_width); Node a = factory.reduce_and(inputs.at(ID(A)));
return extend(a, 1, y_width, false); return factory.extend(a, y_width, false);
}else if(cellType.in({ID($reduce_xor), ID($reduce_xnor)})){ }else if(cellType.in({ID($reduce_xor), ID($reduce_xnor)})){
T a = factory.reduce_xor(inputs.at(ID(A)), a_width); Node a = factory.reduce_xor(inputs.at(ID(A)));
T y = cellType == ID($reduce_xnor) ? factory.bitwise_not(a, 1) : a; Node y = cellType == ID($reduce_xnor) ? factory.bitwise_not(a) : a;
return extend(y, 1, y_width, false); return factory.extend(y, y_width, false);
}else if(cellType == ID($shl) || cellType == ID($sshl)){ }else if(cellType == ID($shl) || cellType == ID($sshl)){
T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed); Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed);
T b = inputs.at(ID(B)); Node b = inputs.at(ID(B));
return logical_shift_left(a, b, y_width, b_width); return logical_shift_left(a, b);
}else if(cellType == ID($shr) || cellType == ID($sshr)){ }else if(cellType == ID($shr) || cellType == ID($sshr)){
int width = max(a_width, y_width); int width = max(a_width, y_width);
T a = extend(inputs.at(ID(A)), a_width, width, a_signed); Node a = factory.extend(inputs.at(ID(A)), width, a_signed);
T b = inputs.at(ID(B)); Node b = inputs.at(ID(B));
T y = a_signed && cellType == ID($sshr) ? Node y = a_signed && cellType == ID($sshr) ?
arithmetic_shift_right(a, b, width, b_width) : arithmetic_shift_right(a, b) :
logical_shift_right(a, b, width, b_width); logical_shift_right(a, b);
return extend(y, width, y_width, a_signed); return factory.extend(y, y_width, a_signed);
}else if(cellType == ID($shiftx) || cellType == ID($shift)){ }else if(cellType == ID($shiftx) || cellType == ID($shift)){
int width = max(a_width, y_width); int width = max(a_width, y_width);
T a = extend(inputs.at(ID(A)), a_width, width, cellType == ID($shift) && a_signed); Node a = factory.extend(inputs.at(ID(A)), width, cellType == ID($shift) && a_signed);
T b = inputs.at(ID(B)); Node b = inputs.at(ID(B));
T shr = logical_shift_right(a, b, width, b_width); Node shr = logical_shift_right(a, b);
if(b_signed) { if(b_signed) {
T sign_b = sign(b, b_width); Node shl = logical_shift_left(a, factory.unary_minus(b));
T shl = logical_shift_left(a, factory.unary_minus(b, b_width), width, b_width); Node y = factory.mux(shr, shl, sign(b));
T y = factory.mux(shr, shl, sign_b, width); return factory.extend(y, y_width, false);
return extend(y, width, y_width, false);
} else { } else {
return extend(shr, width, y_width, false); return factory.extend(shr, y_width, false);
} }
}else if(cellType == ID($mux)){ }else if(cellType == ID($mux)){
int width = parameters.at(ID(WIDTH)).as_int(); return factory.mux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)));
return factory.mux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)), width);
}else if(cellType == ID($pmux)){ }else if(cellType == ID($pmux)){
int width = parameters.at(ID(WIDTH)).as_int(); return factory.pmux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)));
int s_width = parameters.at(ID(S_WIDTH)).as_int();
return factory.pmux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)), width, s_width);
}else if(cellType == ID($concat)){ }else if(cellType == ID($concat)){
T a = inputs.at(ID(A)); Node a = inputs.at(ID(A));
T b = inputs.at(ID(B)); Node b = inputs.at(ID(B));
return factory.concat(a, a_width, b, b_width); return factory.concat(a, b);
}else if(cellType == ID($slice)){ }else if(cellType == ID($slice)){
int offset = parameters.at(ID(OFFSET)).as_int(); int offset = parameters.at(ID(OFFSET)).as_int();
T a = inputs.at(ID(A)); Node a = inputs.at(ID(A));
return factory.slice(a, a_width, offset, y_width); return factory.slice(a, offset, y_width);
}else if(cellType.in({ID($div), ID($mod), ID($divfloor), ID($modfloor)})) { }else if(cellType.in({ID($div), ID($mod), ID($divfloor), ID($modfloor)})) {
int width = max(a_width, b_width); int width = max(a_width, b_width);
bool is_signed = a_signed && b_signed; bool is_signed = a_signed && b_signed;
T a = extend(inputs.at(ID(A)), a_width, width, is_signed); Node a = factory.extend(inputs.at(ID(A)), width, is_signed);
T b = extend(inputs.at(ID(B)), b_width, width, is_signed); Node b = factory.extend(inputs.at(ID(B)), width, is_signed);
if(is_signed) { if(is_signed) {
if(cellType == ID($div)) { if(cellType == ID($div)) {
// divide absolute values, then flip the sign if input signs differ // divide absolute values, then flip the sign if input signs differ
// but extend the width first, to handle the case (most negative value) / (-1) // but extend the width first, to handle the case (most negative value) / (-1)
T abs_y = factory.unsigned_div(abs(a, width), abs(b, width), width); Node abs_y = factory.unsigned_div(abs(a), abs(b));
T out_sign = factory.not_equal(sign(a, width), sign(b, width), 1); Node out_sign = factory.not_equal(sign(a), sign(b));
return neg_if(extend(abs_y, width, y_width, false), y_width, out_sign); return neg_if(factory.extend(abs_y, y_width, false), out_sign);
} else if(cellType == ID($mod)) { } else if(cellType == ID($mod)) {
// similar to division but output sign == divisor sign // similar to division but output sign == divisor sign
T abs_y = factory.unsigned_mod(abs(a, width), abs(b, width), width); Node abs_y = factory.unsigned_mod(abs(a), abs(b));
return neg_if(extend(abs_y, width, y_width, false), y_width, sign(a, width)); return neg_if(factory.extend(abs_y, y_width, false), sign(a));
} else if(cellType == ID($divfloor)) { } else if(cellType == ID($divfloor)) {
// if b is negative, flip both signs so that b is positive // if b is negative, flip both signs so that b is positive
T b_sign = sign(b, width); Node b_sign = sign(b);
T a1 = neg_if(a, width, b_sign); Node a1 = neg_if(a, b_sign);
T b1 = neg_if(b, width, b_sign); Node b1 = neg_if(b, b_sign);
// if a is now negative, calculate ~((~a) / b) = -((-a - 1) / b + 1) // if a is now negative, calculate ~((~a) / b) = -((-a - 1) / b + 1)
// which equals the negative of (-a) / b with rounding up rather than down // which equals the negative of (-a) / b with rounding up rather than down
// note that to handle the case where a = most negative value properly, // note that to handle the case where a = most negative value properly,
// we have to calculate a1_sign from the original values rather than using sign(a1, width) // we have to calculate a1_sign from the original values rather than using sign(a1)
T a1_sign = factory.bitwise_and(factory.not_equal(sign(a, width), sign(b, width), 1), factory.reduce_or(a, width), 1); Node a1_sign = factory.bitwise_and(factory.not_equal(sign(a), sign(b)), factory.reduce_or(a));
T a2 = factory.mux(a1, factory.bitwise_not(a1, width), a1_sign, width); Node a2 = factory.mux(a1, factory.bitwise_not(a1), a1_sign);
T y1 = factory.unsigned_div(a2, b1, width); Node y1 = factory.unsigned_div(a2, b1);
T y2 = extend(y1, width, y_width, false); Node y2 = factory.extend(y1, y_width, false);
return factory.mux(y2, factory.bitwise_not(y2, y_width), a1_sign, y_width); return factory.mux(y2, factory.bitwise_not(y2), a1_sign);
} else if(cellType == ID($modfloor)) { } else if(cellType == ID($modfloor)) {
// calculate |a| % |b| and then subtract from |b| if input signs differ and the remainder is non-zero // calculate |a| % |b| and then subtract from |b| if input signs differ and the remainder is non-zero
T abs_b = abs(b, width); Node abs_b = abs(b);
T abs_y = factory.unsigned_mod(abs(a, width), abs_b, width); Node abs_y = factory.unsigned_mod(abs(a), abs_b);
T flip_y = factory.bitwise_and(factory.bitwise_xor(sign(a, width), sign(b, width), 1), factory.reduce_or(abs_y, width), 1); Node flip_y = factory.bitwise_and(factory.bitwise_xor(sign(a), sign(b)), factory.reduce_or(abs_y));
T y_flipped = factory.mux(abs_y, factory.sub(abs_b, abs_y, width), flip_y, width); Node y_flipped = factory.mux(abs_y, factory.sub(abs_b, abs_y), flip_y);
// since y_flipped is strictly less than |b|, the top bit is always 0 and we can just sign extend the flipped result // since y_flipped is strictly less than |b|, the top bit is always 0 and we can just sign extend the flipped result
T y = neg_if(y_flipped, width, sign(b, b_width)); Node y = neg_if(y_flipped, sign(b));
return extend(y, width, y_width, true); return factory.extend(y, y_width, true);
} else } else
log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str()); log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str());
} else { } else {
if(cellType.in({ID($mod), ID($modfloor)})) if(cellType.in({ID($mod), ID($modfloor)}))
return extend(factory.unsigned_mod(a, b, width), width, y_width, false); return factory.extend(factory.unsigned_mod(a, b), y_width, false);
else else
return extend(factory.unsigned_div(a, b, width), width, y_width, false); return factory.extend(factory.unsigned_div(a, b), y_width, false);
} }
} else if(cellType == ID($pow)) { } else if(cellType == ID($pow)) {
return handle_pow(inputs.at(ID(A)), a_width, inputs.at(ID(B)), b_width, y_width, a_signed && b_signed); return handle_pow(inputs.at(ID(A)), inputs.at(ID(B)), y_width, a_signed && b_signed);
} else if (cellType == ID($lut)) { } else if (cellType == ID($lut)) {
int width = parameters.at(ID(WIDTH)).as_int(); int width = parameters.at(ID(WIDTH)).as_int();
Const lut_table = parameters.at(ID(LUT)); Const lut_table = parameters.at(ID(LUT));
lut_table.extu(1 << width); lut_table.extu(1 << width);
return handle_bmux(factory.constant(lut_table), inputs.at(ID(A)), 1 << width, 0, 1, width, width); return handle_bmux(factory.constant(lut_table), inputs.at(ID(A)), 0, 1, width);
} else if (cellType == ID($bwmux)) { } else if (cellType == ID($bwmux)) {
int width = parameters.at(ID(WIDTH)).as_int(); Node a = inputs.at(ID(A));
T a = inputs.at(ID(A)); Node b = inputs.at(ID(B));
T b = inputs.at(ID(B)); Node s = inputs.at(ID(S));
T s = inputs.at(ID(S));
return factory.bitwise_or( return factory.bitwise_or(
factory.bitwise_and(a, factory.bitwise_not(s, width), width), factory.bitwise_and(a, factory.bitwise_not(s)),
factory.bitwise_and(b, s, width), width); factory.bitwise_and(b, s));
} else if (cellType == ID($bweqx)) { } else if (cellType == ID($bweqx)) {
int width = parameters.at(ID(WIDTH)).as_int(); Node a = inputs.at(ID(A));
T a = inputs.at(ID(A)); Node b = inputs.at(ID(B));
T b = inputs.at(ID(B)); return factory.bitwise_not(factory.bitwise_xor(a, b));
return factory.bitwise_not(factory.bitwise_xor(a, b, width), width);
} else if(cellType == ID($bmux)) { } else if(cellType == ID($bmux)) {
int width = parameters.at(ID(WIDTH)).as_int(); int width = parameters.at(ID(WIDTH)).as_int();
int s_width = parameters.at(ID(S_WIDTH)).as_int(); int s_width = parameters.at(ID(S_WIDTH)).as_int();
return handle_bmux(inputs.at(ID(A)), inputs.at(ID(S)), width << s_width, 0, width, s_width, s_width); return handle_bmux(inputs.at(ID(A)), inputs.at(ID(S)), 0, width, s_width);
} else if(cellType == ID($demux)) { } else if(cellType == ID($demux)) {
int width = parameters.at(ID(WIDTH)).as_int(); int width = parameters.at(ID(WIDTH)).as_int();
int s_width = parameters.at(ID(S_WIDTH)).as_int(); int s_width = parameters.at(ID(S_WIDTH)).as_int();
int y_width = width << s_width; int y_width = width << s_width;
int b_width = ceil_log2(y_width + 1); int b_width = ceil_log2(y_width + 1);
T a = extend(inputs.at(ID(A)), width, y_width, false); Node a = factory.extend(inputs.at(ID(A)), y_width, false);
T s = factory.extend(inputs.at(ID(S)), s_width, b_width, false); Node s = factory.extend(inputs.at(ID(S)), b_width, false);
T b = factory.mul(s, factory.constant(Const(width, b_width)), b_width); Node b = factory.mul(s, factory.constant(Const(width, b_width)));
return factory.logical_shift_left(a, b, y_width, b_width); return factory.logical_shift_left(a, b);
} else { } else {
log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str()); log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str());
} }
} }
}; };
template <class T, class Factory>
class FunctionalIRConstruction { class FunctionalIRConstruction {
using Node = FunctionalIR::Node;
std::deque<DriveSpec> queue; std::deque<DriveSpec> queue;
dict<DriveSpec, T> graph_nodes; dict<DriveSpec, Node> graph_nodes;
idict<Cell *> cells; idict<Cell *> cells;
DriverMap driver_map; DriverMap driver_map;
Factory& factory; FunctionalIR::Factory& factory;
CellSimplifier<T, Factory> simplifier; CellSimplifier simplifier;
vector<Mem> memories_vector; vector<Mem> memories_vector;
dict<Cell*, Mem*> memories; dict<Cell*, Mem*> memories;
T enqueue(DriveSpec const &spec) Node enqueue(DriveSpec const &spec)
{ {
auto it = graph_nodes.find(spec); auto it = graph_nodes.find(spec);
if(it == graph_nodes.end()){ if(it == graph_nodes.end()){
@ -400,7 +382,7 @@ class FunctionalIRConstruction {
return it->second; return it->second;
} }
public: public:
FunctionalIRConstruction(Factory &f) : factory(f), simplifier(f) {} FunctionalIRConstruction(FunctionalIR::Factory &f) : factory(f), simplifier(f) {}
void add_module(Module *module) void add_module(Module *module)
{ {
driver_map.add(module); driver_map.add(module);
@ -410,7 +392,7 @@ public:
} }
for (auto wire : module->wires()) { for (auto wire : module->wires()) {
if (wire->port_output) { if (wire->port_output) {
T node = enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width))); Node node = enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width)));
factory.declare_output(node, wire->name, wire->width); factory.declare_output(node, wire->name, wire->width);
} }
} }
@ -420,37 +402,34 @@ public:
memories[mem.cell] = &mem; memories[mem.cell] = &mem;
} }
} }
T concatenate_read_results(Mem *, vector<T> results) Node concatenate_read_results(Mem *, vector<Node> results)
{ {
/* TODO: write code to check that this is ok to do */ /* TODO: write code to check that this is ok to do */
if(results.size() == 0) if(results.size() == 0)
return factory.undriven(0); return factory.undriven(0);
T node = results[0]; Node node = results[0];
int size = results[0].width(); for(size_t i = 1; i < results.size(); i++)
for(size_t i = 1; i < results.size(); i++) { node = factory.concat(node, results[i]);
node = factory.concat(node, size, results[i], results[i].width());
size += results[i].width();
}
return node; return node;
} }
T handle_memory(Mem *mem) Node handle_memory(Mem *mem)
{ {
vector<T> read_results; vector<Node> read_results;
int addr_width = ceil_log2(mem->size); int addr_width = ceil_log2(mem->size);
int data_width = mem->width; int data_width = mem->width;
T node = factory.state_memory(mem->cell->name, addr_width, data_width); Node node = factory.state_memory(mem->cell->name, addr_width, data_width);
for (auto &rd : mem->rd_ports) { for (auto &rd : mem->rd_ports) {
log_assert(!rd.clk_enable); log_assert(!rd.clk_enable);
T addr = enqueue(driver_map(DriveSpec(rd.addr))); Node addr = enqueue(driver_map(DriveSpec(rd.addr)));
read_results.push_back(factory.memory_read(node, addr, addr_width, data_width)); read_results.push_back(factory.memory_read(node, addr));
} }
for (auto &wr : mem->wr_ports) { for (auto &wr : mem->wr_ports) {
T en = enqueue(driver_map(DriveSpec(wr.en))); Node en = enqueue(driver_map(DriveSpec(wr.en)));
T addr = enqueue(driver_map(DriveSpec(wr.addr))); Node addr = enqueue(driver_map(DriveSpec(wr.addr)));
T new_data = enqueue(driver_map(DriveSpec(wr.data))); Node new_data = enqueue(driver_map(DriveSpec(wr.data)));
T old_data = factory.memory_read(node, addr, addr_width, data_width); Node old_data = factory.memory_read(node, addr);
T wr_data = simplifier.bitwise_mux(old_data, new_data, en, data_width); Node wr_data = simplifier.bitwise_mux(old_data, new_data, en);
node = factory.memory_write(node, addr, wr_data, addr_width, data_width); node = factory.memory_write(node, addr, wr_data);
} }
factory.declare_state_memory(node, mem->cell->name, addr_width, data_width); factory.declare_state_memory(node, mem->cell->name, addr_width, data_width);
return concatenate_read_results(mem, read_results); return concatenate_read_results(mem, read_results);
@ -459,16 +438,13 @@ public:
{ {
for (; !queue.empty(); queue.pop_front()) { for (; !queue.empty(); queue.pop_front()) {
DriveSpec spec = queue.front(); DriveSpec spec = queue.front();
T pending = graph_nodes.at(spec); Node pending = graph_nodes.at(spec);
if (spec.chunks().size() > 1) { if (spec.chunks().size() > 1) {
auto chunks = spec.chunks(); auto chunks = spec.chunks();
T node = enqueue(chunks[0]); Node node = enqueue(chunks[0]);
int width = chunks[0].size(); for(size_t i = 1; i < chunks.size(); i++)
for(size_t i = 1; i < chunks.size(); i++) { node = factory.concat(node, enqueue(chunks[i]));
node = factory.concat(node, width, enqueue(chunks[i]), chunks[i].size());
width += chunks[i].size();
}
factory.update_pending(pending, node); factory.update_pending(pending, node);
} else if (spec.chunks().size() == 1) { } else if (spec.chunks().size() == 1) {
DriveChunk chunk = spec.chunks()[0]; DriveChunk chunk = spec.chunks()[0];
@ -476,18 +452,18 @@ public:
DriveChunkWire wire_chunk = chunk.wire(); DriveChunkWire wire_chunk = chunk.wire();
if (wire_chunk.is_whole()) { if (wire_chunk.is_whole()) {
if (wire_chunk.wire->port_input) { if (wire_chunk.wire->port_input) {
T node = factory.input(wire_chunk.wire->name, wire_chunk.width); Node node = factory.input(wire_chunk.wire->name, wire_chunk.width);
factory.suggest_name(node, wire_chunk.wire->name); factory.suggest_name(node, wire_chunk.wire->name);
factory.update_pending(pending, node); factory.update_pending(pending, node);
} else { } else {
DriveSpec driver = driver_map(DriveSpec(wire_chunk)); DriveSpec driver = driver_map(DriveSpec(wire_chunk));
T node = enqueue(driver); Node node = enqueue(driver);
factory.suggest_name(node, wire_chunk.wire->name); factory.suggest_name(node, wire_chunk.wire->name);
factory.update_pending(pending, node); factory.update_pending(pending, node);
} }
} else { } else {
DriveChunkWire whole_wire(wire_chunk.wire, 0, wire_chunk.wire->width); DriveChunkWire whole_wire(wire_chunk.wire, 0, wire_chunk.wire->width);
T node = factory.slice(enqueue(whole_wire), wire_chunk.wire->width, wire_chunk.offset, wire_chunk.width); Node node = factory.slice(enqueue(whole_wire), wire_chunk.offset, wire_chunk.width);
factory.update_pending(pending, node); factory.update_pending(pending, node);
} }
} else if (chunk.is_port()) { } else if (chunk.is_port()) {
@ -497,21 +473,22 @@ public:
if (port_chunk.cell->type.in(ID($dff), ID($ff))) if (port_chunk.cell->type.in(ID($dff), ID($ff)))
{ {
Cell *cell = port_chunk.cell; Cell *cell = port_chunk.cell;
T node = factory.state(cell->name, port_chunk.width); Node node = factory.state(cell->name, port_chunk.width);
factory.suggest_name(node, port_chunk.cell->name); factory.suggest_name(node, port_chunk.cell->name);
factory.update_pending(pending, node); factory.update_pending(pending, node);
for (auto const &conn : cell->connections()) { for (auto const &conn : cell->connections()) {
if (driver_map.celltypes.cell_input(cell->type, conn.first)) { if (driver_map.celltypes.cell_input(cell->type, conn.first)) {
T node = enqueue(DriveChunkPort(cell, conn)); Node node = enqueue(DriveChunkPort(cell, conn));
factory.declare_state(node, cell->name, port_chunk.width); factory.declare_state(node, cell->name, port_chunk.width);
} }
} }
} }
else else
{ {
T cell = enqueue(DriveChunkMarker(cells(port_chunk.cell), 0, port_chunk.width)); Node cell = enqueue(DriveChunkMarker(cells(port_chunk.cell), 0, port_chunk.width));
factory.suggest_name(cell, port_chunk.cell->name); factory.suggest_name(cell, port_chunk.cell->name);
T node = factory.cell_output(cell, port_chunk.cell->type, port_chunk.port, port_chunk.width); //Node node = factory.cell_output(cell, port_chunk.cell->type, port_chunk.port, port_chunk.width);
Node node = cell;
factory.suggest_name(node, port_chunk.cell->name.str() + "$" + port_chunk.port.str()); factory.suggest_name(node, port_chunk.cell->name.str() + "$" + port_chunk.port.str());
factory.update_pending(pending, node); factory.update_pending(pending, node);
} }
@ -521,37 +498,37 @@ public:
} }
} else { } else {
DriveChunkPort whole_port(port_chunk.cell, port_chunk.port, 0, GetSize(port_chunk.cell->connections().at(port_chunk.port))); DriveChunkPort whole_port(port_chunk.cell, port_chunk.port, 0, GetSize(port_chunk.cell->connections().at(port_chunk.port)));
T node = factory.slice(enqueue(whole_port), whole_port.width, port_chunk.offset, port_chunk.width); Node node = factory.slice(enqueue(whole_port), port_chunk.offset, port_chunk.width);
factory.update_pending(pending, node); factory.update_pending(pending, node);
} }
} else if (chunk.is_constant()) { } else if (chunk.is_constant()) {
T node = factory.constant(chunk.constant()); Node node = factory.constant(chunk.constant());
factory.suggest_name(node, "$const" + std::to_string(chunk.size()) + "b" + chunk.constant().as_string()); factory.suggest_name(node, "$const" + std::to_string(chunk.size()) + "b" + chunk.constant().as_string());
factory.update_pending(pending, node); factory.update_pending(pending, node);
} else if (chunk.is_multiple()) { } else if (chunk.is_multiple()) {
vector<T> args; vector<Node> args;
for (auto const &driver : chunk.multiple().multiple()) for (auto const &driver : chunk.multiple().multiple())
args.push_back(enqueue(driver)); args.push_back(enqueue(driver));
T node = factory.multiple(args, chunk.size()); Node node = factory.multiple(args, chunk.size());
factory.update_pending(pending, node); factory.update_pending(pending, node);
} else if (chunk.is_marker()) { } else if (chunk.is_marker()) {
Cell *cell = cells[chunk.marker().marker]; Cell *cell = cells[chunk.marker().marker];
if (cell->is_mem_cell()) { if (cell->is_mem_cell()) {
Mem *mem = memories.at(cell, nullptr); Mem *mem = memories.at(cell, nullptr);
log_assert(mem != nullptr); log_assert(mem != nullptr);
T node = handle_memory(mem); Node node = handle_memory(mem);
factory.update_pending(pending, node); factory.update_pending(pending, node);
} else { } else {
dict<IdString, T> connections; dict<IdString, Node> connections;
for(auto const &conn : cell->connections()) { for(auto const &conn : cell->connections()) {
if(driver_map.celltypes.cell_input(cell->type, conn.first)) if(driver_map.celltypes.cell_input(cell->type, conn.first))
connections.insert({ conn.first, enqueue(DriveChunkPort(cell, conn)) }); connections.insert({ conn.first, enqueue(DriveChunkPort(cell, conn)) });
} }
T node = simplifier.handle(cell->type, cell->parameters, connections); Node node = simplifier.handle(cell->type, cell->parameters, connections);
factory.update_pending(pending, node); factory.update_pending(pending, node);
} }
} else if (chunk.is_none()) { } else if (chunk.is_none()) {
T node = factory.undriven(chunk.size()); Node node = factory.undriven(chunk.size());
factory.update_pending(pending, node); factory.update_pending(pending, node);
} else { } else {
log_error("unhandled drivespec: %s\n", log_signal(chunk)); log_error("unhandled drivespec: %s\n", log_signal(chunk));
@ -567,7 +544,7 @@ public:
FunctionalIR FunctionalIR::from_module(Module *module) { FunctionalIR FunctionalIR::from_module(Module *module) {
FunctionalIR ir; FunctionalIR ir;
auto factory = ir.factory(); auto factory = ir.factory();
FunctionalIRConstruction<FunctionalIR::Node, FunctionalIR::Factory> ctor(factory); FunctionalIRConstruction ctor(factory);
ctor.add_module(module); ctor.add_module(module);
ctor.process_queue(); ctor.process_queue();
ir.topological_sort(); ir.topological_sort();

View file

@ -385,78 +385,79 @@ public:
void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); } void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); }
void check_unary(Node const &a) { log_assert(a.sort().is_signal()); } void check_unary(Node const &a) { log_assert(a.sort().is_signal()); }
public: public:
Node slice(Node a, int, int offset, int out_width) { Node slice(Node a, int offset, int out_width) {
log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width()); log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width());
if(offset == 0 && out_width == a.width()) if(offset == 0 && out_width == a.width())
return a; return a;
return add(NodeData(Fn::slice, offset), Sort(out_width), {a}); return add(NodeData(Fn::slice, offset), Sort(out_width), {a});
} }
Node extend(Node a, int, int out_width, bool is_signed) { // extend will either extend or truncate the provided value to reach the desired width
Node extend(Node a, int out_width, bool is_signed) {
int in_width = a.sort().width(); int in_width = a.sort().width();
log_assert(a.sort().is_signal()); log_assert(a.sort().is_signal());
if(in_width == out_width) if(in_width == out_width)
return a; return a;
if(in_width < out_width) if(in_width > out_width)
return slice(a, in_width, 0, out_width); return slice(a, 0, out_width);
if(is_signed) if(is_signed)
return add(Fn::sign_extend, Sort(out_width), {a}); return add(Fn::sign_extend, Sort(out_width), {a});
else else
return add(Fn::zero_extend, Sort(out_width), {a}); return add(Fn::zero_extend, Sort(out_width), {a});
} }
Node concat(Node a, int, Node b, int) { Node concat(Node a, Node b) {
log_assert(a.sort().is_signal() && b.sort().is_signal()); log_assert(a.sort().is_signal() && b.sort().is_signal());
return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b}); return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b});
} }
Node add(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); } Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); }
Node sub(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); } Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); }
Node mul(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); } Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); }
Node unsigned_div(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); } Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); }
Node unsigned_mod(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); } Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); }
Node bitwise_and(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); } Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); }
Node bitwise_or(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); } Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); }
Node bitwise_xor(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); } Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); }
Node bitwise_not(Node a, int) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); } Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); }
Node unary_minus(Node a, int) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); } Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); }
Node reduce_and(Node a, int) { Node reduce_and(Node a) {
check_unary(a); check_unary(a);
if(a.width() == 1) if(a.width() == 1)
return a; return a;
return add(Fn::reduce_and, Sort(1), {a}); return add(Fn::reduce_and, Sort(1), {a});
} }
Node reduce_or(Node a, int) { Node reduce_or(Node a) {
check_unary(a); check_unary(a);
if(a.width() == 1) if(a.width() == 1)
return a; return a;
return add(Fn::reduce_or, Sort(1), {a}); return add(Fn::reduce_or, Sort(1), {a});
} }
Node reduce_xor(Node a, int) { Node reduce_xor(Node a) {
check_unary(a); check_unary(a);
if(a.width() == 1) if(a.width() == 1)
return a; return a;
return add(Fn::reduce_xor, Sort(1), {a}); return add(Fn::reduce_xor, Sort(1), {a});
} }
Node equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); } Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); }
Node not_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); } Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); }
Node signed_greater_than(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); } Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); }
Node signed_greater_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); } Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); }
Node unsigned_greater_than(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); } Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); }
Node unsigned_greater_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); } Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); }
Node logical_shift_left(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); } Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); }
Node logical_shift_right(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); } Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); }
Node arithmetic_shift_right(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); } Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); }
Node mux(Node a, Node b, Node s, int) { Node mux(Node a, Node b, Node s) {
log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1)); log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1));
return add(Fn::mux, a.sort(), {a, b, s}); return add(Fn::mux, a.sort(), {a, b, s});
} }
Node pmux(Node a, Node b, Node s, int, int) { Node pmux(Node a, Node b, Node s) {
log_assert(a.sort().is_signal() && b.sort().is_signal() && s.sort().is_signal() && a.sort().width() * s.sort().width() == b.sort().width()); log_assert(a.sort().is_signal() && b.sort().is_signal() && s.sort().is_signal() && a.sort().width() * s.sort().width() == b.sort().width());
return add(Fn::pmux, a.sort(), {a, b, s}); return add(Fn::pmux, a.sort(), {a, b, s});
} }
Node memory_read(Node mem, Node addr, int, int) { Node memory_read(Node mem, Node addr) {
log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width()); log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width());
return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr}); return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr});
} }
Node memory_write(Node mem, Node addr, Node data, int, int) { Node memory_write(Node mem, Node addr, Node data) {
log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() && log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() &&
mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width()); mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width());
return add(Fn::memory_write, mem.sort(), {mem, addr, data}); return add(Fn::memory_write, mem.sort(), {mem, addr, data});
@ -484,9 +485,6 @@ public:
_ir.add_state(name, Sort(addr_width, data_width)); _ir.add_state(name, Sort(addr_width, data_width));
return add(NodeData(Fn::state, name), Sort(addr_width, data_width), {}); return add(NodeData(Fn::state, name), Sort(addr_width, data_width), {});
} }
Node cell_output(Node node, IdString, IdString, int) {
return node;
}
Node multiple(vector<Node> args, int width) { Node multiple(vector<Node> args, int width) {
auto node = add(Fn::multiple, Sort(width), {}); auto node = add(Fn::multiple, Sort(width), {});
for(const auto &arg : args) for(const auto &arg : args)