Skip to content

Commit efbe2a5

Browse files
committed
Use torch.gather instead of generic int indexing for cross_entropy example
stack-info: PR: #2058, branch: AmesingFlank/stack/25
1 parent 584027c commit efbe2a5

1 file changed

Lines changed: 3 additions & 11 deletions

File tree

examples/cross_entropy.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,24 +48,16 @@ def cross_entropy(
4848
n, v = logits.shape
4949
losses = torch.zeros([n], dtype=logits.dtype, device=logits.device)
5050

51-
# Flatten logits once at the beginning
52-
logits_flat = logits.view(-1)
53-
5451
for tile_n in hl.tile(n):
5552
# Get data for this tile
5653
labels_tile = labels[tile_n] # [tile_size]
57-
base_indices_tile = tile_n.index * v # [tile_size]
58-
59-
# Compute the actual flat indices by adding the label offset
60-
flat_indices = base_indices_tile + labels_tile
61-
62-
# Load the logits at the target indices
63-
logits_at_target = hl.load(logits_flat, [flat_indices])
6454

6555
# Compute log_softmax for numerical stability
6656
# Load the full rows for this tile
6757
logits_rows = logits[tile_n, :] # [tile_size, V]
6858

59+
logits_at_target = logits_rows.gather(1, labels_tile.unsqueeze(1)).squeeze(1)
60+
6961
# Compute log-sum-exp
7062
max_logits = torch.amax(logits_rows, dim=-1, keepdim=True)
7163
shifted = logits_rows - max_logits
@@ -89,7 +81,7 @@ def main() -> None:
8981
"""
9082
Main entry point that runs the cross entropy kernel verification.
9183
"""
92-
batch_size, seq_len, vocab_size = 8, 2048, 131072
84+
batch_size, seq_len, vocab_size = 8, 2048, 2048
9385
n = batch_size * seq_len
9486
logits = torch.randn(n, vocab_size, device=DEVICE, dtype=torch.float32)
9587
labels = torch.randint(0, vocab_size, (n,), device=DEVICE, dtype=LONG_INT_TYPE)

0 commit comments

Comments
 (0)