Skip to content

Use torch.gather instead of generic int indexing for cross_entropy example#2058

Draft
AmesingFlank wants to merge 1 commit intomainfrom
AmesingFlank/stack/25
Draft

Use torch.gather instead of generic int indexing for cross_entropy example#2058
AmesingFlank wants to merge 1 commit intomainfrom
AmesingFlank/stack/25

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented Apr 20, 2026

Stacked PRs:


Use torch.gather instead of generic int indexing for cross_entropy example. The benefit of this is that the gather operation is applied on a tile, which is more friendly than an int-indexing expression on the entire tensor. Paired with #2060, this unblocks this cross_entropy example on TPUs

AmesingFlank added a commit that referenced this pull request Apr 20, 2026
…ample

stack-info: PR: #2058, branch: AmesingFlank/stack/25
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/25 branch from d29d610 to 341a5ff Compare April 20, 2026 19:17
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 20, 2026
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 19:23
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/25 branch from 341a5ff to bcb53ba Compare April 20, 2026 19:23
@AmesingFlank AmesingFlank marked this pull request as ready for review April 20, 2026 19:23
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 19:25
@AmesingFlank AmesingFlank marked this pull request as ready for review April 20, 2026 19:36
AmesingFlank added a commit that referenced this pull request Apr 20, 2026
…ample

stack-info: PR: #2058, branch: AmesingFlank/stack/25
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 19:38
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/25 branch from bcb53ba to c160d88 Compare April 20, 2026 19:38
@AmesingFlank AmesingFlank marked this pull request as ready for review April 20, 2026 19:38
AmesingFlank added a commit that referenced this pull request Apr 20, 2026
…ample

stack-info: PR: #2058, branch: AmesingFlank/stack/25
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 19:39
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/25 branch from c160d88 to efbe2a5 Compare April 20, 2026 19:39
@AmesingFlank AmesingFlank marked this pull request as ready for review April 20, 2026 19:39
…ample

stack-info: PR: #2058, branch: AmesingFlank/stack/25
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 19:42
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/25 branch from efbe2a5 to c95b79f Compare April 20, 2026 19:42
@AmesingFlank AmesingFlank marked this pull request as ready for review April 20, 2026 19:42
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 21:39
@AmesingFlank AmesingFlank marked this pull request as ready for review April 20, 2026 21:39
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 22:00
@AmesingFlank AmesingFlank marked this pull request as ready for review April 20, 2026 22:00
@AmesingFlank AmesingFlank requested a review from norx1991 April 20, 2026 22:25
@norx1991
Copy link
Copy Markdown
Contributor

Is there a way to know how this is affecting the GPU performance of this kernel? Or is the kernel code (e.g., triton) the same?

@jansel
Copy link
Copy Markdown
Contributor

jansel commented Apr 21, 2026

  1. Let's benchmark this to make sure it doesn't regress GPU
  2. Is it possible to make both versions work? We should at least add a test of the prior version so we can ensure it doesn't break on GPU in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants