Add aten.view.dtype dispatch to optim state subclasses#4216
Add aten.view.dtype dispatch to optim state subclasses#4216winglian wants to merge 3 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4216
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 3c5acdd with merge base 6807454 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@winglian Do you have a minimal reproducible example? I'm not sure why this is needed. Also, |
|
@claude review please |
|
Claude finished @andrewor14's task in 2m 16s —— View job Code Review: PR #4216
This PR makes three related changes:
The
|
|
@claude can you suggest a unit test |
|
Claude finished @andrewor14's task in 2m 10s —— View job Suggesting unit tests for
|
|
I don't think implementing aten.view.dtype is right. It means "reinterpreting bytes", which doesnt make sense for scaled low-bit. It would be great if there is a reproducible example to see if implementing this op is actually necessary. Then we can add it to the FSDP test as well. |
PR #3934 added an "appearance dtype" to OptimState8bit (e.g. bf16 wrapper over uint8 codes), but didn't implement aten.view.dtype.
torch.compile'd 8bit Adam optimizer (e.g. adamw_torch_8bit in HuggingFace transformers) crashes on PyTorch 2.9.1+
fix: register aten.view.dtype for all three optim subclasses, updating only the appearance dtype (consistent with existing aten._to_copy.default behavior).