summaryrefslogtreecommitdiff
path: root/src/compiler/spirv/vtn_opencl.c
blob: 579c9c262487e4fb2bddd99e52da14d42366d320 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
/*
 * Copyright © 2018 Red Hat
 *
 * Permission is hereby granted, free of charge, to any person obtaining a
 * copy of this software and associated documentation files (the "Software"),
 * to deal in the Software without restriction, including without limitation
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
 * and/or sell copies of the Software, and to permit persons to whom the
 * Software is furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice (including the next
 * paragraph) shall be included in all copies or substantial portions of the
 * Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 * IN THE SOFTWARE.
 *
 * Authors:
 *    Rob Clark (robdclark@gmail.com)
 */

#include "math.h"
#include "nir/nir_builtin_builder.h"

#include "vtn_private.h"
#include "OpenCL.std.h"

typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b,
                                    enum OpenCLstd_Entrypoints opcode,
                                    unsigned num_srcs, nir_ssa_def **srcs,
                                    struct vtn_type **src_types,
                                    const struct vtn_type *dest_type);

static void
handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
             const uint32_t *w, unsigned count, nir_handler handler)
{
   struct vtn_type *dest_type = vtn_get_type(b, w[1]);

   unsigned num_srcs = count - 5;
   nir_ssa_def *srcs[3] = { NULL };
   struct vtn_type *src_types[3] = { NULL };
   vtn_assert(num_srcs <= ARRAY_SIZE(srcs));
   for (unsigned i = 0; i < num_srcs; i++) {
      struct vtn_value *val = vtn_untyped_value(b, w[i + 5]);
      struct vtn_ssa_value *ssa = vtn_ssa_value(b, w[i + 5]);
      srcs[i] = ssa->def;
      src_types[i] = val->type;
   }

   nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, src_types, dest_type);
   if (result) {
      vtn_push_nir_ssa(b, w[2], result);
   } else {
      vtn_assert(dest_type->type == glsl_void_type());
   }
}

static nir_op
nir_alu_op_for_opencl_opcode(struct vtn_builder *b,
                             enum OpenCLstd_Entrypoints opcode)
{
   switch (opcode) {
   case OpenCLstd_Fabs: return nir_op_fabs;
   case OpenCLstd_SAbs: return nir_op_iabs;
   case OpenCLstd_SAdd_sat: return nir_op_iadd_sat;
   case OpenCLstd_UAdd_sat: return nir_op_uadd_sat;
   case OpenCLstd_Ceil: return nir_op_fceil;
   case OpenCLstd_Cos: return nir_op_fcos;
   case OpenCLstd_Exp2: return nir_op_fexp2;
   case OpenCLstd_Log2: return nir_op_flog2;
   case OpenCLstd_Floor: return nir_op_ffloor;
   case OpenCLstd_SHadd: return nir_op_ihadd;
   case OpenCLstd_UHadd: return nir_op_uhadd;
   case OpenCLstd_Fma: return nir_op_ffma;
   case OpenCLstd_Fmax: return nir_op_fmax;
   case OpenCLstd_SMax: return nir_op_imax;
   case OpenCLstd_UMax: return nir_op_umax;
   case OpenCLstd_Fmin: return nir_op_fmin;
   case OpenCLstd_SMin: return nir_op_imin;
   case OpenCLstd_UMin: return nir_op_umin;
   case OpenCLstd_Fmod: return nir_op_fmod;
   case OpenCLstd_Mix: return nir_op_flrp;
   case OpenCLstd_Native_cos: return nir_op_fcos;
   case OpenCLstd_Native_divide: return nir_op_fdiv;
   case OpenCLstd_Native_exp2: return nir_op_fexp2;
   case OpenCLstd_Native_log2: return nir_op_flog2;
   case OpenCLstd_Native_powr: return nir_op_fpow;
   case OpenCLstd_Native_recip: return nir_op_frcp;
   case OpenCLstd_Native_rsqrt: return nir_op_frsq;
   case OpenCLstd_Native_sin: return nir_op_fsin;
   case OpenCLstd_Native_sqrt: return nir_op_fsqrt;
   case OpenCLstd_SMul_hi: return nir_op_imul_high;
   case OpenCLstd_UMul_hi: return nir_op_umul_high;
   case OpenCLstd_Popcount: return nir_op_bit_count;
   case OpenCLstd_Pow: return nir_op_fpow;
   case OpenCLstd_Remainder: return nir_op_frem;
   case OpenCLstd_SRhadd: return nir_op_irhadd;
   case OpenCLstd_URhadd: return nir_op_urhadd;
   case OpenCLstd_Rsqrt: return nir_op_frsq;
   case OpenCLstd_Sign: return nir_op_fsign;
   case OpenCLstd_Sin: return nir_op_fsin;
   case OpenCLstd_Sqrt: return nir_op_fsqrt;
   case OpenCLstd_SSub_sat: return nir_op_isub_sat;
   case OpenCLstd_USub_sat: return nir_op_usub_sat;
   case OpenCLstd_Trunc: return nir_op_ftrunc;
   case OpenCLstd_Rint: return nir_op_fround_even;
   /* uhm... */
   case OpenCLstd_UAbs: return nir_op_mov;
   default:
      vtn_fail("No NIR equivalent");
   }
}

static nir_ssa_def *
handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
           unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
           const struct vtn_type *dest_type)
{
   nir_ssa_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, opcode),
                                    srcs[0], srcs[1], srcs[2], NULL);
   if (opcode == OpenCLstd_Popcount)
      ret = nir_u2u(&b->nb, ret, glsl_get_bit_size(dest_type->type));
   return ret;
}

static nir_ssa_def *
handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
               unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
               const struct vtn_type *dest_type)
{
   nir_builder *nb = &b->nb;

   switch (opcode) {
   case OpenCLstd_SAbs_diff:
      return nir_iabs_diff(nb, srcs[0], srcs[1]);
   case OpenCLstd_UAbs_diff:
      return nir_uabs_diff(nb, srcs[0], srcs[1]);
   case OpenCLstd_Bitselect:
      return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_SMad_hi:
      return nir_imad_hi(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_UMad_hi:
      return nir_umad_hi(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_SMul24:
      return nir_imul24(nb, srcs[0], srcs[1]);
   case OpenCLstd_UMul24:
      return nir_umul24(nb, srcs[0], srcs[1]);
   case OpenCLstd_SMad24:
      return nir_imad24(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_UMad24:
      return nir_umad24(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_FClamp:
      return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_SClamp:
      return nir_iclamp(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_UClamp:
      return nir_uclamp(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_Copysign:
      return nir_copysign(nb, srcs[0], srcs[1]);
   case OpenCLstd_Cross:
      if (dest_type->length == 4)
         return nir_cross4(nb, srcs[0], srcs[1]);
      return nir_cross3(nb, srcs[0], srcs[1]);
   case OpenCLstd_Degrees:
      return nir_degrees(nb, srcs[0]);
   case OpenCLstd_Fdim:
      return nir_fdim(nb, srcs[0], srcs[1]);
   case OpenCLstd_Distance:
      return nir_distance(nb, srcs[0], srcs[1]);
   case OpenCLstd_Fast_distance:
      return nir_fast_distance(nb, srcs[0], srcs[1]);
   case OpenCLstd_Fast_length:
      return nir_fast_length(nb, srcs[0]);
   case OpenCLstd_Fast_normalize:
      return nir_fast_normalize(nb, srcs[0]);
   case OpenCLstd_Length:
      return nir_length(nb, srcs[0]);
   case OpenCLstd_Mad:
      return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_Maxmag:
      return nir_maxmag(nb, srcs[0], srcs[1]);
   case OpenCLstd_Minmag:
      return nir_minmag(nb, srcs[0], srcs[1]);
   case OpenCLstd_Nan:
      return nir_nan(nb, srcs[0]);
   case OpenCLstd_Nextafter:
      return nir_nextafter(nb, srcs[0], srcs[1]);
   case OpenCLstd_Normalize:
      return nir_normalize(nb, srcs[0]);
   case OpenCLstd_Radians:
      return nir_radians(nb, srcs[0]);
   case OpenCLstd_Rotate:
      return nir_rotate(nb, srcs[0], srcs[1]);
   case OpenCLstd_Smoothstep:
      return nir_smoothstep(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_Clz:
      return nir_clz_u(nb, srcs[0]);
   case OpenCLstd_Select:
      return nir_select(nb, srcs[0], srcs[1], srcs[2]);
   case OpenCLstd_Step:
      return nir_sge(nb, srcs[1], srcs[0]);
   case OpenCLstd_S_Upsample:
   case OpenCLstd_U_Upsample:
      return nir_upsample(nb, srcs[0], srcs[1]);
   case OpenCLstd_Native_exp:
      return nir_fexp(nb, srcs[0]);
   case OpenCLstd_Native_exp10:
      return nir_fexp2(nb, nir_fmul_imm(nb, srcs[0], log(10) / log(2)));
   case OpenCLstd_Native_log:
      return nir_flog(nb, srcs[0]);
   case OpenCLstd_Native_log10:
      return nir_fmul_imm(nb, nir_flog2(nb, srcs[0]), log(2) / log(10));
   case OpenCLstd_Native_tan:
      return nir_ftan(nb, srcs[0]);
   default:
      vtn_fail("No NIR equivalent");
      return NULL;
   }
}

static void
_handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
                     const uint32_t *w, unsigned count, bool load)
{
   struct vtn_type *type;
   if (load)
      type = vtn_get_type(b, w[1]);
   else
      type = vtn_get_value_type(b, w[5]);
   unsigned a = load ? 0 : 1;

   const struct glsl_type *dest_type = type->type;
   unsigned components = glsl_get_vector_elements(dest_type);

   nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]);
   struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer);

   struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS];
   nir_ssa_def *ncomps[NIR_MAX_VEC_COMPONENTS];

   nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset, components);
   nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer);

   for (int i = 0; i < components; i++) {
      nir_ssa_def *coffset = nir_iadd_imm(&b->nb, moffset, i);
      nir_deref_instr *arr_deref = nir_build_deref_ptr_as_array(&b->nb, deref, coffset);

      if (load) {
         comps[i] = vtn_local_load(b, arr_deref, p->type->access);
         ncomps[i] = comps[i]->def;
      } else {
         struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(glsl_get_base_type(dest_type)));
         struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]);
         ssa->def = nir_channel(&b->nb, val->def, i);
         vtn_local_store(b, ssa, arr_deref, p->type->access);
      }
   }
   if (load) {
      vtn_push_nir_ssa(b, w[2], nir_vec(&b->nb, ncomps, components));
   }
}

static void
vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
                        const uint32_t *w, unsigned count)
{
   _handle_v_load_store(b, opcode, w, count, true);
}

static void
vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
                         const uint32_t *w, unsigned count)
{
   _handle_v_load_store(b, opcode, w, count, false);
}

static nir_ssa_def *
handle_printf(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
              unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
              const struct vtn_type *dest_type)
{
   /* hahah, yeah, right.. */
   return nir_imm_int(&b->nb, -1);
}

static nir_ssa_def *
handle_round(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs,
             nir_ssa_def **srcs, struct vtn_type **src_types,
             const struct vtn_type *dest_type)
{
   nir_ssa_def *src = srcs[0];
   nir_builder *nb = &b->nb;
   nir_ssa_def *half = nir_imm_floatN_t(nb, 0.5, src->bit_size);
   nir_ssa_def *truncated = nir_ftrunc(nb, src);
   nir_ssa_def *remainder = nir_fsub(nb, src, truncated);

   return nir_bcsel(nb, nir_fge(nb, nir_fabs(nb, remainder), half),
                    nir_fadd(nb, truncated, nir_fsign(nb, src)), truncated);
}

static nir_ssa_def *
handle_shuffle(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs,
               nir_ssa_def **srcs, struct vtn_type **src_types,
               const struct vtn_type *dest_type)
{
   struct nir_ssa_def *input = srcs[0];
   struct nir_ssa_def *mask = srcs[1];

   unsigned out_elems = dest_type->length;
   nir_ssa_def *outres[NIR_MAX_VEC_COMPONENTS];
   unsigned in_elems = input->num_components;
   if (mask->bit_size != 32)
      mask = nir_u2u32(&b->nb, mask);
   mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, in_elems - 1, mask->bit_size));
   for (unsigned i = 0; i < out_elems; i++)
      outres[i] = nir_vector_extract(&b->nb, input, nir_channel(&b->nb, mask, i));

   return nir_vec(&b->nb, outres, out_elems);
}

static nir_ssa_def *
handle_shuffle2(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs,
                nir_ssa_def **srcs, struct vtn_type **src_types,
                const struct vtn_type *dest_type)
{
   struct nir_ssa_def *input0 = srcs[0];
   struct nir_ssa_def *input1 = srcs[1];
   struct nir_ssa_def *mask = srcs[2];

   unsigned out_elems = dest_type->length;
   nir_ssa_def *outres[NIR_MAX_VEC_COMPONENTS];
   unsigned in_elems = input0->num_components;
   unsigned total_mask = 2 * in_elems - 1;
   unsigned half_mask = in_elems - 1;
   if (mask->bit_size != 32)
      mask = nir_u2u32(&b->nb, mask);
   mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, total_mask, mask->bit_size));
   for (unsigned i = 0; i < out_elems; i++) {
      nir_ssa_def *this_mask = nir_channel(&b->nb, mask, i);
      nir_ssa_def *vmask = nir_iand(&b->nb, this_mask, nir_imm_intN_t(&b->nb, half_mask, mask->bit_size));
      nir_ssa_def *val0 = nir_vector_extract(&b->nb, input0, vmask);
      nir_ssa_def *val1 = nir_vector_extract(&b->nb, input1, vmask);
      nir_ssa_def *sel = nir_ilt(&b->nb, this_mask, nir_imm_intN_t(&b->nb, in_elems, mask->bit_size));
      outres[i] = nir_bcsel(&b->nb, sel, val0, val1);
   }
   return nir_vec(&b->nb, outres, out_elems);
}

bool
vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
                              const uint32_t *w, unsigned count)
{
   enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints) ext_opcode;

   switch (cl_opcode) {
   case OpenCLstd_Fabs:
   case OpenCLstd_SAbs:
   case OpenCLstd_UAbs:
   case OpenCLstd_SAdd_sat:
   case OpenCLstd_UAdd_sat:
   case OpenCLstd_Ceil:
   case OpenCLstd_Cos:
   case OpenCLstd_Exp2:
   case OpenCLstd_Log2:
   case OpenCLstd_Floor:
   case OpenCLstd_Fma:
   case OpenCLstd_Fmax:
   case OpenCLstd_SHadd:
   case OpenCLstd_UHadd:
   case OpenCLstd_SMax:
   case OpenCLstd_UMax:
   case OpenCLstd_Fmin:
   case OpenCLstd_SMin:
   case OpenCLstd_UMin:
   case OpenCLstd_Mix:
   case OpenCLstd_Native_cos:
   case OpenCLstd_Native_divide:
   case OpenCLstd_Native_exp2:
   case OpenCLstd_Native_log2:
   case OpenCLstd_Native_powr:
   case OpenCLstd_Native_recip:
   case OpenCLstd_Native_rsqrt:
   case OpenCLstd_Native_sin:
   case OpenCLstd_Native_sqrt:
   case OpenCLstd_Fmod:
   case OpenCLstd_SMul_hi:
   case OpenCLstd_UMul_hi:
   case OpenCLstd_Popcount:
   case OpenCLstd_Pow:
   case OpenCLstd_Remainder:
   case OpenCLstd_SRhadd:
   case OpenCLstd_URhadd:
   case OpenCLstd_Rsqrt:
   case OpenCLstd_Sign:
   case OpenCLstd_Sin:
   case OpenCLstd_Sqrt:
   case OpenCLstd_SSub_sat:
   case OpenCLstd_USub_sat:
   case OpenCLstd_Trunc:
   case OpenCLstd_Rint:
      handle_instr(b, cl_opcode, w, count, handle_alu);
      return true;
   case OpenCLstd_SAbs_diff:
   case OpenCLstd_UAbs_diff:
   case OpenCLstd_SMad_hi:
   case OpenCLstd_UMad_hi:
   case OpenCLstd_SMad24:
   case OpenCLstd_UMad24:
   case OpenCLstd_SMul24:
   case OpenCLstd_UMul24:
   case OpenCLstd_Bitselect:
   case OpenCLstd_FClamp:
   case OpenCLstd_SClamp:
   case OpenCLstd_UClamp:
   case OpenCLstd_Copysign:
   case OpenCLstd_Cross:
   case OpenCLstd_Degrees:
   case OpenCLstd_Fdim:
   case OpenCLstd_Distance:
   case OpenCLstd_Fast_distance:
   case OpenCLstd_Fast_length:
   case OpenCLstd_Fast_normalize:
   case OpenCLstd_Length:
   case OpenCLstd_Mad:
   case OpenCLstd_Maxmag:
   case OpenCLstd_Minmag:
   case OpenCLstd_Nan:
   case OpenCLstd_Nextafter:
   case OpenCLstd_Normalize:
   case OpenCLstd_Radians:
   case OpenCLstd_Rotate:
   case OpenCLstd_Select:
   case OpenCLstd_Step:
   case OpenCLstd_Smoothstep:
   case OpenCLstd_S_Upsample:
   case OpenCLstd_U_Upsample:
   case OpenCLstd_Clz:
   case OpenCLstd_Native_exp:
   case OpenCLstd_Native_exp10:
   case OpenCLstd_Native_log:
   case OpenCLstd_Native_log10:
   case OpenCLstd_Native_tan:
      handle_instr(b, cl_opcode, w, count, handle_special);
      return true;
   case OpenCLstd_Vloadn:
      vtn_handle_opencl_vload(b, cl_opcode, w, count);
      return true;
   case OpenCLstd_Vstoren:
      vtn_handle_opencl_vstore(b, cl_opcode, w, count);
      return true;
   case OpenCLstd_Shuffle:
      handle_instr(b, cl_opcode, w, count, handle_shuffle);
      return true;
   case OpenCLstd_Shuffle2:
      handle_instr(b, cl_opcode, w, count, handle_shuffle2);
      return true;
   case OpenCLstd_Round:
      handle_instr(b, cl_opcode, w, count, handle_round);
      return true;
   case OpenCLstd_Printf:
      handle_instr(b, cl_opcode, w, count, handle_printf);
      return true;
   case OpenCLstd_Prefetch:
      /* TODO maybe add a nir instruction for this? */
      return true;
   default:
      vtn_fail("unhandled opencl opc: %u\n", ext_opcode);
      return false;
   }
}