Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/scope/core/pipelines/wan2_1/lora/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,70 @@
from __future__ import annotations

import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any

from .manager import LoRAManager

logger = logging.getLogger(__name__)

# How long (seconds) to wait for a LoRA file that appears to be downloading.
# Civitai downloads typically complete within 60–90 s on fal workers.
_LORA_WAIT_TIMEOUT_S = 120
# Poll interval while waiting for a LoRA file to appear.
_LORA_WAIT_POLL_S = 2.0


def _wait_for_lora_files(
lora_configs: list[dict[str, Any]],
timeout_s: float = _LORA_WAIT_TIMEOUT_S,
poll_s: float = _LORA_WAIT_POLL_S,
) -> None:
"""Block until every LoRA file in *lora_configs* exists on disk (or timeout).

This prevents a race condition where the pipeline ``__init__`` attempts to
load a LoRA that is still being downloaded from Civitai / HuggingFace. On
session reinitialisation the download and the pipeline load are initiated
concurrently; waiting here resolves the race without requiring changes to
the download path.

Files that already exist are skipped immediately. After *timeout_s* seconds
a warning is logged and we proceed — the strategy loader will raise its own
``FileNotFoundError`` if the file is genuinely missing.
"""
pending = [
cfg["path"]
for cfg in lora_configs
if cfg.get("path") and not Path(cfg["path"]).exists()
]
if not pending:
return

logger.info(
"_wait_for_lora_files: %d LoRA file(s) not yet on disk, waiting up to %ss: %s",
len(pending),
timeout_s,
[Path(p).name for p in pending],
)

deadline = time.monotonic() + timeout_s
while pending and time.monotonic() < deadline:
time.sleep(poll_s)
pending = [p for p in pending if not Path(p).exists()]

if pending:
logger.warning(
"_wait_for_lora_files: timed out after %ss; %d file(s) still missing: %s. "
"Proceeding — the LoRA loader will raise if files remain absent.",
timeout_s,
len(pending),
[Path(p).name for p in pending],
)
else:
logger.info("_wait_for_lora_files: all LoRA files are now present")

PERMANENT_MERGE_MODE = "permanent_merge"
RUNTIME_PEFT_MODE = "runtime_peft"

Expand Down Expand Up @@ -97,6 +154,13 @@ def _init_loras(self, config: Any, model) -> Any:
self.loaded_lora_adapters = []
return model

# Wait for any LoRA files that are still being downloaded (e.g. Civitai
# assets fetched concurrently during session reinitialisation). Without
# this gate, pipeline __init__ can race the download and raise
# FileNotFoundError on the first load attempt even though the file will
# be present seconds later. See daydreamlive/scope#937.
_wait_for_lora_files(list(lora_configs))

# Delegate to strategy managers via LoRAManager
self.loaded_lora_adapters = LoRAManager.load_adapters_from_list(
model=model,
Expand Down
69 changes: 67 additions & 2 deletions src/scope/core/pipelines/wan2_1/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,28 +90,93 @@ def normalize_lora_key(lora_base_key: str) -> str:
return lora_base_key


def _wait_for_lora_file(
lora_path: str,
timeout: float | None = None,
poll_interval: float = 2.0,
) -> bool:
"""Poll until *lora_path* exists on disk or *timeout* is exceeded.

Handles the race condition where a LoRA file is being downloaded from a
remote source (e.g. Civitai) concurrently with pipeline initialisation.
The pipeline calls ``load_lora_weights`` synchronously while the download
runs in a separate thread; without the poll the load fails with
``FileNotFoundError`` even though the file will be available shortly.

Args:
lora_path: Absolute path of the LoRA file to wait for.
timeout: Maximum seconds to wait. Defaults to the value of the
``SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT`` environment variable
(default: 120 s). Pass 0 to disable waiting entirely.
poll_interval: Seconds between existence checks (default 2 s).

Returns:
``True`` if the file appeared within the timeout, ``False`` otherwise.
"""
import time

if timeout is None:
timeout = float(os.getenv("SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT", "120"))

if timeout <= 0 or os.path.exists(lora_path):
return os.path.exists(lora_path)

logger.info(
"_wait_for_lora_file: '%s' not yet present — waiting up to %.0fs for "
"in-flight download to complete (poll every %.1fs)",
lora_path,
timeout,
poll_interval,
)

deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
time.sleep(poll_interval)
if os.path.exists(lora_path):
waited = timeout - (deadline - time.monotonic())
logger.info(
"_wait_for_lora_file: '%s' appeared after %.1fs",
lora_path,
waited,
)
return True

return False


def load_lora_weights(lora_path: str) -> dict[str, torch.Tensor]:
"""
Load LoRA weights from .safetensors or .bin file.

If the file does not exist immediately, this function will poll for up to
``SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT`` seconds (default: 120) to allow an
in-flight Civitai/HuggingFace download to complete before raising. This
prevents spurious ``FileNotFoundError`` failures during session reinit when
a LoRA asset download races with pipeline ``__init__``.

Args:
lora_path: Path to LoRA file (.safetensors or .bin)

Returns:
Dictionary mapping parameter names to tensors

Raises:
FileNotFoundError: If the LoRA file does not exist
FileNotFoundError: If the LoRA file does not exist (and did not appear
within the configured wait timeout).
"""
if not os.path.exists(lora_path):
raise FileNotFoundError(f"load_lora_weights: LoRA file not found: {lora_path}")
if not _wait_for_lora_file(lora_path):
raise FileNotFoundError(
f"load_lora_weights: LoRA file not found: {lora_path}"
)

if lora_path.endswith(".safetensors"):
return load_file(lora_path)
else:
return torch.load(lora_path, map_location="cpu")



def find_lora_pair(
lora_key: str, lora_state: dict[str, torch.Tensor]
) -> tuple[str, str, torch.Tensor, torch.Tensor] | None:
Expand Down
43 changes: 42 additions & 1 deletion src/scope/server/pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ class PipelineNotAvailableException(Exception):
pass


class PipelineNotYetRegisteredException(ValueError):
"""Exception raised when a pipeline ID is not in the registry yet.

This is a *transient* error — it typically occurs during cloud session
initialization when the frontend concurrently requests a plugin install
and a pipeline load. The pipeline load may arrive before the plugin has
finished installing and registering itself, so the registry lookup returns
``None`` even though the pipeline ID will eventually become valid.

Callers should treat this as a retriable condition rather than a hard
error.
"""

pass


class PipelineStatus(Enum):
"""Pipeline loading status enumeration."""

Expand Down Expand Up @@ -336,6 +352,29 @@ def _load_pipeline_by_id_sync(
)
return True

except PipelineNotYetRegisteredException:
# Transient race condition: the pipeline plugin hasn't finished
# installing yet. Log at WARN (not ERROR) and leave the status as
# NOT_LOADED so the frontend doesn't show an error state and the
# load can be retried transparently once the plugin is registered.
self.set_loading_stage(None)
logger.warning(
f"Pipeline '{key}' is not registered — the plugin may still be "
f"installing. This is likely a transient race condition and will "
f"resolve once the plugin is installed."
)
with self._lock:
self._pipeline_statuses[key] = PipelineStatus.NOT_LOADED
if key in self._pipelines:
del self._pipelines[key]
if key in self._pipeline_load_params:
del self._pipeline_load_params[key]
if key in self._pipeline_registry_ids:
del self._pipeline_registry_ids[key]
if key in self._load_events:
self._load_events[key].set()
return False

except Exception as e:
self.set_loading_stage(None)
from .models_config import get_models_dir
Expand Down Expand Up @@ -1385,7 +1424,9 @@ def _load_pipeline_implementation(
logger.info("OpticalFlow pipeline initialized")
return pipeline
else:
raise ValueError(f"Invalid pipeline ID: {pipeline_id}")
raise PipelineNotYetRegisteredException(
f"Invalid pipeline ID: {pipeline_id}. Plugin may not be installed yet."
)

def is_loaded(self) -> bool:
"""Check if pipeline is loaded and ready (thread-safe)."""
Expand Down
131 changes: 131 additions & 0 deletions tests/test_lora_wait_for_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Tests for the LoRA download-wait helper in lora/utils.py.

Covers the race condition where a Civitai LoRA file is still being
downloaded when LongLivePipeline.__init__ calls load_lora_weights.
See: daydreamlive/scope#937
"""

import os
import threading
import time
from pathlib import Path
from unittest.mock import patch

import pytest

from scope.core.pipelines.wan2_1.lora.utils import _wait_for_lora_file


class TestWaitForLoraFile:
"""Unit tests for _wait_for_lora_file."""

def test_file_already_present_returns_immediately(self, tmp_path: Path):
"""If the file exists before the first check, return True right away."""
lora_file = tmp_path / "model.safetensors"
lora_file.touch()

start = time.monotonic()
result = _wait_for_lora_file(str(lora_file), timeout=10, poll_interval=0.1)
elapsed = time.monotonic() - start

assert result is True
# Should not have slept at all
assert elapsed < 0.5

def test_file_appears_during_wait(self, tmp_path: Path):
"""File appears mid-poll; function returns True after ≤2 poll intervals."""
lora_file = tmp_path / "late.safetensors"

def _create_later():
time.sleep(0.3)
lora_file.touch()

t = threading.Thread(target=_create_later, daemon=True)
t.start()

result = _wait_for_lora_file(str(lora_file), timeout=5, poll_interval=0.1)
t.join()

assert result is True

def test_file_never_appears_returns_false(self, tmp_path: Path):
"""File never shows up; function returns False after timeout."""
missing = str(tmp_path / "missing.safetensors")

result = _wait_for_lora_file(missing, timeout=0.3, poll_interval=0.1)

assert result is False

def test_timeout_zero_disables_wait(self, tmp_path: Path):
"""timeout=0 means skip the poll entirely; missing file → False instantly."""
missing = str(tmp_path / "no_wait.safetensors")

start = time.monotonic()
result = _wait_for_lora_file(missing, timeout=0, poll_interval=0.1)
elapsed = time.monotonic() - start

assert result is False
assert elapsed < 0.1

def test_env_var_overrides_default_timeout(self, tmp_path: Path, monkeypatch):
"""SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT env var controls the default timeout."""
missing = str(tmp_path / "env_timeout.safetensors")
monkeypatch.setenv("SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT", "0.2")

start = time.monotonic()
# Pass timeout=None so env var is picked up
result = _wait_for_lora_file(missing, timeout=None, poll_interval=0.05)
elapsed = time.monotonic() - start

assert result is False
# Should respect the 0.2 s limit (allow generous buffer for CI)
assert elapsed < 1.5


class TestLoadLoraWeightsWaits:
"""Integration-style tests ensuring load_lora_weights uses the poll helper."""

def test_raises_after_timeout_when_file_never_appears(self, tmp_path: Path):
"""load_lora_weights should raise FileNotFoundError when wait times out."""
from scope.core.pipelines.wan2_1.lora.utils import load_lora_weights

missing = str(tmp_path / "never_there.safetensors")
# Short timeout so the test stays fast
with patch.dict(os.environ, {"SCOPE_LORA_DOWNLOAD_WAIT_TIMEOUT": "0.2"}):
with pytest.raises(FileNotFoundError, match="LoRA file not found"):
load_lora_weights(missing)

def test_succeeds_when_file_appears_during_wait(self, tmp_path: Path):
"""load_lora_weights should succeed if the file arrives within the timeout.

We bypass load_lora_weights itself and test _wait_for_lora_file + the
safetensors load in combination, keeping the test fast by using a short
poll interval.
"""
import torch
from safetensors.torch import save_file

from scope.core.pipelines.wan2_1.lora.utils import (
_wait_for_lora_file,
load_lora_weights,
)

lora_file = tmp_path / "delayed.safetensors"

# Write a minimal safetensors file after a short delay
def _write_later():
time.sleep(0.3)
tensors = {"lora_A.weight": torch.zeros(4, 4)}
save_file(tensors, str(lora_file))

t = threading.Thread(target=_write_later, daemon=True)
t.start()

# Directly exercise _wait_for_lora_file with a short poll interval, then
# verify load_lora_weights can read the now-present file.
appeared = _wait_for_lora_file(str(lora_file), timeout=5, poll_interval=0.1)
t.join()

assert appeared is True
result = load_lora_weights(str(lora_file))
assert "lora_A.weight" in result
Loading
Loading