-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from LSSTDESC/awo/initial-branch
Initial commit
- Loading branch information
Showing
18 changed files
with
331 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,4 +22,3 @@ python_versions: | |
- '3.10' | ||
- '3.11' | ||
- '3.12' | ||
- '3.13' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
from .example_module import greetings, meaning | ||
from .example_classifier import ExampleClassifier | ||
from .example_feature_extractor import ExampleFeatureExtractor | ||
from .example_query_strategy import ExampleQueryStrategy | ||
|
||
__all__ = ["greetings", "meaning"] | ||
__all__ = [ | ||
"ExampleClassifier", | ||
"ExampleFeatureExtractor", | ||
"ExampleQueryStrategy", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from resspect.classifiers import ResspectClassifier | ||
|
||
|
||
class ExampleClassifier(ResspectClassifier): | ||
"""Example of an externally defined classifier for RESSPECT. The API for the | ||
subclass of ResspectClassifier itself is very simple. However, the classifier | ||
that is assigned to `self.classifier` has a more substantial expected API, | ||
based on the scikit-learn API for classifiers. | ||
The MyClassifier class shows the methods that are expected to be implemented | ||
by the classifier.""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.classifier = MyClassifier(**self.kwargs) | ||
|
||
|
||
class MyClassifier: | ||
"""Example of a classifier. Note that the expected API mirrors a portion of | ||
the scikit-learn classifier API. | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
"""It is better to define __init__ with the explicitly required input | ||
parameters instead of `**kwargs`.""" | ||
pass | ||
|
||
def fit(self, train_features: list, train_labels: list) -> None: | ||
"""Fit the classifier to the training data. Not that there is no return | ||
value, it is only expected to fit the classifier to the data. | ||
Parameters | ||
---------- | ||
train_features : array-like | ||
The features used for training, [n_samples, m_features]. | ||
train_labels : array-like | ||
The training labels, [n_samples]. | ||
""" | ||
pass | ||
|
||
def predict(self, test_features: list) -> list: | ||
"""Predict the class labels for the test data. | ||
Parameters | ||
---------- | ||
test_features : array-like | ||
The features used for testing, [n_samples, m_features]. | ||
Returns | ||
------- | ||
predictions : array-like | ||
The predicted class labels, [n_samples]. | ||
""" | ||
pass | ||
|
||
def predict_proba(self, test_features: list) -> list: | ||
"""Predict the class probabilities for the test data. | ||
Parameters | ||
---------- | ||
test_features : array-like | ||
The features used for testing, [n_samples, m_features]. | ||
Returns | ||
------- | ||
probabilities : array-like | ||
The predicted class probabilities, [n_samples, n_classes]. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import numpy as np | ||
from resspect.feature_extractors.light_curve import LightCurve | ||
|
||
|
||
class ExampleFeatureExtractor(LightCurve): | ||
"""A minimal example of an external feature extractor class.""" | ||
|
||
# The list of feature names that will be extracted from a light curve. | ||
feature_names = [ | ||
"feature_0", | ||
"feature_1", | ||
"feature_n", | ||
# ... whatever additional features are extracted | ||
] | ||
|
||
# The name of the id column of the data. e.g. 'ID', 'obj_id', etc. | ||
id_column = "id" | ||
|
||
# The name of the label column. e.g. 'type', 'label', 'sntype', etc. | ||
# This is the column where the class label is stored. | ||
label_column = "type" | ||
|
||
# The names of classes that are NOT anomalies. e.g. ['Ia', 'Normal', etc.] | ||
non_anomaly_classes = ["Ia"] | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.num_features = len(ExampleFeatureExtractor.feature_names) | ||
self.features = None | ||
|
||
@classmethod | ||
def get_features(cls, filters: list) -> list[str]: | ||
""" | ||
A class method that returns the list of features that will be extracted | ||
from the light curve. | ||
Often the feature list is a cross product of the feature name and the filters. | ||
In this example we return only the feature name list. | ||
Returns | ||
------- | ||
feature_names: list[str] | ||
List of feature names that will be extracted from the light curve. | ||
""" | ||
return cls.feature_names | ||
|
||
@classmethod | ||
def get_metadata_columns(cls, **kwargs) -> list[str]: | ||
""" | ||
A class method that returns the metadata columns for the feature extractor. | ||
Depending on how dynamic the metadata columns are, this method can be | ||
a hard-coded list or a dynamically generated list. | ||
Returns | ||
------- | ||
metadata_columns: list[str] | ||
List of metadata columns for the feature extractor. | ||
""" | ||
|
||
# hard-coded example | ||
metadata_columns = [cls.id_column, "redshift", cls.label_column, "sncode", "sample"] | ||
|
||
# dynamic example | ||
kwargs["override_primary_columns"] = [cls.id_column, "redshift", cls.label_column, "sncode", "sample"] | ||
metadata_columns = super().get_metadata_header(**kwargs) | ||
|
||
return metadata_columns | ||
|
||
@classmethod | ||
def get_feature_header(cls, filters: list[str], **kwargs) -> list[str]: | ||
""" | ||
A class method that returns the full list column names for an output file. | ||
This includes the metadata columns and the feature columns. | ||
Returns | ||
------- | ||
header: list[str] | ||
List of column names for the output file. | ||
""" | ||
|
||
# One way to return this is to concatenate the metadata columns and the feature columns | ||
return ExampleFeatureExtractor.get_metadata_header(**kwargs) + ExampleFeatureExtractor.get_features( | ||
filters | ||
) | ||
|
||
def fit_all(self) -> np.ndarray: | ||
""" | ||
Extracts features for all light curves in the dataset. | ||
Returns | ||
------- | ||
features: np.ndarray | ||
Features extracted from the light curves. | ||
""" | ||
# Implement feature extraction here | ||
self.features = self.example_extraction_function() | ||
return self.features | ||
|
||
def _example_extraction_function(): | ||
# Just for demo purposes | ||
pass | ||
|
||
def get_features_to_write(self): | ||
""" | ||
Implement this method to return the features that will be persisted to disk. | ||
The base `LightCurve` class has a simple implementation, but you can | ||
override it here. | ||
The base `LightCurve` class implementation will return `features_list`: | ||
features_list = [ | ||
self.id, | ||
self.redshift, | ||
self.sntype, | ||
self.sncode, | ||
self.sample] | ||
features_list.extend(self.features) | ||
""" | ||
|
||
return super().get_features_to_write() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import numpy as np | ||
from resspect.query_strategies import QueryStrategy | ||
|
||
|
||
class ExampleQueryStrategy(QueryStrategy): | ||
"""Minimal example of an external query strategy class.""" | ||
|
||
def __init__( | ||
self, | ||
queryable_ids: np.array, | ||
test_ids: np.array, | ||
batch: int = 1, | ||
query_threshold: float = 1.0, | ||
screen: bool = False, | ||
**kwargs, | ||
): | ||
"""The parameters shown are the default set that will be passed to all | ||
query strategies. If there are additional parameters required, pass them | ||
in via `**kwargs`. | ||
Parameters | ||
---------- | ||
queryable_ids : np.array | ||
Set of ids for objects available for querying. | ||
test_ids : np.array | ||
Set of ids for objects in the test sample. | ||
batch : int, optional | ||
Number of objects to be chosen in each batch query, by default 1 | ||
query_threshold : float, optional | ||
Threshold where a query is considered worth it, by default 1.0 (no limit) | ||
screen : bool, optional | ||
If True display on screen the shift in index and | ||
the difference in estimated probabilities of being Ia | ||
caused by constraints on the sample available for querying, by default False | ||
**kwargs: dict | ||
Any additional parameters required by the query strategy. | ||
""" | ||
|
||
# The call to `super().__init__` will set the instance variables as follows: | ||
# self.queryable_ids = queryable_ids | ||
# self.test_ids = test_ids | ||
# self.batch = batch | ||
# self.query_threshold = query_threshold | ||
# self.screen = screen | ||
super().__init__(queryable_ids, test_ids, batch, query_threshold, screen, **kwargs) | ||
|
||
# If there are additional parameters, they can be set here. e.g.: | ||
# self.additional_parameters = kwargs["additional_parameters"] | ||
|
||
def sample(self, probability: np.array) -> list: | ||
"""Search for the sample with highest anomaly certainty in predicted class. | ||
Parameters | ||
---------- | ||
probability : np.array | ||
Classification probability. One value per class per object. | ||
Returns | ||
------- | ||
list | ||
List of indexes identifying the objects from the test sample | ||
to be queried in decreasing order of importance. | ||
If there are less queryable objects than the required batch | ||
it will return only the available objects -- so the list of | ||
objects to query can be smaller than 'batch'. | ||
""" | ||
|
||
# Note - the following are guidelines, but are not enforced in code. | ||
# 1) After all calculations are complete, the list of returned_indexes is | ||
# expected to have length <= self.batch. | ||
# 2) The list of returned_indexes should only include values that are in | ||
# self.queryable_ids. | ||
returned_indexes = [] | ||
|
||
return returned_indexes |
Oops, something went wrong.