Saturday, July 27, 2024
Google search engine
HomeUncategorizedDo not taunt happy fun branch predictor

Do not taunt happy fun branch predictor

Do Not Taunt Happy Fun Branch Predictor

I’ve been writing a lot of AArch64 assembly, for reasons.

I recently came up with a “clever” idea to eliminate one jump from an inner
loop, and was surprised to find that it slowed things down. Allow me to explain
my terrible error, so that you don’t fall victim in the future.

A toy model of the relevant code looks something like this:

float run(const float* data, size_t n) {
    float g = 0.0;
    while (n) {
        n--;
        const float f = *data++;
        foo(f, &g);
    }
    return g;
}

static void foo(float f, float* g) {
    // do some stuff, modifying g
}

(eliding headers and the forward declaration of foo for space)

A simple translation into AArch64 assembly gives something like this:

// x0: const float* data
// x1: size_t n
// Returns a single float in s0

// Prelude: store frame and link registers
stp   x29, x30, [sp, #-16]!

// Initialize g = 0.0
fmov s0, #0.0

loop:
    cmp x1, #0
    b.eq exit
    sub x1, x1, #1
    ldr s1, [x0], #4

    bl foo   // call the function
    b loop   // keep looping

foo:
    // Do some work, reading from s1 and accumulating into s0
    // ...
    ret

exit: // Function exit
    ldp   x29, x30, [sp], #16
    ret

Here, foo is kinda like a naked
function
:
it uses the same stack frame and registers as the parent function, reads from
s1, and writes to s0.

The call to foo uses the the bl instruction, which is “branch and link”:
it jumps to the given label, and stores the next instruction address in the
link register (lr or x30).

When foo is done, the ret instruction jumps to the address in the link
register, which is the instruction following the original bl.

Looking at this code, I was struck by the fact that it does two branches,
one after the other. Surely, it would be more efficient to only branch once.

I had the clever idea to do so without changing foo:

stp   x29, x30, [sp, #-16]!
fmov s0, #0.0

bl loop // Set up x30 to point to the loop entrance
loop:
    cmp x1, #0
    b.eq exit
    sub x1, x1, #1
    ldr s1, [x0], #4

foo:
    // Do some work, accumulating into `s0`
    // ...
    ret

exit: // Function exit
    ldp   x29, x30, [sp], #16
    ret

This is a little subtle:

  • The first call to bl loop stores the beginning of the loop block in x30
  • After checking for loop termination, we fall through into the foo function
    (without a branch!)
  • foo still ends with ret, which returns to the loop block (because
    that’s what’s in x30).

Within the body of the loop, we never change x30, so the repeated ret
instructions always return to the same place.

I set up a benchmark using a very simple foo:

foo:
    fadd s0, s0, s1
    ret

With this foo, the function as a whole sums the incoming array of float
values.

Benchmarking with criterion
(on an M1 Max CPU),
with a 1024-element array:

Program Time
Original 969 ns
“Optimized” 3.85 µs

The “optimized” code with one jump per loop is about 4x slower
than the original version with two jumps per loop!

I found this surprising, so I asked a few colleagues about it.

Between Cliff and
Dan,
the consensus was that mismatched bl / ret
pairs were confusing the
branch predictor.

The ARM documentation agrees:

Why do we need a special function return instruction? Functionally, BR LR
would do the same job as RET. Using RET tells the processor that this is a
function return. Most modern processors, and all Cortex-A processors, support
branch prediction. Knowing that this is a function return allows processors to
more accurately predict the branch.

Branch predictors guess the direction the program flow will take across
branches. The guess is used to decide what to load into a pipeline with
instructions waiting to be processed. If the branch predictor guesses
correctly, the pipeline has the correct instructions and the processor does
not have to wait for instructions to be loaded from memory.

More specifically, the branch predictor probably keeps an internal stack of
function return addresses, which is pushed to whenever a bl is executed. When
the branch predictor sees a ret coming down the pipeline, it assumes that
you’re returning to the address associated with the most recent bl (and begins
prefetching / speculative execution / whatever), then pops that top address from
its internal stack.

This works if you’ve got matched bl / ret pairs, but the prediction will
fail if the same address is used by multiple ret instructions; you’ll end up
with (vague handwaving) useless prefetching, incorrect speculative execution,
and pipeline stalls / flushes

Dan made the great suggestion of replacing ret with br x30 to test this
theory. Sure enough, this fixes the performance regression:

Program Time
Matched bl / ret 969 ns
One bl, many ret 3.85 µs
One bl, many br x30 913 ns

In fact, it’s slightly faster, probably because it’s only doing one branch
per loop instead of two!

To further test the “branch predictor” theory, I opened up Instruments and
examined performance counters for the first two programs. Picking out the worst
offenders, the results seem conclusive:

Counter Matched bl / ret One bl, many ret
BRANCH_RET_INDIR_MISPRED_NONSPECIFIC 92 928,644,975
FETCH_RESTART 61,121 987,765,276
MAP_DISPATCH_BUBBLE 1,155,632 7,350,085,139
MAP_REWIND 6,412,734 2,789,499,545

These measurements are captured while summing an array of 1B elements. We see
that with mismatched bl / ret pairs, the return branch predictor fails about
93% of the time!

Apple doesn’t fully document these counters, but I’m guessing that the other
counters are downstream effects of bad branch prediction:

  • FETCH_RESTART is presumably bad prefetching
  • MAP_DISPATCH_BUBBLE probably refers to pipeline stalls
  • MAP_REWIND might be bad speculative execution that needs to be rewound

In conclusion,
do not taunt happy fun branch predictor
with asymmetric usage of bl and ret instructions.


Appendix: Going Fast

Take a second look at this program:

stp   x29, x30, [sp, #-16]!
fmov s0, #0.0

loop:
    cmp x1, #0
    b.eq exit
    sub x1, x1, #1
    ldr s1, [x0], #4

    bl foo   // call the function
    b loop   // keep looping

foo:
    fadd s0, s0, s1
    ret

exit: // Function exit
    ldp   x29, x30, [sp], #16
    ret

Upon seeing this program, it’s a common reaction to ask “why is foo a
subroutine at all?”

The answer is “because this is a didactic example, not code that’s trying
to go as fast as possible”.

Still, it’s a fair question. You wanna go fast? Let’s go fast.

If we know the contents of foo when building this
function (and it’s shorter than the maximum jump distance), we can remove the
bl and ret entirely:

loop:
    cmp x1, #0
    b.eq exit
    sub x1, x1, #1
    ldr s1, [x0], #4

    // foo is completely inlined here
    fadd s0, s0, s1

    b loop

exit: // Function exit
    ldp   x29, x30, [sp], #16
    ret

This is a roughly 6% speedup: from 969 ns to 911 ns.

We can get faster still by trusting the compiler:

pub fn sum_slice(f: &[f32]) -> f32 {
    f.iter().sum()
}

This brings us down to 833 ns, a significant improvement!

Looking at the assembly,
it’s doing some loop unrolling.
However, even when compiled with -C target-cpu=native, it’s not generating
NEON SIMD instructions.
Can we beat it?

We sure can!

stp   x29, x30, [sp, #-16]!

fmov s0, #0.0
dup v1.4s, v0.s[0]
dup v2.4s, v0.s[0]

loop:  // 1x per loop
    ands xzr, x1, #3
    b.eq simd

    sub x1, x1, #1
    ldr s3, [x0], #4

    fadd s0, s0, s3
    b loop

simd:  // 4x SIMD per loop
    ands xzr, x1, #7
    b.eq simd2

    sub x1, x1, #4
    ldp d3, d4, [x0], #16
    mov v3.d[1], v4.d[0]

    fadd v1.4s, v1.4s, v3.4s

    b simd

simd2:  // 2 x 4x SIMD per loop
    cmp x1, #0
    b.eq exit

    sub x1, x1, #8

    ldp d3, d4, [x0], #16
    mov v3.d[1], v4.d[0]
    fadd v1.4s, v1.4s, v3.4s

    ldp d5, d6, [x0], #16
    mov v5.d[1], v6.d[0]
    fadd v2.4s, v2.4s, v5.4s

    b simd2

exit: // function exit
    fadd v2.4s, v2.4s, v1.4s
    mov s1, v2.s[0]
    fadd s0, s0, s1
    mov s1, v2.s[1]
    fadd s0, s0, s1
    mov s1, v2.s[2]
    fadd s0, s0, s1
    mov s1, v2.s[3]
    fadd s0, s0, s1

    ldp   x29, x30, [sp], #16
    ret

This code includes three different loops:

  • The first loop (loop) sums individual values
    into s0 until we have a multiple of four values remaining
  • The second loop (simd) uses SIMD instructions to sum 4 values at a time
    into the vector register v1, until we have a multiple of 8 values remaining
  • The last loop (simd2) is the same as simd, but is unrolled 2x so it
    handles 8 values per loop iteration, summing into v1 and v2

At the function exit, we accumulate the values in the vector registers v1/v2
into s0, which is returned.

The type punning here is particularly cute:

ldp d3, d4, [x0], #16
mov v3.d[1], v4.d[0]
fadd v1.4s, v1.4s, v3.4s

Remember, x0 holds a float*. We pretend that it’s a double* to load 128
bits (i.e. 4x float values) into d3 and d4. Then, we move the “double” in d4
to occupy the top 64 bits of the v3 vector register (of which d3 is the
lower 64 bits).

Of course, each “double” is two floats, but that doesn’t matter when shuffling
them around. When summing with fadd, we tell the processor to treat them as
four floats (the .4s suffix), and everything works out fine.

How fast are we now?

This runs in 94 ns, or about 8.8x faster than our previous best.

Here’s a summary of performance:

Program Time
Matched bl / ret 969 ns
One bl, many ret 3.85 µs
One bl, many br x30 913 ns
Plain loop with b 911 ns
Rewrite it in Rust 833 ns
SIMD + manual loop unrolling 94 ns

Could we get even faster? I’m sure it’s possible; I make no claims to being
the Agner Fog of AArch64 assembly.

Still, this is a reasonable point to wrap up: we’ve demystified the initial
performance regression, and had some fun hand-writing assembly to go very
fast indeed.

The SIMD code does come with one asterisk, though: because floating-point
addition is not associative, and it performs the summation in a different
order, it may not get the same result as straight-line code. In retrospect,
this is likely why the compiler doesn’t generate SIMD instructions to compute
the sum!

Does this matter for your use case? Only you can know!


All of the code from this post is
published to GitHub.

You can reproduce benchmarks by running cargo bench on an ARM64 machine.

Read More

RELATED ARTICLES

LEAVE A REPLY

Please enter your comment!
Please enter your name here

- Advertisment -
Google search engine

Most Popular

Recent Comments