summaryrefslogtreecommitdiff
path: root/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/amd/vulkan/nir/radv_nir_lower_ray_queries.c')
-rw-r--r--src/amd/vulkan/nir/radv_nir_lower_ray_queries.c728
1 files changed, 728 insertions, 0 deletions
diff --git a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
new file mode 100644
index 00000000000..15a120d8090
--- /dev/null
+++ b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c
@@ -0,0 +1,728 @@
+/*
+ * Copyright © 2022 Konstantin Seurer
+ *
+ * SPDX-License-Identifier: MIT
+ */
+
+#include "nir/nir.h"
+#include "nir/nir_builder.h"
+
+#include "util/hash_table.h"
+
+#include "bvh/bvh.h"
+#include "nir/radv_nir_rt_common.h"
+#include "radv_debug.h"
+#include "radv_nir.h"
+#include "radv_shader.h"
+
+/* Traversal stack size. Traversal supports backtracking so we can go deeper than this size if
+ * needed. However, we keep a large stack size to avoid it being put into registers, which hurts
+ * occupancy. */
+#define MAX_SCRATCH_STACK_ENTRY_COUNT 76
+
+typedef struct {
+ nir_variable *variable;
+ unsigned array_length;
+} rq_variable;
+
+static rq_variable *
+rq_variable_create(void *ctx, nir_shader *shader, unsigned array_length, const struct glsl_type *type, const char *name)
+{
+ rq_variable *result = ralloc(ctx, rq_variable);
+ result->array_length = array_length;
+
+ const struct glsl_type *variable_type = type;
+ if (array_length != 1)
+ variable_type = glsl_array_type(type, array_length, glsl_get_explicit_stride(type));
+
+ result->variable = nir_variable_create(shader, nir_var_shader_temp, variable_type, name);
+
+ return result;
+}
+
+static nir_def *
+nir_load_array(nir_builder *b, nir_variable *array, nir_def *index)
+{
+ return nir_load_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index));
+}
+
+static void
+nir_store_array(nir_builder *b, nir_variable *array, nir_def *index, nir_def *value, unsigned writemask)
+{
+ nir_store_deref(b, nir_build_deref_array(b, nir_build_deref_var(b, array), index), value, writemask);
+}
+
+static nir_deref_instr *
+rq_deref_var(nir_builder *b, nir_def *index, rq_variable *var)
+{
+ if (var->array_length == 1)
+ return nir_build_deref_var(b, var->variable);
+
+ return nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index);
+}
+
+static nir_def *
+rq_load_var(nir_builder *b, nir_def *index, rq_variable *var)
+{
+ if (var->array_length == 1)
+ return nir_load_var(b, var->variable);
+
+ return nir_load_array(b, var->variable, index);
+}
+
+static void
+rq_store_var(nir_builder *b, nir_def *index, rq_variable *var, nir_def *value, unsigned writemask)
+{
+ if (var->array_length == 1) {
+ nir_store_var(b, var->variable, value, writemask);
+ } else {
+ nir_store_array(b, var->variable, index, value, writemask);
+ }
+}
+
+static void
+rq_copy_var(nir_builder *b, nir_def *index, rq_variable *dst, rq_variable *src, unsigned mask)
+{
+ rq_store_var(b, index, dst, rq_load_var(b, index, src), mask);
+}
+
+static nir_def *
+rq_load_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index)
+{
+ if (var->array_length == 1)
+ return nir_load_array(b, var->variable, array_index);
+
+ return nir_load_deref(
+ b, nir_build_deref_array(b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index));
+}
+
+static void
+rq_store_array(nir_builder *b, nir_def *index, rq_variable *var, nir_def *array_index, nir_def *value,
+ unsigned writemask)
+{
+ if (var->array_length == 1) {
+ nir_store_array(b, var->variable, array_index, value, writemask);
+ } else {
+ nir_store_deref(
+ b,
+ nir_build_deref_array(b, nir_build_deref_array(b, nir_build_deref_var(b, var->variable), index), array_index),
+ value, writemask);
+ }
+}
+
+struct ray_query_traversal_vars {
+ rq_variable *origin;
+ rq_variable *direction;
+
+ rq_variable *bvh_base;
+ rq_variable *stack;
+ rq_variable *top_stack;
+ rq_variable *stack_low_watermark;
+ rq_variable *current_node;
+ rq_variable *previous_node;
+ rq_variable *instance_top_node;
+ rq_variable *instance_bottom_node;
+};
+
+struct ray_query_intersection_vars {
+ rq_variable *primitive_id;
+ rq_variable *geometry_id_and_flags;
+ rq_variable *instance_addr;
+ rq_variable *intersection_type;
+ rq_variable *opaque;
+ rq_variable *frontface;
+ rq_variable *sbt_offset_and_flags;
+ rq_variable *barycentrics;
+ rq_variable *t;
+};
+
+struct ray_query_vars {
+ rq_variable *root_bvh_base;
+ rq_variable *flags;
+ rq_variable *cull_mask;
+ rq_variable *origin;
+ rq_variable *tmin;
+ rq_variable *direction;
+
+ rq_variable *incomplete;
+
+ struct ray_query_intersection_vars closest;
+ struct ray_query_intersection_vars candidate;
+
+ struct ray_query_traversal_vars trav;
+
+ rq_variable *stack;
+ uint32_t shared_base;
+ uint32_t stack_entries;
+
+ nir_intrinsic_instr *initialize;
+};
+
+#define VAR_NAME(name) strcat(strcpy(ralloc_size(ctx, strlen(base_name) + strlen(name) + 1), base_name), name)
+
+static struct ray_query_traversal_vars
+init_ray_query_traversal_vars(void *ctx, nir_shader *shader, unsigned array_length, const char *base_name)
+{
+ struct ray_query_traversal_vars result;
+
+ const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
+
+ result.origin = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_origin"));
+ result.direction = rq_variable_create(ctx, shader, array_length, vec3_type, VAR_NAME("_direction"));
+
+ result.bvh_base = rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_bvh_base"));
+ result.stack = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack"));
+ result.top_stack = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_top_stack"));
+ result.stack_low_watermark =
+ rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_stack_low_watermark"));
+ result.current_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_current_node"));
+ result.previous_node = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_previous_node"));
+ result.instance_top_node =
+ rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_instance_top_node"));
+ result.instance_bottom_node =
+ rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_instance_bottom_node"));
+ return result;
+}
+
+static struct ray_query_intersection_vars
+init_ray_query_intersection_vars(void *ctx, nir_shader *shader, unsigned array_length, const char *base_name)
+{
+ struct ray_query_intersection_vars result;
+
+ const struct glsl_type *vec2_type = glsl_vector_type(GLSL_TYPE_FLOAT, 2);
+
+ result.primitive_id = rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_primitive_id"));
+ result.geometry_id_and_flags =
+ rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_geometry_id_and_flags"));
+ result.instance_addr =
+ rq_variable_create(ctx, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_instance_addr"));
+ result.intersection_type =
+ rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_intersection_type"));
+ result.opaque = rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_opaque"));
+ result.frontface = rq_variable_create(ctx, shader, array_length, glsl_bool_type(), VAR_NAME("_frontface"));
+ result.sbt_offset_and_flags =
+ rq_variable_create(ctx, shader, array_length, glsl_uint_type(), VAR_NAME("_sbt_offset_and_flags"));
+ result.barycentrics = rq_variable_create(ctx, shader, array_length, vec2_type, VAR_NAME("_barycentrics"));
+ result.t = rq_variable_create(ctx, shader, array_length, glsl_float_type(), VAR_NAME("_t"));
+
+ return result;
+}
+
+static void
+init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst, const char *base_name,
+ uint32_t max_shared_size)
+{
+ void *ctx = dst;
+ const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
+
+ dst->root_bvh_base = rq_variable_create(dst, shader, array_length, glsl_uint64_t_type(), VAR_NAME("_root_bvh_base"));
+ dst->flags = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_flags"));
+ dst->cull_mask = rq_variable_create(dst, shader, array_length, glsl_uint_type(), VAR_NAME("_cull_mask"));
+ dst->origin = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_origin"));
+ dst->tmin = rq_variable_create(dst, shader, array_length, glsl_float_type(), VAR_NAME("_tmin"));
+ dst->direction = rq_variable_create(dst, shader, array_length, vec3_type, VAR_NAME("_direction"));
+
+ dst->incomplete = rq_variable_create(dst, shader, array_length, glsl_bool_type(), VAR_NAME("_incomplete"));
+
+ dst->closest = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_closest"));
+ dst->candidate = init_ray_query_intersection_vars(dst, shader, array_length, VAR_NAME("_candidate"));
+
+ dst->trav = init_ray_query_traversal_vars(dst, shader, array_length, VAR_NAME("_top"));
+
+ uint32_t workgroup_size =
+ shader->info.workgroup_size[0] * shader->info.workgroup_size[1] * shader->info.workgroup_size[2];
+ uint32_t shared_stack_entries = shader->info.ray_queries == 1 ? 16 : 8;
+ uint32_t shared_stack_size = workgroup_size * shared_stack_entries * 4;
+ uint32_t shared_offset = align(shader->info.shared_size, 4);
+ if (shader->info.stage != MESA_SHADER_COMPUTE || array_length > 1 ||
+ shared_offset + shared_stack_size > max_shared_size) {
+ dst->stack =
+ rq_variable_create(dst, shader, array_length,
+ glsl_array_type(glsl_uint_type(), MAX_SCRATCH_STACK_ENTRY_COUNT, 0), VAR_NAME("_stack"));
+ dst->stack_entries = MAX_SCRATCH_STACK_ENTRY_COUNT;
+ } else {
+ dst->stack = NULL;
+ dst->shared_base = shared_offset;
+ dst->stack_entries = shared_stack_entries;
+
+ shader->info.shared_size = shared_offset + shared_stack_size;
+ }
+}
+
+#undef VAR_NAME
+
+static void
+lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table *ht, uint32_t max_shared_size)
+{
+ struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars);
+
+ unsigned array_length = 1;
+ if (glsl_type_is_array(ray_query->type))
+ array_length = glsl_get_length(ray_query->type);
+
+ init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name, max_shared_size);
+
+ _mesa_hash_table_insert(ht, ray_query, vars);
+}
+
+static void
+copy_candidate_to_closest(nir_builder *b, nir_def *index, struct ray_query_vars *vars)
+{
+ rq_copy_var(b, index, vars->closest.barycentrics, vars->candidate.barycentrics, 0x3);
+ rq_copy_var(b, index, vars->closest.geometry_id_and_flags, vars->candidate.geometry_id_and_flags, 0x1);
+ rq_copy_var(b, index, vars->closest.instance_addr, vars->candidate.instance_addr, 0x1);
+ rq_copy_var(b, index, vars->closest.intersection_type, vars->candidate.intersection_type, 0x1);
+ rq_copy_var(b, index, vars->closest.opaque, vars->candidate.opaque, 0x1);
+ rq_copy_var(b, index, vars->closest.frontface, vars->candidate.frontface, 0x1);
+ rq_copy_var(b, index, vars->closest.sbt_offset_and_flags, vars->candidate.sbt_offset_and_flags, 0x1);
+ rq_copy_var(b, index, vars->closest.primitive_id, vars->candidate.primitive_id, 0x1);
+ rq_copy_var(b, index, vars->closest.t, vars->candidate.t, 0x1);
+}
+
+static void
+insert_terminate_on_first_hit(nir_builder *b, nir_def *index, struct ray_query_vars *vars,
+ const struct radv_ray_flags *ray_flags, bool break_on_terminate)
+{
+ nir_def *terminate_on_first_hit;
+ if (ray_flags)
+ terminate_on_first_hit = ray_flags->terminate_on_first_hit;
+ else
+ terminate_on_first_hit =
+ nir_test_mask(b, rq_load_var(b, index, vars->flags), SpvRayFlagsTerminateOnFirstHitKHRMask);
+ nir_push_if(b, terminate_on_first_hit);
+ {
+ rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
+ if (break_on_terminate)
+ nir_jump(b, nir_jump_break);
+ }
+ nir_pop_if(b, NULL);
+}
+
+static void
+lower_rq_confirm_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
+{
+ copy_candidate_to_closest(b, index, vars);
+ insert_terminate_on_first_hit(b, index, vars, NULL, false);
+}
+
+static void
+lower_rq_generate_intersection(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
+{
+ nir_push_if(b, nir_iand(b, nir_fge(b, rq_load_var(b, index, vars->closest.t), instr->src[1].ssa),
+ nir_fge(b, instr->src[1].ssa, rq_load_var(b, index, vars->tmin))));
+ {
+ copy_candidate_to_closest(b, index, vars);
+ insert_terminate_on_first_hit(b, index, vars, NULL, false);
+ rq_store_var(b, index, vars->closest.t, instr->src[1].ssa, 0x1);
+ }
+ nir_pop_if(b, NULL);
+}
+
+enum rq_intersection_type { intersection_type_none, intersection_type_triangle, intersection_type_aabb };
+
+static void
+lower_rq_initialize(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars,
+ struct radv_instance *instance)
+{
+ rq_store_var(b, index, vars->flags, instr->src[2].ssa, 0x1);
+ rq_store_var(b, index, vars->cull_mask, nir_ishl_imm(b, instr->src[3].ssa, 24), 0x1);
+
+ rq_store_var(b, index, vars->origin, instr->src[4].ssa, 0x7);
+ rq_store_var(b, index, vars->trav.origin, instr->src[4].ssa, 0x7);
+
+ rq_store_var(b, index, vars->tmin, instr->src[5].ssa, 0x1);
+
+ rq_store_var(b, index, vars->direction, instr->src[6].ssa, 0x7);
+ rq_store_var(b, index, vars->trav.direction, instr->src[6].ssa, 0x7);
+
+ rq_store_var(b, index, vars->closest.t, instr->src[7].ssa, 0x1);
+ rq_store_var(b, index, vars->closest.intersection_type, nir_imm_int(b, intersection_type_none), 0x1);
+
+ nir_def *accel_struct = instr->src[1].ssa;
+
+ /* Make sure that instance data loads don't hang in case of a miss by setting a valid initial address. */
+ rq_store_var(b, index, vars->closest.instance_addr, accel_struct, 1);
+ rq_store_var(b, index, vars->candidate.instance_addr, accel_struct, 1);
+
+ nir_def *bvh_offset = nir_build_load_global(
+ b, 1, 32, nir_iadd_imm(b, accel_struct, offsetof(struct radv_accel_struct_header, bvh_offset)),
+ .access = ACCESS_NON_WRITEABLE);
+ nir_def *bvh_base = nir_iadd(b, accel_struct, nir_u2u64(b, bvh_offset));
+ bvh_base = build_addr_to_node(b, bvh_base);
+
+ rq_store_var(b, index, vars->root_bvh_base, bvh_base, 0x1);
+ rq_store_var(b, index, vars->trav.bvh_base, bvh_base, 1);
+
+ if (vars->stack) {
+ rq_store_var(b, index, vars->trav.stack, nir_imm_int(b, 0), 0x1);
+ rq_store_var(b, index, vars->trav.stack_low_watermark, nir_imm_int(b, 0), 0x1);
+ } else {
+ nir_def *base_offset = nir_imul_imm(b, nir_load_local_invocation_index(b), sizeof(uint32_t));
+ base_offset = nir_iadd_imm(b, base_offset, vars->shared_base);
+ rq_store_var(b, index, vars->trav.stack, base_offset, 0x1);
+ rq_store_var(b, index, vars->trav.stack_low_watermark, base_offset, 0x1);
+ }
+
+ rq_store_var(b, index, vars->trav.current_node, nir_imm_int(b, RADV_BVH_ROOT_NODE), 0x1);
+ rq_store_var(b, index, vars->trav.previous_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
+ rq_store_var(b, index, vars->trav.instance_top_node, nir_imm_int(b, RADV_BVH_INVALID_NODE), 0x1);
+ rq_store_var(b, index, vars->trav.instance_bottom_node, nir_imm_int(b, RADV_BVH_NO_INSTANCE_ROOT), 0x1);
+
+ rq_store_var(b, index, vars->trav.top_stack, nir_imm_int(b, -1), 1);
+
+ rq_store_var(b, index, vars->incomplete, nir_imm_bool(b, !(instance->debug_flags & RADV_DEBUG_NO_RT)), 0x1);
+
+ vars->initialize = instr;
+}
+
+static nir_def *
+lower_rq_load(struct radv_device *device, nir_builder *b, nir_def *index, nir_intrinsic_instr *instr,
+ struct ray_query_vars *vars)
+{
+ bool committed = nir_intrinsic_committed(instr);
+ struct ray_query_intersection_vars *intersection = committed ? &vars->closest : &vars->candidate;
+
+ uint32_t column = nir_intrinsic_column(instr);
+
+ nir_ray_query_value value = nir_intrinsic_ray_query_value(instr);
+ switch (value) {
+ case nir_ray_query_value_flags:
+ return rq_load_var(b, index, vars->flags);
+ case nir_ray_query_value_intersection_barycentrics:
+ return rq_load_var(b, index, intersection->barycentrics);
+ case nir_ray_query_value_intersection_candidate_aabb_opaque:
+ return nir_iand(b, rq_load_var(b, index, vars->candidate.opaque),
+ nir_ieq_imm(b, rq_load_var(b, index, vars->candidate.intersection_type), intersection_type_aabb));
+ case nir_ray_query_value_intersection_front_face:
+ return rq_load_var(b, index, intersection->frontface);
+ case nir_ray_query_value_intersection_geometry_index:
+ return nir_iand_imm(b, rq_load_var(b, index, intersection->geometry_id_and_flags), 0xFFFFFF);
+ case nir_ray_query_value_intersection_instance_custom_index: {
+ nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
+ return nir_iand_imm(
+ b,
+ nir_build_load_global(
+ b, 1, 32,
+ nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, custom_instance_and_mask))),
+ 0xFFFFFF);
+ }
+ case nir_ray_query_value_intersection_instance_id: {
+ nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
+ return nir_build_load_global(
+ b, 1, 32, nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, instance_id)));
+ }
+ case nir_ray_query_value_intersection_instance_sbt_index:
+ return nir_iand_imm(b, rq_load_var(b, index, intersection->sbt_offset_and_flags), 0xFFFFFF);
+ case nir_ray_query_value_intersection_object_ray_direction: {
+ nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
+ nir_def *wto_matrix[3];
+ nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
+ return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->direction), wto_matrix, false);
+ }
+ case nir_ray_query_value_intersection_object_ray_origin: {
+ nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
+ nir_def *wto_matrix[3];
+ nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
+ return nir_build_vec3_mat_mult(b, rq_load_var(b, index, vars->origin), wto_matrix, true);
+ }
+ case nir_ray_query_value_intersection_object_to_world: {
+ nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
+ nir_def *rows[3];
+ for (unsigned r = 0; r < 3; ++r)
+ rows[r] = nir_build_load_global(
+ b, 4, 32,
+ nir_iadd_imm(b, instance_node_addr, offsetof(struct radv_bvh_instance_node, otw_matrix) + r * 16));
+
+ return nir_vec3(b, nir_channel(b, rows[0], column), nir_channel(b, rows[1], column),
+ nir_channel(b, rows[2], column));
+ }
+ case nir_ray_query_value_intersection_primitive_index:
+ return rq_load_var(b, index, intersection->primitive_id);
+ case nir_ray_query_value_intersection_t:
+ return rq_load_var(b, index, intersection->t);
+ case nir_ray_query_value_intersection_type: {
+ nir_def *intersection_type = rq_load_var(b, index, intersection->intersection_type);
+ if (!committed)
+ intersection_type = nir_iadd_imm(b, intersection_type, -1);
+
+ return intersection_type;
+ }
+ case nir_ray_query_value_intersection_world_to_object: {
+ nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
+
+ nir_def *wto_matrix[3];
+ nir_build_wto_matrix_load(b, instance_node_addr, wto_matrix);
+
+ nir_def *vals[3];
+ for (unsigned i = 0; i < 3; ++i)
+ vals[i] = nir_channel(b, wto_matrix[i], column);
+
+ return nir_vec(b, vals, 3);
+ }
+ case nir_ray_query_value_tmin:
+ return rq_load_var(b, index, vars->tmin);
+ case nir_ray_query_value_world_ray_direction:
+ return rq_load_var(b, index, vars->direction);
+ case nir_ray_query_value_world_ray_origin:
+ return rq_load_var(b, index, vars->origin);
+ case nir_ray_query_value_intersection_triangle_vertex_positions: {
+ nir_def *instance_node_addr = rq_load_var(b, index, intersection->instance_addr);
+ nir_def *primitive_id = rq_load_var(b, index, intersection->primitive_id);
+ return radv_load_vertex_position(device, b, instance_node_addr, primitive_id, nir_intrinsic_column(instr));
+ }
+ default:
+ unreachable("Invalid nir_ray_query_value!");
+ }
+
+ return NULL;
+}
+
+struct traversal_data {
+ struct ray_query_vars *vars;
+ nir_def *index;
+};
+
+static void
+handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersection,
+ const struct radv_ray_traversal_args *args)
+{
+ struct traversal_data *data = args->data;
+ struct ray_query_vars *vars = data->vars;
+ nir_def *index = data->index;
+
+ rq_store_var(b, index, vars->candidate.primitive_id, intersection->primitive_id, 1);
+ rq_store_var(b, index, vars->candidate.geometry_id_and_flags, intersection->geometry_id_and_flags, 1);
+ rq_store_var(b, index, vars->candidate.opaque, intersection->opaque, 0x1);
+ rq_store_var(b, index, vars->candidate.intersection_type, nir_imm_int(b, intersection_type_aabb), 0x1);
+
+ nir_jump(b, nir_jump_break);
+}
+
+static void
+handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *intersection,
+ const struct radv_ray_traversal_args *args, const struct radv_ray_flags *ray_flags)
+{
+ struct traversal_data *data = args->data;
+ struct ray_query_vars *vars = data->vars;
+ nir_def *index = data->index;
+
+ rq_store_var(b, index, vars->candidate.barycentrics, intersection->barycentrics, 3);
+ rq_store_var(b, index, vars->candidate.primitive_id, intersection->base.primitive_id, 1);
+ rq_store_var(b, index, vars->candidate.geometry_id_and_flags, intersection->base.geometry_id_and_flags, 1);
+ rq_store_var(b, index, vars->candidate.t, intersection->t, 0x1);
+ rq_store_var(b, index, vars->candidate.opaque, intersection->base.opaque, 0x1);
+ rq_store_var(b, index, vars->candidate.frontface, intersection->frontface, 0x1);
+ rq_store_var(b, index, vars->candidate.intersection_type, nir_imm_int(b, intersection_type_triangle), 0x1);
+
+ nir_push_if(b, intersection->base.opaque);
+ {
+ copy_candidate_to_closest(b, index, vars);
+ insert_terminate_on_first_hit(b, index, vars, ray_flags, true);
+ }
+ nir_push_else(b, NULL);
+ {
+ nir_jump(b, nir_jump_break);
+ }
+ nir_pop_if(b, NULL);
+}
+
+static void
+store_stack_entry(nir_builder *b, nir_def *index, nir_def *value, const struct radv_ray_traversal_args *args)
+{
+ struct traversal_data *data = args->data;
+ if (data->vars->stack)
+ rq_store_array(b, data->index, data->vars->stack, index, value, 1);
+ else
+ nir_store_shared(b, value, index, .base = 0, .align_mul = 4);
+}
+
+static nir_def *
+load_stack_entry(nir_builder *b, nir_def *index, const struct radv_ray_traversal_args *args)
+{
+ struct traversal_data *data = args->data;
+ if (data->vars->stack)
+ return rq_load_array(b, data->index, data->vars->stack, index);
+ else
+ return nir_load_shared(b, 1, 32, index, .base = 0, .align_mul = 4);
+}
+
+static nir_def *
+lower_rq_proceed(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars,
+ struct radv_device *device)
+{
+ nir_metadata_require(nir_cf_node_get_function(&instr->instr.block->cf_node), nir_metadata_dominance);
+
+ bool ignore_cull_mask = false;
+ if (nir_block_dominates(vars->initialize->instr.block, instr->instr.block)) {
+ nir_src cull_mask = vars->initialize->src[3];
+ if (nir_src_is_const(cull_mask) && nir_src_as_uint(cull_mask) == 0xFF)
+ ignore_cull_mask = true;
+ }
+
+ nir_variable *inv_dir = nir_local_variable_create(b->impl, glsl_vector_type(GLSL_TYPE_FLOAT, 3), "inv_dir");
+ nir_store_var(b, inv_dir, nir_frcp(b, rq_load_var(b, index, vars->trav.direction)), 0x7);
+
+ struct radv_ray_traversal_vars trav_vars = {
+ .tmax = rq_deref_var(b, index, vars->closest.t),
+ .origin = rq_deref_var(b, index, vars->trav.origin),
+ .dir = rq_deref_var(b, index, vars->trav.direction),
+ .inv_dir = nir_build_deref_var(b, inv_dir),
+ .bvh_base = rq_deref_var(b, index, vars->trav.bvh_base),
+ .stack = rq_deref_var(b, index, vars->trav.stack),
+ .top_stack = rq_deref_var(b, index, vars->trav.top_stack),
+ .stack_low_watermark = rq_deref_var(b, index, vars->trav.stack_low_watermark),
+ .current_node = rq_deref_var(b, index, vars->trav.current_node),
+ .previous_node = rq_deref_var(b, index, vars->trav.previous_node),
+ .instance_top_node = rq_deref_var(b, index, vars->trav.instance_top_node),
+ .instance_bottom_node = rq_deref_var(b, index, vars->trav.instance_bottom_node),
+ .instance_addr = rq_deref_var(b, index, vars->candidate.instance_addr),
+ .sbt_offset_and_flags = rq_deref_var(b, index, vars->candidate.sbt_offset_and_flags),
+ };
+
+ struct traversal_data data = {
+ .vars = vars,
+ .index = index,
+ };
+
+ struct radv_ray_traversal_args args = {
+ .root_bvh_base = rq_load_var(b, index, vars->root_bvh_base),
+ .flags = rq_load_var(b, index, vars->flags),
+ .cull_mask = rq_load_var(b, index, vars->cull_mask),
+ .origin = rq_load_var(b, index, vars->origin),
+ .tmin = rq_load_var(b, index, vars->tmin),
+ .dir = rq_load_var(b, index, vars->direction),
+ .vars = trav_vars,
+ .stack_entries = vars->stack_entries,
+ .ignore_cull_mask = ignore_cull_mask,
+ .stack_store_cb = store_stack_entry,
+ .stack_load_cb = load_stack_entry,
+ .aabb_cb = handle_candidate_aabb,
+ .triangle_cb = handle_candidate_triangle,
+ .data = &data,
+ };
+
+ if (vars->stack) {
+ args.stack_stride = 1;
+ args.stack_base = 0;
+ } else {
+ uint32_t workgroup_size =
+ b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * b->shader->info.workgroup_size[2];
+ args.stack_stride = workgroup_size * 4;
+ args.stack_base = vars->shared_base;
+ }
+
+ nir_push_if(b, rq_load_var(b, index, vars->incomplete));
+ {
+ nir_def *incomplete = radv_build_ray_traversal(device, b, &args);
+ rq_store_var(b, index, vars->incomplete, nir_iand(b, rq_load_var(b, index, vars->incomplete), incomplete), 1);
+ }
+ nir_pop_if(b, NULL);
+
+ return rq_load_var(b, index, vars->incomplete);
+}
+
+static void
+lower_rq_terminate(nir_builder *b, nir_def *index, nir_intrinsic_instr *instr, struct ray_query_vars *vars)
+{
+ rq_store_var(b, index, vars->incomplete, nir_imm_false(b), 0x1);
+}
+
+bool
+radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device)
+{
+ const struct radv_physical_device *pdev = radv_device_physical(device);
+ struct radv_instance *instance = radv_physical_device_instance(pdev);
+ bool progress = false;
+ struct hash_table *query_ht = _mesa_pointer_hash_table_create(NULL);
+
+ nir_foreach_variable_in_list (var, &shader->variables) {
+ if (!var->data.ray_query)
+ continue;
+
+ lower_ray_query(shader, var, query_ht, pdev->max_shared_size);
+
+ progress = true;
+ }
+
+ nir_foreach_function (function, shader) {
+ if (!function->impl)
+ continue;
+
+ nir_builder builder = nir_builder_create(function->impl);
+
+ nir_foreach_variable_in_list (var, &function->impl->locals) {
+ if (!var->data.ray_query)
+ continue;
+
+ lower_ray_query(shader, var, query_ht, pdev->max_shared_size);
+
+ progress = true;
+ }
+
+ nir_foreach_block (block, function->impl) {
+ nir_foreach_instr_safe (instr, block) {
+ if (instr->type != nir_instr_type_intrinsic)
+ continue;
+
+ nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
+
+ if (!nir_intrinsic_is_ray_query(intrinsic->intrinsic))
+ continue;
+
+ nir_deref_instr *ray_query_deref = nir_instr_as_deref(intrinsic->src[0].ssa->parent_instr);
+ nir_def *index = NULL;
+
+ if (ray_query_deref->deref_type == nir_deref_type_array) {
+ index = ray_query_deref->arr.index.ssa;
+ ray_query_deref = nir_instr_as_deref(ray_query_deref->parent.ssa->parent_instr);
+ }
+
+ assert(ray_query_deref->deref_type == nir_deref_type_var);
+
+ struct ray_query_vars *vars =
+ (struct ray_query_vars *)_mesa_hash_table_search(query_ht, ray_query_deref->var)->data;
+
+ builder.cursor = nir_before_instr(instr);
+
+ nir_def *new_dest = NULL;
+
+ switch (intrinsic->intrinsic) {
+ case nir_intrinsic_rq_confirm_intersection:
+ lower_rq_confirm_intersection(&builder, index, intrinsic, vars);
+ break;
+ case nir_intrinsic_rq_generate_intersection:
+ lower_rq_generate_intersection(&builder, index, intrinsic, vars);
+ break;
+ case nir_intrinsic_rq_initialize:
+ lower_rq_initialize(&builder, index, intrinsic, vars, instance);
+ break;
+ case nir_intrinsic_rq_load:
+ new_dest = lower_rq_load(device, &builder, index, intrinsic, vars);
+ break;
+ case nir_intrinsic_rq_proceed:
+ new_dest = lower_rq_proceed(&builder, index, intrinsic, vars, device);
+ break;
+ case nir_intrinsic_rq_terminate:
+ lower_rq_terminate(&builder, index, intrinsic, vars);
+ break;
+ default:
+ unreachable("Unsupported ray query intrinsic!");
+ }
+
+ if (new_dest)
+ nir_def_rewrite_uses(&intrinsic->def, new_dest);
+
+ nir_instr_remove(instr);
+ nir_instr_free(instr);
+
+ progress = true;
+ }
+ }
+
+ nir_metadata_preserve(function->impl, nir_metadata_none);
+ }
+
+ ralloc_free(query_ht);
+
+ return progress;
+}