|
1 | | -import re |
2 | | - |
3 | | -import gradio as gr |
4 | | -from fastapi import FastAPI |
5 | | - |
6 | | -import network |
7 | | -import networks |
8 | | -import lora # noqa:F401 |
9 | | -import lora_patches |
10 | | -import extra_networks_lora |
11 | | -import ui_extra_networks_lora |
12 | | -from modules import script_callbacks, ui_extra_networks, extra_networks, shared |
13 | | - |
14 | | - |
15 | | -def unload(): |
16 | | - networks.originals.undo() |
17 | | - |
18 | | - |
19 | | -def before_ui(): |
20 | | - ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) |
21 | | - |
22 | | - networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora() |
23 | | - extra_networks.register_extra_network(networks.extra_network_lora) |
24 | | - extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco") |
25 | | - |
26 | | - |
27 | | -networks.originals = lora_patches.LoraPatches() |
28 | | - |
29 | | -script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) |
30 | | -script_callbacks.on_script_unloaded(unload) |
31 | | -script_callbacks.on_before_ui(before_ui) |
32 | | -script_callbacks.on_infotext_pasted(networks.infotext_pasted) |
33 | | - |
34 | | - |
35 | | -shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { |
36 | | - "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), |
37 | | - "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), |
38 | | - "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), |
39 | | - "lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'), |
40 | | - "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), |
41 | | - "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), |
42 | | - "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), |
43 | | - "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), |
44 | | - "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), |
45 | | -})) |
46 | | - |
47 | | - |
48 | | -shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { |
49 | | - "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), |
50 | | -})) |
51 | | - |
52 | | - |
53 | | -def create_lora_json(obj: network.NetworkOnDisk): |
54 | | - return { |
55 | | - "name": obj.name, |
56 | | - "alias": obj.alias, |
57 | | - "path": obj.filename, |
58 | | - "metadata": obj.metadata, |
59 | | - } |
60 | | - |
61 | | - |
62 | | -def api_networks(_: gr.Blocks, app: FastAPI): |
63 | | - @app.get("/sdapi/v1/loras") |
64 | | - async def get_loras(): |
65 | | - return [create_lora_json(obj) for obj in networks.available_networks.values()] |
66 | | - |
67 | | - @app.post("/sdapi/v1/refresh-loras") |
68 | | - async def refresh_loras(): |
69 | | - return networks.list_available_networks() |
70 | | - |
71 | | - |
72 | | -script_callbacks.on_app_started(api_networks) |
73 | | - |
74 | | -re_lora = re.compile("<lora:([^:]+):") |
75 | | - |
76 | | - |
77 | | -def infotext_pasted(infotext, d): |
78 | | - hashes = d.get("Lora hashes") |
79 | | - if not hashes: |
80 | | - return |
81 | | - |
82 | | - hashes = [x.strip().split(':', 1) for x in hashes.split(",")] |
83 | | - hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes} |
84 | | - |
85 | | - def network_replacement(m): |
86 | | - alias = m.group(1) |
87 | | - shorthash = hashes.get(alias) |
88 | | - if shorthash is None: |
89 | | - return m.group(0) |
90 | | - |
91 | | - network_on_disk = networks.available_network_hash_lookup.get(shorthash) |
92 | | - if network_on_disk is None: |
93 | | - return m.group(0) |
94 | | - |
95 | | - return f'<lora:{network_on_disk.get_alias()}:' |
96 | | - |
97 | | - d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"]) |
98 | | - |
99 | | - |
100 | | -script_callbacks.on_infotext_pasted(infotext_pasted) |
101 | | - |
102 | | -shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory) |
| 1 | +import re |
| 2 | + |
| 3 | +import gradio as gr |
| 4 | +from fastapi import FastAPI, HTTPException |
| 5 | + |
| 6 | +import network |
| 7 | +import networks |
| 8 | +import lora # noqa:F401 |
| 9 | +import lora_patches |
| 10 | +import extra_networks_lora |
| 11 | +import ui_extra_networks_lora |
| 12 | +from modules import script_callbacks, ui_extra_networks, extra_networks, shared |
| 13 | + |
| 14 | + |
| 15 | +def unload(): |
| 16 | + networks.originals.undo() |
| 17 | + |
| 18 | + |
| 19 | +def before_ui(): |
| 20 | + ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) |
| 21 | + |
| 22 | + networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora() |
| 23 | + extra_networks.register_extra_network(networks.extra_network_lora) |
| 24 | + extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco") |
| 25 | + |
| 26 | + |
| 27 | +networks.originals = lora_patches.LoraPatches() |
| 28 | + |
| 29 | +script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) |
| 30 | +script_callbacks.on_script_unloaded(unload) |
| 31 | +script_callbacks.on_before_ui(before_ui) |
| 32 | +script_callbacks.on_infotext_pasted(networks.infotext_pasted) |
| 33 | + |
| 34 | + |
| 35 | +shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { |
| 36 | + "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), |
| 37 | + "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), |
| 38 | + "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), |
| 39 | + "lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'), |
| 40 | + "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), |
| 41 | + "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), |
| 42 | + "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), |
| 43 | + "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), |
| 44 | + "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), |
| 45 | +})) |
| 46 | + |
| 47 | + |
| 48 | +shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { |
| 49 | + "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), |
| 50 | +})) |
| 51 | + |
| 52 | + |
| 53 | +def create_lora_json(obj: network.NetworkOnDisk): |
| 54 | + return { |
| 55 | + "name": obj.name, |
| 56 | + "alias": obj.alias, |
| 57 | + "path": obj.filename, |
| 58 | + "metadata": obj.metadata, |
| 59 | + } |
| 60 | + |
| 61 | + |
| 62 | +def api_networks(_: gr.Blocks, app: FastAPI): |
| 63 | + from fastapi import Depends |
| 64 | + from fastapi.security import HTTPBasic, HTTPBasicCredentials |
| 65 | + from hmac import compare_digest |
| 66 | + |
| 67 | + security = HTTPBasic(auto_error=False) |
| 68 | + |
| 69 | + def check_api_auth(credentials: HTTPBasicCredentials = Depends(security)): |
| 70 | + if not shared.cmd_opts.api_auth: |
| 71 | + return |
| 72 | + if credentials is None: |
| 73 | + raise HTTPException( |
| 74 | + status_code=401, |
| 75 | + detail="Not authenticated", |
| 76 | + headers={"WWW-Authenticate": "Basic"}, |
| 77 | + ) |
| 78 | + valid_credentials = {} |
| 79 | + for auth_entry in shared.cmd_opts.api_auth.split(","): |
| 80 | + user, password = auth_entry.split(":", 1) |
| 81 | + valid_credentials[user.strip()] = password.strip() |
| 82 | + username = credentials.username |
| 83 | + if username in valid_credentials and compare_digest(credentials.password, valid_credentials[username]): |
| 84 | + return |
| 85 | + raise HTTPException( |
| 86 | + status_code=401, |
| 87 | + detail="Incorrect credentials", |
| 88 | + headers={"WWW-Authenticate": "Basic"}, |
| 89 | + ) |
| 90 | + |
| 91 | + @app.get("/sdapi/v1/loras", dependencies=[Depends(check_api_auth)]) |
| 92 | + async def get_loras(): |
| 93 | + return [create_lora_json(obj) for obj in networks.available_networks.values()] |
| 94 | + |
| 95 | + @app.post("/sdapi/v1/refresh-loras", dependencies=[Depends(check_api_auth)]) |
| 96 | + async def refresh_loras(): |
| 97 | + return networks.list_available_networks() |
| 98 | + |
| 99 | + |
| 100 | +script_callbacks.on_app_started(api_networks) |
| 101 | + |
| 102 | +re_lora = re.compile("<lora:([^:]+):") |
| 103 | + |
| 104 | + |
| 105 | +def infotext_pasted(infotext, d): |
| 106 | + hashes = d.get("Lora hashes") |
| 107 | + if not hashes: |
| 108 | + return |
| 109 | + |
| 110 | + hashes = [x.strip().split(':', 1) for x in hashes.split(",")] |
| 111 | + hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes} |
| 112 | + |
| 113 | + def network_replacement(m): |
| 114 | + alias = m.group(1) |
| 115 | + shorthash = hashes.get(alias) |
| 116 | + if shorthash is None: |
| 117 | + return m.group(0) |
| 118 | + |
| 119 | + network_on_disk = networks.available_network_hash_lookup.get(shorthash) |
| 120 | + if network_on_disk is None: |
| 121 | + return m.group(0) |
| 122 | + |
| 123 | + return f'<lora:{network_on_disk.get_alias()}:' |
| 124 | + |
| 125 | + d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"]) |
| 126 | + |
| 127 | + |
| 128 | +script_callbacks.on_infotext_pasted(infotext_pasted) |
| 129 | + |
| 130 | +shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory) |
0 commit comments