Skip to content

Commit 7a1a7ce

Browse files
committed
Fix Flux2 LoRA guidance conversion
1 parent 71a6fd9 commit 7a1a7ce

3 files changed

Lines changed: 74 additions & 1 deletion

File tree

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24212421
"txt_in": "context_embedder",
24222422
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
24232423
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
2424+
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
2425+
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
24242426
"final_layer.linear": "proj_out",
24252427
"final_layer.adaLN_modulation.1": "norm_out.linear",
24262428
"single_stream_modulation.lin": "single_stream_modulation.linear",

src/diffusers/loaders/lora_pipeline.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5703,6 +5703,13 @@ def load_lora_weights(
57035703
"""
57045704
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
57055705
"""
5706+
transformer = getattr(self, self.transformer_name, None)
5707+
if transformer is None:
5708+
raise ValueError(
5709+
"Flux2 LoRA weights can only be loaded into a pipeline that defines a `transformer` component. "
5710+
"This modular sub-pipeline only exposes a subset of components."
5711+
)
5712+
57065713
if not USE_PEFT_BACKEND:
57075714
raise ValueError("PEFT backend is required for this method.")
57085715

@@ -5726,7 +5733,7 @@ def load_lora_weights(
57265733

57275734
self.load_lora_into_transformer(
57285735
state_dict,
5729-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5736+
transformer=transformer,
57305737
adapter_name=adapter_name,
57315738
metadata=metadata,
57325739
_pipeline=self,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_flux2_lora_to_diffusers
21+
from diffusers.modular_pipelines.flux2.encoders import Flux2KleinTextEncoderStep
22+
23+
24+
class Flux2LoraConversionTests(unittest.TestCase):
25+
def test_convert_non_diffusers_flux2_lora_maps_guidance_embedder(self):
26+
state_dict = {
27+
"diffusion_model.img_in.lora_A.weight": torch.randn(2, 2),
28+
"diffusion_model.img_in.lora_B.weight": torch.randn(2, 2),
29+
"diffusion_model.txt_in.lora_A.weight": torch.randn(2, 2),
30+
"diffusion_model.txt_in.lora_B.weight": torch.randn(2, 2),
31+
"diffusion_model.time_in.in_layer.lora_A.weight": torch.randn(2, 2),
32+
"diffusion_model.time_in.in_layer.lora_B.weight": torch.randn(2, 2),
33+
"diffusion_model.time_in.out_layer.lora_A.weight": torch.randn(2, 2),
34+
"diffusion_model.time_in.out_layer.lora_B.weight": torch.randn(2, 2),
35+
"diffusion_model.guidance_in.in_layer.lora_A.weight": torch.randn(2, 2),
36+
"diffusion_model.guidance_in.in_layer.lora_B.weight": torch.randn(2, 2),
37+
"diffusion_model.guidance_in.out_layer.lora_A.weight": torch.randn(2, 2),
38+
"diffusion_model.guidance_in.out_layer.lora_B.weight": torch.randn(2, 2),
39+
}
40+
41+
converted_state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
42+
43+
expected_keys = {
44+
"transformer.x_embedder.lora_A.weight",
45+
"transformer.x_embedder.lora_B.weight",
46+
"transformer.context_embedder.lora_A.weight",
47+
"transformer.context_embedder.lora_B.weight",
48+
"transformer.time_guidance_embed.timestep_embedder.linear_1.lora_A.weight",
49+
"transformer.time_guidance_embed.timestep_embedder.linear_1.lora_B.weight",
50+
"transformer.time_guidance_embed.timestep_embedder.linear_2.lora_A.weight",
51+
"transformer.time_guidance_embed.timestep_embedder.linear_2.lora_B.weight",
52+
"transformer.time_guidance_embed.guidance_embedder.linear_1.lora_A.weight",
53+
"transformer.time_guidance_embed.guidance_embedder.linear_1.lora_B.weight",
54+
"transformer.time_guidance_embed.guidance_embedder.linear_2.lora_A.weight",
55+
"transformer.time_guidance_embed.guidance_embedder.linear_2.lora_B.weight",
56+
}
57+
58+
self.assertEqual(set(converted_state_dict.keys()), expected_keys)
59+
60+
def test_flux2_text_subpipeline_rejects_transformer_lora_loading(self):
61+
text_pipe = Flux2KleinTextEncoderStep().init_pipeline()
62+
63+
with self.assertRaisesRegex(ValueError, "defines a `transformer` component"):
64+
text_pipe.load_lora_weights({})

0 commit comments

Comments
 (0)