diff --git a/crud.py b/crud.py index c28848b..182658a 100644 --- a/crud.py +++ b/crud.py @@ -116,6 +116,12 @@ async def increment_withdraw_link(link: WithdrawLink) -> None: open_time=link.wait_time + int(datetime.now().timestamp()), ) +async def unincrement_withdraw_link(link: WithdrawLink) -> None: + await update_withdraw_link( + link.id, + used=link.used - 1, + open_time=link.wait_time + int(datetime.now().timestamp()), + ) async def update_withdraw_link(link_id: str, **kwargs) -> Optional[WithdrawLink]: if "is_unique" in kwargs: diff --git a/helpers.py b/helpers.py deleted file mode 100644 index 7a019fb..0000000 --- a/helpers.py +++ /dev/null @@ -1,49 +0,0 @@ -from threading import Lock -from typing import Dict - - -class CounterLock: - def __init__(self): - self.counter = 0 - self.lock = Lock() - - def acquire(self) -> bool: - self.counter += 1 - return self.lock.acquire() - - def release(self) -> None: - self.counter -= 1 - return self.lock.release() - - @property - def no_more_waiters(self) -> bool: - return self.counter == 0 - - -class NamedLock: - _lock = Lock() - _locks: Dict[str, CounterLock] = {} - - def acquire(self, name: str) -> bool: - self._lock.acquire() - - if name not in self._locks: - self._locks[name] = CounterLock() - - self._lock.release() - - return self._locks[name].acquire() - - - - def release(self, name: str): - self._lock.acquire() - - if name not in self._locks: - return self._lock.release() - - self._locks[name].release() - if self._locks[name].no_more_waiters: - del self._locks[name] - - return self._lock.release() diff --git a/lnurl.py b/lnurl.py index dafdc06..739f54e 100644 --- a/lnurl.py +++ b/lnurl.py @@ -16,12 +16,11 @@ from .crud import ( get_withdraw_link_by_hash, increment_withdraw_link, + unincrement_withdraw_link, remove_unique_withdraw_link, ) from .models import WithdrawLink -from .helpers import NamedLock -withdraw_lock = NamedLock() @withdraw_ext.get( "/api/v1/lnurl/{unique_hash}", @@ -83,31 +82,6 @@ async def api_lnurl_callback( pr: str = Query(...), id_unique_hash=None, ): - link = await _check_withdraw_link_safe(unique_hash, k1, id_unique_hash) - - try: - payment_hash = await pay_invoice( - wallet_id=link.wallet, - payment_request=pr, - max_sat=link.max_withdrawable, - extra={"tag": "withdraw", "withdrawal_link_id": link.id}, - ) - if link.webhook_url: - await dispatch_webhook(link, payment_hash, pr) - return {"status": "OK"} - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail=f"withdraw not working. {str(e)}" - ) - -async def _check_withdraw_link_safe(unique_hash, k1, id_unique_hash) -> WithdrawLink: - try: - withdraw_lock.acquire(unique_hash) - return await _check_withdraw_link(unique_hash, k1, id_unique_hash) - finally: - withdraw_lock.release(unique_hash) - -async def _check_withdraw_link(unique_hash, k1, id_unique_hash) -> WithdrawLink: link = await get_withdraw_link_by_hash(unique_hash) now = int(datetime.now().timestamp()) if not link: @@ -119,11 +93,9 @@ async def _check_withdraw_link(unique_hash, k1, id_unique_hash) -> WithdrawLink: raise HTTPException( status_code=HTTPStatus.METHOD_NOT_ALLOWED, detail="withdraw is spent." ) - + await increment_withdraw_link(link) if link.k1 != k1: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, detail="k1 is wrong." - ) + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="k1 is wrong.") if now < link.open_time: raise HTTPException( @@ -138,13 +110,22 @@ async def _check_withdraw_link(unique_hash, k1, id_unique_hash) -> WithdrawLink: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="withdraw not found." ) - await increment_withdraw_link(link) - - return link - - - + try: + payment_hash = await pay_invoice( + wallet_id=link.wallet, + payment_request=pr, + max_sat=link.max_withdrawable, + extra={"tag": "withdraw", "withdrawal_link_id": link.id}, + ) + if link.webhook_url: + await dispatch_webhook(link, payment_hash, pr) + return {"status": "OK"} + except Exception as e: + await unincrement_withdraw_link(link) + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail=f"withdraw not working. {str(e)}" + ) def check_unique_link(link: WithdrawLink, unique_hash: str) -> bool: