Skip to content

Commit

Permalink
refactor: more robust model
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Apr 24, 2024
1 parent 955e0ed commit c4e5094
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions neuralnetlib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size:
validation_data: Tuple of validation data and labels
callbacks: List of callback objects (e.g., EarlyStopping)
"""
x_train = np.array(x_train)
x_test = np.array(x_test)
y_train = np.array(y_train)
y_test = np.array(y_test)

if callbacks:
callback_metrics = set()
Expand Down Expand Up @@ -197,11 +201,14 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size:
print()

def evaluate(self, x_test: np.ndarray, y_test: np.ndarray) -> float:
x_test = np.array(x_test)
y_test = np.array(y_test)
predictions = self.forward_pass(x_test)
loss = self.loss_function(y_test, predictions)
return loss

def predict(self, X: np.ndarray) -> np.ndarray:
X = np.array(X)
return self.forward_pass(X, training=False)

def save(self, filename: str):
Expand Down

0 comments on commit c4e5094

Please sign in to comment.