/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.util;

import ai.djl.ndarray.NDManager;
import java.nio.ByteBuffer;
import java.nio.ShortBuffer;

public final class Float16Utils {
    public static final short ONE = Float16Utils.floatToHalf(1.0f);

    private Float16Utils() {
    }

    public static float[] fromByteBuffer(ByteBuffer buffer) {
        return Float16Utils.fromShortBuffer(buffer.asShortBuffer());
    }

    public static float[] fromShortBuffer(ShortBuffer buffer) {
        int index = 0;
        float[] ret = new float[buffer.remaining()];
        while (buffer.hasRemaining()) {
            short value = buffer.get();
            ret[index++] = Float16Utils.halfToFloat(value);
        }
        return ret;
    }

    public static ByteBuffer toByteBuffer(NDManager manager, float[] floats) {
        ByteBuffer buffer = manager.allocateDirect(floats.length * 2);
        for (float f : floats) {
            short value = Float16Utils.floatToHalf(f);
            buffer.putShort(value);
        }
        buffer.rewind();
        return buffer;
    }

    public static short floatToHalf(float fVal) {
        int bits = Float.floatToIntBits(fVal);
        int sign = bits >>> 16 & 0x8000;
        int val = (bits & Integer.MAX_VALUE) + 4096;
        if (val >= 1199570944) {
            if ((bits & Integer.MAX_VALUE) >= 1199570944) {
                if (val < 2139095040) {
                    return (short)(sign | 0x7C00);
                }
                return (short)(sign | 0x7C00 | (bits & 0x7FFFFF) >>> 13);
            }
            return (short)(sign | 0x7BFF);
        }
        if (val >= 0x38800000) {
            return (short)(sign | val - 0x38000000 >>> 13);
        }
        if (val < 0x33000000) {
            return (short)sign;
        }
        val = (bits & Integer.MAX_VALUE) >>> 23;
        return (short)(sign | (bits & 0x7FFFFF | 0x800000) + (0x800000 >>> val - 102) >>> 126 - val);
    }

    public static float halfToFloat(short half) {
        int mant = half & 0x3FF;
        int exp = half & 0x7C00;
        if (exp == 31744) {
            exp = 261120;
        } else if (exp != 0) {
            if (mant == 0 && (exp += 114688) > 115712) {
                return Float.intBitsToFloat((half & 0x8000) << 16 | exp << 13);
            }
        } else if (mant != 0) {
            exp = 115712;
            do {
                exp -= 1024;
            } while (((mant <<= 1) & 0x400) == 0);
            mant &= 0x3FF;
        }
        return Float.intBitsToFloat((half & 0x8000) << 16 | (exp | mant) << 13);
    }
}

