Skip to content

Commit

Permalink
[dynamo] Support object creation of classes with custom __new__ (pyto…
Browse files Browse the repository at this point in the history
…rch#132977)

Pull Request resolved: pytorch#132977
Approved by: https://github.com/jansel
  • Loading branch information
anijain2305 authored and pytorchmergebot committed Aug 16, 2024
1 parent a1a869f commit 8a5708b
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 3 deletions.
101 changes: 98 additions & 3 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3215,12 +3215,107 @@ def forward(self, x):
x = torch.rand(2, 2)
m = Model()

opt_m = torch.compile(backend="eager")(m)
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
ref = m(x)
res = opt_m(x)
self.assertTrue(same(ref, res))
self.assertEqual(len(counters["graph_break"]), 1)
self.assertFalse("super() nn.Module.__init__" in counters["graph_break"])

def test_dunder_new_function_inlining1(self):
class Mock:
def __new__(cls):
return super().__new__(cls)

def __init__(self):
self.c = 5

def run(self, x):
return x * self.c

def fn(x):
mock = Mock()
return mock.run(x)

opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)

self.assertEqual(fn(x), opt_fn(x))

def test_dunder_new_function_inlining2(self):
class Vehicle:
def __new__(cls, *args, **kwargs):
return super(Vehicle, cls).__new__(cls)

def __init__(self, make, model, year):
self.make = make
self.model = model
self.year = year

class Car(Vehicle):
def __new__(cls, *args, **kwargs):
return super(Car, cls).__new__(cls)

def __init__(self, make, model, year, num_doors):
super(Car, self).__init__(make, model, year)
self.num_doors = num_doors

class ElectricCar(Car):
def __new__(cls, *args, **kwargs):
return super(ElectricCar, cls).__new__(cls)

def __init__(self, make, model, year, num_doors, battery_capacity):
super(ElectricCar, self).__init__(make, model, year, num_doors)
self.battery_capacity = battery_capacity

def run(self, x):
return torch.sin(x)

def fn(x):
ev = ElectricCar("Tesla", "Model S", 2022, 4, "100 kWh")
return ev.run(x)

opt_fn = torch.compile(fn, backend="eager", fullgraph=True)

x = torch.randn(4)

self.assertEqual(fn(x), opt_fn(x))

def test_multiple_inheritance(self):
class Base1:
def __new__(cls):
return super().__new__(cls)

def __init__(self):
super().__init__()
if not hasattr(self, "base2"):
raise ValueError("Wrong MRO tracing")
self.base1 = 3

class Base2:
def __new__(cls):
return super().__new__(cls)

def __init__(self):
super().__init__()
self.base2 = 5

class Derived(Base1, Base2):
def __new__(cls):
return super().__new__(cls)

def __init__(self):
super().__init__()
self.derived = 7

def run(self, x):
return self.base1 * self.base2 * self.derived * x

def fn(x):
o = Derived()
return o.run(x)

opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
self.assertEqual(fn(x), opt_fn(x))

def test_class_duner_mro(self):
class ModuleA(torch.nn.Module):
Expand Down
8 changes: 8 additions & 0 deletions torch/_dynamo/polyfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,11 @@ def mapping_get(obj, key, value=None):
return obj.__getitem__(key)
except KeyError:
return value


def instantiate_user_defined_class_object(*args, **kwargs):
cls = args[0]
other_args = args[1:]
obj = cls.__new__(cls, *other_args, **kwargs)
obj.__init__(*other_args, **kwargs)
return obj
3 changes: 3 additions & 0 deletions torch/_dynamo/variables/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,9 @@ class TupleVariable(BaseListVariable):
def python_type(self):
return tuple

def __repr__(self) -> str:
return f"{self.__class__.__name__}(length={len(self.items)})"

def debug_repr(self):
return self.debug_repr_helper("(", ")")

Expand Down
16 changes: 16 additions & 0 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,22 @@ def call_method(
).call_function(tx, [self.objvar] + args, kwargs)
else:
unimplemented("super() nn.Module.__init__")
elif self.objvar.source and inner_fn is object.__new__:
return tx.output.side_effects.track_object_new(
self.objvar.source,
self.objvar.value,
variables.UnspecializedNNModuleVariable
if issubclass(self.objvar.value, torch.nn.Module)
else UserDefinedObjectVariable,
{},
)
elif name == "__new__" and isinstance(inner_fn, types.FunctionType):
# __new__ is a staticmethod object, but accessing __new__ from the super object, as done in
# _resolved_getattr_and_source, results in a function object. If not specialized here, it will try to add
# the `self` arg and fail bind arg matching later.
return variables.UserFunctionVariable(
inner_fn, source=source
).call_function(tx, args, kwargs)
elif isinstance(inner_fn, types.FunctionType):
return variables.UserFunctionVariable(
inner_fn, source=source
Expand Down
12 changes: 12 additions & 0 deletions torch/_dynamo/variables/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,18 @@ def call_function(
seed = None
random_object = random.Random(seed)
return RandomVariable(random_object)
elif (
not self.is_standard_new()
and SideEffects.cls_supports_mutation_side_effects(self.value)
and self.source
):
return tx.inline_user_function_return(
SourcelessBuilder.create(
tx, polyfill.instantiate_user_defined_class_object
),
[self, *args],
kwargs,
)

return super().call_function(tx, args, kwargs)

Expand Down

0 comments on commit 8a5708b

Please sign in to comment.