-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference Updates #405
Inference Updates #405
Changes from 8 commits
947dd46
9b9253a
ba34c0d
100db4b
4865af6
89834d7
46407d5
b8d8328
528df83
31217d5
bec4072
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
from lightning.pytorch.callbacks import Callback | ||
|
||
|
||
class CSVSaver(Callback): | ||
def __init__(self, save_dir, meta_keys=[]): | ||
self.save_dir = Path(save_dir) | ||
self.meta_keys = meta_keys | ||
|
||
def on_predict_epoch_end(self, trainer, pl_module): | ||
# Access the list of predictions from all predict_steps | ||
predictions = trainer.predict_loop.predictions | ||
feats = [] | ||
for pred, meta in predictions: | ||
batch_feats = pd.DataFrame(pred) | ||
batch_feats["filename"] = meta["filename_or_obj"] | ||
feats.append(batch_feats) | ||
pd.concat(feats).to_csv(self.save_dir / "predictions.csv", index=False) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -220,6 +220,7 @@ def __init__( | |
use_crossmae: Optional[bool] = False, | ||
context_pixels: Optional[List[int]] = [0, 0, 0], | ||
input_channels: Optional[int] = 1, | ||
features_only: Optional[bool] = False, | ||
) -> None: | ||
""" | ||
Parameters | ||
|
@@ -247,6 +248,9 @@ def __init__( | |
context_pixels: List[int] | ||
Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. | ||
input_channels: int | ||
Number of input channels | ||
features_only: bool | ||
Only use encoder to extract features | ||
""" | ||
super().__init__() | ||
assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" | ||
|
@@ -262,6 +266,7 @@ def __init__( | |
), "base_patch_size must be of length spatial_dims" | ||
|
||
self.mask_ratio = mask_ratio | ||
self.features_only = features_only | ||
|
||
self.encoder = MAE_Encoder( | ||
num_patches, | ||
|
@@ -290,5 +295,7 @@ def __init__( | |
|
||
def forward(self, img): | ||
features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio) | ||
if self.features_only: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if features_only mask_ratio must be 0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't want to enforce that in case we want to pass in a mask and just look at features from that subset of tokens (wishful thinking for now, but that's the reason behind that choice) |
||
return features | ||
predicted_img = self.decoder(features, forward_indexes, backward_indexes) | ||
return predicted_img, mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create a data loader?