Skip to content

Commit d4b5142

Browse files
committed
overlap and add
1 parent bfc4191 commit d4b5142

2 files changed

Lines changed: 64 additions & 24 deletions

File tree

compiam/separation/music_source_separation/mixer_model/__init__.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
self.load_model(self.model_path)
7070

7171
self.chunk_size = self.model.chunk_size
72+
self.overlap = 0.25
7273

7374
def forward(self, x):
7475
"""Forward pass of the mixer model"""
@@ -156,27 +157,46 @@ def separate(
156157
and the model is trained on mono audio."
157158
)
158159

159-
# audio has shape B, 1, N
160+
initial_length = audio.shape[-1]
160161
audio = audio.reshape(-1)
161-
predictions = []
162-
pad_length = self.chunk_size - (audio.shape[-1] % self.chunk_size)
162+
pad_length = (self.chunk_size - (audio.shape[-1] % self.chunk_size)) % self.chunk_size
163163
audio = torch.nn.functional.pad(audio, (0, pad_length))
164164

165-
for i in range(0, audio.shape[-1], self.chunk_size):
166-
audio_chunk = audio[i : i + self.chunk_size].reshape(
167-
1, 1, -1
168-
) # TODO Batching
169-
predictions.append(self.forward(audio_chunk))
165+
chunk_size = audio.shape[-1] // ((audio.shape[-1] + self.chunk_size - 1) // self.chunk_size)
166+
hop_size = int(chunk_size * (1 - self.overlap))
167+
num_chunks = (audio.shape[-1] - chunk_size) // hop_size + 1
170168

171-
result = torch.cat(predictions, dim=-1)
172-
result = result[:, :, :-pad_length]
169+
window = torch.hann_window(chunk_size)
170+
out = torch.zeros((2, audio.shape[-1])) # (Channels=2, Time)
171+
weight_sum = torch.zeros(audio.shape[-1]) # Weight accumulation for normalization
172+
173+
# Process chunks
174+
for i in range(num_chunks):
175+
start = i * hop_size
176+
end = start + chunk_size
177+
178+
# Extract chunk (reshape for model input)
179+
audio_chunk = audio[start:end].reshape(1, 1, -1)
180+
181+
# Apply model separation (assumes 2-channel output)
182+
separated_chunk = self.forward(audio_chunk).reshape(2, -1) # (2, chunk_size)
183+
184+
# Apply windowing
185+
separated_chunk *= window # Smooth transition
186+
187+
# Overlap-Add to output
188+
out[:, start:end] += separated_chunk
189+
weight_sum[start:end] += window # Accumulate weights
190+
191+
out /= weight_sum.unsqueeze(0).clamp(min=1e-8) # Avoid division by zero
192+
out = out[..., :initial_length].unsqueeze(0) # (1, 2, N)
173193

174194
vocal_separation = torchaudio.transforms.Resample(
175195
orig_freq=self.sample_rate, new_freq=input_sr
176-
)(result[:, 0, :])
196+
)(out[:, 0, :])
177197
violin_separation = torchaudio.transforms.Resample(
178198
orig_freq=self.sample_rate, new_freq=input_sr
179-
)(result[:, 1, :])
199+
)(out[:, 1, :])
180200

181201
vocal_separation = vocal_separation.detach().cpu().numpy().reshape(-1)
182202
violin_separation = violin_separation.detach().cpu().numpy().reshape(-1)

compiam/separation/singing_voice_extraction/convtdf_vocal_finetune.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
self.load_model(self.model_path)
6969

7070
self.chunk_size = self.model.chunk_size
71+
self.overlap = 0.25
7172

7273
def forward(self, x):
7374
"""Forward pass of the mixer model"""
@@ -155,25 +156,44 @@ def separate(
155156
and the model is trained on mono audio."
156157
)
157158

158-
# audio has shape B, 1, N
159+
initial_length = audio.shape[-1]
159160
audio = audio.reshape(-1)
160-
predictions = []
161-
pad_length = self.chunk_size - (audio.shape[-1] % self.chunk_size)
161+
pad_length = (self.chunk_size - (audio.shape[-1] % self.chunk_size)) % self.chunk_size
162162
audio = torch.nn.functional.pad(audio, (0, pad_length))
163163

164-
for i in range(0, audio.shape[-1], self.chunk_size):
165-
audio_chunk = audio[i : i + self.chunk_size].reshape(
166-
1, 1, -1
167-
) # TODO Batching
168-
predictions.append(self.forward(audio_chunk))
164+
chunk_size = audio.shape[-1] // ((audio.shape[-1] + self.chunk_size - 1) // self.chunk_size)
165+
hop_size = int(chunk_size * (1 - self.overlap))
166+
num_chunks = (audio.shape[-1] - chunk_size) // hop_size + 1
169167

170-
result = torch.cat(predictions, dim=-1)
171-
result = result[:, :, :-pad_length]
168+
window = torch.hann_window(chunk_size)
169+
out = torch.zeros(audio.shape[-1]) # (Time,)
170+
weight_sum = torch.zeros(audio.shape[-1]) # Weight accumulation for normalization
171+
172+
# Process chunks
173+
for i in range(num_chunks):
174+
start = i * hop_size
175+
end = start + chunk_size
176+
177+
# Extract chunk (reshape for model input)
178+
audio_chunk = audio[start:end].reshape(1, 1, -1)
179+
180+
# Apply model separation (now outputs 1-channel)
181+
separated_chunk = self.forward(audio_chunk).reshape(-1) # (chunk_size,)
182+
183+
# Apply windowing
184+
separated_chunk *= window # Smooth transition
185+
186+
# Overlap-Add to output
187+
out[start:end] += separated_chunk
188+
weight_sum[start:end] += window # Accumulate weights
189+
190+
out /= weight_sum.clamp(min=1e-8) # Avoid division by zero
191+
out = out[:initial_length].unsqueeze(0) # (1, N)
172192

173193
vocal_separation = torchaudio.transforms.Resample(
174194
orig_freq=self.sample_rate, new_freq=input_sr
175-
)(result)
176-
195+
)(out)
196+
177197
return vocal_separation.detach().cpu().numpy().reshape(-1)
178198

179199
def download_model(self, model_path=None, force_overwrite=False):

0 commit comments

Comments
 (0)