Skip to content

Commit

Permalink
fix(#28): prevent infinite recursion when custom __eq__ is used on mo…
Browse files Browse the repository at this point in the history
…del.
  • Loading branch information
bradleyess committed Nov 21, 2024
1 parent 6412406 commit 3103646
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/zeal/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,19 @@ def wrapper(*args, **kwargs):
hasattr(queryset, "__zeal_patched") and queryset.__zeal_patched # type: ignore
):
return queryset

if args and args != context.get("args"):
context["args"] = args
if kwargs and kwargs != context.get("kwargs"):
context["kwargs"] = kwargs

# When comparing kwargs, we must use id() rather than == because
# __eq__ methods on model instances can trigger infinite recursion.
if kwargs:
existing_kwargs = context.get("kwargs")
if existing_kwargs is None or any(
id(v) != id(existing_kwargs.get(k)) for k, v in kwargs.items()
):
context["kwargs"] = kwargs

queryset._clone = patch_queryset_function( # type: ignore
queryset._clone, # type: ignore
parser,
Expand Down
82 changes: 82 additions & 0 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys

import pytest
from django.db import models
from djangoproject.social.models import User

from tests.factories import UserFactory
Expand All @@ -23,3 +24,84 @@ def test_handles_empty_querysets():
def test_handles_get_with_values():
user = UserFactory.create()
User.objects.filter(pk=user.pk).values("username").get()


def test_handles_model_eq_comparison():
"""Test that comparing model instances with custom __eq__ doesn't cause recursion"""
user1 = UserFactory.create()
user2 = UserFactory.create()
user3 = UserFactory.create()

# Set up following relationships after creation
user2.following.set([user1])
user3.following.set([user1])

# This should not cause recursion when comparing users with the same following
assert user2 != user1 # different following relationships
assert user2 == user2 # same object
assert user2 != user3 # different objects but same following


class CustomEqualityModel(models.Model):
"""Model that implements custom equality checking using related fields"""

name: models.CharField = models.CharField(max_length=100)
relation: models.ForeignKey[
"CustomEqualityModel", "CustomEqualityModel"
] = models.ForeignKey(
"self", null=True, on_delete=models.CASCADE, related_name="related"
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, CustomEqualityModel):
return NotImplemented
# Explicitly access relation to trigger potential recursion
my_rel = self.relation
other_rel = other.relation
return my_rel == other_rel and self.name == other.name

class Meta:
app_label = "social"


def test_handles_custom_equality_with_relations():
"""
Ensure model equality comparisons don't cause infinite recursion
when __eq__ methods access related fields. This is important because
Django's lazy loading could trigger repeated relation lookups during
equality checks.
"""
# Create test instances
base = CustomEqualityModel.objects.create(name="base")
obj1 = CustomEqualityModel.objects.create(name="test1", relation=base)
obj2 = CustomEqualityModel.objects.create(name="test1", relation=base)
obj3 = CustomEqualityModel.objects.create(name="test2", relation=base)

assert obj1 == obj1 # Same object
assert obj1 == obj2 # Different objects, same values
assert obj1 != obj3 # Different values

result = CustomEqualityModel.objects.filter(name="test1").first()
assert result is not None
# Trigger recursion if not fixed.
_ = result.relation


def test_handles_nested_relation_equality():
"""
Ensure deep relation traversal works correctly without infinite recursion.
This is particularly important for models that compare relations in their
equality checks, as each comparison could potentially trigger a chain of
database lookups through the relationship tree.
"""
root = CustomEqualityModel.objects.create(name="root")
middle = CustomEqualityModel.objects.create(name="middle", relation=root)
leaf1 = CustomEqualityModel.objects.create(name="leaf", relation=middle)
leaf2 = CustomEqualityModel.objects.create(name="leaf", relation=middle)

assert leaf1 == leaf2 # Same name and relation
assert leaf1.relation == leaf2.relation # middle objects are the same

result = CustomEqualityModel.objects.filter(name="leaf").first()
assert result is not None
_ = result.relation.relation # Access root through middle

0 comments on commit 3103646

Please sign in to comment.