#!/usr/bin/python # -*- coding: utf-8 -*- from z3 import * BITWIDTH = 32 Opt = Datatype('Opt') Opt.declare('none') Opt.declare('some', ('value', BitVecSort(BITWIDTH))) Opt = Opt.create() none = Opt.none some = Opt.some value = Opt.value (t, r) = Consts('t r', Opt) s = Solver() # print "sat" if f is refined by g, else print "unsat" # equivalently, you can say "f can be implemented by calling g" def check_refined(f, g): s.push() s.add(ForAll([t, r], Implies(f(t, r) != none, f(t, r) == g(t, r)))) print s.check() s.pop() INT_MIN = BitVecVal(1 << BITWIDTH - 1, BITWIDTH) INT_MAX = BitVecVal((1 << BITWIDTH - 1) - 1, BITWIDTH) INT_MIN_AS_LONG = SignExt(BITWIDTH, INT_MIN) INT_MAX_AS_LONG = SignExt(BITWIDTH, INT_MAX) def signed_add_overflow(y, z): ylong = SignExt(BITWIDTH, y) zlong = SignExt(BITWIDTH, z) rlong = ylong + zlong return Or(rlong > INT_MAX_AS_LONG, rlong < INT_MIN_AS_LONG) def signed_mul_overflow(y, z): ylong = SignExt(BITWIDTH, y) zlong = SignExt(BITWIDTH, z) rlong = ylong * zlong return Or(rlong > INT_MAX_AS_LONG, rlong < INT_MIN_AS_LONG) def add(t, r): return If(Or(t == none, r == none), none, some(value(t) + value(r))) def add_nsw(t, r): return If(Or(t == none, r == none), none, If(signed_add_overflow(value(t), value(r)), none, some(value(t) + value(r)))) def mul(t, r): return If(Or(t == none, r == none), none, some(value(t) * value(r))) def mul_nsw(t, r): return If(Or(t == none, r == none), none, If(signed_mul_overflow(value(t), value(r)), none, some(value(t) * value(r)))) (a, b) = Consts('a b', Opt) def slow(a, b): return add_nsw(a, add_nsw(a, a)) def fast1(a, b): return mul(a, Opt.some(3)) check_refined(slow, fast1) check_refined(fast1, slow) def fast2(a, b): return mul_nsw(a, Opt.some(3)) check_refined(slow, fast2) check_refined(fast2, slow) prove (slow(a, b) == fast1(a, b)) # prove (slow(a, b) == fast2(a, b)) # oops Z3 doesn't come back at larger bitwidths... def opt_str(o): p = simplify(o) if is_const(p): return "none" else: return simplify(value(p)).as_signed_long() if False: print ONE = BitVecVal(1, BITWIDTH) TWO = BitVecVal(2, BITWIDTH) print simplify(ONE).as_signed_long() print simplify(TWO).as_signed_long() print simplify(INT_MAX).as_signed_long() print simplify(INT_MIN).as_signed_long() print simplify(INT_MAX_AS_LONG).as_signed_long() print simplify(INT_MIN_AS_LONG).as_signed_long() ONE = Opt.some(1) TWO = Opt.some(2) NEG_ONE = Opt.some(-1) NEG_TWO = Opt.some(-2) INT_MIN = Opt.some(INT_MIN) INT_MAX = Opt.some(INT_MAX) if False: print opt_str(mul_nsw(ONE, ONE)) # 1 print opt_str(mul_nsw(INT_MAX, ONE)) # 127 print opt_str(mul_nsw(INT_MIN, ONE)) # -128 print opt_str(mul_nsw(INT_MAX, NEG_ONE)) # -127 print opt_str(mul_nsw(INT_MIN, NEG_ONE)) # none print opt_str(mul_nsw(INT_MAX, TWO)) # none print opt_str(mul_nsw(INT_MIN, TWO)) # none print opt_str(mul_nsw(INT_MAX, NEG_TWO)) # none print opt_str(mul_nsw(INT_MIN, NEG_TWO)) # none print opt_str(add_nsw(ONE, ONE)) # 2 print opt_str(add_nsw(INT_MAX, ONE)) # none print opt_str(add_nsw(INT_MIN, ONE)) # -127 print opt_str(add_nsw(INT_MAX, NEG_ONE)) # 126 print opt_str(add_nsw(INT_MIN, NEG_ONE)) # none check_refined(add, add) # sat check_refined(add, add_nsw) # unsat check_refined(add_nsw, add) # sat check_refined(add_nsw, add_nsw) # sat