summarylogtreecommitdiffstats
path: root/35533.patch
blob: 8d5d4b1ebc7e234464f8c80292b1201bf4935503 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
From 69193926ba7bdc63fcb5f7ed3f6b62331daf7eb3 Mon Sep 17 00:00:00 2001
From: Konstantin Seurer <konstantin.seurer@gmail.com>
Date: Sat, 14 Jun 2025 12:22:06 +0200
Subject: [PATCH] radv: Optimize ray tracing position fetch

Gets rid of a lot of indirection when fetching triangle positions.
Storing the primitive address increases register pressure by a bit but
the traversal shader which should have the highest register demand
should not be affected when position fetch is not used.

Totals:
Instrs: 4021686 -> 4022435 (+0.02%); split: -0.01%, +0.03%
CodeSize: 21235812 -> 21235832 (+0.00%); split: -0.02%, +0.02%
Latency: 23402275 -> 23412110 (+0.04%); split: -0.04%, +0.09%
InvThroughput: 4352818 -> 4352206 (-0.01%); split: -0.04%, +0.02%
VClause: 101906 -> 102058 (+0.15%); split: -0.03%, +0.18%
Copies: 342210 -> 342368 (+0.05%); split: -0.09%, +0.14%
Branches: 114988 -> 114993 (+0.00%)
PreVGPRs: 26551 -> 27111 (+2.11%)
VALU: 2249366 -> 2249524 (+0.01%); split: -0.01%, +0.02%
SALU: 529828 -> 529808 (-0.00%); split: -0.01%, +0.00%
---
 src/amd/common/ac_shader_args.h               |  1 +
 src/amd/vulkan/bvh/bvh.h                      |  6 +--
 src/amd/vulkan/bvh/copy_blas_addrs_gfx12.comp |  4 +-
 src/amd/vulkan/bvh/encode.h                   |  2 -
 .../vulkan/nir/radv_nir_lower_ray_queries.c   | 12 +++--
 src/amd/vulkan/nir/radv_nir_rt_common.c       | 52 +++----------------
 src/amd/vulkan/nir/radv_nir_rt_common.h       |  3 +-
 src/amd/vulkan/nir/radv_nir_rt_shader.c       | 42 +++++++++++----
 src/amd/vulkan/radv_acceleration_structure.c  | 28 ----------
 src/amd/vulkan/radv_pipeline_rt.c             | 23 +++++---
 src/amd/vulkan/radv_pipeline_rt.h             |  1 +
 src/amd/vulkan/radv_shader.h                  |  2 +-
 src/amd/vulkan/radv_shader_args.c             |  1 +
 src/compiler/nir/nir_intrinsics.py            | 11 ++--
 14 files changed, 73 insertions(+), 115 deletions(-)

diff --git a/src/amd/common/ac_shader_args.h b/src/amd/common/ac_shader_args.h
index 97ac63772335a..4ae518da4c8d4 100644
--- a/src/amd/common/ac_shader_args.h
+++ b/src/amd/common/ac_shader_args.h
@@ -214,6 +214,7 @@ struct ac_shader_args {
       struct ac_arg accel_struct;
       struct ac_arg primitive_id;
       struct ac_arg instance_addr;
+      struct ac_arg primitive_addr;
       struct ac_arg geometry_id_and_flags;
       struct ac_arg hit_kind;
    } rt;
diff --git a/src/amd/vulkan/bvh/bvh.h b/src/amd/vulkan/bvh/bvh.h
index 07ad8190b77b9..b43bda913cee8 100644
--- a/src/amd/vulkan/bvh/bvh.h
+++ b/src/amd/vulkan/bvh/bvh.h
@@ -63,7 +63,6 @@ struct radv_accel_struct_header {
    /* Everything after this gets updated/copied from the CPU. */
    uint32_t geometry_type;
    uint32_t geometry_count;
-   uint32_t primitive_base_indices_offset;
    uint64_t instance_offset;
    uint64_t instance_count;
    uint32_t leaf_node_offsets_offset;
@@ -158,11 +157,8 @@ struct radv_gfx12_instance_node_user_data {
    uint32_t custom_instance;
    uint32_t instance_index;
    uint32_t bvh_offset;
-   uint32_t padding;
-   uint64_t blas_addr;
-   uint32_t primitive_base_indices_offset;
    uint32_t leaf_node_offsets_offset;
-   uint32_t unused[12];
+   uint32_t unused[16];
 };
 
 /* Size of the primitive header section in bits. */
diff --git a/src/amd/vulkan/bvh/copy_blas_addrs_gfx12.comp b/src/amd/vulkan/bvh/copy_blas_addrs_gfx12.comp
index 40e9625626778..3145e89fd059a 100644
--- a/src/amd/vulkan/bvh/copy_blas_addrs_gfx12.comp
+++ b/src/amd/vulkan/bvh/copy_blas_addrs_gfx12.comp
@@ -50,7 +50,8 @@ main(void)
          REF(radv_gfx12_instance_node_user_data)(instance_addr + SIZEOF(radv_gfx12_instance_node));
 
       if (args.mode == RADV_COPY_MODE_SERIALIZE) {
-         DEREF(INDEX(uint64_t, blas_addrs, i)) = DEREF(instance_data).blas_addr;
+         DEREF(INDEX(uint64_t, blas_addrs, i)) =
+            node_to_addr(DEREF(instance_node).pointer_flags_bvh_addr) - DEREF(instance_data).bvh_offset;
       } else {
          uint32_t bvh_offset = DEREF(instance_data).bvh_offset;
 
@@ -59,7 +60,6 @@ main(void)
          uint64_t blas_addr = DEREF(INDEX(uint64_t, blas_addrs, i));
          DEREF(instance_node).pointer_flags_bvh_addr =
             (pointer_flags_bvh_addr & 0xFFC0000000000000ul) | addr_to_node(blas_addr + bvh_offset);
-         DEREF(instance_data).blas_addr = blas_addr;
       }
    }
 }
diff --git a/src/amd/vulkan/bvh/encode.h b/src/amd/vulkan/bvh/encode.h
index 25abeb2ac8d91..4f6abb2567380 100644
--- a/src/amd/vulkan/bvh/encode.h
+++ b/src/amd/vulkan/bvh/encode.h
@@ -316,8 +316,6 @@ radv_encode_instance_gfx12(VOID_REF dst, vk_ir_instance_node src)
    DEREF(user_data).custom_instance = src.custom_instance_and_mask & 0xffffff;
    DEREF(user_data).instance_index = src.instance_id;
    DEREF(user_data).bvh_offset = blas_header.bvh_offset;
-   DEREF(user_data).blas_addr = src.base_ptr;
-   DEREF(user_data).primitive_base_indices_offset = blas_header.primitive_base_indices_offset;
    DEREF(user_data).leaf_node_offsets_offset = blas_header.leaf_node_offsets_offset;
 }
 
diff --git a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
index 2dc30e5a47425..d4847a32b4791 100644
--- a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
+++ b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
@@ -21,6 +21,7 @@
 #define MAX_SCRATCH_STACK_ENTRY_COUNT 76
 
 enum radv_ray_intersection_field {
+   radv_ray_intersection_primitive_addr,
    radv_ray_intersection_primitive_id,
    radv_ray_intersection_geometry_id_and_flags,
    radv_ray_intersection_instance_addr,
@@ -44,6 +45,7 @@ radv_get_intersection_type()
       .name = #field_name,                                                                                             \
    }
 
+   FIELD(primitive_addr, glsl_uint64_t_type());
    FIELD(primitive_id, glsl_uint_type());
    FIELD(geometry_id_and_flags, glsl_uint_type());
    FIELD(instance_addr, glsl_uint64_t_type());
@@ -186,6 +188,7 @@ copy_candidate_to_closest(nir_builder *b, nir_deref_instr *rq)
    nir_deref_instr *candidate = rq_deref(b, rq, candidate);
 
    isec_copy(b, closest, candidate, barycentrics);
+   isec_copy(b, closest, candidate, primitive_addr);
    isec_copy(b, closest, candidate, geometry_id_and_flags);
    isec_copy(b, closest, candidate, instance_addr);
    isec_copy(b, closest, candidate, intersection_type);
@@ -373,11 +376,8 @@ lower_rq_load(struct radv_device *device, nir_builder *b, nir_intrinsic_instr *i
    case nir_ray_query_value_world_ray_origin:
       return rq_load(b, rq, origin);
    case nir_ray_query_value_intersection_triangle_vertex_positions: {
-      nir_def *instance_node_addr = isec_load(b, intersection, instance_addr);
-      nir_def *primitive_id = isec_load(b, intersection, primitive_id);
-      nir_def *geometry_id = nir_iand_imm(b, isec_load(b, intersection, geometry_id_and_flags), 0xFFFFFF);
-      return radv_load_vertex_position(device, b, instance_node_addr, geometry_id, primitive_id,
-                                       nir_intrinsic_column(instr));
+      nir_def *primitive_addr = isec_load(b, intersection, primitive_addr);
+      return radv_load_vertex_position(device, b, primitive_addr, nir_intrinsic_column(instr));
    }
    default:
       unreachable("Invalid nir_ray_query_value!");
@@ -399,6 +399,7 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
 
    nir_deref_instr *candidate = rq_deref(b, data->rq, candidate);
 
+   isec_store(b, candidate, primitive_addr, intersection->node_addr);
    isec_store(b, candidate, primitive_id, intersection->primitive_id);
    isec_store(b, candidate, geometry_id_and_flags, intersection->geometry_id_and_flags);
    isec_store(b, candidate, opaque, intersection->opaque);
@@ -416,6 +417,7 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
    nir_deref_instr *candidate = rq_deref(b, data->rq, candidate);
 
    isec_store(b, candidate, barycentrics, intersection->barycentrics);
+   isec_store(b, candidate, primitive_addr, intersection->base.node_addr);
    isec_store(b, candidate, primitive_id, intersection->base.primitive_id);
    isec_store(b, candidate, geometry_id_and_flags, intersection->base.geometry_id_and_flags);
    isec_store(b, candidate, t, intersection->t);
diff --git a/src/amd/vulkan/nir/radv_nir_rt_common.c b/src/amd/vulkan/nir/radv_nir_rt_common.c
index c41cb6258642c..a28c73fb047a2 100644
--- a/src/amd/vulkan/nir/radv_nir_rt_common.c
+++ b/src/amd/vulkan/nir/radv_nir_rt_common.c
@@ -319,42 +319,17 @@ nir_build_vec3_mat_mult(nir_builder *b, nir_def *vec, nir_def *matrix[], bool tr
 }
 
 nir_def *
-radv_load_vertex_position(struct radv_device *device, nir_builder *b, nir_def *instance_addr, nir_def *geometry_id,
-                          nir_def *primitive_id, uint32_t index)
+radv_load_vertex_position(struct radv_device *device, nir_builder *b, nir_def *primitive_addr, uint32_t index)
 {
    const struct radv_physical_device *pdev = radv_device_physical(device);
 
    if (radv_use_bvh8(pdev)) {
-      nir_def *addr_offsets =
-         nir_build_load_global(b, 4, 32,
-                               nir_iadd_imm(b, instance_addr,
-                                            sizeof(struct radv_gfx12_instance_node) +
-                                               offsetof(struct radv_gfx12_instance_node_user_data, blas_addr)));
-      nir_def *bvh_offset =
-         nir_build_load_global(b, 1, 32,
-                               nir_iadd_imm(b, instance_addr,
-                                            sizeof(struct radv_gfx12_instance_node) +
-                                               offsetof(struct radv_gfx12_instance_node_user_data, bvh_offset)));
-
-      nir_def *addr = nir_pack_64_2x32(b, nir_channels(b, addr_offsets, 0x3));
-
-      nir_def *base_index_offset =
-         nir_iadd(b, nir_channel(b, addr_offsets, 2), nir_imul_imm(b, geometry_id, sizeof(uint32_t)));
-      nir_def *base_index = nir_build_load_global(b, 1, 32, nir_iadd(b, addr, nir_u2u64(b, base_index_offset)));
-
-      nir_def *offset_offset = nir_iadd(b, nir_channel(b, addr_offsets, 3),
-                                        nir_imul_imm(b, nir_iadd(b, base_index, primitive_id), sizeof(uint32_t)));
-      nir_def *offset = nir_build_load_global(b, 1, 32, nir_iadd(b, addr, nir_u2u64(b, offset_offset)));
-      offset = nir_iadd(b, offset, bvh_offset);
-
       /* Assume that vertices are uncompressed. */
-      offset = nir_iadd_imm(b, offset,
-                            ROUND_DOWN_TO(RADV_GFX12_PRIMITIVE_NODE_HEADER_SIZE / 8, 4) + index * 3 * sizeof(float));
-
+      uint32_t offset = ROUND_DOWN_TO(RADV_GFX12_PRIMITIVE_NODE_HEADER_SIZE / 8, 4) + index * 3 * sizeof(float);
       nir_def *data[4];
       for (uint32_t i = 0; i < ARRAY_SIZE(data); i++) {
-         data[i] = nir_build_load_global(b, 1, 32, nir_iadd(b, addr, nir_u2u64(b, offset)));
-         offset = nir_iadd_imm(b, offset, 4);
+         data[i] = nir_build_load_global(b, 1, 32, nir_iadd_imm(b, primitive_addr, offset));
+         offset += 4;
       }
 
       uint32_t subdword_offset = RADV_GFX12_PRIMITIVE_NODE_HEADER_SIZE % 32;
@@ -369,23 +344,8 @@ radv_load_vertex_position(struct radv_device *device, nir_builder *b, nir_def *i
       return nir_vec3(b, vertices[0], vertices[1], vertices[2]);
    }
 
-   nir_def *bvh_addr_id =
-      nir_build_load_global(b, 1, 64, nir_iadd_imm(b, instance_addr, offsetof(struct radv_bvh_instance_node, bvh_ptr)));
-   nir_def *bvh_addr = build_node_to_addr(device, b, bvh_addr_id, true);
-
-   nir_def *bvh_offset = nir_build_load_global(
-      b, 1, 32, nir_iadd_imm(b, instance_addr, offsetof(struct radv_bvh_instance_node, bvh_offset)));
-   nir_def *accel_struct = nir_isub(b, bvh_addr, nir_u2u64(b, bvh_offset));
-   nir_def *base_indices_offset = nir_build_load_global(
-      b, 1, 32,
-      nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, primitive_base_indices_offset)));
-   nir_def *base_index_offset = nir_iadd(b, base_indices_offset, nir_imul_imm(b, geometry_id, sizeof(uint32_t)));
-   nir_def *base_index = nir_build_load_global(b, 1, 32, nir_iadd(b, accel_struct, nir_u2u64(b, base_index_offset)));
-
-   nir_def *offset = nir_imul_imm(b, nir_iadd(b, base_index, primitive_id), sizeof(struct radv_bvh_triangle_node));
-   offset = nir_iadd_imm(b, offset, sizeof(struct radv_bvh_box32_node) + index * 3 * sizeof(float));
-
-   return nir_build_load_global(b, 3, 32, nir_iadd(b, bvh_addr, nir_u2u64(b, offset)));
+   uint32_t offset = index * 3 * sizeof(float);
+   return nir_build_load_global(b, 3, 32, nir_iadd_imm(b, primitive_addr, offset));
 }
 
 void
diff --git a/src/amd/vulkan/nir/radv_nir_rt_common.h b/src/amd/vulkan/nir/radv_nir_rt_common.h
index aaef6ca72cd15..014e025fc9451 100644
--- a/src/amd/vulkan/nir/radv_nir_rt_common.h
+++ b/src/amd/vulkan/nir/radv_nir_rt_common.h
@@ -18,8 +18,7 @@ nir_def *build_addr_to_node(struct radv_device *device, nir_builder *b, nir_def
 
 nir_def *nir_build_vec3_mat_mult(nir_builder *b, nir_def *vec, nir_def *matrix[], bool translation);
 
-nir_def *radv_load_vertex_position(struct radv_device *device, nir_builder *b, nir_def *instance_addr,
-                                   nir_def *geometry_id, nir_def *primitive_id, uint32_t index);
+nir_def *radv_load_vertex_position(struct radv_device *device, nir_builder *b, nir_def *primitive_addr, uint32_t index);
 
 void radv_load_wto_matrix(struct radv_device *device, nir_builder *b, nir_def *instance_addr, nir_def **out);
 
diff --git a/src/amd/vulkan/nir/radv_nir_rt_shader.c b/src/amd/vulkan/nir/radv_nir_rt_shader.c
index 086726052ecc2..89651e8f3e9b8 100644
--- a/src/amd/vulkan/nir/radv_nir_rt_shader.c
+++ b/src/amd/vulkan/nir/radv_nir_rt_shader.c
@@ -202,6 +202,7 @@ struct rt_variables {
    nir_variable *tmax;
 
    /* Properties of the primitive currently being visited. */
+   nir_variable *primitive_addr;
    nir_variable *primitive_id;
    nir_variable *geometry_id_and_flags;
    nir_variable *instance_addr;
@@ -253,6 +254,7 @@ create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipe
    vars.direction = nir_variable_create(shader, nir_var_shader_temp, vec3_type, "ray_direction");
    vars.tmax = nir_variable_create(shader, nir_var_shader_temp, glsl_float_type(), "ray_tmax");
 
+   vars.primitive_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "primitive_addr");
    vars.primitive_id = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "primitive_id");
    vars.geometry_id_and_flags =
       nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "geometry_id_and_flags");
@@ -299,6 +301,7 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, const s
    _mesa_hash_table_insert(var_remap, src->direction, dst->direction);
    _mesa_hash_table_insert(var_remap, src->tmax, dst->tmax);
 
+   _mesa_hash_table_insert(var_remap, src->primitive_addr, dst->primitive_addr);
    _mesa_hash_table_insert(var_remap, src->primitive_id, dst->primitive_id);
    _mesa_hash_table_insert(var_remap, src->geometry_id_and_flags, dst->geometry_id_and_flags);
    _mesa_hash_table_insert(var_remap, src->instance_addr, dst->instance_addr);
@@ -322,6 +325,8 @@ create_inner_vars(nir_builder *b, const struct rt_variables *vars)
    inner_vars.idx = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_idx");
    inner_vars.shader_record_ptr =
       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_shader_record_ptr");
+   inner_vars.primitive_addr =
+      nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "inner_primitive_addr");
    inner_vars.primitive_id =
       nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_primitive_id");
    inner_vars.geometry_id_and_flags =
@@ -649,10 +654,11 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
    }
    case nir_intrinsic_execute_closest_hit_amd: {
       nir_store_var(b, vars->tmax, intr->src[1].ssa, 0x1);
-      nir_store_var(b, vars->primitive_id, intr->src[2].ssa, 0x1);
-      nir_store_var(b, vars->instance_addr, intr->src[3].ssa, 0x1);
-      nir_store_var(b, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
-      nir_store_var(b, vars->hit_kind, intr->src[5].ssa, 0x1);
+      nir_store_var(b, vars->primitive_addr, intr->src[2].ssa, 0x1);
+      nir_store_var(b, vars->primitive_id, intr->src[3].ssa, 0x1);
+      nir_store_var(b, vars->instance_addr, intr->src[4].ssa, 0x1);
+      nir_store_var(b, vars->geometry_id_and_flags, intr->src[5].ssa, 0x1);
+      nir_store_var(b, vars->hit_kind, intr->src[6].ssa, 0x1);
       load_sbt_entry(b, vars, intr->src[0].ssa, SBT_HIT, SBT_RECURSIVE_PTR);
 
       nir_def *should_return =
@@ -689,11 +695,8 @@ radv_lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_data)
       break;
    }
    case nir_intrinsic_load_ray_triangle_vertex_positions: {
-      nir_def *instance_node_addr = nir_load_var(b, vars->instance_addr);
-      nir_def *primitive_id = nir_load_var(b, vars->primitive_id);
-      nir_def *geometry_id = nir_iand_imm(b, nir_load_var(b, vars->geometry_id_and_flags), 0xFFFFFFF);
-      ret = radv_load_vertex_position(vars->device, b, instance_node_addr, geometry_id, primitive_id,
-                                      nir_intrinsic_column(intr));
+      nir_def *primitive_addr = nir_load_var(b, vars->primitive_addr);
+      ret = radv_load_vertex_position(vars->device, b, primitive_addr, nir_intrinsic_column(intr));
       break;
    }
    default:
@@ -1385,6 +1388,7 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
    {
       struct rt_variables inner_vars = create_inner_vars(b, data->vars);
 
+      nir_store_var(b, inner_vars.primitive_addr, intersection->base.node_addr, 1);
       nir_store_var(b, inner_vars.primitive_id, intersection->base.primitive_id, 1);
       nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
       nir_store_var(b, inner_vars.tmax, intersection->t, 0x1);
@@ -1417,6 +1421,7 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
    }
    nir_pop_if(b, NULL);
 
+   nir_store_var(b, data->vars->primitive_addr, intersection->base.node_addr, 1);
    nir_store_var(b, data->vars->primitive_id, intersection->base.primitive_id, 1);
    nir_store_var(b, data->vars->geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
    nir_store_var(b, data->vars->tmax, intersection->t, 0x1);
@@ -1449,6 +1454,7 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
     * next closest hit candidate. */
    inner_vars.hit_kind = data->vars->hit_kind;
 
+   nir_store_var(b, inner_vars.primitive_addr, intersection->node_addr, 1);
    nir_store_var(b, inner_vars.primitive_id, intersection->primitive_id, 1);
    nir_store_var(b, inner_vars.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
    nir_store_var(b, inner_vars.tmax, nir_load_var(b, data->vars->tmax), 0x1);
@@ -1478,6 +1484,7 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
 
    nir_push_if(b, nir_load_var(b, data->vars->ahit_accept));
    {
+      nir_store_var(b, data->vars->primitive_addr, intersection->node_addr, 1);
       nir_store_var(b, data->vars->primitive_id, intersection->primitive_id, 1);
       nir_store_var(b, data->vars->geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
       nir_store_var(b, data->vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
@@ -1658,7 +1665,14 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin
       } else {
          for (int i = 0; i < ARRAY_SIZE(hit_attribs); ++i)
             nir_store_hit_attrib_amd(b, nir_load_var(b, hit_attribs[i]), .base = i);
-         nir_execute_closest_hit_amd(b, nir_load_var(b, vars->idx), nir_load_var(b, vars->tmax),
+
+         nir_def *primitive_addr;
+         if (info->has_position_fetch)
+            primitive_addr = nir_load_var(b, vars->primitive_addr);
+         else
+            primitive_addr = nir_undef(b, 1, 64);
+
+         nir_execute_closest_hit_amd(b, nir_load_var(b, vars->idx), nir_load_var(b, vars->tmax), primitive_addr,
                                      nir_load_var(b, vars->primitive_id), nir_load_var(b, vars->instance_addr),
                                      nir_load_var(b, vars->geometry_id_and_flags), nir_load_var(b, vars->hit_kind));
       }
@@ -1905,7 +1919,8 @@ void
 radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                       const struct radv_shader_args *args, const struct radv_shader_info *info, uint32_t *stack_size,
                       bool resume_shader, struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
-                      bool monolithic, const struct radv_ray_tracing_stage_info *traversal_info)
+                      bool monolithic, bool has_position_fetch,
+                      const struct radv_ray_tracing_stage_info *traversal_info)
 {
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
 
@@ -1984,6 +1999,8 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH
    else
       nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 0x1);
 
+   nir_def *primitive_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_addr);
+   nir_store_var(&b, vars.primitive_addr, nir_pack_64_2x32(&b, primitive_addr), 1);
    nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id), 1);
    nir_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
    nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
@@ -2056,6 +2073,9 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH
       radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
       radv_store_arg(&b, args, traversal_info, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
 
+      if (has_position_fetch)
+         radv_store_arg(&b, args, traversal_info, args->ac.rt.primitive_addr, nir_load_var(&b, vars.primitive_addr));
+
       radv_store_arg(&b, args, traversal_info, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
       radv_store_arg(&b, args, traversal_info, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
       radv_store_arg(&b, args, traversal_info, args->ac.rt.geometry_id_and_flags,
diff --git a/src/amd/vulkan/radv_acceleration_structure.c b/src/amd/vulkan/radv_acceleration_structure.c
index 4a76a4a15e2f4..51b9c26b4e791 100644
--- a/src/amd/vulkan/radv_acceleration_structure.c
+++ b/src/amd/vulkan/radv_acceleration_structure.c
@@ -50,7 +50,6 @@ static const uint32_t leaf_spv[] = {
 
 struct acceleration_structure_layout {
    uint32_t geometry_info_offset;
-   uint32_t primitive_base_indices_offset;
    uint32_t leaf_node_offsets_offset;
    uint32_t bvh_offset;
    uint32_t leaf_nodes_offset;
@@ -126,11 +125,6 @@ radv_get_acceleration_structure_layout(struct radv_device *device,
       offset += sizeof(struct radv_accel_struct_geometry_info) * state->build_info->geometryCount;
    }
 
-   if (device->vk.enabled_features.rayTracingPositionFetch && geometry_type == VK_GEOMETRY_TYPE_TRIANGLES_KHR) {
-      accel_struct->primitive_base_indices_offset = offset;
-      offset += sizeof(uint32_t) * state->build_info->geometryCount;
-   }
-
    /* On GFX12, we need additional space for leaf node offsets since they do not have the same
     * order as the application provided data.
     */
@@ -665,7 +659,6 @@ radv_init_header(VkCommandBuffer commandBuffer, const struct vk_acceleration_str
    header.build_flags = state->build_info->flags;
    header.geometry_type = vk_get_as_geometry_type(state->build_info);
    header.geometry_count = state->build_info->geometryCount;
-   header.primitive_base_indices_offset = layout.primitive_base_indices_offset;
 
    radv_update_memory_cp(cmd_buffer, vk_acceleration_structure_get_va(dst) + base, (const char *)&header + base,
                          sizeof(header) - base);
@@ -690,27 +683,6 @@ radv_init_header(VkCommandBuffer commandBuffer, const struct vk_acceleration_str
 
       free(geometry_infos);
    }
-
-   VkGeometryTypeKHR geometry_type = vk_get_as_geometry_type(state->build_info);
-   if (device->vk.enabled_features.rayTracingPositionFetch && geometry_type == VK_GEOMETRY_TYPE_TRIANGLES_KHR) {
-      uint32_t base_indices_size = sizeof(uint32_t) * state->build_info->geometryCount;
-      uint32_t *base_indices = malloc(base_indices_size);
-      if (!base_indices) {
-         vk_command_buffer_set_error(&cmd_buffer->vk, VK_ERROR_OUT_OF_HOST_MEMORY);
-         return;
-      }
-
-      uint32_t base_index = 0;
-      for (uint32_t i = 0; i < state->build_info->geometryCount; i++) {
-         base_indices[i] = base_index;
-         base_index += state->build_range_infos[i].primitiveCount;
-      }
-
-      radv_CmdUpdateBuffer(commandBuffer, vk_buffer_to_handle(dst->buffer),
-                           dst->offset + layout.primitive_base_indices_offset, base_indices_size, base_indices);
-
-      free(base_indices);
-   }
 }
 
 static void
diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c
index 597b2ffd49cb0..d5586836c11b9 100644
--- a/src/amd/vulkan/radv_pipeline_rt.c
+++ b/src/amd/vulkan/radv_pipeline_rt.c
@@ -364,7 +364,7 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
                    struct radv_ray_tracing_stage_info *stage_info,
                    const struct radv_ray_tracing_stage_info *traversal_stage_info,
                    struct radv_serialized_shader_arena_block *replay_block, bool skip_shaders_cache,
-                   struct radv_shader **out_shader)
+                   bool has_position_fetch, struct radv_shader **out_shader)
 {
    struct radv_physical_device *pdev = radv_device_physical(device);
    struct radv_instance *instance = radv_physical_device_instance(pdev);
@@ -426,7 +426,7 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
       struct radv_shader_stage temp_stage = *stage;
       temp_stage.nir = shaders[i];
       radv_nir_lower_rt_abi(temp_stage.nir, pCreateInfo, &temp_stage.args, &stage->info, stack_size, i > 0, device,
-                            pipeline, monolithic, traversal_stage_info);
+                            pipeline, monolithic, has_position_fetch, traversal_stage_info);
 
       /* Info might be out-of-date after inlining in radv_nir_lower_rt_abi(). */
       nir_shader_gather_info(temp_stage.nir, nir_shader_get_entrypoint(temp_stage.nir));
@@ -547,6 +547,9 @@ radv_gather_ray_tracing_stage_info(nir_shader *nir)
             continue;
 
          nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
+         if (intr->intrinsic == nir_intrinsic_load_ray_triangle_vertex_positions)
+            info.has_position_fetch = true;
+
          if (intr->intrinsic != nir_intrinsic_trace_ray)
             continue;
 
@@ -632,10 +635,13 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
    }
 
    bool has_callable = false;
+   /* Libraries cannot know how they are used so we need to asssume that position fetch is used. */
+   bool has_position_fetch = library;
    /* TODO: Recompile recursive raygen shaders instead. */
    bool raygen_imported = false;
    for (uint32_t i = 0; i < pipeline->stage_count; i++) {
       has_callable |= rt_stages[i].stage == MESA_SHADER_CALLABLE;
+      has_position_fetch |= rt_stages[i].info.has_position_fetch;
       monolithic &= rt_stages[i].info.can_inline;
 
       if (i >= pCreateInfo->stageCount)
@@ -691,9 +697,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
 
          bool monolithic_raygen = monolithic && stage->stage == MESA_SHADER_RAYGEN;
 
-         result =
-            radv_rt_nir_to_asm(device, cache, pCreateInfo, pipeline, monolithic_raygen, stage, &stack_size,
-                               &rt_stages[idx].info, NULL, replay_block, skip_shaders_cache, &rt_stages[idx].shader);
+         result = radv_rt_nir_to_asm(device, cache, pCreateInfo, pipeline, monolithic_raygen, stage, &stack_size,
+                                     &rt_stages[idx].info, NULL, replay_block, skip_shaders_cache, has_position_fetch,
+                                     &rt_stages[idx].shader);
          if (result != VK_SUCCESS)
             goto cleanup;
 
@@ -720,6 +726,7 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
    struct radv_ray_tracing_stage_info traversal_info = {
       .set_flags = 0xFFFFFFFF,
       .unset_flags = 0xFFFFFFFF,
+      .has_position_fetch = has_position_fetch,
    };
 
    memset(traversal_info.unused_args, 0xFF, sizeof(traversal_info.unused_args));
@@ -750,9 +757,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
       .key = stage_keys[MESA_SHADER_INTERSECTION],
    };
    radv_shader_layout_init(pipeline_layout, MESA_SHADER_INTERSECTION, &traversal_stage.layout);
-   result =
-      radv_rt_nir_to_asm(device, cache, pCreateInfo, pipeline, false, &traversal_stage, NULL, NULL, &traversal_info,
-                         NULL, skip_shaders_cache, &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]);
+   result = radv_rt_nir_to_asm(device, cache, pCreateInfo, pipeline, false, &traversal_stage, NULL, NULL,
+                               &traversal_info, NULL, skip_shaders_cache, has_position_fetch,
+                               &pipeline->base.base.shaders[MESA_SHADER_INTERSECTION]);
    ralloc_free(traversal_nir);
 
 cleanup:
diff --git a/src/amd/vulkan/radv_pipeline_rt.h b/src/amd/vulkan/radv_pipeline_rt.h
index ad40801169d93..18bd2b3edb617 100644
--- a/src/amd/vulkan/radv_pipeline_rt.h
+++ b/src/amd/vulkan/radv_pipeline_rt.h
@@ -74,6 +74,7 @@ struct radv_rt_const_arg_info {
 
 struct radv_ray_tracing_stage_info {
    bool can_inline;
+   bool has_position_fetch;
 
    BITSET_DECLARE(unused_args, AC_MAX_ARGS);
 
diff --git a/src/amd/vulkan/radv_shader.h b/src/amd/vulkan/radv_shader.h
index d74bb9d60b7c6..fa51e5ba9b806 100644
--- a/src/amd/vulkan/radv_shader.h
+++ b/src/amd/vulkan/radv_shader.h
@@ -527,7 +527,7 @@ struct radv_ray_tracing_stage_info;
 void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
                            const struct radv_shader_args *args, const struct radv_shader_info *info,
                            uint32_t *stack_size, bool resume_shader, struct radv_device *device,
-                           struct radv_ray_tracing_pipeline *pipeline, bool monolithic,
+                           struct radv_ray_tracing_pipeline *pipeline, bool monolithic, bool has_position_fetch,
                            const struct radv_ray_tracing_stage_info *traversal_info);
 
 void radv_gather_unused_args(struct radv_ray_tracing_stage_info *info, nir_shader *nir);
diff --git a/src/amd/vulkan/radv_shader_args.c b/src/amd/vulkan/radv_shader_args.c
index a8ce3c09f81f4..a63e5c67c6820 100644
--- a/src/amd/vulkan/radv_shader_args.c
+++ b/src/amd/vulkan/radv_shader_args.c
@@ -350,6 +350,7 @@ radv_declare_rt_shader_args(enum amd_gfx_level gfx_level, struct radv_shader_arg
    ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.rt.miss_index);
 
    ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_PTR, &args->ac.rt.instance_addr);
+   ac_add_arg(&args->ac, AC_ARG_VGPR, 2, AC_ARG_CONST_PTR, &args->ac.rt.primitive_addr);
    ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.rt.primitive_id);
    ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.rt.geometry_id_and_flags);
    ac_add_arg(&args->ac, AC_ARG_VGPR, 1, AC_ARG_INT, &args->ac.rt.hit_kind);
diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py
index e88f6390f9c4d..59b3b90c5e62e 100644
--- a/src/compiler/nir/nir_intrinsics.py
+++ b/src/compiler/nir/nir_intrinsics.py
@@ -1852,11 +1852,12 @@ system_value("cull_mask_and_flags_amd", 1)
 
 #   0. SBT Index
 #   1. Ray Tmax
-#   2. Primitive Id
-#   3. Instance Addr
-#   4. Geometry Id and Flags
-#   5. Hit Kind
-intrinsic("execute_closest_hit_amd", src_comp=[1, 1, 1, 1, 1, 1])
+#   2. Primitive Addr
+#   3. Primitive Id
+#   4. Instance Addr
+#   5. Geometry Id and Flags
+#   6. Hit Kind
+intrinsic("execute_closest_hit_amd", src_comp=[1, 1, 1, 1, 1, 1, 1])
 
 #   0. Ray Tmax
 intrinsic("execute_miss_amd", src_comp=[1])
-- 
GitLab