diff options
author | Jesse Natalie <jenatali@microsoft.com> | 2020-08-18 07:16:32 -0700 |
---|---|---|
committer | Dave Airlie <airlied@linux.ie> | 2020-09-01 21:44:06 +0000 |
commit | 1a5516ab51a44ed3177c042b6437069df49e9276 (patch) | |
tree | 6aefaf5e0a3a53e300a6e3eb679e315515f162e4 | |
parent | 9bbe36a0a775c9346e2ff13657449cd79a8b9b73 (diff) |
vtn/opencl: Add infrastructure for calling out to libclc
This patch adds a function remap table with name mangling, which
can convert a SPIR-V OpenCL extension opcode to a call to the external
libclc shader, which will be lowered/inlined after conversion.
-rw-r--r-- | src/compiler/spirv/nir_spirv.h | 2 | ||||
-rw-r--r-- | src/compiler/spirv/vtn_opencl.c | 362 |
2 files changed, 362 insertions, 2 deletions
diff --git a/src/compiler/spirv/nir_spirv.h b/src/compiler/spirv/nir_spirv.h index 26f5024edc6..b4d8b98b9ca 100644 --- a/src/compiler/spirv/nir_spirv.h +++ b/src/compiler/spirv/nir_spirv.h @@ -83,6 +83,8 @@ struct spirv_to_nir_options { nir_address_format temp_addr_format; nir_address_format constant_addr_format; + nir_shader *clc_shader; + struct { void (*func)(void *private_data, enum nir_spirv_debug_level level, diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index 579c9c26248..c9bb93f516b 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -36,6 +36,174 @@ typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b, struct vtn_type **src_types, const struct vtn_type *dest_type); +static int to_llvm_address_space(SpvStorageClass mode) +{ + switch (mode) { + case SpvStorageClassPrivate: + case SpvStorageClassFunction: return 0; + case SpvStorageClassCrossWorkgroup: return 1; + case SpvStorageClassUniform: + case SpvStorageClassUniformConstant: return 2; + case SpvStorageClassWorkgroup: return 3; + default: return -1; + } +} + + +static void +vtn_opencl_mangle(const char *in_name, + uint32_t const_mask, + int ntypes, struct vtn_type **src_types, + char **outstring) +{ + char local_name[256] = ""; + char *args_str = local_name + sprintf(local_name, "_Z%zu%s", strlen(in_name), in_name); + + for (unsigned i = 0; i < ntypes; ++i) { + const struct glsl_type *type = src_types[i]->type; + enum vtn_base_type base_type = src_types[i]->base_type; + if (src_types[i]->base_type == vtn_base_type_pointer) { + *(args_str++) = 'P'; + int address_space = to_llvm_address_space(src_types[i]->storage_class); + if (address_space > 0) + args_str += sprintf(args_str, "U3AS%d", address_space); + + type = src_types[i]->deref->type; + base_type = src_types[i]->deref->base_type; + } + + if (const_mask & (1 << i)) + *(args_str++) = 'K'; + + unsigned num_elements = glsl_get_components(type); + if (num_elements > 1) { + /* Vectors are not treated as built-ins for mangling, so check for substitution. + * In theory, we'd need to know which substitution value this is. In practice, + * the functions we need from libclc only support 1 + */ + bool substitution = false; + for (unsigned j = 0; j < i; ++j) { + const struct glsl_type *other_type = src_types[j]->base_type == vtn_base_type_pointer ? + src_types[j]->deref->type : src_types[j]->type; + if (type == other_type) { + substitution = true; + break; + } + } + + if (substitution) { + args_str += sprintf(args_str, "S_"); + continue; + } else + args_str += sprintf(args_str, "Dv%d_", num_elements); + } + + const char *suffix = NULL; + switch (base_type) { + case vtn_base_type_sampler: suffix = "11ocl_sampler"; break; + case vtn_base_type_event: suffix = "9ocl_event"; break; + default: { + const char *primitives[] = { + [GLSL_TYPE_UINT] = "j", + [GLSL_TYPE_INT] = "i", + [GLSL_TYPE_FLOAT] = "f", + [GLSL_TYPE_FLOAT16] = "Dh", + [GLSL_TYPE_DOUBLE] = "d", + [GLSL_TYPE_UINT8] = "h", + [GLSL_TYPE_INT8] = "c", + [GLSL_TYPE_UINT16] = "t", + [GLSL_TYPE_INT16] = "s", + [GLSL_TYPE_UINT64] = "m", + [GLSL_TYPE_INT64] = "l", + [GLSL_TYPE_BOOL] = "b", + [GLSL_TYPE_ERROR] = NULL, + }; + enum glsl_base_type glsl_base_type = glsl_get_base_type(type); + assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]); + suffix = primitives[glsl_base_type]; + break; + } + } + args_str += sprintf(args_str, "%s", suffix); + } + + *outstring = strdup(local_name); +} + +static nir_function *mangle_and_find(struct vtn_builder *b, + const char *name, + uint32_t const_mask, + uint32_t num_srcs, + struct vtn_type **src_types) +{ + char *mname; + nir_function *found = NULL; + + vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname); + /* try and find in current shader first. */ + nir_foreach_function(funcs, b->shader) { + if (!strcmp(funcs->name, mname)) { + found = funcs; + break; + } + } + /* if not found here find in clc shader and create a decl mirroring it */ + if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) { + nir_foreach_function(funcs, b->options->clc_shader) { + if (!strcmp(funcs->name, mname)) { + found = funcs; + break; + } + } + if (found) { + nir_function *decl = nir_function_create(b->shader, mname); + decl->num_params = found->num_params; + decl->params = ralloc_array(b->shader, nir_parameter, decl->num_params); + for (unsigned i = 0; i < decl->num_params; i++) { + decl->params[i] = found->params[i]; + } + found = decl; + } + } + if (!found) + vtn_fail("Can't find clc function %s\n", mname); + free(mname); + return found; +} + +static bool call_mangled_function(struct vtn_builder *b, + const char *name, + uint32_t const_mask, + uint32_t num_srcs, + struct vtn_type **src_types, + const struct vtn_type *dest_type, + nir_ssa_def **srcs, + nir_deref_instr **ret_deref_ptr) +{ + nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types); + if (!found) + return false; + + nir_call_instr *call = nir_call_instr_create(b->shader, found); + + nir_deref_instr *ret_deref = NULL; + uint32_t param_idx = 0; + if (dest_type) { + nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl, + glsl_get_bare_type(dest_type->type), + "return_tmp"); + ret_deref = nir_build_deref_var(&b->nb, ret_tmp); + call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa); + } + + for (unsigned i = 0; i < num_srcs; i++) + call->params[param_idx++] = nir_src_for_ssa(srcs[i]); + nir_builder_instr_insert(&b->nb, &call->instr); + + *ret_deref_ptr = ret_deref; + return true; +} + static void handle_instr(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, const uint32_t *w, unsigned count, nir_handler handler) @@ -129,6 +297,189 @@ handle_alu(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, return ret; } +#define REMAP(op, str) [OpenCLstd_##op] = { str } +static const struct { + const char *fn; +} remap_table[] = { + REMAP(Distance, "distance"), + REMAP(Fast_distance, "fast_distance"), + REMAP(Fast_length, "fast_length"), + REMAP(Fast_normalize, "fast_normalize"), + REMAP(Half_rsqrt, "half_rsqrt"), + REMAP(Half_sqrt, "half_sqrt"), + REMAP(Length, "length"), + REMAP(Normalize, "normalize"), + REMAP(Degrees, "degrees"), + REMAP(Radians, "radians"), + REMAP(Rotate, "rotate"), + REMAP(Smoothstep, "smoothstep"), + REMAP(Step, "step"), + + REMAP(Pow, "pow"), + REMAP(Pown, "pown"), + REMAP(Powr, "powr"), + REMAP(Rootn, "rootn"), + REMAP(Modf, "modf"), + + REMAP(Acos, "acos"), + REMAP(Acosh, "acosh"), + REMAP(Acospi, "acospi"), + REMAP(Asin, "asin"), + REMAP(Asinh, "asinh"), + REMAP(Asinpi, "asinpi"), + REMAP(Atan, "atan"), + REMAP(Atan2, "atan2"), + REMAP(Atanh, "atanh"), + REMAP(Atanpi, "atanpi"), + REMAP(Atan2pi, "atan2pi"), + REMAP(Cos, "cos"), + REMAP(Cosh, "cosh"), + REMAP(Cospi, "cospi"), + REMAP(Sin, "sin"), + REMAP(Sinh, "sinh"), + REMAP(Sinpi, "sinpi"), + REMAP(Tan, "tan"), + REMAP(Tanh, "tanh"), + REMAP(Tanpi, "tanpi"), + REMAP(Sincos, "sincos"), + REMAP(Fract, "fract"), + REMAP(Frexp, "frexp"), + REMAP(Fma, "fma"), + REMAP(Fmod, "fmod"), + + REMAP(Half_cos, "cos"), + REMAP(Half_exp, "exp"), + REMAP(Half_exp2, "exp2"), + REMAP(Half_exp10, "exp10"), + REMAP(Half_log, "log"), + REMAP(Half_log2, "log2"), + REMAP(Half_log10, "log10"), + REMAP(Half_powr, "powr"), + REMAP(Half_sin, "sin"), + REMAP(Half_tan, "tan"), + + REMAP(Remainder, "remainder"), + REMAP(Remquo, "remquo"), + REMAP(Hypot, "hypot"), + REMAP(Exp, "exp"), + REMAP(Exp2, "exp2"), + REMAP(Exp10, "exp10"), + REMAP(Expm1, "expm1"), + REMAP(Ldexp, "ldexp"), + + REMAP(Ilogb, "ilogb"), + REMAP(Log, "log"), + REMAP(Log2, "log2"), + REMAP(Log10, "log10"), + REMAP(Log1p, "log1p"), + REMAP(Logb, "logb"), + + REMAP(Cbrt, "cbrt"), + REMAP(Erfc, "erfc"), + REMAP(Erf, "erf"), + + REMAP(Lgamma, "lgamma"), + REMAP(Lgamma_r, "lgamma_r"), + REMAP(Tgamma, "tgamma"), + + REMAP(UMad_sat, "mad_sat"), + REMAP(SMad_sat, "mad_sat"), + + REMAP(Shuffle, "shuffle"), + REMAP(Shuffle2, "shuffle2"), +}; +#undef REMAP + +static const char *remap_clc_opcode(enum OpenCLstd_Entrypoints opcode) +{ + if (opcode >= (sizeof(remap_table) / sizeof(const char *))) + return NULL; + return remap_table[opcode].fn; +} + +static struct vtn_type * +get_vtn_type_for_glsl_type(struct vtn_builder *b, const struct glsl_type *type) +{ + struct vtn_type *ret = rzalloc(b, struct vtn_type); + assert(glsl_type_is_vector_or_scalar(type)); + ret->type = type; + ret->length = glsl_get_vector_elements(type); + ret->base_type = glsl_type_is_vector(type) ? vtn_base_type_vector : vtn_base_type_scalar; + return ret; +} + +static struct vtn_type * +get_pointer_type(struct vtn_builder *b, struct vtn_type *t, SpvStorageClass storage_class) +{ + struct vtn_type *ret = rzalloc(b, struct vtn_type); + ret->type = nir_address_format_to_glsl_type( + vtn_mode_to_address_format( + b, vtn_storage_class_to_mode(b, storage_class, NULL, NULL))); + ret->base_type = vtn_base_type_pointer; + ret->storage_class = storage_class; + ret->deref = t; + return ret; +} + +static struct vtn_type * +get_signed_type(struct vtn_builder *b, struct vtn_type *t) +{ + if (t->base_type == vtn_base_type_pointer) { + return get_pointer_type(b, get_signed_type(b, t->deref), t->storage_class); + } + return get_vtn_type_for_glsl_type( + b, glsl_vector_type(glsl_signed_base_type_of(glsl_get_base_type(t->type)), + glsl_get_vector_elements(t->type))); +} + +static nir_ssa_def * +handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, + int num_srcs, + nir_ssa_def **srcs, + struct vtn_type **src_types, + const struct vtn_type *dest_type) +{ + const char *name = remap_clc_opcode(opcode); + if (!name) + return NULL; + + /* Some functions which take params end up with uint (or pointer-to-uint) being passed, + * which doesn't mangle correctly when the function expects int or pointer-to-int. + * See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers + */ + int signed_param = -1; + switch (opcode) { + case OpenCLstd_Frexp: + case OpenCLstd_Lgamma_r: + case OpenCLstd_Pown: + case OpenCLstd_Rootn: + case OpenCLstd_Ldexp: + signed_param = 1; + break; + case OpenCLstd_Remquo: + signed_param = 2; + break; + case OpenCLstd_SMad_sat: { + /* All parameters need to be converted to signed */ + src_types[0] = src_types[1] = src_types[2] = get_signed_type(b, src_types[0]); + break; + } + default: break; + } + + if (signed_param >= 0) { + src_types[signed_param] = get_signed_type(b, src_types[signed_param]); + } + + nir_deref_instr *ret_deref = NULL; + + if (!call_mangled_function(b, name, 0, num_srcs, src_types, + dest_type, srcs, &ret_deref)) + return NULL; + + return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL; +} + static nir_ssa_def * handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types, @@ -138,6 +489,7 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, switch (opcode) { case OpenCLstd_SAbs_diff: + /* these works easier in direct NIR */ return nir_iabs_diff(nb, srcs[0], srcs[1]); case OpenCLstd_UAbs_diff: return nir_uabs_diff(nb, srcs[0], srcs[1]); @@ -207,6 +559,7 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, return nir_sge(nb, srcs[1], srcs[0]); case OpenCLstd_S_Upsample: case OpenCLstd_U_Upsample: + /* SPIR-V and CL have different defs for upsample, just implement in nir */ return nir_upsample(nb, srcs[0], srcs[1]); case OpenCLstd_Native_exp: return nir_fexp(nb, srcs[0]); @@ -219,9 +572,14 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, case OpenCLstd_Native_tan: return nir_ftan(nb, srcs[0]); default: - vtn_fail("No NIR equivalent"); - return NULL; + break; } + + nir_ssa_def *ret = handle_clc_fn(b, opcode, num_srcs, srcs, src_types, dest_type); + if (!ret) + vtn_fail("No NIR equivalent"); + + return ret; } static void |