-
Notifications
You must be signed in to change notification settings - Fork 384
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,145 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/usr/bin/env python3 | ||
|
||
"""Extracts the configuration file from a slim inference checkpoint.""" | ||
|
||
import argparse | ||
import json | ||
from pathlib import Path | ||
import sys | ||
|
||
import k_diffusion as K | ||
import safetensors.torch as safetorch | ||
|
||
|
||
def main(): | ||
p = argparse.ArgumentParser(description=__doc__, | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
p.add_argument("checkpoint", type=Path, | ||
help="the inference checkpoint to extract the configuration from") | ||
p.add_argument("--output", "-o", type=Path, | ||
help="the output configuration file") | ||
args = p.parse_args() | ||
|
||
print(f"Loading inference checkpoint {args.checkpoint}...", file=sys.stderr) | ||
metadata = K.utils.get_safetensors_metadata(args.checkpoint) | ||
if "config" not in metadata: | ||
raise ValueError("No configuration found in checkpoint") | ||
|
||
output_path = args.output or args.checkpoint.with_suffix(".json") | ||
|
||
print(f"Saving configuration to {output_path}...", file=sys.stderr) | ||
output_path.write_text(metadata["config"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
{ | ||
"model": { | ||
"type": "image_transformer_v2", | ||
"input_channels": 3, | ||
"input_size": [256, 256], | ||
"patch_size": [4, 4], | ||
"depths": [2, 2, 4], | ||
"widths": [128, 256, 512], | ||
"self_attns": [ | ||
{"type": "neighborhood", "d_head": 64, "kernel_size": 7}, | ||
{"type": "neighborhood", "d_head": 64, "kernel_size": 7}, | ||
{"type": "global", "d_head": 64} | ||
], | ||
"loss_config": "karras", | ||
"loss_weighting": "soft-min-snr", | ||
"dropout_rate": [0.0, 0.0, 0.1], | ||
"mapping_dropout_rate": 0.0, | ||
"augment_prob": 0.0, | ||
"sigma_data": 0.5, | ||
"sigma_min": 1e-2, | ||
"sigma_max": 160, | ||
"sigma_sample_density": { | ||
"type": "cosine-interpolated" | ||
} | ||
}, | ||
"dataset": { | ||
"type": "huggingface", | ||
"location": "nelorth/oxford-flowers", | ||
"image_key": "image" | ||
}, | ||
"optimizer": { | ||
"type": "adamw", | ||
"lr": 5e-4, | ||
"betas": [0.9, 0.95], | ||
"eps": 1e-8, | ||
"weight_decay": 1e-3 | ||
}, | ||
"lr_sched": { | ||
"type": "constant", | ||
"warmup": 0.0 | ||
}, | ||
"ema_sched": { | ||
"type": "inverse", | ||
"power": 0.75, | ||
"max_value": 0.9999 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
{ | ||
"model": { | ||
"type": "image_transformer_v2", | ||
"input_channels": 3, | ||
"input_size": [256, 256], | ||
"patch_size": [4, 4], | ||
"depths": [2, 2, 4], | ||
"widths": [128, 256, 512], | ||
"self_attns": [ | ||
{"type": "shifted-window", "d_head": 64, "window_size": 8}, | ||
{"type": "shifted-window", "d_head": 64, "window_size": 8}, | ||
{"type": "global", "d_head": 64} | ||
], | ||
"loss_config": "karras", | ||
"loss_weighting": "soft-min-snr", | ||
"dropout_rate": [0.0, 0.0, 0.1], | ||
"mapping_dropout_rate": 0.0, | ||
"augment_prob": 0.0, | ||
"sigma_data": 0.5, | ||
"sigma_min": 1e-2, | ||
"sigma_max": 160, | ||
"sigma_sample_density": { | ||
"type": "cosine-interpolated" | ||
} | ||
}, | ||
"dataset": { | ||
"type": "huggingface", | ||
"location": "nelorth/oxford-flowers", | ||
"image_key": "image" | ||
}, | ||
"optimizer": { | ||
"type": "adamw", | ||
"lr": 5e-4, | ||
"betas": [0.9, 0.95], | ||
"eps": 1e-8, | ||
"weight_decay": 1e-3 | ||
}, | ||
"lr_sched": { | ||
"type": "constant", | ||
"warmup": 0.0 | ||
}, | ||
"ema_sched": { | ||
"type": "inverse", | ||
"power": 0.75, | ||
"max_value": 0.9999 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from . import flops | ||
from .flags import checkpointing, get_checkpointing | ||
from .image_v1 import ImageDenoiserModelV1 | ||
from .image_transformer_v1 import ImageTransformerDenoiserModelV1 | ||
from .image_transformer_v2 import ImageTransformerDenoiserModelV2 |
Oops, something went wrong.