summarylogtreecommitdiffstats
path: root/0001-radv-nir-lower_cmat-gfx12-fix-8bit-A-B-matrix-layout.patch
blob: eff46d2cc55ad14f664e4d18cf53d563855db43a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
From 91cd6c2ce7138612709b337f665bb62ca9e522cc Mon Sep 17 00:00:00 2001
From: Georg Lehmann <dadschoorse@gmail.com>
Date: Fri, 11 Apr 2025 13:13:45 +0200
Subject: [PATCH 01/13] radv/nir/lower_cmat/gfx12: fix 8bit A/B matrix layout

---
 .../nir/radv_nir_lower_cooperative_matrix.c   | 46 +++++++++++--------
 1 file changed, 26 insertions(+), 20 deletions(-)

diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
index 9e12b0964da..645df8663f6 100644
--- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
+++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
@@ -166,13 +166,15 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower
    if (params->gfx_level >= GFX12) {
       base_row = nir_udiv_imm(b, local_idx, 16);
 
-      if (desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 64) {
+      if ((desc.use == GLSL_CMAT_USE_ACCUMULATOR || radv_nir_cmat_bits(desc) == 8) && params->wave_size == 64) {
          /* Switch rows from lanes 16..31 to 32..47, offset right shift by -2
           * to get implicit * 4.
           */
          base_row = nir_ushr_imm(b, nir_bitfield_reverse(b, base_row), 30 - 2);
+      } else if ((desc.use == GLSL_CMAT_USE_ACCUMULATOR || radv_nir_cmat_bits(desc) == 8) && params->wave_size == 32) {
+         base_row = nir_imul_imm(b, base_row, 8);
       } else {
-         base_row = nir_imul_imm(b, base_row, desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 ? 8 : 4);
+         base_row = nir_imul_imm(b, base_row, 4);
       }
    } else {
       base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(b, local_idx, 16) : nir_imm_int(b, 0);
@@ -181,6 +183,24 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower
    return base_row;
 }
 
+static unsigned
+radv_get_row_iter(struct glsl_cmat_description desc, const lower_cmat_params *params, unsigned i)
+{
+   if (params->gfx_level >= GFX12) {
+      /* 8bit and ACC are indexed normally, 16bit A/B is weird. */
+      if (desc.use != GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 && radv_nir_cmat_bits(desc) >= 16)
+         return i + (i & 4);
+      else
+         return i;
+   } else {
+      if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
+         return i;
+      else
+         return i * params->wave_size / 16;
+   }
+}
+
+
 static nir_def *
 convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type)
 {
@@ -311,7 +331,6 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
 
                unsigned length = radv_nir_cmat_length(desc, &params);
                unsigned mul = radv_nir_cmat_length_mul(desc, &params);
-               unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
                nir_def *vars[16];
                if (mul > 1) {
                   for (unsigned i = 0; i < length; ++i)
@@ -324,16 +343,10 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
 
                for (unsigned i = 0; i < length / mul; ++i) {
                   nir_def *col_offset = inner_idx;
-                  nir_def *row_offset;
-                  uint32_t row_iter;
 
-                  if (gfx_level >= GFX12) {
-                     row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
-                  } else {
-                     row_iter = i * lanes_per_iter / 16;
-                  }
+                  uint32_t row_iter = radv_get_row_iter(desc, &params, i);
 
-                  row_offset = nir_iadd_imm(&b, base_row, row_iter);
+                  nir_def *row_offset = nir_iadd_imm(&b, base_row, row_iter);
 
                   if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
                      nir_def *tmp = col_offset;
@@ -385,7 +398,6 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
 
                unsigned length = radv_nir_cmat_length(desc, &params);
                unsigned mul = radv_nir_cmat_length_mul(desc, &params);
-               unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
                nir_def *vars[16];
                for (unsigned i = 0; i < length; ++i)
                   vars[i] = nir_channel(&b, src, i);
@@ -395,16 +407,10 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
 
                for (unsigned i = 0; i < length / mul; ++i) {
                   nir_def *col_offset = inner_idx;
-                  nir_def *row_offset;
-                  uint32_t row_iter;
 
-                  if (gfx_level >= GFX12) {
-                     row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
-                  } else {
-                     row_iter = i * lanes_per_iter / 16;
-                  }
+                  uint32_t row_iter = radv_get_row_iter(desc, &params, i);
 
-                  row_offset = nir_iadd_imm(&b, base_row, row_iter);
+                  nir_def *row_offset = nir_iadd_imm(&b, base_row, row_iter);
 
                   if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
                      nir_def *tmp = col_offset;
-- 
2.49.0