summarylogtreecommitdiffstats
path: root/0002-radv-nir-lower_cmat-gfx12-fix-A-B-8bit-16bit-convers.patch
blob: 15d80aef442b19b6d2ecec98ba5a253aeeda492e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
From bbb844c19b327df3b99f387df2f1d0c3e2d6b9fe Mon Sep 17 00:00:00 2001
From: Georg Lehmann <dadschoorse@gmail.com>
Date: Fri, 11 Apr 2025 13:38:20 +0200
Subject: [PATCH 02/13] radv/nir/lower_cmat/gfx12: fix A/B 8bit <-> 16bit
 conversions

---
 .../nir/radv_nir_lower_cooperative_matrix.c   | 47 ++++++++++++++++++-
 1 file changed, 46 insertions(+), 1 deletion(-)

diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
index 645df8663f6..5aece711987 100644
--- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
+++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c
@@ -200,7 +200,6 @@ radv_get_row_iter(struct glsl_cmat_description desc, const lower_cmat_params *pa
    }
 }
 
-
 static nir_def *
 convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, enum glsl_base_type dst_type)
 {
@@ -221,6 +220,44 @@ convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, en
    return nir_build_alu1(b, op, src);
 }
 
+static nir_def *
+radv_swizzle_gfx12_8bit_mat(nir_builder *b, nir_def *src, unsigned wave_size)
+{
+   assert(src->bit_size == 8);
+
+   src = nir_extract_bits(b, &src, 1, 0, src->num_components / 4, 32);
+
+   nir_def *res;
+
+   if (wave_size == 64) {
+      assert(src->num_components == 1);
+
+      nir_def *swapped = nir_rotate(b, src, nir_imm_int(b, 32), .cluster_size = 64);
+      swapped = nir_rotate(b, swapped, nir_imm_int(b, 16), .cluster_size = 32);
+
+      nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xffffffff0000ull, 64));
+
+      res = nir_bcsel(b, cond, swapped, src);
+   } else {
+      assert(src->num_components == 2);
+
+      nir_def *src0 = nir_channel(b, src, 0);
+      nir_def *src1 = nir_channel(b, src, 1);
+
+      nir_def *swapped0 = nir_rotate(b, src0, nir_imm_int(b, 16), .cluster_size = 32);
+      nir_def *swapped1 = nir_rotate(b, src1, nir_imm_int(b, 16), .cluster_size = 32);
+
+      nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xffff0000, 32));
+
+      nir_def *res0 = nir_bcsel(b, cond, swapped1, src0);
+      nir_def *res1 = nir_bcsel(b, cond, swapped0, src1);
+
+      res = nir_vec2(b, res0, res1);
+   }
+
+   return nir_extract_bits(b, &res, 1, 0, res->num_components * 4, 8);
+}
+
 bool
 radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
 {
@@ -490,8 +527,16 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
                   src = nir_vec(&b, components, src->num_components / scale);
                }
 
+               if (dst_desc.use != GLSL_CMAT_USE_ACCUMULATOR && gfx_level >= GFX12 &&
+                   radv_nir_cmat_bits(src_desc) == 8 && radv_nir_cmat_bits(dst_desc) > 8)
+                  src = radv_swizzle_gfx12_8bit_mat(&b, src, wave_size);
+
                nir_def *ret = convert_base_type(&b, src, src_element_type, dst_element_type);
 
+               if (dst_desc.use != GLSL_CMAT_USE_ACCUMULATOR && gfx_level >= GFX12 &&
+                   radv_nir_cmat_bits(dst_desc) == 8 && radv_nir_cmat_bits(src_desc) > 8)
+                  ret = radv_swizzle_gfx12_8bit_mat(&b, ret, wave_size);
+
                if (dst_mul > src_mul) {
                   nir_def *components[NIR_MAX_VEC_COMPONENTS];
                   unsigned scale = dst_mul / src_mul;
-- 
2.49.0