Skip to content

Commit fb5a3ff

Browse files
fix: administrative api endpoints /sdapi/v1/loras an... in lora_scrip...
Administrative API endpoints /sdapi/v1/loras and /sdapi/v1/refresh-loras lack authorization validation beyond basic authentication
1 parent 82a973c commit fb5a3ff

1 file changed

Lines changed: 130 additions & 102 deletions

File tree

Lines changed: 130 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,130 @@
1-
import re
2-
3-
import gradio as gr
4-
from fastapi import FastAPI
5-
6-
import network
7-
import networks
8-
import lora # noqa:F401
9-
import lora_patches
10-
import extra_networks_lora
11-
import ui_extra_networks_lora
12-
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
13-
14-
15-
def unload():
16-
networks.originals.undo()
17-
18-
19-
def before_ui():
20-
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
21-
22-
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
23-
extra_networks.register_extra_network(networks.extra_network_lora)
24-
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
25-
26-
27-
networks.originals = lora_patches.LoraPatches()
28-
29-
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
30-
script_callbacks.on_script_unloaded(unload)
31-
script_callbacks.on_before_ui(before_ui)
32-
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
33-
34-
35-
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
36-
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
37-
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
38-
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
39-
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
40-
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
41-
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
42-
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
43-
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
44-
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
45-
}))
46-
47-
48-
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
49-
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
50-
}))
51-
52-
53-
def create_lora_json(obj: network.NetworkOnDisk):
54-
return {
55-
"name": obj.name,
56-
"alias": obj.alias,
57-
"path": obj.filename,
58-
"metadata": obj.metadata,
59-
}
60-
61-
62-
def api_networks(_: gr.Blocks, app: FastAPI):
63-
@app.get("/sdapi/v1/loras")
64-
async def get_loras():
65-
return [create_lora_json(obj) for obj in networks.available_networks.values()]
66-
67-
@app.post("/sdapi/v1/refresh-loras")
68-
async def refresh_loras():
69-
return networks.list_available_networks()
70-
71-
72-
script_callbacks.on_app_started(api_networks)
73-
74-
re_lora = re.compile("<lora:([^:]+):")
75-
76-
77-
def infotext_pasted(infotext, d):
78-
hashes = d.get("Lora hashes")
79-
if not hashes:
80-
return
81-
82-
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
83-
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
84-
85-
def network_replacement(m):
86-
alias = m.group(1)
87-
shorthash = hashes.get(alias)
88-
if shorthash is None:
89-
return m.group(0)
90-
91-
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
92-
if network_on_disk is None:
93-
return m.group(0)
94-
95-
return f'<lora:{network_on_disk.get_alias()}:'
96-
97-
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
98-
99-
100-
script_callbacks.on_infotext_pasted(infotext_pasted)
101-
102-
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
1+
import re
2+
3+
import gradio as gr
4+
from fastapi import FastAPI, HTTPException
5+
6+
import network
7+
import networks
8+
import lora # noqa:F401
9+
import lora_patches
10+
import extra_networks_lora
11+
import ui_extra_networks_lora
12+
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
13+
14+
15+
def unload():
16+
networks.originals.undo()
17+
18+
19+
def before_ui():
20+
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
21+
22+
networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
23+
extra_networks.register_extra_network(networks.extra_network_lora)
24+
extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
25+
26+
27+
networks.originals = lora_patches.LoraPatches()
28+
29+
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
30+
script_callbacks.on_script_unloaded(unload)
31+
script_callbacks.on_before_ui(before_ui)
32+
script_callbacks.on_infotext_pasted(networks.infotext_pasted)
33+
34+
35+
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
36+
"sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
37+
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
38+
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
39+
"lora_bundled_ti_to_infotext": shared.OptionInfo(True, "Add Lora name as TI hashes for bundled Textual Inversion").info('"Add Textual Inversion hashes to infotext" needs to be enabled'),
40+
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
41+
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
42+
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
43+
"lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
44+
"lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
45+
}))
46+
47+
48+
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
49+
"lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
50+
}))
51+
52+
53+
def create_lora_json(obj: network.NetworkOnDisk):
54+
return {
55+
"name": obj.name,
56+
"alias": obj.alias,
57+
"path": obj.filename,
58+
"metadata": obj.metadata,
59+
}
60+
61+
62+
def api_networks(_: gr.Blocks, app: FastAPI):
63+
from fastapi import Depends
64+
from fastapi.security import HTTPBasic, HTTPBasicCredentials
65+
from hmac import compare_digest
66+
67+
security = HTTPBasic(auto_error=False)
68+
69+
def check_api_auth(credentials: HTTPBasicCredentials = Depends(security)):
70+
if not shared.cmd_opts.api_auth:
71+
return
72+
if credentials is None:
73+
raise HTTPException(
74+
status_code=401,
75+
detail="Not authenticated",
76+
headers={"WWW-Authenticate": "Basic"},
77+
)
78+
valid_credentials = {}
79+
for auth_entry in shared.cmd_opts.api_auth.split(","):
80+
user, password = auth_entry.split(":", 1)
81+
valid_credentials[user.strip()] = password.strip()
82+
username = credentials.username
83+
if username in valid_credentials and compare_digest(credentials.password, valid_credentials[username]):
84+
return
85+
raise HTTPException(
86+
status_code=401,
87+
detail="Incorrect credentials",
88+
headers={"WWW-Authenticate": "Basic"},
89+
)
90+
91+
@app.get("/sdapi/v1/loras", dependencies=[Depends(check_api_auth)])
92+
async def get_loras():
93+
return [create_lora_json(obj) for obj in networks.available_networks.values()]
94+
95+
@app.post("/sdapi/v1/refresh-loras", dependencies=[Depends(check_api_auth)])
96+
async def refresh_loras():
97+
return networks.list_available_networks()
98+
99+
100+
script_callbacks.on_app_started(api_networks)
101+
102+
re_lora = re.compile("<lora:([^:]+):")
103+
104+
105+
def infotext_pasted(infotext, d):
106+
hashes = d.get("Lora hashes")
107+
if not hashes:
108+
return
109+
110+
hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
111+
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
112+
113+
def network_replacement(m):
114+
alias = m.group(1)
115+
shorthash = hashes.get(alias)
116+
if shorthash is None:
117+
return m.group(0)
118+
119+
network_on_disk = networks.available_network_hash_lookup.get(shorthash)
120+
if network_on_disk is None:
121+
return m.group(0)
122+
123+
return f'<lora:{network_on_disk.get_alias()}:'
124+
125+
d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
126+
127+
128+
script_callbacks.on_infotext_pasted(infotext_pasted)
129+
130+
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)

0 commit comments

Comments
 (0)