summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/nir/nir.h1
-rw-r--r--src/compiler/nir/nir_lower_subgroups.c48
2 files changed, 49 insertions, 0 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index 5b28c727c80..40b10283f4f 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -2557,6 +2557,7 @@ typedef struct nir_lower_subgroups_options {
uint8_t ballot_bit_size;
bool lower_to_scalar:1;
bool lower_vote_trivial:1;
+ bool lower_vote_eq_to_ballot:1;
bool lower_subgroup_masks:1;
bool lower_shuffle:1;
bool lower_quad:1;
diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c
index f18ad00c370..0d3c83b7951 100644
--- a/src/compiler/nir/nir_lower_subgroups.c
+++ b/src/compiler/nir/nir_lower_subgroups.c
@@ -142,6 +142,51 @@ lower_vote_eq_to_scalar(nir_builder *b, nir_intrinsic_instr *intrin)
}
static nir_ssa_def *
+lower_vote_eq_to_ballot(nir_builder *b, nir_intrinsic_instr *intrin,
+ const nir_lower_subgroups_options *options)
+{
+ assert(intrin->src[0].is_ssa);
+ nir_ssa_def *value = intrin->src[0].ssa;
+
+ /* We have to implicitly lower to scalar */
+ nir_ssa_def *all_eq = NULL;
+ for (unsigned i = 0; i < intrin->num_components; i++) {
+ nir_intrinsic_instr *rfi =
+ nir_intrinsic_instr_create(b->shader,
+ nir_intrinsic_read_first_invocation);
+ nir_ssa_dest_init(&rfi->instr, &rfi->dest,
+ 1, value->bit_size, NULL);
+ rfi->num_components = 1;
+ rfi->src[0] = nir_src_for_ssa(nir_channel(b, value, i));
+ nir_builder_instr_insert(b, &rfi->instr);
+
+ nir_ssa_def *is_eq;
+ if (intrin->intrinsic == nir_intrinsic_vote_feq) {
+ is_eq = nir_feq(b, &rfi->dest.ssa, nir_channel(b, value, i));
+ } else {
+ is_eq = nir_ieq(b, &rfi->dest.ssa, nir_channel(b, value, i));
+ }
+
+ if (all_eq == NULL) {
+ all_eq = is_eq;
+ } else {
+ all_eq = nir_iand(b, all_eq, is_eq);
+ }
+ }
+
+ nir_intrinsic_instr *ballot =
+ nir_intrinsic_instr_create(b->shader, nir_intrinsic_ballot);
+ nir_ssa_dest_init(&ballot->instr, &ballot->dest,
+ 1, options->ballot_bit_size, NULL);
+ ballot->num_components = 1;
+ ballot->src[0] = nir_src_for_ssa(nir_inot(b, all_eq));
+ nir_builder_instr_insert(b, &ballot->instr);
+
+ return nir_ieq(b, &ballot->dest.ssa,
+ nir_imm_intN_t(b, 0, options->ballot_bit_size));
+}
+
+static nir_ssa_def *
lower_shuffle(nir_builder *b, nir_intrinsic_instr *intrin,
bool lower_to_scalar)
{
@@ -219,6 +264,9 @@ lower_subgroups_intrin(nir_builder *b, nir_intrinsic_instr *intrin,
if (options->lower_vote_trivial)
return nir_imm_int(b, NIR_TRUE);
+ if (options->lower_vote_eq_to_ballot)
+ return lower_vote_eq_to_ballot(b, intrin, options);
+
if (options->lower_to_scalar && intrin->num_components > 1)
return lower_vote_eq_to_scalar(b, intrin);
break;