mirror of
https://github.com/YosysHQ/yosys
synced 2025-04-13 04:28:18 +00:00
Added pivoting to qwp solver
This commit is contained in:
parent
69071bbc5f
commit
ec92c89659
|
@ -255,33 +255,62 @@ struct QwpWorker
|
||||||
// (least squares fit for "A*x = y")
|
// (least squares fit for "A*x = y")
|
||||||
//
|
//
|
||||||
// Using gaussian elimination to get M := [Id x]
|
// Using gaussian elimination to get M := [Id x]
|
||||||
// (no pivoting, so let's hope for the best..)
|
|
||||||
|
|
||||||
// eliminate to upper triangular matrix
|
vector<int> pivot_cache;
|
||||||
|
vector<int> queue;
|
||||||
|
|
||||||
|
for (int i = 0; i < N; i++)
|
||||||
|
queue.push_back(i);
|
||||||
|
|
||||||
|
// gaussian elimination
|
||||||
for (int i = 0; i < N; i++)
|
for (int i = 0; i < N; i++)
|
||||||
{
|
{
|
||||||
|
// find best row
|
||||||
|
int best_row = queue.front();
|
||||||
|
int best_row_queue_idx = 0;
|
||||||
|
double best_row_absval = 0;
|
||||||
|
|
||||||
|
for (int k = 0; k < GetSize(queue); k++) {
|
||||||
|
int row = queue[k];
|
||||||
|
double absval = fabs(M[i + row*N1]);
|
||||||
|
if (absval > best_row_absval) {
|
||||||
|
best_row = row;
|
||||||
|
best_row_queue_idx = k;
|
||||||
|
best_row_absval = absval;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int row = best_row;
|
||||||
|
pivot_cache.push_back(row);
|
||||||
|
|
||||||
|
queue[best_row_queue_idx] = queue.back();
|
||||||
|
queue.pop_back();
|
||||||
|
|
||||||
// normalize row
|
// normalize row
|
||||||
for (int j = i+1; j < N+1; j++)
|
for (int k = i+1; k < N1; k++)
|
||||||
M[(N+1)*i + j] /= M[(N+1)*i + i];
|
M[k + row*N1] /= M[i + row*N1];
|
||||||
M[(N+1)*i + i] = 1.0;
|
M[i + row*N1] = 1.0;
|
||||||
|
|
||||||
// elimination
|
// elimination
|
||||||
for (int j = i+1; j < N; j++) {
|
for (int other_row : queue) {
|
||||||
double d = M[(N+1)*j + i];
|
double d = M[i + other_row*N1];
|
||||||
for (int k = 0; k < N+1; k++)
|
for (int k = i+1; k < N1; k++)
|
||||||
if (k > i)
|
M[k + other_row*N1] -= d*M[k + row*N1];
|
||||||
M[(N+1)*j + k] -= d*M[(N+1)*i + k];
|
M[i + other_row*N1] = 0.0;
|
||||||
else
|
|
||||||
M[(N+1)*j + k] = 0.0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log_assert(queue.empty());
|
||||||
|
log_assert(GetSize(pivot_cache) == N);
|
||||||
|
|
||||||
// back substitution
|
// back substitution
|
||||||
for (int i = N-1; i >= 0; i--)
|
for (int i = N-1; i >= 0; i--)
|
||||||
for (int j = i+1; j < N; j++)
|
for (int j = i+1; j < N; j++)
|
||||||
{
|
{
|
||||||
M[(N+1)*i + N] -= M[(N+1)*i + j] * M[(N+1)*j + N];
|
int row = pivot_cache[i];
|
||||||
M[(N+1)*i + j] = 0.0;
|
int other_row = pivot_cache[j];
|
||||||
|
M[N + row*N1] -= M[j + row*N1] * M[N + other_row*N1];
|
||||||
|
M[j + row*N1] = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef LOG_MATRICES
|
#ifdef LOG_MATRICES
|
||||||
|
|
Loading…
Reference in a new issue