1
0
Fork 0

add BoolFixedPointSolver

This commit is contained in:
Jacob Lifshay 2026-06-12 19:55:48 -07:00
parent 1b16118ce5
commit b0e7873a17
Signed by: programmerjake
SSH key fingerprint: SHA256:HnFTLGpSm4Q4Fj502oCFisjZSoakwEuTsJJMSke63RQ
2 changed files with 712 additions and 0 deletions

View file

@ -43,6 +43,7 @@ pub use misc::{
};
pub(crate) use misc::{InternedStrCompareAsStr, chain, copy_le_bytes_to_bitslice};
pub mod bool_fixed_point_solver;
pub(crate) mod indented_print;
pub mod job_server;
pub mod map_trait;

View file

@ -0,0 +1,711 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
// See Notices.txt for copyright information
use petgraph::unionfind::UnionFind;
use std::{collections::BTreeSet, fmt};
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Variable(usize);
impl Variable {
pub fn index(self) -> usize {
self.0
}
}
impl fmt::Debug for Variable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl fmt::Display for Variable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "v{}", self.0)
}
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub enum Constraint {
/// `variable` is constrained to be [`!solver.unconstrained_variables_value()`](BoolFixedPointSolver::unconstrained_variables_value())
MaximallyConstrained { variable: Variable },
/// the constraint is `dest == src`
Equal { dest: Variable, src: Variable },
/// the constraint is `dest == dest & src`
And { dest: Variable, src: Variable },
/// the constraint is `dest == dest | src`
Or { dest: Variable, src: Variable },
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
/// the constraint is `dest == dest & src`
struct AndConstraint {
dest: Variable,
src: Variable,
}
impl AndConstraint {
fn from_or_constraint(or_constraint_dest: Variable, or_constraint_src: Variable) -> Self {
// `a == a | b` is equivalent to `b == b & a`
Self {
dest: or_constraint_src,
src: or_constraint_dest,
}
}
}
impl fmt::Debug for AndConstraint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { dest, src } = *self;
write!(f, "{dest} == {dest} & {src}")
}
}
#[derive(Clone)]
pub struct BoolFixedPointSolver {
variables_union_find: UnionFind<usize>,
variables_value: Vec<bool>,
maximally_constrained: Vec<bool>,
unconstrained_variables_value: bool,
solved: bool,
and_constraints: BTreeSet<AndConstraint>,
}
impl fmt::Debug for BoolFixedPointSolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self {
variables_union_find,
variables_value,
maximally_constrained,
unconstrained_variables_value,
solved,
and_constraints,
} = self;
f.debug_struct("BoolFixedPointSolver")
.field(
"variables_union_find",
&fmt::from_fn(|f| {
f.debug_map()
.entries(
(0..variables_union_find.len())
.map(|i| (Variable(i), Variable(variables_union_find.find(i)))),
)
.finish()
}),
)
.field(
"variables_value",
&fmt::from_fn(|f| {
let mut debug_map = f.debug_map();
for (i, v) in variables_value.iter().enumerate() {
if variables_union_find.find(i) == i {
debug_map.entry(&Variable(i), v);
}
}
debug_map.finish()
}),
)
.field(
"maximally_constrained",
&fmt::from_fn(|f| {
let mut debug_map = f.debug_map();
for (i, v) in maximally_constrained.iter().enumerate() {
if variables_union_find.find(i) == i {
debug_map.entry(&Variable(i), v);
}
}
debug_map.finish()
}),
)
.field(
"unconstrained_variables_value",
unconstrained_variables_value,
)
.field("solved", solved)
.field("and_constraints", and_constraints)
.finish()
}
}
impl BoolFixedPointSolver {
pub const fn new(unconstrained_variables_value: bool) -> Self {
Self {
variables_union_find: UnionFind::new_empty(),
variables_value: Vec::new(),
maximally_constrained: Vec::new(),
unconstrained_variables_value,
solved: false,
and_constraints: BTreeSet::new(),
}
}
pub fn unconstrained_variables_value(&self) -> bool {
self.unconstrained_variables_value
}
pub fn new_variable(&mut self) -> Variable {
let index = self.variables_union_find.new_set();
self.variables_value
.push(self.unconstrained_variables_value);
self.maximally_constrained.push(false);
self.solved = false;
Variable(index)
}
pub fn variable_count(&self) -> usize {
self.variables_union_find.len()
}
#[track_caller]
fn assert_variable_in_range(&self, variable: Variable) {
if variable.0 >= self.variable_count() {
panic!("invalid variable {variable:?}");
}
}
#[track_caller]
pub fn add_constraint(&mut self, constraint: Constraint) {
self.solved = false;
match constraint {
Constraint::MaximallyConstrained { variable } => {
self.assert_variable_in_range(variable);
self.maximally_constrained[self.variables_union_find.find_mut(variable.0)] = true;
return;
}
Constraint::Equal { dest, src } => {
self.assert_variable_in_range(dest);
self.assert_variable_in_range(src);
let maximally_constrained = self.maximally_constrained
[self.variables_union_find.find_mut(dest.0)]
| self.maximally_constrained[self.variables_union_find.find_mut(src.0)];
self.variables_union_find.union(dest.0, src.0);
let merged_index = self.variables_union_find.find_mut(dest.0);
self.maximally_constrained[merged_index] = maximally_constrained;
}
Constraint::And { dest, src } => {
self.assert_variable_in_range(src);
self.assert_variable_in_range(dest);
if src != dest {
self.and_constraints.insert(AndConstraint { dest, src });
}
}
Constraint::Or { dest, src } => {
self.assert_variable_in_range(src);
self.assert_variable_in_range(dest);
if src != dest {
self.and_constraints
.insert(AndConstraint::from_or_constraint(dest, src));
}
}
}
}
pub fn solve(&mut self) {
for (value, maximally_constrained) in self
.variables_value
.iter_mut()
.zip(&self.maximally_constrained)
{
*value = self.unconstrained_variables_value ^ *maximally_constrained;
}
let mut variables_to_constraints_map: Vec<Vec<AndConstraint>> =
vec![Vec::new(); self.variable_count()];
for &AndConstraint { mut dest, mut src } in &self.and_constraints {
dest.0 = self.variables_union_find.find_mut(dest.0);
src.0 = self.variables_union_find.find_mut(src.0);
if dest == src {
continue;
}
let constraint = AndConstraint { dest, src };
variables_to_constraints_map[dest.0].push(constraint);
variables_to_constraints_map[src.0].push(constraint);
}
let mut worklist: Vec<Variable> = (0..self.variable_count())
.filter(|&index| self.variables_union_find.find_mut(index) == index)
.map(Variable)
.collect();
while let Some(variable) = worklist.pop() {
for &AndConstraint { dest, src } in &variables_to_constraints_map[variable.0] {
let dest_value = self.variables_value[dest.0];
let src_value = self.variables_value[src.0];
// equivalent to `dest_value != dest_value & src_value`:
let is_unsatisfied = dest_value && !src_value;
if is_unsatisfied {
if self.unconstrained_variables_value {
self.variables_value[dest.0] = false;
worklist.push(dest);
} else {
self.variables_value[src.0] = true;
worklist.push(src);
}
}
}
}
self.solved = true;
}
#[track_caller]
pub fn value(&mut self, variable: Variable) -> bool {
#[cold]
fn solve_cold(this: &mut BoolFixedPointSolver) {
this.solve();
}
self.assert_variable_in_range(variable);
if !self.solved {
solve_cold(self);
}
self.variables_value[self.variables_union_find.find_mut(variable.0)]
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::num::NonZero;
struct TestCase<'a, C, Vars, Vals> {
variable_count: usize,
expected_values: Option<&'a [bool]>,
constraints: C,
variables: Vars,
values: Vals,
solver: BoolFixedPointSolver,
}
impl<'a, C: FnOnce(&[Variable]) -> I, I: IntoIterator<Item = Constraint>> TestCase<'a, C, (), ()> {
fn new_expected(
unconstrained_variables_value: bool,
expected_values: &'a [bool],
constraints: C,
) -> Self {
Self {
variable_count: expected_values.len(),
expected_values: Some(expected_values),
constraints,
variables: (),
values: (),
solver: BoolFixedPointSolver::new(unconstrained_variables_value),
}
}
#[track_caller]
fn get_constraints_and_variables(
self,
) -> TestCase<'a, Vec<Constraint>, Vec<Variable>, [bool; 0]> {
let Self {
variable_count,
expected_values,
constraints,
variables: (),
values: (),
mut solver,
} = self;
assert_eq!(
expected_values.map_or(variable_count, |v| v.len()),
variable_count,
);
let variables = Vec::from_iter((0..variable_count).map(|_| solver.new_variable()));
let constraints = Vec::from_iter(constraints(&variables));
TestCase {
variable_count,
expected_values,
constraints,
variables,
values: [],
solver,
}
}
}
impl<'a> TestCase<'a, Vec<Constraint>, Vec<Variable>, [bool; 0]> {
#[track_caller]
fn add_and_check_constraints(&mut self) {
if let Some(expected_values) = self.expected_values {
self.check_constraints("expected values", expected_values);
}
for &constraint in &self.constraints {
self.solver.add_constraint(constraint);
}
}
#[track_caller]
fn get_values(self) -> TestCase<'a, Vec<Constraint>, Vec<Variable>, Vec<bool>> {
let Self {
variable_count,
expected_values,
constraints,
variables,
values: [],
mut solver,
} = self;
let values = Vec::from_iter(variables.iter().map(|&v| solver.value(v)));
TestCase {
variable_count,
expected_values,
constraints,
variables,
values,
solver,
}
}
}
impl<'a> TestCase<'a, Vec<Constraint>, Vec<Variable>, Vec<bool>> {
#[track_caller]
fn check_values(&self) {
let Self {
variable_count: _,
expected_values,
constraints: _,
variables,
values,
solver: _,
} = self;
if let Some(expected_values) = expected_values {
for ((&expected_value, &variable), &value) in
expected_values.iter().zip(variables).zip(values)
{
if expected_value != value {
self.error(format_args!(
"solver output for {variable} of {value:?} doesn't \
match expected value of {expected_value:?}",
));
}
}
}
self.check_constraints("solved values", values);
}
}
impl<'a, Vals: AsRef<[bool]>> TestCase<'a, Vec<Constraint>, Vec<Variable>, Vals> {
#[track_caller]
fn check_constraints(&self, values_name: &str, values: &[bool]) {
let unconstrained_variables_value = self.solver.unconstrained_variables_value();
let v = |variable: Variable| values[variable.index()];
for &constraint in &self.constraints {
let satisfied = match constraint {
Constraint::MaximallyConstrained { variable } => {
v(variable) != unconstrained_variables_value
}
Constraint::Equal { dest, src } => v(dest) == v(src),
Constraint::And { dest, src } => v(dest) == v(dest) & v(src),
Constraint::Or { dest, src } => v(dest) == v(dest) | v(src),
};
if !satisfied {
self.error(format_args!(
"{values_name} don't satisfy constraint: {constraint:#?}"
));
}
}
}
#[track_caller]
fn error(&self, msg: fmt::Arguments<'_>) -> ! {
let Self {
variable_count,
expected_values,
ref constraints,
ref variables,
ref values,
ref solver,
} = *self;
let values = values.as_ref();
panic!(
"{msg}\n\
values={values:#?}\n\
constraints={constraints:#?}\n\
solver={solver:#?}",
values = fmt::from_fn(|f| {
let mut debug_map = f.debug_map();
for i in 0..variable_count {
debug_map.key(&variables[i]);
if let Some(value) = values.get(i) {
if let Some(expected_values) = expected_values {
debug_map.value(&format_args!(
"{value:?} (expected: {:?})",
expected_values[i],
));
} else {
debug_map.value(value);
}
} else if let Some(expected_values) = expected_values {
debug_map.value(&format_args!("(expected: {:?})", expected_values[i]));
} else {
debug_map.value(&format_args!("None"));
}
}
debug_map.finish()
}),
);
}
}
#[track_caller]
fn test_case<I: IntoIterator<Item = Constraint>>(
test_case: TestCase<'_, impl FnOnce(&[Variable]) -> I, (), ()>,
) {
let mut test_case = test_case.get_constraints_and_variables();
test_case.add_and_check_constraints();
let test_case = test_case.get_values();
test_case.check_values();
}
#[test]
fn test_bool_fixed_point_solver_simple() {
test_case(TestCase::new_expected(false, &[], |_| []));
test_case(TestCase::new_expected(true, &[], |_| []));
test_case(TestCase::new_expected(false, &[false], |_| []));
test_case(TestCase::new_expected(true, &[true], |_| []));
test_case(TestCase::new_expected(false, &[true], |v| {
[Constraint::MaximallyConstrained { variable: v[0] }]
}));
test_case(TestCase::new_expected(true, &[false], |v| {
[Constraint::MaximallyConstrained { variable: v[0] }]
}));
test_case(TestCase::new_expected(false, &[true, true], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::Equal {
dest: v[1],
src: v[0],
},
]
}));
test_case(TestCase::new_expected(true, &[false, false], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::Equal {
dest: v[1],
src: v[0],
},
]
}));
test_case(TestCase::new_expected(false, &[true, false], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::And {
dest: v[1],
src: v[0],
},
]
}));
test_case(TestCase::new_expected(true, &[false, false], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::And {
dest: v[1],
src: v[0],
},
]
}));
test_case(TestCase::new_expected(false, &[true, true], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::And {
dest: v[0],
src: v[1],
},
]
}));
test_case(TestCase::new_expected(true, &[false, true], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::And {
dest: v[0],
src: v[1],
},
]
}));
test_case(TestCase::new_expected(false, &[true, true], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::Or {
dest: v[1],
src: v[0],
},
]
}));
test_case(TestCase::new_expected(true, &[false, true], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::Or {
dest: v[1],
src: v[0],
},
]
}));
test_case(TestCase::new_expected(false, &[true, false], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::Or {
dest: v[0],
src: v[1],
},
]
}));
test_case(TestCase::new_expected(true, &[false, false], |v| {
[
Constraint::MaximallyConstrained { variable: v[0] },
Constraint::Or {
dest: v[0],
src: v[1],
},
]
}));
}
#[derive(Debug)]
struct Rng {
state: u64,
}
impl Rng {
fn new(test_case_index: u32) -> Self {
Self {
state: (test_case_index as u64) << 32,
}
}
fn next_u64(&mut self) -> u64 {
self.state += 1;
// 4 random primes and 4 random rotate amounts
self.state
.wrapping_mul(0xA3C7_8807_EA6D_A4F9)
.rotate_left(43)
.wrapping_mul(0x1CCA_797A_6BF8_8C63)
.rotate_left(8)
.wrapping_mul(0xCC50_AA59_7C41_946F)
.rotate_left(12)
.wrapping_mul(0xFB2A_0137_F878_C4B5)
.rotate_left(58)
}
#[track_caller]
fn next_u64_in_range(&mut self, range: std::ops::Range<u64>) -> u64 {
let Some(len) = range.end.checked_sub(range.start).and_then(NonZero::new) else {
panic!("empty range: {range:?}");
};
let max_quotient = u64::MAX / len;
loop {
let next_u64 = self.next_u64();
let quotient = next_u64 / len;
let remainder = next_u64 % len;
if quotient < max_quotient {
return remainder + range.start;
}
}
}
#[track_caller]
fn next_usize_in_range(&mut self, range: std::ops::Range<usize>) -> usize {
self.next_u64_in_range(range.start as u64..range.end as u64) as usize
}
#[track_caller]
fn next_from_slice<'a, T>(&mut self, slice: &'a [T]) -> &'a T {
assert!(!slice.is_empty());
&slice[self.next_usize_in_range(0..slice.len())]
}
fn next_bool(&mut self) -> bool {
(self.next_u64() & 1) != 0
}
}
#[track_caller]
fn test_bool_fixed_point_solver_random_case(test_case_index: u32) {
println!("test_bool_fixed_point_solver_random_case({test_case_index})");
let mut rng = Rng::new(test_case_index);
// bias towards smaller problems to make them easier to debug
let variable_count = rng
.next_u64_in_range(1..1_000_000)
.pow(2)
.div_ceil(1_000_000_000) as usize;
let constraint_count =
rng.next_usize_in_range(0..(variable_count * variable_count).clamp(0, 10000));
let solver = BoolFixedPointSolver::new(rng.next_bool());
test_case(TestCase {
variable_count,
expected_values: None,
constraints: |variables: &[Variable]| {
Vec::from_iter(
(0..constraint_count).map(|_| match rng.next_usize_in_range(0..4) {
0 => Constraint::MaximallyConstrained {
variable: *rng.next_from_slice(variables),
},
1 => Constraint::Equal {
dest: *rng.next_from_slice(variables),
src: *rng.next_from_slice(variables),
},
2 => Constraint::And {
dest: *rng.next_from_slice(variables),
src: *rng.next_from_slice(variables),
},
3 => Constraint::Or {
dest: *rng.next_from_slice(variables),
src: *rng.next_from_slice(variables),
},
4.. => unreachable!(),
}),
)
},
variables: (),
values: (),
solver,
});
}
const CASES_FULL_RANGE: std::ops::Range<u32> = 0..100_000;
fn mul_div(v: u32, factor: u32, divisor: u32) -> u32 {
((v as u64 * factor as u64) / divisor as u64) as u32
}
#[track_caller]
fn test_bool_fixed_point_solver_random_cases(split_index: u32) {
assert!(split_index < CASES_SPLIT_COUNT);
let full_range_len = CASES_FULL_RANGE.end - CASES_FULL_RANGE.start;
let start = mul_div(split_index, full_range_len, CASES_SPLIT_COUNT);
let end = mul_div(split_index + 1, full_range_len, CASES_SPLIT_COUNT);
for test_case_index in start..end {
test_bool_fixed_point_solver_random_case(test_case_index)
}
}
const CASES_SPLIT_COUNT: u32 = 10;
#[test]
fn test_bool_fixed_point_solver_random_cases_0() {
test_bool_fixed_point_solver_random_cases(0);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_1() {
test_bool_fixed_point_solver_random_cases(1);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_2() {
test_bool_fixed_point_solver_random_cases(2);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_3() {
test_bool_fixed_point_solver_random_cases(3);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_4() {
test_bool_fixed_point_solver_random_cases(4);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_5() {
test_bool_fixed_point_solver_random_cases(5);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_6() {
test_bool_fixed_point_solver_random_cases(6);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_7() {
test_bool_fixed_point_solver_random_cases(7);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_8() {
test_bool_fixed_point_solver_random_cases(8);
}
#[test]
fn test_bool_fixed_point_solver_random_cases_9() {
test_bool_fixed_point_solver_random_cases(9);
}
}