@@ -46,17 +46,34 @@ int main(int argc, char **argv) {
4646 //
4747 // Using 16 threads (and no hyperthreading), hits 2080 GFlops (67% of peak)
4848 // and 1310 GFLops (85% of peak) respectively.
49+ //
50+ // On Apple M3 Max, single-threaded hits ~114 GFlops (89% of peak), and
51+ // ~1270 GFlops using 16 cores.
4952
5053 const int vec = target.natural_vector_size <float >();
5154
52- // Size the inner loop tiles to fit into the number of registers available
53- // on the target, using either 12 accumulator registers or 24.
54- const int inner_tile_x = 3 * vec;
55- const int inner_tile_y = (target.has_feature (Target::AVX512 ) || target.arch != Target::X86 ) ? 8 : 4 ;
55+ // On 64-bit ARM, there are 32 NEON registers. Using inner_tile_x=4*vec
56+ // with inner_tile_y=4 leaves 10 spare NEON registers, which lets LLVM
57+ // assign an independent GP base address to each A row. This avoids the
58+ // ld1r post-increment serial dependency chain that occurs with 8 rows
59+ // (where only 2 temp registers cycle between rows), and produces balanced
60+ // load/compute throughput (4 cycles each at 4 FP units and 2 load ports).
61+ const bool is_aarch64 = target.arch == Target::ARM && target.bits == 64 ;
62+ const bool is_avx512 = target.has_feature (Target::AVX512 );
5663
57- // The shape of the outer tiling
58- const int tile_y = matrix_size / 4 ;
59- const int tile_k = matrix_size / 16 ;
64+ // Size the inner loop tiles to fit into the number of registers available
65+ // on the target.
66+ // ARM64 NEON: 4×4=16 accumulators (22/32 NEON regs).
67+ // AVX-512: 3×8=24 accumulators (27/32 ZMM regs).
68+ // AVX2 (default): 3×4=12 accumulators.
69+ const int inner_tile_x = is_aarch64 ? 4 * vec : 3 * vec;
70+ const int inner_tile_y = is_avx512 ? 8 : 4 ;
71+
72+ // The shape of the outer tiling. On ARM64, use a narrower y-tile so the
73+ // B panel (inner_tile_x × matrix_k × 4 bytes = ~62KB) fits in L1
74+ // alongside the C accumulator buffer.
75+ const int tile_y = matrix_size / (is_aarch64 ? 8 : 4 );
76+ const int tile_k = matrix_size / (is_aarch64 ? 4 : 16 );
6077
6178 Var xy (" xy" ), xi (" xi" ), yi (" yi" ), yii (" yii" );
6279
@@ -144,16 +161,7 @@ int main(int argc, char **argv) {
144161 return 1 ;
145162 }
146163
147- // Uncomment to see the generated assembly.
148- /*
149- {
150- Target t("host-no_asserts-no_runtime-no_bounds_query");
151- out.compile_to_assembly("/dev/stdout", matrix_mul.infer_arguments(), t);
152- }
153- */
154-
155164 float gflops = 2 .0f * matrix_size * matrix_size * matrix_size / 1e9f;
156-
157165 printf (" Halide: %fms, %f GFLOP/s\n\n " , t * 1e3 , (gflops / t));
158166
159167 printf (" Success!\n " );
0 commit comments