Skip to content

Commit

Permalink
Fixing 05bit#44 - add trailing underscore to positional argument name…
Browse files Browse the repository at this point in the history
…s in Manager.get(), Manager.create(), Manager.get_or_create() and Manager.create_or_get()
  • Loading branch information
rudyryk committed Nov 30, 2016
1 parent 684d63e commit 9c0bc9b
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
except ImportError:
aiomysql = None

__version__ = '0.5.5'
__version__ = '0.5.6'

__all__ = [
### High level API ###
Expand Down Expand Up @@ -132,10 +132,10 @@ def is_connected(self):
return self.database._async_conn is not None

@asyncio.coroutine
def get(self, source, *args, **kwargs):
def get(self, source_, *args, **kwargs):
"""Get the model instance.
:param source: model or base query for lookup
:param source_: model or base query for lookup
Example::
Expand All @@ -148,12 +148,12 @@ async def my_async_func():
"""
yield from self.connect()

if isinstance(source, peewee.Query):
query = source
if isinstance(source_, peewee.Query):
query = source_
model = query.model_class
else:
query = source.select()
model = source
query = source_.select()
model = source_

conditions = list(args) + [(getattr(model, k) == v)
for k, v in kwargs.items()]
Expand All @@ -168,11 +168,11 @@ async def my_async_func():
raise model.DoesNotExist

@asyncio.coroutine
def create(self, model, **data):
def create(self, model_, **data):
"""Create a new object saved to database.
"""
inst = model(**data)
query = model.insert(**dict(inst._data))
inst = model_(**data)
query = model_.insert(**dict(inst._data))

pk = yield from self.execute(query)
if pk is None:
Expand All @@ -183,19 +183,19 @@ def create(self, model, **data):
return inst

@asyncio.coroutine
def get_or_create(self, model, defaults=None, **kwargs):
def get_or_create(self, model_, defaults=None, **kwargs):
"""Try to get an object or create it with the specified defaults.
Return 2-tuple containing the model instance and a boolean
indicating whether the instance was created.
"""
try:
return (yield from self.get(model, **kwargs)), False
except model.DoesNotExist:
return (yield from self.get(model_, **kwargs)), False
except model_.DoesNotExist:
data = defaults or {}
data.update({k: v for k, v in kwargs.items()
if not '__' in k})
return (yield from self.create(model, **data)), True
return (yield from self.create(model_, **data)), True

@asyncio.coroutine
def update(self, obj, only=None):
Expand Down Expand Up @@ -243,19 +243,19 @@ def delete(self, obj, recursive=False, delete_nullable=False):
return (yield from self.execute(query))

@asyncio.coroutine
def create_or_get(self, model, **kwargs):
def create_or_get(self, model_, **kwargs):
"""Try to create new object with specified data. If object already
exists, then try to get it by unique fields.
"""
try:
return (yield from self.create(model, **kwargs)), True
return (yield from self.create(model_, **kwargs)), True
except peewee.IntegrityError:
query = []
for field_name, value in kwargs.items():
field = getattr(model, field_name)
field = getattr(model_, field_name)
if field.unique or field.primary_key:
query.append(field == value)
return (yield from self.get(model, *query)), False
return (yield from self.get(model_, *query)), False

@asyncio.coroutine
def execute(self, query):
Expand Down Expand Up @@ -773,6 +773,7 @@ def __len__(self):
def _get_result_wrapper(self, query):
"""Get result wrapper class.
"""
db = query.database
if query._tuples:
QRW = db.get_result_wrapper(RESULTS_TUPLES)
elif query._dicts:
Expand Down Expand Up @@ -807,6 +808,7 @@ class AsyncRawQueryWrapper(AsyncQueryWrapper):
def _get_result_wrapper(self, query):
"""Get raw query result wrapper class.
"""
db = query.database
if query._tuples:
QRW = db.get_result_wrapper(RESULTS_TUPLES)
elif query._dicts:
Expand Down

0 comments on commit 9c0bc9b

Please sign in to comment.