Skip to content

Softmax twice for SGS loss? #18

@Yu-Zhewen

Description

@Yu-Zhewen

Dear authors, thanks for this nice work.

I wonder why the calculation of the SGS loss is using the softmaxed data rather than the logits, considering the PyTorch CrossEntropyLoss already contains a softmax inside.

g_loss = loss_fn(m.keep_gate, gate_target)

self.keep_gate, self.print_gate, self.print_idx = gumbel_softmax(channel_choice, dim=1, training=self.training)
self.channel_choice = self.print_gate, self.print_idx
else:
self.channel_choice = None
return x
def get_gate(self):
return self.channel_choice
def gumbel_softmax(logits, tau=1, hard=False, dim=1, training=True):
""" See `torch.nn.functional.gumbel_softmax()` """
# if training:
# gumbels = -torch.empty_like(logits,
# memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1)
# gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
# # else:
# # gumbels = logits
# y_soft = gumbels.softmax(dim)
gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() # ~Gumbel(0,1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
with torch.no_grad():
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
# **test**
# index = 0
# y_hard = torch.Tensor([1, 0, 0, 0]).repeat(logits.shape[0], 1).cuda()
ret = y_hard - y_soft.detach() + y_soft
return y_soft, ret, index

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions