Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Croptype #71

Open
wants to merge 197 commits into
base: main
Choose a base branch
from
Open

Croptype #71

wants to merge 197 commits into from

Conversation

gabrieltseng
Copy link
Collaborator

@gabrieltseng gabrieltseng commented Jun 4, 2024

presto/eval.py Outdated
learning_rate=0.05,
early_stopping_rounds=20,
l2_leaf_reg=3,
learning_rate=0.2,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are we selecting these parameters?

paper_eval.py Outdated

# argparser.add_argument("--val_samples_file", type=str, default="cropland_spatial_generalization_test_split_samples.csv")

argparser.add_argument("--presto_model_type", type=str, default="presto-ft-ct")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this argument (presto_model_type) only used to name the experiment file? if so can it be given a different name; the current name implies it will somehow affect the model

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for other names which don't affect the functionality, e.g. compositing_window.

Also I think the most up to date main hasn't been merged into this, since the 10d compositing is also supported but doesn't seem to be here.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. addressed that here 1601dbc

paper_eval.py Outdated
model = Presto.construct(**model_kwargs)
best_model_path = None
model.to(device)
val_samples_file = f"{task_type}_{test_type}_generalization_test_split_samples.csv"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach requires the script to be run 3 times to get the full results (instead of only once to get all 3 results, which was the case previously).

Perhaps we can update this so that a single run gets all the data?

paper_eval.py Outdated
# check if finetuned model already exists
logger.info("Checking if the finetuned model exists")
if os.path.isfile(finetuned_model_path):
logger.info("Finetuned model found! Loading...")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to do this? I can imagine this introducing lots of unexpected errors since if a finetuned model exists from a previous run, it would automatically affect this run whether or not we want it to

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would this work if there is no finetuned model? I think the script would error out since finetuned_model would never be initialized?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also also, how is the model finetuned in this case? I think thats pretty important so it would be good to capture that in this script

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still feel the need for being able to upload the finetuned model, particularly if I want to run downstream classifiers only and collect metrics. in particular, this piece only checks if the model exists, collects metrics for the uploaded model along with spatial plots and runs sklearn models.
I didn't succeed in implementing the upload in presto.py, so now this piece wouldn't run anyway. for now, I
just put a placeholder here 532ec67

presto/eval.py Outdated
batch_size: int = 64
patience: int = 10
num_workers: int = 4
batch_size: int = 2048
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How were these chosen? This batch size is probably too large for finetuning

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, reverted batch_size to 256 here a350280
I was playing with leaning rate, but also ended up using your value.

(torch.zeros(x.shape[0])[:, None].to(device).int(), orig_indices + 1),
dim=1,
)
x, upd_mask, orig_indices = self.add_token(latlon_tokens, x, upd_mask, orig_indices)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbutsko a hackey way to get the model to ignore latlons is to change the mask here. A mask value of 0s tells the model to include the value, and a mask value of 1s tells it to ignore the value.

New tokens get added to the front of the sequence. Concretely:

x has shape [batch_size, num_tokens, dim]
and mask has shape [batch_size, num_tokens].

The latlons just got added to the front of the sequence, so you can do

upd_mask[:, 0] = 1

right after line 489 (where self.add_token(latlon_tokens, x, upd_mask, orig_indices) is called) to update the mask so that the latlon token will be ignored entirely.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added this line as you suggested instead of filling latlons with zeros in the dataset
0cf1e60

Butsko Christina and others added 30 commits October 9, 2024 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants