1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
From d6afc517e727aebc8675f9f1e4047d1c5f8bd4b2 Mon Sep 17 00:00:00 2001
From: Georg Lehmann <dadschoorse@gmail.com>
Date: Wed, 9 Apr 2025 14:17:47 +0200
Subject: [PATCH 08/13] aco: select f2e4m3fn
---
src/amd/compiler/aco_ir.cpp | 2 ++
.../instruction_selection/aco_isel_setup.cpp | 1 +
.../aco_select_nir_alu.cpp | 17 +++++++++++++++++
3 files changed, 20 insertions(+)
diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp
index b29efb4efdf..68512146641 100644
--- a/src/amd/compiler/aco_ir.cpp
+++ b/src/amd/compiler/aco_ir.cpp
@@ -582,6 +582,8 @@ can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx)
case aco_opcode::v_interp_p10_rtz_f16_f32_inreg: return idx == 0 || idx == 2;
case aco_opcode::v_interp_p2_f16_f32_inreg:
case aco_opcode::v_interp_p2_rtz_f16_f32_inreg: return idx == -1 || idx == 0;
+ case aco_opcode::v_cvt_pk_fp8_f32:
+ case aco_opcode::v_cvt_pk_bf8_f32: return idx == -1;
default:
return gfx_level >= GFX11 && (get_gfx11_true16_mask(op) & BITFIELD_BIT(idx == -1 ? 3 : idx));
}
diff --git a/src/amd/compiler/instruction_selection/aco_isel_setup.cpp b/src/amd/compiler/instruction_selection/aco_isel_setup.cpp
index cc15635ddc0..9dbd27e8812 100644
--- a/src/amd/compiler/instruction_selection/aco_isel_setup.cpp
+++ b/src/amd/compiler/instruction_selection/aco_isel_setup.cpp
@@ -412,6 +412,7 @@ init_context(isel_context* ctx, nir_shader* shader)
regclasses[alu_instr->src[0].src.ssa->index].type() == RegType::vgpr)
type = RegType::vgpr;
break;
+ case nir_op_f2e4m3fn:
case nir_op_fmulz:
case nir_op_ffmaz:
case nir_op_f2f64:
diff --git a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp
index aaeef4cb619..8d5ac777629 100644
--- a/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp
+++ b/src/amd/compiler/instruction_selection/aco_select_nir_alu.cpp
@@ -2474,6 +2474,23 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src);
break;
}
+ case nir_op_f2e4m3fn: {
+ Operand src0, src1;
+ if (instr->def.num_components == 2) {
+ Temp src = get_ssa_temp(ctx, instr->src[0].src.ssa);
+ RegClass rc = RegClass(src.regClass().type(), 1);
+ src0 = Operand(emit_extract_vector(ctx, src, instr->src[0].swizzle[0], rc));
+ src1 = Operand(emit_extract_vector(ctx, src, instr->src[0].swizzle[1], rc));
+ } else {
+ assert(instr->def.num_components == 1);
+ src0 = Operand(get_alu_src(ctx, instr->src[0]));
+ src1 = Operand::c32(0);
+ }
+ bld.vop3(aco_opcode::v_cvt_pk_fp8_f32, Definition(dst), src0, src1);
+ if (instr->def.num_components == 2)
+ emit_split_vector(ctx, dst, 2);
+ break;
+ }
case nir_op_i2f16: {
Temp src = get_alu_src(ctx, instr->src[0]);
const unsigned input_size = instr->src[0].src.ssa->bit_size;
--
2.49.0
|