3
0
Fork 0
mirror of https://github.com/YosysHQ/yosys synced 2025-04-06 17:44:09 +00:00

remove widths parameters from FunctionalIR factory methods and from functionalir.cc

This commit is contained in:
Emily Schmidt 2024-07-17 12:42:24 +01:00
parent 55c2c17853
commit 7f8f21b980
2 changed files with 216 additions and 241 deletions

View file

@ -98,87 +98,75 @@ std::string FunctionalIR::Node::to_string(std::function<std::string(Node)> np)
return visit(PrintVisitor(np));
}
template <class T, class Factory>
class CellSimplifier {
Factory &factory;
T reduce_shift_width(T b, int b_width, int y_width, int &reduced_b_width) {
using Node = FunctionalIR::Node;
FunctionalIR::Factory &factory;
Node reduce_shift_width(Node b, int y_width) {
log_assert(y_width > 0);
int new_width = ceil_log2(y_width + 1);
if (b_width <= new_width) {
reduced_b_width = b_width;
if (b.width() <= new_width) {
return b;
} else {
reduced_b_width = new_width;
T lower_b = factory.slice(b, b_width, 0, new_width);
T overflow = factory.unsigned_greater_than(b, factory.constant(RTLIL::Const(y_width, b_width)), b_width);
return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow, new_width);
Node lower_b = factory.slice(b, 0, new_width);
Node overflow = factory.unsigned_greater_than(b, factory.constant(RTLIL::Const(y_width, b.width())));
return factory.mux(lower_b, factory.constant(RTLIL::Const(y_width, new_width)), overflow);
}
}
T sign(T a, int a_width) {
return factory.slice(a, a_width, a_width - 1, 1);
Node sign(Node a) {
return factory.slice(a, a.width() - 1, 1);
}
T neg_if(T a, int a_width, T s) {
return factory.mux(a, factory.unary_minus(a, a_width), s, a_width);
Node neg_if(Node a, Node s) {
return factory.mux(a, factory.unary_minus(a), s);
}
T abs(T a, int a_width) {
return neg_if(a, a_width, sign(a, a_width));
Node abs(Node a) {
return neg_if(a, sign(a));
}
public:
T extend(T a, int in_width, int out_width, bool is_signed) {
if(in_width == out_width)
return a;
if(in_width > out_width)
return factory.slice(a, in_width, 0, out_width);
return factory.extend(a, in_width, out_width, is_signed);
Node logical_shift_left(Node a, Node b) {
Node reduced_b = reduce_shift_width(b, a.width());
return factory.logical_shift_left(a, reduced_b);
}
T logical_shift_left(T a, T b, int y_width, int b_width) {
int reduced_b_width;
T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width);
return factory.logical_shift_left(a, reduced_b, y_width, reduced_b_width);
Node logical_shift_right(Node a, Node b) {
Node reduced_b = reduce_shift_width(b, a.width());
return factory.logical_shift_right(a, reduced_b);
}
T logical_shift_right(T a, T b, int y_width, int b_width) {
int reduced_b_width;
T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width);
return factory.logical_shift_right(a, reduced_b, y_width, reduced_b_width);
Node arithmetic_shift_right(Node a, Node b) {
Node reduced_b = reduce_shift_width(b, a.width());
return factory.arithmetic_shift_right(a, reduced_b);
}
T arithmetic_shift_right(T a, T b, int y_width, int b_width) {
int reduced_b_width;
T reduced_b = reduce_shift_width(b, b_width, y_width, reduced_b_width);
return factory.arithmetic_shift_right(a, reduced_b, y_width, reduced_b_width);
Node bitwise_mux(Node a, Node b, Node s) {
Node aa = factory.bitwise_and(a, factory.bitwise_not(s));
Node bb = factory.bitwise_and(b, s);
return factory.bitwise_or(aa, bb);
}
T bitwise_mux(T a, T b, T s, int width) {
T aa = factory.bitwise_and(a, factory.bitwise_not(s, width), width);
T bb = factory.bitwise_and(b, s, width);
return factory.bitwise_or(aa, bb, width);
}
CellSimplifier(Factory &f) : factory(f) {}
CellSimplifier(FunctionalIR::Factory &f) : factory(f) {}
private:
T handle_pow(T a0, int a_width, T b, int b_width, int y_width, bool is_signed) {
T a = extend(a0, a_width, y_width, is_signed);
T r = factory.constant(Const(1, y_width));
for(int i = 0; i < b_width; i++) {
T b_bit = factory.slice(b, b_width, i, 1);
r = factory.mux(r, factory.mul(r, a, y_width), b_bit, y_width);
a = factory.mul(a, a, y_width);
Node handle_pow(Node a0, Node b, int y_width, bool is_signed) {
Node a = factory.extend(a0, y_width, is_signed);
Node r = factory.constant(Const(1, y_width));
for(int i = 0; i < b.width(); i++) {
Node b_bit = factory.slice(b, i, 1);
r = factory.mux(r, factory.mul(r, a), b_bit);
a = factory.mul(a, a);
}
if (is_signed) {
T a_ge_1 = factory.unsigned_greater_than(abs(a0, a_width), factory.constant(Const(1, a_width)), a_width);
T zero_result = factory.bitwise_and(a_ge_1, sign(b, b_width), 1);
r = factory.mux(r, factory.constant(Const(0, y_width)), zero_result, y_width);
Node a_ge_1 = factory.unsigned_greater_than(abs(a0), factory.constant(Const(1, a0.width())));
Node zero_result = factory.bitwise_and(a_ge_1, sign(b));
r = factory.mux(r, factory.constant(Const(0, y_width)), zero_result);
}
return r;
}
T handle_bmux(T a, T s, int a_width, int a_offset, int width, int s_width, int sn) {
Node handle_bmux(Node a, Node s, int a_offset, int width, int sn) {
if(sn < 1)
return factory.slice(a, a_width, a_offset, width);
return factory.slice(a, a_offset, width);
else {
T y0 = handle_bmux(a, s, a_width, a_offset, width, s_width, sn - 1);
T y1 = handle_bmux(a, s, a_width, a_offset + (width << (sn - 1)), width, s_width, sn - 1);
return factory.mux(y0, y1, factory.slice(s, s_width, sn - 1, 1), width);
Node y0 = handle_bmux(a, s, a_offset, width, sn - 1);
Node y1 = handle_bmux(a, s, a_offset + (width << (sn - 1)), width, sn - 1);
return factory.mux(y0, y1, factory.slice(s, sn - 1, 1));
}
}
public:
T handle(IdString cellType, dict<IdString, Const> parameters, dict<IdString, T> inputs)
Node handle(IdString cellType, dict<IdString, Const> parameters, dict<IdString, Node> inputs)
{
int a_width = parameters.at(ID(A_WIDTH), Const(-1)).as_int();
int b_width = parameters.at(ID(B_WIDTH), Const(-1)).as_int();
@ -187,208 +175,202 @@ public:
bool b_signed = parameters.at(ID(B_SIGNED), Const(0)).as_bool();
if(cellType.in({ID($add), ID($sub), ID($and), ID($or), ID($xor), ID($xnor), ID($mul)})){
bool is_signed = a_signed && b_signed;
T a = extend(inputs.at(ID(A)), a_width, y_width, is_signed);
T b = extend(inputs.at(ID(B)), b_width, y_width, is_signed);
Node a = factory.extend(inputs.at(ID(A)), y_width, is_signed);
Node b = factory.extend(inputs.at(ID(B)), y_width, is_signed);
if(cellType == ID($add))
return factory.add(a, b, y_width);
return factory.add(a, b);
else if(cellType == ID($sub))
return factory.sub(a, b, y_width);
return factory.sub(a, b);
else if(cellType == ID($mul))
return factory.mul(a, b, y_width);
return factory.mul(a, b);
else if(cellType == ID($and))
return factory.bitwise_and(a, b, y_width);
return factory.bitwise_and(a, b);
else if(cellType == ID($or))
return factory.bitwise_or(a, b, y_width);
return factory.bitwise_or(a, b);
else if(cellType == ID($xor))
return factory.bitwise_xor(a, b, y_width);
return factory.bitwise_xor(a, b);
else if(cellType == ID($xnor))
return factory.bitwise_not(factory.bitwise_xor(a, b, y_width), y_width);
return factory.bitwise_not(factory.bitwise_xor(a, b));
else
log_abort();
}else if(cellType.in({ID($eq), ID($ne), ID($eqx), ID($nex), ID($le), ID($lt), ID($ge), ID($gt)})){
bool is_signed = a_signed && b_signed;
int width = max(a_width, b_width);
T a = extend(inputs.at(ID(A)), a_width, width, is_signed);
T b = extend(inputs.at(ID(B)), b_width, width, is_signed);
Node a = factory.extend(inputs.at(ID(A)), width, is_signed);
Node b = factory.extend(inputs.at(ID(B)), width, is_signed);
if(cellType.in({ID($eq), ID($eqx)}))
return extend(factory.equal(a, b, width), 1, y_width, false);
return factory.extend(factory.equal(a, b), y_width, false);
else if(cellType.in({ID($ne), ID($nex)}))
return extend(factory.not_equal(a, b, width), 1, y_width, false);
return factory.extend(factory.not_equal(a, b), y_width, false);
else if(cellType == ID($lt))
return extend(is_signed ? factory.signed_greater_than(b, a, width) : factory.unsigned_greater_than(b, a, width), 1, y_width, false);
return factory.extend(is_signed ? factory.signed_greater_than(b, a) : factory.unsigned_greater_than(b, a), y_width, false);
else if(cellType == ID($le))
return extend(is_signed ? factory.signed_greater_equal(b, a, width) : factory.unsigned_greater_equal(b, a, width), 1, y_width, false);
return factory.extend(is_signed ? factory.signed_greater_equal(b, a) : factory.unsigned_greater_equal(b, a), y_width, false);
else if(cellType == ID($gt))
return extend(is_signed ? factory.signed_greater_than(a, b, width) : factory.unsigned_greater_than(a, b, width), 1, y_width, false);
return factory.extend(is_signed ? factory.signed_greater_than(a, b) : factory.unsigned_greater_than(a, b), y_width, false);
else if(cellType == ID($ge))
return extend(is_signed ? factory.signed_greater_equal(a, b, width) : factory.unsigned_greater_equal(a, b, width), 1, y_width, false);
return factory.extend(is_signed ? factory.signed_greater_equal(a, b) : factory.unsigned_greater_equal(a, b), y_width, false);
else
log_abort();
}else if(cellType.in({ID($logic_or), ID($logic_and)})){
T a = factory.reduce_or(inputs.at(ID(A)), a_width);
T b = factory.reduce_or(inputs.at(ID(B)), b_width);
T y = cellType == ID($logic_and) ? factory.bitwise_and(a, b, 1) : factory.bitwise_or(a, b, 1);
return extend(y, 1, y_width, false);
Node a = factory.reduce_or(inputs.at(ID(A)));
Node b = factory.reduce_or(inputs.at(ID(B)));
Node y = cellType == ID($logic_and) ? factory.bitwise_and(a, b) : factory.bitwise_or(a, b);
return factory.extend(y, y_width, false);
}else if(cellType == ID($not)){
T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed);
return factory.bitwise_not(a, y_width);
Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed);
return factory.bitwise_not(a);
}else if(cellType == ID($pos)){
return extend(inputs.at(ID(A)), a_width, y_width, a_signed);
return factory.extend(inputs.at(ID(A)), y_width, a_signed);
}else if(cellType == ID($neg)){
T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed);
return factory.unary_minus(a, y_width);
Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed);
return factory.unary_minus(a);
}else if(cellType == ID($logic_not)){
T a = factory.reduce_or(inputs.at(ID(A)), a_width);
T y = factory.bitwise_not(a, 1);
return extend(y, 1, y_width, false);
Node a = factory.reduce_or(inputs.at(ID(A)));
Node y = factory.bitwise_not(a);
return factory.extend(y, y_width, false);
}else if(cellType.in({ID($reduce_or), ID($reduce_bool)})){
T a = factory.reduce_or(inputs.at(ID(A)), a_width);
return extend(a, 1, y_width, false);
Node a = factory.reduce_or(inputs.at(ID(A)));
return factory.extend(a, y_width, false);
}else if(cellType == ID($reduce_and)){
T a = factory.reduce_and(inputs.at(ID(A)), a_width);
return extend(a, 1, y_width, false);
Node a = factory.reduce_and(inputs.at(ID(A)));
return factory.extend(a, y_width, false);
}else if(cellType.in({ID($reduce_xor), ID($reduce_xnor)})){
T a = factory.reduce_xor(inputs.at(ID(A)), a_width);
T y = cellType == ID($reduce_xnor) ? factory.bitwise_not(a, 1) : a;
return extend(y, 1, y_width, false);
Node a = factory.reduce_xor(inputs.at(ID(A)));
Node y = cellType == ID($reduce_xnor) ? factory.bitwise_not(a) : a;
return factory.extend(y, y_width, false);
}else if(cellType == ID($shl) || cellType == ID($sshl)){
T a = extend(inputs.at(ID(A)), a_width, y_width, a_signed);
T b = inputs.at(ID(B));
return logical_shift_left(a, b, y_width, b_width);
Node a = factory.extend(inputs.at(ID(A)), y_width, a_signed);
Node b = inputs.at(ID(B));
return logical_shift_left(a, b);
}else if(cellType == ID($shr) || cellType == ID($sshr)){
int width = max(a_width, y_width);
T a = extend(inputs.at(ID(A)), a_width, width, a_signed);
T b = inputs.at(ID(B));
T y = a_signed && cellType == ID($sshr) ?
arithmetic_shift_right(a, b, width, b_width) :
logical_shift_right(a, b, width, b_width);
return extend(y, width, y_width, a_signed);
Node a = factory.extend(inputs.at(ID(A)), width, a_signed);
Node b = inputs.at(ID(B));
Node y = a_signed && cellType == ID($sshr) ?
arithmetic_shift_right(a, b) :
logical_shift_right(a, b);
return factory.extend(y, y_width, a_signed);
}else if(cellType == ID($shiftx) || cellType == ID($shift)){
int width = max(a_width, y_width);
T a = extend(inputs.at(ID(A)), a_width, width, cellType == ID($shift) && a_signed);
T b = inputs.at(ID(B));
T shr = logical_shift_right(a, b, width, b_width);
Node a = factory.extend(inputs.at(ID(A)), width, cellType == ID($shift) && a_signed);
Node b = inputs.at(ID(B));
Node shr = logical_shift_right(a, b);
if(b_signed) {
T sign_b = sign(b, b_width);
T shl = logical_shift_left(a, factory.unary_minus(b, b_width), width, b_width);
T y = factory.mux(shr, shl, sign_b, width);
return extend(y, width, y_width, false);
Node shl = logical_shift_left(a, factory.unary_minus(b));
Node y = factory.mux(shr, shl, sign(b));
return factory.extend(y, y_width, false);
} else {
return extend(shr, width, y_width, false);
return factory.extend(shr, y_width, false);
}
}else if(cellType == ID($mux)){
int width = parameters.at(ID(WIDTH)).as_int();
return factory.mux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)), width);
return factory.mux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)));
}else if(cellType == ID($pmux)){
int width = parameters.at(ID(WIDTH)).as_int();
int s_width = parameters.at(ID(S_WIDTH)).as_int();
return factory.pmux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)), width, s_width);
return factory.pmux(inputs.at(ID(A)), inputs.at(ID(B)), inputs.at(ID(S)));
}else if(cellType == ID($concat)){
T a = inputs.at(ID(A));
T b = inputs.at(ID(B));
return factory.concat(a, a_width, b, b_width);
Node a = inputs.at(ID(A));
Node b = inputs.at(ID(B));
return factory.concat(a, b);
}else if(cellType == ID($slice)){
int offset = parameters.at(ID(OFFSET)).as_int();
T a = inputs.at(ID(A));
return factory.slice(a, a_width, offset, y_width);
Node a = inputs.at(ID(A));
return factory.slice(a, offset, y_width);
}else if(cellType.in({ID($div), ID($mod), ID($divfloor), ID($modfloor)})) {
int width = max(a_width, b_width);
bool is_signed = a_signed && b_signed;
T a = extend(inputs.at(ID(A)), a_width, width, is_signed);
T b = extend(inputs.at(ID(B)), b_width, width, is_signed);
Node a = factory.extend(inputs.at(ID(A)), width, is_signed);
Node b = factory.extend(inputs.at(ID(B)), width, is_signed);
if(is_signed) {
if(cellType == ID($div)) {
// divide absolute values, then flip the sign if input signs differ
// but extend the width first, to handle the case (most negative value) / (-1)
T abs_y = factory.unsigned_div(abs(a, width), abs(b, width), width);
T out_sign = factory.not_equal(sign(a, width), sign(b, width), 1);
return neg_if(extend(abs_y, width, y_width, false), y_width, out_sign);
Node abs_y = factory.unsigned_div(abs(a), abs(b));
Node out_sign = factory.not_equal(sign(a), sign(b));
return neg_if(factory.extend(abs_y, y_width, false), out_sign);
} else if(cellType == ID($mod)) {
// similar to division but output sign == divisor sign
T abs_y = factory.unsigned_mod(abs(a, width), abs(b, width), width);
return neg_if(extend(abs_y, width, y_width, false), y_width, sign(a, width));
Node abs_y = factory.unsigned_mod(abs(a), abs(b));
return neg_if(factory.extend(abs_y, y_width, false), sign(a));
} else if(cellType == ID($divfloor)) {
// if b is negative, flip both signs so that b is positive
T b_sign = sign(b, width);
T a1 = neg_if(a, width, b_sign);
T b1 = neg_if(b, width, b_sign);
Node b_sign = sign(b);
Node a1 = neg_if(a, b_sign);
Node b1 = neg_if(b, b_sign);
// if a is now negative, calculate ~((~a) / b) = -((-a - 1) / b + 1)
// which equals the negative of (-a) / b with rounding up rather than down
// note that to handle the case where a = most negative value properly,
// we have to calculate a1_sign from the original values rather than using sign(a1, width)
T a1_sign = factory.bitwise_and(factory.not_equal(sign(a, width), sign(b, width), 1), factory.reduce_or(a, width), 1);
T a2 = factory.mux(a1, factory.bitwise_not(a1, width), a1_sign, width);
T y1 = factory.unsigned_div(a2, b1, width);
T y2 = extend(y1, width, y_width, false);
return factory.mux(y2, factory.bitwise_not(y2, y_width), a1_sign, y_width);
// we have to calculate a1_sign from the original values rather than using sign(a1)
Node a1_sign = factory.bitwise_and(factory.not_equal(sign(a), sign(b)), factory.reduce_or(a));
Node a2 = factory.mux(a1, factory.bitwise_not(a1), a1_sign);
Node y1 = factory.unsigned_div(a2, b1);
Node y2 = factory.extend(y1, y_width, false);
return factory.mux(y2, factory.bitwise_not(y2), a1_sign);
} else if(cellType == ID($modfloor)) {
// calculate |a| % |b| and then subtract from |b| if input signs differ and the remainder is non-zero
T abs_b = abs(b, width);
T abs_y = factory.unsigned_mod(abs(a, width), abs_b, width);
T flip_y = factory.bitwise_and(factory.bitwise_xor(sign(a, width), sign(b, width), 1), factory.reduce_or(abs_y, width), 1);
T y_flipped = factory.mux(abs_y, factory.sub(abs_b, abs_y, width), flip_y, width);
Node abs_b = abs(b);
Node abs_y = factory.unsigned_mod(abs(a), abs_b);
Node flip_y = factory.bitwise_and(factory.bitwise_xor(sign(a), sign(b)), factory.reduce_or(abs_y));
Node y_flipped = factory.mux(abs_y, factory.sub(abs_b, abs_y), flip_y);
// since y_flipped is strictly less than |b|, the top bit is always 0 and we can just sign extend the flipped result
T y = neg_if(y_flipped, width, sign(b, b_width));
return extend(y, width, y_width, true);
Node y = neg_if(y_flipped, sign(b));
return factory.extend(y, y_width, true);
} else
log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str());
} else {
if(cellType.in({ID($mod), ID($modfloor)}))
return extend(factory.unsigned_mod(a, b, width), width, y_width, false);
return factory.extend(factory.unsigned_mod(a, b), y_width, false);
else
return extend(factory.unsigned_div(a, b, width), width, y_width, false);
return factory.extend(factory.unsigned_div(a, b), y_width, false);
}
} else if(cellType == ID($pow)) {
return handle_pow(inputs.at(ID(A)), a_width, inputs.at(ID(B)), b_width, y_width, a_signed && b_signed);
return handle_pow(inputs.at(ID(A)), inputs.at(ID(B)), y_width, a_signed && b_signed);
} else if (cellType == ID($lut)) {
int width = parameters.at(ID(WIDTH)).as_int();
Const lut_table = parameters.at(ID(LUT));
lut_table.extu(1 << width);
return handle_bmux(factory.constant(lut_table), inputs.at(ID(A)), 1 << width, 0, 1, width, width);
return handle_bmux(factory.constant(lut_table), inputs.at(ID(A)), 0, 1, width);
} else if (cellType == ID($bwmux)) {
int width = parameters.at(ID(WIDTH)).as_int();
T a = inputs.at(ID(A));
T b = inputs.at(ID(B));
T s = inputs.at(ID(S));
Node a = inputs.at(ID(A));
Node b = inputs.at(ID(B));
Node s = inputs.at(ID(S));
return factory.bitwise_or(
factory.bitwise_and(a, factory.bitwise_not(s, width), width),
factory.bitwise_and(b, s, width), width);
factory.bitwise_and(a, factory.bitwise_not(s)),
factory.bitwise_and(b, s));
} else if (cellType == ID($bweqx)) {
int width = parameters.at(ID(WIDTH)).as_int();
T a = inputs.at(ID(A));
T b = inputs.at(ID(B));
return factory.bitwise_not(factory.bitwise_xor(a, b, width), width);
Node a = inputs.at(ID(A));
Node b = inputs.at(ID(B));
return factory.bitwise_not(factory.bitwise_xor(a, b));
} else if(cellType == ID($bmux)) {
int width = parameters.at(ID(WIDTH)).as_int();
int s_width = parameters.at(ID(S_WIDTH)).as_int();
return handle_bmux(inputs.at(ID(A)), inputs.at(ID(S)), width << s_width, 0, width, s_width, s_width);
return handle_bmux(inputs.at(ID(A)), inputs.at(ID(S)), 0, width, s_width);
} else if(cellType == ID($demux)) {
int width = parameters.at(ID(WIDTH)).as_int();
int s_width = parameters.at(ID(S_WIDTH)).as_int();
int y_width = width << s_width;
int b_width = ceil_log2(y_width + 1);
T a = extend(inputs.at(ID(A)), width, y_width, false);
T s = factory.extend(inputs.at(ID(S)), s_width, b_width, false);
T b = factory.mul(s, factory.constant(Const(width, b_width)), b_width);
return factory.logical_shift_left(a, b, y_width, b_width);
Node a = factory.extend(inputs.at(ID(A)), y_width, false);
Node s = factory.extend(inputs.at(ID(S)), b_width, false);
Node b = factory.mul(s, factory.constant(Const(width, b_width)));
return factory.logical_shift_left(a, b);
} else {
log_error("unhandled cell in CellSimplifier %s\n", cellType.c_str());
}
}
};
template <class T, class Factory>
class FunctionalIRConstruction {
using Node = FunctionalIR::Node;
std::deque<DriveSpec> queue;
dict<DriveSpec, T> graph_nodes;
dict<DriveSpec, Node> graph_nodes;
idict<Cell *> cells;
DriverMap driver_map;
Factory& factory;
CellSimplifier<T, Factory> simplifier;
FunctionalIR::Factory& factory;
CellSimplifier simplifier;
vector<Mem> memories_vector;
dict<Cell*, Mem*> memories;
T enqueue(DriveSpec const &spec)
Node enqueue(DriveSpec const &spec)
{
auto it = graph_nodes.find(spec);
if(it == graph_nodes.end()){
@ -400,7 +382,7 @@ class FunctionalIRConstruction {
return it->second;
}
public:
FunctionalIRConstruction(Factory &f) : factory(f), simplifier(f) {}
FunctionalIRConstruction(FunctionalIR::Factory &f) : factory(f), simplifier(f) {}
void add_module(Module *module)
{
driver_map.add(module);
@ -410,7 +392,7 @@ public:
}
for (auto wire : module->wires()) {
if (wire->port_output) {
T node = enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width)));
Node node = enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width)));
factory.declare_output(node, wire->name, wire->width);
}
}
@ -420,37 +402,34 @@ public:
memories[mem.cell] = &mem;
}
}
T concatenate_read_results(Mem *, vector<T> results)
Node concatenate_read_results(Mem *, vector<Node> results)
{
/* TODO: write code to check that this is ok to do */
if(results.size() == 0)
return factory.undriven(0);
T node = results[0];
int size = results[0].width();
for(size_t i = 1; i < results.size(); i++) {
node = factory.concat(node, size, results[i], results[i].width());
size += results[i].width();
}
Node node = results[0];
for(size_t i = 1; i < results.size(); i++)
node = factory.concat(node, results[i]);
return node;
}
T handle_memory(Mem *mem)
Node handle_memory(Mem *mem)
{
vector<T> read_results;
vector<Node> read_results;
int addr_width = ceil_log2(mem->size);
int data_width = mem->width;
T node = factory.state_memory(mem->cell->name, addr_width, data_width);
Node node = factory.state_memory(mem->cell->name, addr_width, data_width);
for (auto &rd : mem->rd_ports) {
log_assert(!rd.clk_enable);
T addr = enqueue(driver_map(DriveSpec(rd.addr)));
read_results.push_back(factory.memory_read(node, addr, addr_width, data_width));
Node addr = enqueue(driver_map(DriveSpec(rd.addr)));
read_results.push_back(factory.memory_read(node, addr));
}
for (auto &wr : mem->wr_ports) {
T en = enqueue(driver_map(DriveSpec(wr.en)));
T addr = enqueue(driver_map(DriveSpec(wr.addr)));
T new_data = enqueue(driver_map(DriveSpec(wr.data)));
T old_data = factory.memory_read(node, addr, addr_width, data_width);
T wr_data = simplifier.bitwise_mux(old_data, new_data, en, data_width);
node = factory.memory_write(node, addr, wr_data, addr_width, data_width);
Node en = enqueue(driver_map(DriveSpec(wr.en)));
Node addr = enqueue(driver_map(DriveSpec(wr.addr)));
Node new_data = enqueue(driver_map(DriveSpec(wr.data)));
Node old_data = factory.memory_read(node, addr);
Node wr_data = simplifier.bitwise_mux(old_data, new_data, en);
node = factory.memory_write(node, addr, wr_data);
}
factory.declare_state_memory(node, mem->cell->name, addr_width, data_width);
return concatenate_read_results(mem, read_results);
@ -459,16 +438,13 @@ public:
{
for (; !queue.empty(); queue.pop_front()) {
DriveSpec spec = queue.front();
T pending = graph_nodes.at(spec);
Node pending = graph_nodes.at(spec);
if (spec.chunks().size() > 1) {
auto chunks = spec.chunks();
T node = enqueue(chunks[0]);
int width = chunks[0].size();
for(size_t i = 1; i < chunks.size(); i++) {
node = factory.concat(node, width, enqueue(chunks[i]), chunks[i].size());
width += chunks[i].size();
}
Node node = enqueue(chunks[0]);
for(size_t i = 1; i < chunks.size(); i++)
node = factory.concat(node, enqueue(chunks[i]));
factory.update_pending(pending, node);
} else if (spec.chunks().size() == 1) {
DriveChunk chunk = spec.chunks()[0];
@ -476,18 +452,18 @@ public:
DriveChunkWire wire_chunk = chunk.wire();
if (wire_chunk.is_whole()) {
if (wire_chunk.wire->port_input) {
T node = factory.input(wire_chunk.wire->name, wire_chunk.width);
Node node = factory.input(wire_chunk.wire->name, wire_chunk.width);
factory.suggest_name(node, wire_chunk.wire->name);
factory.update_pending(pending, node);
} else {
DriveSpec driver = driver_map(DriveSpec(wire_chunk));
T node = enqueue(driver);
Node node = enqueue(driver);
factory.suggest_name(node, wire_chunk.wire->name);
factory.update_pending(pending, node);
}
} else {
DriveChunkWire whole_wire(wire_chunk.wire, 0, wire_chunk.wire->width);
T node = factory.slice(enqueue(whole_wire), wire_chunk.wire->width, wire_chunk.offset, wire_chunk.width);
Node node = factory.slice(enqueue(whole_wire), wire_chunk.offset, wire_chunk.width);
factory.update_pending(pending, node);
}
} else if (chunk.is_port()) {
@ -497,21 +473,22 @@ public:
if (port_chunk.cell->type.in(ID($dff), ID($ff)))
{
Cell *cell = port_chunk.cell;
T node = factory.state(cell->name, port_chunk.width);
Node node = factory.state(cell->name, port_chunk.width);
factory.suggest_name(node, port_chunk.cell->name);
factory.update_pending(pending, node);
for (auto const &conn : cell->connections()) {
if (driver_map.celltypes.cell_input(cell->type, conn.first)) {
T node = enqueue(DriveChunkPort(cell, conn));
Node node = enqueue(DriveChunkPort(cell, conn));
factory.declare_state(node, cell->name, port_chunk.width);
}
}
}
else
{
T cell = enqueue(DriveChunkMarker(cells(port_chunk.cell), 0, port_chunk.width));
Node cell = enqueue(DriveChunkMarker(cells(port_chunk.cell), 0, port_chunk.width));
factory.suggest_name(cell, port_chunk.cell->name);
T node = factory.cell_output(cell, port_chunk.cell->type, port_chunk.port, port_chunk.width);
//Node node = factory.cell_output(cell, port_chunk.cell->type, port_chunk.port, port_chunk.width);
Node node = cell;
factory.suggest_name(node, port_chunk.cell->name.str() + "$" + port_chunk.port.str());
factory.update_pending(pending, node);
}
@ -521,37 +498,37 @@ public:
}
} else {
DriveChunkPort whole_port(port_chunk.cell, port_chunk.port, 0, GetSize(port_chunk.cell->connections().at(port_chunk.port)));
T node = factory.slice(enqueue(whole_port), whole_port.width, port_chunk.offset, port_chunk.width);
Node node = factory.slice(enqueue(whole_port), port_chunk.offset, port_chunk.width);
factory.update_pending(pending, node);
}
} else if (chunk.is_constant()) {
T node = factory.constant(chunk.constant());
Node node = factory.constant(chunk.constant());
factory.suggest_name(node, "$const" + std::to_string(chunk.size()) + "b" + chunk.constant().as_string());
factory.update_pending(pending, node);
} else if (chunk.is_multiple()) {
vector<T> args;
vector<Node> args;
for (auto const &driver : chunk.multiple().multiple())
args.push_back(enqueue(driver));
T node = factory.multiple(args, chunk.size());
Node node = factory.multiple(args, chunk.size());
factory.update_pending(pending, node);
} else if (chunk.is_marker()) {
Cell *cell = cells[chunk.marker().marker];
if (cell->is_mem_cell()) {
Mem *mem = memories.at(cell, nullptr);
log_assert(mem != nullptr);
T node = handle_memory(mem);
Node node = handle_memory(mem);
factory.update_pending(pending, node);
} else {
dict<IdString, T> connections;
dict<IdString, Node> connections;
for(auto const &conn : cell->connections()) {
if(driver_map.celltypes.cell_input(cell->type, conn.first))
connections.insert({ conn.first, enqueue(DriveChunkPort(cell, conn)) });
}
T node = simplifier.handle(cell->type, cell->parameters, connections);
Node node = simplifier.handle(cell->type, cell->parameters, connections);
factory.update_pending(pending, node);
}
} else if (chunk.is_none()) {
T node = factory.undriven(chunk.size());
Node node = factory.undriven(chunk.size());
factory.update_pending(pending, node);
} else {
log_error("unhandled drivespec: %s\n", log_signal(chunk));
@ -567,7 +544,7 @@ public:
FunctionalIR FunctionalIR::from_module(Module *module) {
FunctionalIR ir;
auto factory = ir.factory();
FunctionalIRConstruction<FunctionalIR::Node, FunctionalIR::Factory> ctor(factory);
FunctionalIRConstruction ctor(factory);
ctor.add_module(module);
ctor.process_queue();
ir.topological_sort();

View file

@ -385,78 +385,79 @@ public:
void check_shift(Node const &a, Node const &b) { log_assert(a.sort().is_signal() && b.sort().is_signal()); }
void check_unary(Node const &a) { log_assert(a.sort().is_signal()); }
public:
Node slice(Node a, int, int offset, int out_width) {
Node slice(Node a, int offset, int out_width) {
log_assert(a.sort().is_signal() && offset + out_width <= a.sort().width());
if(offset == 0 && out_width == a.width())
return a;
return add(NodeData(Fn::slice, offset), Sort(out_width), {a});
}
Node extend(Node a, int, int out_width, bool is_signed) {
// extend will either extend or truncate the provided value to reach the desired width
Node extend(Node a, int out_width, bool is_signed) {
int in_width = a.sort().width();
log_assert(a.sort().is_signal());
if(in_width == out_width)
return a;
if(in_width < out_width)
return slice(a, in_width, 0, out_width);
if(in_width > out_width)
return slice(a, 0, out_width);
if(is_signed)
return add(Fn::sign_extend, Sort(out_width), {a});
else
return add(Fn::zero_extend, Sort(out_width), {a});
}
Node concat(Node a, int, Node b, int) {
Node concat(Node a, Node b) {
log_assert(a.sort().is_signal() && b.sort().is_signal());
return add(Fn::concat, Sort(a.sort().width() + b.sort().width()), {a, b});
}
Node add(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); }
Node sub(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); }
Node mul(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); }
Node unsigned_div(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); }
Node unsigned_mod(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); }
Node bitwise_and(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); }
Node bitwise_or(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); }
Node bitwise_xor(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); }
Node bitwise_not(Node a, int) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); }
Node unary_minus(Node a, int) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); }
Node reduce_and(Node a, int) {
Node add(Node a, Node b) { check_basic_binary(a, b); return add(Fn::add, a.sort(), {a, b}); }
Node sub(Node a, Node b) { check_basic_binary(a, b); return add(Fn::sub, a.sort(), {a, b}); }
Node mul(Node a, Node b) { check_basic_binary(a, b); return add(Fn::mul, a.sort(), {a, b}); }
Node unsigned_div(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_div, a.sort(), {a, b}); }
Node unsigned_mod(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_mod, a.sort(), {a, b}); }
Node bitwise_and(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_and, a.sort(), {a, b}); }
Node bitwise_or(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_or, a.sort(), {a, b}); }
Node bitwise_xor(Node a, Node b) { check_basic_binary(a, b); return add(Fn::bitwise_xor, a.sort(), {a, b}); }
Node bitwise_not(Node a) { check_unary(a); return add(Fn::bitwise_not, a.sort(), {a}); }
Node unary_minus(Node a) { check_unary(a); return add(Fn::unary_minus, a.sort(), {a}); }
Node reduce_and(Node a) {
check_unary(a);
if(a.width() == 1)
return a;
return add(Fn::reduce_and, Sort(1), {a});
}
Node reduce_or(Node a, int) {
Node reduce_or(Node a) {
check_unary(a);
if(a.width() == 1)
return a;
return add(Fn::reduce_or, Sort(1), {a});
}
Node reduce_xor(Node a, int) {
Node reduce_xor(Node a) {
check_unary(a);
if(a.width() == 1)
return a;
return add(Fn::reduce_xor, Sort(1), {a});
}
Node equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); }
Node not_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); }
Node signed_greater_than(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); }
Node signed_greater_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); }
Node unsigned_greater_than(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); }
Node unsigned_greater_equal(Node a, Node b, int) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); }
Node logical_shift_left(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); }
Node logical_shift_right(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); }
Node arithmetic_shift_right(Node a, Node b, int, int) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); }
Node mux(Node a, Node b, Node s, int) {
Node equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::equal, Sort(1), {a, b}); }
Node not_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::not_equal, Sort(1), {a, b}); }
Node signed_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_than, Sort(1), {a, b}); }
Node signed_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::signed_greater_equal, Sort(1), {a, b}); }
Node unsigned_greater_than(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_than, Sort(1), {a, b}); }
Node unsigned_greater_equal(Node a, Node b) { check_basic_binary(a, b); return add(Fn::unsigned_greater_equal, Sort(1), {a, b}); }
Node logical_shift_left(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_left, a.sort(), {a, b}); }
Node logical_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::logical_shift_right, a.sort(), {a, b}); }
Node arithmetic_shift_right(Node a, Node b) { check_shift(a, b); return add(Fn::arithmetic_shift_right, a.sort(), {a, b}); }
Node mux(Node a, Node b, Node s) {
log_assert(a.sort().is_signal() && a.sort() == b.sort() && s.sort() == Sort(1));
return add(Fn::mux, a.sort(), {a, b, s});
}
Node pmux(Node a, Node b, Node s, int, int) {
Node pmux(Node a, Node b, Node s) {
log_assert(a.sort().is_signal() && b.sort().is_signal() && s.sort().is_signal() && a.sort().width() * s.sort().width() == b.sort().width());
return add(Fn::pmux, a.sort(), {a, b, s});
}
Node memory_read(Node mem, Node addr, int, int) {
Node memory_read(Node mem, Node addr) {
log_assert(mem.sort().is_memory() && addr.sort().is_signal() && mem.sort().addr_width() == addr.sort().width());
return add(Fn::memory_read, Sort(mem.sort().data_width()), {mem, addr});
}
Node memory_write(Node mem, Node addr, Node data, int, int) {
Node memory_write(Node mem, Node addr, Node data) {
log_assert(mem.sort().is_memory() && addr.sort().is_signal() && data.sort().is_signal() &&
mem.sort().addr_width() == addr.sort().width() && mem.sort().data_width() == data.sort().width());
return add(Fn::memory_write, mem.sort(), {mem, addr, data});
@ -484,9 +485,6 @@ public:
_ir.add_state(name, Sort(addr_width, data_width));
return add(NodeData(Fn::state, name), Sort(addr_width, data_width), {});
}
Node cell_output(Node node, IdString, IdString, int) {
return node;
}
Node multiple(vector<Node> args, int width) {
auto node = add(Fn::multiple, Sort(width), {});
for(const auto &arg : args)