OKLabPaletteMapper.java

/* ========================================================================
 * PlantUML : a free UML diagram generator
 * ========================================================================
 *
 * (C) Copyright 2009-2024, Arnaud Roques
 *
 * Project Info:  https://plantuml.com
 * 
 * If you like this project or if you find it useful, you can support us at:
 * 
 * https://plantuml.com/patreon (only 1$ per month!)
 * https://plantuml.com/paypal
 * 
 * This file is part of PlantUML.
 *
 * PlantUML is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PlantUML distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
 * License for more details.
 *
 * You should have received a copy of the GNU General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
 * USA.
 *
 *
 * Original Author:  Arnaud Roques
 * With assistance from ChatGPT (OpenAI)
 *
 */
package net.sourceforge.plantuml.png.quantx;

import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.awt.image.IndexColorModel;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public final class OKLabPaletteMapper {

    public static BufferedImage quantize(BufferedImage src, Map<Integer, Integer> palette, boolean floydSteinberg) {
        final int w = src.getWidth();
        final int h = src.getHeight();

        // --- 1) Build the 256 palette entries (ARGB -> r[], g[], b[], a[]) ---
        final int paletteSize = Math.min(256, palette.size());
        byte[] r = new byte[256];
        byte[] g = new byte[256];
        byte[] b = new byte[256];
        byte[] a = new byte[256];

        int[] paletteInts = new int[Math.max(1, paletteSize)];
        int idx = 0;
        for (Integer c : palette.keySet()) {
            if (idx == paletteSize) break;
            int argb = c;
            a[idx] = (byte) ((argb >>> 24) & 0xFF);
            r[idx] = (byte) ((argb >>> 16) & 0xFF);
            g[idx] = (byte) ((argb >>> 8) & 0xFF);
            b[idx] = (byte) (argb & 0xFF);
            paletteInts[idx] = argb;
            idx++;
        }

       // Fill up to 256 if needed
        final int fillFrom = Math.max(0, paletteSize);
        final int fillArgb = (paletteSize == 0) ? 0xFF000000 : paletteInts[paletteSize - 1];
        final byte fillA = (byte) ((fillArgb >>> 24) & 0xFF);
        final byte fillR = (byte) ((fillArgb >>> 16) & 0xFF);
        final byte fillG = (byte) ((fillArgb >>> 8) & 0xFF);
        final byte fillB = (byte) (fillArgb & 0xFF);
        for (int i = fillFrom; i < 256; i++) {
            a[i] = fillA;
            r[i] = fillR;
            g[i] = fillG;
            b[i] = fillB;
        }

        IndexColorModel icm = new IndexColorModel(8, 256, r, g, b, a);
        BufferedImage dst = new BufferedImage(w, h, BufferedImage.TYPE_BYTE_INDEXED, icm);
        byte[] out = ((DataBufferByte) dst.getRaster().getDataBuffer()).getData();

        // Edge case: empty palette => everything mapped to index 0
        if (paletteSize == 0) {
            Arrays.fill(out, (byte) 0);
            return dst;
        }

        // --- 2) Precompute the palette in OKLab ---
        float[] pL = new float[paletteSize];
        float[] pA = new float[paletteSize];
        float[] pB = new float[paletteSize];
        int transparentIndex = -1;
        for (int i = 0; i < paletteSize; i++) {
            int argb = paletteInts[i];
            int pa = (argb >>> 24) & 0xFF;
            int pr = (argb >>> 16) & 0xFF;
            int pg = (argb >>> 8) & 0xFF;
            int pb_ = argb & 0xFF;
            float[] lab = srgbToOKLab(pr, pg, pb_);
            pL[i] = lab[0];
            pA[i] = lab[1];
            pB[i] = lab[2];
            if (pa == 0 && transparentIndex < 0) transparentIndex = i; // détection simple
        }

       // --- 3) Pixel mapping -> nearest palette index (OKLab + alpha-aware) ---
        final int[] srcPixels = src.getRGB(0, 0, w, h, null, 0, w);

        // Alpha weight in the distance metric (tune if needed)

        final float alphaWeight = 0.015f; // 0.0f to ignore alpha

        if (!floydSteinberg) {
        	// No dithering: keep a cache (useful for flat-color images)
            final Map<Integer, Byte> cache = new HashMap<>(4096);
            for (int y = 0; y < h; y++) {
                int rowOff = y * w;
                for (int x = 0; x < w; x++) {
                    int pos = rowOff + x;
                    int argb = srcPixels[pos];

                    Byte cached = cache.get(argb);
                    if (cached != null) {
                        out[pos] = cached;
                        continue;
                    }

                    int A0 = (argb >>> 24) & 0xFF;
                    int R0 = (argb >>> 16) & 0xFF;
                    int G0 = (argb >>> 8) & 0xFF;
                    int B0 = argb & 0xFF;

                 // If we have a transparent index and alpha is very low, map directly
                    if (transparentIndex >= 0 && A0 < 8) {
                        byte bi = (byte) transparentIndex;
                        cache.put(argb, bi);
                        out[pos] = bi;
                        continue;
                    }

                    float[] lab0 = srgbToOKLab(R0, G0, B0);
                    byte best = findNearestOKLabIndex(lab0[0], lab0[1], lab0[2], A0, pL, pA, pB, paletteInts, alphaWeight);
                    cache.put(argb, best);
                    out[pos] = best;
                }
            }
            return dst;
        }

        // --- Floyd–Steinberg dithering (serpentine), error diffusion in OKLab ---
        final boolean serpentine = true;

        // Per-row errors for L, a, b (current row and next row)
        float[] errLRow = new float[w];
        float[] errARow = new float[w];
        float[] errBRow = new float[w];
        float[] errLNext = new float[w];
        float[] errANext = new float[w];
        float[] errBNext = new float[w];

        final float w7 = 7f / 16f;
        final float w5 = 5f / 16f;
        final float w3 = 3f / 16f;
        final float w1 = 1f / 16f;

        for (int y = 0; y < h; y++) {
            int rowOff = y * w;
            boolean leftToRight = !serpentine || ((y & 1) == 0);

         // reset next row
            Arrays.fill(errLNext, 0f);
            Arrays.fill(errANext, 0f);
            Arrays.fill(errBNext, 0f);

            if (leftToRight) {
                for (int x = 0; x < w; x++) {
                    int pos = rowOff + x;
                    int argb = srcPixels[pos];

                    int A0 = (argb >>> 24) & 0xFF;
                    int R0 = (argb >>> 16) & 0xFF;
                    int G0 = (argb >>> 8) & 0xFF;
                    int B0 = argb & 0xFF;

                 // Nearly transparent => no error diffusion
                    if (transparentIndex >= 0 && A0 < 8) {
                        out[pos] = (byte) transparentIndex;
                        continue;
                    }

                    float[] lab0 = srgbToOKLab(R0, G0, B0);
                    float L = clamp(lab0[0] + errLRow[x], -0.5f, 1.5f);
                    float Aok = clamp(lab0[1] + errARow[x], -1.5f, 1.5f);
                    float Bok = clamp(lab0[2] + errBRow[x], -1.5f, 1.5f);

                    byte best = findNearestOKLabIndex(L, Aok, Bok, A0, pL, pA, pB, paletteInts, alphaWeight);
                    out[pos] = best;

                    int bi = best & 0xFF;
                    float eL = L - pL[bi];
                    float eA = Aok - pA[bi];
                    float eB = Bok - pB[bi];

                    // Diffusion (x+1, y)
                    if (x + 1 < w) {
                        errLRow[x + 1] += eL * w7;
                        errARow[x + 1] += eA * w7;
                        errBRow[x + 1] += eB * w7;
                    }
                    // (x-1, y+1)
                    if (x - 1 >= 0) {
                        errLNext[x - 1] += eL * w3;
                        errANext[x - 1] += eA * w3;
                        errBNext[x - 1] += eB * w3;
                    }
                    // (x, y+1)
                    errLNext[x] += eL * w5;
                    errANext[x] += eA * w5;
                    errBNext[x] += eB * w5;

                    // (x+1, y+1)
                    if (x + 1 < w) {
                        errLNext[x + 1] += eL * w1;
                        errANext[x + 1] += eA * w1;
                        errBNext[x + 1] += eB * w1;
                    }
                }
            } else {
                for (int xi = w - 1; xi >= 0; xi--) {
                    int pos = rowOff + xi;
                    int argb = srcPixels[pos];

                    int A0 = (argb >>> 24) & 0xFF;
                    int R0 = (argb >>> 16) & 0xFF;
                    int G0 = (argb >>> 8) & 0xFF;
                    int B0 = argb & 0xFF;

                 // Nearly transparent => no error diffusion
                    if (transparentIndex >= 0 && A0 < 8) {
                        out[pos] = (byte) transparentIndex;
                        continue;
                    }

                    float[] lab0 = srgbToOKLab(R0, G0, B0);
                    float L = clamp(lab0[0] + errLRow[xi], -0.5f, 1.5f);
                    float Aok = clamp(lab0[1] + errARow[xi], -1.5f, 1.5f);
                    float Bok = clamp(lab0[2] + errBRow[xi], -1.5f, 1.5f);

                    byte best = findNearestOKLabIndex(L, Aok, Bok, A0, pL, pA, pB, paletteInts, alphaWeight);
                    out[pos] = best;

                    int bi = best & 0xFF;
                    float eL = L - pL[bi];
                    float eA = Aok - pA[bi];
                    float eB = Bok - pB[bi];

                 // Horizontal mirror:
                    // (x-1, y)
                    if (xi - 1 >= 0) {
                        errLRow[xi - 1] += eL * w7;
                        errARow[xi - 1] += eA * w7;
                        errBRow[xi - 1] += eB * w7;
                    }
                    // (x+1, y+1)
                    if (xi + 1 < w) {
                        errLNext[xi + 1] += eL * w3;
                        errANext[xi + 1] += eA * w3;
                        errBNext[xi + 1] += eB * w3;
                    }
                    // (x, y+1)
                    errLNext[xi] += eL * w5;
                    errANext[xi] += eA * w5;
                    errBNext[xi] += eB * w5;

                    // (x-1, y+1)
                    if (xi - 1 >= 0) {
                        errLNext[xi - 1] += eL * w1;
                        errANext[xi - 1] += eA * w1;
                        errBNext[xi - 1] += eB * w1;
                    }
                }
            }

         // Move to next row
            float[] tmp;
            tmp = errLRow; errLRow = errLNext; errLNext = tmp;
            tmp = errARow; errARow = errANext; errANext = tmp;
            tmp = errBRow; errBRow = errBNext; errBNext = tmp;
        }

        return dst;
    }

    // ----------------------------------------------------------------------
    // Nearest palette index search in OKLab, with alpha penalty
    // ----------------------------------------------------------------------
    private static byte findNearestOKLabIndex(float L0, float A0, float B0, int alphaSrc,
                                              float[] pL, float[] pA, float[] pB,
                                              int[] paletteInts, float alphaWeight) {
        int bestI = 0;
        double bestD = Double.POSITIVE_INFINITY;
        for (int i = 0; i < pL.length; i++) {
            float dL = L0 - pL[i];
            float dA = A0 - pA[i];
            float dB = B0 - pB[i];
            double dist = dL * dL + dA * dA + dB * dB;

            int pa = (paletteInts[i] >>> 24) & 0xFF;
            float da = (alphaSrc - pa);
            dist += alphaWeight * (da * da);

            if (dist < bestD) {
                bestD = dist;
                bestI = i;
                if (bestD == 0.0) break;
            }
        }
        return (byte) bestI;
    }

    // ----------------------------------------------------------------------
    // Conversions sRGB -> OKLab
    // ----------------------------------------------------------------------

    // sRGB (0..255) -> OKLab (L, a, b) as float
    private static float[] srgbToOKLab(int R, int G, int B) {
        // 1) sRGB -> linRGB
        double r = srgbToLinear(R / 255.0);
        double g = srgbToLinear(G / 255.0);
        double b = srgbToLinear(B / 255.0);

     // 2) linear RGB -> LMS (matrix recommended in OKLab)
        double l = 0.4122214708 * r + 0.5363325363 * g + 0.0514459929 * b;
        double m = 0.2119034982 * r + 0.6806995451 * g + 0.1073969566 * b;
        double s = 0.0883024619 * r + 0.2817188376 * g + 0.6299787005 * b;

        // 3) cube roots
        double l_ = Math.cbrt(l);
        double m_ = Math.cbrt(m);
        double s_ = Math.cbrt(s);

        // 4) LMS' -> OKLab
        double L = 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_;
        double A = 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_;
        double Bv = 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_;

        return new float[] { (float) L, (float) A, (float) Bv };
    }

    private static double srgbToLinear(double c) {
    	// sRGB inverse gamma
        return (c <= 0.04045) ? (c / 12.92) : Math.pow((c + 0.055) / 1.055, 2.4);
    }

    private static float clamp(float v, float lo, float hi) {
        return (v < lo) ? lo : (v > hi) ? hi : v;
    }
}