@@ -69,6 +69,7 @@ def __init__(
6969 self .load_model (self .model_path )
7070
7171 self .chunk_size = self .model .chunk_size
72+ self .overlap = 0.25
7273
7374 def forward (self , x ):
7475 """Forward pass of the mixer model"""
@@ -156,27 +157,46 @@ def separate(
156157 and the model is trained on mono audio."
157158 )
158159
159- # audio has shape B, 1, N
160+ initial_length = audio . shape [ - 1 ]
160161 audio = audio .reshape (- 1 )
161- predictions = []
162- pad_length = self .chunk_size - (audio .shape [- 1 ] % self .chunk_size )
162+ pad_length = (self .chunk_size - (audio .shape [- 1 ] % self .chunk_size )) % self .chunk_size
163163 audio = torch .nn .functional .pad (audio , (0 , pad_length ))
164164
165- for i in range (0 , audio .shape [- 1 ], self .chunk_size ):
166- audio_chunk = audio [i : i + self .chunk_size ].reshape (
167- 1 , 1 , - 1
168- ) # TODO Batching
169- predictions .append (self .forward (audio_chunk ))
165+ chunk_size = audio .shape [- 1 ] // ((audio .shape [- 1 ] + self .chunk_size - 1 ) // self .chunk_size )
166+ hop_size = int (chunk_size * (1 - self .overlap ))
167+ num_chunks = (audio .shape [- 1 ] - chunk_size ) // hop_size + 1
170168
171- result = torch .cat (predictions , dim = - 1 )
172- result = result [:, :, :- pad_length ]
169+ window = torch .hann_window (chunk_size )
170+ out = torch .zeros ((2 , audio .shape [- 1 ])) # (Channels=2, Time)
171+ weight_sum = torch .zeros (audio .shape [- 1 ]) # Weight accumulation for normalization
172+
173+ # Process chunks
174+ for i in range (num_chunks ):
175+ start = i * hop_size
176+ end = start + chunk_size
177+
178+ # Extract chunk (reshape for model input)
179+ audio_chunk = audio [start :end ].reshape (1 , 1 , - 1 )
180+
181+ # Apply model separation (assumes 2-channel output)
182+ separated_chunk = self .forward (audio_chunk ).reshape (2 , - 1 ) # (2, chunk_size)
183+
184+ # Apply windowing
185+ separated_chunk *= window # Smooth transition
186+
187+ # Overlap-Add to output
188+ out [:, start :end ] += separated_chunk
189+ weight_sum [start :end ] += window # Accumulate weights
190+
191+ out /= weight_sum .unsqueeze (0 ).clamp (min = 1e-8 ) # Avoid division by zero
192+ out = out [..., :initial_length ].unsqueeze (0 ) # (1, 2, N)
173193
174194 vocal_separation = torchaudio .transforms .Resample (
175195 orig_freq = self .sample_rate , new_freq = input_sr
176- )(result [:, 0 , :])
196+ )(out [:, 0 , :])
177197 violin_separation = torchaudio .transforms .Resample (
178198 orig_freq = self .sample_rate , new_freq = input_sr
179- )(result [:, 1 , :])
199+ )(out [:, 1 , :])
180200
181201 vocal_separation = vocal_separation .detach ().cpu ().numpy ().reshape (- 1 )
182202 violin_separation = violin_separation .detach ().cpu ().numpy ().reshape (- 1 )
0 commit comments