diff options
-rw-r--r-- | src/compiler/nir/nir.h | 1 | ||||
-rw-r--r-- | src/compiler/nir/nir_lower_int64.c | 62 |
2 files changed, 63 insertions, 0 deletions
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 7e4352049c3..f84ebbec85e 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3071,6 +3071,7 @@ typedef enum { nir_lower_bit_count64 = (1 << 15), nir_lower_subgroup_shuffle64 = (1 << 16), nir_lower_scan_reduce_bitwise64 = (1 << 17), + nir_lower_scan_reduce_iadd64 = (1 << 18), } nir_lower_int64_options; typedef enum { diff --git a/src/compiler/nir/nir_lower_int64.c b/src/compiler/nir/nir_lower_int64.c index 6c5a2aea0a7..dbe6d1e315b 100644 --- a/src/compiler/nir/nir_lower_int64.c +++ b/src/compiler/nir/nir_lower_int64.c @@ -1110,6 +1110,64 @@ split_64bit_subgroup_op(nir_builder *b, const nir_intrinsic_instr *intrin) return nir_pack_64_2x32_split(b, res[0], res[1]); } +static nir_ssa_def * +build_scan_intrinsic(nir_builder *b, nir_intrinsic_op scan_op, + nir_op reduction_op, unsigned cluster_size, + nir_ssa_def *val) +{ + nir_intrinsic_instr *scan = + nir_intrinsic_instr_create(b->shader, scan_op); + scan->num_components = val->num_components; + scan->src[0] = nir_src_for_ssa(val); + nir_intrinsic_set_reduction_op(scan, reduction_op); + if (scan_op == nir_intrinsic_reduce) + nir_intrinsic_set_cluster_size(scan, cluster_size); + nir_ssa_dest_init(&scan->instr, &scan->dest, + val->num_components, val->bit_size, NULL); + nir_builder_instr_insert(b, &scan->instr); + return &scan->dest.ssa; +} + +static nir_ssa_def * +lower_scan_iadd64(nir_builder *b, const nir_intrinsic_instr *intrin) +{ + unsigned cluster_size = + intrin->intrinsic == nir_intrinsic_reduce ? + nir_intrinsic_cluster_size(intrin) : 0; + + /* Split it into three chunks of no more than 24 bits each. With 8 bits + * of headroom, we're guaranteed that there will never be overflow in the + * individual subgroup operations. (Assuming, of course, a subgroup size + * no larger than 256 which seems reasonable.) We can then scan on each of + * the chunks and add them back together at the end. + */ + assert(intrin->src[0].is_ssa); + nir_ssa_def *x = intrin->src[0].ssa; + nir_ssa_def *x_low = + nir_u2u32(b, nir_iand_imm(b, x, 0xffffff)); + nir_ssa_def *x_mid = + nir_u2u32(b, nir_iand_imm(b, nir_ushr(b, x, nir_imm_int(b, 24)), + 0xffffff)); + nir_ssa_def *x_hi = + nir_u2u32(b, nir_ushr(b, x, nir_imm_int(b, 48))); + + nir_ssa_def *scan_low = + build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd, + cluster_size, x_low); + nir_ssa_def *scan_mid = + build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd, + cluster_size, x_mid); + nir_ssa_def *scan_hi = + build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd, + cluster_size, x_hi); + + scan_low = nir_u2u64(b, scan_low); + scan_mid = nir_ishl(b, nir_u2u64(b, scan_mid), nir_imm_int(b, 24)); + scan_hi = nir_ishl(b, nir_u2u64(b, scan_hi), nir_imm_int(b, 48)); + + return nir_iadd(b, scan_hi, nir_iadd(b, scan_mid, scan_low)); +} + static bool should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin, const nir_shader_compiler_options *options) @@ -1137,6 +1195,8 @@ should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin, return false; switch (nir_intrinsic_reduction_op(intrin)) { + case nir_op_iadd: + return options->lower_int64_options & nir_lower_scan_reduce_iadd64; case nir_op_iand: case nir_op_ior: case nir_op_ixor: @@ -1171,6 +1231,8 @@ lower_int64_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin) case nir_intrinsic_inclusive_scan: case nir_intrinsic_exclusive_scan: switch (nir_intrinsic_reduction_op(intrin)) { + case nir_op_iadd: + return lower_scan_iadd64(b, intrin); case nir_op_iand: case nir_op_ior: case nir_op_ixor: |