@@ -48,24 +48,16 @@ def cross_entropy(
4848 n , v = logits .shape
4949 losses = torch .zeros ([n ], dtype = logits .dtype , device = logits .device )
5050
51- # Flatten logits once at the beginning
52- logits_flat = logits .view (- 1 )
53-
5451 for tile_n in hl .tile (n ):
5552 # Get data for this tile
5653 labels_tile = labels [tile_n ] # [tile_size]
57- base_indices_tile = tile_n .index * v # [tile_size]
58-
59- # Compute the actual flat indices by adding the label offset
60- flat_indices = base_indices_tile + labels_tile
61-
62- # Load the logits at the target indices
63- logits_at_target = hl .load (logits_flat , [flat_indices ])
6454
6555 # Compute log_softmax for numerical stability
6656 # Load the full rows for this tile
6757 logits_rows = logits [tile_n , :] # [tile_size, V]
6858
59+ logits_at_target = logits_rows .gather (1 , labels_tile .unsqueeze (1 )).squeeze (1 )
60+
6961 # Compute log-sum-exp
7062 max_logits = torch .amax (logits_rows , dim = - 1 , keepdim = True )
7163 shifted = logits_rows - max_logits
@@ -89,7 +81,7 @@ def main() -> None:
8981 """
9082 Main entry point that runs the cross entropy kernel verification.
9183 """
92- batch_size , seq_len , vocab_size = 8 , 2048 , 131072
84+ batch_size , seq_len , vocab_size = 8 , 2048 , 2048
9385 n = batch_size * seq_len
9486 logits = torch .randn (n , vocab_size , device = DEVICE , dtype = torch .float32 )
9587 labels = torch .randint (0 , vocab_size , (n ,), device = DEVICE , dtype = LONG_INT_TYPE )
0 commit comments