Multiplying Matrices, Fast and Slow
I recently read a very interesting blog post about exposing Intel SIMD intrinsics via a fork of the Scala compiler (scala-virtualized), which reports multiplicative improvements in throughput over HotSpot JIT compiled code. The academic paper (SIMD Intrinsics on Managed Language Runtimes), which has been accepted at CGO 2018, proposes a powerful alternative to the traditional JVM approach of pairing dumb programmers with a (hopefully) smart JIT compiler. Lightweight Modular Staging (LMS) allows the generation of an executable binary from a high level representation: handcrafted representations of vectorised algorithms, written in a dialect of Scala, can be compiled natively and later invoked with a single JNI call. This approach bypasses C2 without incurring excessive JNI costs. The freely available benchmarks can be easily run to reproduce the results in the paper, which is an achievement in itself, but some of the Java implementations used as baselines look less efficient than they could be. This post is about improving the efficiency of the Java matrix multiplication the LMS generated code is benchmarked against. Despite finding edge cases where autovectorisation fails, I find it is possible to get performance comparable to LMS with plain Java (and a JDK upgrade).
Two implementations of Java matrix multiplication are provided in the NGen benchmarks: JMMM.baseline
- a naive but cache unfriendly matrix multiplication - and JMMM.blocked
which is supplied as an improvement. JMMM.blocked
is something of a local maximum because it does manual loop unrolling: this actually removes the trigger for autovectorisation analysis. I provide a simple and cache-efficient Java implementation (with the same asymptotic complexity, the improvement is just technical) and benchmark these implementations using JDK8 and the soon to be released JDK10 separately.
public void fast(float[] a, float[] b, float[] c, int n) {
int in = 0;
for (int i = 0; i < n; ++i) {
int kn = 0;
for (int k = 0; k < n; ++k) {
float aik = a[in + k];
for (int j = 0; j < n; ++j) {
c[in + j] += aik * b[kn + j];
}
kn += n;
}
in += n;
}
}
With JDK 1.8.0_131, the “fast” implementation is only 2x faster than the blocked algorithm; this is nowhere near fast enough to match LMS. In fact, LMS does a lot better than 5x blocked (6x-8x) on my Skylake laptop at 2.6GHz, and performs between 2x and 4x better than the improved implementation. Flops / Cycle is calculated as size ^ 3 * 2 / CPU frequency Hz
.
==================================================== Benchmarking MMM.jMMM.fast (JVM implementation) ---------------------------------------------------- Size (N) | Flops / Cycle ---------------------------------------------------- 8 | 0.4994459272 32 | 1.0666533335 64 | 0.9429120397 128 | 0.9692385519 192 | 0.9796619688 256 | 1.0141446247 320 | 0.9894415771 384 | 1.0046245750 448 | 1.0221353392 512 | 0.9943527764 576 | 0.9952093603 640 | 0.9854689714 704 | 0.9947153752 768 | 1.0197765248 832 | 1.0479691069 896 | 1.0060121097 960 | 0.9937347412 1024 | 0.9056494897 ==================================================== ==================================================== Benchmarking MMM.nMMM.blocked (LMS generated) ---------------------------------------------------- Size (N) | Flops / Cycle ---------------------------------------------------- 8 | 0.2500390686 32 | 3.9999921875 64 | 4.1626523901 128 | 4.4618695374 192 | 3.9598982956 256 | 4.3737341517 320 | 4.2412225389 384 | 3.9640163416 448 | 4.0957167537 512 | 3.3801071278 576 | 4.1869326167 640 | 3.8225244883 704 | 3.8648224140 768 | 3.5240611589 832 | 3.7941562681 896 | 3.1735179981 960 | 2.5856903789 1024 | 1.7817152313 ==================================================== ==================================================== Benchmarking MMM.jMMM.blocked (JVM implementation) ---------------------------------------------------- Size (N) | Flops / Cycle ---------------------------------------------------- 8 | 0.3333854248 32 | 0.6336670915 64 | 0.5733484649 128 | 0.5987433798 192 | 0.5819900921 256 | 0.5473562109 320 | 0.5623263520 384 | 0.5583823292 448 | 0.5657882256 512 | 0.5430879470 576 | 0.5269635678 640 | 0.5595204791 704 | 0.5297557807 768 | 0.5493631388 832 | 0.5471832673 896 | 0.4769554752 960 | 0.4985080443 1024 | 0.4014589400 ====================================================
JDK10 is about to be released so it’s worth looking at the effect of recent improvements to C2, including better use of AVX2 and support for vectorised FMA. Since LMS depends on scala-virtualized, which currently only supports Scala 2.11, the LMS implementation cannot be run with a more recent JDK so its performance running in JDK10 could only be extrapolated. Since its raison d’être is to bypass C2, it could be reasonably assumed it is insulated from JVM performance improvements (or regressions). Measurements of floating point operations per cycle provide a sensible comparison, in any case.
Moving away from ScalaMeter, I created a JMH benchmark to see how matrix multiplication behaves in JDK10.
@OutputTimeUnit(TimeUnit.SECONDS)
@State(Scope.Benchmark)
public class MMM {
@Param({"8", "32", "64", "128", "192", "256", "320", "384", "448", "512" , "576", "640", "704", "768", "832", "896", "960", "1024"})
int size;
private float[] a;
private float[] b;
private float[] c;
@Setup(Level.Trial)
public void init() {
a = DataUtil.createFloatArray(size * size);
b = DataUtil.createFloatArray(size * size);
c = new float[size * size];
}
@Benchmark
public void fast(Blackhole bh) {
fast(a, b, c, size);
bh.consume(c);
}
@Benchmark
public void baseline(Blackhole bh) {
baseline(a, b, c, size);
bh.consume(c);
}
@Benchmark
public void blocked(Blackhole bh) {
blocked(a, b, c, size);
bh.consume(c);
}
//
// Baseline implementation of a Matrix-Matrix-Multiplication
//
public void baseline (float[] a, float[] b, float[] c, int n){
for (int i = 0; i < n; i += 1) {
for (int j = 0; j < n; j += 1) {
float sum = 0.0f;
for (int k = 0; k < n; k += 1) {
sum += a[i * n + k] * b[k * n + j];
}
c[i * n + j] = sum;
}
}
}
//
// Blocked version of MMM, reference implementation available at:
// http://csapp.cs.cmu.edu/2e/waside/waside-blocking.pdf
//
public void blocked(float[] a, float[] b, float[] c, int n) {
int BLOCK_SIZE = 8;
for (int kk = 0; kk < n; kk += BLOCK_SIZE) {
for (int jj = 0; jj < n; jj += BLOCK_SIZE) {
for (int i = 0; i < n; i++) {
for (int j = jj; j < jj + BLOCK_SIZE; ++j) {
float sum = c[i * n + j];
for (int k = kk; k < kk + BLOCK_SIZE; ++k) {
sum += a[i * n + k] * b[k * n + j];
}
c[i * n + j] = sum;
}
}
}
}
}
public void fast(float[] a, float[] b, float[] c, int n) {
int in = 0;
for (int i = 0; i < n; ++i) {
int kn = 0;
for (int k = 0; k < n; ++k) {
float aik = a[in + k];
for (int j = 0; j < n; ++j) {
c[in + j] = Math.fma(aik, b[kn + j], c[in + j]);
}
kn += n;
}
in += n;
}
}
}
Benchmark | Mode | Threads | Samples | Score | Score Error (99.9%) | Unit | Param: size | Ratio to blocked | Flops/Cycle |
---|---|---|---|---|---|---|---|---|---|
baseline | thrpt | 1 | 10 | 1228544.82 | 38793.17392 | ops/s | 8 | 1.061598336 | 0.483857652 |
baseline | thrpt | 1 | 10 | 22973.03402 | 1012.043446 | ops/s | 32 | 1.302266947 | 0.57906183 |
baseline | thrpt | 1 | 10 | 2943.088879 | 221.57475 | ops/s | 64 | 1.301414733 | 0.593471609 |
baseline | thrpt | 1 | 10 | 358.010135 | 9.342801 | ops/s | 128 | 1.292889618 | 0.577539747 |
baseline | thrpt | 1 | 10 | 105.758366 | 4.275503 | ops/s | 192 | 1.246415143 | 0.575804515 |
baseline | thrpt | 1 | 10 | 41.465557 | 1.112753 | ops/s | 256 | 1.430003946 | 0.535135851 |
baseline | thrpt | 1 | 10 | 20.479081 | 0.462547 | ops/s | 320 | 1.154267894 | 0.516198866 |
baseline | thrpt | 1 | 10 | 11.686685 | 0.263476 | ops/s | 384 | 1.186535349 | 0.509027985 |
baseline | thrpt | 1 | 10 | 7.344184 | 0.269656 | ops/s | 448 | 1.166421127 | 0.507965526 |
baseline | thrpt | 1 | 10 | 3.545153 | 0.108086 | ops/s | 512 | 0.81796657 | 0.366017216 |
baseline | thrpt | 1 | 10 | 3.789384 | 0.130934 | ops/s | 576 | 1.327168294 | 0.557048123 |
baseline | thrpt | 1 | 10 | 1.981957 | 0.040136 | ops/s | 640 | 1.020965271 | 0.399660104 |
baseline | thrpt | 1 | 10 | 1.76672 | 0.036386 | ops/s | 704 | 1.168272442 | 0.474179037 |
baseline | thrpt | 1 | 10 | 1.01026 | 0.049853 | ops/s | 768 | 0.845514112 | 0.352024966 |
baseline | thrpt | 1 | 10 | 1.115814 | 0.03803 | ops/s | 832 | 1.148752171 | 0.494331667 |
baseline | thrpt | 1 | 10 | 0.703561 | 0.110626 | ops/s | 896 | 0.938435436 | 0.389298235 |
baseline | thrpt | 1 | 10 | 0.629896 | 0.052448 | ops/s | 960 | 1.081741651 | 0.428685898 |
baseline | thrpt | 1 | 10 | 0.407772 | 0.019079 | ops/s | 1024 | 1.025356561 | 0.336801424 |
blocked | thrpt | 1 | 10 | 1157259.558 | 49097.48711 | ops/s | 8 | 1 | 0.455782226 |
blocked | thrpt | 1 | 10 | 17640.8025 | 1226.401298 | ops/s | 32 | 1 | 0.444656782 |
blocked | thrpt | 1 | 10 | 2261.453481 | 98.937035 | ops/s | 64 | 1 | 0.456020355 |
blocked | thrpt | 1 | 10 | 276.906961 | 22.851857 | ops/s | 128 | 1 | 0.446704605 |
blocked | thrpt | 1 | 10 | 84.850033 | 4.441454 | ops/s | 192 | 1 | 0.461968485 |
blocked | thrpt | 1 | 10 | 28.996813 | 7.585551 | ops/s | 256 | 1 | 0.374219842 |
blocked | thrpt | 1 | 10 | 17.742052 | 0.627629 | ops/s | 320 | 1 | 0.447208892 |
blocked | thrpt | 1 | 10 | 9.84942 | 0.367603 | ops/s | 384 | 1 | 0.429003641 |
blocked | thrpt | 1 | 10 | 6.29634 | 0.402846 | ops/s | 448 | 1 | 0.435490676 |
blocked | thrpt | 1 | 10 | 4.334105 | 0.384849 | ops/s | 512 | 1 | 0.447472097 |
blocked | thrpt | 1 | 10 | 2.85524 | 0.199102 | ops/s | 576 | 1 | 0.419726816 |
blocked | thrpt | 1 | 10 | 1.941258 | 0.10915 | ops/s | 640 | 1 | 0.391453182 |
blocked | thrpt | 1 | 10 | 1.51225 | 0.076621 | ops/s | 704 | 1 | 0.40588053 |
blocked | thrpt | 1 | 10 | 1.194847 | 0.063147 | ops/s | 768 | 1 | 0.416344283 |
blocked | thrpt | 1 | 10 | 0.971327 | 0.040421 | ops/s | 832 | 1 | 0.430320551 |
blocked | thrpt | 1 | 10 | 0.749717 | 0.042997 | ops/s | 896 | 1 | 0.414837526 |
blocked | thrpt | 1 | 10 | 0.582298 | 0.016725 | ops/s | 960 | 1 | 0.39629231 |
blocked | thrpt | 1 | 10 | 0.397688 | 0.043639 | ops/s | 1024 | 1 | 0.328472491 |
fast | thrpt | 1 | 10 | 1869676.345 | 76416.50848 | ops/s | 8 | 1.615606743 | 0.736364837 |
fast | thrpt | 1 | 10 | 48485.47216 | 1301.926828 | ops/s | 32 | 2.748484496 | 1.222132271 |
fast | thrpt | 1 | 10 | 6431.341657 | 153.905413 | ops/s | 64 | 2.843897392 | 1.296875098 |
fast | thrpt | 1 | 10 | 840.601821 | 45.998723 | ops/s | 128 | 3.035683242 | 1.356053685 |
fast | thrpt | 1 | 10 | 260.386996 | 13.022418 | ops/s | 192 | 3.068790745 | 1.417684611 |
fast | thrpt | 1 | 10 | 107.895708 | 6.584674 | ops/s | 256 | 3.720950575 | 1.392453537 |
fast | thrpt | 1 | 10 | 56.245336 | 2.729061 | ops/s | 320 | 3.170170846 | 1.417728592 |
fast | thrpt | 1 | 10 | 32.917996 | 2.196624 | ops/s | 384 | 3.342125323 | 1.433783932 |
fast | thrpt | 1 | 10 | 20.960189 | 2.077684 | ops/s | 448 | 3.328948087 | 1.449725854 |
fast | thrpt | 1 | 10 | 14.005186 | 0.7839 | ops/s | 512 | 3.231390564 | 1.445957112 |
fast | thrpt | 1 | 10 | 8.827584 | 0.883654 | ops/s | 576 | 3.091713481 | 1.297675056 |
fast | thrpt | 1 | 10 | 7.455607 | 0.442882 | ops/s | 640 | 3.840605937 | 1.503417416 |
fast | thrpt | 1 | 10 | 5.322894 | 0.464362 | ops/s | 704 | 3.519850554 | 1.428638807 |
fast | thrpt | 1 | 10 | 4.308522 | 0.153846 | ops/s | 768 | 3.605919419 | 1.501303934 |
fast | thrpt | 1 | 10 | 3.375274 | 0.106715 | ops/s | 832 | 3.474910097 | 1.495325228 |
fast | thrpt | 1 | 10 | 2.320152 | 0.367881 | ops/s | 896 | 3.094703735 | 1.28379924 |
fast | thrpt | 1 | 10 | 2.057478 | 0.150198 | ops/s | 960 | 3.533376381 | 1.400249889 |
fast | thrpt | 1 | 10 | 1.66255 | 0.181116 | ops/s | 1024 | 4.180538513 | 1.3731919 |
Interestingly, the blocked algorithm is now the worst native JVM implementation. The code generated by C2 got a lot faster, but peaks at 1.5 flops/cycle, which still doesn’t compete with LMS. Why? Taking a look at the assembly, it’s clear that the autovectoriser choked on the array offsets and produced scalar SSE2 code, just like the implementations in the paper. I wasn’t expecting this.
vmovss xmm5,dword ptr [rdi+rcx*4+10h]
vfmadd231ss xmm5,xmm6,xmm2
vmovss dword ptr [rdi+rcx*4+10h],xmm5
Is this the end of the story? No, with some hacks and the cost of array allocation and a copy or two, autovectorisation can be tricked into working again to generate faster code:
public void fast(float[] a, float[] b, float[] c, int n) {
float[] bBuffer = new float[n];
float[] cBuffer = new float[n];
int in = 0;
for (int i = 0; i < n; ++i) {
int kn = 0;
for (int k = 0; k < n; ++k) {
float aik = a[in + k];
System.arraycopy(b, kn, bBuffer, 0, n);
saxpy(n, aik, bBuffer, cBuffer);
kn += n;
}
System.arraycopy(cBuffer, 0, c, in, n);
Arrays.fill(cBuffer, 0f);
in += n;
}
}
private void saxpy(int n, float aik, float[] b, float[] c) {
for (int i = 0; i < n; ++i) {
c[i] += aik * b[i];
}
}
Adding this hack into the NGen benchmark (back in JDK 1.8.0_131) I get closer to the LMS generated code, and beat it beyond L3 cache residency (6MB). LMS is still faster when both matrices fit in L3 concurrently, but by percentage points rather than a multiple. The cost of the hacky array buffers gives the game up for small matrices.
==================================================== Benchmarking MMM.jMMM.fast (JVM implementation) ---------------------------------------------------- Size (N) | Flops / Cycle ---------------------------------------------------- 8 | 0.2500390686 32 | 0.7710872405 64 | 1.1302489072 128 | 2.5113453810 192 | 2.9525859816 256 | 3.1180920385 320 | 3.1081563593 384 | 3.1458423577 448 | 3.0493148252 512 | 3.0551158263 576 | 3.1430376938 640 | 3.2169923048 704 | 3.1026513283 768 | 2.4190053777 832 | 3.3358586705 896 | 3.0755689237 960 | 2.9996690697 1024 | 2.2935654309 ==================================================== ==================================================== Benchmarking MMM.nMMM.blocked (LMS generated) ---------------------------------------------------- Size (N) | Flops / Cycle ---------------------------------------------------- 8 | 1.0001562744 32 | 5.3330416826 64 | 5.8180867784 128 | 5.1717318641 192 | 5.1639907462 256 | 4.3418618628 320 | 5.2536572701 384 | 4.0801359215 448 | 4.1337007093 512 | 3.2678160754 576 | 3.7973028890 640 | 3.3557513664 704 | 4.0103133240 768 | 3.4188362575 832 | 3.2189488327 896 | 3.2316685219 960 | 2.9985655539 1024 | 1.7750946796 ====================================================
With the benchmark below I calculate flops/cycle with improved JDK10 autovectorisation.
@Benchmark
public void fastBuffered(Blackhole bh) {
fastBuffered(a, b, c, size);
bh.consume(c);
}
public void fastBuffered(float[] a, float[] b, float[] c, int n) {
float[] bBuffer = new float[n];
float[] cBuffer = new float[n];
int in = 0;
for (int i = 0; i < n; ++i) {
int kn = 0;
for (int k = 0; k < n; ++k) {
float aik = a[in + k];
System.arraycopy(b, kn, bBuffer, 0, n);
saxpy(n, aik, bBuffer, cBuffer);
kn += n;
}
System.arraycopy(cBuffer, 0, c, in, n);
Arrays.fill(cBuffer, 0f);
in += n;
}
}
private void saxpy(int n, float aik, float[] b, float[] c) {
for (int i = 0; i < n; ++i) {
c[i] = Math.fma(aik, b[i], c[i]);
}
}
Just as in the modified NGen benchmark, this starts paying off once the matrices have 64 rows and columns. Finally, and it took an upgrade and a hack, I breached 4 Flops per cycle:
Benchmark | Mode | Threads | Samples | Score | Score Error (99.9%) | Unit | Param: size | Flops / Cycle |
---|---|---|---|---|---|---|---|---|
fastBuffered | thrpt | 1 | 10 | 1047184.034 | 63532.95095 | ops/s | 8 | 0.412429404 |
fastBuffered | thrpt | 1 | 10 | 58373.56367 | 3239.615866 | ops/s | 32 | 1.471373026 |
fastBuffered | thrpt | 1 | 10 | 12099.41654 | 497.33988 | ops/s | 64 | 2.439838038 |
fastBuffered | thrpt | 1 | 10 | 2136.50264 | 105.038006 | ops/s | 128 | 3.446592911 |
fastBuffered | thrpt | 1 | 10 | 673.470622 | 102.577237 | ops/s | 192 | 3.666730488 |
fastBuffered | thrpt | 1 | 10 | 305.541519 | 25.959163 | ops/s | 256 | 3.943181586 |
fastBuffered | thrpt | 1 | 10 | 158.437372 | 6.708384 | ops/s | 320 | 3.993596774 |
fastBuffered | thrpt | 1 | 10 | 88.283718 | 7.58883 | ops/s | 384 | 3.845306266 |
fastBuffered | thrpt | 1 | 10 | 58.574507 | 4.248521 | ops/s | 448 | 4.051345968 |
fastBuffered | thrpt | 1 | 10 | 37.183635 | 4.360319 | ops/s | 512 | 3.839002314 |
fastBuffered | thrpt | 1 | 10 | 29.949884 | 0.63346 | ops/s | 576 | 4.40270151 |
fastBuffered | thrpt | 1 | 10 | 20.715833 | 4.175897 | ops/s | 640 | 4.177331789 |
fastBuffered | thrpt | 1 | 10 | 10.824837 | 0.902983 | ops/s | 704 | 2.905333492 |
fastBuffered | thrpt | 1 | 10 | 8.285254 | 1.438701 | ops/s | 768 | 2.886995686 |
fastBuffered | thrpt | 1 | 10 | 6.17029 | 0.746537 | ops/s | 832 | 2.733582608 |
fastBuffered | thrpt | 1 | 10 | 4.828872 | 1.316901 | ops/s | 896 | 2.671937962 |
fastBuffered | thrpt | 1 | 10 | 3.6343 | 1.293923 | ops/s | 960 | 2.473381573 |
fastBuffered | thrpt | 1 | 10 | 2.458296 | 0.171224 | ops/s | 1024 | 2.030442485 |
The code generated for the core of the loop looks better now:
vmovdqu ymm1,ymmword ptr [r13+r11*4+10h]
vfmadd231ps ymm1,ymm3,ymmword ptr [r14+r11*4+10h]
vmovdqu ymmword ptr [r13+r11*4+10h],ymm1
These benchmark results can be compared on a line chart.
Given this improvement, it would be exciting to see how LMS can profit from JDK9 or JDK10 - does LMS provide the impetus to resume maintenance of scala-virtualized? L3 cache, which the LMS generated code seems to depend on for throughput, is typically shared between cores: a single thread rarely enjoys exclusive access. I would like to see benchmarks for the LMS generated code in the presence of concurrency.