-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathreconstruct.py
More file actions
123 lines (98 loc) · 4 KB
/
reconstruct.py
File metadata and controls
123 lines (98 loc) · 4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Image Reconstruction code
"""
import os
import sys
sys.path.append(os.getcwd())
import torch
from omegaconf import OmegaConf
import importlib
import numpy as np
from PIL import Image
from tqdm import tqdm
# from taming.models.lfqgan import VQModel
import argparse
try:
import torch_npu
except:
pass
if hasattr(torch, "npu"):
DEVICE = torch.device("npu:0" if torch_npu.npu.is_available() else "cpu")
else:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_vqgan_new(config, ckpt_path=None, is_gumbel=False):
# model = VQModel(**config.model.init_args)
model = instantiate_from_config(config.model)
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
missing, unexpected = model.load_state_dict(sd, strict=False)
return model.eval()
def get_obj_from_str(string, reload=False):
print(string)
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "class_path" in config:
raise KeyError("Expected key `class_path` to instantiate.")
return get_obj_from_str(config["class_path"])(**config.get("init_args", dict()))
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.)/2.
x = x.permute(1,2,0).numpy()
x = (255*x).astype(np.uint8)
x = Image.fromarray(x)
if not x.mode == "RGB":
x = x.convert("RGB")
return x
def main(args):
config_file = args.config_file
configs = OmegaConf.load(config_file)
configs.data.init_args.batch_size = args.batch_size # change the batch size
configs.data.init_args.test.params.config.size = args.image_size #using test to inference
configs.data.init_args.test.params.config.subset = args.subset #using the specific data for comparsion
model = load_vqgan_new(configs, args.ckpt_path).to(DEVICE)
visualize_dir = args.save_dir
visualize_version = args.version
visualize_original = os.path.join(visualize_dir, visualize_version, "original_{}".format(args.image_size))
visualize_rec = os.path.join(visualize_dir, visualize_version, "rec_{}".format(args.image_size))
if not os.path.exists(visualize_original):
os.makedirs(visualize_original, exist_ok=True)
if not os.path.exists(visualize_rec):
os.makedirs(visualize_rec, exist_ok=True)
dataset = instantiate_from_config(configs.data)
dataset.prepare_data()
dataset.setup()
count = 0
with torch.no_grad():
for idx, batch in tqdm(enumerate(dataset._val_dataloader())):
if count > args.image_num:
break
images = batch["image"].to(DEVICE)
count += images.shape[0]
if model.use_ema:
with model.ema_scope():
reconstructed_images, _, _ = model(images)
image = images[0]
reconstructed_image = reconstructed_images[0]
image = custom_to_pil(image)
reconstructed_image = custom_to_pil(reconstructed_image)
image.save(os.path.join(visualize_original, "{}.png".format(idx)))
reconstructed_image.save(os.path.join(visualize_rec, "{}.png".format(idx)))
def get_args():
parser = argparse.ArgumentParser(description="inference parameters")
parser.add_argument("--config_file", required=True, type=str)
parser.add_argument("--ckpt_path", required=True, type=str)
parser.add_argument("--image_size", default=256, type=int)
parser.add_argument("--batch_size", default=1, type=int) ## inference only using 1 batch size
parser.add_argument("--image_num", default=50, type=int)
parser.add_argument("--subset", default=None)
parser.add_argument("--version", type=str, required=True)
parser.add_argument("--save_dir", type=str, required=True)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args)