-
Notifications
You must be signed in to change notification settings - Fork 63
Expand file tree
/
Copy pathsave_fp8_quantized.py
More file actions
86 lines (67 loc) · 2.72 KB
/
save_fp8_quantized.py
File metadata and controls
86 lines (67 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import os
from predict import DevPredictor, SchnellPredictor
from safetensors.torch import save_file
"""
Code to prequantize and save fp8 weights for Dev or Schnell. Pattern should work for other models.
Note - for this code to work, you'll need to tweak the config of the fp8 flux models in `predict.py` s.t. they load and quantize models.
in practice, this just means eliminating the '-fp8' suffix on the model names.
"""
def generate_dev_img(p, img_name="cool_dog_1234.png"):
p.predict("a cool dog", "1:1", None, 0, 1, 28, 3, 1234, "png", 100, True, True, "1")
os.system(f"mv out-0.png {img_name}")
def save_dev_fp8():
p = DevPredictor()
p.setup()
fp8_weights_path = "model-cache/dev-fp8"
if not os.path.exists(fp8_weights_path): # noqa: PTH110
os.makedirs(fp8_weights_path) # noqa: PTH103
generate_dev_img(p)
print(
"scale initialized: ",
p.fp8_model.fp8_pipe.model.double_blocks[0].img_mod.lin.input_scale_initialized,
)
sd = p.fp8_model.fp8_pipe.model.state_dict()
to_trim = "_orig_mod."
sd_to_save = {k[len(to_trim) :]: v for k, v in sd.items()}
save_file(sd_to_save, fp8_weights_path + "/" + "dev-fp8.safetensors")
def test_dev_fp8():
p = DevPredictor()
p.setup()
generate_dev_img(p, "cool_dog_1234_loaded_from_compiled.png")
def generate_schnell_img(p, img_name="fast_dog_1234.png"):
p.predict("a cool dog", "1:1", 1, 4, 1234, "png", 100, True, True, "1")
os.system(f"mv out-0.png {img_name}")
def save_schnell_fp8():
p = SchnellPredictor()
p.setup()
fp8_weights_path = "model-cache/schnell-fp8"
if not os.path.exists(fp8_weights_path): # noqa: PTH110
os.makedirs(fp8_weights_path) # noqa: PTH103
generate_schnell_img(p)
print(
"scale initialized: ",
p.fp8_model.fp8_pipe.model.double_blocks[0].img_mod.lin.input_scale_initialized,
)
sd = p.fp8_model.fp8_pipe.model.state_dict()
to_trim = "_orig_mod."
sd_to_save = {k[len(to_trim) :]: v for k, v in sd.items()}
save_file(sd_to_save, fp8_weights_path + "/" + "schnell-fp8.safetensors")
def test_schnell_fp8():
p = SchnellPredictor()
p.setup()
generate_schnell_img(p, "fast_dog_1234_loaded_from_compiled.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run image generation tests from YAML file"
)
parser.add_argument("flux_model", help="schnell, dev, or all")
args = parser.parse_args()
if args.flux_model == "dev" or args.flux_model == "all":
save_dev_fp8()
if args.flux_model == "schnell" or args.flux_model == "all":
save_schnell_fp8()
else:
print("testing I guess")
# test_dev_fp8()
test_schnell_fp8()