Skip to content

Commit

Permalink
feat: add gobig api (#158)
Browse files Browse the repository at this point in the history
* feat: add gobig api

* feat: add gobig api

* feat: add gobig api

* feat: add gobig api

* feat: add gobig api

* feat: add gobig api

* feat: add gobig api

* feat: add gobig api

* feat: add gobig api

* feat: add gobig api

* style: fix overload and cli autocomplete

* feat: add gobig api

* feat: add gobig api

Co-authored-by: Jina Dev Bot <[email protected]>
  • Loading branch information
hanxiao and jina-bot authored Aug 9, 2022
1 parent 9d7f12d commit 00e78b7
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 2 deletions.
11 changes: 11 additions & 0 deletions FEATURES.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,17 @@ list_diffusion_models()

You can also specify the environment variable `DISCOART_MODELS_YAML` to build your list of diffusion models.

## Go Big

"Upscale" a DiscoArt image by iteratively invoking `create()` with the same set of config (but higher `skip_rate`) on each small sliding window.
Each sliding window is diffused into higher resolution. All sliding windows are stitched together to form the final image. Overlapped areas are averaged.

```python
from discoart import create, go_big
doc = create()
doc = go_big(doc)
```

## Feature changes
- DiscoArt does not support video generation and `image_prompt` (which was marked as ineffective in DD 5.4).
Expand Down
4 changes: 2 additions & 2 deletions discoart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

__version__ = '0.11.8'
__version__ = '0.12.0'

__all__ = ['create', 'cheatsheet']

Expand All @@ -17,5 +17,5 @@
'resources',
)

from .create import create
from .create import create, go_big
from .config import cheatsheet, show_config, save_config, load_config
110 changes: 110 additions & 0 deletions discoart/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from types import SimpleNamespace
from typing import overload, List, Optional, Dict, Any, Union, TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:
import threading
Expand Down Expand Up @@ -238,3 +239,112 @@ def create(**kwargs) -> Optional['DocumentArray']:
and 'DISCOART_DISABLE_IPYTHON' not in os.environ
):
show_result_summary(_da, _name, _args)


def go_big(
doc: 'Document',
window_size: int = 256,
upscale_factor: int = 2,
skip_rate: float = 0.8,
stride_size: Optional[int] = None,
**kwargs,
) -> 'Document':
"""
"Upscale" a DiscoArt image by iteratively applying `create()` with the same config (but higher skip rate) on each small sliding window.
Each sliding window is diffused into higher resolution. All sliding windows are stitched together to form the final image. Overlapped areas are averaged.
This algorithm is coined as GoBig by DiscoArt community.
One should NOT use this function to upscale an image and expect high fidelity. It is more for creating fractal-style images. https://en.wikipedia.org/wiki/Fractal_art
as when skip_rate is low, it adds many details recursively to the image.
:param doc: the resulted doc from `create()`
:param window_size: the size of the small sliding window
:param upscale_factor: the upscale factor, the final image size will be `original size * upscale_factor`
:param skip_rate: skipping diffusion, high skip rate will result in a faster upscaling and less disruption to original image
:param stride_size: the size between sliding window, if not set, it will be `window_size * 0.75`. Smaller value means high overlap and more chunks hence slower.
:param kwargs: other kwargs will be passed to `create()`
:return: the GoBig document where image is in URI
"""
from .config import load_config

if 'DISCOART_DISABLE_RESULT_SUMMARY' not in os.environ:
os.environ['DISCOART_DISABLE_RESULT_SUMMARY'] = '1'
recover_disabled_summary = True
else:
recover_disabled_summary = False

old_args = SimpleNamespace(**load_config(user_config=doc.tags))

d = Document(doc, copy=True)

d.chunks.clear()

stride_size = stride_size or int(window_size * 3 / 4)

d.load_uri_to_image_tensor().convert_image_tensor_to_sliding_windows(
window_shape=(window_size, window_size),
strides=(stride_size, stride_size),
as_chunks=True,
padding=True,
)

final = np.zeros(
shape=(
(d.chunks[-1].location[0] + window_size) * upscale_factor,
(d.chunks[-1].location[1] + window_size) * upscale_factor,
3,
2,
),
dtype='int',
)

from .helper import logger

logger.info(
f'''
you are about to gobig from {d.tensor.shape[:2]} to {(d.tensor.shape[0] * upscale_factor, d.tensor.shape[1] * upscale_factor)}
which means running `create` iteratively over {len(d.chunks)} chunks, this may take a while. If this takes too long, please consider:
- increasing the `window_size`, which leads to fewer chunks
- increasing the `skip_rate`, which leads to fewer diffusion steps
- decreasing the `upscale_factor`, which leads to smaller final result
'''
)

for idx, c in enumerate(d.chunks):
c.tags = copy.deepcopy(d.tags)
c.tensor = (
create(
init_document=c.convert_image_tensor_to_uri(),
n_batches=1,
batch_size=1,
width_height=[window_size * 2, window_size * 2],
skip_steps=int(old_args.steps * skip_rate),
name_docarray=f'{old_args.name_docarray}-gobig-{idx}-{len(d.chunks)}',
**kwargs,
)[0]
.load_uri_to_image_tensor()
.tensor
)
patch = np.stack([c.tensor] * 2, axis=-1)
patch[:, :, :, 1] = 1

start_x = upscale_factor * c.location[0]
end_x = start_x + upscale_factor * window_size
start_y = upscale_factor * c.location[1]
end_y = start_y + upscale_factor * window_size
final[start_x:end_x, start_y:end_y, :, :] += patch

final = final[
0 : d.tensor.shape[0] * upscale_factor,
0 : d.tensor.shape[1] * upscale_factor,
:,
:,
]
d.tensor = np.array(final[:, :, :, 0] / final[:, :, :, 1], dtype='uint8')

if recover_disabled_summary:
del os.environ['DISCOART_DISABLE_RESULT_SUMMARY']

return d.convert_image_tensor_to_uri()

0 comments on commit 00e78b7

Please sign in to comment.