1
0
Fork 0
mirror of https://github.com/pgpainless/pgpainless.git synced 2024-12-23 03:17:58 +01:00

Reinstate integrity-protection and fix tests

Integrity Protection is now checked when reading from the stream,
not only when closing.
This commit is contained in:
Paul Schaub 2022-10-16 19:12:17 +02:00
parent 654493dfcc
commit fbcde13df3
3 changed files with 46 additions and 15 deletions

View file

@ -272,13 +272,15 @@ public class OpenPgpMessageInputStream extends DecryptionStream {
PGPPBEEncryptedData skesk = (PGPPBEEncryptedData) esk; PGPPBEEncryptedData skesk = (PGPPBEEncryptedData) esk;
InputStream decrypted = skesk.getDataStream(decryptorFactory); InputStream decrypted = skesk.getDataStream(decryptorFactory);
encryptedData.sessionKey = sessionKey; encryptedData.sessionKey = sessionKey;
nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, skesk, options);
nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy);
return true; return true;
} else if (esk instanceof PGPPublicKeyEncryptedData) { } else if (esk instanceof PGPPublicKeyEncryptedData) {
PGPPublicKeyEncryptedData pkesk = (PGPPublicKeyEncryptedData) esk; PGPPublicKeyEncryptedData pkesk = (PGPPublicKeyEncryptedData) esk;
InputStream decrypted = pkesk.getDataStream(decryptorFactory); InputStream decrypted = pkesk.getDataStream(decryptorFactory);
encryptedData.sessionKey = sessionKey; encryptedData.sessionKey = sessionKey;
nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, pkesk, options);
nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy);
return true; return true;
} else { } else {
throw new RuntimeException("Unknown ESK class type: " + esk.getClass().getName()); throw new RuntimeException("Unknown ESK class type: " + esk.getClass().getName());
@ -302,7 +304,8 @@ public class OpenPgpMessageInputStream extends DecryptionStream {
throwIfUnacceptable(sessionKey.getAlgorithm()); throwIfUnacceptable(sessionKey.getAlgorithm());
MessageMetadata.EncryptedData encryptedData = new MessageMetadata.EncryptedData(sessionKey.getAlgorithm()); MessageMetadata.EncryptedData encryptedData = new MessageMetadata.EncryptedData(sessionKey.getAlgorithm());
encryptedData.sessionKey = sessionKey; encryptedData.sessionKey = sessionKey;
nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, skesk, options);
nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy);
return true; return true;
} catch (UnacceptableAlgorithmException e) { } catch (UnacceptableAlgorithmException e) {
throw e; throw e;
@ -334,7 +337,8 @@ public class OpenPgpMessageInputStream extends DecryptionStream {
SymmetricKeyAlgorithm.requireFromId(pkesk.getSymmetricAlgorithm(decryptorFactory))); SymmetricKeyAlgorithm.requireFromId(pkesk.getSymmetricAlgorithm(decryptorFactory)));
encryptedData.sessionKey = sessionKey; encryptedData.sessionKey = sessionKey;
nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, pkesk, options);
nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy);
return true; return true;
} catch (UnacceptableAlgorithmException e) { } catch (UnacceptableAlgorithmException e) {
throw e; throw e;
@ -359,7 +363,9 @@ public class OpenPgpMessageInputStream extends DecryptionStream {
MessageMetadata.EncryptedData encryptedData = new MessageMetadata.EncryptedData( MessageMetadata.EncryptedData encryptedData = new MessageMetadata.EncryptedData(
SymmetricKeyAlgorithm.requireFromId(pkesk.getSymmetricAlgorithm(decryptorFactory))); SymmetricKeyAlgorithm.requireFromId(pkesk.getSymmetricAlgorithm(decryptorFactory)));
encryptedData.sessionKey = sessionKey; encryptedData.sessionKey = sessionKey;
nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy);
IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, pkesk, options);
nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy);
return true; return true;
} catch (PGPException e) { } catch (PGPException e) {
// hm :/ // hm :/
@ -491,6 +497,7 @@ public class OpenPgpMessageInputStream extends DecryptionStream {
automaton.next(InputAlphabet.EndOfSequence); automaton.next(InputAlphabet.EndOfSequence);
automaton.assertValid(); automaton.assertValid();
packetInputStream.close();
closed = true; closed = true;
} }

View file

@ -96,6 +96,12 @@ public class TeeBCPGInputStream {
return markerPacket; return markerPacket;
} }
public void close() throws IOException {
this.packetInputStream.close();
this.delayedTee.close();
}
public static class DelayedTeeInputStreamInputStream extends InputStream { public static class DelayedTeeInputStreamInputStream extends InputStream {
private int last = -1; private int last = -1;
@ -112,8 +118,12 @@ public class TeeBCPGInputStream {
if (last != -1) { if (last != -1) {
outputStream.write(last); outputStream.write(last);
} }
last = inputStream.read(); try {
return last; last = inputStream.read();
return last;
} catch (IOException e) {
return -1;
}
} }
/** /**
@ -127,5 +137,11 @@ public class TeeBCPGInputStream {
} }
last = -1; last = -1;
} }
@Override
public void close() throws IOException {
inputStream.close();
outputStream.close();
}
} }
} }

View file

@ -238,8 +238,10 @@ public class ModificationDetectionTests {
); );
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
Streams.pipeAll(decryptionStream, out); assertThrows(ModificationDetectionException.class, () -> {
assertThrows(ModificationDetectionException.class, decryptionStream::close); Streams.pipeAll(decryptionStream, out);
decryptionStream.close();
});
} }
@TestTemplate @TestTemplate
@ -269,8 +271,10 @@ public class ModificationDetectionTests {
); );
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
Streams.pipeAll(decryptionStream, out); assertThrows(ModificationDetectionException.class, () -> {
assertThrows(ModificationDetectionException.class, decryptionStream::close); Streams.pipeAll(decryptionStream, out);
decryptionStream.close();
});
} }
@TestTemplate @TestTemplate
@ -313,8 +317,10 @@ public class ModificationDetectionTests {
); );
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
Streams.pipeAll(decryptionStream, out); assertThrows(ModificationDetectionException.class, () -> {
assertThrows(ModificationDetectionException.class, decryptionStream::close); Streams.pipeAll(decryptionStream, out);
decryptionStream.close();
});
} }
@TestTemplate @TestTemplate
@ -344,8 +350,10 @@ public class ModificationDetectionTests {
); );
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
Streams.pipeAll(decryptionStream, out); assertThrows(ModificationDetectionException.class, () -> {
assertThrows(ModificationDetectionException.class, decryptionStream::close); Streams.pipeAll(decryptionStream, out);
decryptionStream.close();
});
} }
@TestTemplate @TestTemplate