-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_vae.py
31 lines (24 loc) · 970 Bytes
/
test_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
import jax
import jax.numpy as jnp
from jax import Array
from datasets import load_dataset
from PIL import Image
from loguru import logger
jax.experimental.compilation_cache.compilation_cache.set_cache_dir("jit_cache")
from vae.vae_flax import load_pretrained_vae
import jax.experimental.compilation_cache.compilation_cache
vae, params = load_pretrained_vae("pcuenq/sd-vae-ft-mse-flax", True)
sample_size = vae.config.sample_size
@jax.jit
def step(sample: Array):
return vae.apply(params, sample, method="decode")
sample = jnp.zeros((1, 3, sample_size, sample_size))
dataset = load_dataset("roborovski/imagenet-int8-flax")
first_sample = next(iter(dataset["train"])) # type: ignore
sample_tensor = jnp.array(first_sample["vae_output"]).reshape(1, 4, 32, 32)
out = step(sample_tensor)
out_np = np.array(out[0])
img = Image.fromarray((out_np.transpose(1,2,0) * 255).astype("uint8"))
img.save("test.png")
logger.info("Saved image to test.png")