mirror of
				https://github.com/Z3Prover/z3
				synced 2025-11-04 05:19:11 +00:00 
			
		
		
		
	Generalize refine_disequal_lin
This commit is contained in:
		
							parent
							
								
									f80eb6237d
								
							
						
					
					
						commit
						15854301b2
					
				
					 2 changed files with 68 additions and 71 deletions
				
			
		| 
						 | 
				
			
			@ -277,84 +277,71 @@ namespace polysat {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    bool viable::refine_disequal_lin(pvar v, rational const& val) {
 | 
			
		||||
        // LOG_H2("refine-disequal-lin with v" << v << ", val = " << val);
 | 
			
		||||
        auto* e = m_diseq_lin[v];
 | 
			
		||||
        if (!e)
 | 
			
		||||
            return true;
 | 
			
		||||
        entry* first = e;
 | 
			
		||||
        rational const& max_value = s.var2pdd(v).max_value();
 | 
			
		||||
        rational mod_value = max_value + 1;
 | 
			
		||||
        rational const mod_value = max_value + 1;
 | 
			
		||||
 | 
			
		||||
        do {
 | 
			
		||||
            LOG("refine-disequal-lin for src: " << e->src);
 | 
			
		||||
            // We have:
 | 
			
		||||
            //      a1*v + b1 >  a2*v + b2  if e->src.is_positive()
 | 
			
		||||
            //      a1*v + b1 >= a2*v + b2  if e->src.is_negative()
 | 
			
		||||
            // We compute an interval if the concrete value 'val' violates the constraint:
 | 
			
		||||
            //      p*val + q >  r*val + s  if e->src.is_positive()
 | 
			
		||||
            //      p*val + q >= r*val + s  if e->src.is_negative()
 | 
			
		||||
            // Note that e->interval is meaningless in this case,
 | 
			
		||||
            // we just use it to transport the values a1,b1,a2,b2.
 | 
			
		||||
            rational const& a1 = e->interval.lo_val();
 | 
			
		||||
            rational const& b1 = e->interval.lo().val();
 | 
			
		||||
            rational const& a2 = e->interval.hi_val();
 | 
			
		||||
            rational const& b2 = e->interval.hi().val();
 | 
			
		||||
            SASSERT(a1 != a2 && a1 != 0 && a2 != 0);
 | 
			
		||||
            // we just use it to transport the values p,q,r,s
 | 
			
		||||
            rational const& p = e->interval.lo_val();
 | 
			
		||||
            rational const& q_ = e->interval.lo().val();
 | 
			
		||||
            rational const& r = e->interval.hi_val();
 | 
			
		||||
            rational const& s_ = e->interval.hi().val();
 | 
			
		||||
            SASSERT(p != r && p != 0 && r != 0);
 | 
			
		||||
 | 
			
		||||
            rational lhs = mod(a1 * val + b1, mod_value);
 | 
			
		||||
            rational rhs = mod(a2 * val + b2, mod_value);
 | 
			
		||||
            rational const a = mod(p * val + q_, mod_value);
 | 
			
		||||
            rational const b = mod(r * val + s_, mod_value);
 | 
			
		||||
            rational const np = mod_value - p;
 | 
			
		||||
            rational const nr = mod_value - r;
 | 
			
		||||
            int const corr = e->src.is_negative() ? 1 : 0;
 | 
			
		||||
 | 
			
		||||
            auto delta_l = [&](rational const& val) {
 | 
			
		||||
                rational m1 = ceil((rhs + 1) / a2);
 | 
			
		||||
                int corr = e->src.is_negative() ? 1 : 0;
 | 
			
		||||
                rational m3 = (lhs - rhs + corr) / (a1 - a2);
 | 
			
		||||
                if (m3 <= 0)
 | 
			
		||||
                    m3 = m1;  // remove m3 from the minimum
 | 
			
		||||
                else
 | 
			
		||||
                    m3 = ceil(m3);
 | 
			
		||||
 | 
			
		||||
                // return std::min(m1, m3) - 1;
 | 
			
		||||
                return std::min(val, std::min(m1, m3) - 1);
 | 
			
		||||
                rational num = a - b + corr;
 | 
			
		||||
                rational l1 = floor(b / r);
 | 
			
		||||
                rational l2 = val;
 | 
			
		||||
                if (p > r)
 | 
			
		||||
                    l2 = ceil(num / (p - r)) - 1;
 | 
			
		||||
                rational l3 = ceil(num / (p + nr)) - 1;
 | 
			
		||||
                rational l4 = ceil((mod_value - a) / np) - 1;
 | 
			
		||||
                rational d1 = l3;
 | 
			
		||||
                rational d2 = std::min(l1, l2);
 | 
			
		||||
                rational d3 = std::min(l1, l4);
 | 
			
		||||
                rational d4 = std::min(l2, l4);
 | 
			
		||||
                rational dmax = std::max(std::max(d1, d2), std::max(d3, d4));
 | 
			
		||||
                return std::min(val, dmax);
 | 
			
		||||
            };
 | 
			
		||||
            auto delta_u = [&](rational const& val) {
 | 
			
		||||
                rational m1 = ceil((mod_value - lhs) / a1);
 | 
			
		||||
                rational m2 = mod_value - val;
 | 
			
		||||
                int corr = e->src.is_negative() ? 1 : 0;
 | 
			
		||||
                rational m3 = (lhs - rhs + corr) / (a2 - a1);
 | 
			
		||||
                if (m3 <= 0)
 | 
			
		||||
                    m3 = m2;  // remove m3 from the minimum
 | 
			
		||||
                else
 | 
			
		||||
                    m3 = ceil(m3);
 | 
			
		||||
 | 
			
		||||
                return std::min(m1, std::min(m2, m3)) - 1;
 | 
			
		||||
                rational num = a - b + corr;
 | 
			
		||||
                rational h1 = floor(b / nr);
 | 
			
		||||
                rational h2 = max_value - val;
 | 
			
		||||
                if (r > p)
 | 
			
		||||
                    h2 = ceil(num / (r - p)) - 1;
 | 
			
		||||
                rational h3 = ceil(num / (np + r)) - 1;
 | 
			
		||||
                rational h4 = ceil((mod_value - a) / p) - 1;
 | 
			
		||||
                rational d1 = h3;
 | 
			
		||||
                rational d2 = std::min(h1, h2);
 | 
			
		||||
                rational d3 = std::min(h1, h4);
 | 
			
		||||
                rational d4 = std::min(h2, h4);
 | 
			
		||||
                rational dmax = std::max(std::max(d1, d2), std::max(d3, d4));
 | 
			
		||||
                return std::min(max_value - val, dmax);
 | 
			
		||||
            };
 | 
			
		||||
 | 
			
		||||
            if (lhs > rhs || (e->src.is_negative() && lhs == rhs)) {
 | 
			
		||||
                rational lo;
 | 
			
		||||
                rational hi;
 | 
			
		||||
 | 
			
		||||
                // TODO: extract into separate function
 | 
			
		||||
                if (e->src.is_negative() && a2.is_one() && b1.is_zero() && b2.is_zero()) {
 | 
			
		||||
                    // special case:  v > -a*v for some numeral a
 | 
			
		||||
                    rational const& a = mod(-a1, mod_value);
 | 
			
		||||
                    if (val.is_zero()) {
 | 
			
		||||
                        lo = 0;
 | 
			
		||||
                        hi = ceil( (mod_value + 1) / (a + 1) );
 | 
			
		||||
                    } else {
 | 
			
		||||
                        rational const y = mod(-a * val, mod_value);
 | 
			
		||||
                        lo = ceil( val + (y - max_value) / a );
 | 
			
		||||
                        hi = ceil( (y + a*val + 1) / (a + 1) );
 | 
			
		||||
                        // can always extend to 0
 | 
			
		||||
                        if (lo.is_one())
 | 
			
		||||
                            lo = 0;
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                    // general case
 | 
			
		||||
                    lo = val - delta_l(val);
 | 
			
		||||
                    hi = val + delta_u(val) + 1;
 | 
			
		||||
                    // TODO: increase interval
 | 
			
		||||
                }
 | 
			
		||||
            if (a > b || (e->src.is_negative() && a == b)) {
 | 
			
		||||
                rational lo = val - delta_l(val);
 | 
			
		||||
                rational hi = val + delta_u(val) + 1;
 | 
			
		||||
 | 
			
		||||
                LOG("refine-disequal-lin: " << " [" << lo << ", " << hi << "[");
 | 
			
		||||
 | 
			
		||||
                SASSERT(0 <= lo && lo <= val);
 | 
			
		||||
                // SASSERT(val <= hi && hi <= max_value);
 | 
			
		||||
                SASSERT(val <= hi && hi <= mod_value);
 | 
			
		||||
                if (hi == mod_value) hi = 0;
 | 
			
		||||
                pdd lop = s.var2pdd(v).mk_val(lo);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1195,23 +1195,33 @@ class test_fi {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
public:
 | 
			
		||||
    static void exhaustive(unsigned bw = 3) {
 | 
			
		||||
        rational const m = rational::power_of_two(bw);
 | 
			
		||||
        for (rational a1(1); a1 < m; ++a1) {
 | 
			
		||||
            for (rational a2(1); a2 < m; ++a2) {
 | 
			
		||||
                // TODO: remove this to test other cases
 | 
			
		||||
                if (a1 == a2)
 | 
			
		||||
                    continue;
 | 
			
		||||
                for (rational b1(0); b1 < m; ++b1)
 | 
			
		||||
                    for (rational b2(0); b2 < m; ++b2)
 | 
			
		||||
                        for (rational val(0); val < m; ++val)
 | 
			
		||||
                            for (bool negated : {true, false})
 | 
			
		||||
                                check_one(a1, b1, a2, b2, val, negated, bw);
 | 
			
		||||
    static void exhaustive(unsigned bw = 0) {
 | 
			
		||||
        if (bw == 0) {
 | 
			
		||||
            exhaustive(1);
 | 
			
		||||
            exhaustive(2);
 | 
			
		||||
            exhaustive(3);
 | 
			
		||||
            exhaustive(4);
 | 
			
		||||
            exhaustive(5);
 | 
			
		||||
        } else {
 | 
			
		||||
            std::cout << "test_fi::exhaustive for bw=" << bw << std::endl;
 | 
			
		||||
            rational const m = rational::power_of_two(bw);
 | 
			
		||||
            for (rational p(1); p < m; ++p) {
 | 
			
		||||
                for (rational r(1); r < m; ++r) {
 | 
			
		||||
                    // TODO: remove this condition to test the cases other than disequal_lin! (also start p,q from 0)
 | 
			
		||||
                    if (p == r)
 | 
			
		||||
                        continue;
 | 
			
		||||
                    for (rational q(0); q < m; ++q)
 | 
			
		||||
                        for (rational s(0); s < m; ++s)
 | 
			
		||||
                            for (rational v(0); v < m; ++v)
 | 
			
		||||
                                for (bool negated : {true, false})
 | 
			
		||||
                                    check_one(p, q, r, s, v, negated, bw);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static void randomized(unsigned num_rounds = 10'000, unsigned bw = 16) {
 | 
			
		||||
    static void randomized(unsigned num_rounds = 100'000, unsigned bw = 16) {
 | 
			
		||||
        std::cout << "test_fi::randomized for bw=" << bw << " (" << num_rounds << " rounds)" << std::endl;
 | 
			
		||||
        rational const m = rational::power_of_two(bw);
 | 
			
		||||
        VERIFY(bw <= 32 && "random_gen generates 32-bit numbers");
 | 
			
		||||
        random_gen rng;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue