diff --git a/README.md b/README.md index b66b70d..7feaec3 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Dataset Alignment -This code corresponds to the paper "AlignNet: Learning dataset score alignment functions to enable better training of speech quality estimators" Jaden Pieper, Steve Voran, to appear in Proc. Interspeech 2024 and with [preprint available here](https://arxiv.org/abs/2406.10205). +This code corresponds to the paper "AlignNet: Learning dataset score alignment functions to enable better training of speech quality estimators," by Jaden Pieper, Stephen D. Voran, to appear in Proc. Interspeech 2024 and with [preprint available here](https://arxiv.org/abs/2406.10205). When training a no-reference (NR) speech quality estimator, multiple datasets provide more information and can thus lead to better training. But they often are inconsistent in the sense that they use different subjective testing scales, or the exact same scale is used differently by test subjects due to the corpus effect. AlignNet improves the training of NR speech quality estimators with multiple, independent datasets. AlignNet uses an AudioNet to generate intermediate score estimates before using the Aligner to map intermediate estimates to the appropriate score range. @@ -7,11 +7,11 @@ AlignNet is intentionally designed to be independent of the choice of AudioNet. This repository contains implementations of two different AudioNet choices: [MOSNet](https://arxiv.org/abs/1904.08352) and a simple example of a novel multi-scale convolution approach. -MOSNet demonstrates a network that takes the STFT of an audio signal as its input and the multi-scale convolution network is provided primarily as an example of a network that takes raw audio as an input. +MOSNet demonstrates a network that takes the STFT of an audio signal as its input, and the multi-scale convolution network is provided primarily as an example of a network that takes raw audio as an input. # Installation ## Dependencies -There are two included environment files. `environment.yml` has the dependencies required to train with alignnet, but does not impose version requirements. It is thus susceptible to issues in the future if packages deprecate methods or have major backwards compatibility breaks. On the otherhand `environment-paper.yml` contains the exact versions of the packages that were used for all the results reported in our paper. +There are two included environment files. `environment.yml` has the dependencies required to train with alignnet but does not impose version requirements. It is thus susceptible to issues in the future if packages deprecate methods or have major backwards compatibility breaks. On the other hand, `environment-paper.yml` contains the exact versions of the packages that were used for all the results reported in our paper. Create and activate the `alignnet` environment. ``` @@ -25,15 +25,15 @@ pip install . ``` # Preparing data for training -When training with multiple datasets some work must first be done to format them in a consistent manner so they can all be loaded in the same way. -For each dataset one must first make a csv that has subjective score in column called `MOS` and path to audio file in column called `audio_path`. +When training with multiple datasets, some work must first be done to format them in a consistent manner so they can all be loaded in the same way. +For each dataset, one must first make a csv that has subjective score in column called `MOS` and path to audio file in column called `audio_path`. -If your `audio_net` model requires transformed data you can transform it prior to training with `pretransform_data.py` (see `python pretransform_data.py --help` for more information) and store paths to those transformed representation files in a column called `transform_path`. For example MOSNet uses the STFT of audio as an input. For more efficient training, pretransforming the audio into STFT representations, saving them, and including a column called `stft_path` in the csv is recommended. +If your `audio_net` model requires transformed data, you can transform it prior to training with `pretransform_data.py` (see `python pretransform_data.py --help` for more information) and store paths to those transformed representation files in a column called `transform_path`. For example, MOSNet uses the STFT of audio as an input. For more efficient training, pretransforming the audio into STFT representations, saving them, and including a column called `stft_path` in the csv is recommended. More generally, the column name must match the value of `data.pathcol`. -For examples see [MOSNet](alignnet/config/models/pretrain-MOSNet.yaml) or [MultiScaleConvolution](alignnet/config/models/pretrain-msc.yaml). +For examples, see [MOSNet](alignnet/config/models/pretrain-MOSNet.yaml) or [MultiScaleConvolution](alignnet/config/models/pretrain-msc.yaml). -For each dataset, slit the data into training, validation, and testing portions with +For each dataset, split the data into training, validation, and testing portions with ``` python split_labeled_data.py /path/to/data/file.csv --output-dir /datasetX/splits/path ``` @@ -50,16 +50,16 @@ Some basic training help can be found with python train.py --help ``` -To see an example config file and all the overrideable parameters for training MOSNet with AlignNet run +To see an example config file and all the overrideable parameters for training MOSNet with AlignNet, run ``` python train.py --config-dir alignnet/config/models --config-name=alignnet-MOSNet --cfg job ``` Here the `--cfg job` shows the configuration for this job without running the code. -If you are not training with a [clearML](https://clear.ml/) server be sure to set `logging=none`. +If you are not training with a [clearML](https://clear.ml/) server, be sure to set `logging=none`. To change the number of workers used for data loading, override the `data.num_workers` parameter, which defaults to 6. -As an example and to confirm you have appropriately overridden these parameters you could run +As an example, and to confirm you have appropriately overridden these parameters, you could run ``` python train.py logging=none data.num_workers=4 --config-dir alignnet/config/models --config-name=alignnet-MOSNet --cfg job ``` @@ -108,7 +108,7 @@ finetune.restore_file=/absolute/path/to/alignnet/trained_models/pretrained-MOSNe ## MultiScaleConvolution example Training NR speech estimators with AlignNet is intentionally designed to be agnostic to the choice of AudioNet. -To demonstrate this we include code for a rudimentary network that takes raw audio in as an input and trains separate convolutional networks on multiple time scales that are then aggregated into a single network component. +To demonstrate this, we include code for a rudimentary network that takes in raw audio as an input and trains separate convolutional networks on multiple time scales that are then aggregated into a single network component. This network is defined as `alignnet.MultiScaleConvolution` and can be trained via: ``` python path/to/alignnet/train.py \ @@ -123,18 +123,18 @@ Some basic help can be seen via python inference.py --help ``` -In general three overrides must be set: +In general, three overrides must be set: * `model.path` - path to a trained model * `data.data_files` - list containing absolute paths to csv files that list audio files to perform inference on. * `output.file` - path to file where inference output will be stored. -After running inference a csv will be created at `output.file` with the following columns: +After running inference, a csv will be created at `output.file` with the following columns: * `file` - filenames where audio was loaded from * `estimate` - estimate generated by the model -* `dataset` - index for which file from `data.data_files` this file belongs to. -* `AlignNet dataset index` - index for which dataset within the model the scores come from. This will be the same for every file in the csv. The default dataset will always be the reference dataset but this can be overriden via `model.dataset_index`. +* `dataset` - index listing which file from `data.data_files` this file belongs to. +* `AlignNet dataset index` - index listing which dataset within the model the scores come from. This will be the same for every file in the csv. The default dataset will always be the reference dataset, but this can be overriden via `model.dataset_index`. -For example, to run inference using the included AlignNet model trained on the smaller datasets one would run +For example, to run inference using the included AlignNet model trained on the smaller datasets, one would run ``` python inference.py \ data.data_files=[/absolute/path/to/inference/data1.csv,/absolute/path/to/inference/data2.csv] \ @@ -144,13 +144,13 @@ output.file=estimations.csv # Gathering datasets used in 2024 Conference Paper -Here are links and reference to help with locating the data we have used in the paper. +Here are links and references to help with locating the data we have used in the paper. * [Blizzard 2021](https://www.cstr.ed.ac.uk/projects/blizzard/data.html) * Z.-H. Ling, X. Zhou, and S. King, "The Blizzard challenge 2021," in Proc. Blizzard Challenge Workshop, 2021. * [Blizzard 2008](https://www.cstr.ed.ac.uk/projects/blizzard/data.html) * V. Karaiskos, S. King, R. A. J. Clark, and C. Mayo, "The Blizzard challenge 2008," in Proc. Blizzard Challenge Workshop, 2008. -* [FFTnet](https://gfx.cs.princeton.edu/pubs/Jin_2018_FAR/clips/) +* [FFTNet](https://gfx.cs.princeton.edu/pubs/Jin_2018_FAR/clips/) * Z. Jin, A. Finkelstein, G. J. Mysore, and J. Lu, "FFTNet: a real-time speaker-dependent neural vocoder," in Proc. IEEE International Conference on Acoustics, Speech and Signal Processing, 2018. * [NOIZEUS](https://ecs.utdallas.edu/loizou/speech/noizeus/) * Y. Hu and P. Loizou, "Subjective comparison of speech enhancement algorithms," in Proc. IEEE International Conference on Acoustics, Speech and Signal Processing, 2006. @@ -159,7 +159,7 @@ Here are links and reference to help with locating the data we have used in the * [Tencent](https://github.com/ConferencingSpeech/ConferencingSpeech2022) * G. Yi, W. Xiao, Y. Xiao, B. Naderi, S. Moller, W. Wardah, G. Mittag, R. Cutler, Z. Zhang, D. S. Williamson, F. Chen, F. Yang, and S. Shang, "ConferencingSpeech 2022 Challenge: Non-intrusive objective speech quality assessment challenge for online conferencing applications," in Proc. Interspeech, 2022, pp. 3308–3312. * [NISQA](https://github.com/gabrielmittag/NISQA/wiki/NISQA-Corpus) - * G. Mittag, B. Naderi, A. Chehadi, and S. M ̈oller, "NISQA: A deep CNN-self-attention model for multidimensional speech quality prediction with crowdsourced datasets,” in Proc. Interspeech, 2021, pp. 2127–2131. + * G. Mittag, B. Naderi, A. Chehadi, and S. Möller, "NISQA: A deep CNN-self-attention model for multidimensional speech quality prediction with crowdsourced datasets,” in Proc. Interspeech, 2021, pp. 2127–2131. * [Voice Conversion Challenge 2018](https://datashare.ed.ac.uk/handle/10283/3257) * J. Lorenzo-Trueba, J. Yamagishi, T. Toda, D. Saito, F. Villavicencio, T. Kinnunen, and Z. Ling, “The voice conversion challenge 2018: Promoting development of parallel and nonparallel methods,” in Proc. Speaker Odyssey, 2018. * [Indiana U. MOS](https://github.com/ConferencingSpeech/ConferencingSpeech2022) diff --git a/alignnet/config/hydra/help/train_help.yaml b/alignnet/config/hydra/help/train_help.yaml index 18ce6c3..72978b0 100644 --- a/alignnet/config/hydra/help/train_help.yaml +++ b/alignnet/config/hydra/help/train_help.yaml @@ -3,7 +3,7 @@ app_name: AlignNet header: == Training ${hydra.help.app_name} == footer: |- - Powered by Hydra (https://hyrda.cc) + Powered by Hydra (https://hydra.cc) Use --hydra-help to view Hydra specific help. template: |- diff --git a/alignnet/data.py b/alignnet/data.py index 2896833..7d080ba 100644 --- a/alignnet/data.py +++ b/alignnet/data.py @@ -170,7 +170,7 @@ def padding(self, batch): # Concatenate into one tensor audio_out = torch.stack(audio_out, dim=0) - # If a transform is defined and the transform time is at collate now is the time to apply it + # If a transform is defined and the transform time is at collate, now is the time to apply it if self.transform is not None and self.transform_time == "collate": audio_out = self.transform.transform(audio_out) audio_out = torch.unsqueeze(audio_out, dim=1) @@ -179,7 +179,7 @@ def padding(self, batch): class FeatureData(AudioData): """ - For loading pre-computed features for audio files. Only the __getitem__ method needs to change + For loading pre-computed features for audio files. Only the __getitem__ method is changed """ def __init__( @@ -229,7 +229,7 @@ def __getitem__(self, idx): audio = self.wavs[audio_path] else: fname, ext = os.path.splitext(audio_path) - # If using same split csvs as audio this may say wav and not pt + # If using same split csvs as audio, this may say wav and not pt # (coming out of pretransform_data.py will save as pt) if ext == ".wav": audio_path = fname + ".pkl" @@ -296,7 +296,7 @@ def __init__( """ super().__init__() - # If this class sees batch_size=auto it sets to default value and assumes a Tuner is being called in the main + # If this class sees batch_size=auto, it sets to default value and assumes a Tuner is being called in the main # logic to update this later if batch_size == "auto": batch_size = 32 @@ -313,11 +313,11 @@ def setup(self, stage: str): """ Load different datasubsets depending on stage. - If stage == 'fit' then train, valid, and test data are loaded. + If stage == 'fit', then train, valid, and test data are loaded. - If stage == 'test' then only test data is loaded. + If stage == 'test', then only test data is loaded. - If stage == 'predict' then self.data_dirs should be full paths to the specific + If stage == 'predict', then self.data_dirs should be full paths to the specific csv files to run predictions on. Parameters @@ -382,7 +382,7 @@ def find_datasubsets(self, data_paths, subset): def find_datasubset(self, data_path, subset): """ - Helper function for setup to find the different data subsets (test/train/valid) + Helper function for setup to find the different data subsets (train/valid/test) Parameters ---------- diff --git a/alignnet/model.py b/alignnet/model.py index 107905b..d2db142 100644 --- a/alignnet/model.py +++ b/alignnet/model.py @@ -98,7 +98,7 @@ def __init__( activation : nn.Module Activation to include between layers. There will always be n_layers - 1 activations in the sequence. layer_dims : list - List of layer dimensions, not including input features (these are specificed by in_features). + List of layer dimensions, not including input features (these are specified by in_features). """ super().__init__() if layer_dims is not None and n_layers != len(layer_dims): @@ -112,10 +112,10 @@ def __init__( def setup_layers(self, n_layers): """ - Setup and store layers into `output_layers` attribute. + Set up and store layers into `output_layers` attribute. If `self.layer_dims` is not None, linear layers are made that match the - dimension of that list. If it is none, layers are made such that the dimension + dimension of that list. If it is None, layers are made such that the dimension decreases by 1/2 for each layer. Parameters @@ -162,7 +162,7 @@ def forward(self, frame_scores): Returns ------- torch.Tensor - Frame based representation of audio (e.g. feature x frames tensor for each audio file). + Frame-based representation of audio (e.g., features x frames tensor for each audio file). """ for k, layer in enumerate(self.output_layers): frame_scores = layer(frame_scores) @@ -252,7 +252,7 @@ def __init__( # num_layers - number of recurrent layers num_layers=1, # bias - bool if bias weight used (defaults to True) - # batch_first - if True then input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature) + # batch_first - if True, then input and output tensors are provided as (batch, seq, feature) instead of (seq, batch, feature) batch_first=True, # dropout # bidirectional @@ -276,7 +276,7 @@ def forward(self, x): # x dim: (B, C, T, F) = (B, T, 1, 257) y = self.convolutions(x) # y dim: (B, C, T, F) = (B, 128, T, 4) - # Swap dimensions to preserve frame level time before flattening for BLSTM + # Swap dimensions to preserve frame-level time before flattening for BLSTM y = torch.movedim(y, -2, -3).flatten(start_dim=-2, end_dim=-1) # y dim: (B, T, F*C): (B, T, 512) y, _ = self.blstm(y) @@ -319,7 +319,7 @@ def __init__( dilation : int, optional Convolution dilation, by default 1 pooling_type : str, optional - Type of pooling to perform. Either "average", "blur", or "None", by default "average". + Type of pooling to perform: "average", "blur", or "None", by default "average". """ super().__init__() @@ -511,8 +511,8 @@ class MultiScaleConvolution(nn.Module): def __init__(self, path1, path2, path3, path4): """ Neural network that processes audio in up to four independent paths prior to - combining in a fully connected sequence. Each path is compressed to be the same - size regardless of audio length through simple statistical aggregations. + combining in a fully connected sequence. By means of simple statistical + aggregations, each path is compressed to the same size, regardless of audio length. Parameters ---------- @@ -547,7 +547,7 @@ def __init__(self, path1, path2, path3, path4): def forward(self, x): if len(x.shape) > 3 and x.shape[1] == 1: - # This may not be the best place/way to do this but should work on mono audio + # This may not be the best place/way to do this, but should work on mono audio x = torch.squeeze(x, dim=1) path_outs = [] for conv_path in self.conv_paths: @@ -576,9 +576,9 @@ def __init__(self, reference_index=0, num_datasets=0, **kwargs): Parameters ---------- reference_index : int, optional - Unused but exists to easily replace other Aligner setups, by default 0 + Unused, but exists to easily replace other Aligner setups, by default 0 num_datasets : int, optional - Unused but exists to easily replace other Aligner setups, by default 0 + Unused, but exists to easily replace other Aligner setups, by default 0 """ super().__init__() self.reference_index = reference_index @@ -613,7 +613,7 @@ def __init__( embedding_dim : int, optional Size of the dataset index embedding, by default 10 layer_dims : list, optional - Dimensions of the Aligner fully connected layers, by default [16, 16, 16, 16, 1] + Dimensions of the Aligner's fully connected layers, by default [16, 16, 16, 16, 1] """ super().__init__() self.reference_index = reference_index @@ -653,11 +653,11 @@ def __init__( audio_net : nn.Module Network component that maps audio to quality on the reference dataset scale. aligner : nn.Module - Network componenent that maps intermediate quality estimates and dataset + Network component that maps intermediate quality estimates and dataset indicators to the appropriate dataset score. aligner_corr_threshold : float, optional Correlation threshold that determines when the aligner is activated. - If None the aligner turns on immediately, by default None + If None, the aligner turns on immediately, by default None audio_net_freeze_epochs : int, optional Number of epochs to keep the audio_net frozen, by default 0 """ @@ -730,7 +730,7 @@ def __init__( loss_weights : list List of weights to compute weighted average of loss over datasets. If None, then loss is computed without respect to datasets. In the case where one dataset has significantly less data, a weighted average allows - more control to ensure it is properly learned. If loss_weights = 1, then the all datasets will get equal weight. + more control to ensure it is properly learned. If loss_weights = 1, then all the datasets will receive equal weight. """ super().__init__() # self.save_hyperparameters(ignore=["network", "loss"]) @@ -767,7 +767,7 @@ def loss_calc(self, mean_estimate, mos, dataset): torch.tensor Loss. """ - # If there are loss weights use them + # If there are loss weights, use them if self.loss_weights is not None: loss = 0 for dix in torch.unique(dataset): @@ -799,8 +799,8 @@ def _forward(self, training_batch): mos = mos.float() mean_estimate = self.network(audio, dataset) - # If audio is 2-D (e.g. wav2vec representation) needs to be squeezed in diminsion 1 here - # If audio is raw wav this won't do anything (dim 1 will be frames and != 1) + # If audio is 2-D (e.g., wav2vec representation), needs to be squeezed in diminsion 1 here + # If audio is raw wav, this won't do anything (dim 1 will be frames and != 1) mean_estimate = torch.squeeze(mean_estimate, dim=1) loss = self.loss_calc(mean_estimate, mos, dataset) @@ -834,8 +834,8 @@ def training_step(self, training_batch, idx): def validation_step(self, val_batch, idx): """ - Validtion step. Unlike the training and test steps, we need to store per - dataset information here. + Validation step. Unlike the training and test steps, we need to store + per-dataset information here. """ audio, mos, dataset = val_batch mos = mos.float() @@ -854,20 +854,20 @@ def validation_step(self, val_batch, idx): def on_validation_epoch_end(self) -> None: """ - At the end of validation epochs we calculate per dataset statistics. + At the end of validation epochs we calculate per-dataset statistics. """ # Concatenate stored epoch data into single tensor for each metric estimates = torch.cat(self.validation_step_info["outputs"], dim=0) targets = torch.cat(self.validation_step_info["targets"], dim=0) datasets = torch.cat(self.validation_step_info["datasets"], dim=0) - # Overall loss and correlatoin + # Overall loss and correlation loss = self.loss_calc(estimates, targets, datasets) corrcoef = self.pearsons_corr(estimates, targets) # Check if network has a use_aligner flag if hasattr(self.network, "use_aligner"): - # If aligner is off and we have passed the correlation threshold do the updates + # If aligner is off and we have passed the correlation threshold, do the updates if ( not self.network.use_aligner and corrcoef > self.network.aligner_corr_threshold @@ -900,7 +900,7 @@ def on_validation_epoch_end(self) -> None: for k, v in self.validation_step_info.items(): v.clear() - # If we aren't updating audio-net and our epoch has passed the wait time turn it on! + # If we aren't updating audio-net and our epoch has passed the wait time, turn it on! if ( not self.network.update_audio_net and self.epoch >= self.network.audio_net_freeze_epochs diff --git a/inference_configs/hydra/help/inference_help.yaml b/inference_configs/hydra/help/inference_help.yaml index 6fad38a..f15d4ed 100644 --- a/inference_configs/hydra/help/inference_help.yaml +++ b/inference_configs/hydra/help/inference_help.yaml @@ -3,7 +3,7 @@ app_name: AlignNet header: == Using ${hydra.help.app_name} at inference== footer: |- - Powered by Hydra (https://hyrda.cc) + Powered by Hydra (https://hydra.cc) Use --hydra-help to view Hydra specific help. template: |- @@ -11,10 +11,10 @@ template: |- This is the ${hydra.help.app_name} inference program! - To use a model at inference you must override three parameters: + To use a model at inference, you must override three parameters: * model.path : str pointing to the path containing a trained model (must have a `model.ckpt` and `config.yaml` file in path.) - * data.data_files : list containing paths to csv files with file paths to perform inference on. + * data.data_files : list containing paths to csv files with filepaths to perform inference on. The path name of the csv must correspond to `data.pathcol` which can be overriden. * output.file : str to filepath where outputs will be saved. @@ -26,7 +26,7 @@ template: |- == Config == This is the config generated for this run. - You can override everything, for example to switch to an audio input type and see all the options run: + You can override everything. For example, to switch to an audio input type and see all the options, run: ``` python inference.py input_type=audio --help @@ -42,7 +42,7 @@ template: |- $CONFIG ------- - To see the config of an example command directly without running it add + To see the config of an example command directly without running it, add `--cfg job` to your command. ${hydra.help.footer} \ No newline at end of file diff --git a/pretransform_data.py b/pretransform_data.py index 050bef9..091ca43 100644 --- a/pretransform_data.py +++ b/pretransform_data.py @@ -111,12 +111,12 @@ def transform_csv( target_fs : int Target sample rate. Audio will be resampled to this prior to transform if needed. pathcol : str, optional - Column in csv that contains audio file names, by default "filename" + Column in csv that contains audio filenames, by default "filename" """ # Load csv int dataframe df = pd.read_csv(csv_list) for ix, row in tqdm(df.iterrows(), total=len(df)): - # Get file name + # Get filename fname = row[pathcol] # Create file path fpath = os.path.join(datapath, fname) diff --git a/split_labeled_data.py b/split_labeled_data.py index 35ecda9..7edcbbe 100644 --- a/split_labeled_data.py +++ b/split_labeled_data.py @@ -35,16 +35,16 @@ def get_split_numbers(n_audio, split_fraction): def split_df_by_column(df, split_col, split_names, split_fraction): """ - Generate dictionary of indices for splitting up a DataFrame while maintaing + Generate dictionary of indices for splitting up a DataFrame while maintaining balance within splits for a single column. Split dataframe while maintaining balance of elements within a specific column. - For example if there are n conditions that are labelled within a certain column - this can ensure that the proper ratio of conditions are maintained within the - test, train, and validation datasets. E.g. 80% of the data is condition A, - 15% is condition B, and 5% is condition C, then those same ratios will be - preserved in each of the train, test, and validation datasets. + Note that if there are n conditions labelled within a certain column, this + ensures that the proper ratio of conditions is maintained within the train, validation, + and test datasets. For example, if 80% of the data is condition A, + 15% is condition B, and 5% is condition C, then those percentage ratios will + be preserved in each of the train, validation, and test datasets. Parameters @@ -89,7 +89,7 @@ def split_df(df, split_names, split_fraction): """ Generate dictionary of indices for splitting up a DataFrame. - Dictionary keys are defined by split_names and the number of items in each + Dictionary keys are defined by split_names and the number of items in each key is determined by split_fraction. Parameters @@ -146,7 +146,7 @@ def main(args, n=None): split_ix = split_df(score_df, args.split_names, args.split_fraction) else: split_ix = split_df_by_column( - # Split arcording to the split_column + # Split according to the split_column score_df, args.split_column, args.split_names, @@ -162,7 +162,7 @@ def main(args, n=None): if __name__ == "__main__": parser = ArgumentParser( - description="Split a label_file containing target and pathcol for audio file into train, test, and valid csvs.", + description="Split a label_file containing target and pathcol for audio file into train, valid, and test csvs.", formatter_class=ArgumentDefaultsHelpFormatter, ) parser.add_argument( @@ -198,8 +198,8 @@ def main(args, n=None): type=str, default=None, help=( - "Column for which data should be split according to split-fraction (e.g. force distributions of values in " - "that column across each data set.)" + "Column for which data should be split according to split-fraction (e.g., force distributions of values in " + "that column across each dataset.)" ), ) diff --git a/train.py b/train.py index 33c5516..5fe9245 100644 --- a/train.py +++ b/train.py @@ -175,7 +175,7 @@ def main(cfg: DictConfig) -> None: except MissingConfigError as E: print(f"{E}") print( - f"If you do not want to install clearML and avoid this error in the future set `logging=none` override." + f"If you do not want to install clearML and want to avoid this error in the future, set `logging=none` override." ) task = None else: @@ -215,7 +215,7 @@ def main(cfg: DictConfig) -> None: callbacks = [checkpoint_callback] if "earlystop" in cfg: - # Earlystop needs monitor (e.g., val-loss) and mode (e.g., min). This can be added via CLI/cfg, otherwise steal the checkpoint values. + # Earlystop needs monitor (e.g., val-loss) and mode (e.g., min). This can be added via CLI/cfg. Otherwise steal the checkpoint values. stop_params = {"monitor": None, "mode": None} for k, _ in stop_params.items(): if k in cfg.earlystop: @@ -275,12 +275,11 @@ def main(cfg: DictConfig) -> None: cfg.model, network=network, loss=loss, optimizer=optimizer ) print(model) - # This is actually automatically stored in the .hydra folder... - # Save a version of the config # Add working directory to config with open_dict(cfg): cfg.project.working_dir = os.getcwd() + # Save a version of the config cfg_yaml = OmegaConf.to_yaml(cfg) cfg_out = "input_config.yaml" with open(cfg_out, "w") as file: