diff --git a/.DS_Store b/.DS_Store index 704ecca..67c23bd 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index ca8b4e9..490e7d7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ venv .mypy_cache /__pycache__/ /src/__pycache__/ -/src/sections/__pycache__ \ No newline at end of file +/src/sections/__pycache__ +/src/.DS_Store \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..8c59ef8 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,10 @@ +{ + "python.linting.enabled": true, + "python.linting.mypyEnabled": true, + "python.linting.mypyArgs": [ + "--ignore-missing-imports", + "--strict" + ], + "python.linting.pylintEnabled": true, // Optional: Enables pylint + "python.linting.flake8Enabled": true // Optional: Enables flake8 +} \ No newline at end of file diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..a662b23 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +check_untyped_defs = True \ No newline at end of file diff --git a/pickles/.DS_Store b/pickles/.DS_Store deleted file mode 100644 index 5008ddf..0000000 Binary files a/pickles/.DS_Store and /dev/null differ diff --git a/pickles/perp_274342434.pkl b/pickles/perp_274342434.pkl deleted file mode 100644 index b66ed8f..0000000 Binary files a/pickles/perp_274342434.pkl and /dev/null differ diff --git a/pickles/perp_279408293.pkl b/pickles/perp_279408293.pkl new file mode 100644 index 0000000..1994d1a Binary files /dev/null and b/pickles/perp_279408293.pkl differ diff --git a/pickles/perporacles_274342436.pkl b/pickles/perporacles_274342436.pkl deleted file mode 100644 index 10fefd3..0000000 Binary files a/pickles/perporacles_274342436.pkl and /dev/null differ diff --git a/pickles/perporacles_279408296.pkl b/pickles/perporacles_279408296.pkl new file mode 100644 index 0000000..35cb71c Binary files /dev/null and b/pickles/perporacles_279408296.pkl differ diff --git a/pickles/spot_274342432.pkl b/pickles/spot_274342432.pkl deleted file mode 100644 index 036c367..0000000 Binary files a/pickles/spot_274342432.pkl and /dev/null differ diff --git a/pickles/spot_279408289.pkl b/pickles/spot_279408289.pkl new file mode 100644 index 0000000..759adf3 Binary files /dev/null and b/pickles/spot_279408289.pkl differ diff --git a/pickles/spotoracles_274342436.pkl b/pickles/spotoracles_274342436.pkl deleted file mode 100644 index 916fce8..0000000 Binary files a/pickles/spotoracles_274342436.pkl and /dev/null differ diff --git a/pickles/spotoracles_279408296.pkl b/pickles/spotoracles_279408296.pkl new file mode 100644 index 0000000..f1d4dc2 Binary files /dev/null and b/pickles/spotoracles_279408296.pkl differ diff --git a/pickles/usermap_274342298.pkl b/pickles/usermap_279408223.pkl similarity index 62% rename from pickles/usermap_274342298.pkl rename to pickles/usermap_279408223.pkl index 91bea55..894aaf5 100644 Binary files a/pickles/usermap_274342298.pkl and b/pickles/usermap_279408223.pkl differ diff --git a/pickles/userstats_0.pkl b/pickles/userstats_0.pkl new file mode 100644 index 0000000..92c3c88 --- /dev/null +++ b/pickles/userstats_0.pkl @@ -0,0 +1 @@ +€]”. \ No newline at end of file diff --git a/pickles/userstats_274342347.pkl b/pickles/userstats_274342347.pkl deleted file mode 100644 index cb2c561..0000000 Binary files a/pickles/userstats_274342347.pkl and /dev/null differ diff --git a/requirements.txt b/requirements.txt index 30f8541..193f430 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ aiodns==3.0.0 aiohttp==3.8.3 aiosignal==1.3.1 -altair==5.3.0 anchorpy==0.20.1 anchorpy-core==0.2.0 anyio==3.6.2 @@ -11,7 +10,7 @@ attrs==22.1.0 backoff==2.2.1 base58==2.1.1 based58==0.1.1 -blinker==1.8.2 +black==24.4.2 borsh-construct==0.1.0 cachetools==4.2.4 certifi==2022.12.7 @@ -22,86 +21,68 @@ construct==2.10.68 construct-typing==0.5.3 Deprecated==1.2.14 dnspython==2.2.1 -driftpy==0.7.59 +driftpy==0.7.73 Events==0.5 exceptiongroup==1.0.4 flake8==6.0.0 frozenlist==1.3.3 ghp-import==2.1.0 -gitdb==4.0.11 -GitPython==3.1.43 -grpcio==1.64.1 +grpcio==1.65.1 h11==0.14.0 httpcore==0.16.3 httpx==0.23.1 idna==3.4 iniconfig==1.1.1 -Jinja2==3.0.3 +Jinja2==3.1.4 jito_searcher_client==0.1.4 jsonalias==0.1.1 -jsonrpcclient==4.0.2 +jsonrpcclient==4.0.3 jsonrpcserver==5.0.9 jsonschema==4.17.3 loguru==0.6.0 Markdown==3.6 -markdown-it-py==3.0.0 MarkupSafe==2.1.5 mccabe==0.7.0 -mdurl==0.1.2 mergedeep==1.3.4 mkdocs==1.6.0 mkdocs-get-deps==0.2.0 more-itertools==8.14.0 multidict==6.0.3 -mypy==1.10.0 +mypy==1.11.0 mypy-extensions==1.0.0 numpy==1.26.4 OSlash==0.6.3 -packaging==22.0 -pandas==2.2.2 +packaging==23.1 pathspec==0.12.1 -pillow==10.3.0 platformdirs==4.2.2 -plotly==5.22.0 pluggy==1.0.0 -protobuf==4.25.3 +protobuf==4.25.4 psutil==5.9.4 py==1.11.0 -pyarrow==16.1.0 pycares==4.3.0 pycodestyle==2.10.0 pycparser==2.21 -pydeck==0.9.1 pyflakes==3.0.1 -Pygments==2.18.0 pyheck==0.1.5 pyrsistent==0.19.2 pythclient==0.1.4 python-dateutil==2.9.0.post0 -pytz==2024.1 PyYAML==6.0.1 pyyaml_env_tag==0.1 requests==2.32.3 rfc3986==1.5.0 -rich==13.7.1 -rpds-py==0.18.1 six==1.16.0 -smmap==5.0.1 sniffio==1.3.0 solana==0.34.0 solders==0.21.0 -streamlit==1.35.0 sumtypes==0.1a6 -tenacity==8.3.0 toml==0.10.2 tomli==2.0.1 toolz==0.11.2 -tornado==6.4.1 types-cachetools==4.2.10 types-requests==2.31.0.6 types-urllib3==1.26.25.14 typing_extensions==4.12.2 -tzdata==2024.1 urllib3==1.26.13 watchdog==4.0.1 websockets==10.4 diff --git a/src/cache.py b/src/cache.py new file mode 100644 index 0000000..b99f11a --- /dev/null +++ b/src/cache.py @@ -0,0 +1,41 @@ +from asyncio import AbstractEventLoop +from scenario import get_usermap_df +import time +import streamlit as st +from driftpy.drift_client import DriftClient +from driftpy.pickle.vat import Vat +import pandas as pd + + +def _load_asset_liab_dfs( + dc: DriftClient, vat: Vat, loop: AbstractEventLoop +) -> tuple[tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], list[str]]: + start = time.time() + oracle_distort = 0 + (levs_none, levs_init, levs_maint), user_keys = loop.run_until_complete( + get_usermap_df( + dc, + vat.users, + "margins", + oracle_distort, + None, + "ignore stables", + n_scenarios=0, + all_fields=True, + ) + ) + print(f"Loaded asset/liability data in {time.time() - start:.2f} seconds") + return (levs_none, levs_init, levs_maint), user_keys + + +@st.cache_data +def _cache_dataframes( + dfs: tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], user_keys: list[str] +) -> tuple[tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], list[str]]: + return dfs, user_keys + + +def get_cached_asset_liab_dfs( + dc: DriftClient, vat: Vat, loop: AbstractEventLoop +) -> tuple[tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame], list[str]]: + return _cache_dataframes(*_load_asset_liab_dfs(dc, vat, loop)) diff --git a/src/health_utils.py b/src/health_utils.py index 059f368..de94665 100644 --- a/src/health_utils.py +++ b/src/health_utils.py @@ -24,11 +24,18 @@ ) from driftpy.types import is_variant from driftpy.pickle.vat import Vat -from driftpy.constants.spot_markets import mainnet_spot_market_configs, devnet_spot_market_configs -from driftpy.constants.perp_markets import mainnet_perp_market_configs, devnet_perp_market_configs +from driftpy.constants.spot_markets import ( + mainnet_spot_market_configs, + devnet_spot_market_configs, +) +from driftpy.constants.perp_markets import ( + mainnet_perp_market_configs, + devnet_perp_market_configs, +) from utils import load_newest_files, load_vat, to_financial + def get_largest_perp_positions(vat: Vat): top_positions: list[Any] = [] @@ -293,4 +300,4 @@ def get_most_levered_spot_borrows_above_1m(vat: Vat): "Public Key": [pos[1] for pos in borrows], } - return data \ No newline at end of file + return data diff --git a/src/main.py b/src/main.py index 3957605..8974f78 100644 --- a/src/main.py +++ b/src/main.py @@ -1,69 +1,129 @@ import asyncio -import heapq +import io import time import os +import aiohttp +import msgpack +import zipfile -from asyncio import AbstractEventLoop -import plotly.express as px # type: ignore import pandas as pd # type: ignore +import streamlit as st from typing import Any +import datetime as dt +from asyncio import AbstractEventLoop from solana.rpc.async_api import AsyncClient from anchorpy import Wallet -import streamlit as st -from driftpy.drift_user import DriftUser from driftpy.drift_client import DriftClient from driftpy.account_subscription_config import AccountSubscriptionConfig -from driftpy.constants.numeric_constants import ( - BASE_PRECISION, - SPOT_BALANCE_PRECISION, - PRICE_PRECISION, -) -from driftpy.types import is_variant -from driftpy.pickle.vat import Vat -from driftpy.constants.spot_markets import mainnet_spot_market_configs, devnet_spot_market_configs -from driftpy.constants.perp_markets import mainnet_perp_market_configs, devnet_perp_market_configs - -from utils import load_newest_files, load_vat, to_financial -from sections.asset_liab_matrix import asset_liab_matrix_page + +from health_utils import * +from cache import get_cached_asset_liab_dfs +from utils import load_newest_files, load_vat, clear_local_pickles +from sections.asset_liab_matrix import asset_liab_matrix_page, get_matrix from sections.ob import ob_cmp_page from sections.scenario import plot_page from sections.liquidation_curves import plot_liquidation_curve +from sections.margin_model import margin_model + + +SERVER_URL = "http://54.74.185.225:8080" + + +async def fetch_context(session: aiohttp.ClientSession, req: str) -> dict[str, Any]: + async with session.get(req) as response: + return msgpack.unpackb(await response.read(), strict_map_key=False) + + +async def fetch_pickles(session: aiohttp.ClientSession, req: str) -> dict[str, Any]: + async with session.get(req) as response: + content = await response.read() + with zipfile.ZipFile(io.BytesIO(content)) as zip_ref: + zip_ref.extractall(os.getcwd() + "/pickles") + + +async def setup_context(dc: DriftClient, loop: AbstractEventLoop, env): + start_dashboard_ready = time.time() + async with aiohttp.ClientSession() as session: + print("fetching context") + start = time.time() + + tasks = [ + fetch_pickles(session, f"{SERVER_URL}/pickles"), + fetch_context(session, f"{SERVER_URL}/{env}_context"), + ] + _, context_data = await asyncio.gather(*tasks) + print("context fetched in ", time.time() - start) + + filepath = os.getcwd() + "/pickles" + newest_snapshot = load_newest_files(filepath) + start_load_vat = time.time() + vat = await load_vat(dc, newest_snapshot, loop, env) + clear_local_pickles(filepath) + st.session_state["vat"] = vat + print(f"loaded vat in {time.time() - start_load_vat}") + + levs = [ + context_data["levs_none"], + context_data["levs_init"], + context_data["levs_maint"], + ] + user_keys = context_data["user_keys"] + margin = [pd.DataFrame(context_data["res"]), pd.DataFrame(context_data["df"])] + + st.session_state["margin"] = tuple(margin) + st.session_state["asset_liab_data"] = tuple(levs), user_keys + print(f"dashboard ready in: {time.time() - start_dashboard_ready}") + +def setup_context_local(dc: DriftClient, loop: AbstractEventLoop, env): + vat: Vat + if "vat" not in st.session_state: + newest_snapshot = load_newest_files(os.getcwd() + "/pickles") + + start_load_vat = time.time() + vat = loop.run_until_complete(load_vat(dc, newest_snapshot, loop, env)) + st.session_state["vat"] = vat + print(f"loaded vat in {time.time() - start_load_vat}") + else: + vat = st.session_state["vat"] -from health_utils import * + if "asset_liab_data" not in st.session_state: + st.session_state["asset_liab_data"] = get_cached_asset_liab_dfs(dc, vat, loop) -@st.cache(allow_output_mutation=True) -def cached_load_vat(dc: DriftClient): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - newest_snapshot = load_newest_files(os.getcwd() + "/pickles") - vat = loop.run_until_complete(load_vat(dc, newest_snapshot)) - loop.close() - return vat - -def get_vat(dc: DriftClient): - start_load_vat = time.time() - vat = cached_load_vat(dc) - print(f"loaded vat in {time.time() - start_load_vat}") - return vat + if "margin" not in st.session_state: + start = time.time() + st.session_state["margin"] = get_matrix(loop, vat, dc) + print(f"loaded matrix in {time.time() - start}") + + st.session_state["context"] = True def main(): + if "context" not in st.session_state: + st.session_state["context"] = False st.set_page_config(layout="wide") - url = os.getenv("RPC_URL", "🤫") - env = st.sidebar.radio('env', ('mainnet-beta', 'devnet')) - rpc = st.sidebar.text_input("RPC URL", value=url) - if env == 'mainnet-beta' and (rpc == '🤫' or rpc == ''): - rpc = os.environ['ANCHOR_PROVIDER_URL'] - query_index = 0 + + env = st.sidebar.radio("Environment:", ["prod", "dev"]) + + source = st.sidebar.radio("Source:" , ["local", "remote"]) + def query_string_callback(): - st.query_params['tab'] = st.session_state.query_key - query_tab = st.query_params.get('tab', ['Welcome'])[0] - tab_options = ('Welcome', 'Health', 'Price-Shock', 'Asset-Liab-Matrix', 'Orderbook', 'Liquidations') + st.query_params["tab"] = st.session_state.query_key + + query_tab = st.query_params.get("tab", ["Welcome"])[0] + tab_options = ( + "Welcome", + "Health", + "Price-Shock", + "Asset-Liab-Matrix", + "Orderbook", + "Liquidations", + "Margin-Model", + ) for idx, x in enumerate(tab_options): if x.lower() == query_tab.lower(): query_index = idx @@ -73,64 +133,104 @@ def query_string_callback(): tab_options, query_index, on_change=query_string_callback, - key='query_key' + key="query_key", + ) + + if tab is None: + tab = "Welcome" + + drift_client = DriftClient( + AsyncClient("https://api.mainnet-beta.solana.com/"), + Wallet.dummy(), + account_subscription=AccountSubscriptionConfig("cached"), + ) + + def func(): + if ( + tab.lower() + in [ + "welcome", + "health", + "price-shock", + "asset-liab-matrix", + "liquidations", + "margin-model", + ] + and "vat" not in st.session_state + ): + md = st.empty() + md.markdown("`Loading dashboard, do not leave this page`") + if st.session_state["context"] == False: + if source.lower() == "local": + setup_context_local(drift_client, loop, env) + else: + loop.run_until_complete(setup_context(drift_client, loop, env)) + st.session_state["context"] = True + md.markdown("`Dashboard ready!`") + + st.sidebar.button("Start Dashboard", on_click=func) + + loop: AbstractEventLoop = asyncio.new_event_loop() + + if tab.lower() == "welcome": + st.header("Welcome to the Drift v2 Risk Analytics Dashboard!") + st.metric( + "protocol has been live for:", + str( + int( + (dt.datetime.now() - pd.to_datetime("2022-11-05")).total_seconds() + / (60 * 60 * 24) + ) + ) + + " days", ) - - if rpc == "🤫" or rpc == "": - st.warning("Please enter a Solana RPC URL") - else: - drift_client = DriftClient( - AsyncClient(rpc), - Wallet.dummy(), - account_subscription=AccountSubscriptionConfig("cached"), + st.write( + "Click `Start Dashboard` to load the dashboard on the selected `Environment`" ) - loop: AbstractEventLoop = asyncio.new_event_loop() - if tab.lower() in ['health', 'price-shock', 'asset-liab-matrix', 'liquidations'] and 'vat' not in st.session_state: - # start_sub = time.time() - # loop.run_until_complete(dc.subscribe()) - # print(f"subscribed in {time.time() - start_sub}") - - newest_snapshot = load_newest_files(os.getcwd() + "/pickles") - - start_load_vat = time.time() - vat = loop.run_until_complete(load_vat(drift_client, newest_snapshot)) - st.session_state["vat"] = vat - print(f"loaded vat in {time.time() - start_load_vat}") - elif tab.lower() in ['health', 'price-shock', 'asset-liab-matrix', 'liquidations']: - vat = st.session_state["vat"] - - if tab.lower() == 'health': - health_distribution = get_account_health_distribution(vat) - - with st.container(): - st.plotly_chart(health_distribution, use_container_width=True) - - perp_col, spot_col = st.columns([1, 1]) - - with perp_col: - largest_perp_positions = get_largest_perp_positions(vat) - st.markdown("### **Largest perp positions:**") - st.table(largest_perp_positions) - most_levered_positions = get_most_levered_perp_positions_above_1m(vat) - st.markdown("### **Most levered perp positions > $1m:**") - st.table(most_levered_positions) - - with spot_col: - largest_spot_borrows = get_largest_spot_borrows(vat) - st.markdown("### **Largest spot borrows:**") - st.table(largest_spot_borrows) - most_levered_borrows = get_most_levered_spot_borrows_above_1m(vat) - st.markdown("### **Most levered spot borrows > $750k:**") - st.table(most_levered_borrows) - - elif tab.lower() == 'price-shock': - plot_page(loop, vat, drift_client) - elif tab.lower() == 'asset-liab-matrix': - asset_liab_matrix_page(loop, vat, drift_client) - elif tab.lower() == 'orderbook': - ob_cmp_page() - elif tab.lower() == 'liquidations': - plot_liquidation_curve(vat) - -main() + if tab.lower() in [ + "health", + "price-shock", + "asset-liab-matrix", + "liquidations", + "margin model", + ]: + vat = st.session_state["vat"] + + if tab.lower() == "health": + health_distribution = get_account_health_distribution(vat) + + with st.container(): + st.plotly_chart(health_distribution, use_container_width=True) + + perp_col, spot_col = st.columns([1, 1]) + + with perp_col: + largest_perp_positions = get_largest_perp_positions(vat) + st.markdown("### **Largest perp positions:**") + st.table(largest_perp_positions) + most_levered_positions = get_most_levered_perp_positions_above_1m(vat) + st.markdown("### **Most levered perp positions > $1m:**") + st.table(most_levered_positions) + + with spot_col: + largest_spot_borrows = get_largest_spot_borrows(vat) + st.markdown("### **Largest spot borrows:**") + st.table(largest_spot_borrows) + most_levered_borrows = get_most_levered_spot_borrows_above_1m(vat) + st.markdown("### **Most levered spot borrows > $750k:**") + st.table(most_levered_borrows) + + elif tab.lower() == "price-shock": + plot_page(loop, vat, drift_client) + elif tab.lower() == "asset-liab-matrix": + asset_liab_matrix_page(loop, vat, drift_client) + elif tab.lower() == "orderbook": + ob_cmp_page() + elif tab.lower() == "liquidations": + plot_liquidation_curve(vat) + elif tab.lower() == "margin-model": + margin_model(loop, drift_client) + + +main() \ No newline at end of file diff --git a/src/scenario.py b/src/scenario.py index 8054eb8..688b12c 100644 --- a/src/scenario.py +++ b/src/scenario.py @@ -1,10 +1,11 @@ import sys from tokenize import tabsize import driftpy -import pandas as pd -import numpy as np +import pandas as pd +import numpy as np import copy import plotly.express as px + pd.options.plotting.backend = "plotly" # from datafetch.transaction_fetch import load_token_balance # from driftpy.constants.config import configs @@ -13,9 +14,15 @@ from solana.rpc.async_api import AsyncClient from solana.rpc.types import MemcmpOpts from driftpy.drift_client import DriftClient -from driftpy.accounts import get_perp_market_account, get_spot_market_account, get_user_account, get_state_account -from driftpy.constants.numeric_constants import * +from driftpy.accounts import ( + get_perp_market_account, + get_spot_market_account, + get_user_account, + get_state_account, +) +from driftpy.constants.numeric_constants import * from driftpy.drift_user import DriftUser, get_token_amount + # from datafetch.transaction_fetch import transaction_history_for_account, load_token_balance import pickle @@ -29,49 +36,90 @@ from driftpy.constants.perp_markets import devnet_perp_market_configs, PerpMarketConfig from dataclasses import dataclass from solders.pubkey import Pubkey + # from helpers import serialize_perp_market_2, serialize_spot_market, all_user_stats, DRIFT_WHALE_LIST_SNAP from anchorpy import EventParser import asyncio from driftpy.math.margin import MarginCategory import requests from driftpy.types import InsuranceFundStakeAccount, SpotMarketAccount, OraclePriceData -from driftpy.addresses import * +from driftpy.addresses import * import time from driftpy.market_map.market_map_config import WebsocketConfig from driftpy.user_map.user_map import UserMap, UserMapConfig, PollingConfig import datetime import csv -from utils import get_init_health -NUMBER_OF_SPOT = 20 -NUMBER_OF_PERP = 33 +NUMBER_OF_SPOT = len(mainnet_spot_market_configs) +NUMBER_OF_PERP = len(mainnet_perp_market_configs) + + +def get_init_health(user: DriftUser): + if user.is_being_liquidated(): + return 0 + + total_collateral = user.get_total_collateral(MarginCategory.INITIAL) + maintenance_margin_req = user.get_margin_requirement(MarginCategory.INITIAL) + + if maintenance_margin_req == 0 and total_collateral >= 0: + return 100 + elif total_collateral <= 0: + return 0 + else: + return round( + min(100, max(0, (1 - maintenance_margin_req / total_collateral) * 100)) + ) + def comb_asset_liab(a_l_tup): return a_l_tup[0] - a_l_tup[1] + def get_collateral_composition(x: DriftUser, margin_category, n): # ua = x.get_user_account() - net_v = {i: comb_asset_liab(x.get_spot_market_asset_and_liability_value(i, margin_category))/QUOTE_PRECISION for i in range(n)} - return net_v + net_v = { + i: comb_asset_liab( + x.get_spot_market_asset_and_liability_value(i, margin_category) + ) + / QUOTE_PRECISION + for i in range(n) + } + return net_v + def get_perp_liab_composition(x: DriftUser, margin_category, n): # ua = x.get_user_account() - net_p = {i: x.get_perp_market_liability(i, margin_category, signed=True)/QUOTE_PRECISION for i in range(n)} - return net_p + net_p = { + i: x.get_perp_market_liability(i, margin_category, signed=True) + / QUOTE_PRECISION + for i in range(n) + } + return net_p + def get_perp_lp_share_composition(x: DriftUser, n): # ua = x.get_user_account() def get_lp_shares(x, i): res = x.get_perp_position(i) if res is not None: - return res.lp_shares/1e9 + return res.lp_shares / 1e9 else: return 0 + net_p = {i: get_lp_shares(x, i) for i in range(n)} - return net_p + return net_p + -async def get_usermap_df(_drift_client: DriftClient, user_map: UserMap, mode: str, oracle_distor=.1, - only_one_index=None, cov_matrix=None, n_scenarios=5, all_fields=False): +async def get_usermap_df( + _drift_client: DriftClient, + user_map: UserMap, + mode: str, + oracle_distor=0.1, + only_one_index=None, + cov_matrix=None, + n_scenarios=5, + all_fields=False, +): perp_n = NUMBER_OF_PERP spot_n = NUMBER_OF_SPOT @@ -86,73 +134,113 @@ def do_dict(x: DriftUser, margin_category: MarginCategory, oracle_cache=None): # user_account = x.get_user_account() levs0 = { - # 'tokens': [x.get_token_amount(i) for i in range(spot_n)], - 'user_key': x.user_public_key, - 'leverage': x.get_leverage() / MARGIN_PRECISION, - 'health': health_func(x), - 'perp_liability': x.get_perp_market_liability(None, margin_category) / QUOTE_PRECISION, - 'spot_asset': x.get_spot_market_asset_value(None, margin_category) / QUOTE_PRECISION, - 'spot_liability': x.get_spot_market_liability_value(None, margin_category) / QUOTE_PRECISION, - 'upnl': x.get_unrealized_pnl(True) / QUOTE_PRECISION, - # 'funding_upnl': x.get_unrealized_funding_pnl() / QUOTE_PRECISION, - # 'total_collateral': x.get_total_collateral(margin_category or MarginCategory.INITIAL) / QUOTE_PRECISION, - # 'margin_req': x.get_margin_requirement(margin_category or MarginCategory.INITIAL) / QUOTE_PRECISION, - # 'net_v': get_collateral_composition(x, margin_category, spot_n), - # 'net_p': get_perp_liab_composition(x, margin_category, perp_n), - # 'net_lp': get_perp_lp_share_composition(x, perp_n), - # 'last_active_slot': user_account.last_active_slot, - # 'cumulative_perp_funding': user_account.cumulative_perp_funding/QUOTE_PRECISION, - # 'settled_perp_pnl': user_account.settled_perp_pnl/QUOTE_PRECISION, - # 'name': bytes(user_account.name).decode('utf-8', errors='ignore').strip(), - # 'authority': str(user_account.authority), - # 'has_open_order': user_account.has_open_order, - # 'sub_account_id': user_account.sub_account_id, - # 'next_liquidation_id': user_account.next_liquidation_id, - # 'cumulative_spot_fees': user_account.cumulative_spot_fees, - # 'total_deposits': user_account.total_deposits, - # 'total_withdraws': user_account.total_withdraws, - # 'total_social_loss': user_account.total_social_loss, - # 'unsettled_pnl_perp_x': x.get_unrealized_pnl(True, market_index=24) / QUOTE_PRECISION, + # 'tokens': [x.get_token_amount(i) for i in range(spot_n)], + "user_key": x.user_public_key, + "leverage": x.get_leverage() / MARGIN_PRECISION, + "health": health_func(x), + "perp_liability": x.get_perp_market_liability(None, margin_category) + / QUOTE_PRECISION, + "spot_asset": x.get_spot_market_asset_value(None, margin_category) + / QUOTE_PRECISION, + "spot_liability": x.get_spot_market_liability_value(None, margin_category) + / QUOTE_PRECISION, + "upnl": x.get_unrealized_pnl(True) / QUOTE_PRECISION, + "net_usd_value": ( + x.get_net_spot_market_value(None) + x.get_unrealized_pnl(True) + ) + / QUOTE_PRECISION, + # 'funding_upnl': x.get_unrealized_funding_pnl() / QUOTE_PRECISION, + # 'total_collateral': x.get_total_collateral(margin_category or MarginCategory.INITIAL) / QUOTE_PRECISION, + # 'margin_req': x.get_margin_requirement(margin_category or MarginCategory.INITIAL) / QUOTE_PRECISION, + # 'net_v': get_collateral_composition(x, margin_category, spot_n), + # 'net_p': get_perp_liab_composition(x, margin_category, perp_n), + # 'net_lp': get_perp_lp_share_composition(x, perp_n), + # 'last_active_slot': user_account.last_active_slot, + # 'cumulative_perp_funding': user_account.cumulative_perp_funding/QUOTE_PRECISION, + # 'settled_perp_pnl': user_account.settled_perp_pnl/QUOTE_PRECISION, + # 'name': bytes(user_account.name).decode('utf-8', errors='ignore').strip(), + # 'authority': str(user_account.authority), + # 'has_open_order': user_account.has_open_order, + # 'sub_account_id': user_account.sub_account_id, + # 'next_liquidation_id': user_account.next_liquidation_id, + # 'cumulative_spot_fees': user_account.cumulative_spot_fees, + # 'total_deposits': user_account.total_deposits, + # 'total_withdraws': user_account.total_withdraws, + # 'total_social_loss': user_account.total_social_loss, + # 'unsettled_pnl_perp_x': x.get_unrealized_pnl(True, market_index=24) / QUOTE_PRECISION, } - levs0['net_usd_value'] = levs0['spot_asset'] + levs0['upnl'] - levs0['spot_liability'] + # levs0['net_usd_value'] = levs0['spot_asset'] + levs0['upnl'] - levs0['spot_liability'] if all_fields: - levs0['net_v'] = get_collateral_composition(x, margin_category, spot_n) - levs0['net_p'] = get_perp_liab_composition(x, margin_category, spot_n) + levs0["net_v"] = get_collateral_composition(x, margin_category, spot_n) + levs0["net_p"] = get_perp_liab_composition(x, margin_category, spot_n) return levs0 + user_map_result: UserMap = user_map - + user_keys = list(user_map_result.user_map.keys()) user_vals = list(user_map_result.values()) - if cov_matrix == 'ignore stables': - skipped_oracles = [str(x.oracle) for x in mainnet_spot_market_configs if 'USD' in x.symbol] - elif cov_matrix == 'sol + lst only': - skipped_oracles = [str(x.oracle) for x in mainnet_spot_market_configs if 'SOL' not in x.symbol] - elif cov_matrix == 'sol lst only': - skipped_oracles = [str(x.oracle) for x in mainnet_spot_market_configs if x.symbol not in ['mSOL', 'jitoSOL', 'bSOL']] - elif cov_matrix == 'sol ecosystem only': - skipped_oracles = [str(x.oracle) for x in mainnet_spot_market_configs if x.symbol not in ['PYTH', 'JTO', 'WIF', 'JUP', 'TNSR', 'DRIFT']] - elif cov_matrix == 'meme': - skipped_oracles = [str(x.oracle) for x in mainnet_spot_market_configs if x.symbol not in ['WIF']] - elif cov_matrix == 'wrapped only': - skipped_oracles = [str(x.oracle) for x in mainnet_spot_market_configs if x.symbol not in ['wBTC', 'wETH']] - elif cov_matrix == 'stables only': - skipped_oracles = [str(x.oracle) for x in mainnet_spot_market_configs if 'USD' not in x.symbol] + if cov_matrix == "ignore stables": + skipped_oracles = [ + str(x.oracle) for x in mainnet_spot_market_configs if "USD" in x.symbol + ] + elif cov_matrix == "sol + lst only": + skipped_oracles = [ + str(x.oracle) for x in mainnet_spot_market_configs if "SOL" not in x.symbol + ] + elif cov_matrix == "sol lst only": + skipped_oracles = [ + str(x.oracle) + for x in mainnet_spot_market_configs + if x.symbol not in ["mSOL", "jitoSOL", "bSOL"] + ] + elif cov_matrix == "sol ecosystem only": + skipped_oracles = [ + str(x.oracle) + for x in mainnet_spot_market_configs + if x.symbol not in ["PYTH", "JTO", "WIF", "JUP", "TNSR", "DRIFT"] + ] + elif cov_matrix == "meme": + skipped_oracles = [ + str(x.oracle) + for x in mainnet_spot_market_configs + if x.symbol not in ["WIF"] + ] + elif cov_matrix == "wrapped only": + skipped_oracles = [ + str(x.oracle) + for x in mainnet_spot_market_configs + if x.symbol not in ["wBTC", "wETH"] + ] + elif cov_matrix == "stables only": + skipped_oracles = [ + str(x.oracle) for x in mainnet_spot_market_configs if "USD" not in x.symbol + ] if only_one_index is None or len(only_one_index) > 12: only_one_index_key = only_one_index else: - only_one_index_key = ([str(x.oracle) for x in mainnet_perp_market_configs if x.base_asset_symbol == only_one_index] \ - +[str(x.oracle) for x in mainnet_spot_market_configs if x.symbol == only_one_index])[0] + only_one_index_key = ( + [ + str(x.oracle) + for x in mainnet_perp_market_configs + if x.base_asset_symbol == only_one_index + ] + + [ + str(x.oracle) + for x in mainnet_spot_market_configs + if x.symbol == only_one_index + ] + )[0] - if mode == 'margins': + if mode == "margins": levs_none = list(do_dict(x, None) for x in user_vals) levs_init = list(do_dict(x, MarginCategory.INITIAL) for x in user_vals) levs_maint = list(do_dict(x, MarginCategory.MAINTENANCE) for x in user_vals) return (levs_none, levs_init, levs_maint), user_keys else: - num_entrs = n_scenarios # increment to get more steps + num_entrs = n_scenarios # increment to get more steps new_oracles_dat_up = [] new_oracles_dat_down = [] @@ -160,13 +248,14 @@ def do_dict(x: DriftUser, margin_category: MarginCategory, oracle_cache=None): new_oracles_dat_up.append({}) new_oracles_dat_down.append({}) - - assert(len(new_oracles_dat_down) == num_entrs) - print('skipped oracles:', skipped_oracles) + assert len(new_oracles_dat_down) == num_entrs + print("skipped oracles:", skipped_oracles) distorted_oracles = [] cache_up = copy.deepcopy(_drift_client.account_subscriber.cache) cache_down = copy.deepcopy(_drift_client.account_subscriber.cache) - for i,(key, val) in enumerate(_drift_client.account_subscriber.cache['oracle_price_data'].items()): + for i, (key, val) in enumerate( + _drift_client.account_subscriber.cache["oracle_price_data"].items() + ): for i in range(num_entrs): new_oracles_dat_up[i][key] = copy.deepcopy(val) new_oracles_dat_down[i][key] = copy.deepcopy(val) @@ -175,8 +264,8 @@ def do_dict(x: DriftUser, margin_category: MarginCategory, oracle_cache=None): if only_one_index is None or only_one_index_key == key: distorted_oracles.append(key) for i in range(num_entrs): - oracle_distort_up = max(1 + oracle_distor * (i+1), 1) - oracle_distort_down = max(1 - oracle_distor * (i+1), 0) + oracle_distort_up = max(1 + oracle_distor * (i + 1), 1) + oracle_distort_down = max(1 - oracle_distor * (i + 1), 0) # weird pickle artifact if isinstance(new_oracles_dat_up[i][key], OraclePriceData): @@ -191,17 +280,20 @@ def do_dict(x: DriftUser, margin_category: MarginCategory, oracle_cache=None): levs_down = [] for i in range(num_entrs): - cache_up['oracle_price_data'] = new_oracles_dat_up[i] - cache_down['oracle_price_data'] = new_oracles_dat_down[i] + cache_up["oracle_price_data"] = new_oracles_dat_up[i] + cache_down["oracle_price_data"] = new_oracles_dat_down[i] levs_up_i = list(do_dict(x, None, cache_up) for x in user_vals) levs_down_i = list(do_dict(x, None, cache_down) for x in user_vals) levs_up.append(levs_up_i) levs_down.append(levs_down_i) - return (levs_none, tuple(levs_up), tuple(levs_down)), user_keys, distorted_oracles - + return ( + (levs_none, tuple(levs_up), tuple(levs_down)), + user_keys, + distorted_oracles, + ) async def get_new_ff(usermap): await usermap.sync() - usermap.dump() \ No newline at end of file + usermap.dump() diff --git a/src/sections/asset_liab_matrix.py b/src/sections/asset_liab_matrix.py index 9bc16b4..b4f2ea8 100644 --- a/src/sections/asset_liab_matrix.py +++ b/src/sections/asset_liab_matrix.py @@ -8,53 +8,102 @@ from driftpy.constants.spot_markets import mainnet_spot_market_configs from driftpy.constants.perp_markets import mainnet_perp_market_configs -from scenario import get_usermap_df +from scenario import NUMBER_OF_SPOT # type: ignore +from cache import get_cached_asset_liab_dfs # type: ignore options = [0, 1, 2, 3] -labels = ["none", "liq within 50% of oracle", "maint. health < 10%", "init. health < 10%"] - -def get_matrix(loop: AbstractEventLoop, vat: Vat, drift_client: DriftClient, env='mainnet', mode=0, perp_market_inspect=0): - NUMBER_OF_SPOT = 20 - NUMBER_OF_PERP = 33 - - oracle_distort = 0 - if "margin" not in st.session_state: - (levs_none, levs_init, levs_maint), user_keys = loop.run_until_complete(get_usermap_df(drift_client, vat.users, - 'margins', oracle_distort, - None, 'ignore stables', n_scenarios=0, all_fields=True)) - levs_maint = [x for x in levs_maint if int(x['health']) <= 10] - levs_init = [x for x in levs_init if int(x['health']) <= 10] - st.session_state["margin"] = (levs_none, levs_init, levs_maint), user_keys +labels = [ + "none", + "liq within 50% of oracle", + "maint. health < specified %", + "init. health < specified %", +] + + +def get_matrix( + loop: AbstractEventLoop, + vat: Vat, + drift_client: DriftClient, + env="mainnet", + mode=0, + perp_market_inspect=0, + health=50, +): + if "asset_liab_data" not in st.session_state: + st.session_state["asset_liab_data"] = get_cached_asset_liab_dfs( + drift_client, vat, loop + ) + (levs_none, levs_init, levs_maint) = st.session_state["asset_liab_data"][0] + user_keys = st.session_state["asset_liab_data"][1] else: - (levs_none, levs_init, levs_maint), user_keys = st.session_state["margin"] - + (levs_none, levs_init, levs_maint) = st.session_state["asset_liab_data"][0] + user_keys = st.session_state["asset_liab_data"][1] + + if mode == 2: + levs_maint = [x for x in levs_maint if int(x["health"]) <= int(health)] + if mode == 3: + levs_init = [x for x in levs_init if int(x["health"]) <= int(health)] + df: pd.DataFrame match mode: - case 0: # nothing + case 0: # nothing df = pd.DataFrame(levs_none, index=user_keys) - case 1: # liq within 50% of oracle + case 1: # liq within 50% of oracle df = pd.DataFrame(levs_none, index=user_keys) - case 2: # maint. health < 10% - user_keys = [x['user_key'] for x in levs_init] + case 2: # maint. health < 10% + user_keys = [x["user_key"] for x in levs_init] df = pd.DataFrame(levs_init, index=user_keys) - case 3: # init. health < 10% - user_keys = [x['user_key'] for x in levs_maint] + case 3: # init. health < 10% + user_keys = [x["user_key"] for x in levs_maint] df = pd.DataFrame(levs_maint, index=user_keys) - + def get_rattt(row): calculations = [ - ('all_assets', lambda v: v if v > 0 else 0), # Simplified from v / row['spot_asset'] * row['spot_asset'] - ('all', lambda v: v / row['spot_asset'] * (row['perp_liability'] + row['spot_liability']) if v > 0 else 0), - ('all_perp', lambda v: v / row['spot_asset'] * row['perp_liability'] if v > 0 else 0), - ('all_spot', lambda v: v / row['spot_asset'] * row['spot_liability'] if v > 0 else 0), - (f'perp_{perp_market_inspect}_long', lambda v: v / row['spot_asset'] * row['net_p'][perp_market_inspect] if v > 0 and row['net_p'][0] > 0 else 0), - (f'perp_{perp_market_inspect}_short', lambda v: v / row['spot_asset'] * row['net_p'][perp_market_inspect] if v > 0 and row['net_p'][perp_market_inspect] < 0 else 0), + ( + "all_assets", + lambda v: v if v > 0 else 0, + ), + ( + "all", + lambda v: ( + v + / row["spot_asset"] + * (row["perp_liability"] + row["spot_liability"]) + if v > 0 + else 0 + ), + ), + ("leverage", lambda v: row["leverage"]), + ( + "all_perp", + lambda v: v / row["spot_asset"] * row["perp_liability"] if v > 0 else 0, + ), + ( + "all_spot", + lambda v: v / row["spot_asset"] * row["spot_liability"] if v > 0 else 0, + ), + ( + f"perp_{perp_market_inspect}_long", + lambda v: ( + v / row["spot_asset"] * row["net_p"][perp_market_inspect] + if v > 0 and row["net_p"][0] > 0 + else 0 + ), + ), + ( + f"perp_{perp_market_inspect}_short", + lambda v: ( + v / row["spot_asset"] * row["net_p"][perp_market_inspect] + if v > 0 and row["net_p"][perp_market_inspect] < 0 + else 0 + ), + ), ] series_list = [] for suffix, calc_func in calculations: - series = pd.Series([calc_func(val) for key, val in row['net_v'].items()]) - series.index = [f'spot_{x}_{suffix}' for x in series.index] + series = pd.Series([calc_func(val) for key, val in row["net_v"].items()]) + series.index = [f"spot_{x}_{suffix}" for x in series.index] series_list.append(series) return pd.concat(series_list) @@ -62,84 +111,136 @@ def get_rattt(row): df = pd.concat([df, df.apply(get_rattt, axis=1)], axis=1) def calculate_effective_leverage(group): - assets = group['all_assets'] - liabilities = group['all_liabilities'] + assets = group["all_assets"] + liabilities = group["all_liabilities"] return liabilities / assets if assets != 0 else 0 def format_with_checkmark(value, condition, mode, financial=False): if financial: - formatted_value = f"{value:,.2f}" + formatted_value = f"${value:,.2f}" else: formatted_value = f"{value:.2f}" - + if condition and mode > 0: return f"{formatted_value} ✅" return formatted_value - res = pd.DataFrame({ - ('spot' + str(i)): ( - df[f"spot_{i}_all_assets"].sum(), - format_with_checkmark( - df[f"spot_{i}_all"].sum(), - 0 < df[f"spot_{i}_all"].sum() < 1_000_000, - mode, - financial=True - ), - format_with_checkmark( - calculate_effective_leverage({ - 'all_assets': df[f"spot_{i}_all_assets"].sum(), - 'all_liabilities': df[f"spot_{i}_all"].sum() - }), - 0 < calculate_effective_leverage({ - 'all_assets': df[f"spot_{i}_all_assets"].sum(), - 'all_liabilities': df[f"spot_{i}_all"].sum() - }) < 2, - mode - ), - df[f"spot_{i}_all_spot"].sum(), - df[f"spot_{i}_all_perp"].sum(), - df[f"spot_{i}_perp_{perp_market_inspect}_long"].sum(), - df[f"spot_{i}_perp_{perp_market_inspect}_short"].sum() - ) for i in range(NUMBER_OF_SPOT) - }, index=['all_assets', 'all_liabilities', 'effective_leverage', 'all_spot', 'all_perp', - f'perp_{perp_market_inspect}_long', - f'perp_{perp_market_inspect}_short']).T - - res['all_liabilities'] = res['all_liabilities'].astype(str) - res['effective_leverage'] = res['effective_leverage'].astype(str) - - if env == 'mainnet': #mainnet_spot_market_configs + res = pd.DataFrame( + { + ("spot" + str(i)): ( + df[f"spot_{i}_all_assets"].sum(), + format_with_checkmark( + df[f"spot_{i}_all"].sum(), + 0 < df[f"spot_{i}_all"].sum() < 1_000_000, + mode, + financial=True, + ), + format_with_checkmark( + calculate_effective_leverage( + { + "all_assets": df[f"spot_{i}_all_assets"].sum(), + "all_liabilities": df[f"spot_{i}_all"].sum(), + } + ), + 0 + < calculate_effective_leverage( + { + "all_assets": df[f"spot_{i}_all_assets"].sum(), + "all_liabilities": df[f"spot_{i}_all"].sum(), + } + ) + < 2, + mode, + ), + df[f"spot_{i}_all_spot"].sum(), + df[f"spot_{i}_all_perp"].sum(), + df[f"spot_{i}_perp_{perp_market_inspect}_long"].sum(), + df[f"spot_{i}_perp_{perp_market_inspect}_short"].sum(), + ) + for i in range(NUMBER_OF_SPOT) + }, + index=[ + "all_assets", + "all_liabilities", + "effective_leverage", + "all_spot", + "all_perp", + f"perp_{perp_market_inspect}_long", + f"perp_{perp_market_inspect}_short", + ], + ).T + + res["all_liabilities"] = res["all_liabilities"].astype(str) + res["effective_leverage"] = res["effective_leverage"].astype(str) + + if env == "mainnet": # mainnet_spot_market_configs res.index = [x.symbol for x in mainnet_spot_market_configs] - res.index.name = 'spot assets' # type: ignore + res.index.name = "spot assets" # type: ignore return res, df -def asset_liab_matrix_page(loop: AbstractEventLoop, vat: Vat, drift_client: DriftClient, env='mainnet'): - mode = st.selectbox("Options", options, format_func=lambda x: labels[x]) + +def asset_liab_matrix_page( + loop: AbstractEventLoop, vat: Vat, drift_client: DriftClient, env="mainnet" +): + print(f"[ASSET-LIAB-MATRIX] context set?: {st.session_state['context']}") + if st.session_state["context"] == False: + st.write("Please load dashboard before viewing this page") + return + + mode = st.selectbox("Mode", options, format_func=lambda x: labels[x]) if mode is None: mode = 0 - perp_market_inspect = st.selectbox("Market index", [x.market_index for x in mainnet_perp_market_configs]) + perp_market_inspect = st.selectbox( + "Market index", [x.market_index for x in mainnet_perp_market_configs] + ) if perp_market_inspect is None: perp_market_inspect = 0 - - res, df = get_matrix(loop, vat, drift_client, env, mode, perp_market_inspect) - st.write(f"{df.shape[0]} users for scenario") + slider_disabled = mode not in [2, 3] + + health = st.slider( + "Upper health cutoff (only for init. health and maint. health modes)", + min_value=1, + max_value=100, + value=50, + step=1, + disabled=slider_disabled, + ) + + if mode not in [1, 2]: + if "margin" not in st.session_state: + st.session_state["margin"] = get_matrix(loop, vat, drift_client) + res = st.session_state["margin"][0] + df = st.session_state["margin"][1] + else: + res = st.session_state["margin"][0] + df = st.session_state["margin"][1] + else: + import time + start = time.time() + res, df = get_matrix( + loop, vat, drift_client, env, mode, perp_market_inspect, health + ) + st.write(f"matrix in {time.time() - start:.2f} seconds") + + st.write(f"{df.shape[0]} users for scenario") st.write(res) - tabs = st.tabs(['FULL'] + [x.symbol for x in mainnet_spot_market_configs]) + tabs = st.tabs(["FULL"] + [x.symbol for x in mainnet_spot_market_configs]) tabs[0].dataframe(df) for idx, tab in enumerate(tabs[1:]): - important_cols = [x for x in df.columns if 'spot_'+str(idx) in x] - toshow = df[['spot_asset', 'net_usd_value']+important_cols] - toshow = toshow[toshow[important_cols].abs().sum(axis=1)!=0].sort_values(by="spot_"+str(idx)+'_all', ascending=False) - tab.write(f'{ len(toshow)} users with this asset to cover liabilities') + important_cols = [x for x in df.columns if "spot_" + str(idx) in x] + toshow = df[["spot_asset", "net_usd_value"] + important_cols] + toshow = toshow[toshow[important_cols].abs().sum(axis=1) != 0].sort_values( + by="spot_" + str(idx) + "_all", ascending=False + ) + tab.write(f"{ len(toshow)} users with this asset to cover liabilities") tab.dataframe(toshow) - diff --git a/src/sections/health.py b/src/sections/health.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/sections/liquidation_curves.py b/src/sections/liquidation_curves.py index 2de971b..079fbc8 100644 --- a/src/sections/liquidation_curves.py +++ b/src/sections/liquidation_curves.py @@ -3,36 +3,76 @@ from driftpy.constants.numeric_constants import ( BASE_PRECISION, PRICE_PRECISION, + QUOTE_PRECISION, ) +from driftpy.drift_user import DriftUser +from driftpy.math.margin import MarginCategory import numpy as np -import plotly.graph_objects as go # type: ignore +import plotly.graph_objects as go # type: ignore import streamlit as st +from solders.pubkey import Pubkey # type: ignore +import pandas as pd +from driftpy.constants.perp_markets import mainnet_perp_market_configs -options = [0, 1, 2] -labels = ["SOL-PERP", "BTC-PERP", "ETH-PERP"] +# options = [0, 1, 2] +# labels = ["SOL-PERP", "BTC-PERP", "ETH-PERP"] -def get_liquidation_curve(vat: Vat, market_index: int): - liquidations_long: list[tuple[float, float]] = [] - liquidations_short: list[tuple[float, float]] = [] +options = {market.symbol: market.market_index for market in mainnet_perp_market_configs} + + +def get_liquidation_curve(vat: Vat, market_index: int, use_liq_buffer=False): + liquidations_long: list[tuple[float, float, Pubkey]] = [] + liquidations_short: list[tuple[float, float, Pubkey]] = [] market_price = vat.perp_oracles.get(market_index) - market_price_ui = market_price.price / PRICE_PRECISION # type: ignore + market_price_ui = market_price.price / PRICE_PRECISION # type: ignore for user in vat.users.user_map.values(): + user: DriftUser = user perp_position = user.get_perp_position(market_index) if perp_position is not None: - liquidation_price = user.get_perp_liq_price(market_index) + liquidation_price = user.get_perp_liq_price(market_index) + if liquidation_price is not None: liquidation_price_ui = liquidation_price / PRICE_PRECISION - position_size = abs(perp_position.base_asset_amount) / BASE_PRECISION - position_notional = position_size * market_price_ui + + if use_liq_buffer: + perp_market = user.drift_client.get_perp_market_account( + market_index + ) + oracle_str = str(perp_market.amm.oracle) + prev_price = user.drift_client.account_subscriber.cache[ + "oracle_price_data" + ][oracle_str].price + user.drift_client.account_subscriber.cache["oracle_price_data"][ + oracle_str + ].price = liquidation_price + + position_notional = ( + user.get_margin_requirement(MarginCategory.MAINTENANCE, 100) + - user.get_total_collateral() + ) / QUOTE_PRECISION + user.drift_client.account_subscriber.cache["oracle_price_data"][ + oracle_str + ].price = prev_price + + else: + position_size = ( + abs(perp_position.base_asset_amount) / BASE_PRECISION + ) + position_notional = position_size * market_price_ui + is_zero = round(position_notional) == 0 is_short = perp_position.base_asset_amount < 0 is_long = perp_position.base_asset_amount > 0 if is_zero: continue if is_short and liquidation_price_ui > market_price_ui: - liquidations_short.append((liquidation_price_ui, position_notional)) + liquidations_short.append( + (liquidation_price_ui, position_notional, user.user_public_key) + ) elif is_long and liquidation_price_ui < market_price_ui: - liquidations_long.append((liquidation_price_ui, position_notional)) + liquidations_long.append( + (liquidation_price_ui, position_notional, user.user_public_key) + ) else: pass # print(f"liquidation price for user {user.user_public_key} is {liquidation_price_ui} and market price is {market_price_ui} - is_short: {is_short} - size {position_size} - notional {position_notional}") @@ -46,13 +86,23 @@ def get_liquidation_curve(vat: Vat, market_index: int): # for (price, size) in liquidations_short: # print(f"Short liquidation for {size} @ {price}") - return plot_liquidation_curves(liquidations_long, liquidations_short, market_price_ui) - + return plot_liquidation_curves( + liquidations_long, liquidations_short, market_price_ui + ), (liquidations_long, liquidations_short) + + def plot_liquidation_curves(liquidations_long, liquidations_short, market_price_ui): - def filter_outliers(liquidations, upper_bound_multiplier=2.0, lower_bound_multiplier=0.5): + def filter_outliers( + liquidations, upper_bound_multiplier=2.0, lower_bound_multiplier=0.5 + ): """Filter out liquidations based on a range multiplier of the market price.""" - return [(price, notional) for price, notional in liquidations - if lower_bound_multiplier * market_price_ui <= price <= upper_bound_multiplier * market_price_ui] + return [ + (price, notional) + for price, notional, _ in liquidations + if lower_bound_multiplier * market_price_ui + <= price + <= upper_bound_multiplier * market_price_ui + ] def aggregate_liquidations(liquidations): """Aggregate liquidations to calculate cumulative notional amounts.""" @@ -64,27 +114,36 @@ def aggregate_liquidations(liquidations): def prepare_data_for_plot(aggregated_data, reverse=False): """Prepare and sort data for plotting, optionally reversing the cumulative sum for descending plots.""" sorted_prices = sorted(aggregated_data.keys(), reverse=reverse) - cumulative_notional = np.cumsum([aggregated_data[price] for price in sorted_prices]) + cumulative_notional = np.cumsum( + [aggregated_data[price] for price in sorted_prices] + ) # if reverse: # cumulative_notional = cumulative_notional[::-1] # Reverse cumulative sum for descending plots return sorted_prices, cumulative_notional # Filter outliers based on defined criteria - liquidations_long = filter_outliers(liquidations_long, 2, 0.2) # Example multipliers for long positions - liquidations_short = filter_outliers(liquidations_short, 5, 0.5) # Example multipliers for short positions + liquidations_long = filter_outliers( + liquidations_long, 2, 0.5 + ) # Example multipliers for long positions + liquidations_short = filter_outliers( + liquidations_short, 3, 0.5 + ) # Example multipliers for short positions # Aggregate and prepare data aggregated_long = aggregate_liquidations(liquidations_long) aggregated_short = aggregate_liquidations(liquidations_short) - long_prices, long_cum_notional = prepare_data_for_plot(aggregated_long, reverse=True) + long_prices, long_cum_notional = prepare_data_for_plot( + aggregated_long, reverse=True + ) short_prices, short_cum_notional = prepare_data_for_plot(aggregated_short) - print(sum(long_cum_notional)) - print(sum(short_cum_notional)) + st.write( + "long/short cum notional:", sum(long_cum_notional), sum(short_cum_notional) + ) if not long_prices or not short_prices: - print("No data available for plotting.") + st.warning("No data available for plotting.") return None # Create Plotly figures @@ -92,46 +151,100 @@ def prepare_data_for_plot(aggregated_data, reverse=False): short_fig = go.Figure() # Add traces for long and short positions - long_fig.add_trace(go.Scatter(x=long_prices, y=long_cum_notional, mode='lines', name='Long Positions', - line=dict(color='purple', width=2))) - short_fig.add_trace(go.Scatter(x=short_prices, y=short_cum_notional, mode='lines', name='Short Positions', - line=dict(color='turquoise', width=2))) + long_fig.add_trace( + go.Scatter( + x=long_prices, + y=long_cum_notional, + mode="lines", + name="Long Positions", + line=dict(color="purple", width=2), + ) + ) + short_fig.add_trace( + go.Scatter( + x=short_prices, + y=short_cum_notional, + mode="lines", + name="Short Positions", + line=dict(color="turquoise", width=2), + ) + ) # Update layout with axis titles and grid settings - long_fig.update_layout(title='Long Liquidation Curve', - xaxis_title='Asset Price', - yaxis_title='Liquidations (Notional)', - xaxis=dict(showgrid=True), - yaxis=dict(showgrid=True)) - - short_fig.update_layout(title='Short Liquidation Curve', - xaxis_title='Asset Price', - yaxis_title='Liquidations (Notional)', - xaxis=dict(showgrid=True), - yaxis=dict(showgrid=True)) + long_fig.update_layout( + title="Long Liquidation Curve", + xaxis_title="Asset Price", + yaxis_title="Liquidations (Notional)", + xaxis=dict(showgrid=True), + yaxis=dict(showgrid=True), + ) + + short_fig.update_layout( + title="Short Liquidation Curve", + xaxis_title="Asset Price", + yaxis_title="Liquidations (Notional)", + xaxis=dict(showgrid=True), + yaxis=dict(showgrid=True), + ) return long_fig, short_fig - + def plot_liquidation_curve(vat: Vat): + print(f"[LIQUIDATIONS] context set?: {st.session_state['context']}") + if st.session_state["context"] == False: + st.write("Please load dashboard before viewing this page") + return + st.write("Liquidation Curves") - market_index = st.selectbox( - "Market", - options, - format_func=lambda x: labels[x], - ) + default_index = list(options.keys()).index("SOL-PERP") + market = st.selectbox("Market", options=list(options.keys()), index=default_index) + + market_index = options[market] # type: ignore if market_index is None: market_index = 0 - (long_fig, short_fig) = get_liquidation_curve(vat, market_index) + (long_fig, short_fig), ( + liquidations_long, + liquidations_short, + ) = get_liquidation_curve(vat, int(market_index), True) + (long_fig2, short_fig2), ( + liquidations_long2, + liquidations_short2, + ) = get_liquidation_curve(vat, int(market_index), False) long_col, short_col = st.columns([1, 1]) + use_liq_buffer = st.radio( + "use liq buffer in details:", [True, False], horizontal=True + ) + with long_col: + st.header("liq notional") st.plotly_chart(long_fig, use_container_width=True) + st.header("position notional") + st.plotly_chart(long_fig2, use_container_width=True) + with st.expander("user details"): + st.dataframe( + pd.DataFrame( + liquidations_long2 if use_liq_buffer else liquidations_long, + columns=["liq_price", "notional", "user_pubkey"], + ).sort_values("notional", ascending=False) + ) + with short_col: + st.header("liq notional") st.plotly_chart(short_fig, use_container_width=True) + st.header("position notional") + st.plotly_chart(short_fig2, use_container_width=True) + with st.expander("user details"): + st.dataframe( + pd.DataFrame( + liquidations_short2 if use_liq_buffer else liquidations_short, + columns=["liq_price", "notional", "user_pubkey"], + ).sort_values("notional", ascending=False) + ) diff --git a/src/sections/margin_model.py b/src/sections/margin_model.py new file mode 100644 index 0000000..6882b03 --- /dev/null +++ b/src/sections/margin_model.py @@ -0,0 +1,815 @@ +import copy +import streamlit as st +import pandas as pd # type: ignore + +from dataclasses import dataclass +from asyncio import AbstractEventLoop +from typing import Iterator + + +from driftpy.drift_client import DriftClient +from driftpy.drift_user import DriftUser +from driftpy.pickle.vat import Vat +from driftpy.accounts.types import DataAndSlot +from driftpy.types import SpotMarketAccount +from driftpy.constants.spot_markets import mainnet_spot_market_configs +from driftpy.math.margin import MarginCategory +from driftpy.constants.numeric_constants import ( + SPOT_BALANCE_PRECISION, + PERCENTAGE_PRECISION, + QUOTE_PRECISION, + PRICE_PRECISION, + BASE_PRECISION, +) + +from sections.asset_liab_matrix import get_matrix, NUMBER_OF_SPOT # type: ignore +from utils import aggregate_perps, drift_client_deep_copy, vat_deep_copy, drift_user_deep_copy + + +@dataclass +class LiquidationInfo: + spot_market_index: int + user_public_key: str + notional_liquidated: float + spot_asset_scaled_balance: int + + +spot_fields = [ + "deposit_balance", + "borrow_balance", + "initial_asset_weight", + "maintenance_asset_weight", + "initial_liability_weight", + "maintenance_liability_weight", + "optimal_utilization", + "optimal_borrow_rate", + "max_borrow_rate", + "market_index", + "scale_initial_asset_weight_start", +] + +margin_scalars = { + "USDC": 1.3, + "USDT": 1.3, + "SOL": 1.1, + "mSOL": 1.1, + "jitoSOL": 1.1, + "bSOL": 1.1, + "INF": 1.1, + "dSOL": 1.1, +} + +DEFAULT_MARGIN_SCALAR = 0.5 + +stables = [0, 5, 18] +sol_and_lst = [1, 2, 6, 8, 16, 17] +sol_eco = [7, 9, 11, 12, 13, 14, 15, 19] +wrapped = [3, 4] + +spot_market_indexes = [market.market_index for market in mainnet_spot_market_configs] + +index_options = {} +index_options["All"] = spot_market_indexes +index_options["Stables"] = stables +index_options["Solana and LST"] = sol_and_lst +index_options["Solana Ecosystem"] = sol_eco +index_options["Wrapped"] = wrapped +index_options.update( + {market.symbol: [market.market_index] for market in mainnet_spot_market_configs} +) + + +def margin_model(loop: AbstractEventLoop, dc: DriftClient): + print(f"[MARGIN-MODEL] context set?: {st.session_state['context']}") + if st.session_state["context"] == False: + st.write("Please load dashboard before viewing this page") + return + + st.header("Margin Model") + if "vat" not in st.session_state: + st.write("No Vat loaded.") + return + + vat: Vat = st.session_state["vat"] + + open_interest = 0 + for perp_market in vat.perp_markets.values(): + oracle_price = ( + vat.perp_oracles.get(perp_market.data.market_index).price / PRICE_PRECISION + ) + oi_long = perp_market.data.amm.base_asset_amount_long / BASE_PRECISION + oi_short = abs(perp_market.data.amm.base_asset_amount_short) / BASE_PRECISION + oi = max(oi_long, oi_short) * oracle_price + open_interest += oi + print(open_interest) + + import time + start = time.time() + aggregated_users = aggregate_perps(vat, loop) + print(f"aggregated users in {time.time() - start}") + # aggregated_users: list[DriftUser] + # if "agg_perps" not in st.session_state: + # aggregated_users = aggregate_perps(vat, loop) + # st.session_state["agg_perps"] = aggregated_users + # else: + # aggregated_users = st.session_state["agg_perps"] + + # spot_df: pd.DataFrame + # if "spot_df" not in st.session_state: + spot_df = get_spot_df(vat.spot_markets.values(), vat) + # st.session_state["spot_df"] = spot_df + # else: + # spot_df = st.session_state["spot_df"] + + stable_df = spot_df[spot_df.index.isin(stables)] + + col1, col2, col3 = st.columns([1, 1, 1]) + total_deposits = spot_df["deposit_balance"].sum() + total_stable_deposits_notional = stable_df["deposit_balance"].sum() + total_stable_borrows_notional = stable_df["borrow_balance"].sum() + total_stable_utilization = ( + total_stable_borrows_notional / total_stable_deposits_notional + ) + total_margin_available = spot_df["max_margin_extended"].sum() + maximum_exchange_leverage = total_margin_available / total_deposits + + with col1: + st.markdown( + f"##### Total Stable Collateral: `${total_stable_deposits_notional:,.2f}` #####" + ) + + st.markdown( + f"##### Total Stable Liabilities: `${total_stable_borrows_notional:,.2f}` #####" + ) + + st.markdown( + f"##### Total Stable Utilization: `{(total_stable_utilization * 100):,.2f}%` #####" + ) + + margin_df: pd.DataFrame + res: pd.DataFrame + if "margin" not in st.session_state: + st.session_state["margin"] = get_matrix(loop, vat, dc) + margin_df = st.session_state["margin"][1] + res = st.session_state["margin"][0] + else: + margin_df = st.session_state["margin"][1] + res = st.session_state["margin"][0] + + spot_df["all_liabilities"] = spot_df["symbol"].map(res["all_liabilities"]) + spot_df["all_liabilities"] = ( + spot_df["all_liabilities"].str.replace(r"[$,]", "", regex=True).astype(float) + ) + + spot_df.insert(3, "all_liabilities", spot_df.pop("all_liabilities")) + + spot_df["perp_leverage"] = spot_df["all_liabilities"] / spot_df["deposit_balance"] + + spot_df.insert(6, "perp_leverage", spot_df.pop("perp_leverage")) + + spot_df.rename(columns={"leverage": "actual_utilization"}, inplace=True) + + spot_df.insert(4, "optimal_utilization", spot_df.pop("optimal_utilization")) + total_margin_utilized_notional = sum( + margin_df[f"spot_{i}_all"].sum() for i in range(NUMBER_OF_SPOT) + ) + total_margin_utilized = total_margin_utilized_notional / total_margin_available + + # sanity check + assert float(total_margin_utilized_notional) > float(open_interest) + + actual_exchange_leverage = total_margin_utilized_notional / total_deposits + + with col2: + st.markdown( + f"##### Total Margin Extended: `${total_margin_available:,.2f}` #####" + ) + + st.markdown( + f"##### Total Margin Utilized: `${total_margin_utilized_notional:,.2f}` #####" + ) + + st.markdown( + f"##### Total Margin Utilized: `{(total_margin_utilized * 100):,.2f}%` #####" + ) + + with col3: + st.markdown(f"##### Total Collateral: `${total_deposits:,.2f}` #####") + + st.markdown( + f"##### Maximum Exchange Leverage: `{maximum_exchange_leverage:,.2f}` #####" + ) + + st.markdown( + f"##### Actual Exchange Leverage: `{actual_exchange_leverage:,.2f}` #####" + ) + + st.markdown("#### Spot Market Overview ####") + + col1, col2 = st.columns([1, 1]) + + default_index = list(index_options.keys()).index("All") + with col1: + selected_option = st.selectbox( + "Select Markets", options=list(index_options.keys()), index=default_index + ) + + with col2: + oracle_warp = st.text_input("Oracle Warp (%, 0-100)", 50) + + liqs_long, liqs_short = get_liquidations( + aggregated_users, vat, int(oracle_warp), loop + ) + + totals_long = {} + users_long = {} + + for liquidation_info in liqs_long: + print( + f"notional liquidated long: {liquidation_info.notional_liquidated} user public key: {liquidation_info.user_public_key}" + ) + totals_long[liquidation_info.spot_market_index] = ( + totals_long.get(liquidation_info.spot_market_index, 0) + + liquidation_info.notional_liquidated + ) + users_long[liquidation_info.user_public_key] = ( + liquidation_info.notional_liquidated, + liquidation_info.spot_asset_scaled_balance / SPOT_BALANCE_PRECISION, + ) + + totals_short = {} + users_short = {} + print("\n\n\n") + + for liquidation_info in liqs_short: + print( + f"notional liquidated short: {liquidation_info.notional_liquidated} user public key: {liquidation_info.user_public_key}" + ) + totals_short[liquidation_info.spot_market_index] = ( + totals_short.get(liquidation_info.spot_market_index, 0) + + liquidation_info.notional_liquidated + ) + users_short[liquidation_info.user_public_key] = ( + liquidation_info.notional_liquidated, + liquidation_info.spot_asset_scaled_balance / SPOT_BALANCE_PRECISION, + ) + + # this is where we will hit ec2 server for liquidation numbers + # users_long, users_short = make_request(oracle_warp) + + long_users_list = pd.DataFrame.from_dict(users_long, orient="index") + short_users_list = pd.DataFrame.from_dict(users_short, orient="index") + + long_users_list.columns = ["liquidation_notional", "scaled_balance"] + short_users_list.columns = ["liquidation_notional", "scaled_balance"] + + long_liq_notional = pd.Series(totals_long, name="long_liq_notional") + short_liq_notional = pd.Series(totals_short, name="short_liq_notional") + + long_liq_notional = long_liq_notional.reindex(spot_df.index, fill_value=0) + short_liq_notional = short_liq_notional.reindex(spot_df.index, fill_value=0) + + spot_df = spot_df.join(long_liq_notional) + spot_df = spot_df.join(short_liq_notional) + + spot_df.insert(6, "long_liq_notional", spot_df.pop("long_liq_notional")) + spot_df.insert(7, "short_liq_notional", spot_df.pop("short_liq_notional")) + + selected_indexes = index_options[selected_option] # type: ignore + + filtered_df = spot_df.loc[spot_df.index.isin(selected_indexes)] + + st.dataframe(display_formatted_df(filtered_df)) + + st.markdown("#### Spot Market Analysis ####") + + index_options_list = list(index_options.values()) + + all_analytics_df = get_analytics_df(index_options_list[0], spot_df) + sol_and_lst_basis_df = get_analytics_df(index_options_list[2], spot_df) + wrapped_basis_df = get_analytics_df(index_options_list[4], spot_df) + wif_df = get_analytics_df([10], spot_df) + + total_margin_available = all_analytics_df["target_margin_extended"].sum() + + st.markdown( + f"##### Total Target Margin Extended: `${total_margin_available:,.2f}` #####" + ) + + tabs = st.tabs(list(index_options.keys())) + + def get_recommendations(analytics_df: pd.DataFrame, selected_indexes: list[int]): + st.write(selected_indexes) + + for idx, tab in enumerate(tabs): + with tab: + if idx == 0: + current_analytics_df = all_analytics_df + st.dataframe(display_formatted_df(current_analytics_df)) + elif idx == 2: + st.dataframe(display_formatted_df(sol_and_lst_basis_df)) + elif idx == 4: + st.dataframe(display_formatted_df(wrapped_basis_df)) + elif set(index_options_list[idx]).issubset(set(sol_and_lst)): + filtered_df = sol_and_lst_basis_df[ + sol_and_lst_basis_df.index == index_options_list[idx][0] + ] + st.dataframe(display_formatted_df(filtered_df)) + elif set(index_options_list[idx]).issubset(wrapped): + filtered_df = wrapped_basis_df[ + wrapped_basis_df.index == index_options_list[idx][0] + ] + st.dataframe(display_formatted_df(filtered_df)) + elif set(index_options_list[idx]).issubset(set([10])): # WIF + st.dataframe(display_formatted_df(wif_df)) + else: + analytics_df = get_analytics_df(index_options_list[idx], spot_df) + st.dataframe(display_formatted_df(analytics_df)) + + # if len(index_options_list[idx]) == 1: + # get_recommendations(analytics_df, index_options_list[idx]) + + (levs_none, _, _) = st.session_state["asset_liab_data"][0] + user_keys = st.session_state["asset_liab_data"][1] + + df = pd.DataFrame(levs_none, index=user_keys) + + lev, size = get_size_and_lev(df, selected_indexes) + + col1, col2 = st.columns([1, 1]) + + with col1: + with st.expander("users by size (selected overview category)"): + st.dataframe(size) + + with col2: + with st.expander("users by lev (selected overview category)"): + st.dataframe(lev) + + col1, col2 = st.columns([1, 1]) + + with col1: + with st.expander("users long liquidation (offset)"): + st.dataframe(display_formatted_df(long_users_list)) + + with col2: + with st.expander("users short liquidation (offset)"): + st.dataframe(display_formatted_df(short_users_list)) + + +def get_size_and_lev(df: pd.DataFrame, market_indexes: list[int]): + def has_target(net_v, market_indexes): + if isinstance(net_v, dict): + return any(net_v.get(idx, 0) != 0 for idx in market_indexes) + return False + + lev = df.sort_values(by="leverage", ascending=False) + + lev = lev[lev.apply(lambda row: has_target(row["net_v"], market_indexes), axis=1)] + + lev["selected_assets"] = lev["net_v"].apply( + lambda net_v: { + idx: net_v.get(idx, 0) for idx in market_indexes if net_v.get(idx, 0) != 0 + } + ) + + size = df.sort_values(by="spot_asset", ascending=False) + size = size[ + size.apply(lambda row: has_target(row["net_v"], market_indexes), axis=1) + ] + + size["selected_assets"] = size["net_v"].apply( + lambda net_v: { + idx: net_v.get(idx, 0) for idx in market_indexes if net_v.get(idx, 0) != 0 + } + ) + + lev.pop("user_key") + size.pop("user_key") + + lev.insert(2, "selected_assets", lev.pop("selected_assets")) + size.insert(2, "selected_assets", size.pop("selected_assets")) + + return lev, size + + +def get_analytics_df(market_indexes: list[int], spot_df: pd.DataFrame): + (levs_none, _, _) = st.session_state["asset_liab_data"][0] + user_keys = st.session_state["asset_liab_data"][1] + + df = pd.DataFrame(levs_none, index=user_keys) + + analytics_df = pd.DataFrame() + analytics_df.index = spot_df.index + columns = [ + "symbol", + "deposit_balance", + "borrow_balance", + "all_liabilities", + "perp_leverage", + "max_margin_extended", + ] + analytics_df[columns] = spot_df[columns] + + def is_basis(market_indexes): + sol_lst_set = set(sol_and_lst) + wrapped_set = set(wrapped) + wif_set = set([10]) + + is_sol_lst = market_indexes.issubset(sol_lst_set) + is_wrapped = market_indexes.issubset(wrapped_set) + is_wif = market_indexes.issubset(wif_set) + + return is_sol_lst or is_wrapped or is_wif + + if is_basis(set(market_indexes)): + analytics_df["basis_short"] = spot_df.apply( + lambda row: get_basis_trade_notional(row, df), axis=1 + ) + + def calculate_target_margin(row): + current_max_margin = row["max_margin_extended"] + deposit_balance = row["deposit_balance"] + all_liabilities = row["all_liabilities"] + + target_leverage = 1.0 + + if deposit_balance > current_max_margin: + new_target_margin = deposit_balance + elif all_liabilities > deposit_balance: + new_target_margin = all_liabilities / target_leverage + else: + new_target_margin = min(deposit_balance, all_liabilities / target_leverage) + + return new_target_margin + + analytics_df["target_margin_extended"] = analytics_df.apply( + calculate_target_margin, axis=1 + ) + + safety_factor = 1.1 + analytics_df["target_margin_extended"] = ( + analytics_df["target_margin_extended"] * safety_factor + ) + + return analytics_df.loc[analytics_df.index.isin(market_indexes)] + + +def get_perp_short(df, market_index, basis_index): + new_column_name = f"perp_{basis_index}_short" + + def calculate_perp_short(row): + net_v = row["net_v"][market_index] + net_p = row["net_p"][basis_index] + spot_asset = row["spot_asset"] + + condition = net_v > 0 and net_p < 0 + value = net_v / spot_asset * net_p if condition else 0 + + return value + + df[new_column_name] = df.apply(calculate_perp_short, axis=1) + + return df[new_column_name] + + +def get_basis_trade_notional(row, df): + basis_index = -1 + market_index = row.name + if market_index in sol_and_lst and market_index != 0: + basis_index = 0 + elif market_index == 3: + basis_index = 1 + elif market_index == 4: + basis_index = 2 + + if basis_index == -1: + return 0 + + perp_short_series: pd.Series + if f"perp_short_series_{market_index}" in st.session_state: + perp_short_series = st.session_state[f"perp_short_series_{market_index}"] + else: + perp_short_series = get_perp_short(df, market_index, basis_index) + st.session_state[f"perp_short_series_{market_index}"] = perp_short_series + + total_perp_short = perp_short_series.sum() + print( + f"Total perp short for market_index {market_index} basis_index {basis_index}: {total_perp_short}" + ) + return abs(total_perp_short) + + +def get_spot_df(accounts: Iterator[DataAndSlot[SpotMarketAccount]], vat: Vat): + transformations = { + "deposit_balance": lambda x: x / SPOT_BALANCE_PRECISION, + "borrow_balance": lambda x: x / SPOT_BALANCE_PRECISION, + "initial_asset_weight": lambda x: x / PERCENTAGE_PRECISION * 100, + "maintenance_asset_weight": lambda x: x / PERCENTAGE_PRECISION * 100, + "initial_liability_weight": lambda x: x / PERCENTAGE_PRECISION * 100, + "maintenance_liability_weight": lambda x: x / PERCENTAGE_PRECISION * 100, + "optimal_utilization": lambda x: x / PERCENTAGE_PRECISION, + "optimal_borrow_rate": lambda x: x / PERCENTAGE_PRECISION, + "max_borrow_rate": lambda x: x / PERCENTAGE_PRECISION, + "scale_initial_asset_weight_start": lambda x: x / QUOTE_PRECISION, + } + + data = [ + {field: getattr(account.data, field) for field in spot_fields} + for account in accounts + ] + + df = pd.DataFrame(data) + + if "market_index" in spot_fields: + df.set_index("market_index", inplace=True) + + for column, transformation in transformations.items(): + if column in df.columns: + df[column] = df[column].apply(transformation) + + df["leverage"] = df["borrow_balance"] / df["deposit_balance"] + + df["symbol"] = df.index.map(lambda idx: mainnet_spot_market_configs[idx].symbol) + + def notional(row, balance_type): + market_price = vat.spot_oracles.get(row.name).price / PRICE_PRECISION # type: ignore + size = row[balance_type] + notional_value = size * market_price + return notional_value + + df["deposit_balance"] = df.apply(notional, balance_type="deposit_balance", axis=1) + df["borrow_balance"] = df.apply(notional, balance_type="borrow_balance", axis=1) + + df.insert(0, "symbol", df.pop("symbol")) + df.insert(1, "deposit_balance", df.pop("deposit_balance")) + df.insert(2, "borrow_balance", df.pop("borrow_balance")) + df.insert(4, "leverage", df.pop("leverage")) + df.insert( + 5, + "scale_initial_asset_weight_start", + df.pop("scale_initial_asset_weight_start"), + ) + + df.rename( + columns={"scale_initial_asset_weight_start": "max_margin_extended"}, + inplace=True, + ) + df = df.sort_index() + + return df + + +def display_formatted_df(df): + format_dict = { + "deposit_balance": "${:,.2f}", + "borrow_balance": "${:,.2f}", + "perp_liabilities": "${:,.2f}", + "initial_asset_weight": "{:.2%}", + "maintenance_asset_weight": "{:.2%}", + "initial_liability_weight": "{:.2%}", + "maintenance_liability_weight": "{:.2%}", + "optimal_utilization": "{:.2%}", + "actual_utilization": "{:.2%}", + "optimal_borrow_rate": "{:.2%}", + "max_borrow_rate": "{:.2%}", + "market_index": "{:}", + "max_margin_extended": "${:,.2f}", + "perp_leverage": "{:.2f}", + "target_margin_extended": "${:,.2f}", + "basis_short": "${:,.2f}", + "leverage": "{:.2f}", + "health": "{:.2%}%", + "short_liq_notional": "${:,.2f}", + "long_liq_notional": "${:,.2f}", + "liquidation_notional": "${:,.2f}", + "scaled_balance": "{:,.2f}", + } + + df.rename(columns={"all_liabilities": "perp_liabilities"}, inplace=True) + + styled_df = df.style.format(format_dict) + + return styled_df + + +def get_liquidations( + aggregated_users: list[DriftUser], + vat: Vat, + oracle_warp: int, + loop: AbstractEventLoop, +) -> tuple[list[LiquidationInfo], list[LiquidationInfo]]: + long_liquidations: list[LiquidationInfo] = [] + short_liquidations: list[LiquidationInfo] = [] + import time + print("starting liquidations") + start = time.time() + + drift_client = drift_client_deep_copy(vat.drift_client) + user_copies = [drift_user_deep_copy(user, drift_client) for user in aggregated_users] + print(f"deep copied users in {time.time() - start}") + vat_copy = vat_deep_copy(vat) + + current_sol_perp_price = vat_copy.perp_oracles.get(0).price / PRICE_PRECISION # type: ignore + sol_perp_oracle = vat_copy.perp_markets.get(0).data.amm.oracle # type: ignore + for user in user_copies: + user_total_spot_value = user.get_spot_market_asset_value() + # print(f"user public key {user.user_public_key}") + # print(f"user total spot value {user_total_spot_value}") + # ignore users that are bankrupt, of which there should not be many + if user_total_spot_value <= 0: + continue + + # print(user.get_user_account().spot_positions) + # print(user.get_user_account().perp_positions) + + # save the current user account state + saved_user_account = copy.deepcopy(user.get_user_account()) + + print("\n\n") + for i, spot_position in enumerate(saved_user_account.spot_positions): + # ignore borrows + from driftpy.types import is_variant + + if is_variant(spot_position.balance_type, "Borrow"): + continue + spot_market_index = spot_position.market_index + precision = vat_copy.spot_markets.get(spot_market_index).data.decimals # type: ignore + + # create a copy to force isolated margin upon + fake_user_account = copy.deepcopy(user.get_user_account()) + + # save the current oracle price, s.t. we can reset it after our calculations + spot_oracle_pubkey = vat_copy.spot_markets.get(spot_position.market_index).data.oracle # type: ignore + saved_price_data = vat_copy.spot_oracles.get(spot_position.market_index) + saved_price = saved_price_data.price + try: + # figure out what proportion of the user's collateral is in this spot market + spot_position_token_amount = user.get_token_amount( + spot_position.market_index + ) + # print(f"spot position scaled balance {spot_position.scaled_balance}") + # print(f"spot position token amount {spot_position_token_amount}") + # print( + # f"spot pos normalized {spot_position_token_amount / (10 ** precision)}" + # ) + collateral_in_spot_asset_usd = ( + spot_position_token_amount / (10**precision) + ) * (saved_price / PRICE_PRECISION) + proportion_of_net_collateral = collateral_in_spot_asset_usd / ( + user_total_spot_value / QUOTE_PRECISION + ) + + p = [ + pos + for pos in fake_user_account.perp_positions + if pos.market_index == 0 + ][0] + perp_position = copy.deepcopy(p) + + # this shouldn't ever happen, but if it does, we'll skip this user + if proportion_of_net_collateral > 1: + print("proportion of net collateral > 1") + continue + + # anything less than 1% of their collateral is dust relative to the rest of the account, so it's negligibly small + if proportion_of_net_collateral < 0.01: + print(f"proportion of net collateral < 0.01") + continue + + # scale the perp position size by the proportion of net collateral to mock isolated margin + perp_position.base_asset_amount = int( + perp_position.base_asset_amount * proportion_of_net_collateral + ) + perp_position.quote_asset_amount = int( + perp_position.quote_asset_amount * proportion_of_net_collateral + ) + + # if the position is so small that it's proportionally less than 1e-7 units of the asset, it's dust & negligible + if abs(perp_position.base_asset_amount) < 100: + print(f"perp position base asset amount < 100") + continue + + # replace the user's UserAccount with the mocked isolated margin account + fake_user_account.spot_positions = [copy.deepcopy(spot_position)] + fake_user_account.perp_positions = [copy.deepcopy(perp_position)] + user.account_subscriber.user_and_slot.data = fake_user_account + + # print(user.get_user_account().spot_positions) + # print(user.get_user_account().perp_positions) + + # set the oracle price to the price after an oracle_warp percent decrease + # it doesn't make sense to increase the collateral price, because nobody would ever get liquidated if their collateral went up in value + # our short / long numbers are evaluated based on the type of the perp position, which is... + shocked_price = max(saved_price * (1 - (int(oracle_warp) / 100)), 0) + + user.drift_client.account_subscriber.cache["oracle_price_data"][ + str(spot_oracle_pubkey) + ].price = shocked_price + + # ...evaluated here + is_short = perp_position.base_asset_amount < 0 + + # get the notional value that we would liquidate at the shock_price + # users are liquidated to 100 "margin ratio units" above their maintenance margin requirements + shocked_margin_requirement = user.get_margin_requirement( + MarginCategory.MAINTENANCE, 100 + ) + + shocked_spot_asset_value = user.get_spot_market_asset_value( + MarginCategory.MAINTENANCE, include_open_orders=True, strict=False + ) + shocked_upnl = user.get_unrealized_pnl( + True, MarginCategory.MAINTENANCE, strict=False + ) + + # print(f"STREAMLIT spot asset value: {shocked_spot_asset_value}") + # print(f"STREAMLIT upnl: {shocked_upnl}") + + shocked_total_collateral = user.get_total_collateral( + MarginCategory.MAINTENANCE, False + ) + + # if the user has more collateral than margin required, the position by definition cannot be in liquidation + if shocked_total_collateral >= shocked_margin_requirement: + continue + + shocked_notional = ( + shocked_margin_requirement - shocked_total_collateral + ) / QUOTE_PRECISION + + # get the liquidation price of the weighted & aggregated SOL-PERP position after collateral price shock + print(f"get perp liq price") + shocked_liquidation_price = user.get_perp_liq_price(0) / PRICE_PRECISION + + # some forced isolated accounts will have such a tiny position that their liquidation price will be some super-tiny negative number + # in this case, we do not care, because the position size is totally negligible and that liquidation price will never be hit + if shocked_liquidation_price < 0: + continue + + # print(f"is short {is_short}") + # print(f"user public key {user.user_public_key}") + # print(f"spot market index {spot_market_index}") + # print(f"margin requirement: {shocked_margin_requirement}") + + print(f"total collateral: {shocked_total_collateral}") + print(f"spot market index: {spot_position.market_index}") + print(f"notional: {shocked_notional}") + print(f"shocked price: {shocked_price}") + print(f"sol perp price {current_sol_perp_price}") + print(f"liquidation price {shocked_liquidation_price}") + print(f"in liquidation: {user.can_be_liquidated()}") + print(f"proportion of net collateral {proportion_of_net_collateral}") + # print(f"base asset amount {perp_position.base_asset_amount}") + # print( + # f"perp position value {(perp_position.base_asset_amount / BASE_PRECISION) * current_sol_perp_price}" + # ) + + print("\n\n") + + if is_short: + # if the position is short, and the liquidation price is lte the current price, the position is in liquidation + if user.can_be_liquidated(): + short_liquidations.append( + LiquidationInfo( + spot_market_index=spot_market_index, + user_public_key=user.user_public_key, + notional_liquidated=shocked_notional, + spot_asset_scaled_balance=spot_position.scaled_balance, + ) + ) + else: + # similarly, if the position is long, and the liquidation price is gte the current price, the position is in liquidation + if user.can_be_liquidated(): + long_liquidations.append( + LiquidationInfo( + spot_market_index=spot_market_index, + user_public_key=user.user_public_key, + notional_liquidated=shocked_notional, + spot_asset_scaled_balance=spot_position.scaled_balance, + ) + ) + finally: + # reset the user object to the original state + user.drift_client.account_subscriber.cache["oracle_price_data"][ + str(spot_oracle_pubkey) + ].price = saved_price + user.account_subscriber.user_and_slot.data = saved_user_account + vat_copy.spot_oracles[spot_position.market_index] = saved_price_data + + print(len(long_liquidations)) + print(len(short_liquidations)) + + original_spot_oracles = copy.deepcopy(vat.spot_oracles) + current_spot_oracles = copy.deepcopy(vat_copy.spot_oracles) + + for key, val in current_spot_oracles.items(): + print(f"[COPY] market index: {key}, price: {val.price}") + original_oracle = original_spot_oracles.get(key, None) + print(f"[ORIGINAL] market index: {key}, price: {original_oracle.price}") + if val.price != original_oracle.price: + assert False, f"[DOES NOT MATCH] market index: {key}, price: {val.price}, original price: {original_oracle.price}" + else: + print(f"[MATCHES] market index: {key}, price: {val.price}, original price: {original_oracle.price}") + + return long_liquidations, short_liquidations diff --git a/src/sections/ob.py b/src/sections/ob.py index 781d3c7..4b80502 100644 --- a/src/sections/ob.py +++ b/src/sections/ob.py @@ -1,4 +1,3 @@ - import asyncio import heapq import time @@ -13,8 +12,14 @@ from driftpy.drift_client import DriftClient from driftpy.pickle.vat import Vat -from driftpy.constants.spot_markets import mainnet_spot_market_configs, devnet_spot_market_configs -from driftpy.constants.perp_markets import mainnet_perp_market_configs, devnet_perp_market_configs +from driftpy.constants.spot_markets import ( + mainnet_spot_market_configs, + devnet_spot_market_configs, +) +from driftpy.constants.perp_markets import ( + mainnet_perp_market_configs, + devnet_perp_market_configs, +) from scenario import get_usermap_df import requests @@ -23,21 +28,14 @@ def fetch_ob_data(coin, size): # Define the POST request details post_url = "https://api.hyperliquid.xyz/info" - payload = { - 'type': 'metaAndAssetCtxs' - } - payload2 = { - "type": 'l2Book', - "coin": coin - } + payload = {"type": "metaAndAssetCtxs"} + payload2 = {"type": "l2Book", "coin": coin} - post_headers = { - "Content-Type": "application/json" - } + post_headers = {"Content-Type": "application/json"} results = {} - for nom, pay in [('hl_cxt', payload), ('hl_book', payload2)]: + for nom, pay in [("hl_cxt", payload), ("hl_book", payload2)]: # Send the POST request post_response = requests.post(post_url, json=pay, headers=post_headers) # Print the POST request response @@ -52,10 +50,10 @@ def fetch_ob_data(coin, size): # Define the GET request URL get_url = "https://dlob.drift.trade/l2" get_params = { - "marketName": coin+"-PERP", + "marketName": coin + "-PERP", "depth": 5, "includeOracle": "true", - "includeVamm": "true" + "includeVamm": "true", } # Send the GET request @@ -69,27 +67,26 @@ def fetch_ob_data(coin, size): else: print("Error:", get_response.text) - - results['dr_book'] = get_response.json() + results["dr_book"] = get_response.json() def calculate_average_fill_price_dr(order_book, volume): # Adjust volume to match size precision (1e9) volume = volume - - bids = order_book['bids'] - asks = order_book['asks'] + + bids = order_book["bids"] + asks = order_book["asks"] print(f'{float(bids[0]["price"])/1e6}/{float(asks[0]["price"])/1e6}') - + def average_price(levels, volume, is_buy): total_volume = 0 total_cost = 0.0 - + for level in levels: # Price is in 1e6 precision, size is in 1e9 precision - price = float(level['price']) / 1e6 - size = float(level['size']) / 1e9 - + price = float(level["price"]) / 1e6 + size = float(level["size"]) / 1e9 + if total_volume + size >= volume: # Only take the remaining required volume at this level remaining_volume = volume - total_volume @@ -100,35 +97,33 @@ def average_price(levels, volume, is_buy): # Take the whole size at this level total_cost += size * price total_volume += size - + if total_volume < volume: - raise ValueError("Insufficient volume in the order book to fill the order") - + raise ValueError( + "Insufficient volume in the order book to fill the order" + ) + return total_cost / volume - + try: buy_price = average_price(asks, volume, is_buy=True) sell_price = average_price(bids, volume, is_buy=False) except ValueError as e: return str(e) - - return { - "average_buy_price": buy_price, - "average_sell_price": sell_price - } + return {"average_buy_price": buy_price, "average_sell_price": sell_price} def calculate_average_fill_price_hl(order_book, volume): - buy_levels = order_book['levels'][1] # Bids (lower prices first) - sell_levels = order_book['levels'][0] # Asks (higher prices first) + buy_levels = order_book["levels"][1] # Bids (lower prices first) + sell_levels = order_book["levels"][0] # Asks (higher prices first) def average_price(levels, volume): total_volume = 0 total_cost = 0.0 for level in levels: - px = float(level['px']) - sz = float(level['sz']) + px = float(level["px"]) + sz = float(level["sz"]) if total_volume + sz >= volume: # Only take the remaining required volume at this level @@ -142,7 +137,9 @@ def average_price(levels, volume): total_volume += sz if total_volume < volume: - raise ValueError("Insufficient volume in the order book to fill the order") + raise ValueError( + "Insufficient volume in the order book to fill the order" + ) return total_cost / volume @@ -152,34 +149,30 @@ def average_price(levels, volume): except ValueError as e: return str(e) - return { - "average_buy_price": buy_price, - "average_sell_price": sell_price - } + return {"average_buy_price": buy_price, "average_sell_price": sell_price} + r = calculate_average_fill_price_hl(results["hl_book"], size) + d = calculate_average_fill_price_dr(results["dr_book"], size) + return (r, d, results["dr_book"]["oracle"] / 1e6, results["hl_cxt"]) - r = calculate_average_fill_price_hl(results['hl_book'], size) - d = calculate_average_fill_price_dr(results['dr_book'], size) - return (r,d, results['dr_book']['oracle']/1e6, results['hl_cxt']) def ob_cmp_page(): - if st.button('Refresh'): + if st.button("Refresh"): st.cache_data.clear() s1, s2 = st.columns(2) - coin = s1.selectbox('coin:', ['SOL','BTC','ETH']) - size = s2.number_input('size:', min_value=.1, value=1.0, help='in base units') + coin = s1.selectbox("coin:", ["SOL", "BTC", "ETH"]) + size = s2.number_input("size:", min_value=0.1, value=1.0, help="in base units") hl, dr, dr_oracle, hl_ctx = fetch_ob_data(coin, size) - - uni_id = [i for (i,x) in enumerate(hl_ctx[0]['universe']) if coin==x['name']] + uni_id = [i for (i, x) in enumerate(hl_ctx[0]["universe"]) if coin == x["name"]] # hl_oracle = hl_ctx o1, o2 = st.columns(2) - o1.header('hyperliquid') - o1.write(float(hl_ctx[1][uni_id[0]]['oraclePx'])) + o1.header("hyperliquid") + o1.write(float(hl_ctx[1][uni_id[0]]["oraclePx"])) o1.write(hl) - o2.header('drift') + o2.header("drift") o2.write(dr_oracle) - o2.write(dr) \ No newline at end of file + o2.write(dr) diff --git a/src/sections/scenario.py b/src/sections/scenario.py index 08d57c2..60479ee 100644 --- a/src/sections/scenario.py +++ b/src/sections/scenario.py @@ -15,70 +15,106 @@ from scenario import get_usermap_df + def price_shock_plot(price_scenario_users: list[Any], oracle_distort: float): levs = price_scenario_users - dfs = [pd.DataFrame(levs[2][i]) for i in range(len(levs[2]))] \ - + [pd.DataFrame(levs[0])] \ - + [pd.DataFrame(levs[1][i]) for i in range(len(levs[1]))] - + dfs = ( + [pd.DataFrame(levs[2][i]) for i in range(len(levs[2]))] + + [pd.DataFrame(levs[0])] + + [pd.DataFrame(levs[1][i]) for i in range(len(levs[1]))] + ) + spot_bankrs = [] for df in dfs: - spot_b_t1 = -(df[(df['spot_asset']= 0: - return 100 - elif total_collateral <= 0: - return 0 - else: - return round( - min(100, max(0, (1 - maintenance_margin_req / total_collateral) * 100)) - ) def to_financial(num): num_str = str(num) @@ -74,7 +93,12 @@ def load_newest_files(directory: Optional[str] = None) -> dict[str, str]: # function assumes that you have already subscribed # the use of websocket configs in here doesn't matter because the maps are never subscribed to -async def load_vat(dc: DriftClient, pickle_map: dict[str, str]) -> Vat: +async def load_vat( + dc: DriftClient, + pickle_map: dict[str, str], + loop: AbstractEventLoop, + env: str = "prod", +) -> Vat: perp = MarketMap( MarketMapConfig( dc.program, MarketType.Perp(), MarketMapWebsocketConfig(), dc.connection @@ -109,4 +133,211 @@ async def load_vat(dc: DriftClient, pickle_map: dict[str, str]) -> Vat: perp_oracles_filename, ) + if env == "dev": + users = [] + for user in vat.users.values(): + value = user.get_net_spot_market_value(None) + user.get_unrealized_pnl(True) + users.append( + (value, user.user_public_key, user.get_user_account_and_slot()) + ) + users.sort(key=lambda x: x[0], reverse=True) + vat.users.clear() + for user in users[:100]: + await vat.users.add_pubkey(user[1], user[2]) + + print(vat.users.values()) + return vat + + +def clear_local_pickles(directory: str): + for filename in os.listdir(directory): + os.remove(directory + "/" + filename) + + +def aggregate_perps(vat: Vat, loop: AbstractEventLoop): + print("aggregating perps") + + def aggregate_perp(user: DriftUser) -> DriftUser: + agg_perp = PerpPosition(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + sol_price = vat.perp_oracles.get(0).price + user_account = user.get_user_account() + sol_market = vat.perp_markets.get(0).data + for perp_position in user_account.perp_positions: + if perp_position.base_asset_amount == 0: + continue + asset_price = vat.perp_oracles.get(perp_position.market_index).price # type: ignore + + market = vat.perp_markets.get(perp_position.market_index) + # ratio transform + sol_margin_ratio = calculate_market_margin_ratio( + sol_market, agg_perp.base_asset_amount, MarginCategory.INITIAL + ) + margin_ratio = calculate_market_margin_ratio( + market.data, perp_position.base_asset_amount, MarginCategory.INITIAL + ) + sol_margin_scalar = 1 / (sol_margin_ratio / MARGIN_PRECISION) + curr_margin_scalar = 1 / (margin_ratio / MARGIN_PRECISION) + + # simple price conversion + exchange_rate = sol_price / asset_price + exchange_rate_normalized = exchange_rate / PRICE_PRECISION + new_baa = perp_position.base_asset_amount * exchange_rate_normalized + + # apply margin ratio transofmr + new_baa_adjusted = new_baa * (sol_margin_scalar / curr_margin_scalar) + + # aggregate + agg_perp.base_asset_amount += new_baa_adjusted + agg_perp.quote_asset_amount += perp_position.quote_asset_amount * ( + sol_margin_scalar / curr_margin_scalar + ) + + if agg_perp.base_asset_amount == 0: + return None + + # force use this new fake user account for all sdk functions + user_account.perp_positions = [agg_perp] + ds = user.account_subscriber.user_and_slot + ds.data = user_account + user.account_subscriber.user_and_slot = ds + return user + + users_list = list(vat.users.values()) + + import copy + + # deep copy usermap + # required or else aggregation affects vat.users which breaks stuff p bad + usermap = UserMap(UserMapConfig(vat.drift_client, UserMapWebsocketConfig())) + for user in users_list: + loop.run_until_complete( + usermap.add_pubkey( + copy.deepcopy(user.user_public_key), + copy.deepcopy(user.get_user_account_and_slot()), + ) + ) + + aggregated_users = [ + user + for user in (aggregate_perp(user) for user in usermap.values()) + if user is not None + ] + + aggregated_users = sorted( + aggregated_users, + key=lambda x: x.get_total_collateral(MarginCategory.MAINTENANCE) + + x.get_total_perp_position_value(MarginCategory.MAINTENANCE), + ) + + return aggregated_users + +import copy +from driftpy.account_subscription_config import AccountSubscriptionConfig +from driftpy.accounts.types import DataAndSlot + +def drift_client_deep_copy(dc: DriftClient) -> DriftClient: + from solana.rpc.async_api import AsyncClient + perp_markets = [] + spot_markets = [] + oracle_price_data = {} + state_account = None + for market in dc.get_perp_market_accounts(): + perp_markets.append(DataAndSlot(0, copy.deepcopy(market))) + + for market in dc.get_spot_market_accounts(): + spot_markets.append(DataAndSlot(0, copy.deepcopy(market))) + + for pubkey, oracle in dc.account_subscriber.cache["oracle_price_data"].items(): + oracle_price_data[copy.deepcopy(pubkey)] = copy.deepcopy(oracle) + + if dc.get_state_account() is not None: + state_account = copy.deepcopy(dc.account_subscriber.get_state_account_and_slot()) + + new_wallet = copy.deepcopy(dc.wallet) + new_connection = AsyncClient(copy.deepcopy(dc.connection._provider.endpoint_uri)) + + new_drift_client = DriftClient( + new_connection, + new_wallet, + account_subscription=AccountSubscriptionConfig("cached"), + ) + + new_drift_client.account_subscriber.cache["perp_markets"] = sorted(perp_markets, key=lambda x: x.data.market_index) + new_drift_client.account_subscriber.cache["spot_markets"] = sorted(spot_markets, key=lambda x: x.data.market_index) + new_drift_client.account_subscriber.cache["oracle_price_data"] = oracle_price_data + new_drift_client.account_subscriber.cache["state_account"] = state_account + + return new_drift_client + +def drift_user_deep_copy(user: DriftUser, drift_client: DriftClient) -> DriftUser: + user_account_and_slot = copy.deepcopy(user.get_user_account_and_slot()) + new_user = DriftUser( + drift_client, + copy.deepcopy(user.user_public_key), + account_subscription=AccountSubscriptionConfig("cached"), + ) + new_user.account_subscriber.user_and_slot = user_account_and_slot + return new_user + +def drift_user_stats_deep_copy(stats: DriftUserStats, drift_client: DriftClient) -> DriftUserStats: + new_user_stats = DriftUserStats( + drift_client, + copy.deepcopy(stats.user_public_key), + config=UserStatsSubscriptionConfig( + initial_data=copy.deepcopy(stats.get_account_and_slot()) + ), + ) + + return new_user_stats + +def vat_deep_copy(vat: Vat) -> Vat: + import time + start = time.time() + new_drift_client = drift_client_deep_copy(vat.drift_client) + print(f"copied drift client in {time.time() - start}") + + new_user_map = UserMap(UserMapConfig(new_drift_client, UserMapWebsocketConfig())) + new_spot_map = MarketMap( + MarketMapConfig( + new_drift_client.program, MarketType.Spot(), MarketMapWebsocketConfig(), new_drift_client.connection + ) + ) + new_perp_map = MarketMap( + MarketMapConfig( + new_drift_client.program, MarketType.Perp(), MarketMapWebsocketConfig(), new_drift_client.connection + ) + ) + new_perp_oracles = {} + new_spot_oracles = {} + + start = time.time() + for market_index, oracle in vat.perp_oracles.items(): + new_perp_oracles[copy.deepcopy(market_index)] = copy.deepcopy(oracle) + print(f"copied perp oracles in {time.time() - start}") + + start = time.time() + for market_index, oracle in vat.spot_oracles.items(): + new_spot_oracles[copy.deepcopy(market_index)] = copy.deepcopy(oracle) + print(f"copied spot oracles in {time.time() - start}") + + start = time.time() + for pubkey, market in vat.perp_markets.market_map.items(): + new_perp_map.market_map[copy.deepcopy(pubkey)] = copy.deepcopy(market) + print(f"copied perp markets in {time.time() - start}") + + start = time.time() + for pubkey, market in vat.spot_markets.market_map.items(): + new_spot_map.market_map[copy.deepcopy(pubkey)] = copy.deepcopy(market) + print(f"copied spot markets in {time.time() - start}") + + start = time.time() + for pubkey, user in vat.users.user_map.items(): + new_user_map.user_map[str(copy.deepcopy(pubkey))] = drift_user_deep_copy(user, new_drift_client) + print(f"copied users in {time.time() - start}") + + new_vat = Vat(new_drift_client, new_user_map, vat.user_stats, new_spot_map, new_perp_map) + new_vat.perp_oracles = new_perp_oracles + new_vat.spot_oracles = new_spot_oracles + + return new_vat