Skip to content

Commit

Permalink
margin.py: fix spot liq logic
Browse files Browse the repository at this point in the history
  • Loading branch information
0xbigz committed Dec 29, 2022
1 parent 6df2204 commit dc4336f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 13 deletions.
87 changes: 80 additions & 7 deletions src/driftpy/clearing_house_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,63 @@ def __init__(
# get_perp_market
# get_user
# if state = cache => get cached_market else get new market
async def set_cache_last(self, CACHE=None):
"""sets the cache of the accounts to use to inspect
Args:
CACHE (dict, optional): other existing cache object - if None will pull ƒresh accounts from RPC. Defaults to None.
"""
self.cache_is_set = True

if CACHE is not None:
self.CACHE = CACHE
return

self.CACHE = {}
state = await get_state_account(self.program)
self.CACHE['state'] = state

spot_markets = []
spot_market_oracle_data = []
for i in range(state.number_of_spot_markets):
spot_market = await get_spot_market_account(
self.program, i
)
spot_markets.append(spot_market)

if i == 0:
spot_market_oracle_data.append(OracleData(
PRICE_PRECISION, 0, 1, 1, 0, True
))
else:
oracle_data = OracleData(
spot_market.historical_oracle_data.last_oracle_price, 0, 1, 1, 0, True
)
spot_market_oracle_data.append(oracle_data)

self.CACHE['spot_markets'] = spot_markets
self.CACHE['spot_market_oracles'] = spot_market_oracle_data

perp_markets = []
perp_market_oracle_data = []
for i in range(state.number_of_markets):
perp_market = await get_perp_market_account(
self.program, i
)
perp_markets.append(perp_market)

oracle_data = OracleData(
perp_market.amm.historical_oracle_data.last_oracle_price, 0, 1, 1, 0, True
)
perp_market_oracle_data.append(oracle_data)

self.CACHE['perp_markets'] = perp_markets
self.CACHE['perp_market_oracles'] = perp_market_oracle_data

user = await get_user_account(
self.program, self.authority, self.subaccount_id
)
self.CACHE['user'] = user

async def set_cache(self, CACHE=None):
"""sets the cache of the accounts to use to inspect
Expand Down Expand Up @@ -298,15 +355,22 @@ async def can_be_liquidated(self) -> bool:
return total_collateral < maintenance_req

async def get_margin_requirement(
self, margin_category: MarginCategory, liquidation_buffer: Optional[int] = 0
self, margin_category: MarginCategory, liquidation_buffer: Optional[int] = 0,
include_open_orders=True,
include_spot=True
) -> int:
perp_liability = await self.get_total_perp_liability(
margin_category, liquidation_buffer, True
margin_category, liquidation_buffer, include_open_orders
)
spot_liability = await self.get_spot_market_liability(
None, margin_category, liquidation_buffer, True
)
return perp_liability + spot_liability

result = perp_liability
if include_spot:
spot_liability = await self.get_spot_market_liability(
None, margin_category, liquidation_buffer, include_open_orders
)
result += spot_liability

return result

async def get_total_collateral(
self, margin_category: Optional[MarginCategory] = None
Expand Down Expand Up @@ -424,6 +488,15 @@ async def get_spot_market_asset_value(
spot_token_value = get_token_amount(
position.scaled_balance, spot_market, position.balance_type
)

match str(position.balance_type):
case "SpotBalanceType.Deposit()":
spot_token_value *= 1
case "SpotBalanceType.Borrow()":
spot_token_value *= -1
case _:
raise Exception(f"Invalid balance type: {position.balance_type}")

total_value += spot_token_value
continue

Expand Down Expand Up @@ -503,7 +576,7 @@ async def get_spot_liq_price(
return None

total_collateral = await self.get_total_collateral(MarginCategory.MAINTENANCE)
margin_req = await self.get_margin_requirement(MarginCategory.MAINTENANCE)
margin_req = await self.get_margin_requirement(MarginCategory.MAINTENANCE, None, True, False)
delta_liq = total_collateral - margin_req

spot_market = await self.get_spot_market(spot_market_index)
Expand Down
13 changes: 7 additions & 6 deletions src/driftpy/math/margin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def calculate_size_discount_asset_weight(
asset_weight,
):
if imf_factor == 0:
return 0
return asset_weight

size_sqrt = int((size * 10) ** 0.5) + 1
imf_num = SPOT_IMF_PRECISION + (SPOT_IMF_PRECISION / 10)
Expand Down Expand Up @@ -53,11 +53,12 @@ def calculate_asset_weight(
spot_market.initial_asset_weight,
)
case MarginCategory.MAINTENANCE:
asset_weight = calculate_size_discount_asset_weight(
size_in_amm_precision,
spot_market.imf_factor,
spot_market.maintenance_asset_weight,
)
asset_weight = spot_market.maintenance_asset_weight
# calculate_size_discount_asset_weight(
# size_in_amm_precision,
# spot_market.imf_factor,
# spot_market.maintenance_asset_weight,
# )
case None:
asset_weight = spot_market.initial_asset_weight
case _:
Expand Down

0 comments on commit dc4336f

Please sign in to comment.