Newer
Older
BouncyScrypt / src / main / java / helpers / ScryptHelper.java
package helpers;

import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.nio.charset.CharsetEncoder;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import org.bouncycastle.crypto.generators.SCrypt;

/**
 * Wrapper for Bouncy Castle's scrypt implementation.
 *
 * Generates salted scrypt hashes in a format similar to Modular Crypt Format
 * (MCF).
 *
 * @author Mark George <mark.george@otago.ac.nz>
 */
public final class ScryptHelper {

	// the standard work factors for scrypt
	private static final int N = 32768;
	private static final int r = 8;
	private static final int p = 1;

	private static final int saltSize = 64;
	private static final int dkLen = 64;

	private static final Charset utf8 = StandardCharsets.UTF_8;

	private ScryptHelper() {
	}

	/**
	 * Generates an MCF-like formatted salted scrypt hash for the password.
	 *
	 * @param password The password to hash.
	 * @return The MCF formatted hash.
	 */
	public static CharBuffer hash(CharSequence password) {
		byte[] salt = salt();
		byte[] hash = hash(password, salt, N, r, p, dkLen);

		Base64.Encoder b64encoder = Base64.getEncoder();

		int costParams = (log2(N)) << 16 | r << 8 | p;
		byte[] dollar = "$".getBytes(utf8);
		byte[] params = String.valueOf(costParams).getBytes(utf8);
		byte[] b64salt = b64encoder.encode(salt);
		byte[] b64hash = b64encoder.encode(hash);

		int size = (4 * dollar.length) + params.length + b64salt.length + b64hash.length;

		ByteBuffer mcf = ByteBuffer.allocate(size);
		mcf.put(dollar).put(params);
		mcf.put(dollar).put(b64salt);
		mcf.put(dollar).put(b64hash);
		mcf.put(dollar);
		mcf.flip();

		CharBuffer result = utf8.decode(mcf);

		Arrays.fill(hash, (byte) 0);
		Arrays.fill(salt, (byte) 0);
		Arrays.fill(b64hash, (byte) 0);
		Arrays.fill(b64salt, (byte) 0);
		Arrays.fill(mcf.array(), (byte) 0);

		return result;
	}

	/**
	 * Checks a hash (as generated by the hash method) against a password.
	 *
	 * @param mcfHash The MCF formatted hash.
	 * @param password The password.
	 *
	 * @return True if the password matches the hash, false if not.
	 */
	public static boolean check(CharSequence mcfHash, CharSequence password) {

		// expected format: $params$salt$hash$
		if (mcfHash == null || mcfHash.length() == 0 || mcfHash.charAt(0) != '$') {
			throw new IllegalArgumentException("Hash is not in a recognisable format");
		}

		int[] indices = new int[4];
		int count = 0;

		// find positions of the delimiters
		for (int i = 0; i < mcfHash.length() && count < 4; i++) {
			if (mcfHash.charAt(i) == '$') {
				indices[count++] = i;
			}
		}

		if (count != 4) {
			throw new IllegalArgumentException("Invalid format: missing delimiters");
		}

		CharSequence paramsSeq = mcfHash.subSequence(indices[0] + 1, indices[1]);
		int costParams = parseAsInt(paramsSeq);

		Base64.Decoder decoder = Base64.getDecoder();

		CharSequence saltSeq = mcfHash.subSequence(indices[1] + 1, indices[2]);
		byte[] salt = decoder.decode(toBytes(saltSeq));

		CharSequence hashSeq = mcfHash.subSequence(indices[2] + 1, indices[3]);
		byte[] hash = decoder.decode(toBytes(hashSeq));

		try {
			int hN = 1 << (costParams >> 16);
			int hr = (costParams >> 8) & 255;
			int hp = costParams & 255;

			byte[] cHash = hash(password, salt, hN, hr, hp, hash.length);

			boolean matches = MessageDigest.isEqual(cHash, hash);

			Arrays.fill(cHash, (byte) 0);

			return matches;

		} finally {
			Arrays.fill(salt, (byte) 0);
			Arrays.fill(hash, (byte) 0);
		}
	}

	/**
	 * Parses a CharSequence as an integer without creating a String.
	 */
	private static int parseAsInt(CharSequence s) {
		int val = 0;
		for (int i = 0; i < s.length(); i++) {
			int digit = Character.digit(s.charAt(i), 10);
			if (digit < 0) {
				throw new NumberFormatException("Invalid digit: " + s.charAt(i));
			}
			val = val * 10 + digit;
		}
		return val;
	}

	static byte[] hash(CharSequence password, byte[] salt, int N, int r, int p, int dkLen) {
		return SCrypt.generate(toBytes(password), salt, N, r, p, dkLen);
	}

	private static byte[] salt() {
		try {
			byte[] salt = new byte[saltSize];
			SecureRandom.getInstanceStrong().nextBytes(salt);
			return salt;
		} catch (NoSuchAlgorithmException ex) {
			throw new IllegalStateException("Strong RNG not supported on this computer", ex);
		}
	}

	private static int log2(int n) {
		if (n <= 0 || (n & (n - 1)) != 0) {
			throw new IllegalArgumentException("N must be a power of 2");
		}
		return Integer.numberOfTrailingZeros(n);
	}

	/*
	 * Source: https://stackoverflow.com/a/9670279
	 */
	private static byte[] toBytes(CharSequence chars) {
		CharBuffer charBuffer = CharBuffer.wrap(chars);
		ByteBuffer byteBuffer = ByteBuffer.allocate(charBuffer.remaining() * 3); // 3 bytes per char UTF-8

		CharsetEncoder encoder = utf8.newEncoder();
		encoder.encode(charBuffer, byteBuffer, true);
		encoder.flush(byteBuffer);

		byteBuffer.flip();
		byte[] result = new byte[byteBuffer.remaining()];
		byteBuffer.get(result);

		Arrays.fill(byteBuffer.array(), (byte) 0);
		return result;
	}

}