-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #86 from ucsd-ets/mamba-scipy-ml-fixup
Add more workflow tests
- Loading branch information
Showing
6 changed files
with
301 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import numpy as np | ||
from keras.models import Sequential | ||
from keras.layers import Dense | ||
import pytest | ||
|
||
@pytest.fixture | ||
def simple_model(): | ||
model = Sequential() | ||
model.add(Dense(units=10, activation='relu', input_shape=(5,))) | ||
model.add(Dense(units=1, activation='sigmoid')) | ||
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) | ||
return model | ||
|
||
def test_model_training(simple_model): | ||
x_train = np.random.random((100, 5)) | ||
y_train = np.random.randint(2, size=(100, 1)) | ||
simple_model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=0) | ||
assert simple_model.layers[0].input_shape == (None, 5) | ||
assert simple_model.layers[1].output_shape == (None, 1) | ||
|
||
def test_model_evaluation(simple_model): | ||
x_test = np.random.random((20, 5)) | ||
y_test = np.random.randint(2, size=(20, 1)) | ||
loss, accuracy = simple_model.evaluate(x_test, y_test, verbose=0) | ||
assert loss >= 0 | ||
assert 0 <= accuracy <= 1 | ||
|
||
def test_model_prediction(simple_model): | ||
x_new = np.random.random((1, 5)) | ||
prediction = simple_model.predict(x_new) | ||
assert prediction.shape == (1, 1) | ||
assert 0 <= prediction <= 1 |
29 changes: 29 additions & 0 deletions
29
images/scipy-ml-notebook/workflow_tests/test_matplotlib.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
def create_simple_plot(x, y, title="Test Plot"): | ||
fig, ax = plt.subplots() | ||
ax.plot(x, y) | ||
ax.set_title(title) | ||
return fig, ax | ||
|
||
def test_number_of_plots_created(): | ||
x = np.arange(0, 10, 1) | ||
y = x ** 2 | ||
fig, ax = create_simple_plot(x, y) | ||
assert len(fig.axes) == 1, "There should be exactly one plot created" | ||
|
||
def test_plot_title_is_correct(): | ||
x = np.arange(0, 10, 1) | ||
y = x ** 2 | ||
title = "Test Plot" | ||
_, ax = create_simple_plot(x, y, title=title) | ||
assert ax.get_title() == title, f"The title should be '{title}'" | ||
|
||
def test_data_matches_input(): | ||
x = np.arange(0, 10, 1) | ||
y = x ** 2 | ||
_, ax = create_simple_plot(x, y) | ||
line = ax.lines[0] # Get the first (and in this case, only) line object | ||
np.testing.assert_array_equal(line.get_xdata(), x, "X data does not match input") | ||
np.testing.assert_array_equal(line.get_ydata(), y, "Y data does not match input") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import nltk | ||
import pytest | ||
|
||
def setup_module(module): | ||
nltk.download('punkt', download_dir='/tmp/nltk_data') | ||
nltk.download('maxent_ne_chunker', download_dir='/tmp/nltk_data') | ||
nltk.download('words', download_dir='/tmp/nltk_data') | ||
nltk.data.path.append('/tmp/nltk_data') | ||
|
||
def test_tokenization(): | ||
# Test sentence tokenization | ||
sentence = "This is a sample sentence. It consists of two sentences." | ||
tokenized_sentences = nltk.sent_tokenize(sentence) | ||
assert len(tokenized_sentences) == 2 | ||
assert tokenized_sentences[0] == "This is a sample sentence." | ||
assert tokenized_sentences[1] == "It consists of two sentences." | ||
|
||
# Test word tokenization | ||
sentence = "The quick brown fox jumps over the lazy dog." | ||
tokenized_words = nltk.word_tokenize(sentence) | ||
assert len(tokenized_words) == 10 | ||
assert tokenized_words == ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "."] | ||
|
||
def test_stemming(): | ||
# Test Porter stemmer | ||
porter_stemmer = nltk.PorterStemmer() | ||
words = ["running", "runs", "ran", "runner"] | ||
stemmed_words = [porter_stemmer.stem(word) for word in words] | ||
assert stemmed_words == ["run", "run", "ran", "runner"] | ||
|
||
# Test Lancaster stemmer | ||
lancaster_stemmer = nltk.LancasterStemmer() | ||
words = ["happiness", "happier", "happiest", "happily"] | ||
stemmed_words = [lancaster_stemmer.stem(word) for word in words] | ||
assert stemmed_words == ["happy", "happy", "happiest", "happy"] | ||
|
||
def test_named_entity_recognition(): | ||
sentence = "Barack Obama was the 44th President of the United States." | ||
tokens = nltk.word_tokenize(sentence) | ||
tags = nltk.pos_tag(tokens) | ||
ne_chunks = nltk.ne_chunk(tags) | ||
|
||
found_barack_obama = False | ||
found_united_states = False | ||
|
||
# Buffer for consecutive person tags | ||
person_buffer = [] | ||
|
||
def check_and_clear_buffer(): | ||
nonlocal found_barack_obama | ||
if person_buffer: | ||
person_name = " ".join(person_buffer) | ||
if person_name == "Barack Obama": | ||
found_barack_obama = True | ||
person_buffer.clear() | ||
|
||
for ne in ne_chunks: | ||
if isinstance(ne, nltk.tree.Tree): | ||
if ne.label() == "PERSON": | ||
person_buffer.append(" ".join(token[0] for token in ne)) | ||
else: | ||
# If we encounter a non-PERSON entity, check and clear the buffer | ||
check_and_clear_buffer() | ||
if ne.label() == "GPE" and " ".join(token[0] for token in ne) == "United States": | ||
found_united_states = True | ||
else: | ||
# For tokens not recognized as NE, clear the buffer | ||
check_and_clear_buffer() | ||
|
||
check_and_clear_buffer() | ||
|
||
#print(str(ne_chunks)) | ||
|
||
# Assert the named entities were found | ||
assert found_barack_obama, "Barack Obama as PERSON not found" | ||
assert found_united_states, "United States as GPE not found" | ||
|
||
# Assert the named entities were found | ||
assert found_barack_obama, "Barack Obama as PERSON not found" | ||
assert found_united_states, "United States as GPE not found" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import pandas as pd | ||
import numpy as np | ||
import pytest | ||
|
||
def test_dataframe_creation(): | ||
# Test creating a DataFrame from a dictionary | ||
data = {'name': ['Alice', 'Bob', 'Charlie'], | ||
'age': [25, 30, 35], | ||
'city': ['New York', 'London', 'Paris']} | ||
df = pd.DataFrame(data) | ||
|
||
assert df.shape == (3, 3) | ||
|
||
assert list(df.columns) == ['name', 'age', 'city'] | ||
|
||
assert df['name'].dtype == object | ||
assert df['age'].dtype == int | ||
assert df['city'].dtype == object | ||
|
||
def test_dataframe_indexing(): | ||
# Create a sample DataFrame | ||
data = {'A': [1, 2, 3], | ||
'B': [4, 5, 6], | ||
'C': [7, 8, 9]} | ||
df = pd.DataFrame(data) | ||
|
||
assert df['A'].tolist() == [1, 2, 3] | ||
assert df['B'].tolist() == [4, 5, 6] | ||
assert df['C'].tolist() == [7, 8, 9] | ||
|
||
assert df.iloc[0].tolist() == [1, 4, 7] | ||
assert df.iloc[1].tolist() == [2, 5, 8] | ||
assert df.iloc[2].tolist() == [3, 6, 9] | ||
|
||
def test_dataframe_merge(): | ||
# Create two sample DataFrames | ||
df1 = pd.DataFrame({'key': ['A', 'B', 'C', 'D'], | ||
'value1': [1, 2, 3, 4]}) | ||
df2 = pd.DataFrame({'key': ['B', 'D', 'E', 'F'], | ||
'value2': [5, 6, 7, 8]}) | ||
|
||
merged_df = pd.merge(df1, df2, on='key') | ||
|
||
assert merged_df.shape == (2, 3) | ||
|
||
assert merged_df['key'].tolist() == ['B', 'D'] | ||
assert merged_df['value1'].tolist() == [2, 4] | ||
assert merged_df['value2'].tolist() == [5, 6] | ||
|
||
def test_dataframe_groupby(): | ||
# Create a sample DataFrame | ||
data = {'category': ['A', 'B', 'A', 'B', 'A'], | ||
'value': [1, 2, 3, 4, 5]} | ||
df = pd.DataFrame(data) | ||
|
||
grouped_df = df.groupby('category').sum() | ||
|
||
assert grouped_df.shape == (2, 1) | ||
|
||
assert grouped_df.loc['A', 'value'] == 9 | ||
assert grouped_df.loc['B', 'value'] == 6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
images/scipy-ml-notebook/workflow_tests/test_statsmodels.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import numpy as np | ||
import statsmodels.api as sm | ||
import pytest | ||
|
||
def test_ols_simple_fit(): | ||
# Generate synthetic data (reproducible with seed(0)) | ||
np.random.seed(0) | ||
X = np.random.rand(100, 1) | ||
X = sm.add_constant(X) # Adds a constant term for the intercept | ||
beta = [0.5, 2.0] # True coefficients | ||
y = np.dot(X, beta) + np.random.normal(size=100) | ||
|
||
# Fit the model | ||
model = sm.OLS(y, X) | ||
results = model.fit() | ||
|
||
# Check if the estimated coefficients are close to the true coefficients | ||
assert np.allclose(results.params, beta, atol=0.5), "The estimated coefficients are not as expected." | ||
|
||
def test_logistic_regression_prediction(): | ||
# Generate synthetic data | ||
np.random.seed(1) | ||
X = np.random.randn(100, 2) | ||
X = sm.add_constant(X) | ||
beta = [0.1, 0.5, -0.3] | ||
y_prob = 1 / (1 + np.exp(-np.dot(X, beta))) # Sigmoid function for true probabilities | ||
y = (y_prob > 0.5).astype(int) # Binary outcome | ||
|
||
# Fit the logistic regression model | ||
model = sm.Logit(y, X) | ||
results = model.fit(disp=0) # disp=0 suppresses the optimization output | ||
|
||
# Predict using the model | ||
predictions = results.predict(X) > 0.5 | ||
|
||
# Check if the predictions match the actual binary outcomes | ||
accuracy = np.mean(predictions == y) | ||
assert accuracy > 0.75, "The prediction accuracy should be higher than 75%." | ||
|
||
def test_ols_summary_contains_r_squared(): | ||
# Simple linear regression with synthetic data | ||
np.random.seed(2) | ||
X = np.random.rand(50, 1) | ||
y = 2 * X.squeeze() + 1 + np.random.normal(scale=0.5, size=50) | ||
X = sm.add_constant(X) | ||
|
||
model = sm.OLS(y, X) | ||
results = model.fit() | ||
|
||
summary_str = str(results.summary()) | ||
|
||
# Check if 'R-squared' is in the summary | ||
assert 'R-squared' in summary_str, "'R-squared' not found in the model summary." |