From c397bc8d770d69fda6ad7d6271c505268dc51d4e Mon Sep 17 00:00:00 2001 From: Arad Ganir Date: Fri, 10 Jan 2025 17:48:43 -0800 Subject: [PATCH] Finished 0.4 --- minitorch/module.py | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/minitorch/module.py b/minitorch/module.py index 0a66058c..97e86c64 100644 --- a/minitorch/module.py +++ b/minitorch/module.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple, List class Module: @@ -31,15 +31,30 @@ def modules(self) -> Sequence[Module]: def train(self) -> None: """Set the mode of this module and all descendent modules to `train`.""" + # TODO: Implement for Task 0.4. - raise NotImplementedError("Need to implement for Task 0.4") + # raise NotImplementedError("Need to implement for Task 0.4") + def _train(module): + module.training = True + for m in module.modules(): + _train(m) + + _train(self) def eval(self) -> None: """Set the mode of this module and all descendent modules to `eval`.""" + # TODO: Implement for Task 0.4. - raise NotImplementedError("Need to implement for Task 0.4") + # raise NotImplementedError("Need to implement for Task 0.4") + + def _eval(module): + module.training = False + for m in module.modules(): + _eval(m) - def named_parameters(self) -> Sequence[Tuple[str, Parameter]]: + _eval(self) + + def named_parameters(self) -> list[tuple[str, Module]]: """Collect all the parameters of this module and its descendents. Returns @@ -47,13 +62,23 @@ def named_parameters(self) -> Sequence[Tuple[str, Parameter]]: The name and `Parameter` of each ancestor parameter. """ + # TODO: Implement for Task 0.4. - raise NotImplementedError("Need to implement for Task 0.4") + #raise NotImplementedError("Need to implement for Task 0.4") + + def _named_parameters(module, prefix=""): + for name, param in module._parameters.items(): + yield prefix + name, param + for name, child in module._modules.items(): + yield from _named_parameters(module, prefix + name + ".") + + return list(_named_parameters(self)) - def parameters(self) -> Sequence[Parameter]: + def parameters(self) -> list[Module]: """Enumerate over all the parameters of this module and its descendents.""" # TODO: Implement for Task 0.4. - raise NotImplementedError("Need to implement for Task 0.4") + #raise NotImplementedError("Need to implement for Task 0.4") + return [param for _, param in self.named_parameters()] def add_parameter(self, k: str, v: Any) -> Parameter: """Manually add a parameter. Useful helper for scalar parameters.