summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRhys Perry <pendingchaos02@gmail.com>2019-12-05 14:12:39 +0000
committerMarge Bot <eric+marge@anholt.net>2020-10-29 18:08:31 +0000
commit176137948150d153c7756505fc78dcfb13511f83 (patch)
tree0e90977023c848f208ff2fbc614bd0dbfd80f204
parentecc5b59a7069ab080a892e3f6a413ef62d3afee2 (diff)
aco: handle SDWA in the optimizer
Apply SGPRs/modifiers when possible and try not to break when SDWA instructions are encountered. No shader-db changes. Signed-off-by: Rhys Perry <pendingchaos02@gmail.com> Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/7349>
-rw-r--r--src/amd/compiler/aco_optimizer.cpp91
1 files changed, 69 insertions, 22 deletions
diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp
index 48cf5e345b5..9a8d6d9fe2f 100644
--- a/src/amd/compiler/aco_optimizer.cpp
+++ b/src/amd/compiler/aco_optimizer.cpp
@@ -619,8 +619,10 @@ bool can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
instr->opcode != aco_opcode::v_readfirstlane_b32;
}
-bool can_apply_sgprs(aco_ptr<Instruction>& instr)
+bool can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
{
+ if (instr->isSDWA() && ctx.program->chip_class < GFX9)
+ return false;
return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
instr->opcode != aco_opcode::v_readlane_b32 &&
instr->opcode != aco_opcode::v_readlane_b32_e64 &&
@@ -891,7 +893,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
info = ctx.info[info.temp.id()];
}
/* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
- if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(instr) && instr->operands.size() == 1) {
+ if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) && instr->operands.size() == 1) {
instr->operands[i].setTemp(info.temp);
info = ctx.info[info.temp.id()];
}
@@ -900,12 +902,19 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
unsigned can_use_mod = instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
can_use_mod = can_use_mod && instr_info.can_use_input_modifiers[(int)instr->opcode];
- if (info.is_abs() && (can_use_VOP3(ctx, instr) || instr->isDPP()) && can_use_mod) {
- if (!instr->isDPP())
+ if (instr->isSDWA())
+ can_use_mod = can_use_mod && (static_cast<SDWA_instruction*>(instr.get())->sel[i] & sdwa_asuint) == sdwa_udword;
+ else
+ can_use_mod = can_use_mod && (instr->isDPP() || can_use_VOP3(ctx, instr));
+
+ if (info.is_abs() && can_use_mod) {
+ if (!instr->isDPP() && !instr->isSDWA())
to_VOP3(ctx, instr);
instr->operands[i] = Operand(info.temp);
if (instr->isDPP())
static_cast<DPP_instruction*>(instr.get())->abs[i] = true;
+ else if (instr->isSDWA())
+ static_cast<SDWA_instruction*>(instr.get())->abs[i] = true;
else
static_cast<VOP3A_instruction*>(instr.get())->abs[i] = true;
}
@@ -917,12 +926,14 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
instr->operands[i].setTemp(info.temp);
continue;
- } else if (info.is_neg() && (can_use_VOP3(ctx, instr) || instr->isDPP()) && can_use_mod) {
- if (!instr->isDPP())
+ } else if (info.is_neg() && can_use_mod) {
+ if (!instr->isDPP() && !instr->isSDWA())
to_VOP3(ctx, instr);
instr->operands[i].setTemp(info.temp);
if (instr->isDPP())
static_cast<DPP_instruction*>(instr.get())->neg[i] = true;
+ else if (instr->isSDWA())
+ static_cast<SDWA_instruction*>(instr.get())->neg[i] = true;
else
static_cast<VOP3A_instruction*>(instr.get())->neg[i] = true;
continue;
@@ -932,7 +943,8 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
(!instr->isSDWA() || ctx.program->chip_class >= GFX9)) {
Operand op = get_constant_op(ctx, info, bits);
perfwarn(ctx.program, instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get());
- if (i == 0 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_writelane_b32) {
+ if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
+ instr->opcode == aco_opcode::v_writelane_b32) {
instr->operands[i] = op;
continue;
} else if (!instr->isVOP3() && can_swap_operands(instr)) {
@@ -1641,6 +1653,8 @@ bool combine_ordering_test(opt_ctx &ctx, aco_ptr<Instruction>& instr)
neg[i] = vop3->neg[0];
abs[i] = vop3->abs[0];
opsel |= (vop3->opsel & 1) << i;
+ } else if (op_instr[i]->isSDWA()) {
+ return false;
}
Temp op0 = op_instr[i]->operands[0].getTemp();
@@ -1715,6 +1729,8 @@ bool combine_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& instr)
Instruction *cmp = follow_operand(ctx, instr->operands[1], true);
if (!nan_test || !cmp)
return false;
+ if (nan_test->isSDWA() || cmp->isSDWA())
+ return false;
if (get_f32_cmp(cmp->opcode) == expected_nan_test)
std::swap(nan_test, cmp);
@@ -1785,6 +1801,8 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& in
if (!nan_test || !cmp)
return false;
+ if (nan_test->isSDWA() || cmp->isSDWA())
+ return false;
aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
if (get_f32_cmp(cmp->opcode) == expected_nan_test)
@@ -1906,6 +1924,18 @@ bool combine_inverse_comparison(opt_ctx &ctx, aco_ptr<Instruction>& instr)
new_vop3->omod = cmp_vop3->omod;
new_vop3->opsel = cmp_vop3->opsel;
new_instr = new_vop3;
+ } else if (cmp->isSDWA()) {
+ SDWA_instruction *new_sdwa = create_instruction<SDWA_instruction>(
+ new_opcode, (Format)((uint16_t)Format::SDWA | (uint16_t)Format::VOPC), 2, 1);
+ SDWA_instruction *cmp_sdwa = static_cast<SDWA_instruction*>(cmp);
+ memcpy(new_sdwa->abs, cmp_sdwa->abs, sizeof(new_sdwa->abs));
+ memcpy(new_sdwa->sel, cmp_sdwa->sel, sizeof(new_sdwa->sel));
+ memcpy(new_sdwa->neg, cmp_sdwa->neg, sizeof(new_sdwa->neg));
+ new_sdwa->dst_sel = cmp_sdwa->dst_sel;
+ new_sdwa->dst_preserve = cmp_sdwa->dst_preserve;
+ new_sdwa->clamp = cmp_sdwa->clamp;
+ new_sdwa->omod = cmp_sdwa->omod;
+ new_instr = new_sdwa;
} else {
new_instr = create_instruction<VOPC_instruction>(new_opcode, Format::VOPC, 2, 1);
}
@@ -1942,6 +1972,9 @@ bool match_op3_for_vop3(opt_ctx &ctx, aco_opcode op1, aco_opcode op2,
VOP3A_instruction *op1_vop3 = op1_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op1_instr) : NULL;
VOP3A_instruction *op2_vop3 = op2_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op2_instr) : NULL;
+ if (op1_instr->isSDWA() || op2_instr->isSDWA())
+ return false;
+
/* don't support inbetween clamp/omod */
if (op2_vop3 && (op2_vop3->clamp || op2_vop3->omod))
return false;
@@ -2431,7 +2464,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
/* Applying two sgprs require making it VOP3, so don't do it unless it's
* definitively beneficial.
* TODO: this is too conservative because later the use count could be reduced to 1 */
- if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3())
+ if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() && !instr->isSDWA())
break;
Temp sgpr = ctx.info[sgpr_info_id].temp;
@@ -2439,7 +2472,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
if (new_sgpr && num_sgprs >= max_sgprs)
continue;
- if (sgpr_idx == 0 || instr->isVOP3()) {
+ if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA()) {
instr->operands[sgpr_idx] = Operand(sgpr);
} else if (can_swap_operands(instr)) {
instr->operands[sgpr_idx] = instr->operands[0];
@@ -2461,22 +2494,20 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
}
}
-bool apply_omod_clamp_helper(opt_ctx &ctx, aco_ptr<Instruction>& instr, ssa_info& def_info)
+template <typename T>
+bool apply_omod_clamp_helper(opt_ctx &ctx, T *instr, ssa_info& def_info)
{
- to_VOP3(ctx, instr);
-
- if (!def_info.is_clamp() && (static_cast<VOP3A_instruction*>(instr.get())->clamp ||
- static_cast<VOP3A_instruction*>(instr.get())->omod))
+ if (!def_info.is_clamp() && (instr->clamp || instr->omod))
return false;
if (def_info.is_omod2())
- static_cast<VOP3A_instruction*>(instr.get())->omod = 1;
+ instr->omod = 1;
else if (def_info.is_omod4())
- static_cast<VOP3A_instruction*>(instr.get())->omod = 2;
+ instr->omod = 2;
else if (def_info.is_omod5())
- static_cast<VOP3A_instruction*>(instr.get())->omod = 3;
+ instr->omod = 3;
else if (def_info.is_clamp())
- static_cast<VOP3A_instruction*>(instr.get())->clamp = true;
+ instr->clamp = true;
return true;
}
@@ -2488,11 +2519,14 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
!instr_info.can_use_output_modifiers[(int)instr->opcode])
return false;
- if (!can_use_VOP3(ctx, instr))
+ bool can_vop3 = can_use_VOP3(ctx, instr);
+ if (!instr->isSDWA() && !can_vop3)
return false;
/* omod has no effect if denormals are enabled */
bool can_use_omod = (instr->definitions[0].bytes() == 4 ? block.fp_mode.denorm32 : block.fp_mode.denorm16_64) == 0;
+ can_use_omod = can_use_omod && (can_vop3 || ctx.program->chip_class >= GFX9); /* SDWA omod is GFX9+ */
+
ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
@@ -2506,8 +2540,14 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
/* MADs/FMAs are created later, so we don't have to update the original add */
assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
- if (!apply_omod_clamp_helper(ctx, instr, def_info))
- return false;
+ if (instr->isSDWA()) {
+ if (!apply_omod_clamp_helper(ctx, static_cast<SDWA_instruction *>(instr.get()), def_info))
+ return false;
+ } else {
+ to_VOP3(ctx, instr);
+ if (!apply_omod_clamp_helper(ctx, static_cast<VOP3A_instruction *>(instr.get()), def_info))
+ return false;
+ }
std::swap(instr->definitions[0], def_info.instr->definitions[0]);
ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
@@ -2525,7 +2565,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
return;
if (instr->isVALU()) {
- if (can_apply_sgprs(instr))
+ if (can_apply_sgprs(ctx, instr))
apply_sgprs(ctx, instr);
while (apply_omod_clamp(ctx, block, instr)) ;
}
@@ -2534,6 +2574,9 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
instr->definitions[0].setHint(vcc);
}
+ if (instr->isSDWA())
+ return;
+
/* TODO: There are still some peephole optimizations that could be done:
* - abs(a - b) -> s_absdiff_i32
* - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
@@ -2557,6 +2600,8 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
return;
if (mul_instr->isVOP3() && static_cast<VOP3A_instruction*>(mul_instr)->clamp)
return;
+ if (mul_instr->isSDWA())
+ return;
/* convert to mul(neg(a), b) */
ctx.uses[mul_instr->definitions[0].tempId()]--;
@@ -2639,6 +2684,8 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
// TODO: would be better to check this before selecting a mul instr?
if (!check_vop3_operands(ctx, 3, op))
return;
+ if (mul_instr->isSDWA())
+ return;
if (mul_instr->isVOP3()) {
VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);