From 1971efe5ba7c2e3fdeacfe6b019073c237baf55b Mon Sep 17 00:00:00 2001 From: Erik Faye-Lund Date: Wed, 7 Apr 2021 16:50:22 +0200 Subject: zink: support emitting 16-bit float types This prepares us for being able to support using 16-bit float types in shaders, which might help performance in some cases. Reviewed-By: Mike Blumenkrantz Part-of: --- .../drivers/zink/nir_to_spirv/nir_to_spirv.c | 33 +++++++++++++++++----- .../drivers/zink/nir_to_spirv/spirv_builder.c | 13 ++++++--- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c index 75d75bab541..b55e026a950 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c +++ b/src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c @@ -220,7 +220,7 @@ get_atomic_op(nir_intrinsic_op op) static SpvId emit_float_const(struct ntv_context *ctx, int bit_size, double value) { - assert(bit_size == 32 || bit_size == 64); + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); return spirv_builder_const_float(&ctx->builder, bit_size, value); } @@ -241,7 +241,7 @@ emit_int_const(struct ntv_context *ctx, int bit_size, int64_t value) static SpvId get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components) { - assert(bit_size == 32 || bit_size == 64); + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size); if (num_components > 1) @@ -312,6 +312,9 @@ get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type) case GLSL_TYPE_BOOL: return spirv_builder_type_bool(&ctx->builder); + case GLSL_TYPE_FLOAT16: + return spirv_builder_type_float(&ctx->builder, 16); + case GLSL_TYPE_FLOAT: return spirv_builder_type_float(&ctx->builder, 32); @@ -1389,7 +1392,7 @@ static SpvId get_fvec_constant(struct ntv_context *ctx, unsigned bit_size, unsigned num_components, double value) { - assert(bit_size == 32 || bit_size == 64); + assert(bit_size == 16 || bit_size == 32 || bit_size == 64); SpvId result = emit_float_const(ctx, bit_size, value); if (num_components == 1) @@ -1578,12 +1581,15 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) UNOP(nir_op_f2u16, SpvOpConvertFToU) UNOP(nir_op_f2i32, SpvOpConvertFToS) UNOP(nir_op_f2u32, SpvOpConvertFToU) + UNOP(nir_op_i2f16, SpvOpConvertSToF) UNOP(nir_op_i2f32, SpvOpConvertSToF) + UNOP(nir_op_u2f16, SpvOpConvertUToF) UNOP(nir_op_u2f32, SpvOpConvertUToF) UNOP(nir_op_i2i16, SpvOpSConvert) UNOP(nir_op_i2i32, SpvOpSConvert) UNOP(nir_op_u2u16, SpvOpUConvert) UNOP(nir_op_u2u32, SpvOpUConvert) + UNOP(nir_op_f2f16, SpvOpFConvert) UNOP(nir_op_f2f32, SpvOpFConvert) UNOP(nir_op_f2i64, SpvOpConvertFToS) UNOP(nir_op_f2u64, SpvOpConvertFToU) @@ -1612,6 +1618,7 @@ emit_alu(struct ntv_context *ctx, nir_alu_instr *alu) get_ivec_constant(ctx, bit_size, num_components, 0)); break; + case nir_op_b2f16: case nir_op_b2f32: case nir_op_b2f64: assert(nir_op_infos[alu->op].num_inputs == 1); @@ -1809,10 +1816,19 @@ emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const) load_const->value[i].b); } else { - for (int i = 0; i < num_components; i++) - components[i] = emit_uint_const(ctx, bit_size, - bit_size == 64 ? load_const->value[i].u64 : load_const->value[i].u32); - + for (int i = 0; i < num_components; i++) { + if (bit_size == 16) + components[i] = emit_uint_const(ctx, bit_size, + load_const->value[i].u16); + else if (bit_size == 32) + components[i] = emit_uint_const(ctx, bit_size, + load_const->value[i].u32); + else if (bit_size == 64) + components[i] = emit_uint_const(ctx, bit_size, + load_const->value[i].u64); + else + unreachable("unhandled constant bit size!"); + } } constant = spirv_builder_const_composite(&ctx->builder, type, components, num_components); @@ -3587,6 +3603,9 @@ nir_to_spirv(struct nir_shader *s, const struct zink_so_info *so_info, bool spir spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt16); if (s->info.bit_sizes_int & 64) spirv_builder_emit_cap(&ctx.builder, SpvCapabilityInt64); + + if (s->info.bit_sizes_float & 16) + spirv_builder_emit_cap(&ctx.builder, SpvCapabilityFloat16); if (s->info.bit_sizes_float & 64) spirv_builder_emit_cap(&ctx.builder, SpvCapabilityFloat64); diff --git a/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c b/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c index a6fe423c0f9..40898d560a7 100644 --- a/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c +++ b/src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c @@ -28,6 +28,7 @@ #include "util/ralloc.h" #include "util/u_bitcast.h" #include "util/u_memory.h" +#include "util/half_float.h" #include "util/hash_table.h" #define XXH_INLINE_ALL #include "util/xxhash.h" @@ -1408,7 +1409,7 @@ spirv_builder_const_bool(struct spirv_builder *b, bool val) SpvId spirv_builder_const_int(struct spirv_builder *b, int width, int64_t val) { - assert(width >= 32); + assert(width >= 16); SpvId type = spirv_builder_type_int(b, width); if (width <= 32) return emit_constant_32(b, type, val); @@ -1437,12 +1438,16 @@ spirv_builder_spec_const_uint(struct spirv_builder *b, int width) SpvId spirv_builder_const_float(struct spirv_builder *b, int width, double val) { - assert(width >= 32); + assert(width >= 16); SpvId type = spirv_builder_type_float(b, width); - if (width <= 32) + if (width == 16) + return emit_constant_32(b, type, _mesa_float_to_half(val)); + else if (width == 32) return emit_constant_32(b, type, u_bitcast_f2u(val)); - else + else if (width == 64) return emit_constant_64(b, type, u_bitcast_d2u(val)); + + unreachable("unhandled float-width"); } SpvId -- cgit v1.2.3