Ed25519.java
/* ========================================================================
* PlantUML : a free UML diagram generator
* ========================================================================
*
* (C) Copyright 2009-2024, Arnaud Roques
*
* Project Info: https://plantuml.com
*
* If you like this project or if you find it useful, you can support us at:
*
* https://plantuml.com/patreon (only 1$ per month!)
* https://plantuml.com/paypal
*
* This file is part of PlantUML.
*
* PlantUML is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* PlantUML distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
* or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
* License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
* USA.
*
*
* Original Author: Arnaud Roques
* With assistance from ChatGPT (OpenAI)
*
*/
package net.sourceforge.plantuml.licensing;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Arrays;
/**
* Minimal Ed25519 in pure Java.
*/
public final class Ed25519 {
// p = 2^255 - 19
private static final BigInteger P = new BigInteger("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed", 16);
// L = subgroup order
private static final BigInteger L = new BigInteger("1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed", 16);
// curve constants (twisted Edwards, a = -1)
private static final BigInteger D; // d = -121665/121666 mod p
private static final BigInteger SQRT_M1; // sqrt(-1) mod p
// basepoint B (x, y) — standard Ed25519 values
private static final BigInteger BX = new BigInteger(
"15112221349535400772501151409588531511454012693041857206046113283949847762202");
private static final BigInteger BY = new BigInteger(
"46316835694926478169428394003475163141307993866256225615783033603165251855960");
static {
// d = -121665 * inv(121666) mod p
D = BigInteger.valueOf(-121665).mod(P)
.multiply(BigInteger.valueOf(121666).modInverse(P)).mod(P);
// sqrt(-1) = 2^((p-1)/4) mod p
SQRT_M1 = BigInteger.valueOf(2).modPow(P.subtract(BigInteger.ONE).shiftRight(2), P);
}
private Ed25519() {}
// --- key API ---
/** Generates a pair (private seed 32 bytes, public key 32 bytes). */
public static KeyPair generateKeyPair(SecureRandom rnd) {
byte[] seed = new byte[32];
rnd.nextBytes(seed);
byte[] pub = publicKeyFromSeed(seed);
return new KeyPair(seed, pub);
}
/** Derives the public key (32 bytes) from a private seed (32 bytes). */
public static byte[] publicKeyFromSeed(byte[] seed32) {
if (seed32 == null || seed32.length != 32) throw new IllegalArgumentException("seed must be 32 bytes");
Digest d = sha512(seed32);
byte[] aBytes = clamp(Arrays.copyOfRange(d.bytes, 0, 32));
BigInteger a = leToInt(aBytes);
Point A = scalarMulBase(a);
return encodePoint(A);
}
// --- signature ---
/**
* Signs a message with a private seed (32 bytes). Returns a 64-byte signature (R||S).
* The public key encoder is implicitly the one derived from the seed.
*/
public static byte[] sign(byte[] seed32, byte[] message) {
if (seed32 == null || seed32.length != 32) throw new IllegalArgumentException("seed must be 32 bytes");
if (message == null) message = new byte[0];
// H = SHA-512(seed)
Digest h = sha512(seed32);
byte[] aBytes = clamp(Arrays.copyOfRange(h.bytes, 0, 32)); // clamped "a" (LE)
byte[] prefix = Arrays.copyOfRange(h.bytes, 32, 64);
BigInteger a = leToInt(aBytes);
byte[] Aenc = encodePoint(scalarMulBase(a));
// r = SHA-512(prefix || M) mod L
Digest rDig = sha512(prefix, message);
BigInteger r = leToInt(rDig.bytes).mod(L);
// R = r * B
byte[] Renc = encodePoint(scalarMulBase(r));
// k = SHA-512(Renc || Aenc || M) mod L
Digest kDig = sha512(Renc, Aenc, message);
BigInteger k = leToInt(kDig.bytes).mod(L);
// S = r + k*a (mod L)
BigInteger S = r.add(k.multiply(a)).mod(L);
// sig = Renc (32) || S(le,32)
byte[] sig = new byte[64];
System.arraycopy(Renc, 0, sig, 0, 32);
System.arraycopy(intToLe(S, 32), 0, sig, 32, 32);
return sig;
}
/**
* Verifies a 64-byte signature on a message using a 32-byte public key.
*/
public static boolean verify(byte[] publicKey32, byte[] message, byte[] signature64) {
if (publicKey32 == null || publicKey32.length != 32) return false;
if (signature64 == null || signature64.length != 64) return false;
if (message == null) message = new byte[0];
try {
// parse signature
byte[] Renc = Arrays.copyOfRange(signature64, 0, 32);
byte[] Senc = Arrays.copyOfRange(signature64, 32, 64);
BigInteger S = leToInt(Senc);
if (S.signum() < 0 || S.compareTo(L) >= 0) return false;
Point A = decodePoint(publicKey32);
Point R = decodePoint(Renc);
// k = SHA-512(Renc || Aenc || M) mod L
Digest kDig = sha512(Renc, publicKey32, message);
BigInteger k = leToInt(kDig.bytes).mod(L);
// check: [S]B == R + [k]A
Point left = scalarMulBase(S);
Point right = add(R, scalarMul(A, k));
return left.equals(right);
} catch (Exception ex) {
return false;
}
}
// --- structures & utilities ---
public static final class KeyPair {
private final byte[] seed32; // private (seed)
private final byte[] public32; // public key (encoded)
public KeyPair(byte[] seed32, byte[] public32) {
this.seed32 = Arrays.copyOf(seed32, 32);
this.public32 = Arrays.copyOf(public32, 32);
}
public byte[] getPrivateSeed() { return Arrays.copyOf(seed32, 32); }
public byte[] getPublicKey() { return Arrays.copyOf(public32, 32); }
}
private static final class Digest {
final byte[] bytes;
Digest(byte[] b) { this.bytes = b; }
}
private static Digest sha512(byte[]... parts) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-512");
for (byte[] p : parts) md.update(p);
return new Digest(md.digest());
} catch (Exception e) {
throw new IllegalStateException(e);
}
}
private static byte[] clamp(byte[] a) {
// a: 32 bytes LE
a[0] &= (byte) 0xF8;
a[31] &= (byte) 0x7F;
a[31] |= (byte) 0x40;
return a;
}
private static byte[] intToLe(BigInteger x, int len) {
byte[] be = x.toByteArray(); // signed BE
if (be.length > 1 && be[0] == 0) be = Arrays.copyOfRange(be, 1, be.length);
byte[] out = new byte[len];
// copy BE -> end, then reverse to LE
int copy = Math.min(be.length, len);
System.arraycopy(be, be.length - copy, out, len - copy, copy);
for (int i = 0, j = len - 1; i < j; i++, j--) { byte t = out[i]; out[i] = out[j]; out[j] = t; }
return out;
}
private static BigInteger leToInt(byte[] le) {
byte[] be = Arrays.copyOf(le, le.length);
for (int i = 0, j = be.length - 1; i < j; i++, j--) { byte t = be[i]; be[i] = be[j]; be[j] = t; }
return new BigInteger(1, be);
}
private static BigInteger inv(BigInteger x) { return x.modInverse(P); }
// point on the curve (affine coordinates)
private static final class Point {
final BigInteger x, y;
Point(BigInteger x, BigInteger y) { this.x = x.mod(P); this.y = y.mod(P); }
@Override public boolean equals(Object o) {
if (!(o instanceof Point)) return false;
Point p = (Point) o;
return this.x.equals(p.x) && this.y.equals(p.y);
}
}
private static final Point B = new Point(BX, BY);
private static final Point ID = new Point(BigInteger.ZERO, BigInteger.ONE); // neutral element
// Edwards addition (a = -1)
private static Point add(Point P1, Point P2) {
BigInteger x1 = P1.x, y1 = P1.y, x2 = P2.x, y2 = P2.y;
BigInteger x1x2 = x1.multiply(x2).mod(P);
BigInteger y1y2 = y1.multiply(y2).mod(P);
BigInteger x1y2 = x1.multiply(y2).mod(P);
BigInteger y1x2 = y1.multiply(x2).mod(P);
BigInteger denX = BigInteger.ONE.add(D.multiply(x1x2).mod(P).multiply(y1y2).mod(P)).mod(P);
BigInteger denY = BigInteger.ONE.subtract(D.multiply(x1x2).mod(P).multiply(y1y2).mod(P)).mod(P);
BigInteger x3 = x1y2.add(y1x2).mod(P).multiply(inv(denX)).mod(P);
BigInteger y3 = y1y2.add(x1x2).mod(P).multiply(inv(denY)).mod(P);
return new Point(x3, y3);
}
private static Point dbl(Point P) { return add(P, P); }
private static Point scalarMul(Point P, BigInteger k) {
Point Q = ID;
for (int i = k.bitLength() - 1; i >= 0; i--) {
Q = dbl(Q);
if (k.testBit(i)) Q = add(Q, P);
}
return Q;
}
private static Point scalarMulBase(BigInteger k) { return scalarMul(B, k); }
// encoding (y in LE 255 bits + sign bit of x in MSB)
private static byte[] encodePoint(Point P) {
BigInteger x = P.x, y = P.y;
BigInteger yMasked = y; // 255 bits
byte[] enc = intToLe(yMasked, 32);
// set MSB (bit 255) with the parity bit of x
if (x.testBit(0)) enc[31] = (byte) (enc[31] | 0x80);
else enc[31] = (byte) (enc[31] & 0x7F);
return enc;
}
// decoding from 32 bytes LE (y || sign(x))
private static Point decodePoint(byte[] enc) {
if (enc == null || enc.length != 32) throw new IllegalArgumentException("point encoding must be 32 bytes");
byte[] yLe = Arrays.copyOf(enc, 32);
int signX = (yLe[31] & 0x80) >>> 7;
yLe[31] &= 0x7F; // mask of sign bit
BigInteger y = leToInt(yLe);
// x^2 = (y^2 - 1) / (d*y^2 + 1)
BigInteger y2 = y.multiply(y).mod(P);
BigInteger u = y2.subtract(BigInteger.ONE).mod(P);
BigInteger v = D.multiply(y2).add(BigInteger.ONE).mod(P);
BigInteger x2 = u.multiply(inv(v)).mod(P);
// square root mod p (p ≡ 5 (mod 8))
BigInteger x = x2.modPow(P.add(BigInteger.valueOf(3)).shiftRight(3), P);
if (!x.multiply(x).mod(P).equals(x2)) {
x = x.multiply(SQRT_M1).mod(P);
}
if (!x.multiply(x).mod(P).equals(x2)) {
throw new IllegalArgumentException("invalid point encoding");
}
if (x.testBit(0) != (signX == 1)) {
x = P.subtract(x);
}
return new Point(x, y);
}
}