Automated Reasoning About LLVM Optimizations and Undefined Behavior


Following up on last week’s toy use of Z3 and my LLVM superoptimizer post from a few weeks ago, I wanted to try to prove things about optimizations that involve undefined behavior.

We’ll be working with the following example:

int triple (int x)
{
  return x + x + x;
}

Clang’s -O0 output for this code is cluttered but after running mem2reg we get a nice literal translation of the C code:

define i32 @triple(i32 %x) #0 {
  %1 = add nsw i32 %x, %x
  %2 = add nsw i32 %1, %x
  ret i32 %2
}

Next we can run instcombine, with this result:

define i32 @triple(i32 %x) #0 {
  %1 = mul i32 %x, 3
  ret i32 %1
}

The optimization itself (replacing x+x+x with x*3) is uninteresting and proving that it is correct is trivial in Z3 or any similar tool that supports reasoning about bitvector operations. To see this, go to the Z3 site and paste this code into the Python window on the right:

x = BitVec('x', 32)
slow = x+x+x
fast = x*3
prove (slow == fast)

The interesting aspect of this optimization is that the original C code and LLVM code are not using two’s complement addition. Rather, + in C and add nsw in LLVM are only defined when the addition doesn’t overflow. On the other hand, the multiplication emitted by instcombine is not qualified by nsw: it wraps in the expected fashion when the multiplication overflows (see the docs for more detail). My goal was to use Z3 to automatically prove that instcombine dropped the ball here: it could have emitted mul nsw which has weaker semantics and is therefore more desirable: compared to the unqualified mul it gives subsequent optimizations more freedom.

When we talk informally about compiler optimizations, we often say that the compiler should replace code with (faster or smaller) code that is functionally equivalent. Actually I just did this above while showing that x+x+x and x*3 are equivalent. But in fact, a compiler optimization does not need to replace code with functionally equivalent code. Rather, the new code only needs to refine the old code. Refinement occurs many times during a typical compilation. For example this C code:

foo (bar(), baz());

can be refined to either this:

tmp1 = bar();
tmp2 = baz();
foo (tmp1, tmp2);

or this:

tmp2 = baz();
tmp1 = bar();
foo (tmp1, tmp2);

These refinements are not (in general) equivalent to each other, nor are they equivalent to the original non-deterministic C code. Analogously, the LLVM instruction mul %x, 3 refines mul nsw %x, 3 but not the other way around.

So how do we teach Z3 to do refinement proofs? I was a bit worried by some wording in the Z3 documentation that says “in Z3 all functions are total” and then I got stuck but Alexey Solovyev, a formal methods postdoc at Utah, helped me out with the following solution. The idea is to use Z3’s ML-style option types to turn functions with undefined behavior into total functions:

BITWIDTH = 32

Opt = Datatype('Opt')
Opt.declare('none')
Opt.declare('some', ('value', BitVecSort(BITWIDTH)))

Opt = Opt.create()
none = Opt.none
some = Opt.some
value = Opt.value

(t, r) = Consts('t r', Opt)

s = Solver()

# print "sat" if f is refined by g, else print "unsat"
# equivalently, you can say "f can be implemented by calling g"
def check_refined(f, g):
    s.push()
    s.add(ForAll([t, r], Implies(f(t, r) != none, f(t, r) == g(t, r))))
    print s.check()
    s.pop()

In other words, f() is refined by g() iff g() returns the same value as f() for every member of f()’s domain. For inputs not in the domain of f(), the behavior of g() is a don’t-care.

Here are some LLVM operations that let us reason about the example code above:

INT_MIN = BitVecVal(1 << BITWIDTH - 1, BITWIDTH)
INT_MAX = BitVecVal((1 << BITWIDTH - 1) - 1, BITWIDTH)
INT_MIN_AS_LONG = SignExt(BITWIDTH, INT_MIN)
INT_MAX_AS_LONG = SignExt(BITWIDTH, INT_MAX)

def signed_add_overflow(y, z):
    ylong = SignExt(BITWIDTH, y)
    zlong = SignExt(BITWIDTH, z)
    rlong = ylong + zlong
    return Or(rlong > INT_MAX_AS_LONG, rlong < INT_MIN_AS_LONG)

def signed_mul_overflow(y, z):
    ylong = SignExt(BITWIDTH, y)
    zlong = SignExt(BITWIDTH, z)
    rlong = ylong * zlong
    return Or(rlong > INT_MAX_AS_LONG, rlong < INT_MIN_AS_LONG)

def add(t, r):
    return If(Or(t == none, r == none), 
              none, 
              some(value(t) + value(r)))

def add_nsw(t, r):
    return If(Or(t == none, r == none), 
              none,
              If(signed_add_overflow(value(t), value(r)), 
                 none,
                 some(value(t) + value(r))))

def mul(t, r):
    return If(Or(t == none, r == none), 
              none, 
              some(value(t) * value(r)))

def mul_nsw(t, r):
    return If(Or(t == none, r == none), 
              none,
              If(signed_mul_overflow(value(t), value(r)), 
                 none,
                 some(value(t) * value(r))))

Hopefully this code is pretty clear: add_nsw and its friend mul_nsw act like add and mul as long as there’s no overflow. If an overflow occurs, they map their arguments to none.

At this point I might as well mention a dirty little secret: formal methods code isn’t any easier to get right than regular code, and often it is considerably more difficult. The solution is testing. This is more than a little bit ironic since working around inadequacies of testing is the point of using formal methods in the first place. So here’s some test code. To make it easier to understand I’ve set BITWIDTH to 8 and also put the results of the print statements in comments.

def opt_str(o):
    p = simplify(o)
    if is_const(p):
        return "none"
    else:
        return simplify(value(p)).as_signed_long()
        
ONE = Opt.some(1)
TWO = Opt.some(2)
NEG_ONE = Opt.some(-1)
NEG_TWO = Opt.some(-2)
INT_MIN = Opt.some(INT_MIN)
INT_MAX = Opt.some(INT_MAX)

print opt_str(mul_nsw(ONE, ONE))           #    1
print opt_str(mul_nsw(INT_MAX, ONE))       #  127
print opt_str(mul_nsw(INT_MIN, ONE))       # -128
print opt_str(mul_nsw(INT_MAX, NEG_ONE))   # -127
print opt_str(mul_nsw(INT_MIN, NEG_ONE))   # none
print opt_str(mul_nsw(INT_MAX, TWO))       # none
print opt_str(mul_nsw(INT_MIN, TWO))       # none
print opt_str(mul_nsw(INT_MAX, NEG_TWO))   # none
print opt_str(mul_nsw(INT_MIN, NEG_TWO))   # none

print opt_str(add_nsw(ONE, ONE))           #    2
print opt_str(add_nsw(INT_MAX, ONE))       # none
print opt_str(add_nsw(INT_MIN, ONE))       # -127
print opt_str(add_nsw(INT_MAX, NEG_ONE))   #  126
print opt_str(add_nsw(INT_MIN, NEG_ONE))   # none

check_refined(add, add)                    # sat
check_refined(add, add_nsw)                # unsat

check_refined(add_nsw, add)                # sat
check_refined(add_nsw, add_nsw)            # sat

The formalization passes this quick smoke test. I’ve done some more testing and won’t bore you with it. None of this is conclusive of course.

Finally, we return to the example that motivated this post. Recall that LLVM’s instcombine pass translates two add nsw instructions into a single mul. Let’s make sure that was correct. Again, for convenience, I’ll put outputs in comments.

(a, b) = Consts('a b', Opt)

def slow(a, b):
    return add_nsw(a, add_nsw(a, a))

def fast1(a, b):
    return mul(a, Opt.some(3))

check_refined(slow, fast1)           # sat
check_refined(fast1, slow)           # unsat

As expected, LLVM’s transformation is correct: x+x+x where both + operators have undefined overflow can be refined by x*3 where the * operator has two’s complement wraparound. The converse is not true.

Next, let’s explore the idea that instcombine could emit a mul nsw instruction.

def fast2(a, b):
    return mul_nsw(a, Opt.some(3))

check_refined(slow, fast2)           # sat
check_refined(fast2, slow)           # sat

Not only does this optimization work, but the refinement property also holds for the converse, indicating that the two formulations are equivalent.

We can also ask Z3 to prove equivalence of these different formulations:

prove (slow(a, b) == fast1(a, b))    # counterexample [a = some(137)]
prove (slow(a, b) == fast2(a, b))    # proved

As expected, fast1 is not equivalent to the original code and additionally Z3 has given us a value for a that demonstrates this (for the 8-bit case). The second equivalence check passes.

What have we learned here? First, automated proofs about compiler optimizations in the presence of integer undefined behaviors are not too difficult. This is nice because manual reasoning about undefined behavior is error-prone. Second, there’s room for at least a little bit of improvement in LLVM’s instcombine pass (I’m sure the LLVM people know this). One idea that I quite like is using a solver to improve LLVM’s peephole optimizations by automatically figuring out when qualifiers like “nsw”, “nuw”, and “exact” can be used correctly.

Here’s the Python code all in one piece. You can run it locally after installing Z3 or else you can paste it into the Z3 website’s Python interface.


12 responses to “Automated Reasoning About LLVM Optimizations and Undefined Behavior”

  1. Nice post. Theorem provers (and particularly z3’s python bindings) have made work on instcombine much more enjoyable, and I hope it’ll help to reduce the bug count in that pass. While modeling undefined behavior on integer arithmetic isn’t too hard, there’s one major roadblock left: floating point. Proving that an FP peephole is safe by hand is extremely time-consuming and I think LLVM still has a lot of deficiencies there. It sounds like a prime target for a theorem prover but so far they were of no help.

    Z3 has some support for floating point now but I never got it to work in the way I needed to. To make things harder we have more than just plain IEEE754 single/double to support in instcombine. There’s x86’s extended precision (which is rather closely related to the single/double formats) and plainly weird formats such as PPC’s “double double”. I often just excluded those formats because I couldn’t verify that a transformation was legal for them.

    Then there’s -ffast-math which has a lovely loose definition of “allowing things that break IEEE754”. Really hard to formalize something like that 🙁

  2. Hi Benjamin, yeah… I haven’t messed with the FP support in any solver, my assumptions have been that (1) they aren’t ready yet and (2) most interesting transformations are illegal anyway due to associativity problems.

    If you or others have done some work on LLVM+solvers, can you send me a pointer? I’d love to learn more.

  3. It’s funny that FP came up. One of the first major real-world applications of a theorem prove (ACL2) was to prove the correctness of one of the AMD FP div instructions.

  4. I gave a talk last year at LLVM’s developers conference on using SMT solvers (and in particular Z3) to prove some LLVM optimizations correct. The examples I gave included instcombine.
    The slides are available here: http://web.ist.utl.pt/nuno.lopes/pres/llvm-devconf13.pptx

    The way I took advantage of undefined behavior was to prove equivalence of expressions just for the defined behavior, since the optimized expression may do whatever it wants for the undef cases. So basically I assert that signed operations don’t overflow, for example.

    We have been using this technique already. I’m sure you’ll find this LLVM bug report very interesting: http://llvm.org/PR17827

    Regarding floating point, Z3 has a fairly complete implementation for it. It works by transforming floating-point operations into bit-vectors, and then bit-blasts to SAT.
    There are other SMT solvers that already support FPs, like MathSAT, IIRC.

  5. Nuno, I had seen the slides but not this bug report, thanks!

    Something that would be excellent is a library of Z3 or SMTLIB functions implementing the LLVM instruction set. I will work on this. But if you have code to share, Benjamin and Nuno, that would be useful too.

  6. May I ask what is so hard about using SMT Solvers to prove statements about floats? As far as I can tell you only have to model the respective floating point data type precisely. Is there any fundamental difficulty or has it just not been done yet?

  7. tobi, I assume it’s the mix of continuous and discrete mathematics that makes floats difficult. Of course the FP operations can be viewed as circuits but at that point proofs involving real numbers would seem to become difficult.

  8. How hard would it be to re-write some of the optimization passes in LLVM to be generated from a source that can also be verified by Z3 or one of it’s ilk?

    The last time I took a look at this sort of thing, I ran into endless issues due to the combinatorial explosion the domain I was working in generates.

  9. Hi bcs, of course this is hard for a random LLVM pass but I think it’s totally doable for the instruction combiner, which is a workhorse but also it’s kind of nasty — it was one of the most fruitful sources of bugs for Csmith, for example.