Skip to content

Commit

Permalink
Implement a "certainty sampler" to return most anomalous objects (#89)
Browse files Browse the repository at this point in the history
* Initial commit to implement a "certainty sampler".

* Fix spelling error.

* Updating comments and logging messages.

* Add test for certainty sampling.
  • Loading branch information
drewoldag authored Dec 9, 2024
1 parent f33b8ed commit cd86501
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 1 deletion.
121 changes: 120 additions & 1 deletion src/resspect/query_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ['uncertainty_sampling',
__all__ = ['certainty_sampling',
'uncertainty_sampling',
'random_sampling',
'uncertainty_sampling_entropy',
'uncertainty_sampling_least_confident',
Expand Down Expand Up @@ -90,6 +91,42 @@ def sample(self, probability: np.array) -> list:
raise NotImplementedError


class CertaintySampling(QueryStrategy):
"""RESSPECT-specific implementation of certainty sampling."""
def __init__(self,
queryable_ids: np.array,
test_ids: np.array,
batch: int,
query_threshold: float,
screen: bool,
**kwargs):
super().__init__(queryable_ids, test_ids, batch, query_threshold, screen, **kwargs)

def sample(self, probability: np.array) -> list:
"""Search for the sample with highest anomaly certainty in predicted class.
Parameters
----------
probability : np.array
Classification probability. One value per class per object.
Returns
-------
list
List of indexes identifying the objects from the test sample
to be queried in decreasing order of importance.
If there are less queryable objects than the required batch
it will return only the available objects -- so the list of
objects to query can be smaller than 'batch'.
"""
return certainty_sampling(probability,
test_ids=self.test_ids,
queryable_ids=self.queryable_ids,
batch=self.batch,
screen=self.screen,
query_thre=self.query_threshold)


class UncSampling(QueryStrategy):
"""RESSPECT-specific implementation of uncertainty sampling."""
def __init__(self,
Expand Down Expand Up @@ -368,6 +405,7 @@ def sample(self, probability: np.array) -> list:
screen=self.screen,
query_thre=self.query_threshold)


def compute_entropy(ps: np.array):
"""
Calcualte the entropy for discrete distributoons assuming the events are
Expand Down Expand Up @@ -411,6 +449,85 @@ def compute_qbd_mi_entropy(ensemble_probs: np.array):
return entropy_avg_dist, mutual_information


def certainty_sampling(
class_prob: np.array,
test_ids: np.array,
queryable_ids: np.array,
batch=1,
screen=False,
query_thre=1.0
) -> list:
"""Search for the sample with highest anomaly certainty in predicted class.
Parameters
----------
class_prob: np.array
Classification probability. One value per class per object.
test_ids: np.array
Set of ids for objects in the test sample.
queryable_ids: np.array
Set of ids for objects available for querying.
batch: int (optional)
Number of objects to be chosen in each batch query.
Default is 1.
screen: bool (optional)
If True display on screen the shift in index and
the difference in estimated probabilities of being Ia
caused by constraints on the sample available for querying.
query_thre: float (optional)
Maximum percentile where a spectra is considered worth it.
If not queryable object is available before this threshold,
return empty query. Default is 1.0.
Returns
-------
query_indx: list
List of indexes identifying the objects from the test sample
to be queried in decreasing order of importance.
If there are less queryable objects than the required batch
it will return only the available objects -- so the list of
objects to query can be smaller than 'batch'.
"""

if class_prob.shape[0] != test_ids.shape[0]:
raise ValueError('Number of probabiblities is different ' +
'from number of objects in the test sample!')

# calculate distance to the decision boundary - only binary classification
anomaly_value = 1.0 #! Change this to be the value of anomaly (1 or 0)
dist = abs(class_prob[:, 1] - anomaly_value)

# get indexes in increasing order
order = dist.argsort()

# only allow objects in the query sample to be chosen
flag = list(pd.Series(data=test_ids[order]).isin(queryable_ids))

# check if there are queryable objects within threshold
indx = int(len(flag) * query_thre)

if sum(flag[:indx]) > 0:

# arrange queryable elements in increasing order
flag = np.array(flag)
final_order = order[flag]

if screen:
print('\n Inside Certainty Sampling: ')
print(' query_ids: ', test_ids[final_order][:batch], '\n')
print(' number of test_ids: ', test_ids.shape[0])
print(' number of queryable_ids: ', len(queryable_ids), '\n')
print(' *** Displacement caused by constraints on query****')
print(' 0 -> ', list(order).index(final_order[0]))
print(' ', class_prob[order[0]], '-- > ', class_prob[final_order[0]], '\n')

# return the index of the highest certainty objects which are queryable
return list(final_order)[:batch]

else:
return list([])


def uncertainty_sampling(class_prob: np.array, test_ids: np.array,
queryable_ids: np.array, batch=1,
screen=False, query_thre=1.0) -> list:
Expand Down Expand Up @@ -621,6 +738,7 @@ def uncertainty_sampling_entropy(class_prob: np.array, test_ids: np.array,
else:
return list([])


def uncertainty_sampling_least_confident(class_prob: np.array, test_ids: np.array,
queryable_ids: np.array, batch=1,
screen=False, query_thre=1.0) -> list:
Expand Down Expand Up @@ -687,6 +805,7 @@ def uncertainty_sampling_least_confident(class_prob: np.array, test_ids: np.arra
else:
return list([])


def uncertainty_sampling_margin(class_prob: np.array, test_ids: np.array,
queryable_ids: np.array, batch=1,
screen=False, query_thre=1.0) -> list:
Expand Down
25 changes: 25 additions & 0 deletions tests/resspect/test_query_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
qbd_entropy,
qbd_mi,
random_sampling,
certainty_sampling,
uncertainty_sampling,
uncertainty_sampling_entropy,
uncertainty_sampling_least_confident,
Expand All @@ -30,6 +31,30 @@ def test_random_sampling(batch_size, queryable):
assert np.all(np.array(sample) % 3 == 0)


def test_certainty_sampling():
"""Test the certainty sampling functionality."""
test_ids = np.arange(0, 10)
queryable_ids = np.array([0, 1, 2, 3, 4, 7, 8, 9]) # No 5 or 6
#given that most anomalous = 1.0
class1_probs = np.array([
0.01, # 0 - normal
0.50, # 1 - low certainty
0.10, # 2 - pretty normal
0.20, # 3 - pretty normal
0.65, # 4 - low certainty
0.79, # 5 - low certainty (not queryable)
0.25, # 6 - normal
0.80, # 7 - anomalous
0.40, # 8 - low certainty
0.02, # 9 - very normal
])
class_probs = np.array([class1_probs, class1_probs]).T

# Test that we generate the correct number of samples.
sample = certainty_sampling(class_probs, test_ids, queryable_ids, batch=3)
assert len(sample) == 3
assert np.array_equal(sample, [7, 4, 1])

def test_uncertainty_sampling():
"""Test the uncertainity sampling functionality."""
test_ids = np.arange(0, 10)
Expand Down

0 comments on commit cd86501

Please sign in to comment.