/* * Copyright © 2020 Collabora Ltd. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the "Software"), * to deal in the Software without restriction, including without limitation * the rights to use, copy, modify, merge, publish, distribute, sublicense, * and/or sell copies of the Software, and to permit persons to whom the * Software is furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice (including the next * paragraph) shall be included in all copies or substantial portions of the * Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS * IN THE SOFTWARE. */ #ifndef NIR_CONVERSION_BUILDER_H #define NIR_CONVERSION_BUILDER_H #include "util/u_math.h" #include "nir_builder.h" #include "nir_builtin_builder.h" #ifdef __cplusplus extern "C" { #endif static inline nir_ssa_def * nir_round_float_to_int(nir_builder *b, nir_ssa_def *src, nir_rounding_mode round) { switch (round) { case nir_rounding_mode_ru: return nir_fceil(b, src); case nir_rounding_mode_rd: return nir_ffloor(b, src); case nir_rounding_mode_rtne: return nir_fround_even(b, src); case nir_rounding_mode_undef: case nir_rounding_mode_rtz: break; } unreachable("unexpected rounding mode"); } static inline nir_ssa_def * nir_round_float_to_float(nir_builder *b, nir_ssa_def *src, unsigned dest_bit_size, nir_rounding_mode round) { unsigned src_bit_size = src->bit_size; if (dest_bit_size > src_bit_size) return src; /* No rounding is needed for an up-convert */ nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size, nir_type_float | dest_bit_size, nir_rounding_mode_undef); nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size, nir_type_float | src_bit_size, nir_rounding_mode_undef); switch (round) { case nir_rounding_mode_ru: { /* If lower-precision conversion results in a lower value, push it * up one ULP. */ nir_ssa_def *lower_prec = nir_build_alu(b, low_conv, src, NULL, NULL, NULL); nir_ssa_def *roundtrip = nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL); nir_ssa_def *cmp = nir_flt(b, roundtrip, src); nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size); return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec); } case nir_rounding_mode_rd: { /* If lower-precision conversion results in a higher value, push it * down one ULP. */ nir_ssa_def *lower_prec = nir_build_alu(b, low_conv, src, NULL, NULL, NULL); nir_ssa_def *roundtrip = nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL); nir_ssa_def *cmp = nir_flt(b, src, roundtrip); nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size); return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec); } case nir_rounding_mode_rtz: return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)), nir_round_float_to_float(b, src, dest_bit_size, nir_rounding_mode_ru), nir_round_float_to_float(b, src, dest_bit_size, nir_rounding_mode_rd)); case nir_rounding_mode_rtne: case nir_rounding_mode_undef: break; } unreachable("unexpected rounding mode"); } static inline nir_ssa_def * nir_round_int_to_float(nir_builder *b, nir_ssa_def *src, nir_alu_type src_type, unsigned dest_bit_size, nir_rounding_mode round) { /* We only care whether or not its signed */ src_type = nir_alu_type_get_base_type(src_type); unsigned mantissa_bits; switch (dest_bit_size) { case 16: mantissa_bits = 10; break; case 32: mantissa_bits = 23; break; case 64: mantissa_bits = 52; break; default: unreachable("Unsupported bit size"); } if (src->bit_size < mantissa_bits) return src; if (src_type == nir_type_int) { nir_ssa_def *sign = nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1))); nir_ssa_def *abs = nir_iabs(b, src); nir_ssa_def *positive_rounded = nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round); nir_ssa_def *max_positive = nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size); switch (round) { case nir_rounding_mode_rtz: return nir_bcsel(b, sign, nir_ineg(b, positive_rounded), positive_rounded); break; case nir_rounding_mode_ru: return nir_bcsel(b, sign, nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)), nir_umin(b, positive_rounded, max_positive)); break; case nir_rounding_mode_rd: return nir_bcsel(b, sign, nir_ineg(b, nir_umin(b, max_positive, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))), positive_rounded); case nir_rounding_mode_rtne: case nir_rounding_mode_undef: break; } unreachable("unexpected rounding mode"); } else { nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits); nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size); nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size); nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size); nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose); nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one)); nir_ssa_def *truncated = nir_iand(b, src, mask); switch (round) { case nir_rounding_mode_rtz: case nir_rounding_mode_rd: return truncated; break; case nir_rounding_mode_ru: return nir_bcsel(b, nir_ieq(b, src, truncated), src, nir_uadd_sat(b, truncated, adjust)); case nir_rounding_mode_rtne: case nir_rounding_mode_undef: break; } unreachable("unexpected rounding mode"); } } /** Returns true if the representable range of a contains the representable * range of b. */ static inline bool nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b) { /* Split types from bit sizes */ nir_alu_type a_base_type = nir_alu_type_get_base_type(a); nir_alu_type b_base_type = nir_alu_type_get_base_type(b); unsigned a_bit_size = nir_alu_type_get_type_size(a); unsigned b_bit_size = nir_alu_type_get_type_size(b); /* This requires sized types */ assert(a_bit_size > 0 && b_bit_size > 0); if (a_base_type == b_base_type && a_bit_size >= b_bit_size) return true; if (a_base_type == nir_type_int && b_base_type == nir_type_uint && a_bit_size > b_bit_size) return true; /* 16-bit floats fit in 32-bit integers */ if (a_base_type == nir_type_int && a_bit_size >= 32 && b == nir_type_float16) return true; /* All signed or unsigned ints can fit in float or above. A uint8 can fit * in a float16. */ if (a_base_type == nir_type_float && b_base_type != nir_type_float && (a_bit_size >= 32 || b_bit_size == 8)) return true; return false; } /** * Retrieves limits used for clamping a value of the src type into * the widest representable range of the dst type via cmp + bcsel */ static inline void nir_get_clamp_limits(nir_builder *b, nir_alu_type src_type, nir_alu_type dest_type, nir_ssa_def **low, nir_ssa_def **high) { /* Split types from bit sizes */ nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); unsigned src_bit_size = nir_alu_type_get_type_size(src_type); unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); assert(dest_bit_size != 0 && src_bit_size != 0); *low = NULL; *high = NULL; /* limits of the destination type, expressed in the source type */ switch (dest_base_type) { case nir_type_int: { int64_t ilow, ihigh; if (dest_bit_size == 64) { ilow = INT64_MIN; ihigh = INT64_MAX; } else { ilow = -(1ll << (dest_bit_size - 1)); ihigh = (1ll << (dest_bit_size - 1)) - 1; } if (src_base_type == nir_type_int) { *low = nir_imm_intN_t(b, ilow, src_bit_size); *high = nir_imm_intN_t(b, ihigh, src_bit_size); } else if (src_base_type == nir_type_uint) { assert(src_bit_size >= dest_bit_size); *high = nir_imm_intN_t(b, ihigh, src_bit_size); } else { *low = nir_imm_floatN_t(b, ilow, src_bit_size); *high = nir_imm_floatN_t(b, ihigh, src_bit_size); } break; } case nir_type_uint: { uint64_t uhigh = dest_bit_size == 64 ? ~0ull : (1ull << dest_bit_size) - 1; if (src_base_type != nir_type_float) { *low = nir_imm_intN_t(b, 0, src_bit_size); if (src_base_type == nir_type_uint || src_bit_size > dest_bit_size) *high = nir_imm_intN_t(b, uhigh, src_bit_size); } else { *low = nir_imm_floatN_t(b, 0.0f, src_bit_size); *high = nir_imm_floatN_t(b, uhigh, src_bit_size); } break; } case nir_type_float: { double flow, fhigh; switch (dest_bit_size) { case 16: flow = -65504.0f; fhigh = 65504.0f; break; case 32: flow = -FLT_MAX; fhigh = FLT_MAX; break; case 64: flow = -DBL_MAX; fhigh = DBL_MAX; break; default: unreachable("Unhandled bit size"); } switch (src_base_type) { case nir_type_int: { int64_t src_ilow, src_ihigh; if (src_bit_size == 64) { src_ilow = INT64_MIN; src_ihigh = INT64_MAX; } else { src_ilow = -(1ll << (src_bit_size - 1)); src_ihigh = (1ll << (src_bit_size - 1)) - 1; } if (src_ilow < flow) *low = nir_imm_intN_t(b, flow, src_bit_size); if (src_ihigh > fhigh) *high = nir_imm_intN_t(b, fhigh, src_bit_size); break; } case nir_type_uint: { uint64_t src_uhigh = src_bit_size == 64 ? ~0ull : (1ull << src_bit_size) - 1; if (src_uhigh > fhigh) *high = nir_imm_intN_t(b, fhigh, src_bit_size); break; } case nir_type_float: *low = nir_imm_floatN_t(b, flow, src_bit_size); *high = nir_imm_floatN_t(b, fhigh, src_bit_size); break; default: unreachable("Clamping from unknown type"); } break; } default: unreachable("clamping to unknown type"); break; } } /** * Clamp the value into the widest representatble range of the * destination type with cmp + bcsel. * * val/val_type: The variables used for bcsel * src/src_type: The variables used for comparison * dest_type: The type which determines the range used for comparison */ static inline nir_ssa_def * nir_clamp_to_type_range(nir_builder *b, nir_ssa_def *val, nir_alu_type val_type, nir_ssa_def *src, nir_alu_type src_type, nir_alu_type dest_type) { assert(nir_alu_type_get_type_size(src_type) == 0 || nir_alu_type_get_type_size(src_type) == src->bit_size); src_type |= src->bit_size; if (nir_alu_type_range_contains_type_range(dest_type, src_type)) return val; /* limits of the destination type, expressed in the source type */ nir_ssa_def *low = NULL, *high = NULL; nir_get_clamp_limits(b, src_type, dest_type, &low, &high); nir_ssa_def *low_cond = NULL, *high_cond = NULL; switch (nir_alu_type_get_base_type(src_type)) { case nir_type_int: low_cond = low ? nir_ilt(b, src, low) : NULL; high_cond = high ? nir_ilt(b, high, src) : NULL; break; case nir_type_uint: low_cond = low ? nir_ult(b, src, low) : NULL; high_cond = high ? nir_ult(b, high, src) : NULL; break; case nir_type_float: low_cond = low ? nir_fge(b, low, src) : NULL; high_cond = high ? nir_fge(b, src, high) : NULL; break; default: unreachable("clamping from unknown type"); } nir_ssa_def *val_low = low, *val_high = high; if (val_type != src_type) { nir_get_clamp_limits(b, val_type, dest_type, &val_low, &val_high); } nir_ssa_def *res = val; if (low_cond && val_low) res = nir_bcsel(b, low_cond, val_low, res); if (high_cond && val_high) res = nir_bcsel(b, high_cond, val_high, res); return res; } static inline nir_rounding_mode nir_simplify_conversion_rounding(nir_alu_type src_type, nir_alu_type dest_type, nir_rounding_mode rounding) { nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); unsigned src_bit_size = nir_alu_type_get_type_size(src_type); unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); assert(src_bit_size > 0 && dest_bit_size > 0); if (rounding == nir_rounding_mode_undef) return rounding; /* Pure integer conversion doesn't have any rounding */ if (src_base_type != nir_type_float && dest_base_type != nir_type_float) return nir_rounding_mode_undef; /* Float down-casts don't round */ if (src_base_type == nir_type_float && dest_base_type == nir_type_float && dest_bit_size >= src_bit_size) return nir_rounding_mode_undef; /* Regular float to int conversions are RTZ */ if (src_base_type == nir_type_float && dest_base_type != nir_type_float && rounding == nir_rounding_mode_rtz) return nir_rounding_mode_undef; /* The CL spec requires regular conversions to float to be RTNE */ if (dest_base_type == nir_type_float && rounding == nir_rounding_mode_rtne) return nir_rounding_mode_undef; /* Couldn't simplify */ return rounding; } static inline nir_ssa_def * nir_convert_with_rounding(nir_builder *b, nir_ssa_def *src, nir_alu_type src_type, nir_alu_type dest_type, nir_rounding_mode round, bool clamp) { /* Some stuff wants sized types */ assert(nir_alu_type_get_type_size(src_type) == 0 || nir_alu_type_get_type_size(src_type) == src->bit_size); src_type |= src->bit_size; /* Split types from bit sizes */ nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); /* Try to simplify the conversion if we can */ clamp = clamp && !nir_alu_type_range_contains_type_range(dest_type, src_type); round = nir_simplify_conversion_rounding(src_type, dest_type, round); /* For float -> int/uint conversions, we might not be able to represent * the destination range in the source float accurately. For these cases, * do the comparison in float range, but the bcsel in the destination range. */ bool clamp_after_conversion = clamp && src_base_type == nir_type_float && dest_base_type != nir_type_float; /* * If we don't care about rounding and clamping, we can just use NIR's * built-in ops. There is also a special case for SPIR-V in shaders, where * f32/f64 -> f16 conversions can have one of two rounding modes applied, * which NIR has built-in opcodes for. * * For the rest, we have our own implementation of rounding and clamping. */ bool trivial_convert; if (!clamp && round == nir_rounding_mode_undef) { trivial_convert = true; } else if (!clamp && src_type == nir_type_float32 && dest_type == nir_type_float16 && (round == nir_rounding_mode_rtne || round == nir_rounding_mode_rtz)) { trivial_convert = true; } else { trivial_convert = false; } if (trivial_convert) { nir_op op = nir_type_conversion_op(src_type, dest_type, round); return nir_build_alu(b, op, src, NULL, NULL, NULL); } nir_ssa_def *dest = src; /* clamp the result into range */ if (clamp && !clamp_after_conversion) dest = nir_clamp_to_type_range(b, src, src_type, src, src_type, dest_type); /* round with selected rounding mode */ if (!trivial_convert && round != nir_rounding_mode_undef) { if (src_base_type == nir_type_float) { if (dest_base_type == nir_type_float) { dest = nir_round_float_to_float(b, dest, dest_bit_size, round); } else { dest = nir_round_float_to_int(b, dest, round); } } else { dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round); } round = nir_rounding_mode_undef; } /* now we can convert the value */ nir_op op = nir_type_conversion_op(src_type, dest_type, round); dest = nir_build_alu(b, op, dest, NULL, NULL, NULL); if (clamp_after_conversion) dest = nir_clamp_to_type_range(b, dest, dest_type, src, src_type, dest_type); return dest; } #ifdef __cplusplus } #endif #endif /* NIR_CONVERSION_BUILDER_H */