From 536f5af0c57c963f201283e1a9343518e8995565 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matev=C5=BE=20Jekovec?= Date: Fri, 28 Jun 2024 13:24:16 +0200 Subject: [PATCH] backend: On-chain SIWE message parsing --- backend/contracts/DateTime.sol | 218 ++++++++++++++++++++++++++ backend/contracts/MessageBox.sol | 51 +----- backend/contracts/SiweAuth.sol | 257 +++++++++++++++++++++++++++++++ backend/hardhat.config.ts | 3 +- backend/package.json | 1 + backend/test/MessageBox.ts | 26 +++- 6 files changed, 505 insertions(+), 51 deletions(-) create mode 100644 backend/contracts/DateTime.sol create mode 100644 backend/contracts/SiweAuth.sol diff --git a/backend/contracts/DateTime.sol b/backend/contracts/DateTime.sol new file mode 100644 index 0000000..55aab86 --- /dev/null +++ b/backend/contracts/DateTime.sol @@ -0,0 +1,218 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +// https://github.com/pipermerriam/ethereum-datetime/blob/master/contracts/DateTime.sol +contract DateTime { + /* + * Date and Time utilities for ethereum contracts + * + */ + struct _DateTime { + uint16 year; + uint8 month; + uint8 day; + uint8 hour; + uint8 minute; + uint8 second; + uint8 weekday; + } + + uint constant DAY_IN_SECONDS = 86400; + uint constant YEAR_IN_SECONDS = 31536000; + uint constant LEAP_YEAR_IN_SECONDS = 31622400; + + uint constant HOUR_IN_SECONDS = 3600; + uint constant MINUTE_IN_SECONDS = 60; + + uint16 constant ORIGIN_YEAR = 1970; + + function isLeapYear(uint16 year) public pure returns (bool) { + if (year % 4 != 0) { + return false; + } + if (year % 100 != 0) { + return true; + } + if (year % 400 != 0) { + return false; + } + return true; + } + + function leapYearsBefore(uint year) public pure returns (uint) { + year -= 1; + return year / 4 - year / 100 + year / 400; + } + + function getDaysInMonth(uint8 month, uint16 year) public pure returns (uint8) { + if (month == 1 || month == 3 || month == 5 || month == 7 || month == 8 || month == 10 || month == 12) { + return 31; + } + else if (month == 4 || month == 6 || month == 9 || month == 11) { + return 30; + } + else if (isLeapYear(year)) { + return 29; + } + else { + return 28; + } + } + + function parseTimestamp(uint timestamp) internal pure returns (_DateTime memory dt) { + uint secondsAccountedFor = 0; + uint buf; + uint8 i; + + // Year + dt.year = getYear(timestamp); + buf = leapYearsBefore(dt.year) - leapYearsBefore(ORIGIN_YEAR); + + secondsAccountedFor += LEAP_YEAR_IN_SECONDS * buf; + secondsAccountedFor += YEAR_IN_SECONDS * (dt.year - ORIGIN_YEAR - buf); + + // Month + uint secondsInMonth; + for (i = 1; i <= 12; i++) { + secondsInMonth = DAY_IN_SECONDS * getDaysInMonth(i, dt.year); + if (secondsInMonth + secondsAccountedFor > timestamp) { + dt.month = i; + break; + } + secondsAccountedFor += secondsInMonth; + } + + // Day + for (i = 1; i <= getDaysInMonth(dt.month, dt.year); i++) { + if (DAY_IN_SECONDS + secondsAccountedFor > timestamp) { + dt.day = i; + break; + } + secondsAccountedFor += DAY_IN_SECONDS; + } + + // Hour + dt.hour = getHour(timestamp); + + // Minute + dt.minute = getMinute(timestamp); + + // Second + dt.second = getSecond(timestamp); + + // Day of week. + dt.weekday = getWeekday(timestamp); + } + + function getYear(uint timestamp) public pure returns (uint16) { + uint secondsAccountedFor = 0; + uint16 year; + uint numLeapYears; + + // Year + year = uint16(ORIGIN_YEAR + timestamp / YEAR_IN_SECONDS); + numLeapYears = leapYearsBefore(year) - leapYearsBefore(ORIGIN_YEAR); + + secondsAccountedFor += LEAP_YEAR_IN_SECONDS * numLeapYears; + secondsAccountedFor += YEAR_IN_SECONDS * (year - ORIGIN_YEAR - numLeapYears); + + while (secondsAccountedFor > timestamp) { + if (isLeapYear(uint16(year - 1))) { + secondsAccountedFor -= LEAP_YEAR_IN_SECONDS; + } + else { + secondsAccountedFor -= YEAR_IN_SECONDS; + } + year -= 1; + } + return year; + } + + function getMonth(uint timestamp) public pure returns (uint8) { + return parseTimestamp(timestamp).month; + } + + function getDay(uint timestamp) public pure returns (uint8) { + return parseTimestamp(timestamp).day; + } + + function getHour(uint timestamp) public pure returns (uint8) { + return uint8((timestamp / 60 / 60) % 24); + } + + function getMinute(uint timestamp) public pure returns (uint8) { + return uint8((timestamp / 60) % 60); + } + + function getSecond(uint timestamp) public pure returns (uint8) { + return uint8(timestamp % 60); + } + + function getWeekday(uint timestamp) public pure returns (uint8) { + return uint8((timestamp / DAY_IN_SECONDS + 4) % 7); + } + + function toTimestamp(uint16 year, uint8 month, uint8 day) public pure returns (uint timestamp) { + return toTimestamp(year, month, day, 0, 0, 0); + } + + function toTimestamp(uint16 year, uint8 month, uint8 day, uint8 hour) public pure returns (uint timestamp) { + return toTimestamp(year, month, day, hour, 0, 0); + } + + function toTimestamp(uint16 year, uint8 month, uint8 day, uint8 hour, uint8 minute) public pure returns (uint timestamp) { + return toTimestamp(year, month, day, hour, minute, 0); + } + + function toTimestamp(uint16 year, uint8 month, uint8 day, uint8 hour, uint8 minute, uint8 second) public pure returns (uint timestamp) { + uint16 i; + + // Year + for (i = ORIGIN_YEAR; i < year; i++) { + if (isLeapYear(i)) { + timestamp += LEAP_YEAR_IN_SECONDS; + } + else { + timestamp += YEAR_IN_SECONDS; + } + } + + // Month + uint8[12] memory monthDayCounts; + monthDayCounts[0] = 31; + if (isLeapYear(year)) { + monthDayCounts[1] = 29; + } + else { + monthDayCounts[1] = 28; + } + monthDayCounts[2] = 31; + monthDayCounts[3] = 30; + monthDayCounts[4] = 31; + monthDayCounts[5] = 30; + monthDayCounts[6] = 31; + monthDayCounts[7] = 31; + monthDayCounts[8] = 30; + monthDayCounts[9] = 31; + monthDayCounts[10] = 30; + monthDayCounts[11] = 31; + + for (i = 1; i < month; i++) { + timestamp += DAY_IN_SECONDS * monthDayCounts[i - 1]; + } + + // Day + timestamp += DAY_IN_SECONDS * (day - 1); + + // Hour + timestamp += HOUR_IN_SECONDS * (hour); + + // Minute + timestamp += MINUTE_IN_SECONDS * (minute); + + // Second + timestamp += second; + + return timestamp; + } +} diff --git a/backend/contracts/MessageBox.sol b/backend/contracts/MessageBox.sol index 92af130..e045df3 100644 --- a/backend/contracts/MessageBox.sol +++ b/backend/contracts/MessageBox.sol @@ -1,63 +1,28 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.0; -import "hardhat/console.sol"; -struct Sig { - uint8 v; - bytes32 r; - bytes32 s; -} +import "./SiweAuth.sol"; -contract MessageBox { +contract MessageBox is SiweAuth { string private _message; address public author; - function toAsciiStringAddr(address x) internal pure returns (string memory) { - bytes memory s = new bytes(40); - for (uint i = 0; i < 20; i++) { - bytes1 b = bytes1(uint8(uint(uint160(x)) / (2**(8*(19 - i))))); - bytes1 hi = bytes1(uint8(b) / 16); - bytes1 lo = bytes1(uint8(b) - 16 * uint8(hi)); - s[2*i] = char(hi); - s[2*i+1] = char(lo); - } - return string(s); - } - - function char(bytes1 b) internal pure returns (bytes1 c) { - if (uint8(b) < 10) return bytes1(uint8(b) + 0x30); - else return bytes1(uint8(b) + 0x57); - } - - - function getSiweMsg() external view returns (bytes memory) { - string memory domain="demo-starter"; - string memory uri="http://localhost:5173"; - string memory version="1"; - string memory chainId="0x5afd"; - string memory nonce="1"; - string memory issuedAt="2021-09-30T16:25:24Z"; - - // TODO: contract address needs to be hex case-sensitive checksummed. - bytes memory siweMsg = abi.encodePacked(domain, " wants you to sign in with your Ethereum account:\n0x", toAsciiStringAddr(address(this)), "\n\n\n\nURI: ", uri, "\nVersion: ",version,"\nChain ID: ", chainId, "\nNonce: ", nonce, "\nIssued At: ", issuedAt); - return siweMsg; - } - - modifier _authorOnly(Sig calldata auth) { - bytes memory eip191msg = abi.encodePacked("\x19Ethereum Signed Message:\n", "203", this.getSiweMsg()); - address addr = ecrecover(keccak256(eip191msg), auth.v, auth.r, auth.s); - if (addr != author) { + modifier _authorOnly(bytes calldata bearer) { + if (authMsgSender(bearer) != author) { revert("not allowed"); } _; } + constructor(string memory domain) SiweAuth(domain) { + } + function setMessage(string calldata in_message) external { _message = in_message; author = msg.sender; } - function message(Sig calldata auth) external view _authorOnly(auth) returns (string memory) { + function message(bytes calldata bearer) external view _authorOnly(bearer) returns (string memory) { return _message; } } diff --git a/backend/contracts/SiweAuth.sol b/backend/contracts/SiweAuth.sol new file mode 100644 index 0000000..567db3a --- /dev/null +++ b/backend/contracts/SiweAuth.sol @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.0; + +import "@oasisprotocol/sapphire-contracts/contracts/Sapphire.sol"; +import "./DateTime.sol"; + +struct Sig { + uint8 v; + bytes32 r; + bytes32 s; +} + +struct Bearer { + string domain; + address userAddr; + uint256 validUntil; // in Unix timestamp. +} + +contract SiweAuth { + string _domain; + bytes32 _bearerEncKey; + address _authMsgSender; + DateTime _dateTime; + + uint constant DEFAULT_VALIDITY=24*3600; // in seconds. + + struct ParsedSiweMessage { + bytes schemeDomain; + address addr; + bytes statement; + bytes uri; + bytes version; + uint chainId; + bytes nonce; + bytes issuedAt; + bytes expirationTime; + bytes notBefore; + bytes requestId; + bytes[] resources; + } + + // Converts string containing hex address without 0x prefix to solidity address object. + function _hexStringToAddress(bytes memory s) private pure returns (address) { + require(s.length == 40, "Invalid address length"); + bytes memory r = new bytes(s.length/2); + for (uint i=0; i= bytes1('0') && bytes1(c) <= bytes1('9')) { + return c - uint8(bytes1('0')); + } + if (bytes1(c) >= bytes1('a') && bytes1(c) <= bytes1('f')) { + return 10 + c - uint8(bytes1('a')); + } + if (bytes1(c) >= bytes1('A') && bytes1(c) <= bytes1('F')) { + return 10 + c - uint8(bytes1('A')); + } + return 0; + } + + constructor(string memory in_domain) { + _bearerEncKey = bytes32(Sapphire.randomBytes(32, "")); + _domain = in_domain; + _dateTime = new DateTime(); + } + + // Substring. + function _substr(bytes memory str, uint startIndex, uint endIndex) private pure returns (bytes memory) { + bytes memory result = new bytes(endIndex-startIndex); + for(uint i = startIndex; i < endIndex && i _timestampFromIso(p.notBefore), "not before not reached yet"); + } + require(block.timestamp < _timestampFromIso(p.expirationTime), "expired"); + + if (p.expirationTime.length!=0) { + // Compute expected block number at expiration time. + b.validUntil = _timestampFromIso(p.expirationTime); + } else { + // Otherwise, just take the default validity. + b.validUntil = block.timestamp + DEFAULT_VALIDITY; + } + + bytes memory encB = Sapphire.encrypt(_bearerEncKey, 0, abi.encode(b), ""); + return encB; + } + + // Returns the domain associated with the dApp. + function domain() public view returns (string memory) { + return _domain; + } + + // Validates the bearer token and returns authenticated msg.sender. + function authMsgSender(bytes calldata bearer) internal view returns (address) { + bytes memory bearerEncoded = Sapphire.decrypt(_bearerEncKey, 0, bearer, ""); + Bearer memory b = abi.decode(bearerEncoded, (Bearer)); + require(keccak256(bytes(b.domain))==keccak256(bytes(_domain)), "invalid domain"); + require(b.validUntil>=block.timestamp, "expired"); + return b.userAddr; + } +} diff --git a/backend/hardhat.config.ts b/backend/hardhat.config.ts index c8b0134..dad3a98 100644 --- a/backend/hardhat.config.ts +++ b/backend/hardhat.config.ts @@ -39,13 +39,14 @@ task(TASK_EXPORT_ABIS, async (_args, hre) => { // Unencrypted contract deployment. task('deploy') + .addPositionalParam('domain', 'dApp domain which Metamask will be allowed for signing-in') .setAction(async (args, hre) => { await hre.run('compile'); // For deployment unwrap the provider to enable contract verification. const uwProvider = new JsonRpcProvider(hre.network.config.url); const MessageBox = await hre.ethers.getContractFactory('MessageBox', new hre.ethers.Wallet(accounts[0], uwProvider)); - const messageBox = await MessageBox.deploy(); + const messageBox = await MessageBox.deploy(args.domain); await messageBox.waitForDeployment(); console.log(`MessageBox address: ${await messageBox.getAddress()}`); diff --git a/backend/package.json b/backend/package.json index 1f4fda4..2a76e19 100644 --- a/backend/package.json +++ b/backend/package.json @@ -57,6 +57,7 @@ "npm-run-all": "^4.1.5", "prettier": "^2.8.4", "prettier-plugin-solidity": "1.1.2", + "siwe": "^2.3.2", "solhint": "^3.4.0", "solidity-coverage": "^0.8.2", "ts-node": "^10.9.1", diff --git a/backend/test/MessageBox.ts b/backend/test/MessageBox.ts index 2ea24fc..0f6e531 100644 --- a/backend/test/MessageBox.ts +++ b/backend/test/MessageBox.ts @@ -1,16 +1,28 @@ import { expect } from "chai"; import { config, ethers } from "hardhat"; +import {SiweMessage} from "siwe"; import "@nomicfoundation/hardhat-chai-matchers"; import {MessagePrefix, hashMessage, concat, toUtf8Bytes, toUtf8String} from "ethers"; describe("MessageBox", function () { async function deployMessageBox() { const MessageBox_factory = await ethers.getContractFactory("MessageBox"); - const messageBox = await MessageBox_factory.deploy(); + const messageBox = await MessageBox_factory.deploy("localhost"); await messageBox.waitForDeployment(); return { messageBox }; } + async function siweMsg(): Promise { + return new SiweMessage({ + domain: "localhost", + address: await (await ethers.provider.getSigner(0)).getAddress(), + statement: "I accept the ExampleOrg Terms of Service: http://localhost/tos", + uri: "http://localhost:5173", + version: "1", + chainId: config.networks.hardhat.chainId, + }).toMessage(); + } + it("Should set message", async function () { const {messageBox} = await deployMessageBox(); @@ -19,18 +31,18 @@ describe("MessageBox", function () { // Check, if author is correctly set. expect(await messageBox.author()).to.equal(await (await ethers.provider.getSigner(0)).getAddress()); - const siweMsg = toUtf8String(await messageBox.getSiweMsg()); - // Author should read a message. const accounts = config.networks.hardhat.accounts; const acc = ethers.HDNodeWallet.fromMnemonic(ethers.Mnemonic.fromPhrase(accounts.mnemonic), accounts.path+'/0'); - const auth = ethers.Signature.from(await acc.signMessage(siweMsg)); - expect(await messageBox.message(auth)).to.equal("hello world"); + const sig = ethers.Signature.from(await acc.signMessage(siweMsg())); + const bearer = messageBox.login(await siweMsg(), sig); + expect(await messageBox.message(bearer)).to.equal("hello world"); // Anyone else trying to read the message should fail. const acc2 = ethers.HDNodeWallet.fromMnemonic(ethers.Mnemonic.fromPhrase(accounts.mnemonic), accounts.path+'/1'); - const auth2 = ethers.Signature.from(await acc2.signMessage(siweMsg)) - await expect(messageBox.message(auth2)).to.be.reverted; + const sig2 = ethers.Signature.from(await acc2.signMessage(siweMsg())) + const bearer2 = messageBox.login(await siweMsg(), sig2); + await expect(messageBox.message(bearer2)).to.be.reverted; }); });