Dear authors, thanks for this nice work.
I wonder why the calculation of the SGS loss is using the softmaxed data rather than the logits, considering the PyTorch CrossEntropyLoss already contains a softmax inside.
|
g_loss = loss_fn(m.keep_gate, gate_target) |
|
self.keep_gate, self.print_gate, self.print_idx = gumbel_softmax(channel_choice, dim=1, training=self.training) |
|
self.channel_choice = self.print_gate, self.print_idx |
|
else: |
|
self.channel_choice = None |
|
|
|
return x |
|
|
|
def get_gate(self): |
|
return self.channel_choice |
|
|
|
|
|
def gumbel_softmax(logits, tau=1, hard=False, dim=1, training=True): |
|
""" See `torch.nn.functional.gumbel_softmax()` """ |
|
# if training: |
|
# gumbels = -torch.empty_like(logits, |
|
# memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1) |
|
# gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) |
|
# # else: |
|
# # gumbels = logits |
|
# y_soft = gumbels.softmax(dim) |
|
|
|
gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1) |
|
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) |
|
y_soft = gumbels.softmax(dim) |
|
with torch.no_grad(): |
|
index = y_soft.max(dim, keepdim=True)[1] |
|
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) |
|
# **test** |
|
# index = 0 |
|
# y_hard = torch.Tensor([1, 0, 0, 0]).repeat(logits.shape[0], 1).cuda() |
|
ret = y_hard - y_soft.detach() + y_soft |
|
return y_soft, ret, index |
Dear authors, thanks for this nice work.
I wonder why the calculation of the SGS loss is using the softmaxed data rather than the logits, considering the PyTorch CrossEntropyLoss already contains a softmax inside.
DS-Net/dyn_slim/apis/train_slim_gate.py
Line 98 in 15cd303
DS-Net/dyn_slim/models/dyn_slim_blocks.py
Lines 324 to 355 in 15cd303