Skip to content

Commit

Permalink
fix subscription permission and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Yifan Zhang authored and yifan committed Jan 11, 2023
1 parent a7b9469 commit 2c66e9d
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 67 deletions.
15 changes: 14 additions & 1 deletion apihub/subscription/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sqlalchemy.orm import relationship

from ..common.db_session import Base
from .schemas import SubscriptionTier
from .schemas import SubscriptionTier, ApplicationCreate, SubscriptionPricingCreate


class Application(Base):
Expand All @@ -34,6 +34,19 @@ class Application(Base):
def __str__(self):
return f"{self.name} || {self.url}"

def to_schema(self, with_pricing=False) -> ApplicationCreate:
return ApplicationCreate(
name=self.name,
url=self.url,
description=self.description,
pricing=[
SubscriptionPricingCreate(
tier=pricing.tier, price=pricing.price, credit=pricing.credit, application=self.name,
)
for pricing in self.subscriptions_pricing
] if with_pricing else [],
)


class SubscriptionPricing(Base):
"""
Expand Down
27 changes: 6 additions & 21 deletions apihub/subscription/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,34 +76,17 @@ def get_application(self, name: str) -> ApplicationCreate:
"""
try:
application = self.get_query().filter(Application.name == name).one()
application_create = ApplicationCreate(
name=application.name,
url=application.url,
description=application.description,
pricing=[],
)
for pricing in application.subscriptions_pricing:
application_create.pricing.append(
SubscriptionPricingBase(
tier=pricing.tier,
price=pricing.price,
credit=pricing.credit,
)
)
return application_create
return application.to_schema(with_pricing=True)
except NoResultFound:
raise ApplicationException(f"Application {name} not found.")

def get_applications(self) -> List[ApplicationCreate]:
def get_applications(self, username=None) -> List[ApplicationCreate]:
"""
List applications.
:return: List of applications.
"""

return [
self.get_application(application.name)
for application in self.get_query().all()
]
applications = map(lambda x: x.to_schema(), self.get_query().all())
return list(applications)


class SubscriptionPricingQuery(BaseQuery):
Expand Down Expand Up @@ -285,6 +268,8 @@ def get_active_subscriptions(self, username: str) -> List[SubscriptionDetails]:
balance=subscription.balance,
expires_at=subscription.expires_at,
recurring=subscription.recurring,
created_by=subscription.created_by,
created_at=subscription.created_at,
)
for subscription in subscriptions
]
Expand Down
22 changes: 13 additions & 9 deletions apihub/subscription/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,17 @@ def create_application(
@router.get("/application", response_model=List[ApplicationCreate])
def get_applications(
session: Session = Depends(create_session),
username: str = Depends(require_admin),
user: UserBase = Depends(require_token),
):
"""
List all applications.
"""

try:
"""
List all applications.
"""
return ApplicationQuery(session).get_applications()
except ApplicationException:
raise HTTPException(400, "Error while retrieving applications")
applications = ApplicationQuery(session).get_applications()
return applications
except ApplicationException as e:
raise HTTPException(400, detail=str(e))


@router.get("/application/{application}", response_model=ApplicationCreate)
Expand Down Expand Up @@ -163,10 +165,12 @@ def get_active_subscriptions(
return []

return [
SubscriptionTokenResponse(
SubscriptionIn(
username=subscription.username,
application=subscription.application,
expires_time=subscription.expires_at,
tier=subscription.tier,
expires_at=subscription.expires_at,
recurring=subscription.recurring,
)
for subscription in subscriptions
]
Expand Down
60 changes: 24 additions & 36 deletions tests/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ def api_function_2(
UserFactory._meta.sqlalchemy_session_persistence = "commit"
UserFactory(username="tester", role=UserType.USER)

ApplicationFactory._meta.sqlalchemy_session = db_session
ApplicationFactory._meta.sqlalchemy_session_persistence = "commit"
application = ApplicationFactory(name="test", url="/test")

SubscriptionPricingFactory._meta.sqlalchemy_session = db_session
SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit"
pricing = SubscriptionPricingFactory(
tier=SubscriptionTier.TRIAL,
price=100,
credit=100,
application="test",
)

SubscriptionFactory._meta.sqlalchemy_session = db_session
SubscriptionFactory._meta.sqlalchemy_session_persistence = "commit"

SubscriptionFactory(username="tester", application="test", credit=100)

yield TestClient(app)


Expand All @@ -138,7 +156,7 @@ def _require_user_token():
class TestApplication:
def test_create_application(self, client):
new_application = ApplicationCreate(
name="test",
name="app",
url="/test",
description="test",
pricing=[
Expand All @@ -160,53 +178,21 @@ def test_create_application(self, client):
assert response.status_code == 200

response = client.get(
"/application/test",
"/application/app",
)

response_json = response.json()
assert len(response_json["pricing"]) == 3

def test_list_application(self, client, db_session):
ApplicationFactory._meta.sqlalchemy_session = db_session
ApplicationFactory._meta.sqlalchemy_session_persistence = "commit"
ApplicationFactory(name="application", url="/test")
ApplicationFactory(name="application2", url="/test2")

SubscriptionPricingFactory._meta.sqlalchemy_session = db_session
SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit"
SubscriptionPricingFactory(
tier=SubscriptionTier.TRIAL,
price=100,
credit=100,
application="application",
)
SubscriptionPricingFactory(
tier=SubscriptionTier.TRIAL,
price=200,
credit=200,
application="application2",
)
response = client.get("/application")
assert response.status_code == 200
response_json = response.json()
assert len(response_json) == 2
assert len(response_json) == 1

def test_get_application(self, client, db_session):
ApplicationFactory._meta.sqlalchemy_session = db_session
ApplicationFactory._meta.sqlalchemy_session_persistence = "commit"
ApplicationFactory(name="application", url="/test")

SubscriptionPricingFactory._meta.sqlalchemy_session = db_session
SubscriptionPricingFactory._meta.sqlalchemy_session_persistence = "commit"
SubscriptionPricingFactory(
tier=SubscriptionTier.TRIAL,
price=100,
credit=100,
application="application",
)

response = client.get(
"/application/application",
"/application/test",
)
assert response.status_code == 200
response_json = response.json()
Expand Down Expand Up @@ -257,6 +243,8 @@ def test_get_all_subscriptions(self, client):
"/subscription",
)
assert response.status_code == 200
response_json = response.json()
assert len(response_json) == 1

def test_create_subscription_not_existing_user(self, client):
new_subscription = SubscriptionIn(
Expand Down

0 comments on commit 2c66e9d

Please sign in to comment.