Skip to content

Commit

Permalink
Finished 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Arad Ganir authored and Arad Ganir committed Jan 11, 2025
1 parent a4c18ca commit c397bc8
Showing 1 changed file with 32 additions and 7 deletions.
39 changes: 32 additions & 7 deletions minitorch/module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple, List


class Module:
Expand Down Expand Up @@ -31,29 +31,54 @@ def modules(self) -> Sequence[Module]:

def train(self) -> None:
"""Set the mode of this module and all descendent modules to `train`."""

# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
# raise NotImplementedError("Need to implement for Task 0.4")
def _train(module):
module.training = True
for m in module.modules():
_train(m)

_train(self)

def eval(self) -> None:
"""Set the mode of this module and all descendent modules to `eval`."""

# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
# raise NotImplementedError("Need to implement for Task 0.4")

def _eval(module):
module.training = False
for m in module.modules():
_eval(m)

def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
_eval(self)

def named_parameters(self) -> list[tuple[str, Module]]:
"""Collect all the parameters of this module and its descendents.
Returns
-------
The name and `Parameter` of each ancestor parameter.
"""

# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
#raise NotImplementedError("Need to implement for Task 0.4")

def _named_parameters(module, prefix=""):
for name, param in module._parameters.items():
yield prefix + name, param
for name, child in module._modules.items():
yield from _named_parameters(module, prefix + name + ".")

return list(_named_parameters(self))

def parameters(self) -> Sequence[Parameter]:
def parameters(self) -> list[Module]:
"""Enumerate over all the parameters of this module and its descendents."""
# TODO: Implement for Task 0.4.
raise NotImplementedError("Need to implement for Task 0.4")
#raise NotImplementedError("Need to implement for Task 0.4")
return [param for _, param in self.named_parameters()]

def add_parameter(self, k: str, v: Any) -> Parameter:
"""Manually add a parameter. Useful helper for scalar parameters.
Expand Down

0 comments on commit c397bc8

Please sign in to comment.