HikoGUI
A low latency retained GUI
Loading...
Searching...
No Matches
float16.hpp
1// Copyright Take Vos 2020.
2// Distributed under the Boost Software License, Version 1.0.
3// (See accompanying file LICENSE_1_0.txt or copy at https://www.boost.org/LICENSE_1_0.txt)
4
5#pragma once
6
7#include <emmintrin.h>
8#include <immintrin.h>
9#include <type_traits>
10#include "rapid/numeric_array.hpp"
11
12namespace tt {
13
14constexpr uint32_t float16_bias = 15;
15constexpr uint32_t float32_bias = 127;
16constexpr uint32_t f32_to_f16_adjustment_exponent = float32_bias - float16_bias;
17constexpr uint32_t f32_to_f16_lowest_normal_exponent = 0x01 + f32_to_f16_adjustment_exponent;
18constexpr uint32_t f32_to_f16_infinite_exponent = 0x1f + f32_to_f16_adjustment_exponent;
19constexpr uint32_t f32_to_f16_adjustment = f32_to_f16_adjustment_exponent << 23;
20constexpr uint32_t f32_to_f16_lowest_normal = f32_to_f16_lowest_normal_exponent << 23;
21constexpr uint32_t f32_to_f16_infinite = f32_to_f16_infinite_exponent << 23;
22
23// Test with greater or equal is slow, so test with greater than, adjust lowerst_normal.
24inline constinit u32x4 f32_to_f16_constants = u32x4{f32_to_f16_lowest_normal - 1, f32_to_f16_infinite, f32_to_f16_adjustment, 0};
25
26constexpr f32x4 f16x8_to_f32x4(i16x8 value) noexcept
27{
28 // Convert the 16 bit values to 32 bit with leading zeros.
29 auto u = bit_cast<u32x4>(i16x8::interleave_lo(value, i16x8{}));
30
31 // Extract the sign bit.
32 auto sign = (u >> 15) << 31;
33
34 // Strip the sign bit and align the exponent/mantissa boundary to a float 32.
35 u = (u << 17) >> 4;
36
37 // Adjust the bias.
38 u = u + f32_to_f16_constants.zzzz();
39
40 // Get a mask of '1' bits when the half-float would be normal or infinite.
41 auto is_normal = bit_cast<u32x4>(gt_mask(bit_cast<i32x4>(u), bit_cast<i32x4>(f32_to_f16_constants.xxxx())));
42
43 // Add the sign bit back in.
44 u = u | bit_cast<u32x4>(sign);
45
46 // Keep the value if normal, if denormal make it zero.
47 u = u & is_normal;
48
49 return bit_cast<f32x4>(u);
50}
51
52constexpr i16x8 f32x4_to_f16x8(f32x4 value) noexcept
53{
54 // Interpret the floating point number as 32 bit-field.
55 auto u = bit_cast<u32x4>(value);
56
57 // Get the sign of the floating point number as a bit mask of the upper 17 bits.
58 auto sign = (bit_cast<i32x4>(u) >> 31) << 15;
59
60 // Strip sign bit.
61 u = (u << 1) >> 1;
62
63 // Get a mask of '1' bits when the half-float would be normal or infinite.
64 auto is_normal = bit_cast<u32x4>(gt_mask(bit_cast<i32x4>(u), bit_cast<i32x4>(f32_to_f16_constants.xxxx())));
65
66 // Clamp the floating point number to where the half-float would be infinite.
67 u = min(u, f32_to_f16_constants.yyyy());
68
69 // Convert the bias from float to half-float.
70 u = u - f32_to_f16_constants.zzzz();
71
72 // Shift the float until it becomes a half-float. This truncates the mantissa.
73 u = u >> 13;
74
75 // Keep the value if normal, if denormal make it zero.
76 u = u & is_normal;
77
78 // Add the sign bit back in, also set the upper 16 bits so that saturated pack
79 // will work correctly when converting to int16.
80 u = u | bit_cast<u32x4>(sign);
81
82 // Saturate and pack the 32 bit integers to 16 bit integers.
83 auto tmp = bit_cast<i32x4>(u);
84 return i16x8{tmp, tmp};
85}
86
87class float16 {
88 uint16_t v;
89
90public:
91 float16() noexcept : v() {}
92
93 template<typename T, std::enable_if_t<std::is_arithmetic_v<T>,int> = 0>
94 float16(T const &rhs) noexcept {
95 ttlet tmp1 = f32x4{narrow_cast<float>(rhs)};
96 ttlet tmp2 = f32x4_to_f16x8(tmp1);
97 v = tmp2.x();
98 }
99
100 template<typename T, std::enable_if_t<std::is_arithmetic_v<T>,int> = 0>
101 float16 &operator=(T const &rhs) noexcept {
102 ttlet tmp1 = f32x4{narrow_cast<float>(rhs)};
103 ttlet tmp2 = f32x4_to_f16x8(tmp1);
104 v = tmp2.x();
105 return *this;
106 }
107
108 operator float () const noexcept {
109 ttlet tmp1 = i16x8{static_cast<int16_t>(v)};
110 ttlet tmp2 = f16x8_to_f32x4(tmp1);
111 return tmp2.x();
112 }
113
114 static float16 from_uint16_t(uint16_t const rhs) noexcept
115 {
116 auto r = float16{};
117 r.v = rhs;
118 return r;
119 }
120
121 [[nodiscard]] constexpr uint16_t get() const noexcept {
122 return v;
123 }
124
125 constexpr float16 &set(uint16_t rhs) noexcept {
126 v = rhs;
127 return *this;
128 }
129
130 [[nodiscard]] size_t hash() const noexcept
131 {
132 return std::hash<uint16_t>{}(v);
133 }
134
135 [[nodiscard]] friend bool operator==(float16 const &lhs, float16 const &rhs) noexcept {
136 return lhs.v == rhs.v;
137 }
138
139 [[nodiscard]] friend bool operator!=(float16 const &lhs, float16 const &rhs) noexcept {
140 return lhs.v != rhs.v;
141 }
142};
143
144}
145
146namespace std {
147
148template<>
149struct std::hash<tt::float16> {
150 size_t operator()(tt::float16 const &rhs) noexcept
151 {
152 return rhs.hash();
153 }
154};
155
156}
STL namespace.
Definition float16.hpp:87
static constexpr numeric_array interleave_lo(numeric_array a, numeric_array b) noexcept
Interleave the first words in both arrays.
Definition numeric_array.hpp:416
T min(T... args)
T operator()(T... args)