summaryrefslogtreecommitdiff
path: root/src/intel/compiler/brw_nir_rt.c
blob: 5943c5283a0e8f774db8d8ba69ebf8c8336c748a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
/*
 * Copyright © 2020 Intel Corporation
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice (including the next
 * paragraph) shall be included in all copies or substantial portions of the
 * Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 * IN THE SOFTWARE.
 */

#include "brw_nir_rt.h"
#include "brw_nir_rt_builder.h"

static bool
resize_deref(nir_builder *b, nir_deref_instr *deref,
             unsigned num_components, unsigned bit_size)
{
   assert(deref->dest.is_ssa);
   if (deref->dest.ssa.num_components == num_components &&
       deref->dest.ssa.bit_size == bit_size)
      return false;

   /* NIR requires array indices have to match the deref bit size */
   if (deref->dest.ssa.bit_size != bit_size &&
       (deref->deref_type == nir_deref_type_array ||
        deref->deref_type == nir_deref_type_ptr_as_array)) {
      b->cursor = nir_before_instr(&deref->instr);
      assert(deref->arr.index.is_ssa);
      nir_ssa_def *idx;
      if (nir_src_is_const(deref->arr.index)) {
         idx = nir_imm_intN_t(b, nir_src_as_int(deref->arr.index), bit_size);
      } else {
         idx = nir_i2i(b, deref->arr.index.ssa, bit_size);
      }
      nir_instr_rewrite_src(&deref->instr, &deref->arr.index,
                            nir_src_for_ssa(idx));
   }

   deref->dest.ssa.num_components = num_components;
   deref->dest.ssa.bit_size = bit_size;

   return true;
}

static bool
lower_rt_io_derefs(nir_shader *shader)
{
   nir_function_impl *impl = nir_shader_get_entrypoint(shader);

   bool progress = false;

   unsigned num_shader_call_vars = 0;
   nir_foreach_variable_with_modes(var, shader, nir_var_shader_call_data)
      num_shader_call_vars++;

   unsigned num_ray_hit_attrib_vars = 0;
   nir_foreach_variable_with_modes(var, shader, nir_var_ray_hit_attrib)
      num_ray_hit_attrib_vars++;

   /* At most one payload is allowed because it's an input.  Technically, this
    * is also true for hit attribute variables.  However, after we inline an
    * any-hit shader into an intersection shader, we can end up with multiple
    * hit attribute variables.  They'll end up mapping to a cast from the same
    * base pointer so this is fine.
    */
   assert(num_shader_call_vars <= 1);

   nir_builder b;
   nir_builder_init(&b, impl);

   b.cursor = nir_before_cf_list(&impl->body);
   nir_ssa_def *call_data_addr = NULL;
   if (num_shader_call_vars > 0) {
      assert(shader->scratch_size >= BRW_BTD_STACK_CALLEE_DATA_SIZE);
      call_data_addr =
         brw_nir_rt_load_scratch(&b, BRW_BTD_STACK_CALL_DATA_PTR_OFFSET, 8,
                                 1, 64);
      progress = true;
   }

   gl_shader_stage stage = shader->info.stage;
   nir_ssa_def *hit_attrib_addr = NULL;
   if (num_ray_hit_attrib_vars > 0) {
      assert(stage == MESA_SHADER_ANY_HIT ||
             stage == MESA_SHADER_CLOSEST_HIT ||
             stage == MESA_SHADER_INTERSECTION);
      nir_ssa_def *hit_addr =
         brw_nir_rt_mem_hit_addr(&b, stage == MESA_SHADER_CLOSEST_HIT);
      /* The vec2 barycentrics are in 2nd and 3rd dwords of MemHit */
      nir_ssa_def *bary_addr = nir_iadd_imm(&b, hit_addr, 4);
      hit_attrib_addr = nir_bcsel(&b, nir_load_leaf_procedural_intel(&b),
                                      brw_nir_rt_hit_attrib_data_addr(&b),
                                      bary_addr);
      progress = true;
   }

   nir_foreach_block(block, impl) {
      nir_foreach_instr_safe(instr, block) {
         if (instr->type != nir_instr_type_deref)
            continue;

         nir_deref_instr *deref = nir_instr_as_deref(instr);
         if (nir_deref_mode_is(deref, nir_var_shader_call_data)) {
            deref->modes = nir_var_function_temp;
            if (deref->deref_type == nir_deref_type_var) {
               b.cursor = nir_before_instr(&deref->instr);
               nir_deref_instr *cast =
                  nir_build_deref_cast(&b, call_data_addr,
                                       nir_var_function_temp,
                                       deref->var->type, 0);
               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
                                        &cast->dest.ssa);
               nir_instr_remove(&deref->instr);
               progress = true;
            }
         } else if (nir_deref_mode_is(deref, nir_var_ray_hit_attrib)) {
            deref->modes = nir_var_function_temp;
            if (deref->deref_type == nir_deref_type_var) {
               b.cursor = nir_before_instr(&deref->instr);
               nir_deref_instr *cast =
                  nir_build_deref_cast(&b, hit_attrib_addr,
                                       nir_var_function_temp,
                                       deref->type, 0);
               nir_ssa_def_rewrite_uses(&deref->dest.ssa,
                                        &cast->dest.ssa);
               nir_instr_remove(&deref->instr);
               progress = true;
            }
         }

         /* We're going to lower all function_temp memory to scratch using
          * 64-bit addresses.  We need to resize all our derefs first or else
          * nir_lower_explicit_io will have a fit.
          */
         if (nir_deref_mode_is(deref, nir_var_function_temp) &&
             resize_deref(&b, deref, 1, 64))
            progress = true;
      }
   }

   if (progress) {
      nir_metadata_preserve(impl, nir_metadata_block_index |
                                  nir_metadata_dominance);
   } else {
      nir_metadata_preserve(impl, nir_metadata_all);
   }

   return progress;
}

/** Lowers ray-tracing shader I/O and scratch access
 *
 * SPV_KHR_ray_tracing adds three new types of I/O, each of which need their
 * own bit of special care:
 *
 *  - Shader payload data:  This is represented by the IncomingCallableData
 *    and IncomingRayPayload storage classes which are both represented by
 *    nir_var_call_data in NIR.  There is at most one of these per-shader and
 *    they contain payload data passed down the stack from the parent shader
 *    when it calls executeCallable() or traceRay().  In our implementation,
 *    the actual storage lives in the calling shader's scratch space and we're
 *    passed a pointer to it.
 *
 *  - Hit attribute data:  This is represented by the HitAttribute storage
 *    class in SPIR-V and nir_var_ray_hit_attrib in NIR.  For triangle
 *    geometry, it's supposed to contain two floats which are the barycentric
 *    coordinates.  For AABS/procedural geometry, it contains the hit data
 *    written out by the intersection shader.  In our implementation, it's a
 *    64-bit pointer which points either to the u/v area of the relevant
 *    MemHit data structure or the space right after the HW ray stack entry.
 *
 *  - Shader record buffer data:  This allows read-only access to the data
 *    stored in the SBT right after the bindless shader handles.  It's
 *    effectively a UBO with a magic address.  Coming out of spirv_to_nir,
 *    we get a nir_intrinsic_load_shader_record_ptr which is cast to a
 *    nir_var_mem_global deref and all access happens through that.  The
 *    shader_record_ptr system value is handled in brw_nir_lower_rt_intrinsics
 *    and we assume nir_lower_explicit_io is called elsewhere thanks to
 *    VK_KHR_buffer_device_address so there's really nothing to do here.
 *
 * We also handle lowering any remaining function_temp variables to scratch at
 * this point.  This gets rid of any remaining arrays and also takes care of
 * the sending side of ray payloads where we pass pointers to a function_temp
 * variable down the call stack.
 */
static void
lower_rt_io_and_scratch(nir_shader *nir)
{
   /* First, we to ensure all the I/O variables have explicit types.  Because
    * these are shader-internal and don't come in from outside, they don't
    * have an explicit memory layout and we have to assign them one.
    */
   NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
              nir_var_function_temp |
              nir_var_shader_call_data |
              nir_var_ray_hit_attrib,
              glsl_get_natural_size_align_bytes);

   /* Now patch any derefs to I/O vars */
   NIR_PASS_V(nir, lower_rt_io_derefs);

   /* Finally, lower any remaining function_temp, mem_constant, or
    * ray_hit_attrib access to 64-bit global memory access.
    */
   NIR_PASS_V(nir, nir_lower_explicit_io,
              nir_var_function_temp |
              nir_var_mem_constant |
              nir_var_ray_hit_attrib,
              nir_address_format_64bit_global);
}

static void
build_terminate_ray(nir_builder *b)
{
   nir_ssa_def *skip_closest_hit =
      nir_i2b(b, nir_iand_imm(b, nir_load_ray_flags(b),
                              BRW_RT_RAY_FLAG_SKIP_CLOSEST_HIT_SHADER));
   nir_push_if(b, skip_closest_hit);
   {
      /* The shader that calls traceRay() is unable to access any ray hit
       * information except for that which is explicitly written into the ray
       * payload by shaders invoked during the trace.  If there's no closest-
       * hit shader, then accepting the hit has no observable effect; it's
       * just extra memory traffic for no reason.
       */
      brw_nir_btd_return(b);
      nir_jump(b, nir_jump_halt);
   }
   nir_push_else(b, NULL);
   {
      /* The closest hit shader is in the same shader group as the any-hit
       * shader that we're currently in.  We can get the address for its SBT
       * handle by looking at the shader record pointer and subtracting the
       * size of a SBT handle.  The BINDLESS_SHADER_RECORD for a closest hit
       * shader is the first one in the SBT handle.
       */
      nir_ssa_def *closest_hit =
         nir_iadd_imm(b, nir_load_shader_record_ptr(b),
                        -BRW_RT_SBT_HANDLE_SIZE);

      brw_nir_rt_commit_hit(b);
      brw_nir_btd_spawn(b, closest_hit);
      nir_jump(b, nir_jump_halt);
   }
   nir_pop_if(b, NULL);
}

/** Lowers away ray walk intrinsics
 *
 * This lowers terminate_ray, ignore_ray_intersection, and the NIR-specific
 * accept_ray_intersection intrinsics to the appropriate Intel-specific
 * intrinsics.
 */
static bool
lower_ray_walk_intrinsics(nir_shader *shader,
                          const struct intel_device_info *devinfo)
{
   assert(shader->info.stage == MESA_SHADER_ANY_HIT ||
          shader->info.stage == MESA_SHADER_INTERSECTION);

   nir_function_impl *impl = nir_shader_get_entrypoint(shader);

   nir_builder b;
   nir_builder_init(&b, impl);

   bool progress = false;
   nir_foreach_block_safe(block, impl) {
      nir_foreach_instr_safe(instr, block) {
         if (instr->type != nir_instr_type_intrinsic)
            continue;

         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);

         switch (intrin->intrinsic) {
         case nir_intrinsic_ignore_ray_intersection: {
            b.cursor = nir_instr_remove(&intrin->instr);

            /* We put the newly emitted code inside a dummy if because it's
             * going to contain a jump instruction and we don't want to deal
             * with that mess here.  It'll get dealt with by our control-flow
             * optimization passes.
             */
            nir_push_if(&b, nir_imm_true(&b));
            nir_trace_ray_continue_intel(&b);
            nir_jump(&b, nir_jump_halt);
            nir_pop_if(&b, NULL);
            progress = true;
            break;
         }

         case nir_intrinsic_accept_ray_intersection: {
            b.cursor = nir_instr_remove(&intrin->instr);

            nir_ssa_def *terminate =
               nir_i2b(&b, nir_iand_imm(&b, nir_load_ray_flags(&b),
                                        BRW_RT_RAY_FLAG_TERMINATE_ON_FIRST_HIT));
            nir_push_if(&b, terminate);
            {
               build_terminate_ray(&b);
            }
            nir_push_else(&b, NULL);
            {
               nir_trace_ray_commit_intel(&b);
               nir_jump(&b, nir_jump_halt);
            }
            nir_pop_if(&b, NULL);
            progress = true;
            break;
         }

         case nir_intrinsic_terminate_ray: {
            b.cursor = nir_instr_remove(&intrin->instr);
            build_terminate_ray(&b);
            progress = true;
            break;
         }

         default:
            break;
         }
      }
   }

   if (progress) {
      nir_metadata_preserve(impl, nir_metadata_none);
   } else {
      nir_metadata_preserve(impl, nir_metadata_all);
   }

   return progress;
}

void
brw_nir_lower_raygen(nir_shader *nir)
{
   assert(nir->info.stage == MESA_SHADER_RAYGEN);
   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
   lower_rt_io_and_scratch(nir);
}

void
brw_nir_lower_any_hit(nir_shader *nir, const struct intel_device_info *devinfo)
{
   assert(nir->info.stage == MESA_SHADER_ANY_HIT);
   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
   NIR_PASS_V(nir, lower_ray_walk_intrinsics, devinfo);
   lower_rt_io_and_scratch(nir);
}

void
brw_nir_lower_closest_hit(nir_shader *nir)
{
   assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
   lower_rt_io_and_scratch(nir);
}

void
brw_nir_lower_miss(nir_shader *nir)
{
   assert(nir->info.stage == MESA_SHADER_MISS);
   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
   lower_rt_io_and_scratch(nir);
}

void
brw_nir_lower_callable(nir_shader *nir)
{
   assert(nir->info.stage == MESA_SHADER_CALLABLE);
   NIR_PASS_V(nir, brw_nir_lower_shader_returns);
   lower_rt_io_and_scratch(nir);
}

void
brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
                                            const nir_shader *any_hit,
                                            const struct intel_device_info *devinfo)
{
   assert(intersection->info.stage == MESA_SHADER_INTERSECTION);
   assert(any_hit == NULL || any_hit->info.stage == MESA_SHADER_ANY_HIT);
   NIR_PASS_V(intersection, brw_nir_lower_shader_returns);
   NIR_PASS_V(intersection, brw_nir_lower_intersection_shader,
              any_hit, devinfo);
   NIR_PASS_V(intersection, lower_ray_walk_intrinsics, devinfo);
   lower_rt_io_and_scratch(intersection);
}

static nir_ssa_def *
build_load_uniform(nir_builder *b, unsigned offset,
                   unsigned num_components, unsigned bit_size)
{
   return nir_load_uniform(b, num_components, bit_size, nir_imm_int(b, 0),
                           .base = offset,
                           .range = num_components * bit_size / 8);
}

#define load_trampoline_param(b, name, num_components, bit_size) \
   build_load_uniform((b), offsetof(struct brw_rt_raygen_trampoline_params, name), \
                      (num_components), (bit_size))

nir_shader *
brw_nir_create_raygen_trampoline(const struct brw_compiler *compiler,
                                 void *mem_ctx)
{
   const struct intel_device_info *devinfo = compiler->devinfo;
   const nir_shader_compiler_options *nir_options =
      compiler->glsl_compiler_options[MESA_SHADER_COMPUTE].NirOptions;

   STATIC_ASSERT(sizeof(struct brw_rt_raygen_trampoline_params) == 32);

   nir_builder b = nir_builder_init_simple_shader(MESA_SHADER_COMPUTE,
                                                  nir_options,
                                                  "RT Ray-Gen Trampoline");
   ralloc_steal(mem_ctx, b.shader);

   b.shader->info.workgroup_size_variable = true;

   /* The RT global data and raygen BINDLESS_SHADER_RECORD addresses are
    * passed in as push constants in the first register.  We deal with the
    * raygen BSR address here; the global data we'll deal with later.
    */
   b.shader->num_uniforms = 32;
   nir_ssa_def *raygen_bsr_addr =
      load_trampoline_param(&b, raygen_bsr_addr, 1, 64);
   nir_ssa_def *local_shift =
      nir_u2u32(&b, load_trampoline_param(&b, local_group_size_log2, 3, 8));

   nir_ssa_def *global_id = nir_load_workgroup_id(&b, 32);
   nir_ssa_def *simd_channel = nir_load_subgroup_invocation(&b);
   nir_ssa_def *local_x =
      nir_ubfe(&b, simd_channel, nir_imm_int(&b, 0),
                  nir_channel(&b, local_shift, 0));
   nir_ssa_def *local_y =
      nir_ubfe(&b, simd_channel, nir_channel(&b, local_shift, 0),
                  nir_channel(&b, local_shift, 1));
   nir_ssa_def *local_z =
      nir_ubfe(&b, simd_channel,
                  nir_iadd(&b, nir_channel(&b, local_shift, 0),
                              nir_channel(&b, local_shift, 1)),
                  nir_channel(&b, local_shift, 2));
   nir_ssa_def *launch_id =
      nir_iadd(&b, nir_ishl(&b, global_id, local_shift),
                  nir_vec3(&b, local_x, local_y, local_z));

   nir_ssa_def *launch_size = nir_load_ray_launch_size(&b);
   nir_push_if(&b, nir_ball(&b, nir_ult(&b, launch_id, launch_size)));
   {
      nir_store_global(&b, brw_nir_rt_sw_hotzone_addr(&b, devinfo), 16,
                       nir_vec4(&b, nir_imm_int(&b, 0), /* Stack ptr */
                                    nir_channel(&b, launch_id, 0),
                                    nir_channel(&b, launch_id, 1),
                                    nir_channel(&b, launch_id, 2)),
                       0xf /* write mask */);

      brw_nir_btd_spawn(&b, raygen_bsr_addr);
   }
   nir_push_else(&b, NULL);
   {
      /* Even though these invocations aren't being used for anything, the
       * hardware allocated stack IDs for them.  They need to retire them.
       */
      brw_nir_btd_retire(&b);
   }
   nir_pop_if(&b, NULL);

   nir_shader *nir = b.shader;
   nir->info.name = ralloc_strdup(nir, "RT: TraceRay trampoline");
   nir_validate_shader(nir, "in brw_nir_create_raygen_trampoline");
   brw_preprocess_nir(compiler, nir, NULL);

   NIR_PASS_V(nir, brw_nir_lower_rt_intrinsics, devinfo);

   /* brw_nir_lower_rt_intrinsics will leave us with a btd_global_arg_addr
    * intrinsic which doesn't exist in compute shaders.  We also created one
    * above when we generated the BTD spawn intrinsic.  Now we go through and
    * replace them with a uniform load.
    */
   nir_foreach_block(block, b.impl) {
      nir_foreach_instr_safe(instr, block) {
         if (instr->type != nir_instr_type_intrinsic)
            continue;

         nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
         if (intrin->intrinsic != nir_intrinsic_load_btd_global_arg_addr_intel)
            continue;

         b.cursor = nir_before_instr(&intrin->instr);
         nir_ssa_def *global_arg_addr =
            load_trampoline_param(&b, rt_disp_globals_addr, 1, 64);
         assert(intrin->dest.is_ssa);
         nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
                                  global_arg_addr);
         nir_instr_remove(instr);
      }
   }

   NIR_PASS_V(nir, brw_nir_lower_cs_intrinsics);

   brw_nir_optimize(nir, compiler, true, false);

   return nir;
}