-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Croptype #71
base: main
Are you sure you want to change the base?
Croptype #71
Conversation
presto/eval.py
Outdated
learning_rate=0.05, | ||
early_stopping_rounds=20, | ||
l2_leaf_reg=3, | ||
learning_rate=0.2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are we selecting these parameters?
paper_eval.py
Outdated
|
||
# argparser.add_argument("--val_samples_file", type=str, default="cropland_spatial_generalization_test_split_samples.csv") | ||
|
||
argparser.add_argument("--presto_model_type", type=str, default="presto-ft-ct") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this argument (presto_model_type
) only used to name the experiment file? if so can it be given a different name; the current name implies it will somehow affect the model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for other names which don't affect the functionality, e.g. compositing_window
.
Also I think the most up to date main
hasn't been merged into this, since the 10d compositing is also supported but doesn't seem to be here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. addressed that here 1601dbc
paper_eval.py
Outdated
model = Presto.construct(**model_kwargs) | ||
best_model_path = None | ||
model.to(device) | ||
val_samples_file = f"{task_type}_{test_type}_generalization_test_split_samples.csv" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach requires the script to be run 3 times to get the full results (instead of only once to get all 3 results, which was the case previously).
Perhaps we can update this so that a single run gets all the data?
paper_eval.py
Outdated
# check if finetuned model already exists | ||
logger.info("Checking if the finetuned model exists") | ||
if os.path.isfile(finetuned_model_path): | ||
logger.info("Finetuned model found! Loading...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to do this? I can imagine this introducing lots of unexpected errors since if a finetuned model exists from a previous run, it would automatically affect this run whether or not we want it to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also would this work if there is no finetuned model? I think the script would error out since finetuned_model
would never be initialized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also also, how is the model finetuned in this case? I think thats pretty important so it would be good to capture that in this script
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still feel the need for being able to upload the finetuned model, particularly if I want to run downstream classifiers only and collect metrics. in particular, this piece only checks if the model exists, collects metrics for the uploaded model along with spatial plots and runs sklearn models.
I didn't succeed in implementing the upload in presto.py, so now this piece wouldn't run anyway. for now, I
just put a placeholder here 532ec67
presto/eval.py
Outdated
batch_size: int = 64 | ||
patience: int = 10 | ||
num_workers: int = 4 | ||
batch_size: int = 2048 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How were these chosen? This batch size is probably too large for finetuning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, reverted batch_size to 256 here a350280
I was playing with leaning rate, but also ended up using your value.
(torch.zeros(x.shape[0])[:, None].to(device).int(), orig_indices + 1), | ||
dim=1, | ||
) | ||
x, upd_mask, orig_indices = self.add_token(latlon_tokens, x, upd_mask, orig_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cbutsko a hackey way to get the model to ignore latlons is to change the mask here. A mask value of 0s tells the model to include the value, and a mask value of 1s tells it to ignore the value.
New tokens get added to the front of the sequence. Concretely:
x
has shape [batch_size, num_tokens, dim]
and mask
has shape [batch_size, num_tokens]
.
The latlons just got added to the front of the sequence, so you can do
upd_mask[:, 0] = 1
right after line 489 (where self.add_token(latlon_tokens, x, upd_mask, orig_indices)
is called) to update the mask so that the latlon token will be ignored entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added this line as you suggested instead of filling latlons with zeros in the dataset
0cf1e60
… now naive balancing does not repeat every smallest class and is much faster; thanks for the idea @kvantricht!
Valid month and mask debugging
Device ambiguity issue #217
cc @cbutsko