1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
From 91cd6c2ce7138612709b337f665bb62ca9e522cc Mon Sep 17 00:00:00 2001
From: Georg Lehmann <dadschoorse@gmail.com>
Date: Fri, 11 Apr 2025 13:13:45 +0200
Subject: [PATCH 01/13] radv/nir/lower_cmat/gfx12: fix 8bit A/B matrix layout
---
.../nir/radv_nir_lower_cooperative_matrix.c | 46 +++++++++++--------
1 file changed, 26 insertions(+), 20 deletions(-)
diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
index 9e12b0964da..645df8663f6 100644
--- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
+++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
@@ -166,13 +166,15 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower
if (params->gfx_level >= GFX12) {
base_row = nir_udiv_imm(b, local_idx, 16);
- if (desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 64) {
+ if ((desc.use == GLSL_CMAT_USE_ACCUMULATOR || radv_nir_cmat_bits(desc) == 8) && params->wave_size == 64) {
/* Switch rows from lanes 16..31 to 32..47, offset right shift by -2
* to get implicit * 4.
*/
base_row = nir_ushr_imm(b, nir_bitfield_reverse(b, base_row), 30 - 2);
+ } else if ((desc.use == GLSL_CMAT_USE_ACCUMULATOR || radv_nir_cmat_bits(desc) == 8) && params->wave_size == 32) {
+ base_row = nir_imul_imm(b, base_row, 8);
} else {
- base_row = nir_imul_imm(b, base_row, desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 ? 8 : 4);
+ base_row = nir_imul_imm(b, base_row, 4);
}
} else {
base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(b, local_idx, 16) : nir_imm_int(b, 0);
@@ -181,6 +183,24 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower
return base_row;
}
+static unsigned
+radv_get_row_iter(struct glsl_cmat_description desc, const lower_cmat_params *params, unsigned i)
+{
+ if (params->gfx_level >= GFX12) {
+ /* 8bit and ACC are indexed normally, 16bit A/B is weird. */
+ if (desc.use != GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 && radv_nir_cmat_bits(desc) >= 16)
+ return i + (i & 4);
+ else
+ return i;
+ } else {
+ if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
+ return i;
+ else
+ return i * params->wave_size / 16;
+ }
+}
+
+
static nir_def *
convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type)
{
@@ -311,7 +331,6 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
unsigned length = radv_nir_cmat_length(desc, ¶ms);
unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms);
- unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
if (mul > 1) {
for (unsigned i = 0; i < length; ++i)
@@ -324,16 +343,10 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
for (unsigned i = 0; i < length / mul; ++i) {
nir_def *col_offset = inner_idx;
- nir_def *row_offset;
- uint32_t row_iter;
- if (gfx_level >= GFX12) {
- row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
- } else {
- row_iter = i * lanes_per_iter / 16;
- }
+ uint32_t row_iter = radv_get_row_iter(desc, ¶ms, i);
- row_offset = nir_iadd_imm(&b, base_row, row_iter);
+ nir_def *row_offset = nir_iadd_imm(&b, base_row, row_iter);
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
nir_def *tmp = col_offset;
@@ -385,7 +398,6 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
unsigned length = radv_nir_cmat_length(desc, ¶ms);
unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms);
- unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
for (unsigned i = 0; i < length; ++i)
vars[i] = nir_channel(&b, src, i);
@@ -395,16 +407,10 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
for (unsigned i = 0; i < length / mul; ++i) {
nir_def *col_offset = inner_idx;
- nir_def *row_offset;
- uint32_t row_iter;
- if (gfx_level >= GFX12) {
- row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
- } else {
- row_iter = i * lanes_per_iter / 16;
- }
+ uint32_t row_iter = radv_get_row_iter(desc, ¶ms, i);
- row_offset = nir_iadd_imm(&b, base_row, row_iter);
+ nir_def *row_offset = nir_iadd_imm(&b, base_row, row_iter);
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
nir_def *tmp = col_offset;
--
2.49.0
|