1
0
Fork 0
mirror of https://github.com/vanitasvitae/Smack.git synced 2024-06-16 08:34:50 +02:00
Smack/smack-core/src/main/java/org/jivesoftware/smack/sasl/core/SCRAMSHA1Mechanism.java
Florian Schmaus 403ecff2b2 Add SCRAM-SHA1 support
Thanks to Stefan Karlsson for helping with the implementation.

Also add SASLMechanism.checkIfSuccessfulOrThrow(), to increase the
security by verifying the mechanisms state at the end of SASL
authentication.

SASLMechanism now has a SASLPrep StringTransformer.

Refactor SHA1 functions out of StringUtils into SHA1 utility class.

Add MAC utility class.

Make DummyConnection getSentpacket() methods use generics to make unit
testing SCRAM-SHA1 easier.

Fixes SMACK-398
2014-10-21 15:03:48 +02:00

347 lines
12 KiB
Java

/**
*
* Copyright 2014 Florian Schmaus
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jivesoftware.smack.sasl.core;
import java.security.InvalidKeyException;
import java.security.SecureRandom;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.security.auth.callback.CallbackHandler;
import org.jivesoftware.smack.SmackException;
import org.jivesoftware.smack.sasl.SASLMechanism;
import org.jivesoftware.smack.util.ByteUtils;
import org.jivesoftware.smack.util.MAC;
import org.jivesoftware.smack.util.SHA1;
import org.jivesoftware.smack.util.stringencoder.Base64;
import org.jxmpp.util.cache.Cache;
import org.jxmpp.util.cache.LruCache;
public class SCRAMSHA1Mechanism extends SASLMechanism {
public static final String NAME = "SCRAM-SHA-1";
private static final int RANDOM_ASCII_BYTE_COUNT = 32;
private static final String DEFAULT_GS2_HEADER = "n,,";
private static final byte[] CLIENT_KEY_BYTES = toBytes("Client Key");
private static final byte[] SERVER_KEY_BYTES = toBytes("Server Key");
private static final byte[] ONE = new byte[] { 0, 0, 0, 1 };
private static final SecureRandom RANDOM = new SecureRandom();
private static final Cache<String, Keys> CACHE = new LruCache<String, Keys>(10);
private enum State {
INITIAL,
AUTH_TEXT_SENT,
RESPONSE_SENT,
VALID_SERVER_RESPONSE,
}
/**
* The state of the this instance of SASL SCRAM-SHA1 authentication.
*/
private State state = State.INITIAL;
/**
* The client's random ASCII which is used as nonce
*/
private String clientRandomAscii;
private String clientFirstMessageBare;
private byte[] serverSignature;
@Override
protected void authenticateInternal(CallbackHandler cbh) throws SmackException {
throw new UnsupportedOperationException("CallbackHandler not (yet) supported");
}
@Override
protected byte[] getAuthenticationText() throws SmackException {
clientRandomAscii = getRandomAscii();
String saslPrepedAuthcId = saslPrep(authenticationId);
clientFirstMessageBare = "n=" + escape(saslPrepedAuthcId) + ",r=" + clientRandomAscii;
String clientFirstMessage = DEFAULT_GS2_HEADER + clientFirstMessageBare;
state = State.AUTH_TEXT_SENT;
return toBytes(clientFirstMessage);
}
@Override
public String getName() {
return NAME;
}
@Override
public int getPriority() {
return 110;
}
@Override
public SCRAMSHA1Mechanism newInstance() {
return new SCRAMSHA1Mechanism();
}
@Override
public void checkIfSuccessfulOrThrow() throws SmackException {
if (state != State.VALID_SERVER_RESPONSE) {
throw new SmackException("SCRAM-SHA1 is missing valid server response");
}
}
@Override
protected byte[] evaluateChallenge(byte[] challenge) throws SmackException {
final String challengeString = new String(challenge);
switch (state) {
case AUTH_TEXT_SENT:
final String serverFirstMessage = challengeString;
Map<Character, String> attributes = parseAttributes(challengeString);
// Handle server random ASCII (nonce)
String rvalue = attributes.get('r');
if (rvalue == null) {
throw new SmackException("Server random ASCII is null");
}
if (rvalue.length() <= clientRandomAscii.length()) {
throw new SmackException("Server random ASCII is shorter then client random ASCII");
}
String receivedClientRandomAscii = rvalue.substring(0, clientRandomAscii.length());
if (!receivedClientRandomAscii.equals(clientRandomAscii)) {
throw new SmackException("Received client random ASCII does not match client random ASCII");
}
// Handle iterations
int iterations;
String iterationsString = attributes.get('i');
if (iterationsString == null) {
throw new SmackException("Iterations attribute not set");
}
try {
iterations = Integer.parseInt(iterationsString);
}
catch (NumberFormatException e) {
throw new SmackException("Exception parsing iterations", e);
}
// Handle salt
String salt = attributes.get('s');
if (salt == null) {
throw new SmackException("SALT not send");
}
// Parsing and error checking is done, we can now begin to calculate the values
// First the client-final-message-without-proof
String clientFinalMessageWithoutProof = "c=" + Base64.encode(DEFAULT_GS2_HEADER) + ",r=" + rvalue;
// AuthMessage := client-first-message-bare + "," + server-first-message + "," +
// client-final-message-without-proof
byte[] authMessage = toBytes(clientFirstMessageBare + ',' + serverFirstMessage + ','
+ clientFinalMessageWithoutProof);
// RFC 5802 § 5.1 "Note that a client implementation MAY cache ClientKey&ServerKey … for later reauthentication …
// as it is likely that the server is going to advertise the same salt value upon reauthentication."
final String cacheKey = password + ',' + salt;
byte[] serverKey, clientKey;
Keys keys = CACHE.get(cacheKey);
if (keys == null) {
// SaltedPassword := Hi(Normalize(password), salt, i)
byte[] saltedPassword = hi(saslPrep(password), Base64.decode(salt), iterations);
// ServerKey := HMAC(SaltedPassword, "Server Key")
serverKey = hmac(saltedPassword, SERVER_KEY_BYTES);
// ServerSignature := HMAC(ServerKey, AuthMessage)
serverSignature = hmac(serverKey, authMessage);
// ClientKey := HMAC(SaltedPassword, "Client Key")
clientKey = hmac(saltedPassword, CLIENT_KEY_BYTES);
keys = new Keys(clientKey, serverKey);
CACHE.put(cacheKey, keys);
}
else {
serverKey = keys.serverKey;
clientKey = keys.clientKey;
}
// StoredKey := H(ClientKey)
byte[] storedKey = SHA1.bytes(clientKey);
// ClientSignature := HMAC(StoredKey, AuthMessage)
byte[] clientSignature = hmac(storedKey, authMessage);
// ClientProof := ClientKey XOR ClientSignature
byte[] clientProof = new byte[clientKey.length];
for (int i = 0; i < clientProof.length; i++) {
clientProof[i] = (byte) (clientKey[i] ^ clientSignature[i]);
}
String clientFinalMessage = clientFinalMessageWithoutProof + ",p=" + Base64.encodeToString(clientProof);
state = State.RESPONSE_SENT;
return toBytes(clientFinalMessage);
case RESPONSE_SENT:
String clientCalculatedServerFinalMessage = "v=" + Base64.encodeToString(serverSignature);
if (!clientCalculatedServerFinalMessage.equals(challengeString)) {
throw new SmackException("Server final message does not match calculated one");
}
state = State.VALID_SERVER_RESPONSE;
break;
default:
throw new SmackException("Invalid state");
}
return null;
}
private static Map<Character, String> parseAttributes(String string) throws SmackException {
if (string.length() == 0) {
return Collections.emptyMap();
}
String[] keyValuePairs = string.split(",");
Map<Character, String> res = new HashMap<Character, String>(keyValuePairs.length, 1);
for (String keyValuePair : keyValuePairs) {
if (keyValuePair.length() < 3) {
throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
}
char key = keyValuePair.charAt(0);
if (keyValuePair.charAt(1) != '=') {
throw new SmackException("Invalid Key-Value pair: " + keyValuePair);
}
String value = keyValuePair.substring(2);
res.put(key, value);
}
return res;
}
/**
* Generate random ASCII.
* <p>
* This method is non-static and package-private for unit testing purposes.
* </p>
* @return A String of 32 random printable ASCII characters.
*/
String getRandomAscii() {
int count = 0;
char[] randomAscii = new char[RANDOM_ASCII_BYTE_COUNT];
while (count < RANDOM_ASCII_BYTE_COUNT) {
int r = RANDOM.nextInt(128);
char c = (char) r;
// RFC 5802 § 5.1 specifies 'r:' to exclude the ',' character and to be only printable ASCII characters
if (!isPrintableNonCommaAsciiChar(c)) {
continue;
}
randomAscii[count++] = c;
}
return new String(randomAscii);
}
private static boolean isPrintableNonCommaAsciiChar(char c) {
if (c == ',') {
return false;
}
return c >= 32 && c < 127;
}
/**
* Escapes usernames or passwords for SASL SCRAM-SHA1.
* <p>
* According to RFC 5802 § 5.1 'n:'
* "The characters ',' or '=' in usernames are sent as '=2C' and '=3D' respectively."
* </p>
*
* @param string
* @return the escaped string
*/
private static String escape(String string) {
StringBuilder sb = new StringBuilder((int) (string.length() * 1.1));
for (int i = 0; i < string.length(); i++) {
char c = string.charAt(i);
switch (c) {
case ',':
sb.append("=2C");
break;
case '=':
sb.append("=3D");
break;
default:
sb.append(c);
break;
}
}
return sb.toString();
}
/**
* RFC 5802 § 2.2 HMAC(key, str)
*
* @param key
* @param str
* @return
* @throws SmackException
*/
private static byte[] hmac(byte[] key, byte[] str) throws SmackException {
try {
return MAC.hmacsha1(key, str);
}
catch (InvalidKeyException e) {
throw new SmackException(NAME + " HMAC-SHA1 Exception", e);
}
}
/**
* RFC 5802 § 2.2 Hi(str, salt, i)
* <p>
* Hi() is, essentially, PBKDF2 [RFC2898] with HMAC() as the pseudorandom function
* (PRF) and with dkLen == output length of HMAC() == output length of H().
* </p>
*
* @param str
* @param salt
* @param iterations
* @return
* @throws SmackException
*/
private static byte[] hi(String str, byte[] salt, int iterations) throws SmackException {
byte[] key = str.getBytes();
// U1 := HMAC(str, salt + INT(1))
byte[] u = hmac(key, ByteUtils.concact(salt, ONE));
byte[] res = u.clone();
for (int i = 1; i < iterations; i++) {
u = hmac(key, u);
for (int j = 0; j < u.length; j++) {
res[j] ^= u[j];
}
}
return res;
}
private static class Keys {
private final byte[] clientKey;
private final byte[] serverKey;
public Keys(byte[] clientKey, byte[] serverKey) {
this.clientKey = clientKey;
this.serverKey = serverKey;
}
}
}