From bb7dfd496542398bcf38c11e2b9a83da6f22a003 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Thu, 16 Nov 2023 17:58:13 -0500 Subject: [PATCH 1/7] update dependencies and add compute units/priority fee --- poetry.lock | 200 ++++++++++----------- pyproject.toml | 18 +- src/driftpy/accounts/cache/drift_client.py | 4 +- src/driftpy/accounts/cache/user.py | 4 +- src/driftpy/accounts/get_accounts.py | 18 +- src/driftpy/accounts/oracle.py | 14 +- src/driftpy/accounts/types.py | 4 +- src/driftpy/addresses.py | 72 ++++---- src/driftpy/admin.py | 36 ++-- src/driftpy/constants/banks.py | 26 +-- src/driftpy/constants/config.py | 20 +-- src/driftpy/constants/markets.py | 48 ++--- src/driftpy/drift_client.py | 124 +++++++------ src/driftpy/sdk_types.py | 4 +- src/driftpy/setup/helpers.py | 94 +++++----- src/driftpy/types.py | 85 ++++----- tests/test.py | 28 +-- 17 files changed, 407 insertions(+), 392 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4b69ed53..6ff0858e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -141,31 +141,34 @@ frozenlist = ">=1.1.0" [[package]] name = "anchorpy" -version = "0.10.0" +version = "0.17.1" description = "The Python Anchor client." category = "main" optional = false python-versions = ">=3.9,<4.0" files = [ - {file = "anchorpy-0.10.0-py3-none-any.whl", hash = "sha256:cedc32b4d7fbb0697534208eda2aaee2e89cb2eae52e7d3e78e84b37b2bb9d57"}, - {file = "anchorpy-0.10.0.tar.gz", hash = "sha256:6923342c8871331a3731ec885a72155bb6ead2cd3db8d6b39f9cdc9af6f27655"}, + {file = "anchorpy-0.17.1-py3-none-any.whl", hash = "sha256:e95c53113a18e873228ef7673db8333d61df8d7cc961df7252764b5a039591b5"}, + {file = "anchorpy-0.17.1.tar.gz", hash = "sha256:d8e6fbe68d7808433e7d18dccb75d119b396f6f9f75e9cd5f567f6df1f6f53e7"}, ] [package.dependencies] -apischema = ">=0.17.5,<0.18.0" +anchorpy-core = ">=0.1.2,<0.2.0" +based58 = ">=0.1.1,<0.2.0" borsh-construct = ">=0.1.0,<0.2.0" construct-typing = ">=0.5.1,<0.6.0" jsonrpcclient = ">=4.0.1,<5.0.0" more-itertools = ">=8.11.0,<9.0.0" +py = ">=1.11.0,<2.0.0" pyheck = ">=0.1.4,<0.2.0" -pytest = ">=6.2.5,<7.0.0" -pytest-asyncio = ">=0.17.2,<0.18.0" +pytest = ">=7.2.0,<8.0.0" +pytest-asyncio = ">=0.21.0,<0.22.0" pytest-xprocess = ">=0.18.1,<0.19.0" -solana = ">=0.25.0,<0.26.0" -sumtypes = ">=0.1a6,<0.2" +solana = ">=0.30.1,<0.31.0" +solders = ">=0.17.0,<0.18.0" +toml = ">=0.10.2,<0.11.0" toolz = ">=0.11.2,<0.12.0" -websockets = ">=10.0,<11.0" -zstandard = ">=0.17.0,<0.18.0" +websockets = ">=9.0,<11.0" +zstandard = ">=0.18.0,<0.19.0" [package.extras] cli = ["autoflake (>=1.4,<2.0)", "black (>=22.3.0,<23.0.0)", "genpy (>=2021.1,<2022.0)", "ipython (>=8.0.1,<9.0.0)", "typer (==0.4.1)"] @@ -277,17 +280,6 @@ files = [ {file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"}, ] -[[package]] -name = "atomicwrites" -version = "1.4.1" -description = "Atomic file writes." -category = "main" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ - {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"}, -] - [[package]] name = "attrs" version = "22.1.0" @@ -1602,46 +1594,45 @@ files = [ [[package]] name = "pytest" -version = "6.2.5" +version = "7.4.3" description = "pytest: simple powerful testing with Python" category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, - {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, + {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, + {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, ] [package.dependencies] -atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} -attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=0.12,<2.0" -py = ">=1.8.2" -toml = "*" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] -testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] [[package]] name = "pytest-asyncio" -version = "0.17.2" +version = "0.21.1" description = "Pytest support for asyncio" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "pytest-asyncio-0.17.2.tar.gz", hash = "sha256:6d895b02432c028e6957d25fc936494e78c6305736e785d9fee408b1efbc7ff4"}, - {file = "pytest_asyncio-0.17.2-py3-none-any.whl", hash = "sha256:e0fe5dbea40516b661ef1bcfe0bd9461c2847c4ef4bb40012324f2454fb7d56d"}, + {file = "pytest-asyncio-0.21.1.tar.gz", hash = "sha256:40a7eae6dded22c7b604986855ea48400ab15b069ae38116e8c01238e9eeb64d"}, + {file = "pytest_asyncio-0.21.1-py3-none-any.whl", hash = "sha256:8666c1c8ac02631d7c51ba282e0c69a8a452b211ffedf2599099845da5c5c37b"}, ] [package.dependencies] -pytest = ">=6.1.0" +pytest = ">=7.0.0" [package.extras] -testing = ["coverage (==6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (==0.931)"] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] [[package]] name = "pytest-xprocess" @@ -1855,53 +1846,48 @@ files = [ [[package]] name = "solana" -version = "0.25.1" +version = "0.30.1" description = "Solana Python API" category = "main" optional = false python-versions = ">=3.7,<4.0" files = [ - {file = "solana-0.25.1-py3-none-any.whl", hash = "sha256:24d0fb17096e63d1c6eb95bb605802a3eed422626899dc3758d40194c9674da1"}, - {file = "solana-0.25.1.tar.gz", hash = "sha256:f11fdb5c4dfe0e1e1267e30d3c1a580d0570bc8f1bf8113078ce4746e8488447"}, + {file = "solana-0.30.1-py3-none-any.whl", hash = "sha256:b5f4964ec568d118e31384c45f1ecc3e29ed09c07925944070240de3a4e85480"}, + {file = "solana-0.30.1.tar.gz", hash = "sha256:cd72d57278772d41def5ab68c54bccea40c9df40852882b938f3279d815ab51a"}, ] [package.dependencies] -apischema = ">=0.17.5,<0.18.0" -based58 = ">=0.1.0,<0.2.0" cachetools = ">=4.2.2,<5.0.0" construct-typing = ">=0.5.2,<0.6.0" httpx = ">=0.23.0,<0.24.0" -jsonrpcclient = ">=4.0.1,<5.0.0" -jsonrpcserver = ">=5.0.7,<6.0.0" -requests = ">=2.24,<3.0" -solders = ">=0.2.0,<0.3.0" +solders = ">=0.17.0,<0.18.0" types-cachetools = ">=4.2.4,<5.0.0" -typing-extensions = ">=3.10.0" -websockets = ">=10.3,<11.0" +typing-extensions = ">=4.2.0" +websockets = ">=9.0,<12.0" [[package]] name = "solders" -version = "0.2.0" -description = "Python binding to the Solana Rust SDK" +version = "0.17.0" +description = "Python bindings for Solana Rust tools" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "solders-0.2.0-cp37-abi3-macosx_10_7_x86_64.whl", hash = "sha256:6b0a862b4aebf33c39fe68b4242737dc212d5eaa891db1fec466ce7f322027fd"}, - {file = "solders-0.2.0-cp37-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0872f8bccc14d67245d6849df957cbf7ae9fb6afcc61bcdf2076d0ba809c71a2"}, - {file = "solders-0.2.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39ebee6545660b393178d0ef5187b71d743f15cd15214cb7b4dd4177991f0aa4"}, - {file = "solders-0.2.0-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6ddd3da0921fb5a99a7e58899cd1c2f330e086d28b5feb4b367fc25550b5a3b6"}, - {file = "solders-0.2.0-cp37-abi3-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:60ad094b1137e0efbe5e03d5808a596c455cbaf20956d44a7faae2e6da1a269c"}, - {file = "solders-0.2.0-cp37-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4327a6a37a226f572a3d7c5bd5dcb86e11de510dd5f0e0bede9519b08df223a4"}, - {file = "solders-0.2.0-cp37-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d63efde2713379325b547e45db481b002219e005cd7d71c3bb2ed722a03afefa"}, - {file = "solders-0.2.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0117bd3559cb47a89dfcd669da9b2dd9d71d385588936bdaeadf7e1a53290e1c"}, - {file = "solders-0.2.0-cp37-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:841881445d8fe7964c55454e00a8d1ef3c5fe7476656f231eccc8bebe453b650"}, - {file = "solders-0.2.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:88ecad7525fe74513d9305a29f9a61330af3d1460a09322b7d65a187ffc74aa2"}, - {file = "solders-0.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f696c6fcdbe37c4357865aaa5ab843881ead5ea1255cb58ea8bde3a6bd2af0fb"}, - {file = "solders-0.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:4d699d6f2f1854d85c7f1c4dc1b8cba7fd2c4a1fda6c6783b611ac88c99e0f97"}, - {file = "solders-0.2.0.tar.gz", hash = "sha256:ac952fa3d9d71001f1e98489336288d6b8a560c134d7aba98ce521699c844b21"}, + {file = "solders-0.17.0-cp37-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:2222c20300a3b339d0c30cc85a38cb76b9341529c943f8f6e84e37b6c4cf7c3f"}, + {file = "solders-0.17.0-cp37-abi3-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:af069840d7aaafecfa819ab05ba13df1d1fb718650dab0e989bbb7c34b63b839"}, + {file = "solders-0.17.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7ba319d05bdd5802b781326d9cbe70b9af05f935714d38b6595a51154d070b2"}, + {file = "solders-0.17.0-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eab2565dc01ac1474bcb83d43a65e4e38c16cbe762650f77d8d075fae75ee832"}, + {file = "solders-0.17.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5002c1785ad301815d49a771ebd540bba15a8b30efbcef2f43456731e36a7df1"}, + {file = "solders-0.17.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:d3d1559134569bbf69a8fe86509e8bf94e6a73c8de796d8de4b29aaae78f0a7a"}, + {file = "solders-0.17.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:da0d8ff4e21b3acebb4c5b3ca097659c54358cb97df087fa9fdf717710425759"}, + {file = "solders-0.17.0-cp37-abi3-win_amd64.whl", hash = "sha256:da64d20c0eb14b1b1e39bbe82b878a03e71c076f71345ff2faf5ca87f11fa431"}, + {file = "solders-0.17.0.tar.gz", hash = "sha256:a898804637769ec22518f2667b961230a68ff3264d6ccf90d9c94942c12a0ea1"}, ] +[package.dependencies] +jsonalias = "0.1.1" +typing-extensions = ">=4.2.0" + [[package]] name = "sumtypes" version = "0.1a6" @@ -2245,56 +2231,56 @@ multidict = ">=4.0" [[package]] name = "zstandard" -version = "0.17.0" +version = "0.18.0" description = "Zstandard bindings for Python" category = "main" optional = false python-versions = ">=3.6" files = [ - {file = "zstandard-0.17.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a1991cdf2e81e643b53fb8d272931d2bdf5f4e70d56a457e1ef95bde147ae627"}, - {file = "zstandard-0.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4768449d8d1b0785309ace288e017cc5fa42e11a52bf08c90d9c3eb3a7a73cc6"}, - {file = "zstandard-0.17.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1ad6d2952b41d9a0ea702a474cc08c05210c6289e29dd496935c9ca3c7fb45c"}, - {file = "zstandard-0.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90a9ba3a9c16b86afcb785b3c9418af39ccfb238fd5f6e429166e3ca8542b01f"}, - {file = "zstandard-0.17.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:9cf18c156b3a108197a8bf90b37d03c31c8ef35a7c18807b321d96b74e12c301"}, - {file = "zstandard-0.17.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c81fd9386449df0ebf1ab3e01187bb30d61122c74df53ba4880a2454d866e55d"}, - {file = "zstandard-0.17.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:787efc741e61e00ffe5e65dac99b0dc5c88b9421012a207a91b869a8b1164921"}, - {file = "zstandard-0.17.0-cp310-cp310-win32.whl", hash = "sha256:49cd09ccbd1e3c0e2690dd62ebf95064d84aa42b9db381867e0b138631f969f2"}, - {file = "zstandard-0.17.0-cp310-cp310-win_amd64.whl", hash = "sha256:d78aac2ffc4e88ab1cbcad844669924c24e24c7c255de9628a18f14d832007c5"}, - {file = "zstandard-0.17.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:c19d1e06569c277dcc872d80cbadf14a29e8199e013ff2a176d169f461439a40"}, - {file = "zstandard-0.17.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d916018289d2f9a882e90d2e3bd41652861ce11b5ecd8515fa07ad31d97d56e5"}, - {file = "zstandard-0.17.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0c87f097d6867833a839b086eb8d03676bb87c2efa067a131099f04aa790683"}, - {file = "zstandard-0.17.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:60943f71e3117583655a1eb76188a7cc78a25267ef09cc74be4d25a0b0c8b947"}, - {file = "zstandard-0.17.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:208fa6bead577b2607205640078ee452e81fe20fe96321623c632bad9ebd7148"}, - {file = "zstandard-0.17.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:42f3c02c7021073cafbc6cd152b288c56a25e585518861589bb08b063b6d2ad2"}, - {file = "zstandard-0.17.0-cp36-cp36m-win32.whl", hash = "sha256:2a2ac752162ba5cbc869c60c4a4e54e890b2ee2ffb57d3ff159feab1ae4518db"}, - {file = "zstandard-0.17.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d1405caa964ba11b2396bd9fd19940440217345752e192c936d084ba5fe67dcb"}, - {file = "zstandard-0.17.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ef62eb3bcfd6d786f439828bb544ebd3936432db669403e0b8f48e424f1d55f1"}, - {file = "zstandard-0.17.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:477f172807a9fa83467b30d7c58876af1410d20177c554c27525211edf535bae"}, - {file = "zstandard-0.17.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de1aa618306a741e0497878b7f845fd6c397e52dd096fb76ed791e7268887176"}, - {file = "zstandard-0.17.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:a827b9c464ee966524f8e82ec1aabb4a77ff9514cae041667fa81ae2ec8bd3e9"}, - {file = "zstandard-0.17.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cf96ace804945e53bc3e5294097e5fa32a2d43bc52416c632b414b870ee0a21"}, - {file = "zstandard-0.17.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:802109f67328c5b822d4fdac28e1cf65a24de2e2e99d76cdbeee9121cedb1b6c"}, - {file = "zstandard-0.17.0-cp37-cp37m-win32.whl", hash = "sha256:a628f20d019feb0f3a171c7a55cc4f75681f3b8c1bd7a5009165a487314887cd"}, - {file = "zstandard-0.17.0-cp37-cp37m-win_amd64.whl", hash = "sha256:7d2e7abac41d2b4b18f03575aca860d2cb647c343e13c23d6c769106a3db2f6f"}, - {file = "zstandard-0.17.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f502fe79757434292174b04db114f9e25c767b2d5ca9e759d118b22a66f445f8"}, - {file = "zstandard-0.17.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e37c4e21f696d6bcdbbc7caf98dffa505d04c0053909b9db0a6e8ca3b935eb07"}, - {file = "zstandard-0.17.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8fd386d0ec1f9343f1776391d9e60d4eedced0a0b0e625bb89b91f6d05f70e83"}, - {file = "zstandard-0.17.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91a228a077fc7cd8486c273788d4a006a37d060cb4293f471eb0325c3113af68"}, - {file = "zstandard-0.17.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:59eadb9f347d40e8f7ef77caffd0c04a31e82c1df82fe2d2a688032429d750ac"}, - {file = "zstandard-0.17.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a71809ec062c5b7acf286ba6d4484e6fe8130fc2b93c25e596bb34e7810c79b2"}, - {file = "zstandard-0.17.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8aedd38d357f6d5e2facd88ce62b4976afdc29db57216a23f14a0cd0ca05a8a3"}, - {file = "zstandard-0.17.0-cp38-cp38-win32.whl", hash = "sha256:bd842ae3dbb7cba88beb022161c819fa80ca7d0c5a4ddd209e7daae85d904e49"}, - {file = "zstandard-0.17.0-cp38-cp38-win_amd64.whl", hash = "sha256:d0e9fec68e304fb35c559c44530213adbc7d5918bdab906a45a0f40cd56c4de2"}, - {file = "zstandard-0.17.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9ec62a4c2dbb0a86ee5138c16ef133e59a23ac108f8d7ac97aeb61d410ce6857"}, - {file = "zstandard-0.17.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d5373a56b90052f171c8634fedc53a6ac371e6c742606e9825772a394bdbd4b0"}, - {file = "zstandard-0.17.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2e3ea5e4d5ecf3faefd4a5294acb6af1f0578b0cdd75d6b4529c45deaa54d6f"}, - {file = "zstandard-0.17.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a3a1aa9528087f6f4c47f4ece2d5e6a160527821263fb8174ff36429233e093"}, - {file = "zstandard-0.17.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:bdf691a205bc492956e6daef7a06fb38f8cbe8b2c1cb0386f35f4412c360c9e9"}, - {file = "zstandard-0.17.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:db993a56e21d903893933887984ca9b0d274f2b1db7b3cf21ba129783953864f"}, - {file = "zstandard-0.17.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a7756a9446f83c81101f6c0a48c3bfd8d387a249933c57b0d095ca8b20541337"}, - {file = "zstandard-0.17.0-cp39-cp39-win32.whl", hash = "sha256:37e50501baaa935f13a1820ab2114f74313b5cb4cfff8146acb8c5b18cdced2a"}, - {file = "zstandard-0.17.0-cp39-cp39-win_amd64.whl", hash = "sha256:b4e671c4c0804cdf752be26f260058bb858fbdaaef1340af170635913ecca01e"}, - {file = "zstandard-0.17.0.tar.gz", hash = "sha256:fa9194cb91441df7242aa3ddc4cb184be38876cb10dd973674887f334bafbfb6"}, + {file = "zstandard-0.18.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ef7e8a200e4c8ac9102ed3c90ed2aa379f6b880f63032200909c1be21951f556"}, + {file = "zstandard-0.18.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2dc466207016564805e56d28375f4f533b525ff50d6776946980dff5465566ac"}, + {file = "zstandard-0.18.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a2ee1d4f98447f3e5183ecfce5626f983504a4a0c005fbe92e60fa8e5d547ec"}, + {file = "zstandard-0.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d956e2f03c7200d7e61345e0880c292783ec26618d0d921dcad470cb195bbce2"}, + {file = "zstandard-0.18.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:ce6f59cba9854fd14da5bfe34217a1501143057313966637b7291d1b0267bd1e"}, + {file = "zstandard-0.18.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7fa67cba473623848b6e88acf8d799b1906178fd883fb3a1da24561c779593b"}, + {file = "zstandard-0.18.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cdb44d7284c8c5dd1b66dfb86dda7f4560fa94bfbbc1d2da749ba44831335e32"}, + {file = "zstandard-0.18.0-cp310-cp310-win32.whl", hash = "sha256:63694a376cde0aa8b1971d06ca28e8f8b5f492779cb6ee1cc46bbc3f019a42a5"}, + {file = "zstandard-0.18.0-cp310-cp310-win_amd64.whl", hash = "sha256:702a8324cd90c74d9c8780d02bf55e79da3193c870c9665ad3a11647e3ad1435"}, + {file = "zstandard-0.18.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:46f679bc5dfd938db4fb058218d9dc4db1336ffaf1ea774ff152ecadabd40805"}, + {file = "zstandard-0.18.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc2a4de9f363b3247d472362a65041fe4c0f59e01a2846b15d13046be866a885"}, + {file = "zstandard-0.18.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd3220d7627fd4d26397211cb3b560ec7cc4a94b75cfce89e847e8ce7fabe32d"}, + {file = "zstandard-0.18.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:39e98cf4773234bd9cebf9f9db730e451dfcfe435e220f8921242afda8321887"}, + {file = "zstandard-0.18.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5228e596eb1554598c872a337bbe4e5afe41cd1f8b1b15f2e35b50d061e35244"}, + {file = "zstandard-0.18.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d4a8fd45746a6c31e729f35196e80b8f1e9987c59f5ccb8859d7c6a6fbeb9c63"}, + {file = "zstandard-0.18.0-cp36-cp36m-win32.whl", hash = "sha256:4cbb85f29a990c2fdbf7bc63246567061a362ddca886d7fae6f780267c0a9e67"}, + {file = "zstandard-0.18.0-cp36-cp36m-win_amd64.whl", hash = "sha256:bfa6c8549fa18e6497a738b7033c49f94a8e2e30c5fbe2d14d0b5aa8bbc1695d"}, + {file = "zstandard-0.18.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e02043297c1832f2666cd2204f381bef43b10d56929e13c42c10c732c6e3b4ed"}, + {file = "zstandard-0.18.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7231543d38d2b7e02ef7cc78ef7ffd86419437e1114ff08709fe25a160e24bd6"}, + {file = "zstandard-0.18.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c86befac87445927488f5c8f205d11566f64c11519db223e9d282b945fa60dab"}, + {file = "zstandard-0.18.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:999a4e1768f219826ba3fa2064fab1c86dd72fdd47a42536235478c3bb3ca3e2"}, + {file = "zstandard-0.18.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9df59cd1cf3c62075ee2a4da767089d19d874ac3ad42b04a71a167e91b384722"}, + {file = "zstandard-0.18.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1be31e9e3f7607ee0cdd60915410a5968b205d3e7aa83b7fcf3dd76dbbdb39e0"}, + {file = "zstandard-0.18.0-cp37-cp37m-win32.whl", hash = "sha256:490d11b705b8ae9dc845431bacc8dd1cef2408aede176620a5cd0cd411027936"}, + {file = "zstandard-0.18.0-cp37-cp37m-win_amd64.whl", hash = "sha256:266aba27fa9cc5e9091d3d325ebab1fa260f64e83e42516d5e73947c70216a5b"}, + {file = "zstandard-0.18.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8b2260c4e07dd0723eadb586de7718b61acca4083a490dda69c5719d79bc715c"}, + {file = "zstandard-0.18.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3af8c2383d02feb6650e9255491ec7d0824f6e6dd2bbe3e521c469c985f31fb1"}, + {file = "zstandard-0.18.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28723a1d2e4df778573b76b321ebe9f3469ac98988104c2af116dd344802c3f8"}, + {file = "zstandard-0.18.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19cac7108ff2c342317fad6dc97604b47a41f403c8f19d0bfc396dfadc3638b8"}, + {file = "zstandard-0.18.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:76725d1ee83a8915100a310bbad5d9c1fc6397410259c94033b8318d548d9990"}, + {file = "zstandard-0.18.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d716a7694ce1fa60b20bc10f35c4a22be446ef7f514c8dbc8f858b61976de2fb"}, + {file = "zstandard-0.18.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:49685bf9a55d1ab34bd8423ea22db836ba43a181ac6b045ac4272093d5cb874e"}, + {file = "zstandard-0.18.0-cp38-cp38-win32.whl", hash = "sha256:1af1268a7dc870eb27515fb8db1f3e6c5a555d2b7bcc476fc3bab8886c7265ab"}, + {file = "zstandard-0.18.0-cp38-cp38-win_amd64.whl", hash = "sha256:1dc2d3809e763055a1a6c1a73f2b677320cc9a5aa1a7c6cfb35aee59bddc42d9"}, + {file = "zstandard-0.18.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eea18c1e7442f2aa9aff1bb84550dbb6a1f711faf6e48e7319de8f2b2e923c2a"}, + {file = "zstandard-0.18.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8677ffc6a6096cccbd892e558471c901fd821aba12b7fbc63833c7346f549224"}, + {file = "zstandard-0.18.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:083dc08abf03807af9beeb2b6a91c23ad78add2499f828176a3c7b742c44df02"}, + {file = "zstandard-0.18.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c990063664c08169c84474acecc9251ee035871589025cac47c060ff4ec4bc1a"}, + {file = "zstandard-0.18.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:533db8a6fac6248b2cb2c935e7b92f994efbdeb72e1ffa0b354432e087bb5a3e"}, + {file = "zstandard-0.18.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dbb3cb8a082d62b8a73af42291569d266b05605e017a3d8a06a0e5c30b5f10f0"}, + {file = "zstandard-0.18.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:d6c85ca5162049ede475b7ec98e87f9390501d44a3d6776ddd504e872464ec25"}, + {file = "zstandard-0.18.0-cp39-cp39-win32.whl", hash = "sha256:75479e7c2b3eebf402c59fbe57d21bc400cefa145ca356ee053b0a08908c5784"}, + {file = "zstandard-0.18.0-cp39-cp39-win_amd64.whl", hash = "sha256:d85bfabad444812133a92fc6fbe463e1d07581dba72f041f07a360e63808b23c"}, + {file = "zstandard-0.18.0.tar.gz", hash = "sha256:0ac0357a0d985b4ff31a854744040d7b5754385d1f98f7145c30e02c6865cb6f"}, ] [package.dependencies] @@ -2306,4 +2292,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "6d83015b97ae3e161aa2002417f145c09160a3dfa77badd5496e83b2928be3db" +content-hash = "fc879d686ae6b2d2eedcde15487706afc24e62d3c5fd2a2ca0181867f18f7593" diff --git a/pyproject.toml b/pyproject.toml index 26374b85..a0b26c83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,8 @@ documentation = "https://drift-labs.github.io/driftpy/" [tool.poetry.dependencies] python = "^3.10" -anchorpy = "0.10.0" -solana = "0.25.1" +anchorpy = "0.17.1" +solana = "0.30.1" requests = "2.28.1" types-requests = "^2.28.9" mkdocs = "^1.3.0" @@ -62,12 +62,12 @@ pycparser = "2.21" pyflakes = "3.0.1" pyheck = "0.1.5" pyrsistent = "0.19.2" -pytest = "6.2.5" -pytest-asyncio = "0.17.2" -pytest-xprocess = "0.18.1" +pytest = "^7.2.0" +pytest-xprocess = "^0.18.1" +pytest-asyncio = "^0.21.0" rfc3986 = "1.5.0" sniffio = "1.3.0" -solders = "0.2.0" +solders = "0.17.0" sumtypes = "0.1a6" toml = "0.10.2" tomli = "2.0.1" @@ -77,15 +77,15 @@ typing-extensions = "4.4.0" urllib3 = "1.26.13" websockets = "10.4" yarl = "1.8.2" -zstandard = "0.17.0" +zstandard = "0.18.0" jinja2 = "<3.1" [tool.poetry.dev-dependencies] -pytest = "^6.2.5" +pytest = "^7.2.0" flake8 = "6.0.0" mypy = "^0.931" black = "^23.3.0" -pytest-asyncio = "^0.17.2" +pytest-asyncio = "^0.21.0" mkdocs = "^1.3.0" mkdocstrings = "^0.17.0" mkdocs-material = "^8.1.8" diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py index c2463922..37db66e3 100644 --- a/src/driftpy/accounts/cache/drift_client.py +++ b/src/driftpy/accounts/cache/drift_client.py @@ -1,5 +1,5 @@ from anchorpy import Program -from solana.publickey import PublicKey +from solders.pubkey import Pubkey from solana.rpc.commitment import Commitment from driftpy.accounts import get_state_account_and_slot, get_spot_market_account_and_slot, \ @@ -69,7 +69,7 @@ async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndS await self.cache_if_needed() return self.cache["spot_markets"][market_index] - async def get_oracle_data_and_slot(self, oracle: PublicKey) -> Optional[DataAndSlot[OraclePriceData]]: + async def get_oracle_data_and_slot(self, oracle: Pubkey) -> Optional[DataAndSlot[OraclePriceData]]: await self.cache_if_needed() return self.cache["oracle_price_data"][str(oracle)] diff --git a/src/driftpy/accounts/cache/user.py b/src/driftpy/accounts/cache/user.py index ffcd1b4f..357f393f 100644 --- a/src/driftpy/accounts/cache/user.py +++ b/src/driftpy/accounts/cache/user.py @@ -1,7 +1,7 @@ from typing import Optional from anchorpy import Program -from solana.publickey import PublicKey +from solders.pubkey import Pubkey from solana.rpc.commitment import Commitment from driftpy.accounts import get_user_account_and_slot @@ -10,7 +10,7 @@ class CachedUserAccountSubscriber(UserAccountSubscriber): - def __init__(self, user_pubkey: PublicKey, program: Program, commitment: Commitment = "confirmed"): + def __init__(self, user_pubkey: Pubkey, program: Program, commitment: Commitment = "confirmed"): self.program = program self.commitment = commitment self.user_pubkey = user_pubkey diff --git a/src/driftpy/accounts/get_accounts.py b/src/driftpy/accounts/get_accounts.py index edc1e0a7..9a56dfcb 100644 --- a/src/driftpy/accounts/get_accounts.py +++ b/src/driftpy/accounts/get_accounts.py @@ -1,6 +1,6 @@ import base64 from typing import cast -from solana.publickey import PublicKey +from solders.pubkey import Pubkey from anchorpy import Program, ProgramAccount from solana.rpc.commitment import Commitment @@ -9,7 +9,7 @@ from .types import DataAndSlot, T -async def get_account_data_and_slot(address: PublicKey, program: Program, commitment: Commitment = "processed") -> Optional[ +async def get_account_data_and_slot(address: Pubkey, program: Program, commitment: Commitment = "processed") -> Optional[ DataAndSlot[T]]: account_info = await program.provider.connection.get_account_info( address, @@ -17,11 +17,11 @@ async def get_account_data_and_slot(address: PublicKey, program: Program, commit commitment=commitment, ) - if not account_info["result"]["value"]: + if not account_info.value: return None - slot = account_info["result"]["context"]["slot"] - data = base64.b64decode(account_info["result"]["value"]["data"][0]) + slot = account_info.context.slot + data = account_info.value.data decoded_data = program.coder.accounts.decode(data) @@ -38,7 +38,7 @@ async def get_state_account(program: Program) -> State: async def get_if_stake_account( - program: Program, authority: PublicKey, spot_market_index: int + program: Program, authority: Pubkey, spot_market_index: int ) -> InsuranceFundStake: if_stake_pk = get_insurance_fund_stake_public_key( program.program_id, authority, spot_market_index @@ -49,7 +49,7 @@ async def get_if_stake_account( async def get_user_stats_account( program: Program, - authority: PublicKey, + authority: Pubkey, ) -> UserStats: user_stats_public_key = get_user_stats_account_public_key( program.program_id, @@ -60,13 +60,13 @@ async def get_user_stats_account( async def get_user_account_and_slot( program: Program, - user_public_key: PublicKey, + user_public_key: Pubkey, ) -> DataAndSlot[User]: return await get_account_data_and_slot(user_public_key, program) async def get_user_account( program: Program, - user_public_key: PublicKey, + user_public_key: Pubkey, ) -> User: return (await get_user_account_and_slot(program, user_public_key)).data diff --git a/src/driftpy/accounts/oracle.py b/src/driftpy/accounts/oracle.py index 127f4f6f..f6c8c373 100644 --- a/src/driftpy/accounts/oracle.py +++ b/src/driftpy/accounts/oracle.py @@ -1,10 +1,10 @@ -from solana.rpc.types import RPCResponse +from solders.rpc.responses import GetAccountInfoResp from .types import DataAndSlot from driftpy.constants.numeric_constants import * from driftpy.types import OracleSource, OraclePriceData -from solana.publickey import PublicKey +from solders.pubkey import Pubkey from pythclient.pythaccounts import PythPriceInfo, _ACCOUNT_HEADER_BYTES, EmaType from solana.rpc.async_api import AsyncClient import base64 @@ -13,11 +13,11 @@ def convert_pyth_price(price, scale=1): return int(price * PRICE_PRECISION * scale) -async def get_oracle_price_data_and_slot(connection: AsyncClient, address: PublicKey, oracle_source=OracleSource.PYTH()) -> DataAndSlot[ +async def get_oracle_price_data_and_slot(connection: AsyncClient, address: Pubkey, oracle_source=OracleSource.PYTH()) -> DataAndSlot[ OraclePriceData]: if 'Pyth' in str(oracle_source): rpc_reponse = await connection.get_account_info(address) - rpc_response_slot = rpc_reponse['result']['context']['slot'] + rpc_response_slot = rpc_reponse.context.slot (pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info(rpc_reponse) scale = 1 @@ -41,10 +41,8 @@ async def get_oracle_price_data_and_slot(connection: AsyncClient, address: Publi else: raise NotImplementedError('Unsupported Oracle Source', str(oracle_source)) -async def _parse_pyth_price_info(resp: RPCResponse) -> (PythPriceInfo, int, int, int): - value = resp["result"].get("value") - data_base64, data_format = value["data"] - buffer = base64.b64decode(data_base64) +async def _parse_pyth_price_info(resp: GetAccountInfoResp) -> (PythPriceInfo, int, int, int): + buffer = resp.value.data offset = _ACCOUNT_HEADER_BYTES _, exponent, _ = struct.unpack_from(" Optional[DataAndS pass @abstractmethod - async def get_oracle_data_and_slot(self, oracle: PublicKey) -> Optional[DataAndSlot[OraclePriceData]]: + async def get_oracle_data_and_slot(self, oracle: Pubkey) -> Optional[DataAndSlot[OraclePriceData]]: pass class UserAccountSubscriber: diff --git a/src/driftpy/addresses.py b/src/driftpy/addresses.py index 613620df..048450bf 100644 --- a/src/driftpy/addresses.py +++ b/src/driftpy/addresses.py @@ -1,4 +1,4 @@ -from solana.publickey import PublicKey +from solders.pubkey import Pubkey def int_to_le_bytes(a: int): @@ -6,94 +6,94 @@ def int_to_le_bytes(a: int): def get_perp_market_public_key( - program_id: PublicKey, + program_id: Pubkey, market_index: int, -) -> PublicKey: - return PublicKey.find_program_address( +) -> Pubkey: + return Pubkey.find_program_address( [b"perp_market", int_to_le_bytes(market_index)], program_id )[0] def get_insurance_fund_vault_public_key( - program_id: PublicKey, + program_id: Pubkey, spot_market_index: int, -) -> PublicKey: - return PublicKey.find_program_address( +) -> Pubkey: + return Pubkey.find_program_address( [b"insurance_fund_vault", int_to_le_bytes(spot_market_index)], program_id )[0] def get_insurance_fund_stake_public_key( - program_id: PublicKey, - authority: PublicKey, + program_id: Pubkey, + authority: Pubkey, spot_market_index: int, -) -> PublicKey: - return PublicKey.find_program_address( +) -> Pubkey: + return Pubkey.find_program_address( [b"insurance_fund_stake", bytes(authority), int_to_le_bytes(spot_market_index)], program_id, )[0] def get_spot_market_public_key( - program_id: PublicKey, + program_id: Pubkey, spot_market_index: int, -) -> PublicKey: - return PublicKey.find_program_address( +) -> Pubkey: + return Pubkey.find_program_address( [b"spot_market", int_to_le_bytes(spot_market_index)], program_id )[0] def get_spot_market_vault_public_key( - program_id: PublicKey, + program_id: Pubkey, spot_market_index: int, -) -> PublicKey: - return PublicKey.find_program_address( +) -> Pubkey: + return Pubkey.find_program_address( [b"spot_market_vault", int_to_le_bytes(spot_market_index)], program_id )[0] def get_spot_market_vault_authority_public_key( - program_id: PublicKey, + program_id: Pubkey, spot_market_index: int, -) -> PublicKey: - return PublicKey.find_program_address( +) -> Pubkey: + return Pubkey.find_program_address( [b"spot_market_vault_authority", int_to_le_bytes(spot_market_index)], program_id )[0] def get_state_public_key( - program_id: PublicKey, -) -> PublicKey: - return PublicKey.find_program_address([b"drift_state"], program_id)[0] + program_id: Pubkey, +) -> Pubkey: + return Pubkey.find_program_address([b"drift_state"], program_id)[0] def get_drift_client_signer_public_key( - program_id: PublicKey, -) -> PublicKey: - return PublicKey.find_program_address([b"drift_signer"], program_id)[0] + program_id: Pubkey, +) -> Pubkey: + return Pubkey.find_program_address([b"drift_signer"], program_id)[0] def get_user_stats_account_public_key( - program_id: PublicKey, - authority: PublicKey, -) -> PublicKey: - return PublicKey.find_program_address( + program_id: Pubkey, + authority: Pubkey, +) -> Pubkey: + return Pubkey.find_program_address( [b"user_stats", bytes(authority)], program_id )[0] def get_user_account_public_key( - program_id: PublicKey, - authority: PublicKey, + program_id: Pubkey, + authority: Pubkey, user_id=0, -) -> PublicKey: - return PublicKey.find_program_address( +) -> Pubkey: + return Pubkey.find_program_address( [b"user", bytes(authority), int_to_le_bytes(user_id)], program_id )[0] -# program = PublicKey("9jwr5nC2f9yAraXrg4UzHXmCX3vi9FQkjD6p9e8bRqNa") -# auth = PublicKey("D78cqss3dbU1aJAs5qeuhLi8Rqa2CL4Kzkr3VzdgN5F6") +# program = Pubkey("9jwr5nC2f9yAraXrg4UzHXmCX3vi9FQkjD6p9e8bRqNa") +# auth = Pubkey("D78cqss3dbU1aJAs5qeuhLi8Rqa2CL4Kzkr3VzdgN5F6") # == EjQ8rFmR4hd9faX1TYLkqCTsAkyjJ4qUKBuagtmVG3cP # get_user_account_public_key( # program, diff --git a/src/driftpy/admin.py b/src/driftpy/admin.py index d6c2be8e..9c8c634b 100644 --- a/src/driftpy/admin.py +++ b/src/driftpy/admin.py @@ -1,9 +1,9 @@ -from solana.publickey import PublicKey -from solana.transaction import TransactionSignature -from solana.keypair import Keypair -from solana.system_program import SYS_PROGRAM_ID -from solana.sysvar import SYSVAR_RENT_PUBKEY +from solders.pubkey import Pubkey +from solders.signature import Signature +from solders.keypair import Keypair +from solders.system_program import ID +from solders.sysvar import RENT from spl.token.constants import TOKEN_PROGRAM_ID from anchorpy import Program, Provider, Context @@ -55,15 +55,15 @@ def from_config( async def initialize( self, - usdc_mint: PublicKey, + usdc_mint: Pubkey, admin_controls_prices: bool, - ) -> tuple[TransactionSignature, TransactionSignature]: + ) -> tuple[Signature, Signature]: state_account_rpc_response = ( await self.program.provider.connection.get_account_info( get_state_public_key(self.program_id) ) ) - if state_account_rpc_response["result"]["value"] is not None: + if state_account_rpc_response.value is not None: raise RuntimeError("Drift Client already initialized") state_public_key = get_state_public_key(self.program_id) @@ -77,8 +77,8 @@ async def initialize( "drift_signer": get_drift_client_signer_public_key( self.program_id ), - "rent": SYSVAR_RENT_PUBKEY, - "system_program": SYS_PROGRAM_ID, + "rent": RENT, + "system_program": ID, "token_program": TOKEN_PROGRAM_ID, }, ), @@ -89,7 +89,7 @@ async def initialize( async def initialize_perp_market( self, market_index: int, - price_oracle: PublicKey, + price_oracle: Pubkey, base_asset_reserve: int, quote_asset_reserve: int, periodicity: int, @@ -100,7 +100,7 @@ async def initialize_perp_market( liquidation_fee: int = 0, active_status: bool = True, name: list = [0] * 32, - ) -> TransactionSignature: + ) -> Signature: state_public_key = get_state_public_key(self.program.program_id) state = await get_state_account(self.program) market_pubkey = get_perp_market_public_key( @@ -126,19 +126,19 @@ async def initialize_perp_market( "state": state_public_key, "oracle": price_oracle, "perp_market": market_pubkey, - "rent": SYSVAR_RENT_PUBKEY, - "system_program": SYS_PROGRAM_ID, + "rent": RENT, + "system_program": ID, } ), ) async def initialize_spot_market( self, - mint: PublicKey, + mint: Pubkey, optimal_utilization: int = SPOT_RATE_PRECISION // 2, optimal_rate: int = SPOT_RATE_PRECISION, max_rate: int = SPOT_RATE_PRECISION, - oracle: PublicKey = PublicKey([0] * PublicKey.LENGTH), + oracle: Pubkey = Pubkey([0] * Pubkey.LENGTH), oracle_source: OracleSource = OracleSource.QUOTE_ASSET(), initial_asset_weight: int = SPOT_WEIGHT_PRECISION, maintenance_asset_weight: int = SPOT_WEIGHT_PRECISION, @@ -186,8 +186,8 @@ async def initialize_spot_market( ), "spot_market_mint": mint, "oracle": oracle, - "rent": SYSVAR_RENT_PUBKEY, - "system_program": SYS_PROGRAM_ID, + "rent": RENT, + "system_program": ID, "token_program": TOKEN_PROGRAM_ID, } ), diff --git a/src/driftpy/constants/banks.py b/src/driftpy/constants/banks.py index c399152c..f0a4c013 100644 --- a/src/driftpy/constants/banks.py +++ b/src/driftpy/constants/banks.py @@ -1,38 +1,38 @@ from dataclasses import dataclass from driftpy.types import OracleSource -from solana.publickey import PublicKey +from solders.pubkey import Pubkey @dataclass class Bank: symbol: str bank_index: int - oracle: PublicKey + oracle: Pubkey oracle_source: OracleSource - mint: PublicKey + mint: Pubkey devnet_banks: list[Bank] = [ Bank( symbol="USDC", bank_index=0, - oracle=PublicKey(0), + oracle=Pubkey.default(), oracle_source=OracleSource.QUOTE_ASSET, - mint=PublicKey("8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"), + mint=Pubkey.from_string("8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"), ), Bank( symbol="SOL", bank_index=1, - oracle=PublicKey("J83w4HKfqxwcq3BEMMkPFSppX3gqekLyLJBexebFVkix"), + oracle=Pubkey.from_string("J83w4HKfqxwcq3BEMMkPFSppX3gqekLyLJBexebFVkix"), oracle_source=OracleSource.PYTH, - mint=PublicKey("So11111111111111111111111111111111111111112"), + mint=Pubkey.from_string("So11111111111111111111111111111111111111112"), ), Bank( symbol="BTC", bank_index=2, - oracle=PublicKey("HovQMDrbAgAYPCmHVSrezcSmkMtXSSUsLDFANExrZh2J"), + oracle=Pubkey.from_string("HovQMDrbAgAYPCmHVSrezcSmkMtXSSUsLDFANExrZh2J"), oracle_source=OracleSource.PYTH, - mint=PublicKey("3BZPwbcqB5kKScF3TEXxwNfx5ipV13kbRVDvfVp5c6fv"), + mint=Pubkey.from_string("3BZPwbcqB5kKScF3TEXxwNfx5ipV13kbRVDvfVp5c6fv"), ), ] @@ -40,15 +40,15 @@ class Bank: Bank( symbol="USDC", bank_index=0, - oracle=PublicKey(0), + oracle=Pubkey.default(), oracle_source=OracleSource.QUOTE_ASSET, - mint=PublicKey("8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"), + mint=Pubkey.from_string("8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"), ), Bank( symbol="SOL", bank_index=1, - oracle=PublicKey("H6ARHf6YXhGYeQfUzQNGk6rDNnLBQKrenN712K4AQJEG"), + oracle=Pubkey.from_string("H6ARHf6YXhGYeQfUzQNGk6rDNnLBQKrenN712K4AQJEG"), oracle_source=OracleSource.PYTH, - mint=PublicKey("So11111111111111111111111111111111111111112"), + mint=Pubkey.from_string("So11111111111111111111111111111111111111112"), ), ] diff --git a/src/driftpy/constants/config.py b/src/driftpy/constants/config.py index 2262d04e..3ac44389 100644 --- a/src/driftpy/constants/config.py +++ b/src/driftpy/constants/config.py @@ -1,15 +1,15 @@ from driftpy.constants.banks import devnet_banks, mainnet_banks, Bank from driftpy.constants.markets import devnet_markets, mainnet_markets, Market from dataclasses import dataclass -from solana.publickey import PublicKey +from solders.pubkey import Pubkey @dataclass class Config: env: str - pyth_oracle_mapping_address: PublicKey - drift_client_program_id: PublicKey - usdc_mint_address: PublicKey + pyth_oracle_mapping_address: Pubkey + drift_client_program_id: Pubkey + usdc_mint_address: Pubkey default_http: str default_ws: str markets: list[Market] @@ -19,13 +19,13 @@ class Config: configs = { "devnet": Config( env="devnet", - pyth_oracle_mapping_address=PublicKey( + pyth_oracle_mapping_address=Pubkey.from_string( "BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2" ), - drift_client_program_id=PublicKey( + drift_client_program_id=Pubkey.from_string( "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" ), - usdc_mint_address=PublicKey("8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"), + usdc_mint_address=Pubkey.from_string("8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"), default_http="https://api.devnet.solana.com", default_ws="wss://api.devnet.solana.com", markets=devnet_markets, @@ -33,13 +33,13 @@ class Config: ), "mainnet": Config( env="mainnet", - pyth_oracle_mapping_address=PublicKey( + pyth_oracle_mapping_address=Pubkey.from_string( "AHtgzX45WTKfkPG53L6WYhGEXwQkN1BVknET3sVsLL8J" ), - drift_client_program_id=PublicKey( + drift_client_program_id=Pubkey.from_string( "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" ), - usdc_mint_address=PublicKey("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"), + usdc_mint_address=Pubkey.from_string("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"), default_http="https://api.mainnet-beta.solana.com", default_ws="wss://api.mainnet-beta.solana.com", markets=mainnet_markets, diff --git a/src/driftpy/constants/markets.py b/src/driftpy/constants/markets.py index b35b394d..0e03dd51 100644 --- a/src/driftpy/constants/markets.py +++ b/src/driftpy/constants/markets.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from solana.publickey import PublicKey +from solders.pubkey import Pubkey @dataclass @@ -7,7 +7,7 @@ class Market: symbol: str base_asset_symbol: str market_index: int - pyth_oracle: PublicKey + pyth_oracle: Pubkey devnet_markets: list[Market] = [ @@ -15,67 +15,67 @@ class Market: base_asset_symbol="SOL", symbol="SOL-PERP", market_index=0, - pyth_oracle=PublicKey("J83w4HKfqxwcq3BEMMkPFSppX3gqekLyLJBexebFVkix"), + pyth_oracle=Pubkey.from_string("J83w4HKfqxwcq3BEMMkPFSppX3gqekLyLJBexebFVkix"), ), Market( base_asset_symbol="BTC", symbol="BTC-PERP", market_index=1, - pyth_oracle=PublicKey("HovQMDrbAgAYPCmHVSrezcSmkMtXSSUsLDFANExrZh2J"), + pyth_oracle=Pubkey.from_string("HovQMDrbAgAYPCmHVSrezcSmkMtXSSUsLDFANExrZh2J"), ), Market( base_asset_symbol="ETH", symbol="ETH-PERP", market_index=2, - pyth_oracle=PublicKey("EdVCmQ9FSPcVe5YySXDPCRmc8aDQLKJ9xvYBMZPie1Vw"), + pyth_oracle=Pubkey.from_string("EdVCmQ9FSPcVe5YySXDPCRmc8aDQLKJ9xvYBMZPie1Vw"), ), Market( base_asset_symbol="APT", symbol="APT-PERP", market_index=3, - pyth_oracle=PublicKey("5d2QJ6u2NveZufmJ4noHja5EHs3Bv1DUMPLG5xfasSVs"), + pyth_oracle=Pubkey.from_string("5d2QJ6u2NveZufmJ4noHja5EHs3Bv1DUMPLG5xfasSVs"), ), Market( symbol="1MBONK-PERP", base_asset_symbol="1MBONK", market_index=4, - pyth_oracle=PublicKey("6bquU99ktV1VRiHDr8gMhDFt3kMfhCQo5nfNrg2Urvsn"), + pyth_oracle=Pubkey.from_string("6bquU99ktV1VRiHDr8gMhDFt3kMfhCQo5nfNrg2Urvsn"), ), Market( symbol="MATIC-PERP", base_asset_symbol="MATIC", market_index=5, - pyth_oracle=PublicKey("FBirwuDFuRAu4iSGc7RGxN5koHB7EJM1wbCmyPuQoGur"), + pyth_oracle=Pubkey.from_string("FBirwuDFuRAu4iSGc7RGxN5koHB7EJM1wbCmyPuQoGur"), ), Market( symbol="ARB-PERP", base_asset_symbol="ARB", market_index=6, - pyth_oracle=PublicKey("4mRGHzjGerQNWKXyQAmr9kWqb9saPPHKqo1xziXGQ5Dh"), + pyth_oracle=Pubkey.from_string("4mRGHzjGerQNWKXyQAmr9kWqb9saPPHKqo1xziXGQ5Dh"), ), Market( symbol="DOGE-PERP", base_asset_symbol="DOGE", market_index=7, - pyth_oracle=PublicKey("4L6YhY8VvUgmqG5MvJkUJATtzB2rFqdrJwQCmFLv4Jzy"), + pyth_oracle=Pubkey.from_string("4L6YhY8VvUgmqG5MvJkUJATtzB2rFqdrJwQCmFLv4Jzy"), ), Market( symbol="BNB-PERP", base_asset_symbol="BNB", market_index=8, - pyth_oracle=PublicKey("GwzBgrXb4PG59zjce24SF2b9JXbLEjJJTBkmytuEZj1b"), + pyth_oracle=Pubkey.from_string("GwzBgrXb4PG59zjce24SF2b9JXbLEjJJTBkmytuEZj1b"), ), Market( symbol="SUI-PERP", base_asset_symbol="SUI", market_index=9, - pyth_oracle=PublicKey("6SK9vS8eMSSj3LUX2dPku93CrNv8xLCp9ng39F39h7A5"), + pyth_oracle=Pubkey.from_string("6SK9vS8eMSSj3LUX2dPku93CrNv8xLCp9ng39F39h7A5"), ), Market( symbol="1MPEPE-PERP", base_asset_symbol="1MPEPE", market_index=10, - pyth_oracle=PublicKey("Gz9RfgDeAFSsH7BHDGyNTgCik74rjNwsodJpsCizzmkj"), + pyth_oracle=Pubkey.from_string("Gz9RfgDeAFSsH7BHDGyNTgCik74rjNwsodJpsCizzmkj"), ), ] @@ -84,66 +84,66 @@ class Market: symbol="SOL-PERP", base_asset_symbol="SOL", market_index=0, - pyth_oracle=PublicKey("H6ARHf6YXhGYeQfUzQNGk6rDNnLBQKrenN712K4AQJEG"), + pyth_oracle=Pubkey.from_string("H6ARHf6YXhGYeQfUzQNGk6rDNnLBQKrenN712K4AQJEG"), ), Market( symbol="BTC-PERP", base_asset_symbol="BTC", market_index=1, - pyth_oracle=PublicKey("GVXRSBjFk6e6J3NbVPXohDJetcTjaeeuykUpbQF8UoMU"), + pyth_oracle=Pubkey.from_string("GVXRSBjFk6e6J3NbVPXohDJetcTjaeeuykUpbQF8UoMU"), ), Market( symbol="ETH-PERP", base_asset_symbol="ETH", market_index=2, - pyth_oracle=PublicKey("JBu1AL4obBcCMqKBBxhpWCNUt136ijcuMZLFvTP7iWdB"), + pyth_oracle=Pubkey.from_string("JBu1AL4obBcCMqKBBxhpWCNUt136ijcuMZLFvTP7iWdB"), ), Market( symbol="APT-PERP", base_asset_symbol="APT", market_index=3, - pyth_oracle=PublicKey("FNNvb1AFDnDVPkocEri8mWbJ1952HQZtFLuwPiUjSJQ"), + pyth_oracle=Pubkey.from_string("FNNvb1AFDnDVPkocEri8mWbJ1952HQZtFLuwPiUjSJQ"), ), Market( symbol="1MBONK", base_asset_symbol="1MBONK", market_index=4, - pyth_oracle=PublicKey("8ihFLu5FimgTQ1Unh4dVyEHUGodJ5gJQCrQf4KUVB9bN"), + pyth_oracle=Pubkey.from_string("8ihFLu5FimgTQ1Unh4dVyEHUGodJ5gJQCrQf4KUVB9bN"), ), Market( symbol="MATIC-PERP", base_asset_symbol="MATIC", market_index=5, - pyth_oracle=PublicKey("7KVswB9vkCgeM3SHP7aGDijvdRAHK8P5wi9JXViCrtYh"), + pyth_oracle=Pubkey.from_string("7KVswB9vkCgeM3SHP7aGDijvdRAHK8P5wi9JXViCrtYh"), ), Market( symbol="ARB-PERP", base_asset_symbol="ARB", market_index=6, - pyth_oracle=PublicKey("5HRrdmghsnU3i2u5StaKaydS7eq3vnKVKwXMzCNKsc4C"), + pyth_oracle=Pubkey.from_string("5HRrdmghsnU3i2u5StaKaydS7eq3vnKVKwXMzCNKsc4C"), ), Market( symbol="DOGE-PERP", base_asset_symbol="DOGE", market_index=7, - pyth_oracle=PublicKey("FsSM3s38PX9K7Dn6eGzuE29S2Dsk1Sss1baytTQdCaQj"), + pyth_oracle=Pubkey.from_string("FsSM3s38PX9K7Dn6eGzuE29S2Dsk1Sss1baytTQdCaQj"), ), Market( symbol="BNB-PERP", base_asset_symbol="BNB", market_index=8, - pyth_oracle=PublicKey("4CkQJBxhU8EZ2UjhigbtdaPbpTe6mqf811fipYBFbSYN"), + pyth_oracle=Pubkey.from_string("4CkQJBxhU8EZ2UjhigbtdaPbpTe6mqf811fipYBFbSYN"), ), Market( symbol="SUI-PERP", base_asset_symbol="SUI", market_index=9, - pyth_oracle=PublicKey("3Qub3HaAJaa2xNY7SUqPKd3vVwTqDfDDkEUMPjXD2c1q"), + pyth_oracle=Pubkey.from_string("3Qub3HaAJaa2xNY7SUqPKd3vVwTqDfDDkEUMPjXD2c1q"), ), Market( symbol="1MPEPE-PERP", base_asset_symbol="1MPEPE", market_index=10, - pyth_oracle=PublicKey("FSfxunDmjjbDV2QxpyxFCAPKmYJHSLnLuvQXDLkMzLBm"), + pyth_oracle=Pubkey.from_string("FSfxunDmjjbDV2QxpyxFCAPKmYJHSLnLuvQXDLkMzLBm"), ), ] diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 56ccd00f..f86b0702 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -1,12 +1,13 @@ -from solana.publickey import PublicKey +from solders.pubkey import Pubkey import json from typing import Optional -from solana.publickey import PublicKey -from solana.keypair import Keypair -from solana.transaction import Transaction, TransactionInstruction -from solana.system_program import SYS_PROGRAM_ID -from solana.sysvar import SYSVAR_RENT_PUBKEY +from solders.keypair import Keypair +from solana.transaction import Transaction +from solders.instruction import Instruction +from solders.system_program import ID +from solders.sysvar import RENT from solana.transaction import AccountMeta +from solders.compute_budget import set_compute_unit_limit, set_compute_unit_price from spl.token.constants import TOKEN_PROGRAM_ID from anchorpy import Program, Context, Idl, Provider from struct import pack_into @@ -29,14 +30,12 @@ DEFAULT_USER_NAME = "Main Account" -DEFAULT_PUBKEY = PublicKey("11111111111111111111111111111111") - class DriftClient: """This class is the main way to interact with Drift Protocol including depositing, opening new positions, closing positions, placing orders, etc. """ - def __init__(self, program: Program, signer: Keypair = None, authority: PublicKey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None): + def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None): """Initializes the drift client object -- likely want to use the .from_config method instead of this one Args: @@ -51,7 +50,7 @@ def __init__(self, program: Program, signer: Keypair = None, authority: PublicKe signer = program.provider.wallet.payer if authority is None: - authority = signer.public_key + authority = signer.pubkey() self.signer = signer self.authority = authority @@ -65,6 +64,11 @@ def __init__(self, program: Program, signer: Keypair = None, authority: PublicKe self.account_subscriber = account_subscriber + if tx_params is None: + tx_params = TxParams(600_000, 0) + + self.tx_params = tx_params + @staticmethod def from_config(config: Config, provider: Provider, authority: Keypair = None): """Initializes the drift client object from a Config @@ -99,7 +103,7 @@ def from_config(config: Config, provider: Provider, authority: Keypair = None): return drift_client - def get_user_account_public_key(self, user_id=0) -> PublicKey: + def get_user_account_public_key(self, user_id=0) -> Pubkey: return get_user_account_public_key(self.program_id, self.authority, user_id) async def get_user(self, user_id=0) -> User: @@ -123,26 +127,38 @@ async def get_spot_market(self, market_index: int) -> Optional[SpotMarket]: spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot(market_index) return getattr(spot_market_and_slot, 'data', None) - async def get_oracle_price_data(self, oracle: PublicKey) -> Optional[OraclePriceData]: + async def get_oracle_price_data(self, oracle: Pubkey) -> Optional[OraclePriceData]: oracle_price_data_and_slot = await self.account_subscriber.get_oracle_data_and_slot(oracle) return getattr(oracle_price_data_and_slot, 'data', None) async def send_ixs( self, - ixs: Union[TransactionInstruction, list[TransactionInstruction]], + ixs: Union[Instruction, list[Instruction]], signers=None, ): - if isinstance(ixs, TransactionInstruction): + if isinstance(ixs, Instruction): ixs = [ixs] tx = Transaction() - for ix in ixs: - tx.add(ix) - if signers is None: - signers = self.signers + if self.tx_params.compute_units is not None: + tx.add(set_compute_unit_limit(self.tx_params.compute_units)) + + if self.tx_params.compute_units_price is not None: + tx.add(set_compute_unit_price(self.tx_params.compute_units_price)) + + [tx.add(ix) for ix in ixs] + + latest_blockhash = (await self.program.provider.connection.get_latest_blockhash()).value.blockhash + tx.recent_blockhash = latest_blockhash + tx.fee_payer = self.signer.pubkey() + + tx.sign_partial(self.signer) + + if signers is not None: + [tx.sign_partial(signer) for signer in signers] - return await self.program.provider.send(tx, signers=signers) + return await self.program.provider.send(tx) async def intialize_user(self, user_id: int = 0): """intializes a drift user @@ -173,15 +189,15 @@ def get_initialize_user_stats( "state": state_public_key, "authority": self.authority, "payer": self.authority, - "rent": SYSVAR_RENT_PUBKEY, - "system_program": SYS_PROGRAM_ID, + "rent": RENT, + "system_program": ID, }, ), ) def get_initialize_user_instructions( self, user_id: int = 0, name: str = DEFAULT_USER_NAME - ) -> TransactionInstruction: + ) -> Instruction: user_public_key = self.get_user_account_public_key(user_id) state_public_key = self.get_state_public_key() user_stats_public_key = self.get_user_stats_public_key() @@ -211,8 +227,8 @@ def get_initialize_user_instructions( "state": state_public_key, "authority": self.authority, "payer": self.authority, - "rent": SYSVAR_RENT_PUBKEY, - "system_program": SYS_PROGRAM_ID, + "rent": RENT, + "system_program": ID, }, ), ) @@ -226,11 +242,11 @@ async def get_remaining_accounts( user_id=[0], include_oracles: bool = True, include_spot_markets: bool = True, - authority: Optional[Union[PublicKey, Sequence[PublicKey]]] = None, + authority: Optional[Union[Pubkey, Sequence[Pubkey]]] = None, ): if authority is None: authority = [self.authority] - elif isinstance(authority, PublicKey): + elif isinstance(authority, Pubkey): authority = [authority] if isinstance(user_id, int): @@ -259,7 +275,7 @@ async def track_market(market_index, is_writable): spot_market = await self.get_spot_market( perp_market.quote_spot_market_index ) - if spot_market.oracle != DEFAULT_PUBKEY: + if spot_market.oracle != Pubkey.default(): oracle_map[str(spot_market.oracle)] = AccountMeta( pubkey=spot_market.oracle, is_signer=False, is_writable=False ) @@ -275,7 +291,7 @@ async def track_spot_market(spot_market_index, is_writable): is_writable=is_writable, ) - if include_oracles and spot_market.oracle != DEFAULT_PUBKEY: + if include_oracles and spot_market.oracle != Pubkey.default(): oracle_map[str(spot_market.pubkey)] = AccountMeta( pubkey=spot_market.oracle, is_signer=False, is_writable=False ) @@ -326,7 +342,7 @@ async def withdraw( self, amount: int, spot_market_index: int, - user_token_account: PublicKey, + user_token_account: Pubkey, reduce_only: bool = False, user_id: int = 0, ): @@ -335,7 +351,7 @@ async def withdraw( Args: amount (int): amount to withdraw spot_market_index (int): - user_token_account (PublicKey): ata of the account to withdraw to + user_token_account (Pubkey): ata of the account to withdraw to reduce_only (bool, optional): if True will only withdraw existing funds else if False will allow taking out borrows. Defaults to False. user_id (int, optional): subaccount. Defaults to 0. @@ -354,7 +370,7 @@ async def get_withdraw_collateral_ix( self, amount: int, spot_market_index: int, - user_token_account: PublicKey, + user_token_account: Pubkey, reduce_only: bool = False, user_id: int = 0, ): @@ -390,7 +406,7 @@ async def deposit( self, amount: int, spot_market_index: int, - user_token_account: PublicKey, + user_token_account: Pubkey, user_id: int = 0, reduce_only=False, user_initialized=True, @@ -400,7 +416,7 @@ async def deposit( Args: amount (int): amount to deposit spot_market_index (int): - user_token_account (PublicKey): + user_token_account (Pubkey): user_id (int, optional): subaccount to deposit into. Defaults to 0. reduce_only (bool, optional): paying back borrow vs depositing new assets. Defaults to False. user_initialized (bool, optional): if need to initialize user account too set this to False. Defaults to True. @@ -425,11 +441,11 @@ async def get_deposit_collateral_ix( self, amount: int, spot_market_index: int, - user_token_account: PublicKey, + user_token_account: Pubkey, user_id: int = 0, reduce_only=False, user_initialized=True, - ) -> TransactionInstruction: + ) -> Instruction: if user_initialized: remaining_accounts = await self.get_remaining_accounts( writable_spot_market_index=spot_market_index, user_id=user_id @@ -640,8 +656,8 @@ async def get_open_position_ix( ix = await self.get_place_and_take_ix(order, subaccount_id=user_id) return ix - def get_increase_compute_ix(self) -> TransactionInstruction: - program_id = PublicKey("ComputeBudget111111111111111111111111111111") + def get_increase_compute_ix(self) -> Instruction: + program_id = Pubkey("ComputeBudget111111111111111111111111111111") name_bytes = bytearray(1 + 4 + 4) pack_into("B", name_bytes, 0, 0) @@ -649,7 +665,7 @@ def get_increase_compute_ix(self) -> TransactionInstruction: pack_into("I", name_bytes, 5, 0) data = bytes(name_bytes) - compute_ix = TransactionInstruction([], program_id, data) + compute_ix = Instruction(program_id, data, []) return compute_ix @@ -864,7 +880,7 @@ async def get_place_and_take_ix( async def settle_lp( self, - settlee_authority: PublicKey, + settlee_authority: Pubkey, market_index: int, user_id: int = 0, ): @@ -874,7 +890,7 @@ async def settle_lp( ) async def get_settle_lp_ix( - self, settlee_authority: PublicKey, market_index: int, user_id: int = 0 + self, settlee_authority: Pubkey, market_index: int, user_id: int = 0 ): remaining_accounts = await self.get_remaining_accounts( writable_market_index=market_index, @@ -990,7 +1006,7 @@ def default_order_params( async def liquidate_spot( self, - user_authority: PublicKey, + user_authority: Pubkey, asset_market_index: int, liability_market_index: int, max_liability_transfer: int, @@ -1012,7 +1028,7 @@ async def liquidate_spot( async def get_liquidate_spot_ix( self, - user_authority: PublicKey, + user_authority: Pubkey, asset_market_index: int, liability_market_index: int, max_liability_transfer: int, @@ -1057,7 +1073,7 @@ async def get_liquidate_spot_ix( async def liquidate_perp( self, - user_authority: PublicKey, + user_authority: Pubkey, market_index: int, max_base_asset_amount: int, limit_price: Optional[int] = None, @@ -1079,7 +1095,7 @@ async def liquidate_perp( async def get_liquidate_perp_ix( self, - user_authority: PublicKey, + user_authority: Pubkey, market_index: int, max_base_asset_amount: int, limit_price: Optional[int] = None, @@ -1122,7 +1138,7 @@ async def get_liquidate_perp_ix( async def liquidate_perp_pnl_for_deposit( self, - user_authority: PublicKey, + user_authority: Pubkey, perp_market_index: int, spot_market_index: int, max_pnl_transfer: int, @@ -1142,7 +1158,7 @@ async def liquidate_perp_pnl_for_deposit( async def get_liquidate_perp_pnl_for_deposit_ix( self, - user_authority: PublicKey, + user_authority: Pubkey, perp_market_index: int, spot_market_index: int, max_pnl_transfer: int, @@ -1189,7 +1205,7 @@ async def get_liquidate_perp_pnl_for_deposit_ix( async def settle_pnl( self, - user_authority: PublicKey, + user_authority: Pubkey, market_index: int, user_id: int = 0, ): @@ -1199,7 +1215,7 @@ async def settle_pnl( async def get_settle_pnl_ix( self, - user_authority: PublicKey, + user_authority: Pubkey, market_index: int, user_id: int = 0, ): @@ -1232,7 +1248,7 @@ async def get_settle_pnl_ix( async def resolve_spot_bankruptcy( self, - user_authority: PublicKey, + user_authority: Pubkey, spot_market_index: int, user_subaccount_id: int = 0, liq_subaccount_id: int = 0, @@ -1250,7 +1266,7 @@ async def resolve_spot_bankruptcy( async def get_resolve_spot_bankruptcy_ix( self, - user_authority: PublicKey, + user_authority: Pubkey, spot_market_index: int, user_subaccount_id: int = 0, liq_subaccount_id: int = 0, @@ -1301,7 +1317,7 @@ async def get_resolve_spot_bankruptcy_ix( async def resolve_perp_bankruptcy( self, - user_authority: PublicKey, + user_authority: Pubkey, market_index: int, user_subaccount_id: int = 0, liq_subaccount_id: int = 0, @@ -1319,7 +1335,7 @@ async def resolve_perp_bankruptcy( async def get_resolve_perp_bankruptcy_ix( self, - user_authority: PublicKey, + user_authority: Pubkey, market_index: int, user_subaccount_id: int = 0, liq_subaccount_id: int = 0, @@ -1626,8 +1642,8 @@ def get_initialize_insurance_fund_stake_ix( "state": get_state_public_key(self.program_id), "authority": self.authority, "payer": self.authority, - "rent": SYSVAR_RENT_PUBKEY, - "system_program": SYS_PROGRAM_ID, + "rent": RENT, + "system_program": ID, } ), ) diff --git a/src/driftpy/sdk_types.py b/src/driftpy/sdk_types.py index aa748367..1c0af8c8 100644 --- a/src/driftpy/sdk_types.py +++ b/src/driftpy/sdk_types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from solana.publickey import PublicKey +from solders.pubkey import Pubkey from borsh_construct.enum import _rust_enum from sumtypes import constructor from driftpy.types import * @@ -13,5 +13,5 @@ class AssetType: @dataclass class MakerInfo: - maker: PublicKey + maker: Pubkey order: Order diff --git a/src/driftpy/setup/helpers.py b/src/driftpy/setup/helpers.py index 262f66fd..6d9c07a4 100644 --- a/src/driftpy/setup/helpers.py +++ b/src/driftpy/setup/helpers.py @@ -2,20 +2,17 @@ from dataclasses import dataclass from typing import Optional from construct import Int32sl, Int64ul -from solana.keypair import Keypair -from solana.publickey import PublicKey -from solana.system_program import create_account, CreateAccountParams +from solders.keypair import Keypair +from solders.pubkey import Pubkey +from solders.system_program import create_account, CreateAccountParams from anchorpy import Program, Context, Provider -from solana.keypair import Keypair from solana.transaction import Transaction -from solana.system_program import create_account, CreateAccountParams from spl.token.constants import TOKEN_PROGRAM_ID from spl.token._layouts import MINT_LAYOUT from spl.token.async_client import AsyncToken from spl.token.instructions import initialize_mint, InitializeMintParams import math -from solana.system_program import create_account, CreateAccountParams from spl.token.async_client import AsyncToken from spl.token._layouts import ACCOUNT_LAYOUT from spl.token.constants import TOKEN_PROGRAM_ID @@ -25,7 +22,7 @@ mint_to, MintToParams, ) -from solana.transaction import TransactionSignature +from solana.transaction import Signature from driftpy.sdk_types import AssetType from driftpy.types import * @@ -63,13 +60,13 @@ async def adjust_oracle_pretrade( async def _airdrop_user( provider: Provider, user: Optional[Keypair] = None -) -> tuple[Keypair, TransactionSignature]: +) -> tuple[Keypair, Signature]: if user is None: user = Keypair() resp = await provider.connection.request_airdrop( - user.public_key, 100_0 * 1000000000 + user.pubkey(), 100_0 * 1000000000 ) - tx_sig = resp["result"] + tx_sig = resp.value return user, tx_sig @@ -77,42 +74,50 @@ async def _create_mint(provider: Provider) -> Keypair: fake_create_mint = Keypair() params = CreateAccountParams( from_pubkey=provider.wallet.public_key, - new_account_pubkey=fake_create_mint.public_key, + to_pubkey=fake_create_mint.pubkey(), lamports=await AsyncToken.get_min_balance_rent_for_exempt_for_mint( provider.connection ), space=MINT_LAYOUT.sizeof(), - program_id=TOKEN_PROGRAM_ID, + owner=TOKEN_PROGRAM_ID, ) create_create_mint_account_ix = create_account(params) init_collateral_mint_ix = initialize_mint( InitializeMintParams( decimals=6, program_id=TOKEN_PROGRAM_ID, - mint=fake_create_mint.public_key, + mint=fake_create_mint.pubkey(), mint_authority=provider.wallet.public_key, freeze_authority=None, ) ) - fake_tx = Transaction().add(create_create_mint_account_ix, init_collateral_mint_ix) - await provider.send(fake_tx, [fake_create_mint]) + + fake_tx = Transaction( + instructions=[create_create_mint_account_ix, init_collateral_mint_ix], + recent_blockhash=(await provider.connection.get_latest_blockhash()).value.blockhash, + fee_payer=provider.wallet.public_key, + ) + + fake_tx.sign_partial(fake_create_mint) + provider.wallet.sign_transaction(fake_tx) + await provider.send(fake_tx) return fake_create_mint async def _create_user_ata_tx( - account: Keypair, provider: Provider, mint: Keypair, owner: PublicKey + account: Keypair, provider: Provider, mint: Keypair, owner: Pubkey ) -> Transaction: fake_tx = Transaction() create_token_account_ix = create_account( CreateAccountParams( from_pubkey=provider.wallet.public_key, - new_account_pubkey=account.public_key, + to_pubkey=account.pubkey(), lamports=await AsyncToken.get_min_balance_rent_for_exempt_for_account( provider.connection ), space=ACCOUNT_LAYOUT.sizeof(), - program_id=TOKEN_PROGRAM_ID, + owner=TOKEN_PROGRAM_ID, ) ) fake_tx.add(create_token_account_ix) @@ -120,8 +125,8 @@ async def _create_user_ata_tx( init_token_account_ix = initialize_account( InitializeAccountParams( program_id=TOKEN_PROGRAM_ID, - account=account.public_key, - mint=mint.public_key, + account=account.pubkey(), + mint=mint.pubkey(), owner=owner, ) ) @@ -131,10 +136,10 @@ async def _create_user_ata_tx( def mint_ix( - usdc_mint: PublicKey, - mint_auth: PublicKey, + usdc_mint: Pubkey, + mint_auth: Pubkey, usdc_amount: int, - ata_account: PublicKey, + ata_account: Pubkey, ) -> Transaction: mint_to_user_account_tx = mint_to( MintToParams( @@ -153,14 +158,14 @@ def _mint_usdc_tx( usdc_mint: Keypair, provider: Provider, usdc_amount: int, - ata_account: PublicKey, + ata_account: Pubkey, ) -> Transaction: fake_usdc_tx = Transaction() mint_to_user_account_tx = mint_to( MintToParams( program_id=TOKEN_PROGRAM_ID, - mint=usdc_mint.public_key, + mint=usdc_mint.pubkey(), dest=ata_account, mint_authority=provider.wallet.public_key, signers=[], @@ -173,7 +178,7 @@ def _mint_usdc_tx( async def _create_and_mint_user_usdc( - usdc_mint: Keypair, provider: Provider, usdc_amount: int, owner: PublicKey + usdc_mint: Keypair, provider: Provider, usdc_amount: int, owner: Pubkey ) -> Keypair: usdc_account = Keypair() @@ -184,20 +189,26 @@ async def _create_and_mint_user_usdc( owner, ) mint_tx: Transaction = _mint_usdc_tx( - usdc_mint, provider, usdc_amount, usdc_account.public_key + usdc_mint, provider, usdc_amount, usdc_account.pubkey() ) for ix in mint_tx.instructions: ata_tx.add(ix) - await provider.send(ata_tx, [provider.wallet.payer, usdc_account]) + ata_tx.recent_blockhash = (await provider.connection.get_latest_blockhash()).value.blockhash + ata_tx.fee_payer = provider.wallet.payer.pubkey() + + ata_tx.sign_partial(usdc_account) + ata_tx.sign(provider.wallet.payer) + + await provider.send(ata_tx) return usdc_account async def set_price_feed( oracle_program: Program, - oracle_public_key: PublicKey, + oracle_public_key: Pubkey, price: float, ): data = await get_feed_data(oracle_program, oracle_public_key) @@ -210,7 +221,7 @@ async def set_price_feed( async def set_price_feed_detailed( oracle_program: Program, - oracle_public_key: PublicKey, + oracle_public_key: Pubkey, price: float, conf: float, slot: int, @@ -226,7 +237,7 @@ async def set_price_feed_detailed( async def get_set_price_feed_detailed_ix( oracle_program: Program, - oracle_public_key: PublicKey, + oracle_public_key: Pubkey, price: float, conf: float, slot: int, @@ -245,7 +256,7 @@ async def create_price_feed( init_price: int, confidence: Optional[int] = None, expo: int = -4, -) -> PublicKey: +) -> Pubkey: conf = int((init_price / 10) * 10**-expo) if confidence is None else confidence collateral_token_feed = Keypair() space = 3312 @@ -254,28 +265,28 @@ async def create_price_feed( space ) ) - lamports = mbre_resp["result"] + lamports = mbre_resp.value await oracle_program.rpc["initialize"]( int(init_price * 10**-expo), expo, conf, ctx=Context( - accounts={"price": collateral_token_feed.public_key}, + accounts={"price": collateral_token_feed.pubkey()}, signers=[collateral_token_feed], pre_instructions=[ create_account( CreateAccountParams( from_pubkey=oracle_program.provider.wallet.public_key, - new_account_pubkey=collateral_token_feed.public_key, + to_pubkey=collateral_token_feed.pubkey(), space=space, lamports=lamports, - program_id=oracle_program.program_id, + owner=oracle_program.program_id, ) ), ], ), ) - return collateral_token_feed.public_key + return collateral_token_feed.pubkey() @dataclass @@ -291,10 +302,9 @@ def parse_price_data(data: bytes) -> PriceData: return PriceData(exponent, price) -async def get_feed_data(oracle_program: Program, price_feed: PublicKey) -> PriceData: +async def get_feed_data(oracle_program: Program, price_feed: Pubkey) -> PriceData: info_resp = await oracle_program.provider.connection.get_account_info(price_feed) - raw_bytes = b64decode(info_resp["result"]["value"]["data"][0]) - return parse_price_data(raw_bytes) + return parse_price_data(info_resp.value.data) from solana.rpc.async_api import AsyncClient @@ -302,7 +312,7 @@ async def get_feed_data(oracle_program: Program, price_feed: PublicKey) -> Price async def get_oracle_data( connection: AsyncClient, - oracle_addr: PublicKey, + oracle_addr: Pubkey, ): info_resp = await connection.get_account_info(oracle_addr) return parse_price_data(b64decode(info_resp["result"]["value"]["data"][0])) @@ -310,7 +320,7 @@ async def get_oracle_data( async def mock_oracle( pyth_program: Program, price: int = int(50 * 10e7), expo=-7 -) -> PublicKey: +) -> Pubkey: price_feed_address = await create_price_feed( oracle_program=pyth_program, init_price=price, expo=expo ) diff --git a/src/driftpy/types.py b/src/driftpy/types.py index b41c9fba..8b037de7 100644 --- a/src/driftpy/types.py +++ b/src/driftpy/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from solana.publickey import PublicKey +from solders.pubkey import Pubkey from borsh_construct.enum import _rust_enum from sumtypes import constructor from typing import Optional @@ -342,7 +342,7 @@ class PoolBalance: @dataclass class AMM: - oracle: PublicKey + oracle: Pubkey historical_oracle_data: HistoricalOracleData base_asset_amount_per_lp: int quote_asset_amount_per_lp: int @@ -505,12 +505,12 @@ class Order: @dataclass class PhoenixV1FulfillmentConfig: - pubkey: PublicKey - phoenix_program_id: PublicKey - phoenix_log_authority: PublicKey - phoenix_market: PublicKey - phoenix_base_vault: PublicKey - phoenix_quote_vault: PublicKey + pubkey: Pubkey + phoenix_program_id: Pubkey + phoenix_log_authority: Pubkey + phoenix_market: Pubkey + phoenix_base_vault: Pubkey + phoenix_quote_vault: Pubkey market_index: int fulfillment_type: SpotFulfillmentType status: SpotFulfillmentConfigStatus @@ -518,16 +518,16 @@ class PhoenixV1FulfillmentConfig: @dataclass class SerumV3FulfillmentConfig: - pubkey: PublicKey - serum_program_id: PublicKey - serum_market: PublicKey - serum_request_queue: PublicKey - serum_event_queue: PublicKey - serum_bids: PublicKey - serum_asks: PublicKey - serum_base_vault: PublicKey - serum_quote_vault: PublicKey - serum_open_orders: PublicKey + pubkey: Pubkey + serum_program_id: Pubkey + serum_market: Pubkey + serum_request_queue: Pubkey + serum_event_queue: Pubkey + serum_bids: Pubkey + serum_asks: Pubkey + serum_base_vault: Pubkey + serum_quote_vault: Pubkey + serum_open_orders: Pubkey serum_signer_nonce: int market_index: int fulfillment_type: SpotFulfillmentType @@ -544,7 +544,7 @@ class InsuranceClaim: @dataclass class PerpMarket: - pubkey: PublicKey + pubkey: Pubkey amm: AMM pnl_pool: PoolBalance name: list[int] @@ -584,7 +584,7 @@ class HistoricalIndexData: @dataclass class InsuranceFund: - vault: PublicKey + vault: Pubkey total_shares: int user_shares: int shares_base: int @@ -596,10 +596,10 @@ class InsuranceFund: @dataclass class SpotMarket: - pubkey: PublicKey - oracle: PublicKey - mint: PublicKey - vault: PublicKey + pubkey: Pubkey + oracle: Pubkey + mint: Pubkey + vault: Pubkey name: list[int] historical_oracle_data: HistoricalOracleData historical_index_data: HistoricalIndexData @@ -652,11 +652,11 @@ class SpotMarket: @dataclass class State: - admin: PublicKey - whitelist_mint: PublicKey - discount_mint: PublicKey - signer: PublicKey - srm_vault: PublicKey + admin: Pubkey + whitelist_mint: Pubkey + discount_mint: Pubkey + signer: Pubkey + srm_vault: Pubkey perp_fee_structure: FeeStructure spot_fee_structure: FeeStructure oracle_guard_rails: OracleGuardRails @@ -696,8 +696,8 @@ class PerpPosition: @dataclass class User: - authority: PublicKey - delegate: PublicKey + authority: Pubkey + delegate: Pubkey name: list[int] spot_positions: list[SpotPosition] perp_positions: list[PerpPosition] @@ -735,8 +735,8 @@ class UserFees: @dataclass class UserStats: - authority: PublicKey - referrer: PublicKey + authority: Pubkey + referrer: Pubkey fees: UserFees next_epoch_ts: int maker_volume30d: int @@ -798,7 +798,7 @@ class PerpBankruptcyRecord: market_index: int pnl: int if_payment: int - clawback_user: Optional[PublicKey] + clawback_user: Optional[Pubkey] clawback_user_payment: Optional[int] cumulative_funding_rate_delta: int @@ -811,7 +811,7 @@ class SpotBankruptcyRecord: @dataclass class InsuranceFundStake: - authority: PublicKey + authority: Pubkey if_shares: int last_withdraw_request_shares: int if_base: int @@ -824,7 +824,7 @@ class InsuranceFundStake: @dataclass class ProtocolIfSharesTransferConfig: - whitelisted_signers: list[PublicKey] + whitelisted_signers: list[Pubkey] max_transfer_per_epoch: int current_epoch_transfer: int next_epoch_ts: int @@ -832,9 +832,9 @@ class ProtocolIfSharesTransferConfig: @dataclass class ReferrerName: - authority: PublicKey - user: PublicKey - user_stats: PublicKey + authority: Pubkey + user: Pubkey + user_stats: Pubkey name: list[int] @dataclass @@ -844,4 +844,9 @@ class OraclePriceData: confidence: int twap: int twap_confidence: int - has_sufficient_number_of_datapoints: bool \ No newline at end of file + has_sufficient_number_of_datapoints: bool + +@dataclass +class TxParams: + compute_units: Optional[int] + compute_units_price: Optional[int] \ No newline at end of file diff --git a/tests/test.py b/tests/test.py index e72be1cf..58c9bfee 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,7 +1,7 @@ from pytest import fixture, mark from pytest_asyncio import fixture as async_fixture -from solana.keypair import Keypair -from solana.publickey import PublicKey +from solders.keypair import Keypair +from solders.pubkey import Pubkey from anchorpy import Program, Provider, WorkspaceType, workspace_fixture from driftpy.admin import Admin from driftpy.constants.numeric_constants import ( @@ -80,7 +80,7 @@ def provider(program: Program) -> Provider: @async_fixture(scope="session") async def drift_client(program: Program, usdc_mint: Keypair) -> Admin: admin = Admin(program) - await admin.initialize(usdc_mint.public_key, admin_controls_prices=True) + await admin.initialize(usdc_mint.pubkey(), admin_controls_prices=True) return admin @@ -89,7 +89,7 @@ async def initialized_spot_market( drift_client: Admin, usdc_mint: Keypair, ): - await drift_client.initialize_spot_market(usdc_mint.public_key) + await drift_client.initialize_spot_market(usdc_mint.pubkey()) @mark.asyncio @@ -114,7 +114,7 @@ async def test_initialized_spot_market_2( main_liab_weight = int(SPOT_WEIGHT_PRECISION * 11 / 10) await admin_drift_client.initialize_spot_market( - mint.public_key, + mint.pubkey(), oracle=oracle, optimal_utilization=optimal_util, optimal_rate=optimal_weight, @@ -134,7 +134,7 @@ async def test_initialized_spot_market_2( @async_fixture(scope="session") async def initialized_market( drift_client: Admin, workspace: WorkspaceType -) -> PublicKey: +) -> Pubkey: pyth_program = workspace["pyth"] sol_usd = await mock_oracle(pyth_program=pyth_program, price=1) perp_market_index = 0 @@ -152,7 +152,7 @@ async def initialized_market( @mark.asyncio async def test_spot( drift_client: Admin, - initialized_spot_market: PublicKey, + initialized_spot_market: Pubkey, ): program = drift_client.program spot_market = await get_spot_market_account(program, 0) @@ -162,7 +162,7 @@ async def test_spot( @mark.asyncio async def test_market( drift_client: Admin, - initialized_market: PublicKey, + initialized_market: Pubkey, ): program = drift_client.program market_oracle_public_key = initialized_market @@ -190,9 +190,9 @@ async def test_usdc_deposit( ): usdc_spot_market = await get_spot_market_account(drift_client.program, 0) assert(usdc_spot_market.market_index == 0) - drift_client.spot_market_atas[0] = user_usdc_account.public_key + drift_client.spot_market_atas[0] = user_usdc_account.pubkey() await drift_client.deposit( - USDC_AMOUNT, 0, user_usdc_account.public_key, user_initialized=True + USDC_AMOUNT, 0, user_usdc_account.pubkey(), user_initialized=True ) user_account = await drift_client.get_user(0) assert ( @@ -215,7 +215,7 @@ async def test_update_curve( from driftpy.setup.helpers import set_price_feed_detailed pyth_program = workspace["pyth"] - slot = (await drift_client.program.provider.connection.get_slot())["result"] + slot = (await drift_client.program.provider.connection.get_slot()).value await set_price_feed_detailed(pyth_program, market.amm.oracle, 1.07, 0, slot) new_peg = int(market.amm.peg_multiplier * 1.05) @@ -306,7 +306,7 @@ async def test_stake_if( user_usdc_account: Keypair, ): # important - drift_client.usdc_ata = user_usdc_account.public_key + drift_client.usdc_ata = user_usdc_account.pubkey() await drift_client.update_update_insurance_fund_unstaking_period(0, 0) @@ -340,13 +340,13 @@ async def test_liq_perp( liq, _ = await _airdrop_user(drift_client.program.provider) liq_drift_client = DriftClient(drift_client.program, liq) usdc_acc = await _create_and_mint_user_usdc( - usdc_mint, drift_client.program.provider, USDC_AMOUNT, liq.public_key + usdc_mint, drift_client.program.provider, USDC_AMOUNT, liq.pubkey() ) await liq_drift_client.intialize_user() await liq_drift_client.deposit( USDC_AMOUNT, 0, - usdc_acc.public_key, + usdc_acc.pubkey(), ) from driftpy.constants.numeric_constants import AMM_RESERVE_PRECISION From 71190ee0e78403768088c89078f18cd2c7af7301 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Thu, 16 Nov 2023 17:59:15 -0500 Subject: [PATCH 2/7] update submodule --- protocol-v2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protocol-v2 b/protocol-v2 index b9e533e9..c1ae3038 160000 --- a/protocol-v2 +++ b/protocol-v2 @@ -1 +1 @@ -Subproject commit b9e533e96e3788e6f810797bc0e948e4da49f6a7 +Subproject commit c1ae303829f670b5030834b7fdbaf46c51fb636c From 99708f3b9bbc7bd1f459360bc5d6030a2b6c3d95 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Thu, 16 Nov 2023 19:38:59 -0500 Subject: [PATCH 3/7] add versioned tx --- src/driftpy/drift_client.py | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index f86b0702..f364ef6c 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -3,6 +3,9 @@ from typing import Optional from solders.keypair import Keypair from solana.transaction import Transaction +from solders.transaction import VersionedTransaction +from solders.transaction import TransactionVersion, Legacy +from solders.message import MessageV0 from solders.instruction import Instruction from solders.system_program import ID from solders.sysvar import RENT @@ -35,7 +38,7 @@ class DriftClient: depositing, opening new positions, closing positions, placing orders, etc. """ - def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None): + def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None, tx_version: Optional[TransactionVersion] = None): """Initializes the drift client object -- likely want to use the .from_config method instead of this one Args: @@ -69,6 +72,8 @@ def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = self.tx_params = tx_params + self.tx_version = tx_version if tx_version is not None else Legacy + @staticmethod def from_config(config: Config, provider: Provider, authority: Keypair = None): """Initializes the drift client object from a Config @@ -139,24 +144,32 @@ async def send_ixs( if isinstance(ixs, Instruction): ixs = [ixs] - tx = Transaction() - if self.tx_params.compute_units is not None: - tx.add(set_compute_unit_limit(self.tx_params.compute_units)) + ixs.insert(0, set_compute_unit_limit(self.tx_params.compute_units)) if self.tx_params.compute_units_price is not None: - tx.add(set_compute_unit_price(self.tx_params.compute_units_price)) - - [tx.add(ix) for ix in ixs] + ixs.insert(1, set_compute_unit_price(self.tx_params.compute_units_price)) latest_blockhash = (await self.program.provider.connection.get_latest_blockhash()).value.blockhash - tx.recent_blockhash = latest_blockhash - tx.fee_payer = self.signer.pubkey() - tx.sign_partial(self.signer) + if self.tx_version == Legacy: + tx = Transaction( + instructions=ixs, + recent_blockhash=latest_blockhash, + fee_payer=self.signer.pubkey() + ) - if signers is not None: - [tx.sign_partial(signer) for signer in signers] + tx.sign_partial(self.signer) + + if signers is not None: + [tx.sign_partial(signer) for signer in signers] + elif self.tx_version == 0: + msg = MessageV0.try_compile( + self.signer.pubkey(), ixs, [], latest_blockhash + ) + tx = VersionedTransaction(msg, [self.signer]) + else: + raise NotImplementedError("unknown tx version", self.tx_version) return await self.program.provider.send(tx) From 465c79664d13523505dbf7b5f466ba100b2f3458 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Thu, 16 Nov 2023 20:09:34 -0500 Subject: [PATCH 4/7] add black to pre-commit --- .pre-commit-config.yaml | 6 + Makefile | 5 +- poetry.lock | 83 +++++--- pyproject.toml | 4 +- src/driftpy/_types.py | 21 +- src/driftpy/accounts/__init__.py | 2 +- src/driftpy/accounts/cache/__init__.py | 2 +- src/driftpy/accounts/cache/drift_client.py | 40 ++-- src/driftpy/accounts/cache/user.py | 7 +- src/driftpy/accounts/get_accounts.py | 33 +-- src/driftpy/accounts/oracle.py | 51 +++-- src/driftpy/accounts/types.py | 25 ++- src/driftpy/addresses.py | 4 +- src/driftpy/admin.py | 9 +- src/driftpy/constants/config.py | 8 +- src/driftpy/drift_client.py | 65 +++--- src/driftpy/drift_user.py | 25 ++- src/driftpy/math/amm.py | 2 +- src/driftpy/math/market.py | 3 +- src/driftpy/math/repeg.py | 5 +- src/driftpy/setup/helpers.py | 16 +- src/driftpy/types.py | 227 ++++++++++++++------- tests/test.py | 24 +-- 23 files changed, 426 insertions(+), 241 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..745ff771 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/psf/black + rev: 23.11.0 + hooks: + - id: black + language_version: python3.10 \ No newline at end of file diff --git a/Makefile b/Makefile index 5c58e85e..b82c4419 100644 --- a/Makefile +++ b/Makefile @@ -4,5 +4,8 @@ test: lint: poetry run black --check --diff src tests - poetry run flake8 src tests poetry run mypy src tests + +lintfix: + poetry run black src tests + diff --git a/poetry.lock b/poetry.lock index 6ff0858e..b385e42d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -298,6 +298,22 @@ docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] tests-no-zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] +[[package]] +name = "autopep8" +version = "2.0.4" +description = "A tool that automatically formats Python code to conform to the PEP 8 style guide" +category = "dev" +optional = false +python-versions = ">=3.6" +files = [ + {file = "autopep8-2.0.4-py2.py3-none-any.whl", hash = "sha256:067959ca4a07b24dbd5345efa8325f5f58da4298dab0dde0443d5ed765de80cb"}, + {file = "autopep8-2.0.4.tar.gz", hash = "sha256:2913064abd97b3419d1cc83ea71f042cb821f87e45b9c88cad5ad3c4ea87fe0c"}, +] + +[package.dependencies] +pycodestyle = ">=2.10.0" +tomli = {version = "*", markers = "python_version < \"3.11\""} + [[package]] name = "backoff" version = "2.2.1" @@ -1237,48 +1253,57 @@ files = [ [[package]] name = "mypy" -version = "0.931" +version = "1.7.0" description = "Optional static typing for Python" -category = "dev" +category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "mypy-0.931-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3c5b42d0815e15518b1f0990cff7a705805961613e701db60387e6fb663fe78a"}, - {file = "mypy-0.931-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c89702cac5b302f0c5d33b172d2b55b5df2bede3344a2fbed99ff96bddb2cf00"}, - {file = "mypy-0.931-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:300717a07ad09525401a508ef5d105e6b56646f7942eb92715a1c8d610149714"}, - {file = "mypy-0.931-cp310-cp310-win_amd64.whl", hash = "sha256:7b3f6f557ba4afc7f2ce6d3215d5db279bcf120b3cfd0add20a5d4f4abdae5bc"}, - {file = "mypy-0.931-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:1bf752559797c897cdd2c65f7b60c2b6969ffe458417b8d947b8340cc9cec08d"}, - {file = "mypy-0.931-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:4365c60266b95a3f216a3047f1d8e3f895da6c7402e9e1ddfab96393122cc58d"}, - {file = "mypy-0.931-cp36-cp36m-win_amd64.whl", hash = "sha256:1b65714dc296a7991000b6ee59a35b3f550e0073411ac9d3202f6516621ba66c"}, - {file = "mypy-0.931-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e839191b8da5b4e5d805f940537efcaa13ea5dd98418f06dc585d2891d228cf0"}, - {file = "mypy-0.931-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:50c7346a46dc76a4ed88f3277d4959de8a2bd0a0fa47fa87a4cde36fe247ac05"}, - {file = "mypy-0.931-cp37-cp37m-win_amd64.whl", hash = "sha256:d8f1ff62f7a879c9fe5917b3f9eb93a79b78aad47b533911b853a757223f72e7"}, - {file = "mypy-0.931-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f9fe20d0872b26c4bba1c1be02c5340de1019530302cf2dcc85c7f9fc3252ae0"}, - {file = "mypy-0.931-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1b06268df7eb53a8feea99cbfff77a6e2b205e70bf31743e786678ef87ee8069"}, - {file = "mypy-0.931-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8c11003aaeaf7cc2d0f1bc101c1cc9454ec4cc9cb825aef3cafff8a5fdf4c799"}, - {file = "mypy-0.931-cp38-cp38-win_amd64.whl", hash = "sha256:d9d2b84b2007cea426e327d2483238f040c49405a6bf4074f605f0156c91a47a"}, - {file = "mypy-0.931-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ff3bf387c14c805ab1388185dd22d6b210824e164d4bb324b195ff34e322d166"}, - {file = "mypy-0.931-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5b56154f8c09427bae082b32275a21f500b24d93c88d69a5e82f3978018a0266"}, - {file = "mypy-0.931-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8ca7f8c4b1584d63c9a0f827c37ba7a47226c19a23a753d52e5b5eddb201afcd"}, - {file = "mypy-0.931-cp39-cp39-win_amd64.whl", hash = "sha256:74f7eccbfd436abe9c352ad9fb65872cc0f1f0a868e9d9c44db0893440f0c697"}, - {file = "mypy-0.931-py3-none-any.whl", hash = "sha256:1171f2e0859cfff2d366da2c7092b06130f232c636a3f7301e3feb8b41f6377d"}, - {file = "mypy-0.931.tar.gz", hash = "sha256:0038b21890867793581e4cb0d810829f5fd4441aa75796b53033af3aa30430ce"}, + {file = "mypy-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5da84d7bf257fd8f66b4f759a904fd2c5a765f70d8b52dde62b521972a0a2357"}, + {file = "mypy-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3637c03f4025f6405737570d6cbfa4f1400eb3c649317634d273687a09ffc2f"}, + {file = "mypy-1.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b633f188fc5ae1b6edca39dae566974d7ef4e9aaaae00bc36efe1f855e5173ac"}, + {file = "mypy-1.7.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d6ed9a3997b90c6f891138e3f83fb8f475c74db4ccaa942a1c7bf99e83a989a1"}, + {file = "mypy-1.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:1fe46e96ae319df21359c8db77e1aecac8e5949da4773c0274c0ef3d8d1268a9"}, + {file = "mypy-1.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:df67fbeb666ee8828f675fee724cc2cbd2e4828cc3df56703e02fe6a421b7401"}, + {file = "mypy-1.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a79cdc12a02eb526d808a32a934c6fe6df07b05f3573d210e41808020aed8b5d"}, + {file = "mypy-1.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f65f385a6f43211effe8c682e8ec3f55d79391f70a201575def73d08db68ead1"}, + {file = "mypy-1.7.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e81ffd120ee24959b449b647c4b2fbfcf8acf3465e082b8d58fd6c4c2b27e46"}, + {file = "mypy-1.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:f29386804c3577c83d76520abf18cfcd7d68264c7e431c5907d250ab502658ee"}, + {file = "mypy-1.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:87c076c174e2c7ef8ab416c4e252d94c08cd4980a10967754f91571070bf5fbe"}, + {file = "mypy-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6cb8d5f6d0fcd9e708bb190b224089e45902cacef6f6915481806b0c77f7786d"}, + {file = "mypy-1.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93e76c2256aa50d9c82a88e2f569232e9862c9982095f6d54e13509f01222fc"}, + {file = "mypy-1.7.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cddee95dea7990e2215576fae95f6b78a8c12f4c089d7e4367564704e99118d3"}, + {file = "mypy-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:d01921dbd691c4061a3e2ecdbfbfad029410c5c2b1ee88946bf45c62c6c91210"}, + {file = "mypy-1.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:185cff9b9a7fec1f9f7d8352dff8a4c713b2e3eea9c6c4b5ff7f0edf46b91e41"}, + {file = "mypy-1.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7b1e399c47b18feb6f8ad4a3eef3813e28c1e871ea7d4ea5d444b2ac03c418"}, + {file = "mypy-1.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc9fe455ad58a20ec68599139ed1113b21f977b536a91b42bef3ffed5cce7391"}, + {file = "mypy-1.7.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d0fa29919d2e720c8dbaf07d5578f93d7b313c3e9954c8ec05b6d83da592e5d9"}, + {file = "mypy-1.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b53655a295c1ed1af9e96b462a736bf083adba7b314ae775563e3fb4e6795f5"}, + {file = "mypy-1.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c1b06b4b109e342f7dccc9efda965fc3970a604db70f8560ddfdee7ef19afb05"}, + {file = "mypy-1.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bf7a2f0a6907f231d5e41adba1a82d7d88cf1f61a70335889412dec99feeb0f8"}, + {file = "mypy-1.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551d4a0cdcbd1d2cccdcc7cb516bb4ae888794929f5b040bb51aae1846062901"}, + {file = "mypy-1.7.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:55d28d7963bef00c330cb6461db80b0b72afe2f3c4e2963c99517cf06454e665"}, + {file = "mypy-1.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:870bd1ffc8a5862e593185a4c169804f2744112b4a7c55b93eb50f48e7a77010"}, + {file = "mypy-1.7.0-py3-none-any.whl", hash = "sha256:96650d9a4c651bc2a4991cf46f100973f656d69edc7faf91844e87fe627f7e96"}, + {file = "mypy-1.7.0.tar.gz", hash = "sha256:1e280b5697202efa698372d2f39e9a6713a0395a756b1c6bd48995f8d72690dc"}, ] [package.dependencies] -mypy-extensions = ">=0.4.3" -tomli = ">=1.1.0" -typing-extensions = ">=3.10" +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" [package.extras] dmypy = ["psutil (>=4.0)"] -python2 = ["typed-ast (>=1.4.0,<2)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] [[package]] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" +category = "main" optional = false python-versions = ">=3.5" files = [ @@ -2292,4 +2317,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "fc879d686ae6b2d2eedcde15487706afc24e62d3c5fd2a2ca0181867f18f7593" +content-hash = "5fc096eea3c4b4688a49b7e41849119f07b949b6f941b6f65dc8579ec1b289e8" diff --git a/pyproject.toml b/pyproject.toml index a0b26c83..71449c0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,17 +79,19 @@ websockets = "10.4" yarl = "1.8.2" zstandard = "0.18.0" jinja2 = "<3.1" +mypy = "^1.7.0" [tool.poetry.dev-dependencies] pytest = "^7.2.0" flake8 = "6.0.0" -mypy = "^0.931" black = "^23.3.0" pytest-asyncio = "^0.21.0" mkdocs = "^1.3.0" mkdocstrings = "^0.17.0" mkdocs-material = "^8.1.8" bump2version = "^1.0.1" +autopep8 = "^2.0.4" +mypy = "^1.7.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/src/driftpy/_types.py b/src/driftpy/_types.py index 1e11393d..531311e6 100644 --- a/src/driftpy/_types.py +++ b/src/driftpy/_types.py @@ -1,3 +1,5 @@ +from driftpy.constants.numeric_constants import SPOT_RATE_PRECISION +from driftpy.types import OracleSource from typing import Optional, Any from dataclasses import dataclass from sumtypes import constructor # type: ignore @@ -287,9 +289,10 @@ class MarketPosition: padding3: int padding4: int - ## dw why this doesnt register :( + # dw why this doesnt register :( # def is_available(self): - # return self.base_asset_amount == 0 and self.open_orders == 0 and self.lp_shares == 0 + # return self.base_asset_amount == 0 and self.open_orders == 0 and + # self.lp_shares == 0 @dataclass @@ -377,14 +380,14 @@ class AMM: quote_asset_amount_long: int = 0 quote_asset_amount_short: int = 0 - ## lp stuff + # lp stuff cumulative_funding_payment_per_lp: int = 0 cumulative_fee_per_lp: int = 0 cumulative_base_asset_amount_with_amm_per_lp: int = 0 lp_cooldown_time: int = 0 user_lp_shares: int = 0 - ## funding + # funding last_funding_rate: int = 0 last_funding_rate_ts: int = 0 funding_period: int = 0 @@ -397,11 +400,11 @@ class AMM: last_mark_price_twap: int = 0 last_mark_price_twap_ts: int = 0 - ## trade constraints + # trade constraints minimum_quote_asset_trade_size: int = 0 base_asset_amount_step_size: int = 0 - ## market making + # market making base_spread: int = 0 long_spread: int = 0 short_spread: int = 0 @@ -420,7 +423,7 @@ class AMM: short_intensity_volume: int = 0 curve_update_intensity: int = 0 - ## fee tracking + # fee tracking total_fee: int = 0 total_mm_fee: int = 0 total_exchange_fee: int = 0 @@ -462,10 +465,6 @@ class Market: padding4: int = 0 -from driftpy.types import OracleSource -from driftpy.constants.numeric_constants import SPOT_RATE_PRECISION - - @dataclass class SpotMarket: mint: PublicKey # this diff --git a/src/driftpy/accounts/__init__.py b/src/driftpy/accounts/__init__.py index 4c8643a9..a7722442 100644 --- a/src/driftpy/accounts/__init__.py +++ b/src/driftpy/accounts/__init__.py @@ -1,2 +1,2 @@ from .get_accounts import * -from .types import * \ No newline at end of file +from .types import * diff --git a/src/driftpy/accounts/cache/__init__.py b/src/driftpy/accounts/cache/__init__.py index f31516fc..58e298b4 100644 --- a/src/driftpy/accounts/cache/__init__.py +++ b/src/driftpy/accounts/cache/__init__.py @@ -1,2 +1,2 @@ from .drift_client import * -from .user import * \ No newline at end of file +from .user import * diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py index 37db66e3..29c20eb9 100644 --- a/src/driftpy/accounts/cache/drift_client.py +++ b/src/driftpy/accounts/cache/drift_client.py @@ -2,8 +2,11 @@ from solders.pubkey import Pubkey from solana.rpc.commitment import Commitment -from driftpy.accounts import get_state_account_and_slot, get_spot_market_account_and_slot, \ - get_perp_market_account_and_slot +from driftpy.accounts import ( + get_state_account_and_slot, + get_spot_market_account_and_slot, + get_perp_market_account_and_slot, +) from driftpy.accounts.oracle import get_oracle_price_data_and_slot from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot from typing import Optional @@ -28,30 +31,37 @@ async def update_cache(self): spot_markets = [] for i in range(state_and_slot.data.number_of_spot_markets): - spot_market_and_slot = await get_spot_market_account_and_slot(self.program, i) + spot_market_and_slot = await get_spot_market_account_and_slot( + self.program, i + ) spot_markets.append(spot_market_and_slot) oracle_price_data_and_slot = await get_oracle_price_data_and_slot( self.program.provider.connection, spot_market_and_slot.data.oracle, - spot_market_and_slot.data.oracle_source - + spot_market_and_slot.data.oracle_source, ) - oracle_data[str(spot_market_and_slot.data.oracle)] = oracle_price_data_and_slot + oracle_data[ + str(spot_market_and_slot.data.oracle) + ] = oracle_price_data_and_slot self.cache["spot_markets"] = spot_markets perp_markets = [] for i in range(state_and_slot.data.number_of_markets): - perp_market_and_slot = await get_perp_market_account_and_slot(self.program, i) + perp_market_and_slot = await get_perp_market_account_and_slot( + self.program, i + ) perp_markets.append(perp_market_and_slot) oracle_price_data_and_slot = await get_oracle_price_data_and_slot( self.program.provider.connection, perp_market_and_slot.data.amm.oracle, - perp_market_and_slot.data.amm.oracle_source + perp_market_and_slot.data.amm.oracle_source, ) - oracle_data[str(perp_market_and_slot.data.amm.oracle)] = oracle_price_data_and_slot + oracle_data[ + str(perp_market_and_slot.data.amm.oracle) + ] = oracle_price_data_and_slot self.cache["perp_markets"] = perp_markets @@ -61,15 +71,21 @@ async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: await self.cache_if_needed() return self.cache["state"] - async def get_perp_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: + async def get_perp_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[PerpMarket]]: await self.cache_if_needed() return self.cache["perp_markets"][market_index] - async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[SpotMarket]]: + async def get_spot_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[SpotMarket]]: await self.cache_if_needed() return self.cache["spot_markets"][market_index] - async def get_oracle_data_and_slot(self, oracle: Pubkey) -> Optional[DataAndSlot[OraclePriceData]]: + async def get_oracle_data_and_slot( + self, oracle: Pubkey + ) -> Optional[DataAndSlot[OraclePriceData]]: await self.cache_if_needed() return self.cache["oracle_price_data"][str(oracle)] diff --git a/src/driftpy/accounts/cache/user.py b/src/driftpy/accounts/cache/user.py index 357f393f..18a3c96f 100644 --- a/src/driftpy/accounts/cache/user.py +++ b/src/driftpy/accounts/cache/user.py @@ -10,7 +10,12 @@ class CachedUserAccountSubscriber(UserAccountSubscriber): - def __init__(self, user_pubkey: Pubkey, program: Program, commitment: Commitment = "confirmed"): + def __init__( + self, + user_pubkey: Pubkey, + program: Program, + commitment: Commitment = "confirmed", + ): self.program = program self.commitment = commitment self.user_pubkey = user_pubkey diff --git a/src/driftpy/accounts/get_accounts.py b/src/driftpy/accounts/get_accounts.py index 9a56dfcb..31fc8837 100644 --- a/src/driftpy/accounts/get_accounts.py +++ b/src/driftpy/accounts/get_accounts.py @@ -9,8 +9,9 @@ from .types import DataAndSlot, T -async def get_account_data_and_slot(address: Pubkey, program: Program, commitment: Commitment = "processed") -> Optional[ - DataAndSlot[T]]: +async def get_account_data_and_slot( + address: Pubkey, program: Program, commitment: Commitment = "processed" +) -> Optional[DataAndSlot[T]]: account_info = await program.provider.connection.get_account_info( address, encoding="base64", @@ -38,7 +39,7 @@ async def get_state_account(program: Program) -> State: async def get_if_stake_account( - program: Program, authority: Pubkey, spot_market_index: int + program: Program, authority: Pubkey, spot_market_index: int ) -> InsuranceFundStake: if_stake_pk = get_insurance_fund_stake_public_key( program.program_id, authority, spot_market_index @@ -48,8 +49,8 @@ async def get_if_stake_account( async def get_user_stats_account( - program: Program, - authority: Pubkey, + program: Program, + authority: Pubkey, ) -> UserStats: user_stats_public_key = get_user_stats_account_public_key( program.program_id, @@ -58,21 +59,27 @@ async def get_user_stats_account( response = await program.account["UserStats"].fetch(user_stats_public_key) return cast(UserStats, response) + async def get_user_account_and_slot( - program: Program, - user_public_key: Pubkey, + program: Program, + user_public_key: Pubkey, ) -> DataAndSlot[User]: return await get_account_data_and_slot(user_public_key, program) + async def get_user_account( - program: Program, - user_public_key: Pubkey, + program: Program, + user_public_key: Pubkey, ) -> User: return (await get_user_account_and_slot(program, user_public_key)).data -async def get_perp_market_account_and_slot(program: Program, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: - perp_market_public_key = get_perp_market_public_key(program.program_id, market_index) +async def get_perp_market_account_and_slot( + program: Program, market_index: int +) -> Optional[DataAndSlot[PerpMarket]]: + perp_market_public_key = get_perp_market_public_key( + program.program_id, market_index + ) return await get_account_data_and_slot(perp_market_public_key, program) @@ -85,7 +92,7 @@ async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount] async def get_spot_market_account_and_slot( - program: Program, spot_market_index: int + program: Program, spot_market_index: int ) -> DataAndSlot[SpotMarket]: spot_market_public_key = get_spot_market_public_key( program.program_id, spot_market_index @@ -94,7 +101,7 @@ async def get_spot_market_account_and_slot( async def get_spot_market_account( - program: Program, spot_market_index: int + program: Program, spot_market_index: int ) -> SpotMarket: return (await get_spot_market_account_and_slot(program, spot_market_index)).data diff --git a/src/driftpy/accounts/oracle.py b/src/driftpy/accounts/oracle.py index f6c8c373..5e2a28c8 100644 --- a/src/driftpy/accounts/oracle.py +++ b/src/driftpy/accounts/oracle.py @@ -10,20 +10,25 @@ import base64 import struct + def convert_pyth_price(price, scale=1): return int(price * PRICE_PRECISION * scale) -async def get_oracle_price_data_and_slot(connection: AsyncClient, address: Pubkey, oracle_source=OracleSource.PYTH()) -> DataAndSlot[ - OraclePriceData]: - if 'Pyth' in str(oracle_source): + +async def get_oracle_price_data_and_slot( + connection: AsyncClient, address: Pubkey, oracle_source=OracleSource.PYTH() +) -> DataAndSlot[OraclePriceData]: + if "Pyth" in str(oracle_source): rpc_reponse = await connection.get_account_info(address) rpc_response_slot = rpc_reponse.context.slot - (pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info(rpc_reponse) + (pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info( + rpc_reponse + ) scale = 1 - if '1K' in str(oracle_source): + if "1K" in str(oracle_source): scale = 1e3 - elif '1M' in str(oracle_source): + elif "1M" in str(oracle_source): scale = 1e6 oracle_data = OraclePriceData( @@ -36,29 +41,41 @@ async def get_oracle_price_data_and_slot(connection: AsyncClient, address: Pubke ) return DataAndSlot(data=oracle_data, slot=rpc_response_slot) - elif 'Quote' in str(oracle_source): - return DataAndSlot(data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0) + elif "Quote" in str(oracle_source): + return DataAndSlot( + data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0 + ) else: - raise NotImplementedError('Unsupported Oracle Source', str(oracle_source)) + raise NotImplementedError("Unsupported Oracle Source", str(oracle_source)) -async def _parse_pyth_price_info(resp: GetAccountInfoResp) -> (PythPriceInfo, int, int, int): + +async def _parse_pyth_price_info( + resp: GetAccountInfoResp, +) -> (PythPriceInfo, int, int, int): buffer = resp.value.data offset = _ACCOUNT_HEADER_BYTES _, exponent, _ = struct.unpack_from(" Optional[DataAndSlot[State]]: pass @abstractmethod - async def get_perp_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[PerpMarket]]: + async def get_perp_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[PerpMarket]]: pass @abstractmethod - async def get_spot_market_and_slot(self, market_index: int) -> Optional[DataAndSlot[SpotMarket]]: + async def get_spot_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[SpotMarket]]: pass @abstractmethod - async def get_oracle_data_and_slot(self, oracle: Pubkey) -> Optional[DataAndSlot[OraclePriceData]]: + async def get_oracle_data_and_slot( + self, oracle: Pubkey + ) -> Optional[DataAndSlot[OraclePriceData]]: pass + class UserAccountSubscriber: @abstractmethod async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: - pass \ No newline at end of file + pass diff --git a/src/driftpy/addresses.py b/src/driftpy/addresses.py index 048450bf..8aa58cf1 100644 --- a/src/driftpy/addresses.py +++ b/src/driftpy/addresses.py @@ -77,9 +77,7 @@ def get_user_stats_account_public_key( program_id: Pubkey, authority: Pubkey, ) -> Pubkey: - return Pubkey.find_program_address( - [b"user_stats", bytes(authority)], program_id - )[0] + return Pubkey.find_program_address([b"user_stats", bytes(authority)], program_id)[0] def get_user_account_public_key( diff --git a/src/driftpy/admin.py b/src/driftpy/admin.py index 9c8c634b..ddd4bb4e 100644 --- a/src/driftpy/admin.py +++ b/src/driftpy/admin.py @@ -1,4 +1,3 @@ - from solders.pubkey import Pubkey from solders.signature import Signature from solders.keypair import Keypair @@ -74,9 +73,7 @@ async def initialize( "admin": self.authority, "state": state_public_key, "quote_asset_mint": usdc_mint, - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "rent": RENT, "system_program": ID, "token_program": TOKEN_PROGRAM_ID, @@ -181,9 +178,7 @@ async def initialize_spot_market( "spot_market": spot_public_key, "spot_market_vault": spot_vault_public_key, "insurance_fund_vault": insurance_vault_public_key, - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "spot_market_mint": mint, "oracle": oracle, "rent": RENT, diff --git a/src/driftpy/constants/config.py b/src/driftpy/constants/config.py index 3ac44389..0329373d 100644 --- a/src/driftpy/constants/config.py +++ b/src/driftpy/constants/config.py @@ -25,7 +25,9 @@ class Config: drift_client_program_id=Pubkey.from_string( "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" ), - usdc_mint_address=Pubkey.from_string("8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2"), + usdc_mint_address=Pubkey.from_string( + "8zGuJQqwhZafTah7Uc7Z4tXRnguqkn5KLFAP8oV6PHe2" + ), default_http="https://api.devnet.solana.com", default_ws="wss://api.devnet.solana.com", markets=devnet_markets, @@ -39,7 +41,9 @@ class Config: drift_client_program_id=Pubkey.from_string( "dRiftyHA39MWEi3m9aunc5MzRF1JYuBsbn6VPcn33UH" ), - usdc_mint_address=Pubkey.from_string("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"), + usdc_mint_address=Pubkey.from_string( + "EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v" + ), default_http="https://api.mainnet-beta.solana.com", default_ws="wss://api.mainnet-beta.solana.com", markets=mainnet_markets, diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index f364ef6c..d96e232e 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -33,12 +33,21 @@ DEFAULT_USER_NAME = "Main Account" + class DriftClient: """This class is the main way to interact with Drift Protocol including depositing, opening new positions, closing positions, placing orders, etc. """ - def __init__(self, program: Program, signer: Keypair = None, authority: Pubkey = None, account_subscriber: Optional[DriftClientAccountSubscriber] = None, tx_params: Optional[TxParams] = None, tx_version: Optional[TransactionVersion] = None): + def __init__( + self, + program: Program, + signer: Keypair = None, + authority: Pubkey = None, + account_subscriber: Optional[DriftClientAccountSubscriber] = None, + tx_params: Optional[TxParams] = None, + tx_version: Optional[TransactionVersion] = None, + ): """Initializes the drift client object -- likely want to use the .from_config method instead of this one Args: @@ -112,7 +121,9 @@ def get_user_account_public_key(self, user_id=0) -> Pubkey: return get_user_account_public_key(self.program_id, self.authority, user_id) async def get_user(self, user_id=0) -> User: - return await get_user_account(self.program, self.get_user_account_public_key(user_id)) + return await get_user_account( + self.program, self.get_user_account_public_key(user_id) + ) def get_state_public_key(self): return get_state_public_key(self.program_id) @@ -122,19 +133,25 @@ def get_user_stats_public_key(self): async def get_state(self) -> Optional[State]: state_and_slot = await self.account_subscriber.get_state_account_and_slot() - return getattr(state_and_slot, 'data', None) + return getattr(state_and_slot, "data", None) async def get_perp_market(self, market_index: int) -> Optional[PerpMarket]: - perp_market_and_slot = await self.account_subscriber.get_perp_market_and_slot(market_index) - return getattr(perp_market_and_slot, 'data', None) + perp_market_and_slot = await self.account_subscriber.get_perp_market_and_slot( + market_index + ) + return getattr(perp_market_and_slot, "data", None) async def get_spot_market(self, market_index: int) -> Optional[SpotMarket]: - spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot(market_index) - return getattr(spot_market_and_slot, 'data', None) + spot_market_and_slot = await self.account_subscriber.get_spot_market_and_slot( + market_index + ) + return getattr(spot_market_and_slot, "data", None) async def get_oracle_price_data(self, oracle: Pubkey) -> Optional[OraclePriceData]: - oracle_price_data_and_slot = await self.account_subscriber.get_oracle_data_and_slot(oracle) - return getattr(oracle_price_data_and_slot, 'data', None) + oracle_price_data_and_slot = ( + await self.account_subscriber.get_oracle_data_and_slot(oracle) + ) + return getattr(oracle_price_data_and_slot, "data", None) async def send_ixs( self, @@ -150,13 +167,15 @@ async def send_ixs( if self.tx_params.compute_units_price is not None: ixs.insert(1, set_compute_unit_price(self.tx_params.compute_units_price)) - latest_blockhash = (await self.program.provider.connection.get_latest_blockhash()).value.blockhash + latest_blockhash = ( + await self.program.provider.connection.get_latest_blockhash() + ).value.blockhash if self.tx_version == Legacy: tx = Transaction( instructions=ixs, recent_blockhash=latest_blockhash, - fee_payer=self.signer.pubkey() + fee_payer=self.signer.pubkey(), ) tx.sign_partial(self.signer) @@ -164,9 +183,7 @@ async def send_ixs( if signers is not None: [tx.sign_partial(signer) for signer in signers] elif self.tx_version == 0: - msg = MessageV0.try_compile( - self.signer.pubkey(), ixs, [], latest_blockhash - ) + msg = MessageV0.try_compile(self.signer.pubkey(), ixs, [], latest_blockhash) tx = VersionedTransaction(msg, [self.signer]) else: raise NotImplementedError("unknown tx version", self.tx_version) @@ -268,7 +285,9 @@ async def get_remaining_accounts( accounts = [] for pk, id in zip(authority, user_id): - user_public_key = get_user_account_public_key(self.program.program_id, pk, id) + user_public_key = get_user_account_public_key( + self.program.program_id, pk, id + ) user_account = await get_user_account(self.program, user_public_key) accounts.append(user_account) @@ -1528,9 +1547,7 @@ async def get_cancel_request_remove_insurance_fund_stake_ix( "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "user_token_account": self.spot_market_atas[spot_market_index], "token_program": TOKEN_PROGRAM_ID, }, @@ -1567,9 +1584,7 @@ async def get_remove_insurance_fund_stake_ix(self, spot_market_index: int): "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "user_token_account": self.spot_market_atas[spot_market_index], "token_program": TOKEN_PROGRAM_ID, }, @@ -1617,9 +1632,7 @@ async def get_add_insurance_fund_stake_ix( "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "user_token_account": self.spot_market_atas[spot_market_index], "token_program": TOKEN_PROGRAM_ID, }, @@ -1715,9 +1728,7 @@ async def settle_revenue_to_insurance_fund(self, spot_market_index: int): "spot_market_vault": get_spot_market_vault_public_key( self.program_id, spot_market_index ), - "drift_signer": get_drift_client_signer_public_key( - self.program_id - ), + "drift_signer": get_drift_client_signer_public_key(self.program_id), "insurance_fund_vault": get_insurance_fund_vault_public_key( self.program_id, spot_market_index, diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index c9934636..041a64f2 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -35,18 +35,25 @@ def __init__( self.connection = self.program.provider.connection self.subaccount_id = subaccount_id - self.user_public_key = get_user_account_public_key(self.program.program_id, self.authority, self.subaccount_id) + self.user_public_key = get_user_account_public_key( + self.program.program_id, self.authority, self.subaccount_id + ) if account_subscriber is None: - account_subscriber = CachedUserAccountSubscriber(self.user_public_key, self.program) + account_subscriber = CachedUserAccountSubscriber( + self.user_public_key, self.program + ) self.account_subscriber = account_subscriber - - async def get_spot_oracle_data(self, spot_market: SpotMarket) -> Optional[OraclePriceData]: + async def get_spot_oracle_data( + self, spot_market: SpotMarket + ) -> Optional[OraclePriceData]: return await self.drift_client.get_oracle_price_data(spot_market.oracle) - async def get_perp_oracle_data(self, perp_market: PerpMarket) -> Optional[OraclePriceData]: + async def get_perp_oracle_data( + self, perp_market: PerpMarket + ) -> Optional[OraclePriceData]: return await self.drift_client.get_oracle_price_data(perp_market.amm.oracle) async def get_state(self) -> State: @@ -356,11 +363,11 @@ async def get_spot_market_asset_value( if not include_open_orders: token_amount = get_token_amount( - position.scaled_balance, spot_market, position.balance_type - ) + position.scaled_balance, spot_market, position.balance_type + ) spot_token_value = get_spot_asset_value( - token_amount, oracle_data, spot_market, margin_category - ) + token_amount, oracle_data, spot_market, margin_category + ) match str(position.balance_type): case "SpotBalanceType.Deposit()": spot_token_value *= 1 diff --git a/src/driftpy/math/amm.py b/src/driftpy/math/amm.py index 7470143f..ccb6be43 100644 --- a/src/driftpy/math/amm.py +++ b/src/driftpy/math/amm.py @@ -169,7 +169,7 @@ def get_swap_direction( def calculate_budgeted_repeg(amm, cost, target_px=None, pay_only=False): - if target_px == None: + if target_px is None: target_px = amm.last_oracle_price # / 1e10 assert amm.last_oracle_price != 0 diff --git a/src/driftpy/math/market.py b/src/driftpy/math/market.py index c793bbe9..410d800d 100644 --- a/src/driftpy/math/market.py +++ b/src/driftpy/math/market.py @@ -90,7 +90,8 @@ def calculate_candidate_amm(market, oracle_price=None): base_scale = 1 quote_scale = 1 - budget_cost = None # max(0, (market.amm.total_fee_minus_distributions/1e6)/2) + # max(0, (market.amm.total_fee_minus_distributions/1e6)/2) + budget_cost = None fee_pool = (market.amm.total_fee_minus_distributions / QUOTE_PRECISION) - ( market.amm.total_fee / QUOTE_PRECISION ) / 2 diff --git a/src/driftpy/math/repeg.py b/src/driftpy/math/repeg.py index 1539017c..dbf8c4ce 100644 --- a/src/driftpy/math/repeg.py +++ b/src/driftpy/math/repeg.py @@ -1,3 +1,4 @@ +from driftpy.constants.numeric_constants import * from driftpy.math.amm import calculate_terminal_price, calculate_budgeted_repeg from driftpy.math.positions import calculate_base_asset_value, calculate_position_pnl from driftpy.types import PerpPosition @@ -216,10 +217,6 @@ def calculate_buyout_cost(market, market_index, new_peg, sqrt_k): return cost / 1e6, marketNewK -from driftpy.types import AMM -from driftpy.constants.numeric_constants import * - - def calculate_repeg_cost(amm: AMM, new_peg: int) -> int: dqar = amm.quote_asset_reserve - amm.terminal_quote_asset_reserve cost = ( diff --git a/src/driftpy/setup/helpers.py b/src/driftpy/setup/helpers.py index 6d9c07a4..b66402db 100644 --- a/src/driftpy/setup/helpers.py +++ b/src/driftpy/setup/helpers.py @@ -1,3 +1,4 @@ +from solana.rpc.async_api import AsyncClient from base64 import b64decode from dataclasses import dataclass from typing import Optional @@ -63,9 +64,7 @@ async def _airdrop_user( ) -> tuple[Keypair, Signature]: if user is None: user = Keypair() - resp = await provider.connection.request_airdrop( - user.pubkey(), 100_0 * 1000000000 - ) + resp = await provider.connection.request_airdrop(user.pubkey(), 100_0 * 1000000000) tx_sig = resp.value return user, tx_sig @@ -94,7 +93,9 @@ async def _create_mint(provider: Provider) -> Keypair: fake_tx = Transaction( instructions=[create_create_mint_account_ix, init_collateral_mint_ix], - recent_blockhash=(await provider.connection.get_latest_blockhash()).value.blockhash, + recent_blockhash=( + await provider.connection.get_latest_blockhash() + ).value.blockhash, fee_payer=provider.wallet.public_key, ) @@ -195,7 +196,9 @@ async def _create_and_mint_user_usdc( for ix in mint_tx.instructions: ata_tx.add(ix) - ata_tx.recent_blockhash = (await provider.connection.get_latest_blockhash()).value.blockhash + ata_tx.recent_blockhash = ( + await provider.connection.get_latest_blockhash() + ).value.blockhash ata_tx.fee_payer = provider.wallet.payer.pubkey() ata_tx.sign_partial(usdc_account) @@ -307,9 +310,6 @@ async def get_feed_data(oracle_program: Program, price_feed: Pubkey) -> PriceDat return parse_price_data(info_resp.value.data) -from solana.rpc.async_api import AsyncClient - - async def get_oracle_data( connection: AsyncClient, oracle_addr: Pubkey, diff --git a/src/driftpy/types.py b/src/driftpy/types.py index 8b037de7..210bb822 100644 --- a/src/driftpy/types.py +++ b/src/driftpy/types.py @@ -4,48 +4,57 @@ from sumtypes import constructor from typing import Optional + @_rust_enum class SwapDirection: ADD = constructor() REMOVE = constructor() - + + @_rust_enum class ModifyOrderId: USER_ORDER_ID = constructor() ORDER_ID = constructor() - + + @_rust_enum class PositionDirection: LONG = constructor() SHORT = constructor() - + + @_rust_enum class SpotFulfillmentType: SERUM_V3 = constructor() MATCH = constructor() PHOENIX_V1 = constructor() - + + @_rust_enum class SwapReduceOnly: IN = constructor() OUT = constructor() - + + @_rust_enum class TwapPeriod: FUNDING_PERIOD = constructor() FIVE_MIN = constructor() - + + @_rust_enum class LiquidationMultiplierType: DISCOUNT = constructor() PREMIUM = constructor() - + + @_rust_enum class MarginRequirementType: INITIAL = constructor() FILL = constructor() MAINTENANCE = constructor() - + + @_rust_enum class OracleValidity: INVALID = constructor() @@ -55,7 +64,8 @@ class OracleValidity: INSUFFICIENT_DATA_POINTS = constructor() STALE_FOR_A_M_M = constructor() VALID = constructor() - + + @_rust_enum class DriftAction: UPDATE_FUNDING = constructor() @@ -67,7 +77,8 @@ class DriftAction: MARGIN_CALC = constructor() UPDATE_TWAP = constructor() UPDATE_A_M_M_CURVE = constructor() - + + @_rust_enum class PositionUpdateType: OPEN = constructor() @@ -75,17 +86,20 @@ class PositionUpdateType: REDUCE = constructor() CLOSE = constructor() FLIP = constructor() - + + @_rust_enum class DepositExplanation: NONE = constructor() TRANSFER = constructor() - + + @_rust_enum class DepositDirection: DEPOSIT = constructor() WITHDRAW = constructor() - + + @_rust_enum class OrderAction: PLACE = constructor() @@ -93,7 +107,8 @@ class OrderAction: FILL = constructor() TRIGGER = constructor() EXPIRE = constructor() - + + @_rust_enum class OrderActionExplanation: NONE = constructor() @@ -114,13 +129,15 @@ class OrderActionExplanation: ORDER_FILL_WITH_PHOENIX = constructor() ORDER_FILLED_WITH_A_M_M_JIT_L_P_SPLIT = constructor() ORDER_FILLED_WITH_L_P_JIT = constructor() - + + @_rust_enum class LPAction: ADD_LIQUIDITY = constructor() REMOVE_LIQUIDITY = constructor() SETTLE_LIQUIDITY = constructor() - + + @_rust_enum class LiquidationType: LIQUIDATE_PERP = constructor() @@ -129,12 +146,14 @@ class LiquidationType: LIQUIDATE_PERP_PNL_FOR_DEPOSIT = constructor() PERP_BANKRUPTCY = constructor() SPOT_BANKRUPTCY = constructor() - + + @_rust_enum class SettlePnlExplanation: NONE = constructor() EXPIRED_POSITION = constructor() - + + @_rust_enum class StakeAction: STAKE = constructor() @@ -143,28 +162,33 @@ class StakeAction: UNSTAKE = constructor() UNSTAKE_TRANSFER = constructor() STAKE_TRANSFER = constructor() - + + @_rust_enum class FillMode: FILL = constructor() PLACE_AND_MAKE = constructor() PLACE_AND_TAKE = constructor() - + + @_rust_enum class PerpFulfillmentMethod: A_M_M = constructor() MATCH = constructor() - + + @_rust_enum class SpotFulfillmentMethod: EXTERNAL_MARKET = constructor() MATCH = constructor() - + + @_rust_enum class MarginCalculationMode: STANDARD = constructor() LIQUIDATION = constructor() - + + @_rust_enum class OracleSource: PYTH = constructor() @@ -173,19 +197,22 @@ class OracleSource: PYTH1_K = constructor() PYTH1_M = constructor() PYTH_STABLE_COIN = constructor() - + + @_rust_enum class PostOnlyParam: NONE = constructor() MUST_POST_ONLY = constructor() TRY_POST_ONLY = constructor() SLIDE = constructor() - + + @_rust_enum class ModifyOrderPolicy: TRY_MODIFY = constructor() MUST_MODIFY = constructor() - + + @_rust_enum class MarketStatus: INITIALIZED = constructor() @@ -197,12 +224,14 @@ class MarketStatus: REDUCE_ONLY = constructor() SETTLEMENT = constructor() DELISTED = constructor() - + + @_rust_enum class ContractType: PERPETUAL = constructor() FUTURE = constructor() - + + @_rust_enum class ContractTier: A = constructor() @@ -210,23 +239,27 @@ class ContractTier: C = constructor() SPECULATIVE = constructor() ISOLATED = constructor() - + + @_rust_enum class AMMLiquiditySplit: PROTOCOL_OWNED = constructor() L_P_OWNED = constructor() SHARED = constructor() - + + @_rust_enum class SpotBalanceType: DEPOSIT = constructor() BORROW = constructor() - + + @_rust_enum class SpotFulfillmentConfigStatus: ENABLED = constructor() DISABLED = constructor() - + + @_rust_enum class AssetTier: COLLATERAL = constructor() @@ -234,7 +267,8 @@ class AssetTier: CROSS = constructor() ISOLATED = constructor() UNLISTED = constructor() - + + @_rust_enum class ExchangeStatus: DEPOSIT_PAUSED = constructor() @@ -244,25 +278,29 @@ class ExchangeStatus: LIQ_PAUSED = constructor() FUNDING_PAUSED = constructor() SETTLE_PNL_PAUSED = constructor() - + + @_rust_enum class UserStatus: BEING_LIQUIDATED = constructor() BANKRUPT = constructor() REDUCE_ONLY = constructor() - + + @_rust_enum class AssetType: BASE = constructor() QUOTE = constructor() - + + @_rust_enum class OrderStatus: INIT = constructor() OPEN = constructor() FILLED = constructor() CANCELED = constructor() - + + @_rust_enum class OrderType: MARKET = constructor() @@ -270,24 +308,28 @@ class OrderType: TRIGGER_MARKET = constructor() TRIGGER_LIMIT = constructor() ORACLE = constructor() - + + @_rust_enum class OrderTriggerCondition: ABOVE = constructor() BELOW = constructor() TRIGGERED_ABOVE = constructor() TRIGGERED_BELOW = constructor() - + + @_rust_enum class MarketType: SPOT = constructor() PERP = constructor() - + + @dataclass class MarketIdentifier: market_type: MarketType market_index: int - + + @dataclass class OrderParams: order_type: OrderType @@ -307,7 +349,8 @@ class OrderParams: auction_duration: Optional[int] auction_start_price: Optional[int] auction_end_price: Optional[int] - + + @dataclass class ModifyOrderParams: direction: Optional[PositionDirection] @@ -324,7 +367,8 @@ class ModifyOrderParams: auction_start_price: Optional[int] auction_end_price: Optional[int] policy: Optional[ModifyOrderPolicy] - + + @dataclass class HistoricalOracleData: last_oracle_price: int @@ -333,13 +377,15 @@ class HistoricalOracleData: last_oracle_price_twap: int last_oracle_price_twap5min: int last_oracle_price_twap_ts: int - + + @dataclass class PoolBalance: scaled_balance: int market_index: int padding: list[int] - + + @dataclass class AMM: oracle: Pubkey @@ -423,24 +469,28 @@ class AMM: padding2: int total_fee_earned_per_lp: int padding: list[int] - + + @dataclass class PriceDivergenceGuardRails: mark_oracle_percent_divergence: int oracle_twap5min_percent_divergence: int - + + @dataclass class ValidityGuardRails: slots_before_stale_for_amm: int slots_before_stale_for_margin: int confidence_interval_max_size: int too_volatile_ratio: int - + + @dataclass class OracleGuardRails: price_divergence: PriceDivergenceGuardRails validity: ValidityGuardRails - + + @dataclass class FeeTier: fee_numerator: int @@ -451,20 +501,23 @@ class FeeTier: referrer_reward_denominator: int referee_fee_numerator: int referee_fee_denominator: int - + + @dataclass class OrderFillerRewardStructure: reward_numerator: int reward_denominator: int time_based_reward_lower_bound: int - + + @dataclass class FeeStructure: fee_tiers: list[FeeTier] filler_reward_structure: OrderFillerRewardStructure referrer_reward_epoch_upper_bound: int flat_filler_fee: int - + + @dataclass class SpotPosition: scaled_balance: int @@ -475,7 +528,8 @@ class SpotPosition: balance_type: SpotBalanceType open_orders: int padding: list[int] - + + @dataclass class Order: slot: int @@ -502,7 +556,8 @@ class Order: trigger_condition: OrderTriggerCondition auction_duration: int padding: list[int] - + + @dataclass class PhoenixV1FulfillmentConfig: pubkey: Pubkey @@ -515,7 +570,8 @@ class PhoenixV1FulfillmentConfig: fulfillment_type: SpotFulfillmentType status: SpotFulfillmentConfigStatus padding: list[int] - + + @dataclass class SerumV3FulfillmentConfig: pubkey: Pubkey @@ -533,7 +589,8 @@ class SerumV3FulfillmentConfig: fulfillment_type: SpotFulfillmentType status: SpotFulfillmentConfigStatus padding: list[int] - + + @dataclass class InsuranceClaim: revenue_withdraw_since_last_settle: int @@ -541,7 +598,8 @@ class InsuranceClaim: quote_max_insurance: int quote_settled_insurance: int last_revenue_withdraw_ts: int - + + @dataclass class PerpMarket: pubkey: Pubkey @@ -573,7 +631,8 @@ class PerpMarket: quote_spot_market_index: int fee_adjustment: int padding: list[int] - + + @dataclass class HistoricalIndexData: last_index_bid_price: int @@ -581,7 +640,8 @@ class HistoricalIndexData: last_index_price_twap: int last_index_price_twap5min: int last_index_price_twap_ts: int - + + @dataclass class InsuranceFund: vault: Pubkey @@ -593,7 +653,8 @@ class InsuranceFund: revenue_settle_period: int total_factor: int user_factor: int - + + @dataclass class SpotMarket: pubkey: Pubkey @@ -649,7 +710,8 @@ class SpotMarket: total_swap_fee: int scale_initial_asset_weight_start: int padding: list[int] - + + @dataclass class State: admin: Pubkey @@ -675,7 +737,8 @@ class State: liquidation_duration: int initial_pct_to_liquidate: int padding: list[int] - + + @dataclass class PerpPosition: last_cumulative_funding_rate: int @@ -693,7 +756,8 @@ class PerpPosition: market_index: int open_orders: int per_lp_base: int - + + @dataclass class User: authority: Pubkey @@ -723,7 +787,8 @@ class User: open_auctions: int has_open_auction: bool padding: list[int] - + + @dataclass class UserFees: total_fee_paid: int @@ -732,7 +797,8 @@ class UserFees: total_referee_discount: int total_referrer_reward: int current_epoch_referrer_reward: int - + + @dataclass class UserStats: authority: Pubkey @@ -751,7 +817,8 @@ class UserStats: is_referrer: bool disable_update_perp_bid_ask_twap: bool padding: list[int] - + + @dataclass class LiquidatePerpRecord: market_index: int @@ -764,7 +831,8 @@ class LiquidatePerpRecord: liquidator_order_id: int liquidator_fee: int if_fee: int - + + @dataclass class LiquidateSpotRecord: asset_market_index: int @@ -774,7 +842,8 @@ class LiquidateSpotRecord: liability_price: int liability_transfer: int if_fee: int - + + @dataclass class LiquidateBorrowForPerpPnlRecord: perp_market_index: int @@ -783,7 +852,8 @@ class LiquidateBorrowForPerpPnlRecord: liability_market_index: int liability_price: int liability_transfer: int - + + @dataclass class LiquidatePerpPnlForDepositRecord: perp_market_index: int @@ -792,7 +862,8 @@ class LiquidatePerpPnlForDepositRecord: asset_market_index: int asset_price: int asset_transfer: int - + + @dataclass class PerpBankruptcyRecord: market_index: int @@ -801,14 +872,16 @@ class PerpBankruptcyRecord: clawback_user: Optional[Pubkey] clawback_user_payment: Optional[int] cumulative_funding_rate_delta: int - + + @dataclass class SpotBankruptcyRecord: market_index: int borrow_amount: int if_payment: int cumulative_deposit_interest_delta: int - + + @dataclass class InsuranceFundStake: authority: Pubkey @@ -821,7 +894,8 @@ class InsuranceFundStake: cost_basis: int market_index: int padding: list[int] - + + @dataclass class ProtocolIfSharesTransferConfig: whitelisted_signers: list[Pubkey] @@ -829,7 +903,8 @@ class ProtocolIfSharesTransferConfig: current_epoch_transfer: int next_epoch_ts: int padding: list[int] - + + @dataclass class ReferrerName: authority: Pubkey @@ -837,6 +912,7 @@ class ReferrerName: user_stats: Pubkey name: list[int] + @dataclass class OraclePriceData: price: int @@ -846,7 +922,8 @@ class OraclePriceData: twap_confidence: int has_sufficient_number_of_datapoints: bool + @dataclass class TxParams: compute_units: Optional[int] - compute_units_price: Optional[int] \ No newline at end of file + compute_units_price: Optional[int] diff --git a/tests/test.py b/tests/test.py index 58c9bfee..da38453c 100644 --- a/tests/test.py +++ b/tests/test.py @@ -132,9 +132,7 @@ async def test_initialized_spot_market_2( @async_fixture(scope="session") -async def initialized_market( - drift_client: Admin, workspace: WorkspaceType -) -> Pubkey: +async def initialized_market(drift_client: Admin, workspace: WorkspaceType) -> Pubkey: pyth_program = workspace["pyth"] sol_usd = await mock_oracle(pyth_program=pyth_program, price=1) perp_market_index = 0 @@ -176,10 +174,10 @@ async def test_init_user( drift_client: Admin, ): await drift_client.intialize_user() - user_public_key = get_user_account_public_key(drift_client.program.program_id, drift_client.authority, 0) - user: User = await get_user_account( - drift_client.program, user_public_key + user_public_key = get_user_account_public_key( + drift_client.program.program_id, drift_client.authority, 0 ) + user: User = await get_user_account(drift_client.program, user_public_key) assert user.authority == drift_client.authority @@ -189,7 +187,7 @@ async def test_usdc_deposit( user_usdc_account: Keypair, ): usdc_spot_market = await get_spot_market_account(drift_client.program, 0) - assert(usdc_spot_market.market_index == 0) + assert usdc_spot_market.market_index == 0 drift_client.spot_market_atas[0] = user_usdc_account.pubkey() await drift_client.deposit( USDC_AMOUNT, 0, user_usdc_account.pubkey(), user_initialized=True @@ -311,21 +309,23 @@ async def test_stake_if( await drift_client.update_update_insurance_fund_unstaking_period(0, 0) await drift_client.initialize_insurance_fund_stake(0) - if_acc = await get_if_stake_account( - drift_client.program, drift_client.authority, 0 - ) + if_acc = await get_if_stake_account(drift_client.program, drift_client.authority, 0) assert if_acc.market_index == 0 await drift_client.add_insurance_fund_stake(0, 1 * QUOTE_PRECISION) - user_stats = await get_user_stats_account(drift_client.program, drift_client.authority) + user_stats = await get_user_stats_account( + drift_client.program, drift_client.authority + ) assert user_stats.if_staked_quote_asset_amount == 1 * QUOTE_PRECISION await drift_client.request_remove_insurance_fund_stake(0, 1 * QUOTE_PRECISION) await drift_client.remove_insurance_fund_stake(0) - user_stats = await get_user_stats_account(drift_client.program, drift_client.authority) + user_stats = await get_user_stats_account( + drift_client.program, drift_client.authority + ) assert user_stats.if_staked_quote_asset_amount == 0 From 2486f8268d27c24235a10f57b76f7e61be3d0886 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Fri, 17 Nov 2023 16:16:36 -0500 Subject: [PATCH 5/7] ws init --- src/driftpy/accounts/ws/__init__.py | 1 + src/driftpy/accounts/ws/user.py | 71 +++++++++++++++++++++++++++++ src/driftpy/drift_client.py | 7 ++- 3 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 src/driftpy/accounts/ws/__init__.py create mode 100644 src/driftpy/accounts/ws/user.py diff --git a/src/driftpy/accounts/ws/__init__.py b/src/driftpy/accounts/ws/__init__.py new file mode 100644 index 00000000..82da278c --- /dev/null +++ b/src/driftpy/accounts/ws/__init__.py @@ -0,0 +1 @@ +from .user import * \ No newline at end of file diff --git a/src/driftpy/accounts/ws/user.py b/src/driftpy/accounts/ws/user.py new file mode 100644 index 00000000..c2baa1d3 --- /dev/null +++ b/src/driftpy/accounts/ws/user.py @@ -0,0 +1,71 @@ +import asyncio +from typing import Optional + +from anchorpy import Program +from solders.pubkey import Pubkey +from solana.rpc.commitment import Commitment + +from driftpy.accounts import get_user_account_and_slot +from driftpy.accounts import UserAccountSubscriber, DataAndSlot +from driftpy.types import User + +import websockets +import websockets.exceptions # force eager imports +from solana.rpc.websocket_api import connect + +from typing import cast + + +class WebsocketUserAccountSubscriber(UserAccountSubscriber): + def __init__( + self, + user_pubkey: Pubkey, + program: Program, + commitment: Commitment = "confirmed", + ): + self.program = program + self.commitment = commitment + self.user_pubkey = user_pubkey + self.user_and_slot = None + + self.task = None + self.ws = None + self.subscription_id = None + + async def subscribe(self): + await self._subscribe() + + async def _subscribe(self): + print('here9') + ws_endpoint = self.program.provider.connection._provider.endpoint_uri.replace("https", "wss") + async for ws in connect(ws_endpoint): + try: + await ws.account_subscribe(# type: ignore + self.user_pubkey, + commitment=self.commitment, + encoding="base64", + ) + first_resp = await ws.recv() + subscription_id = cast(int, first_resp[0].result) # type: ignore + print(f"Subscription id: {subscription_id}") + async for msg in ws: + try: + slot = int(msg[0].result.context.slot) # type: ignore + account_bytes = cast(bytes, msg[0].result.value.data) # type: ignore + decoded_data = self.program.coder.accounts.decode(account_bytes) + self.user_and_slot = DataAndSlot(slot, decoded_data) + print("here") + except Exception: + print(f"Error processing account data") + break + await ws.account_unsubscribe(subscription_id) # type: ignore + except websockets.exceptions.ConnectionClosed: + print("Websocket closed, reconnecting...") + continue + + async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: + return self.user_and_slot + + async def unsubscribe(self): + self.task.cancel() + self.task = None diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index d96e232e..af5657d8 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -100,8 +100,8 @@ def from_config(config: Config, provider: Provider, authority: Keypair = None): file = Path(str(driftpy.__path__[0]) + "/idl/drift.json") print(file) with file.open() as f: - idl_dict = json.load(f) - idl = Idl.from_json(idl_dict) + raw = file.read_text() + idl = Idl.from_json(raw) # create the program program = Program( @@ -110,8 +110,7 @@ def from_config(config: Config, provider: Provider, authority: Keypair = None): provider, ) - drift_client = DriftClient - (program, authority) + drift_client = DriftClient(program, authority) drift_client.config = config drift_client.idl = idl From cef1b112f76823fe98f66c4b19d0da60ff82efa5 Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Fri, 17 Nov 2023 19:58:45 -0500 Subject: [PATCH 6/7] user and drift client ws subscription --- src/driftpy/accounts/cache/drift_client.py | 2 +- src/driftpy/accounts/get_accounts.py | 11 +- src/driftpy/accounts/oracle.py | 49 ++++---- src/driftpy/accounts/types.py | 2 +- src/driftpy/accounts/ws/__init__.py | 3 +- src/driftpy/accounts/ws/account_subscriber.py | 92 +++++++++++++++ src/driftpy/accounts/ws/drift_client.py | 105 ++++++++++++++++++ src/driftpy/accounts/ws/user.py | 71 ++---------- src/driftpy/drift_user.py | 4 +- 9 files changed, 240 insertions(+), 99 deletions(-) create mode 100644 src/driftpy/accounts/ws/account_subscriber.py create mode 100644 src/driftpy/accounts/ws/drift_client.py diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py index 29c20eb9..ee05003a 100644 --- a/src/driftpy/accounts/cache/drift_client.py +++ b/src/driftpy/accounts/cache/drift_client.py @@ -83,7 +83,7 @@ async def get_spot_market_and_slot( await self.cache_if_needed() return self.cache["spot_markets"][market_index] - async def get_oracle_data_and_slot( + async def get_oracle_price_data_and_slot( self, oracle: Pubkey ) -> Optional[DataAndSlot[OraclePriceData]]: await self.cache_if_needed() diff --git a/src/driftpy/accounts/get_accounts.py b/src/driftpy/accounts/get_accounts.py index 31fc8837..812fb687 100644 --- a/src/driftpy/accounts/get_accounts.py +++ b/src/driftpy/accounts/get_accounts.py @@ -1,5 +1,5 @@ import base64 -from typing import cast +from typing import cast, Optional, Callable from solders.pubkey import Pubkey from anchorpy import Program, ProgramAccount from solana.rpc.commitment import Commitment @@ -10,7 +10,10 @@ async def get_account_data_and_slot( - address: Pubkey, program: Program, commitment: Commitment = "processed" + address: Pubkey, + program: Program, + commitment: Commitment = "processed", + decode: Optional[Callable[[bytes], T]] = None, ) -> Optional[DataAndSlot[T]]: account_info = await program.provider.connection.get_account_info( address, @@ -24,7 +27,9 @@ async def get_account_data_and_slot( slot = account_info.context.slot data = account_info.value.data - decoded_data = program.coder.accounts.decode(data) + decoded_data = ( + decode(data) if decode is not None else program.coder.accounts.decode(data) + ) return DataAndSlot(slot, decoded_data) diff --git a/src/driftpy/accounts/oracle.py b/src/driftpy/accounts/oracle.py index 5e2a28c8..149287e9 100644 --- a/src/driftpy/accounts/oracle.py +++ b/src/driftpy/accounts/oracle.py @@ -21,26 +21,12 @@ async def get_oracle_price_data_and_slot( if "Pyth" in str(oracle_source): rpc_reponse = await connection.get_account_info(address) rpc_response_slot = rpc_reponse.context.slot - (pyth_price_info, last_slot, twac, twap) = await _parse_pyth_price_info( - rpc_reponse - ) - scale = 1 - if "1K" in str(oracle_source): - scale = 1e3 - elif "1M" in str(oracle_source): - scale = 1e6 - - oracle_data = OraclePriceData( - price=convert_pyth_price(pyth_price_info.price, scale), - slot=pyth_price_info.pub_slot, - confidence=convert_pyth_price(pyth_price_info.confidence_interval, scale), - twap=convert_pyth_price(twap, scale), - twap_confidence=convert_pyth_price(twac, scale), - has_sufficient_number_of_datapoints=True, + oracle_price_data = decode_pyth_price_info( + rpc_reponse.value.data, oracle_source ) - return DataAndSlot(data=oracle_data, slot=rpc_response_slot) + return DataAndSlot(data=oracle_price_data, slot=rpc_response_slot) elif "Quote" in str(oracle_source): return DataAndSlot( data=OraclePriceData(PRICE_PRECISION, 0, 1, 1, 0, True), slot=0 @@ -49,11 +35,10 @@ async def get_oracle_price_data_and_slot( raise NotImplementedError("Unsupported Oracle Source", str(oracle_source)) -async def _parse_pyth_price_info( - resp: GetAccountInfoResp, -) -> (PythPriceInfo, int, int, int): - buffer = resp.value.data - +def decode_pyth_price_info( + buffer: bytes, + oracle_source=OracleSource.PYTH(), +) -> OraclePriceData: offset = _ACCOUNT_HEADER_BYTES _, exponent, _ = struct.unpack_from(" Optional[DataAndSlot[OraclePriceData]]: pass diff --git a/src/driftpy/accounts/ws/__init__.py b/src/driftpy/accounts/ws/__init__.py index 82da278c..58e298b4 100644 --- a/src/driftpy/accounts/ws/__init__.py +++ b/src/driftpy/accounts/ws/__init__.py @@ -1 +1,2 @@ -from .user import * \ No newline at end of file +from .drift_client import * +from .user import * diff --git a/src/driftpy/accounts/ws/account_subscriber.py b/src/driftpy/accounts/ws/account_subscriber.py new file mode 100644 index 00000000..d8862f48 --- /dev/null +++ b/src/driftpy/accounts/ws/account_subscriber.py @@ -0,0 +1,92 @@ +import asyncio +from typing import Optional + +from anchorpy import Program +from solders.pubkey import Pubkey +from solana.rpc.commitment import Commitment + +from driftpy.accounts import get_account_data_and_slot +from driftpy.accounts import UserAccountSubscriber, DataAndSlot + +import websockets +import websockets.exceptions # force eager imports +from solana.rpc.websocket_api import connect + +from typing import cast, Generic, TypeVar, Callable + +T = TypeVar("T") + + +class WebsocketAccountSubscriber(UserAccountSubscriber, Generic[T]): + def __init__( + self, + pubkey: Pubkey, + program: Program, + commitment: Commitment = "confirmed", + decode: Optional[Callable[[bytes], T]] = None, + ): + self.program = program + self.commitment = commitment + self.pubkey = pubkey + self.data_and_slot = None + self.task = None + self.decode = ( + decode if decode is not None else self.program.coder.accounts.decode + ) + + async def subscribe(self): + if self.data_and_slot is None: + await self.fetch() + + self.task = asyncio.create_task(self.subscribe_ws()) + return self.task + + async def subscribe_ws(self): + ws_endpoint = self.program.provider.connection._provider.endpoint_uri.replace( + "https", "wss" + ).replace("http", "ws") + async for ws in connect(ws_endpoint): + try: + await ws.account_subscribe( # type: ignore + self.pubkey, + commitment=self.commitment, + encoding="base64", + ) + first_resp = await ws.recv() + subscription_id = cast(int, first_resp[0].result) # type: ignore + + async for msg in ws: + try: + slot = int(msg[0].result.context.slot) # type: ignore + + if msg[0].result.value is None: + continue + + account_bytes = cast(bytes, msg[0].result.value.data) # type: ignore + decoded_data = self.decode(account_bytes) + self._update_data(DataAndSlot(slot, decoded_data)) + except Exception: + print(f"Error processing account data") + break + await ws.account_unsubscribe(subscription_id) # type: ignore + except websockets.exceptions.ConnectionClosed: + print("Websocket closed, reconnecting...") + continue + + async def fetch(self): + new_data = await get_account_data_and_slot( + self.pubkey, self.program, self.commitment, self.decode + ) + + self._update_data(new_data) + + def _update_data(self, new_data: Optional[DataAndSlot[T]]): + if new_data is None: + return + + if self.data_and_slot is None or new_data.slot > self.data_and_slot.slot: + self.data_and_slot = new_data + + def unsubscribe(self): + self.task.cancel() + self.task = None diff --git a/src/driftpy/accounts/ws/drift_client.py b/src/driftpy/accounts/ws/drift_client.py new file mode 100644 index 00000000..ee5fc140 --- /dev/null +++ b/src/driftpy/accounts/ws/drift_client.py @@ -0,0 +1,105 @@ +from anchorpy import Program +from solders.pubkey import Pubkey +from solana.rpc.commitment import Commitment + +from driftpy.accounts import ( + get_state_account_and_slot, + get_spot_market_account_and_slot, + get_perp_market_account_and_slot, +) +from driftpy.accounts.oracle import get_oracle_price_data_and_slot +from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot +from typing import Optional + +from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber +from driftpy.types import PerpMarket, SpotMarket, OraclePriceData, State + +from driftpy.addresses import * + +from driftpy.types import OracleSource + +from driftpy.accounts.oracle import decode_pyth_price_info + + +class WebsocketDriftClientAccountSubscriber: + def __init__(self, program: Program, commitment: Commitment = "confirmed"): + self.program = program + self.commitment = commitment + self.state_subscriber = None + self.spot_market_subscribers = [] + self.perp_market_subscribers = [] + self.oracle_subscribers = {} + + async def subscribe(self): + state_public_key = get_state_public_key(self.program.program_id) + self.state_subscriber = WebsocketAccountSubscriber[State]( + state_public_key, self.program, self.commitment + ) + await self.state_subscriber.subscribe() + + for i in range(self.state_subscriber.data_and_slot.data.number_of_spot_markets): + spot_market_public_key = get_spot_market_public_key( + self.program.program_id, i + ) + spot_market_subscriber = WebsocketAccountSubscriber[SpotMarket]( + spot_market_public_key, self.program, self.commitment + ) + await spot_market_subscriber.subscribe() + self.spot_market_subscribers.append(spot_market_subscriber) + + spot_market = spot_market_subscriber.data_and_slot.data + oracle = spot_market.oracle + if oracle != Pubkey.default(): + oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData]( + oracle, + self.program, + self.commitment, + self._get_oracle_decode_fn(spot_market.oracle_source), + ) + await oracle_subscriber.subscribe() + self.oracle_subscribers[str(oracle)] = oracle_subscriber + + for i in range(self.state_subscriber.data_and_slot.data.number_of_markets): + perp_market_public_key = get_perp_market_public_key( + self.program.program_id, i + ) + perp_market_subscriber = WebsocketAccountSubscriber[PerpMarket]( + perp_market_public_key, self.program, self.commitment + ) + await perp_market_subscriber.subscribe() + self.perp_market_subscribers.append(perp_market_subscriber) + + perp_market = perp_market_subscriber.data_and_slot.data + oracle = perp_market.amm.oracle + oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData]( + oracle, + self.program, + self.commitment, + self._get_oracle_decode_fn(perp_market.amm.oracle_source), + ) + await oracle_subscriber.subscribe() + self.oracle_subscribers[str(oracle)] = oracle_subscriber + + def _get_oracle_decode_fn(self, oracle_source: OracleSource): + if "Pyth" in str(oracle_source): + return lambda data: decode_pyth_price_info(data, oracle_source) + else: + raise Exception("Unknown oracle source") + + async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: + return self.state_subscriber.data_and_slot + + async def get_perp_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[PerpMarket]]: + return self.perp_market_subscribers[market_index].data_and_slot + + async def get_spot_market_and_slot( + self, market_index: int + ) -> Optional[DataAndSlot[SpotMarket]]: + return self.spot_market_subscribers[market_index].data_and_slot + + async def get_oracle_price_data_and_slot( + self, oracle: Pubkey + ) -> Optional[DataAndSlot[OraclePriceData]]: + return self.oracle_subscribers[str(oracle)].data_and_slot diff --git a/src/driftpy/accounts/ws/user.py b/src/driftpy/accounts/ws/user.py index c2baa1d3..425958ed 100644 --- a/src/driftpy/accounts/ws/user.py +++ b/src/driftpy/accounts/ws/user.py @@ -1,71 +1,14 @@ -import asyncio from typing import Optional -from anchorpy import Program -from solders.pubkey import Pubkey -from solana.rpc.commitment import Commitment - -from driftpy.accounts import get_user_account_and_slot -from driftpy.accounts import UserAccountSubscriber, DataAndSlot +from driftpy.accounts import DataAndSlot from driftpy.types import User -import websockets -import websockets.exceptions # force eager imports -from solana.rpc.websocket_api import connect - -from typing import cast - - -class WebsocketUserAccountSubscriber(UserAccountSubscriber): - def __init__( - self, - user_pubkey: Pubkey, - program: Program, - commitment: Commitment = "confirmed", - ): - self.program = program - self.commitment = commitment - self.user_pubkey = user_pubkey - self.user_and_slot = None +from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber +from driftpy.accounts.types import UserAccountSubscriber - self.task = None - self.ws = None - self.subscription_id = None - - async def subscribe(self): - await self._subscribe() - - async def _subscribe(self): - print('here9') - ws_endpoint = self.program.provider.connection._provider.endpoint_uri.replace("https", "wss") - async for ws in connect(ws_endpoint): - try: - await ws.account_subscribe(# type: ignore - self.user_pubkey, - commitment=self.commitment, - encoding="base64", - ) - first_resp = await ws.recv() - subscription_id = cast(int, first_resp[0].result) # type: ignore - print(f"Subscription id: {subscription_id}") - async for msg in ws: - try: - slot = int(msg[0].result.context.slot) # type: ignore - account_bytes = cast(bytes, msg[0].result.value.data) # type: ignore - decoded_data = self.program.coder.accounts.decode(account_bytes) - self.user_and_slot = DataAndSlot(slot, decoded_data) - print("here") - except Exception: - print(f"Error processing account data") - break - await ws.account_unsubscribe(subscription_id) # type: ignore - except websockets.exceptions.ConnectionClosed: - print("Websocket closed, reconnecting...") - continue +class WebsocketUserAccountSubscriber( + WebsocketAccountSubscriber[User], UserAccountSubscriber +): async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: - return self.user_and_slot - - async def unsubscribe(self): - self.task.cancel() - self.task = None + return self.data_and_slot diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index 041a64f2..badadb5a 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -1,5 +1,5 @@ from driftpy.accounts import UserAccountSubscriber -from driftpy.accounts.cache import CachedUserAccountSubscriber +from driftpy.accounts.cache import WebsocketUserAccountSubscriber from driftpy.drift_client import DriftClient from driftpy.math.positions import * from driftpy.math.margin import * @@ -40,7 +40,7 @@ def __init__( ) if account_subscriber is None: - account_subscriber = CachedUserAccountSubscriber( + account_subscriber = WebsocketUserAccountSubscriber( self.user_public_key, self.program ) From a61579d720d2b80d6600d382fda8c9ce153328ab Mon Sep 17 00:00:00 2001 From: Chris Heaney Date: Fri, 17 Nov 2023 20:58:04 -0500 Subject: [PATCH 7/7] tests working --- src/driftpy/accounts/cache/drift_client.py | 6 ++ src/driftpy/accounts/cache/user.py | 6 ++ src/driftpy/accounts/types.py | 8 ++ src/driftpy/accounts/ws/drift_client.py | 112 ++++++++++++--------- src/driftpy/admin.py | 8 +- src/driftpy/drift_client.py | 43 ++++---- src/driftpy/drift_user.py | 19 ++-- tests/test.py | 17 +++- 8 files changed, 137 insertions(+), 82 deletions(-) diff --git a/src/driftpy/accounts/cache/drift_client.py b/src/driftpy/accounts/cache/drift_client.py index ee05003a..75905249 100644 --- a/src/driftpy/accounts/cache/drift_client.py +++ b/src/driftpy/accounts/cache/drift_client.py @@ -20,6 +20,9 @@ def __init__(self, program: Program, commitment: Commitment = "confirmed"): self.commitment = commitment self.cache = None + async def subscribe(self): + await self.cache_if_needed() + async def update_cache(self): if self.cache is None: self.cache = {} @@ -92,3 +95,6 @@ async def get_oracle_price_data_and_slot( async def cache_if_needed(self): if self.cache is None: await self.update_cache() + + def unsubscribe(self): + self.cache = None diff --git a/src/driftpy/accounts/cache/user.py b/src/driftpy/accounts/cache/user.py index 6c79f3ab..068cdbe7 100644 --- a/src/driftpy/accounts/cache/user.py +++ b/src/driftpy/accounts/cache/user.py @@ -21,6 +21,9 @@ def __init__( self.user_pubkey = user_pubkey self.user_and_slot = None + async def subscribe(self): + await self.cache_if_needed() + async def update_cache(self): user_and_slot = await get_user_account_and_slot(self.program, self.user_pubkey) self.user_and_slot = user_and_slot @@ -32,3 +35,6 @@ async def get_user_account_and_slot(self) -> Optional[DataAndSlot[User]]: async def cache_if_needed(self): if self.user_and_slot is None: await self.update_cache() + + def unsubscribe(self): + self.user_and_slot = None diff --git a/src/driftpy/accounts/types.py b/src/driftpy/accounts/types.py index 598c3c9e..9f7f4d6b 100644 --- a/src/driftpy/accounts/types.py +++ b/src/driftpy/accounts/types.py @@ -23,6 +23,14 @@ class DataAndSlot(Generic[T]): class DriftClientAccountSubscriber: + @abstractmethod + async def subscribe(self): + pass + + @abstractmethod + def unsubscribe(self): + pass + @abstractmethod async def get_state_account_and_slot(self) -> Optional[DataAndSlot[State]]: pass diff --git a/src/driftpy/accounts/ws/drift_client.py b/src/driftpy/accounts/ws/drift_client.py index ee5fc140..2eaa5572 100644 --- a/src/driftpy/accounts/ws/drift_client.py +++ b/src/driftpy/accounts/ws/drift_client.py @@ -2,12 +2,6 @@ from solders.pubkey import Pubkey from solana.rpc.commitment import Commitment -from driftpy.accounts import ( - get_state_account_and_slot, - get_spot_market_account_and_slot, - get_perp_market_account_and_slot, -) -from driftpy.accounts.oracle import get_oracle_price_data_and_slot from driftpy.accounts.types import DriftClientAccountSubscriber, DataAndSlot from typing import Optional @@ -21,13 +15,13 @@ from driftpy.accounts.oracle import decode_pyth_price_info -class WebsocketDriftClientAccountSubscriber: +class WebsocketDriftClientAccountSubscriber(DriftClientAccountSubscriber): def __init__(self, program: Program, commitment: Commitment = "confirmed"): self.program = program self.commitment = commitment self.state_subscriber = None - self.spot_market_subscribers = [] - self.perp_market_subscribers = [] + self.spot_market_subscribers = {} + self.perp_market_subscribers = {} self.oracle_subscribers = {} async def subscribe(self): @@ -38,47 +32,60 @@ async def subscribe(self): await self.state_subscriber.subscribe() for i in range(self.state_subscriber.data_and_slot.data.number_of_spot_markets): - spot_market_public_key = get_spot_market_public_key( - self.program.program_id, i - ) - spot_market_subscriber = WebsocketAccountSubscriber[SpotMarket]( - spot_market_public_key, self.program, self.commitment - ) - await spot_market_subscriber.subscribe() - self.spot_market_subscribers.append(spot_market_subscriber) - - spot_market = spot_market_subscriber.data_and_slot.data - oracle = spot_market.oracle - if oracle != Pubkey.default(): - oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData]( - oracle, - self.program, - self.commitment, - self._get_oracle_decode_fn(spot_market.oracle_source), - ) - await oracle_subscriber.subscribe() - self.oracle_subscribers[str(oracle)] = oracle_subscriber + await self.subscribe_to_spot_market(i) for i in range(self.state_subscriber.data_and_slot.data.number_of_markets): - perp_market_public_key = get_perp_market_public_key( - self.program.program_id, i - ) - perp_market_subscriber = WebsocketAccountSubscriber[PerpMarket]( - perp_market_public_key, self.program, self.commitment - ) - await perp_market_subscriber.subscribe() - self.perp_market_subscribers.append(perp_market_subscriber) - - perp_market = perp_market_subscriber.data_and_slot.data - oracle = perp_market.amm.oracle - oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData]( - oracle, - self.program, - self.commitment, - self._get_oracle_decode_fn(perp_market.amm.oracle_source), - ) - await oracle_subscriber.subscribe() - self.oracle_subscribers[str(oracle)] = oracle_subscriber + await self.subscribe_to_perp_market(i) + + async def subscribe_to_spot_market(self, market_index: int): + if market_index in self.spot_market_subscribers: + return + + spot_market_public_key = get_spot_market_public_key( + self.program.program_id, market_index + ) + spot_market_subscriber = WebsocketAccountSubscriber[SpotMarket]( + spot_market_public_key, self.program, self.commitment + ) + await spot_market_subscriber.subscribe() + self.spot_market_subscribers[market_index] = spot_market_subscriber + + spot_market = spot_market_subscriber.data_and_slot.data + await self.subscribe_to_oracle(spot_market.oracle, spot_market.oracle_source) + + async def subscribe_to_perp_market(self, market_index: int): + if market_index in self.perp_market_subscribers: + return + + perp_market_public_key = get_perp_market_public_key( + self.program.program_id, market_index + ) + perp_market_subscriber = WebsocketAccountSubscriber[PerpMarket]( + perp_market_public_key, self.program, self.commitment + ) + await perp_market_subscriber.subscribe() + self.perp_market_subscribers[market_index] = perp_market_subscriber + + perp_market = perp_market_subscriber.data_and_slot.data + await self.subscribe_to_oracle( + perp_market.amm.oracle, perp_market.amm.oracle_source + ) + + async def subscribe_to_oracle(self, oracle: Pubkey, oracle_source: OracleSource): + if oracle == Pubkey.default(): + return + + if str(oracle) in self.oracle_subscribers: + return + + oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData]( + oracle, + self.program, + self.commitment, + self._get_oracle_decode_fn(oracle_source), + ) + await oracle_subscriber.subscribe() + self.oracle_subscribers[str(oracle)] = oracle_subscriber def _get_oracle_decode_fn(self, oracle_source: OracleSource): if "Pyth" in str(oracle_source): @@ -103,3 +110,12 @@ async def get_oracle_price_data_and_slot( self, oracle: Pubkey ) -> Optional[DataAndSlot[OraclePriceData]]: return self.oracle_subscribers[str(oracle)].data_and_slot + + def unsubscribe(self): + self.state_subscriber.unsubscribe() + for spot_market_subscriber in self.spot_market_subscribers.values(): + spot_market_subscriber.unsubscribe() + for perp_market_subscriber in self.perp_market_subscribers.values(): + perp_market_subscriber.unsubscribe() + for oracle_subscriber in self.oracle_subscribers.values(): + oracle_subscriber.unsubscribe() diff --git a/src/driftpy/admin.py b/src/driftpy/admin.py index ddd4bb4e..0811bbe2 100644 --- a/src/driftpy/admin.py +++ b/src/driftpy/admin.py @@ -148,14 +148,14 @@ async def initialize_spot_market( ): state_public_key = get_state_public_key(self.program_id) state = await get_state_account(self.program) - spot_index = state.number_of_spot_markets + spot_market_index = state.number_of_spot_markets - spot_public_key = get_spot_market_public_key(self.program_id, spot_index) + spot_public_key = get_spot_market_public_key(self.program_id, spot_market_index) spot_vault_public_key = get_spot_market_vault_public_key( - self.program_id, spot_index + self.program_id, spot_market_index ) insurance_vault_public_key = get_insurance_fund_vault_public_key( - self.program_id, spot_index + self.program_id, spot_market_index ) return await self.program.rpc["initialize_spot_market"]( diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 1b00cf59..d9f6d432 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -29,7 +29,7 @@ from driftpy.math.positions import is_available, is_spot_position_available from driftpy.accounts import DriftClientAccountSubscriber -from driftpy.accounts.cache import CachedDriftClientAccountSubscriber +from driftpy.accounts.ws import WebsocketDriftClientAccountSubscriber DEFAULT_USER_NAME = "Main Account" @@ -72,7 +72,7 @@ def __init__( self.subaccounts = [0] if account_subscriber is None: - account_subscriber = CachedDriftClientAccountSubscriber(self.program) + account_subscriber = WebsocketDriftClientAccountSubscriber(self.program) self.account_subscriber = account_subscriber @@ -116,6 +116,12 @@ def from_config(config: Config, provider: Provider, authority: Keypair = None): return drift_client + async def subscribe(self): + await self.account_subscriber.subscribe() + + def unsubscribe(self): + self.account_subscriber.unsubscribe() + def get_user_account_public_key(self, user_id=0) -> Pubkey: return get_user_account_public_key(self.program_id, self.authority, user_id) @@ -787,11 +793,10 @@ async def place_perp_order( user_id: int = 0, ): return await self.send_ixs( - [ + [ self.get_increase_compute_ix(), - (await self.get_place_perp_order_ix(order_params, user_id))[-1] + (await self.get_place_perp_order_ix(order_params, user_id))[-1], ] - ) async def get_place_perp_order_ix( @@ -805,24 +810,21 @@ async def get_place_perp_order_ix( ) ix = self.program.instruction["place_perp_order"]( - order_params, - ctx=Context( - accounts={ - "state": self.get_state_public_key(), - "user": user_account_public_key, - "authority": self.signer.pubkey(), - }, - remaining_accounts=remaining_accounts, - ), - ) + order_params, + ctx=Context( + accounts={ + "state": self.get_state_public_key(), + "user": user_account_public_key, + "authority": self.signer.pubkey(), + }, + remaining_accounts=remaining_accounts, + ), + ) return ix async def get_place_perp_orders_ix( - self, - order_params: List[OrderParams], - user_id: int = 0, - cancel_all=True + self, order_params: List[OrderParams], user_id: int = 0, cancel_all=True ): user_account_public_key = self.get_user_account_public_key(user_id) writeable_market_indexes = list(set([x.market_index for x in order_params])) @@ -844,7 +846,8 @@ async def get_place_perp_orders_ix( }, remaining_accounts=remaining_accounts, ), - )) + ) + ) for order_param in order_params: ix = self.program.instruction["place_perp_order"]( order_param, diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index b9762d87..7a854ef3 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -46,6 +46,12 @@ def __init__( self.account_subscriber = account_subscriber + async def subscribe(self): + await self.account_subscriber.subscribe() + + def unsubscribe(self): + self.account_subscriber.unsubscribe() + async def get_spot_oracle_data( self, spot_market: SpotMarket ) -> Optional[OraclePriceData]: @@ -68,16 +74,15 @@ async def get_perp_market(self, market_index: int) -> PerpMarket: async def get_user(self) -> User: return (await self.account_subscriber.get_user_account_and_slot()).data - - async def get_open_orders(self, - # market_type: MarketType, - # market_index: int, - # position_direction: PositionDirection - ): + async def get_open_orders( + self, + # market_type: MarketType, + # market_index: int, + # position_direction: PositionDirection + ): user: User = await self.get_user() return user.orders - async def get_spot_market_liability( self, market_index=None, diff --git a/tests/test.py b/tests/test.py index 4fea2c76..1a5c91ef 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,3 +1,5 @@ +import asyncio + from pytest import fixture, mark from pytest_asyncio import fixture as async_fixture from solders.keypair import Keypair @@ -12,6 +14,7 @@ SPOT_BALANCE_PRECISION, SPOT_WEIGHT_PRECISION, ) +from driftpy.accounts.cache import CachedUserAccountSubscriber, CachedDriftClientAccountSubscriber from math import sqrt from driftpy.drift_user import DriftUser @@ -83,8 +86,9 @@ def provider(program: Program) -> Provider: @async_fixture(scope="session") async def drift_client(program: Program, usdc_mint: Keypair) -> Admin: - admin = Admin(program) + admin = Admin(program, account_subscriber=CachedDriftClientAccountSubscriber(program)) await admin.initialize(usdc_mint.pubkey(), admin_controls_prices=True) + await admin.subscribe() return admin @@ -130,6 +134,8 @@ async def test_initialized_spot_market_2( maintenance_liability_weight=main_liab_weight, ) + await drift_client.account_subscriber.update_cache() + spot_market = await get_spot_market_account(admin_drift_client.program, 1) assert spot_market.market_index == 1 print(spot_market.market_index) @@ -148,6 +154,9 @@ async def initialized_market(drift_client: Admin, workspace: WorkspaceType) -> P PERIODICITY, ) + + await drift_client.account_subscriber.update_cache() + return sol_usd @@ -207,7 +216,8 @@ async def test_open_orders( drift_client: Admin, ): - drift_user = DriftUser(drift_client) + drift_user = DriftUser(drift_client, account_subscriber=CachedUserAccountSubscriber(drift_client.get_user_account_public_key(), drift_client.program)) + await drift_user.subscribe() user_account = await drift_client.get_user(0) assert(len(user_account.orders)==32) @@ -224,6 +234,7 @@ async def test_open_orders( ixs = await drift_client.get_place_perp_orders_ix([order_params]) await drift_client.send_ixs(ixs) await drift_user.account_subscriber.update_cache() + await asyncio.sleep(1) open_orders_after = await drift_user.get_open_orders() assert(open_orders_after[0].base_asset_amount == BASE_PRECISION) assert(open_orders_after[0].order_id == 1) @@ -376,7 +387,7 @@ async def test_liq_perp( user_account = await drift_client.get_user(0) liq, _ = await _airdrop_user(drift_client.program.provider) - liq_drift_client = DriftClient(drift_client.program, liq) + liq_drift_client = DriftClient(drift_client.program, liq, account_subscriber=CachedDriftClientAccountSubscriber(drift_client.program)) usdc_acc = await _create_and_mint_user_usdc( usdc_mint, drift_client.program.provider, USDC_AMOUNT, liq.pubkey() )