diff --git a/src/Geralt.Tests/XChaCha20Tests.cs b/src/Geralt.Tests/XChaCha20Tests.cs index 19653b1..d936869 100644 --- a/src/Geralt.Tests/XChaCha20Tests.cs +++ b/src/Geralt.Tests/XChaCha20Tests.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Security.Cryptography; using Microsoft.VisualStudio.TestTools.UnitTesting; namespace Geralt.Tests; @@ -36,7 +37,7 @@ public static IEnumerable DraftXChaChaEncryptTestVectors() "5468652064686f6c65202870726f6e6f756e6365642022646f6c65222920697320616c736f206b6e6f776e2061732074686520417369617469632077696c6420646f672c2072656420646f672c20616e642077686973746c696e6720646f672e2049742069732061626f7574207468652073697a65206f662061204765726d616e20736865706865726420627574206c6f6f6b73206d6f7265206c696b652061206c6f6e672d6c656767656420666f782e205468697320686967686c7920656c757369766520616e6420736b696c6c6564206a756d70657220697320636c6173736966696564207769746820776f6c7665732c20636f796f7465732c206a61636b616c732c20616e6420666f78657320696e20746865207461786f6e6f6d69632066616d696c792043616e696461652e", "404142434445464748494a4b4c4d4e4f5051525354555658", "808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f", - (uint)0 + (ulong)0 }; yield return new object[] { @@ -44,18 +45,19 @@ public static IEnumerable DraftXChaChaEncryptTestVectors() "5468652064686f6c65202870726f6e6f756e6365642022646f6c65222920697320616c736f206b6e6f776e2061732074686520417369617469632077696c6420646f672c2072656420646f672c20616e642077686973746c696e6720646f672e2049742069732061626f7574207468652073697a65206f662061204765726d616e20736865706865726420627574206c6f6f6b73206d6f7265206c696b652061206c6f6e672d6c656767656420666f782e205468697320686967686c7920656c757369766520616e6420736b696c6c6564206a756d70657220697320636c6173736966696564207769746820776f6c7665732c20636f796f7465732c206a61636b616c732c20616e6420666f78657320696e20746865207461786f6e6f6d69632066616d696c792043616e696461652e", "404142434445464748494a4b4c4d4e4f5051525354555658", "808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f", - (uint)1 + (ulong)1 }; } public static IEnumerable EncryptInvalidParameterSizes() { - yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize + 1, XChaCha20.NonceSize, XChaCha20.KeySize, (uint)0 }; - yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize - 1, XChaCha20.NonceSize, XChaCha20.KeySize, (uint)0 }; - yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize + 1, XChaCha20.KeySize, (uint)0 }; - yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize - 1, XChaCha20.KeySize, (uint)0 }; - yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize, XChaCha20.KeySize + 1, (uint)0 }; - yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize, XChaCha20.KeySize - 1, (uint)0 }; + yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize + 1, XChaCha20.NonceSize, XChaCha20.KeySize, (ulong)0 }; + yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize - 1, XChaCha20.NonceSize, XChaCha20.KeySize, (ulong)0 }; + yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize + 1, XChaCha20.KeySize, (ulong)0 }; + yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize - 1, XChaCha20.KeySize, (ulong)0 }; + yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize, XChaCha20.KeySize + 1, (ulong)0 }; + yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize, XChaCha20.KeySize - 1, (ulong)0 }; + yield return new object[] { XChaCha20.BlockSize, XChaCha20.BlockSize, XChaCha20.NonceSize, XChaCha20.KeySize, ulong.MaxValue }; } [TestMethod] @@ -92,7 +94,7 @@ public void Fill_Invalid(int bufferSize, int nonceSize, int keySize) [TestMethod] [DynamicData(nameof(DraftXChaChaEncryptTestVectors), DynamicDataSourceType.Method)] - public void Encrypt_Valid(string ciphertext, string plaintext, string nonce, string key, uint counter) + public void Encrypt_Valid(string ciphertext, string plaintext, string nonce, string key, ulong counter) { Span c = stackalloc byte[ciphertext.Length / 2]; Span p = Convert.FromHexString(plaintext); @@ -106,19 +108,24 @@ public void Encrypt_Valid(string ciphertext, string plaintext, string nonce, str [TestMethod] [DynamicData(nameof(EncryptInvalidParameterSizes), DynamicDataSourceType.Method)] - public void Encrypt_Invalid(int ciphertextSize, int plaintextSize, int nonceSize, int keySize, uint counter) + public void Encrypt_Invalid(int ciphertextSize, int plaintextSize, int nonceSize, int keySize, ulong counter) { var c = new byte[ciphertextSize]; var p = new byte[plaintextSize]; var n = new byte[nonceSize]; var k = new byte[keySize]; - Assert.ThrowsException(() => XChaCha20.Encrypt(c, p, n, k, counter)); + if (counter < ulong.MaxValue) { + Assert.ThrowsException(() => XChaCha20.Encrypt(c, p, n, k, counter)); + } + else { + Assert.ThrowsException(() => XChaCha20.Encrypt(c, p, n, k, counter)); + } } [TestMethod] [DynamicData(nameof(DraftXChaChaEncryptTestVectors), DynamicDataSourceType.Method)] - public void Decrypt_Valid(string ciphertext, string plaintext, string nonce, string key, uint counter) + public void Decrypt_Valid(string ciphertext, string plaintext, string nonce, string key, ulong counter) { Span p = stackalloc byte[plaintext.Length / 2]; Span c = Convert.FromHexString(ciphertext); @@ -132,13 +139,18 @@ public void Decrypt_Valid(string ciphertext, string plaintext, string nonce, str [TestMethod] [DynamicData(nameof(EncryptInvalidParameterSizes), DynamicDataSourceType.Method)] - public void Decrypt_Invalid(int ciphertextSize, int plaintextSize, int nonceSize, int keySize, uint counter) + public void Decrypt_Invalid(int ciphertextSize, int plaintextSize, int nonceSize, int keySize, ulong counter) { var p = new byte[plaintextSize]; var c = new byte[ciphertextSize]; var n = new byte[nonceSize]; var k = new byte[keySize]; - Assert.ThrowsException(() => XChaCha20.Decrypt(p, c, n, k, counter)); + if (counter < ulong.MaxValue) { + Assert.ThrowsException(() => XChaCha20.Decrypt(p, c, n, k, counter)); + } + else { + Assert.ThrowsException(() => XChaCha20.Decrypt(p, c, n, k, counter)); + } } } diff --git a/src/Geralt/Crypto/ChaCha20.cs b/src/Geralt/Crypto/ChaCha20.cs index d053526..8d87cfa 100644 --- a/src/Geralt/Crypto/ChaCha20.cs +++ b/src/Geralt/Crypto/ChaCha20.cs @@ -27,7 +27,7 @@ public static unsafe void Encrypt(Span ciphertext, ReadOnlySpan plai Validation.EqualToSize(nameof(ciphertext), ciphertext.Length, plaintext.Length); Validation.EqualToSize(nameof(nonce), nonce.Length, NonceSize); Validation.EqualToSize(nameof(key), key.Length, KeySize); - CounterOverflow(plaintext.Length, counter); + ThrowIfCounterOverflow(plaintext.Length, counter); Sodium.Initialize(); fixed (byte* c = ciphertext, p = plaintext, n = nonce, k = key) { @@ -41,7 +41,7 @@ public static unsafe void Decrypt(Span plaintext, ReadOnlySpan ciphe Validation.EqualToSize(nameof(plaintext), plaintext.Length, ciphertext.Length); Validation.EqualToSize(nameof(nonce), nonce.Length, NonceSize); Validation.EqualToSize(nameof(key), key.Length, KeySize); - CounterOverflow(ciphertext.Length, counter); + ThrowIfCounterOverflow(ciphertext.Length, counter); Sodium.Initialize(); fixed (byte* p = plaintext, c = ciphertext, n = nonce, k = key) { @@ -50,7 +50,7 @@ public static unsafe void Decrypt(Span plaintext, ReadOnlySpan ciphe } } - private static void CounterOverflow(int messageSize, uint counter) + private static void ThrowIfCounterOverflow(int messageSize, uint counter) { long blockCount = (-1L + messageSize + BlockSize) / BlockSize; if (counter + blockCount > uint.MaxValue) diff --git a/src/Geralt/Crypto/XChaCha20.cs b/src/Geralt/Crypto/XChaCha20.cs index 71725f0..bbd4b03 100644 --- a/src/Geralt/Crypto/XChaCha20.cs +++ b/src/Geralt/Crypto/XChaCha20.cs @@ -27,6 +27,7 @@ public static unsafe void Encrypt(Span ciphertext, ReadOnlySpan plai Validation.EqualToSize(nameof(ciphertext), ciphertext.Length, plaintext.Length); Validation.EqualToSize(nameof(nonce), nonce.Length, NonceSize); Validation.EqualToSize(nameof(key), key.Length, KeySize); + ThrowIfCounterOverflow(plaintext.Length, counter); Sodium.Initialize(); fixed (byte* c = ciphertext, p = plaintext, n = nonce, k = key) { @@ -40,6 +41,7 @@ public static unsafe void Decrypt(Span plaintext, ReadOnlySpan ciphe Validation.EqualToSize(nameof(plaintext), plaintext.Length, ciphertext.Length); Validation.EqualToSize(nameof(nonce), nonce.Length, NonceSize); Validation.EqualToSize(nameof(key), key.Length, KeySize); + ThrowIfCounterOverflow(ciphertext.Length, counter); Sodium.Initialize(); fixed (byte* p = plaintext, c = ciphertext, n = nonce, k = key) { @@ -47,4 +49,11 @@ public static unsafe void Decrypt(Span plaintext, ReadOnlySpan ciphe if (ret != 0) { throw new CryptographicException("Error decrypting ciphertext."); } } } + + private static void ThrowIfCounterOverflow(int messageSize, ulong counter) + { + long blockCount = (-1L + messageSize + BlockSize) / BlockSize; + if (ulong.MaxValue - (ulong)blockCount < counter) + throw new CryptographicException("Counter overflow prevented."); + } }