Newer
Older
BouncyScrypt / src / main / java / helpers / ScryptHelper.java
package helpers;

import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.util.Arrays;
import org.bouncycastle.crypto.generators.SCrypt;
import org.bouncycastle.util.encoders.Base64;

/**
 * Wrapper for Bouncy Castle's scrypt implementation.
 *
 * Generates salted scrypt hashes in a format similar to Modular Crypt Format
 * (MCF).
 *
 * @license FreeBSD License (BSD-2-Clause) https://opensource.org/licenses/BSD-2-Clause
 * 
 * @author Mark George <mark.george@otago.ac.nz>
 */
public final class ScryptHelper {

	// the standard work factors for scrypt
	private static final int N = 32768;
	private static final int r = 8;
	private static final int p = 1;

	private static final int saltSize = 64;
	private static final int dkLen = 64;

	private static final Charset utf8 = StandardCharsets.UTF_8;

	private static final SecureRandom rng = new SecureRandom();

	private ScryptHelper() {
	}

	/**
	 * Generates an MCF-like formatted salted scrypt hash for the password.
	 *
	 * @param password The password to hash.
	 * @return The MCF formatted hash.
	 */
	public static CharBuffer hash(CharSequence password) {
		byte[] salt = salt();
		byte[] hash = hash(password, salt, N, r, p, dkLen);

		int costParams = (log2(N)) << 16 | r << 8 | p;
		byte[] dollar = "$".getBytes(utf8);
		byte[] params = String.valueOf(costParams).getBytes(utf8);
		byte[] b64salt = Base64.encode(salt);
		byte[] b64hash = Base64.encode(hash);

		int size = (4 * dollar.length) + params.length + b64salt.length + b64hash.length;

		ByteBuffer mcf = ByteBuffer.allocate(size);
		mcf.put(dollar).put(params);
		mcf.put(dollar).put(b64salt);
		mcf.put(dollar).put(b64hash);
		mcf.put(dollar);
		mcf.flip();

		CharBuffer result = utf8.decode(mcf);

		Arrays.fill(hash, (byte) 0);
		Arrays.fill(salt, (byte) 0);
		Arrays.fill(b64hash, (byte) 0);
		Arrays.fill(b64salt, (byte) 0);
		Arrays.fill(mcf.array(), (byte) 0);

		return result;
	}

	/**
	 * Checks a hash (as generated by the hash method) against a password.
	 *
	 * @param mcfHash The MCF formatted hash.
	 * @param password The password.
	 *
	 * @return True if the password matches the hash, false if not.
	 */
	public static boolean check(CharSequence mcfHash, CharSequence password) {

		if (mcfHash == null || mcfHash.length() == 0 || mcfHash.charAt(0) != '$') {
			throw new IllegalArgumentException("Hash is not in a recognisable format");
		}

		int[] indices = new int[4];
		int count = 0;

		for (int i = 0; i < mcfHash.length() && count < 4; i++) {
			if (mcfHash.charAt(i) == '$') {
				indices[count++] = i;
			}
		}

		if (count != 4) {
			throw new IllegalArgumentException("Invalid format: missing delimiters");
		}

		int costParams = parseAsInt(mcfHash.subSequence(indices[0] + 1, indices[1]));

		byte[] saltB64 = toBytes(mcfHash.subSequence(indices[1] + 1, indices[2]));
		byte[] hashB64 = toBytes(mcfHash.subSequence(indices[2] + 1, indices[3]));

		byte[] salt = Base64.decode(saltB64);
		byte[] hash = Base64.decode(hashB64);

		try {
			int hN = 1 << (costParams >> 16);
			int hr = (costParams >> 8) & 255;
			int hp = costParams & 255;

			byte[] cHash = hash(password, salt, hN, hr, hp, hash.length);

			boolean matches = MessageDigest.isEqual(cHash, hash);

			Arrays.fill(cHash, (byte) 0);
			return matches;

		} finally {
			Arrays.fill(saltB64, (byte) 0);
			Arrays.fill(hashB64, (byte) 0);
			Arrays.fill(salt, (byte) 0);
			Arrays.fill(hash, (byte) 0);
		}
	}

	/**
	 * Parses a CharSequence as an integer without creating a String.
	 */
	private static int parseAsInt(CharSequence s) {
		int val = 0;
		for (int i = 0; i < s.length(); i++) {
			int digit = Character.digit(s.charAt(i), 10);
			if (digit < 0) {
				throw new NumberFormatException("Invalid digit: " + s.charAt(i));
			}
			val = val * 10 + digit;
		}
		return val;
	}

	static byte[] hash(CharSequence password, byte[] salt, int N, int r, int p, int dkLen) {
		byte[] pwd = toBytes(password);
		try {
			return SCrypt.generate(pwd, salt, N, r, p, dkLen);
		} finally {
			Arrays.fill(pwd, (byte) 0);
		}
	}

	private static byte[] salt() {
		byte[] salt = new byte[saltSize];
		rng.nextBytes(salt);
		return salt;
	}

	private static int log2(int n) {
		if (n <= 0 || (n & (n - 1)) != 0) {
			throw new IllegalArgumentException("N must be a power of 2");
		}
		return Integer.numberOfTrailingZeros(n);
	}

	private static byte[] toBytes(CharSequence chars) {
		CharBuffer charBuffer = CharBuffer.wrap(chars);
		ByteBuffer byteBuffer = ByteBuffer.allocate(charBuffer.remaining() * 4); // max 4 bytes per char for UTF-8

		CharsetEncoder encoder = utf8.newEncoder();
		encoder.encode(charBuffer, byteBuffer, true);
		encoder.flush(byteBuffer);

		byteBuffer.flip();
		byte[] result = new byte[byteBuffer.remaining()];
		byteBuffer.get(result);

		Arrays.fill(byteBuffer.array(), (byte) 0);
		return result;
	}

}