QuantizerWuOKLab.java

/* ========================================================================
 * PlantUML : a free UML diagram generator
 * ========================================================================
 *
 * (C) Copyright 2009-2024, Arnaud Roques
 *
 * Project Info:  https://plantuml.com
 * 
 * If you like this project or if you find it useful, you can support us at:
 * 
 * https://plantuml.com/patreon (only 1$ per month!)
 * https://plantuml.com/paypal
 * 
 * This file is part of PlantUML.
 *
 * PlantUML is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * PlantUML distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
 * License for more details.
 *
 * You should have received a copy of the GNU General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
 * USA.
 *
 *
 * Original Author:  Arnaud Roques
 * With assistance from ChatGPT (OpenAI)
 *
 */
package net.sourceforge.plantuml.png.quantx;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;


/**
 * Wu quantization over OKLab (L, a, b) instead of RGB.
 *
 * - 3D histogram on L, a, b (32 bins per axis -> INDEX_BITS = 5, INDEX_COUNT = 33 for integrals)
 * - Cumulative moments/sums along each axis + squared sum (L^2 + a^2 + b^2)
 * - Recursive splitting (maximize variance) same as the original algorithm
 * - Average cube colors converted back to 8-bit sRGB
 *
 * Notes:
 * - L in [0,1], a,b ~ [-0.5, 0.5] for sRGB; we clamp a,b to these limits
 * - “Moments” L/a/b are stored as double; weights remain int
 */
public final class QuantizerWuOKLab {

  // Histograms / moments (3D compact indexing with INDEX_COUNT = 33)
  int[] weights;         // sum of weights
  double[] momentsL;     // sum of L * count
  double[] momentsA;     // sum of a * count
  double[] momentsB;     // sum of b * count
  double[] momentsSS;    // sum of (L^2 + a^2 + b^2) * count
	  
  Box[] cubes;

  //OKLab quantization parameters
  private static final int INDEX_BITS = 5;          // 5 bits -> 32 bacs / dimension
  private static final int INDEX_COUNT = 33;        // (1 << INDEX_BITS) + 1 = 32 + 1
  private static final int TOTAL_SIZE = INDEX_COUNT * INDEX_COUNT * INDEX_COUNT; // 35937

  //OKLab bounds for a and b (safe for sRGB)
  private static final float A_MIN = -0.5f, A_MAX = 0.5f;
  private static final float B_MIN = -0.5f, B_MAX = 0.5f;
  private static final int INDEX_MAX = (1 << INDEX_BITS) - 1; // 31


  public Map<Integer, Integer> quantize(int[] pixels, int colorCount) {
	Map<Integer, Integer> mapResult = new QuantizerMap().quantize(pixels, colorCount);
    constructHistogram(mapResult);
    createMoments();
    CreateBoxesResult createBoxesResult = createBoxes(colorCount);
    List<Integer> colors = createResult(createBoxesResult.resultCount);
    Map<Integer, Integer> resultMap = new LinkedHashMap<>();
    for (int color : colors) {
      resultMap.put(color, 0);
    }
    return resultMap;
  }

  // ---- Histogram construction in OKLab -------------------------------------

  static int getIndex(int l, int a, int b) {
    // index = l*33^2 + a*33 + b, optimized using shifts (since 33 = 32 + 1)
    return (l << (INDEX_BITS * 2)) + (l << (INDEX_BITS + 1)) + l
         + (a << INDEX_BITS) + a + b;
  }

  void constructHistogram(Map<Integer, Integer> pixels) {
    weights   = new int[TOTAL_SIZE];
    momentsL  = new double[TOTAL_SIZE];
    momentsA  = new double[TOTAL_SIZE];
    momentsB  = new double[TOTAL_SIZE];
    momentsSS = new double[TOTAL_SIZE];

    for (Map.Entry<Integer, Integer> pair : pixels.entrySet()) {
      final int argb = pair.getKey();
      final int count = pair.getValue();

      int r8 = (argb >> 16) & 0xFF;
      int g8 = (argb >>  8) & 0xFF;
      int b8 = (argb      ) & 0xFF;

      // Convert to OKLab (L,a,b) floats
      float[] lab = srgb8ToOKLab(r8, g8, b8);
      float L = lab[0];
      float A = lab[1];
      float B = lab[2];

      // Quantize L/a/b into indices 1..32 (0 and 33 reserved for integrals)
      int iL = toIndexL(L);
      int iA = toIndexA(A);
      int iB = toIndexB(B);

      int index = getIndex(iL, iA, iB);

      weights[index] += count;
      momentsL[index] += (double)L * count;
      momentsA[index] += (double)A * count;
      momentsB[index] += (double)B * count;
      momentsSS[index] += (double)count * (L * L + A * A + B * B);
    }
  }

  // ---- 3D cumulative integrals (moments) -----------------------------------

  void createMoments() {
    for (int l = 1; l < INDEX_COUNT; ++l) {
      int[] areaW = new int[INDEX_COUNT];
      double[] areaL = new double[INDEX_COUNT];
      double[] areaA = new double[INDEX_COUNT];
      double[] areaB = new double[INDEX_COUNT];
      double[] areaSS = new double[INDEX_COUNT];

      for (int a = 1; a < INDEX_COUNT; ++a) {
        int lineW = 0;
        double lineL = 0.0;
        double lineA = 0.0;
        double lineB = 0.0;
        double lineSS = 0.0;

        for (int b = 1; b < INDEX_COUNT; ++b) {
          int index = getIndex(l, a, b);

          lineW  += weights[index];
          lineL  += momentsL[index];
          lineA  += momentsA[index];
          lineB  += momentsB[index];
          lineSS += momentsSS[index];

          areaW[b]  += lineW;
          areaL[b]  += lineL;
          areaA[b]  += lineA;
          areaB[b]  += lineB;
          areaSS[b] += lineSS;

          int prev = getIndex(l - 1, a, b);
          weights[index]  = weights[prev]  + areaW[b];
          momentsL[index] = momentsL[prev] + areaL[b];
          momentsA[index] = momentsA[prev] + areaA[b];
          momentsB[index] = momentsB[prev] + areaB[b];
          momentsSS[index]= momentsSS[prev]+ areaSS[b];
        }
      }
    }
  }

  // ---- Splitting (partitioning) --------------------------------------------

  CreateBoxesResult createBoxes(int maxColorCount) {
    cubes = new Box[maxColorCount];
    for (int i = 0; i < maxColorCount; i++) cubes[i] = new Box();

    double[] volumeVariance = new double[maxColorCount];

    Box first = cubes[0];
    first.l1 = INDEX_COUNT - 1;
    first.a1 = INDEX_COUNT - 1;
    first.b1 = INDEX_COUNT - 1;

    int generatedColorCount = maxColorCount;
    int next = 0;
    for (int i = 1; i < maxColorCount; i++) {
      if (cut(cubes[next], cubes[i])) {
        volumeVariance[next] = (cubes[next].vol > 1) ? variance(cubes[next]) : 0.0;
        volumeVariance[i]    = (cubes[i].vol    > 1) ? variance(cubes[i])    : 0.0;
      } else {
        volumeVariance[next] = 0.0;
        i--;
      }

      next = 0;
      double best = volumeVariance[0];
      for (int j = 1; j <= i; j++) {
        if (volumeVariance[j] > best) {
          best = volumeVariance[j];
          next = j;
        }
      }
      if (best <= 0.0) {
        generatedColorCount = i + 1;
        break;
      }
    }
    return new CreateBoxesResult(maxColorCount, generatedColorCount);
  }

  List<Integer> createResult(int colorCount) {
    List<Integer> colors = new ArrayList<>();
    for (int i = 0; i < colorCount; ++i) {
      Box cube = cubes[i];
      int w = volume(cube, weights);
      if (w > 0) {
        double L = volume(cube, momentsL) / w;
        double A = volume(cube, momentsA) / w;
        double B = volume(cube, momentsB) / w;
        int argb = oklabToSrgb8Clamp((float)L, (float)A, (float)B);
        colors.add(argb);
      }
    }
    return colors;
  }

  double variance(Box cube) {
    double dL = volume(cube, momentsL);
    double dA = volume(cube, momentsA);
    double dB = volume(cube, momentsB);

    double xx =
        momentsSS[getIndex(cube.l1, cube.a1, cube.b1)]
      - momentsSS[getIndex(cube.l1, cube.a1, cube.b0)]
      - momentsSS[getIndex(cube.l1, cube.a0, cube.b1)]
      + momentsSS[getIndex(cube.l1, cube.a0, cube.b0)]
      - momentsSS[getIndex(cube.l0, cube.a1, cube.b1)]
      + momentsSS[getIndex(cube.l0, cube.a1, cube.b0)]
      + momentsSS[getIndex(cube.l0, cube.a0, cube.b1)]
      - momentsSS[getIndex(cube.l0, cube.a0, cube.b0)];

    double hyp = dL * dL + dA * dA + dB * dB;
    int volW = volume(cube, weights);
    return xx - hyp / (double) volW;
  }

  Boolean cut(Box one, Box two) {
    double wholeL = volume(one, momentsL);
    double wholeA = volume(one, momentsA);
    double wholeB = volume(one, momentsB);
    int    wholeW = volume(one, weights);

    MaximizeResult maxL = maximize(one, Direction.L, one.l0 + 1, one.l1, wholeL, wholeA, wholeB, wholeW);
    MaximizeResult maxA = maximize(one, Direction.A, one.a0 + 1, one.a1, wholeL, wholeA, wholeB, wholeW);
    MaximizeResult maxB = maximize(one, Direction.B, one.b0 + 1, one.b1, wholeL, wholeA, wholeB, wholeW);

    Direction dir;
    MaximizeResult best = maxL;
    dir = Direction.L;

    if (maxA.maximum >= best.maximum) { best = maxA; dir = Direction.A; }
    if (maxB.maximum >= best.maximum) { best = maxB; dir = Direction.B; }

    if (best.cutLocation < 0) return false;

    two.l1 = one.l1; two.a1 = one.a1; two.b1 = one.b1;

    switch (dir) {
      case L:
        one.l1 = best.cutLocation;
        two.l0 = one.l1;
        two.a0 = one.a0;
        two.b0 = one.b0;
        break;
      case A:
        one.a1 = best.cutLocation;
        two.l0 = one.l0;
        two.a0 = one.a1;
        two.b0 = one.b0;
        break;
      case B:
        one.b1 = best.cutLocation;
        two.l0 = one.l0;
        two.a0 = one.a0;
        two.b0 = one.b1;
        break;
    }

    one.vol = (one.l1 - one.l0) * (one.a1 - one.a0) * (one.b1 - one.b0);
    two.vol = (two.l1 - two.l0) * (two.a1 - two.a0) * (two.b1 - two.b0);
    return true;
  }

  MaximizeResult maximize(
      Box cube,
      Direction direction,
      int first,
      int last,
      double wholeL,
      double wholeA,
      double wholeB,
      int wholeW) {

    double bottomL = bottom(cube, direction, momentsL);
    double bottomA = bottom(cube, direction, momentsA);
    double bottomB = bottom(cube, direction, momentsB);
    int    bottomW = bottom(cube, direction, weights);

    double max = 0.0;
    int cut = -1;

    for (int i = first; i < last; i++) {
      double halfL = bottomL + top(cube, direction, i, momentsL);
      double halfA = bottomA + top(cube, direction, i, momentsA);
      double halfB = bottomB + top(cube, direction, i, momentsB);
      int    halfW = bottomW + top(cube, direction, i, weights);

      if (halfW != 0) {
        double t = (halfL * halfL + halfA * halfA + halfB * halfB) / (double) halfW;

        double rL = wholeL - halfL;
        double rA = wholeA - halfA;
        double rB = wholeB - halfB;
        int    rW = wholeW - halfW;

        if (rW != 0) {
          t += (rL * rL + rA * rA + rB * rB) / (double) rW;
          if (t > max) { max = t; cut = i; }
        }
      }
    }
    return new MaximizeResult(cut, max);
  }

  // ---- Integrals: volume/top/bottom helpers --------------------------------

  static int volume(Box c, int[] m) {
    return  m[getIndex(c.l1, c.a1, c.b1)]
          - m[getIndex(c.l1, c.a1, c.b0)]
          - m[getIndex(c.l1, c.a0, c.b1)]
          + m[getIndex(c.l1, c.a0, c.b0)]
          - m[getIndex(c.l0, c.a1, c.b1)]
          + m[getIndex(c.l0, c.a1, c.b0)]
          + m[getIndex(c.l0, c.a0, c.b1)]
          - m[getIndex(c.l0, c.a0, c.b0)];
  }

  static double volume(Box c, double[] m) {
    return  m[getIndex(c.l1, c.a1, c.b1)]
          - m[getIndex(c.l1, c.a1, c.b0)]
          - m[getIndex(c.l1, c.a0, c.b1)]
          + m[getIndex(c.l1, c.a0, c.b0)]
          - m[getIndex(c.l0, c.a1, c.b1)]
          + m[getIndex(c.l0, c.a1, c.b0)]
          + m[getIndex(c.l0, c.a0, c.b1)]
          - m[getIndex(c.l0, c.a0, c.b0)];
  }

  static int bottom(Box c, Direction d, int[] m) {
    switch (d) {
      case L:
        return -m[getIndex(c.l0, c.a1, c.b1)] + m[getIndex(c.l0, c.a1, c.b0)]
             +  m[getIndex(c.l0, c.a0, c.b1)] - m[getIndex(c.l0, c.a0, c.b0)];
      case A:
        return -m[getIndex(c.l1, c.a0, c.b1)] + m[getIndex(c.l1, c.a0, c.b0)]
             +  m[getIndex(c.l0, c.a0, c.b1)] - m[getIndex(c.l0, c.a0, c.b0)];
      case B:
        return -m[getIndex(c.l1, c.a1, c.b0)] + m[getIndex(c.l1, c.a0, c.b0)]
             +  m[getIndex(c.l0, c.a1, c.b0)] - m[getIndex(c.l0, c.a0, c.b0)];
    }
    throw new IllegalArgumentException("unexpected direction " + d);
  }

  static double bottom(Box c, Direction d, double[] m) {
    switch (d) {
      case L:
        return -m[getIndex(c.l0, c.a1, c.b1)] + m[getIndex(c.l0, c.a1, c.b0)]
             +  m[getIndex(c.l0, c.a0, c.b1)] - m[getIndex(c.l0, c.a0, c.b0)];
      case A:
        return -m[getIndex(c.l1, c.a0, c.b1)] + m[getIndex(c.l1, c.a0, c.b0)]
             +  m[getIndex(c.l0, c.a0, c.b1)] - m[getIndex(c.l0, c.a0, c.b0)];
      case B:
        return -m[getIndex(c.l1, c.a1, c.b0)] + m[getIndex(c.l1, c.a0, c.b0)]
             +  m[getIndex(c.l0, c.a1, c.b0)] - m[getIndex(c.l0, c.a0, c.b0)];
    }
    throw new IllegalArgumentException("unexpected direction " + d);
  }

  static int top(Box c, Direction d, int pos, int[] m) {
    switch (d) {
      case L:
        return  m[getIndex(pos, c.a1, c.b1)] - m[getIndex(pos, c.a1, c.b0)]
              - m[getIndex(pos, c.a0, c.b1)] + m[getIndex(pos, c.a0, c.b0)];
      case A:
        return  m[getIndex(c.l1, pos, c.b1)] - m[getIndex(c.l1, pos, c.b0)]
              - m[getIndex(c.l0, pos, c.b1)] + m[getIndex(c.l0, pos, c.b0)];
      case B:
        return  m[getIndex(c.l1, c.a1, pos)] - m[getIndex(c.l1, c.a0, pos)]
              - m[getIndex(c.l0, c.a1, pos)] + m[getIndex(c.l0, c.a0, pos)];
    }
    throw new IllegalArgumentException("unexpected direction " + d);
  }

  static double top(Box c, Direction d, int pos, double[] m) {
    switch (d) {
      case L:
        return  m[getIndex(pos, c.a1, c.b1)] - m[getIndex(pos, c.a1, c.b0)]
              - m[getIndex(pos, c.a0, c.b1)] + m[getIndex(pos, c.a0, c.b0)];
      case A:
        return  m[getIndex(c.l1, pos, c.b1)] - m[getIndex(c.l1, pos, c.b0)]
              - m[getIndex(c.l0, pos, c.b1)] + m[getIndex(c.l0, pos, c.b0)];
      case B:
        return  m[getIndex(c.l1, c.a1, pos)] - m[getIndex(c.l1, c.a0, pos)]
              - m[getIndex(c.l0, c.a1, pos)] + m[getIndex(c.l0, c.a0, pos)];
    }
    throw new IllegalArgumentException("unexpected direction " + d);
  }

  private static enum Direction { L, A, B }

  private static final class MaximizeResult {
    int cutLocation;   // < 0 si impossible
    double maximum;
    MaximizeResult(int cut, double max) { this.cutLocation = cut; this.maximum = max; }
  }

  private static final class CreateBoxesResult {
    int requestedCount;
    int resultCount;
    CreateBoxesResult(int requestedCount, int resultCount) {
      this.requestedCount = requestedCount;
      this.resultCount = resultCount;
    }
  }

  private static final class Box {
    int l0 = 0, l1 = 0;
    int a0 = 0, a1 = 0;
    int b0 = 0, b1 = 0;
    int vol = 0;
  }

  // ---- OKLab -> index quantization (1..32) ---------------------------------

  private static int toIndexL(float L) {
    if (L < 0f) L = 0f; else if (L > 1f) L = 1f;
    int i = (int)(L * INDEX_MAX);
    if (i < 0) i = 0; else if (i > INDEX_MAX) i = INDEX_MAX;
    return i + 1;
  }

  private static int toIndexA(float a) {
    if (a < A_MIN) a = A_MIN; else if (a > A_MAX) a = A_MAX;
    float norm = (a - A_MIN) / (A_MAX - A_MIN); // 0..1
    int i = (int)(norm * INDEX_MAX);
    if (i < 0) i = 0; else if (i > INDEX_MAX) i = INDEX_MAX;
    return i + 1;
  }

  private static int toIndexB(float b) {
    if (b < B_MIN) b = B_MIN; else if (b > B_MAX) b = B_MAX;
    float norm = (b - B_MIN) / (B_MAX - B_MIN); // 0..1
    int i = (int)(norm * INDEX_MAX);
    if (i < 0) i = 0; else if (i > INDEX_MAX) i = INDEX_MAX;
    return i + 1;
  }

  // ---- Conversions sRGB <-> OKLab ------------------------------------------

  private static float[] srgb8ToOKLab(int r8, int g8, int b8) {
    // 8-bit -> [0,1]
    double r = r8 / 255.0;
    double g = g8 / 255.0;
    double b = b8 / 255.0;

    // sRGB -> lin
    double rl = srgbToLinear(r);
    double gl = srgbToLinear(g);
    double bl = srgbToLinear(b);

    // linear sRGB -> LMS (matrix recommended by Bjorn Ottosson)
    
    double l = 0.4122214708 * rl + 0.5363325363 * gl + 0.0514459929 * bl;
    double m = 0.2119034982 * rl + 0.6806995451 * gl + 0.1073969566 * bl;
    double s = 0.0883024619 * rl + 0.2817188376 * gl + 0.6299787005 * bl;

    // cube root
    double l_ = Math.cbrt(l);
    double m_ = Math.cbrt(m);
    double s_ = Math.cbrt(s);

    double L = 0.2104542553 * l_ + 0.7936177850 * m_ - 0.0040720468 * s_;
    double A = 1.9779984951 * l_ - 2.4285922050 * m_ + 0.4505937099 * s_;
    double B = 0.0259040371 * l_ + 0.7827717662 * m_ - 0.8086757660 * s_;

    return new float[]{ (float)L, (float)A, (float)B };
  }

  private static int oklabToSrgb8Clamp(float L, float A, float B) {
    // OKLab -> LMS^
    double l_ = L + 0.3963377774 * A + 0.2158037573 * B;
    double m_ = L - 0.1055613458 * A - 0.0638541728 * B;
    double s_ = L - 0.0894841775 * A - 1.2914855480 * B;

    // ^3
    double l = l_ * l_ * l_;
    double m = m_ * m_ * m_;
    double s = s_ * s_ * s_;

    // LMS -> lin sRGB
    double rl =  4.0767416621 * l - 3.3077115913 * m + 0.2309699292 * s;
    double gl = -1.2684380046 * l + 2.6097574011 * m - 0.3413193965 * s;
    double bl = -0.0041960863 * l - 0.7034186147 * m + 1.7076147010 * s;

    // lin -> sRGB (borne 0..1)
    double r = linearToSrgb(rl);
    double g = linearToSrgb(gl);
    double b = linearToSrgb(bl);

    int r8 = clamp8((int)Math.round(r * 255.0));
    int g8 = clamp8((int)Math.round(g * 255.0));
    int b8 = clamp8((int)Math.round(b * 255.0));

    return (0xFF << 24) | (r8 << 16) | (g8 << 8) | b8;
  }

  private static double srgbToLinear(double c) {
    return (c <= 0.04045) ? (c / 12.92) : Math.pow((c + 0.055) / 1.055, 2.4);
  }

  private static double linearToSrgb(double x) {
    if (x <= 0.0) return 0.0;
    if (x >= 1.0) return 1.0;
    return (x <= 0.0031308) ? (12.92 * x) : (1.055 * Math.pow(x, 1.0 / 2.4) - 0.055);
  }

  private static int clamp8(int v) {
    if (v < 0) return 0;
    if (v > 255) return 255;
    return v;
  }
}