Autovectorised FMA in JDK10
Fused-multiply-add (FMA) allows floating point expressions of the form a * x + b
to be evaluated in a single instruction, which is useful for numerical linear algebra. Despite the obvious appeal of FMA, JVM implementors are rather constrained when it comes to floating point arithmetic because Java programs are expected to be reproducible across versions and target architectures. FMA does not produce precisely the same result as the equivalent multiplication and addition instructions (this is caused by the compounding effect of rounding) so its use is a change in semantics rather than an optimisation; the user must opt in. To the best of my knowledge, support for FMA was first proposed in 2000, along with reorderable floating point operations, which would have been activated by a fastfp
keyword, but this proposal was withdrawn. In Java 9, the intrinsic Math.fma
was introduced to provide access to FMA for the first time.
DAXPY Benchmark
A good use case to evaluate Math.fma
is DAXPY from the Basic Linear Algebra Subroutine library. The code below will compile with JDK9+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Thread)
public class DAXPY {
double s;
@Setup(Level.Invocation)
public void init() {
s = ThreadLocalRandom.current().nextDouble();
}
@Benchmark
public void daxpyFMA(DoubleData state, Blackhole bh) {
double[] a = state.data1;
double[] b = state.data2;
for (int i = 0; i < a.length; ++i) {
a[i] = Math.fma(s, b[i], a[i]);
}
bh.consume(a);
}
@Benchmark
public void daxpy(DoubleData state, Blackhole bh) {
double[] a = state.data1;
double[] b = state.data2;
for (int i = 0; i < a.length; ++i) {
a[i] += s * b[i];
}
bh.consume(a);
}
}
Running this benchmark with Java 9, you may wonder why you bothered because the code is actually slower.
Benchmark | Mode | Threads | Samples | Score | Score Error (99.9%) | Unit | Param: size |
---|---|---|---|---|---|---|---|
daxpy | thrpt | 1 | 10 | 25.011242 | 2.259007 | ops/ms | 100000 |
daxpy | thrpt | 1 | 10 | 0.706180 | 0.046146 | ops/ms | 1000000 |
daxpyFMA | thrpt | 1 | 10 | 15.334652 | 0.271946 | ops/ms | 100000 |
daxpyFMA | thrpt | 1 | 10 | 0.623838 | 0.018041 | ops/ms | 1000000 |
This is because using Math.fma
disables autovectorisation. Taking a look at PrintAssembly
you can see that the naive daxpy
routine exploits AVX2, whereas daxpyFMA
reverts to scalar usage of SSE2.
// daxpy routine, code taken from main vectorised loop vmovdqu ymm1,ymmword ptr [r10+rdx*8+10h] vmulpd ymm1,ymm1,ymm2 vaddpd ymm1,ymm1,ymmword ptr [r8+rdx*8+10h] vmovdqu ymmword ptr [r8+rdx*8+10h],ymm1 // daxpyFMA routine vmovsd xmm2,qword ptr [rcx+r13*8+10h] vfmadd231sd xmm2,xmm0,xmm1 vmovsd qword ptr [rcx+r13*8+10h],xmm2
Not to worry, this seems to have been fixed in JDK 10. Since Java 10’s release is around the corner, there are early access builds available for all platforms. Rerunning this benchmark, FMA no longer incurs costs, and it doesn’t bring the performance boost some people might expect. The benefit is that there is less floating point error because the total number of roundings is halved.
Benchmark | Mode | Threads | Samples | Score | Score Error (99.9%) | Unit | Param: size |
---|---|---|---|---|---|---|---|
daxpy | thrpt | 1 | 10 | 2582.363228 | 116.637400 | ops/ms | 1000 |
daxpy | thrpt | 1 | 10 | 405.904377 | 32.364782 | ops/ms | 10000 |
daxpy | thrpt | 1 | 10 | 25.210111 | 1.671794 | ops/ms | 100000 |
daxpy | thrpt | 1 | 10 | 0.608660 | 0.112512 | ops/ms | 1000000 |
daxpyFMA | thrpt | 1 | 10 | 2650.264580 | 211.342407 | ops/ms | 1000 |
daxpyFMA | thrpt | 1 | 10 | 389.274693 | 43.567450 | ops/ms | 10000 |
daxpyFMA | thrpt | 1 | 10 | 24.941172 | 2.393358 | ops/ms | 100000 |
daxpyFMA | thrpt | 1 | 10 | 0.671310 | 0.158470 | ops/ms | 1000000 |
// vectorised daxpyFMA routine, code taken from main loop (you can still see the old code in pre/post loops) vmovdqu ymm0,ymmword ptr [r9+r13*8+10h] vfmadd231pd ymm0,ymm1,ymmword ptr [rbx+r13*8+10h] vmovdqu ymmword ptr [r9+r13*8+10h],ymm0
Paul Sandoz discussed
Math.fma
at Oracle Code One 2018.