summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/compiler/spirv/nir_spirv.h2
-rw-r--r--src/compiler/spirv/vtn_opencl.c362
2 files changed, 362 insertions, 2 deletions
diff --git a/src/compiler/spirv/nir_spirv.h b/src/compiler/spirv/nir_spirv.h
index 26f5024edc6..b4d8b98b9ca 100644
--- a/src/compiler/spirv/nir_spirv.h
+++ b/src/compiler/spirv/nir_spirv.h
@@ -83,6 +83,8 @@ struct spirv_to_nir_options {
nir_address_format temp_addr_format;
nir_address_format constant_addr_format;
+ nir_shader *clc_shader;
+
struct {
void (*func)(void *private_data,
enum nir_spirv_debug_level level,
diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c
index 579c9c26248..c9bb93f516b 100644
--- a/src/compiler/spirv/vtn_opencl.c
+++ b/src/compiler/spirv/vtn_opencl.c
@@ -36,6 +36,174 @@ typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b,
struct vtn_type **src_types,
const struct vtn_type *dest_type);
+static int to_llvm_address_space(SpvStorageClass mode)
+{
+ switch (mode) {
+ case SpvStorageClassPrivate:
+ case SpvStorageClassFunction: return 0;
+ case SpvStorageClassCrossWorkgroup: return 1;
+ case SpvStorageClassUniform:
+ case SpvStorageClassUniformConstant: return 2;
+ case SpvStorageClassWorkgroup: return 3;
+ default: return -1;
+ }
+}
+
+
+static void
+vtn_opencl_mangle(const char *in_name,
+ uint32_t const_mask,
+ int ntypes, struct vtn_type **src_types,
+ char **outstring)
+{
+ char local_name[256] = "";
+ char *args_str = local_name + sprintf(local_name, "_Z%zu%s", strlen(in_name), in_name);
+
+ for (unsigned i = 0; i < ntypes; ++i) {
+ const struct glsl_type *type = src_types[i]->type;
+ enum vtn_base_type base_type = src_types[i]->base_type;
+ if (src_types[i]->base_type == vtn_base_type_pointer) {
+ *(args_str++) = 'P';
+ int address_space = to_llvm_address_space(src_types[i]->storage_class);
+ if (address_space > 0)
+ args_str += sprintf(args_str, "U3AS%d", address_space);
+
+ type = src_types[i]->deref->type;
+ base_type = src_types[i]->deref->base_type;
+ }
+
+ if (const_mask & (1 << i))
+ *(args_str++) = 'K';
+
+ unsigned num_elements = glsl_get_components(type);
+ if (num_elements > 1) {
+ /* Vectors are not treated as built-ins for mangling, so check for substitution.
+ * In theory, we'd need to know which substitution value this is. In practice,
+ * the functions we need from libclc only support 1
+ */
+ bool substitution = false;
+ for (unsigned j = 0; j < i; ++j) {
+ const struct glsl_type *other_type = src_types[j]->base_type == vtn_base_type_pointer ?
+ src_types[j]->deref->type : src_types[j]->type;
+ if (type == other_type) {
+ substitution = true;
+ break;
+ }
+ }
+
+ if (substitution) {
+ args_str += sprintf(args_str, "S_");
+ continue;
+ } else
+ args_str += sprintf(args_str, "Dv%d_", num_elements);
+ }
+
+ const char *suffix = NULL;
+ switch (base_type) {
+ case vtn_base_type_sampler: suffix = "11ocl_sampler"; break;
+ case vtn_base_type_event: suffix = "9ocl_event"; break;
+ default: {
+ const char *primitives[] = {
+ [GLSL_TYPE_UINT] = "j",
+ [GLSL_TYPE_INT] = "i",
+ [GLSL_TYPE_FLOAT] = "f",
+ [GLSL_TYPE_FLOAT16] = "Dh",
+ [GLSL_TYPE_DOUBLE] = "d",
+ [GLSL_TYPE_UINT8] = "h",
+ [GLSL_TYPE_INT8] = "c",
+ [GLSL_TYPE_UINT16] = "t",
+ [GLSL_TYPE_INT16] = "s",
+ [GLSL_TYPE_UINT64] = "m",
+ [GLSL_TYPE_INT64] = "l",
+ [GLSL_TYPE_BOOL] = "b",
+ [GLSL_TYPE_ERROR] = NULL,
+ };
+ enum glsl_base_type glsl_base_type = glsl_get_base_type(type);
+ assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]);
+ suffix = primitives[glsl_base_type];
+ break;
+ }
+ }
+ args_str += sprintf(args_str, "%s", suffix);
+ }
+
+ *outstring = strdup(local_name);
+}
+
+static nir_function *mangle_and_find(struct vtn_builder *b,
+ const char *name,
+ uint32_t const_mask,
+ uint32_t num_srcs,
+ struct vtn_type **src_types)
+{
+ char *mname;
+ nir_function *found = NULL;
+
+ vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname);
+ /* try and find in current shader first. */
+ nir_foreach_function(funcs, b->shader) {
+ if (!strcmp(funcs->name, mname)) {
+ found = funcs;
+ break;
+ }
+ }
+ /* if not found here find in clc shader and create a decl mirroring it */
+ if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) {
+ nir_foreach_function(funcs, b->options->clc_shader) {
+ if (!strcmp(funcs->name, mname)) {
+ found = funcs;
+ break;
+ }
+ }
+ if (found) {
+ nir_function *decl = nir_function_create(b->shader, mname);
+ decl->num_params = found->num_params;
+ decl->params = ralloc_array(b->shader, nir_parameter, decl->num_params);
+ for (unsigned i = 0; i < decl->num_params; i++) {
+ decl->params[i] = found->params[i];
+ }
+ found = decl;
+ }
+ }
+ if (!found)
+ vtn_fail("Can't find clc function %s\n", mname);
+ free(mname);
+ return found;
+}
+
+static bool call_mangled_function(struct vtn_builder *b,
+ const char *name,
+ uint32_t const_mask,
+ uint32_t num_srcs,
+ struct vtn_type **src_types,
+ const struct vtn_type *dest_type,
+ nir_ssa_def **srcs,
+ nir_deref_instr **ret_deref_ptr)
+{
+ nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types);
+ if (!found)
+ return false;
+
+ nir_call_instr *call = nir_call_instr_create(b->shader, found);
+
+ nir_deref_instr *ret_deref = NULL;
+ uint32_t param_idx = 0;
+ if (dest_type) {
+ nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl,
+ glsl_get_bare_type(dest_type->type),
+ "return_tmp");
+ ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
+ call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
+ }
+
+ for (unsigned i = 0; i < num_srcs; i++)
+ call->params[param_idx++] = nir_src_for_ssa(srcs[i]);
+ nir_builder_instr_insert(&b->nb, &call->instr);
+
+ *ret_deref_ptr = ret_deref;
+ return true;
+}
+
static void
handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
const uint32_t *w, unsigned count, nir_handler handler)
@@ -129,6 +297,189 @@ handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
return ret;
}
+#define REMAP(op, str) [OpenCLstd_##op] = { str }
+static const struct {
+ const char *fn;
+} remap_table[] = {
+ REMAP(Distance, "distance"),
+ REMAP(Fast_distance, "fast_distance"),
+ REMAP(Fast_length, "fast_length"),
+ REMAP(Fast_normalize, "fast_normalize"),
+ REMAP(Half_rsqrt, "half_rsqrt"),
+ REMAP(Half_sqrt, "half_sqrt"),
+ REMAP(Length, "length"),
+ REMAP(Normalize, "normalize"),
+ REMAP(Degrees, "degrees"),
+ REMAP(Radians, "radians"),
+ REMAP(Rotate, "rotate"),
+ REMAP(Smoothstep, "smoothstep"),
+ REMAP(Step, "step"),
+
+ REMAP(Pow, "pow"),
+ REMAP(Pown, "pown"),
+ REMAP(Powr, "powr"),
+ REMAP(Rootn, "rootn"),
+ REMAP(Modf, "modf"),
+
+ REMAP(Acos, "acos"),
+ REMAP(Acosh, "acosh"),
+ REMAP(Acospi, "acospi"),
+ REMAP(Asin, "asin"),
+ REMAP(Asinh, "asinh"),
+ REMAP(Asinpi, "asinpi"),
+ REMAP(Atan, "atan"),
+ REMAP(Atan2, "atan2"),
+ REMAP(Atanh, "atanh"),
+ REMAP(Atanpi, "atanpi"),
+ REMAP(Atan2pi, "atan2pi"),
+ REMAP(Cos, "cos"),
+ REMAP(Cosh, "cosh"),
+ REMAP(Cospi, "cospi"),
+ REMAP(Sin, "sin"),
+ REMAP(Sinh, "sinh"),
+ REMAP(Sinpi, "sinpi"),
+ REMAP(Tan, "tan"),
+ REMAP(Tanh, "tanh"),
+ REMAP(Tanpi, "tanpi"),
+ REMAP(Sincos, "sincos"),
+ REMAP(Fract, "fract"),
+ REMAP(Frexp, "frexp"),
+ REMAP(Fma, "fma"),
+ REMAP(Fmod, "fmod"),
+
+ REMAP(Half_cos, "cos"),
+ REMAP(Half_exp, "exp"),
+ REMAP(Half_exp2, "exp2"),
+ REMAP(Half_exp10, "exp10"),
+ REMAP(Half_log, "log"),
+ REMAP(Half_log2, "log2"),
+ REMAP(Half_log10, "log10"),
+ REMAP(Half_powr, "powr"),
+ REMAP(Half_sin, "sin"),
+ REMAP(Half_tan, "tan"),
+
+ REMAP(Remainder, "remainder"),
+ REMAP(Remquo, "remquo"),
+ REMAP(Hypot, "hypot"),
+ REMAP(Exp, "exp"),
+ REMAP(Exp2, "exp2"),
+ REMAP(Exp10, "exp10"),
+ REMAP(Expm1, "expm1"),
+ REMAP(Ldexp, "ldexp"),
+
+ REMAP(Ilogb, "ilogb"),
+ REMAP(Log, "log"),
+ REMAP(Log2, "log2"),
+ REMAP(Log10, "log10"),
+ REMAP(Log1p, "log1p"),
+ REMAP(Logb, "logb"),
+
+ REMAP(Cbrt, "cbrt"),
+ REMAP(Erfc, "erfc"),
+ REMAP(Erf, "erf"),
+
+ REMAP(Lgamma, "lgamma"),
+ REMAP(Lgamma_r, "lgamma_r"),
+ REMAP(Tgamma, "tgamma"),
+
+ REMAP(UMad_sat, "mad_sat"),
+ REMAP(SMad_sat, "mad_sat"),
+
+ REMAP(Shuffle, "shuffle"),
+ REMAP(Shuffle2, "shuffle2"),
+};
+#undef REMAP
+
+static const char *remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)
+{
+ if (opcode >= (sizeof(remap_table) / sizeof(const char *)))
+ return NULL;
+ return remap_table[opcode].fn;
+}
+
+static struct vtn_type *
+get_vtn_type_for_glsl_type(struct vtn_builder *b, const struct glsl_type *type)
+{
+ struct vtn_type *ret = rzalloc(b, struct vtn_type);
+ assert(glsl_type_is_vector_or_scalar(type));
+ ret->type = type;
+ ret->length = glsl_get_vector_elements(type);
+ ret->base_type = glsl_type_is_vector(type) ? vtn_base_type_vector : vtn_base_type_scalar;
+ return ret;
+}
+
+static struct vtn_type *
+get_pointer_type(struct vtn_builder *b, struct vtn_type *t, SpvStorageClass storage_class)
+{
+ struct vtn_type *ret = rzalloc(b, struct vtn_type);
+ ret->type = nir_address_format_to_glsl_type(
+ vtn_mode_to_address_format(
+ b, vtn_storage_class_to_mode(b, storage_class, NULL, NULL)));
+ ret->base_type = vtn_base_type_pointer;
+ ret->storage_class = storage_class;
+ ret->deref = t;
+ return ret;
+}
+
+static struct vtn_type *
+get_signed_type(struct vtn_builder *b, struct vtn_type *t)
+{
+ if (t->base_type == vtn_base_type_pointer) {
+ return get_pointer_type(b, get_signed_type(b, t->deref), t->storage_class);
+ }
+ return get_vtn_type_for_glsl_type(
+ b, glsl_vector_type(glsl_signed_base_type_of(glsl_get_base_type(t->type)),
+ glsl_get_vector_elements(t->type)));
+}
+
+static nir_ssa_def *
+handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
+ int num_srcs,
+ nir_ssa_def **srcs,
+ struct vtn_type **src_types,
+ const struct vtn_type *dest_type)
+{
+ const char *name = remap_clc_opcode(opcode);
+ if (!name)
+ return NULL;
+
+ /* Some functions which take params end up with uint (or pointer-to-uint) being passed,
+ * which doesn't mangle correctly when the function expects int or pointer-to-int.
+ * See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
+ */
+ int signed_param = -1;
+ switch (opcode) {
+ case OpenCLstd_Frexp:
+ case OpenCLstd_Lgamma_r:
+ case OpenCLstd_Pown:
+ case OpenCLstd_Rootn:
+ case OpenCLstd_Ldexp:
+ signed_param = 1;
+ break;
+ case OpenCLstd_Remquo:
+ signed_param = 2;
+ break;
+ case OpenCLstd_SMad_sat: {
+ /* All parameters need to be converted to signed */
+ src_types[0] = src_types[1] = src_types[2] = get_signed_type(b, src_types[0]);
+ break;
+ }
+ default: break;
+ }
+
+ if (signed_param >= 0) {
+ src_types[signed_param] = get_signed_type(b, src_types[signed_param]);
+ }
+
+ nir_deref_instr *ret_deref = NULL;
+
+ if (!call_mangled_function(b, name, 0, num_srcs, src_types,
+ dest_type, srcs, &ret_deref))
+ return NULL;
+
+ return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
+}
+
static nir_ssa_def *
handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
@@ -138,6 +489,7 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
switch (opcode) {
case OpenCLstd_SAbs_diff:
+ /* these works easier in direct NIR */
return nir_iabs_diff(nb, srcs[0], srcs[1]);
case OpenCLstd_UAbs_diff:
return nir_uabs_diff(nb, srcs[0], srcs[1]);
@@ -207,6 +559,7 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
return nir_sge(nb, srcs[1], srcs[0]);
case OpenCLstd_S_Upsample:
case OpenCLstd_U_Upsample:
+ /* SPIR-V and CL have different defs for upsample, just implement in nir */
return nir_upsample(nb, srcs[0], srcs[1]);
case OpenCLstd_Native_exp:
return nir_fexp(nb, srcs[0]);
@@ -219,9 +572,14 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
case OpenCLstd_Native_tan:
return nir_ftan(nb, srcs[0]);
default:
- vtn_fail("No NIR equivalent");
- return NULL;
+ break;
}
+
+ nir_ssa_def *ret = handle_clc_fn(b, opcode, num_srcs, srcs, src_types, dest_type);
+ if (!ret)
+ vtn_fail("No NIR equivalent");
+
+ return ret;
}
static void