Skip to content

Commit

Permalink
Merge pull request #37 from marcpinet/refactor-more-robust-model
Browse files Browse the repository at this point in the history
refactor: more robust model
  • Loading branch information
marcpinet authored Apr 24, 2024
2 parents 955e0ed + 95eac25 commit 1242803
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions neuralnetlib/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size:
validation_data: Tuple of validation data and labels
callbacks: List of callback objects (e.g., EarlyStopping)
"""
x_train = np.array(x_train)
y_train = np.array(y_train)

if validation_data is not None:
x_test, y_test = validation_data
x_test = np.array(x_test)
y_test = np.array(y_test)

if callbacks:
callback_metrics = set()
Expand Down Expand Up @@ -197,11 +204,14 @@ def fit(self, x_train: np.ndarray, y_train: np.ndarray, epochs: int, batch_size:
print()

def evaluate(self, x_test: np.ndarray, y_test: np.ndarray) -> float:
x_test = np.array(x_test)
y_test = np.array(y_test)
predictions = self.forward_pass(x_test)
loss = self.loss_function(y_test, predictions)
return loss

def predict(self, X: np.ndarray) -> np.ndarray:
X = np.array(X)
return self.forward_pass(X, training=False)

def save(self, filename: str):
Expand Down

0 comments on commit 1242803

Please sign in to comment.