Newer
Older
BouncyScrypt / src / main / java / helpers / ScryptHelper.java
package helpers;

import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bouncycastle.crypto.generators.SCrypt;

/**
 * Wrapper for Bouncy Castle's scrypt implementation.
 *
 * Generates salted scrypt hashes in a format similar to Modular Crypt Format
 * (MCF).
 *
 * @author Mark George <mark.george@otago.ac.nz>
 */
public final class ScryptHelper {

	// the standard work factors for scrypt
	private static final int N = 32768;
	private static final int r = 8;
	private static final int p = 1;

	private static final int saltSize = 64;
	private static final int dkLen = 64;

	private static final Charset utf8 = StandardCharsets.UTF_8;

	private ScryptHelper() {
	}

	/**
	 * Generates an MCF-like formatted salted scrypt hash for the password.
	 *
	 * @param password The password to hash.
	 * @return The MCF formatted hash.
	 */
	public static CharBuffer hash(CharSequence password) {
		byte[] salt = salt();
		byte[] hash = hash(password, salt, N, r, p, dkLen);

		Base64.Encoder b64encoder = Base64.getEncoder();

		int costParams = (log2(N)) << 16 | r << 8 | p;
		byte[] dollar = "$".getBytes(utf8);
		byte[] params = String.valueOf(costParams).getBytes(utf8);
		byte[] b64salt = b64encoder.encode(salt);
		byte[] b64hash = b64encoder.encode(hash);

		int size = (4 * dollar.length) + params.length + b64salt.length + b64hash.length;

		ByteBuffer mcf = ByteBuffer.allocate(size);
		mcf.put(dollar).put(params);
		mcf.put(dollar).put(b64salt);
		mcf.put(dollar).put(b64hash);
		mcf.put(dollar);
		mcf.flip();

		CharBuffer result = utf8.decode(mcf);

		Arrays.fill(hash, (byte) 0);
		Arrays.fill(salt, (byte) 0);
		Arrays.fill(b64hash, (byte) 0);
		Arrays.fill(b64salt, (byte) 0);
		Arrays.fill(mcf.array(), (byte) 0);

		return result;
	}

	/**
	 * Checks a hash (as generated by the hash method) against a password.
	 *
	 * @param mcfHash The MCF formatted hash.
	 * @param password The password.
	 *
	 * @return True if the password matches the hash, false if not.
	 */
	public static boolean check(CharSequence mcfHash, CharSequence password) {

		Pattern regex = Pattern.compile("\\$(\\d+?)\\$(.+?)\\$(.+?)\\$");

		Matcher matcher = regex.matcher(mcfHash);

		if (matcher.matches()) {
			int costParams = Integer.parseInt(matcher.group(1));

			Base64.Decoder decoder = Base64.getDecoder();

			byte[] salt = decoder.decode(matcher.group(2));
			byte[] hash = decoder.decode(matcher.group(3));

			int hN = 1 << (costParams >> 16);
			int hr = costParams >> 8 & 255;
			int hp = costParams & 255;

			byte[] cHash = hash(password, salt, hN, hr, hp, hash.length);

			for (int i = 0; i < cHash.length; i++) {
				if (cHash[i] != hash[i]) {
					return false;
				}
			}

			return true;

		} else {
			throw new IllegalArgumentException("Hash is not in a recognisable format!");
		}

	}

	static byte[] hash(CharSequence password, byte[] salt, int N, int r, int p, int dkLen) {
		return SCrypt.generate(toBytes(password), salt, N, r, p, dkLen);
	}

	private static byte[] salt() {
		try {
			byte[] salt = new byte[saltSize];
			SecureRandom.getInstance("SHA1PRNG").nextBytes(salt);
			return salt;
		} catch (NoSuchAlgorithmException ex) {
			throw new IllegalStateException("SHA1PRNG not supported on this JVM!", ex);
		}
	}

	private static int log2(int operand) {
		double log2 = Math.log(operand) / Math.log(2);
		if (log2 % 1 != 0) {
			throw new IllegalArgumentException("N must be a power of 2.");
		} else {
			return Math.toIntExact(Math.round(log2));
		}
	}

	/*
	 * Source: https://stackoverflow.com/a/9670279
	 */
	private static byte[] toBytes(CharSequence chars) {
		CharBuffer charBuffer = CharBuffer.wrap(chars);
		ByteBuffer byteBuffer = utf8.encode(charBuffer);
		byte[] bytes = Arrays.copyOfRange(byteBuffer.array(), byteBuffer.position(), byteBuffer.limit());
		Arrays.fill(byteBuffer.array(), (byte) 0);
		return bytes;
	}

}