diff --git a/pgpainless-core/src/main/java/org/pgpainless/decryption_verification/OpenPgpMessageInputStream.java b/pgpainless-core/src/main/java/org/pgpainless/decryption_verification/OpenPgpMessageInputStream.java index a4b05bbe..54172e7f 100644 --- a/pgpainless-core/src/main/java/org/pgpainless/decryption_verification/OpenPgpMessageInputStream.java +++ b/pgpainless-core/src/main/java/org/pgpainless/decryption_verification/OpenPgpMessageInputStream.java @@ -272,13 +272,15 @@ public class OpenPgpMessageInputStream extends DecryptionStream { PGPPBEEncryptedData skesk = (PGPPBEEncryptedData) esk; InputStream decrypted = skesk.getDataStream(decryptorFactory); encryptedData.sessionKey = sessionKey; - nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); + IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, skesk, options); + nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy); return true; } else if (esk instanceof PGPPublicKeyEncryptedData) { PGPPublicKeyEncryptedData pkesk = (PGPPublicKeyEncryptedData) esk; InputStream decrypted = pkesk.getDataStream(decryptorFactory); encryptedData.sessionKey = sessionKey; - nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); + IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, pkesk, options); + nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy); return true; } else { throw new RuntimeException("Unknown ESK class type: " + esk.getClass().getName()); @@ -302,7 +304,8 @@ public class OpenPgpMessageInputStream extends DecryptionStream { throwIfUnacceptable(sessionKey.getAlgorithm()); MessageMetadata.EncryptedData encryptedData = new MessageMetadata.EncryptedData(sessionKey.getAlgorithm()); encryptedData.sessionKey = sessionKey; - nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); + IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, skesk, options); + nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy); return true; } catch (UnacceptableAlgorithmException e) { throw e; @@ -334,7 +337,8 @@ public class OpenPgpMessageInputStream extends DecryptionStream { SymmetricKeyAlgorithm.requireFromId(pkesk.getSymmetricAlgorithm(decryptorFactory))); encryptedData.sessionKey = sessionKey; - nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); + IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, pkesk, options); + nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy); return true; } catch (UnacceptableAlgorithmException e) { throw e; @@ -359,7 +363,9 @@ public class OpenPgpMessageInputStream extends DecryptionStream { MessageMetadata.EncryptedData encryptedData = new MessageMetadata.EncryptedData( SymmetricKeyAlgorithm.requireFromId(pkesk.getSymmetricAlgorithm(decryptorFactory))); encryptedData.sessionKey = sessionKey; - nestedInputStream = new OpenPgpMessageInputStream(buffer(decrypted), options, encryptedData, policy); + + IntegrityProtectedInputStream integrityProtected = new IntegrityProtectedInputStream(decrypted, pkesk, options); + nestedInputStream = new OpenPgpMessageInputStream(buffer(integrityProtected), options, encryptedData, policy); return true; } catch (PGPException e) { // hm :/ @@ -491,6 +497,7 @@ public class OpenPgpMessageInputStream extends DecryptionStream { automaton.next(InputAlphabet.EndOfSequence); automaton.assertValid(); + packetInputStream.close(); closed = true; } diff --git a/pgpainless-core/src/main/java/org/pgpainless/decryption_verification/TeeBCPGInputStream.java b/pgpainless-core/src/main/java/org/pgpainless/decryption_verification/TeeBCPGInputStream.java index f80793f0..2efcfc43 100644 --- a/pgpainless-core/src/main/java/org/pgpainless/decryption_verification/TeeBCPGInputStream.java +++ b/pgpainless-core/src/main/java/org/pgpainless/decryption_verification/TeeBCPGInputStream.java @@ -96,6 +96,12 @@ public class TeeBCPGInputStream { return markerPacket; } + + public void close() throws IOException { + this.packetInputStream.close(); + this.delayedTee.close(); + } + public static class DelayedTeeInputStreamInputStream extends InputStream { private int last = -1; @@ -112,8 +118,12 @@ public class TeeBCPGInputStream { if (last != -1) { outputStream.write(last); } - last = inputStream.read(); - return last; + try { + last = inputStream.read(); + return last; + } catch (IOException e) { + return -1; + } } /** @@ -127,5 +137,11 @@ public class TeeBCPGInputStream { } last = -1; } + + @Override + public void close() throws IOException { + inputStream.close(); + outputStream.close(); + } } } diff --git a/pgpainless-core/src/test/java/org/pgpainless/decryption_verification/ModificationDetectionTests.java b/pgpainless-core/src/test/java/org/pgpainless/decryption_verification/ModificationDetectionTests.java index 14a041ff..9ecaa38a 100644 --- a/pgpainless-core/src/test/java/org/pgpainless/decryption_verification/ModificationDetectionTests.java +++ b/pgpainless-core/src/test/java/org/pgpainless/decryption_verification/ModificationDetectionTests.java @@ -238,8 +238,10 @@ public class ModificationDetectionTests { ); ByteArrayOutputStream out = new ByteArrayOutputStream(); - Streams.pipeAll(decryptionStream, out); - assertThrows(ModificationDetectionException.class, decryptionStream::close); + assertThrows(ModificationDetectionException.class, () -> { + Streams.pipeAll(decryptionStream, out); + decryptionStream.close(); + }); } @TestTemplate @@ -269,8 +271,10 @@ public class ModificationDetectionTests { ); ByteArrayOutputStream out = new ByteArrayOutputStream(); - Streams.pipeAll(decryptionStream, out); - assertThrows(ModificationDetectionException.class, decryptionStream::close); + assertThrows(ModificationDetectionException.class, () -> { + Streams.pipeAll(decryptionStream, out); + decryptionStream.close(); + }); } @TestTemplate @@ -313,8 +317,10 @@ public class ModificationDetectionTests { ); ByteArrayOutputStream out = new ByteArrayOutputStream(); - Streams.pipeAll(decryptionStream, out); - assertThrows(ModificationDetectionException.class, decryptionStream::close); + assertThrows(ModificationDetectionException.class, () -> { + Streams.pipeAll(decryptionStream, out); + decryptionStream.close(); + }); } @TestTemplate @@ -344,8 +350,10 @@ public class ModificationDetectionTests { ); ByteArrayOutputStream out = new ByteArrayOutputStream(); - Streams.pipeAll(decryptionStream, out); - assertThrows(ModificationDetectionException.class, decryptionStream::close); + assertThrows(ModificationDetectionException.class, () -> { + Streams.pipeAll(decryptionStream, out); + decryptionStream.close(); + }); } @TestTemplate