From a8f6b5c407ffde40d9a5d5aa8688da00f84aec9f Mon Sep 17 00:00:00 2001 From: Donghang Lu Date: Tue, 8 Nov 2022 17:13:47 -0800 Subject: [PATCH] Implement MPC-AES decryption circuit Differential Revision: D41143077 fbshipit-source-id: b42e7478d27c5d3e1ce871cbc20a909441e26cc9 --- fbpcf/mpc_std_lib/aes_circuit/AesCircuit.h | 2 +- .../mpc_std_lib/aes_circuit/AesCircuit_impl.h | 79 ++++++++++++++++- .../aes_circuit/test/AesCircuitTest.cpp | 87 +++++++++++++++++++ 3 files changed, 164 insertions(+), 4 deletions(-) diff --git a/fbpcf/mpc_std_lib/aes_circuit/AesCircuit.h b/fbpcf/mpc_std_lib/aes_circuit/AesCircuit.h index 551ec4ca..47253b45 100644 --- a/fbpcf/mpc_std_lib/aes_circuit/AesCircuit.h +++ b/fbpcf/mpc_std_lib/aes_circuit/AesCircuit.h @@ -49,7 +49,7 @@ class AesCircuit : public IAesCircuit { void inverseMixColumnsInPlace(WordType& src) const; void shiftRowInPlace(std::array& src) const; - + void inverseShiftRowInPlace(std::array& src) const; #ifdef AES_CIRCUIT_TEST_FRIENDS AES_CIRCUIT_TEST_FRIENDS; #endif diff --git a/fbpcf/mpc_std_lib/aes_circuit/AesCircuit_impl.h b/fbpcf/mpc_std_lib/aes_circuit/AesCircuit_impl.h index 3d6963ab..a4f9002c 100644 --- a/fbpcf/mpc_std_lib/aes_circuit/AesCircuit_impl.h +++ b/fbpcf/mpc_std_lib/aes_circuit/AesCircuit_impl.h @@ -78,11 +78,65 @@ std::vector AesCircuit::encrypt_impl( return convertFromWords(plaintextBlocks); } +// implementation based on https://engineering.purdue.edu/kak/compsec/NewLectures/Lecture8.pdf template std::vector AesCircuit::decrypt_impl( - const std::vector& /* ciphertext */, - const std::vector& /* expandedDecKey */) const { - throw std::runtime_error("Not implemented!"); + const std::vector& ciphertext, + const std::vector& expandedDecKey) const { + // prepare input + auto ciphertextBlocks = convertToWords(ciphertext); + auto roundKeys = convertToWords(expandedDecKey); + size_t blockNo = ciphertextBlocks.size(); + + int round = 10; + // pre-round + for (int block = 0; block < blockNo; ++block) { + for (int word = 0; word < 4; ++word) { + for (int byte = 0; byte < 4; ++byte) { + for (int bit = 0; bit < 8; ++bit) { + ciphertextBlocks[block][word][byte][bit] = + ciphertextBlocks[block][word][byte][bit] ^ + roundKeys[round][word][byte][bit]; + } + } + } + } + // rounds 1 - 10 + for (int round = 9; round >= 0; --round) { + // InverseShiftRows + for (int block = 0; block < blockNo; ++block) { + inverseShiftRowInPlace(ciphertextBlocks[block]); + } + // InverseSbox + for (int block = 0; block < blockNo; ++block) { + for (int word = 0; word < 4; ++word) { + for (int byte = 0; byte < 4; ++byte) { + inverseSBoxInPlace(ciphertextBlocks[block][word][byte]); + } + } + } + // AddRoundKey + for (int block = 0; block < blockNo; ++block) { + for (int word = 0; word < 4; ++word) { + for (int byte = 0; byte < 4; ++byte) { + for (int bit = 0; bit < 8; ++bit) { + ciphertextBlocks[block][word][byte][bit] = + ciphertextBlocks[block][word][byte][bit] ^ + roundKeys[round][word][byte][bit]; + } + } + } + } + // InverseMixColumns except for 10-th Round + if (round != 0) { + for (int block = 0; block < blockNo; ++block) { + for (int word = 0; word < 4; ++word) { + inverseMixColumnsInPlace(ciphertextBlocks[block][word]); + } + } + } + } + return convertFromWords(ciphertextBlocks); } template @@ -494,4 +548,23 @@ void AesCircuit::shiftRowInPlace(std::array& src) const { std::swap(src[1][row], src[0][row]); } +template +void AesCircuit::inverseShiftRowInPlace( + std::array& src) const { + // 1st row is not shifted, 2nd row shifted right by 1 + int row = 1; + std::swap(src[2][row], src[3][row]); + std::swap(src[1][row], src[2][row]); + std::swap(src[0][row], src[1][row]); + // 3rd row shifted right by 2 + row++; + std::swap(src[0][row], src[2][row]); + std::swap(src[1][row], src[3][row]); + // 4th row shifted right by 3 + row++; + std::swap(src[0][row], src[1][row]); + std::swap(src[1][row], src[2][row]); + std::swap(src[2][row], src[3][row]); +} + } // namespace fbpcf::mpc_std_lib::aes_circuit diff --git a/fbpcf/mpc_std_lib/aes_circuit/test/AesCircuitTest.cpp b/fbpcf/mpc_std_lib/aes_circuit/test/AesCircuitTest.cpp index 4cf80aab..14cb6192 100644 --- a/fbpcf/mpc_std_lib/aes_circuit/test/AesCircuitTest.cpp +++ b/fbpcf/mpc_std_lib/aes_circuit/test/AesCircuitTest.cpp @@ -62,6 +62,28 @@ class AesCircuitTests : public AesCircuit { } } + void testInverseShiftRowInPlace(std::vector plaintext) { + std::array, 4>, 4> block; + for (int k = 0; k < 4; ++k) { + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + block[k][i][j] = plaintext[32 * k + 8 * i + j]; + } + } + } + + AesCircuit::inverseShiftRowInPlace(block); + for (int k = 0; k < 4; ++k) { + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + EXPECT_EQ( + block[k][i][j], + plaintext[32 * ((((k - i) % 4) + 4) % 4) + 8 * i + j]); + } + } + } + } + void testWordConversion() { using ByteType = std::array; using WordType = std::array; @@ -159,6 +181,12 @@ TEST(AesCircuitTest, testShiftRowInPlace) { test.testShiftRowInPlace(plaintext); } +TEST(AesCircuitTest, testInverseShiftRowInPlace) { + auto plaintext = generateRandomPlaintext(); + AesCircuitTests test; + test.testInverseShiftRowInPlace(plaintext); +} + TEST(AesCircuitTest, testWordConversion) { AesCircuitTests test; test.testWordConversion(); @@ -352,6 +380,65 @@ TEST(AesCircuitTest, testAesCircuitEncrypt) { testAesCircuitEncrypt(std::make_unique>()); } +void testAesCircuitDecrypt( + std::shared_ptr> AesCircuitFactory) { + auto AesCircuit = AesCircuitFactory->create(); + + std::random_device rd; + std::mt19937_64 e(rd()); + std::uniform_int_distribution dist(0, 0xFF); + size_t blockNo = dist(e); + + // generate random key + __m128i key = _mm_set_epi32(dist(e), dist(e), dist(e), dist(e)); + // generate random plaintext + std::vector plaintext; + plaintext.reserve(blockNo * 16); + for (int i = 0; i < blockNo * 16; ++i) { + plaintext.push_back(dist(e)); + } + std::vector<__m128i> plaintextAES; + loadValueToLocalAes(plaintext, plaintextAES); + + // expand key + engine::util::Aes truthAes(key); + auto expandedKey = truthAes.expandEncryptionKey(key); + // extract key and plaintext + std::vector extractedKeys; + extractedKeys.reserve(176); + for (auto keyb : expandedKey) { + loadValueFromLocalAes(keyb, extractedKeys); + } + + // convert key and plaintext into bool vector + std::vector keyBits; + keyBits.reserve(1408); + int8VecToBinaryVec(extractedKeys, keyBits); + std::vector plaintextBits; + plaintextBits.reserve(blockNo * 128); + int8VecToBinaryVec(plaintext, plaintextBits); + + // encrypt in real aes + truthAes.encryptInPlace(plaintextAES); + + // extract ciphertext in real aes + std::vector ciphertextTruth; + ciphertextTruth.reserve(blockNo * 16); + for (auto b : plaintextAES) { + loadValueFromLocalAes(b, ciphertextTruth); + } + std::vector cipherextBitsTruth; + cipherextBitsTruth.reserve(blockNo * 128); + int8VecToBinaryVec(ciphertextTruth, cipherextBitsTruth); + // decrypt this ciphertext using our decrypt circuit + auto decryptionBits = AesCircuit->decrypt(cipherextBitsTruth, keyBits); + testVectorEq(decryptionBits, plaintextBits); +} + +TEST(AesCircuitTest, testAesCircuitDecrypt) { + testAesCircuitDecrypt(std::make_unique>()); +} + void testAesCircuitCtr( std::shared_ptr> AesCircuitCtrFactory) { auto AesCircuitCtr = AesCircuitCtrFactory->create();