Skip to content

Commit

Permalink
Implement MPC-AES decryption circuit
Browse files Browse the repository at this point in the history
Differential Revision: D41143077

fbshipit-source-id: b42e7478d27c5d3e1ce871cbc20a909441e26cc9
  • Loading branch information
Donghang Lu authored and facebook-github-bot committed Nov 9, 2022
1 parent 156a113 commit a8f6b5c
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 4 deletions.
2 changes: 1 addition & 1 deletion fbpcf/mpc_std_lib/aes_circuit/AesCircuit.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class AesCircuit : public IAesCircuit<BitType> {
void inverseMixColumnsInPlace(WordType& src) const;

void shiftRowInPlace(std::array<WordType, 4>& src) const;

void inverseShiftRowInPlace(std::array<WordType, 4>& src) const;
#ifdef AES_CIRCUIT_TEST_FRIENDS
AES_CIRCUIT_TEST_FRIENDS;
#endif
Expand Down
79 changes: 76 additions & 3 deletions fbpcf/mpc_std_lib/aes_circuit/AesCircuit_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,65 @@ std::vector<BitType> AesCircuit<BitType>::encrypt_impl(
return convertFromWords(plaintextBlocks);
}

// implementation based on https://engineering.purdue.edu/kak/compsec/NewLectures/Lecture8.pdf
template <typename BitType>
std::vector<BitType> AesCircuit<BitType>::decrypt_impl(
const std::vector<BitType>& /* ciphertext */,
const std::vector<BitType>& /* expandedDecKey */) const {
throw std::runtime_error("Not implemented!");
const std::vector<BitType>& ciphertext,
const std::vector<BitType>& expandedDecKey) const {
// prepare input
auto ciphertextBlocks = convertToWords(ciphertext);
auto roundKeys = convertToWords(expandedDecKey);
size_t blockNo = ciphertextBlocks.size();

int round = 10;
// pre-round
for (int block = 0; block < blockNo; ++block) {
for (int word = 0; word < 4; ++word) {
for (int byte = 0; byte < 4; ++byte) {
for (int bit = 0; bit < 8; ++bit) {
ciphertextBlocks[block][word][byte][bit] =
ciphertextBlocks[block][word][byte][bit] ^
roundKeys[round][word][byte][bit];
}
}
}
}
// rounds 1 - 10
for (int round = 9; round >= 0; --round) {
// InverseShiftRows
for (int block = 0; block < blockNo; ++block) {
inverseShiftRowInPlace(ciphertextBlocks[block]);
}
// InverseSbox
for (int block = 0; block < blockNo; ++block) {
for (int word = 0; word < 4; ++word) {
for (int byte = 0; byte < 4; ++byte) {
inverseSBoxInPlace(ciphertextBlocks[block][word][byte]);
}
}
}
// AddRoundKey
for (int block = 0; block < blockNo; ++block) {
for (int word = 0; word < 4; ++word) {
for (int byte = 0; byte < 4; ++byte) {
for (int bit = 0; bit < 8; ++bit) {
ciphertextBlocks[block][word][byte][bit] =
ciphertextBlocks[block][word][byte][bit] ^
roundKeys[round][word][byte][bit];
}
}
}
}
// InverseMixColumns except for 10-th Round
if (round != 0) {
for (int block = 0; block < blockNo; ++block) {
for (int word = 0; word < 4; ++word) {
inverseMixColumnsInPlace(ciphertextBlocks[block][word]);
}
}
}
}
return convertFromWords(ciphertextBlocks);
}

template <typename BitType>
Expand Down Expand Up @@ -494,4 +548,23 @@ void AesCircuit<BitType>::shiftRowInPlace(std::array<WordType, 4>& src) const {
std::swap(src[1][row], src[0][row]);
}

template <typename BitType>
void AesCircuit<BitType>::inverseShiftRowInPlace(
std::array<WordType, 4>& src) const {
// 1st row is not shifted, 2nd row shifted right by 1
int row = 1;
std::swap(src[2][row], src[3][row]);
std::swap(src[1][row], src[2][row]);
std::swap(src[0][row], src[1][row]);
// 3rd row shifted right by 2
row++;
std::swap(src[0][row], src[2][row]);
std::swap(src[1][row], src[3][row]);
// 4th row shifted right by 3
row++;
std::swap(src[0][row], src[1][row]);
std::swap(src[1][row], src[2][row]);
std::swap(src[2][row], src[3][row]);
}

} // namespace fbpcf::mpc_std_lib::aes_circuit
87 changes: 87 additions & 0 deletions fbpcf/mpc_std_lib/aes_circuit/test/AesCircuitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ class AesCircuitTests : public AesCircuit<BitType> {
}
}

void testInverseShiftRowInPlace(std::vector<bool> plaintext) {
std::array<std::array<std::array<bool, 8>, 4>, 4> block;
for (int k = 0; k < 4; ++k) {
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 8; j++) {
block[k][i][j] = plaintext[32 * k + 8 * i + j];
}
}
}

AesCircuit<bool>::inverseShiftRowInPlace(block);
for (int k = 0; k < 4; ++k) {
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 8; j++) {
EXPECT_EQ(
block[k][i][j],
plaintext[32 * ((((k - i) % 4) + 4) % 4) + 8 * i + j]);
}
}
}
}

void testWordConversion() {
using ByteType = std::array<bool, 8>;
using WordType = std::array<ByteType, 4>;
Expand Down Expand Up @@ -159,6 +181,12 @@ TEST(AesCircuitTest, testShiftRowInPlace) {
test.testShiftRowInPlace(plaintext);
}

TEST(AesCircuitTest, testInverseShiftRowInPlace) {
auto plaintext = generateRandomPlaintext();
AesCircuitTests<bool> test;
test.testInverseShiftRowInPlace(plaintext);
}

TEST(AesCircuitTest, testWordConversion) {
AesCircuitTests<bool> test;
test.testWordConversion();
Expand Down Expand Up @@ -352,6 +380,65 @@ TEST(AesCircuitTest, testAesCircuitEncrypt) {
testAesCircuitEncrypt(std::make_unique<AesCircuitFactory<bool>>());
}

void testAesCircuitDecrypt(
std::shared_ptr<AesCircuitFactory<bool>> AesCircuitFactory) {
auto AesCircuit = AesCircuitFactory->create();

std::random_device rd;
std::mt19937_64 e(rd());
std::uniform_int_distribution<uint8_t> dist(0, 0xFF);
size_t blockNo = dist(e);

// generate random key
__m128i key = _mm_set_epi32(dist(e), dist(e), dist(e), dist(e));
// generate random plaintext
std::vector<uint8_t> plaintext;
plaintext.reserve(blockNo * 16);
for (int i = 0; i < blockNo * 16; ++i) {
plaintext.push_back(dist(e));
}
std::vector<__m128i> plaintextAES;
loadValueToLocalAes(plaintext, plaintextAES);

// expand key
engine::util::Aes truthAes(key);
auto expandedKey = truthAes.expandEncryptionKey(key);
// extract key and plaintext
std::vector<uint8_t> extractedKeys;
extractedKeys.reserve(176);
for (auto keyb : expandedKey) {
loadValueFromLocalAes(keyb, extractedKeys);
}

// convert key and plaintext into bool vector
std::vector<bool> keyBits;
keyBits.reserve(1408);
int8VecToBinaryVec(extractedKeys, keyBits);
std::vector<bool> plaintextBits;
plaintextBits.reserve(blockNo * 128);
int8VecToBinaryVec(plaintext, plaintextBits);

// encrypt in real aes
truthAes.encryptInPlace(plaintextAES);

// extract ciphertext in real aes
std::vector<uint8_t> ciphertextTruth;
ciphertextTruth.reserve(blockNo * 16);
for (auto b : plaintextAES) {
loadValueFromLocalAes(b, ciphertextTruth);
}
std::vector<bool> cipherextBitsTruth;
cipherextBitsTruth.reserve(blockNo * 128);
int8VecToBinaryVec(ciphertextTruth, cipherextBitsTruth);
// decrypt this ciphertext using our decrypt circuit
auto decryptionBits = AesCircuit->decrypt(cipherextBitsTruth, keyBits);
testVectorEq(decryptionBits, plaintextBits);
}

TEST(AesCircuitTest, testAesCircuitDecrypt) {
testAesCircuitDecrypt(std::make_unique<AesCircuitFactory<bool>>());
}

void testAesCircuitCtr(
std::shared_ptr<AesCircuitCtrFactory<bool>> AesCircuitCtrFactory) {
auto AesCircuitCtr = AesCircuitCtrFactory->create();
Expand Down

0 comments on commit a8f6b5c

Please sign in to comment.