Skip to content

Add aten.view.dtype dispatch to optim state subclasses#4216

Open
winglian wants to merge 3 commits intopytorch:mainfrom
winglian:aten-dtype-optim
Open

Add aten.view.dtype dispatch to optim state subclasses#4216
winglian wants to merge 3 commits intopytorch:mainfrom
winglian:aten-dtype-optim

Conversation

@winglian
Copy link
Copy Markdown

@winglian winglian commented Apr 1, 2026

PR #3934 added an "appearance dtype" to OptimState8bit (e.g. bf16 wrapper over uint8 codes), but didn't implement aten.view.dtype.

torch.compile'd 8bit Adam optimizer (e.g. adamw_torch_8bit in HuggingFace transformers) crashes on PyTorch 2.9.1+

fix: register aten.view.dtype for all three optim subclasses, updating only the appearance dtype (consistent with existing aten._to_copy.default behavior).

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 1, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4216

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 3c5acdd with merge base 6807454 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2026
@gau-nernst
Copy link
Copy Markdown
Collaborator

@winglian Do you have a minimal reproducible example? I'm not sure why this is needed.

Also, aten.view.dtype doesn't make sense to me for low-bit tensor subclass...

@andrewor14
Copy link
Copy Markdown
Contributor

@claude review please

@andrewor14 andrewor14 added the module: training quantize_ api training flow label Apr 2, 2026
@claude
Copy link
Copy Markdown

claude bot commented Apr 2, 2026

Claude finished @andrewor14's task in 2m 16s —— View job


Code Review: PR #4216

  • Read changed files and understand the PR
  • Review the implementation for correctness and consistency
  • Check existing patterns and pre-existing bugs
  • Post review feedback

This PR makes three related changes:

  1. Registers aten.view.dtype for OptimState4bit, OptimState8bit, and OptimStateFp8
  2. Adds .clone() calls in aten._to_copy.default for all three subclasses
  3. Propagates dtype=x.dtype through aten.view.default which was previously dropped

The aten.view.dtype dispatch

The implementation is straightforward and consistent across all three subclasses. It only updates the appearance dtype, which is the correct behavior — matching what aten._to_copy.default does for dtype changes. This makes sense for torch.compile's fake tensor infrastructure, which may use view.dtype to create fake tensors with the correct appearance dtype.

That said, I agree with @gau-nernst's request for a minimal reproducible example. It would strengthen the PR to include the specific traceback or a small script that demonstrates the crash on PyTorch 2.9.1+. This helps reviewers confirm the root cause and ensures the fix is actually exercised.

The .clone() addition in _to_copy

The addition of .clone() after .to(device=device) in _to_copy is a separate behavioral change from what the PR title suggests. This ensures the result has independent storage rather than being a view of the original when device is unchanged. The comment explains the motivation (torch.compile fake-tensor metadata checks), but again a concrete example would help.

One concern: .clone() after .to(device=device) when device is a different device will do an unnecessary copy — .to(device=...) already returns a new tensor when the device changes. This is a minor inefficiency, not a correctness issue. A possible improvement would be to only clone when the device is the same:

codes = args[0].codes.to(device=device)
if device is None or codes.data_ptr() == args[0].codes.data_ptr():
    codes = codes.clone()

But this is a micro-optimization and probably not worth the complexity for optimizer state tensors.

Bug: Missing dtype propagation in pre-existing ops

The PR fixes dtype propagation in view.default (good catch), but there are pre-existing call sites that still don't propagate dtype:

aten.slice.Tensor — all three subclasses:

  • subclass_8bit.py:254: return OptimState8bit(x.codes[start:end], ..., x.signed) — missing dtype=x.dtype
  • subclass_4bit.py:291: return OptimState4bit(codes, scale, ..., x.signed, shape) — missing dtype=x.dtype
  • subclass_fp8.py:223: return OptimStateFp8(x.codes[start:end], ...) — missing dtype=x.dtype

c10d ops — all three subclasses:

  • subclass_8bit.py:210: return OptimState8bit(..., x.signed) — missing dtype=x.dtype
  • subclass_4bit.py:244: return OptimState4bit(..., x.signed, shape) — missing dtype=x.dtype
  • subclass_fp8.py:185: return OptimStateFp8(...) — missing dtype=x.dtype

These will silently reset the appearance dtype to None whenever these ops are used. Since this PR is already fixing dtype propagation in view.default, it would be good to fix these too for completeness. Fix all missing dtype propagation →

Tests

There are no tests for the new aten.view.dtype dispatch or for the _to_copy .clone() change. Even a basic unit test would help prevent regressions:

def test_view_dtype():
    state = OptimState8bit.zeros((64,), dtype=torch.bfloat16)
    viewed = torch.ops.aten.view.dtype(state, torch.float32)
    assert viewed.dtype == torch.float32
    assert viewed.shape == state.shape

Summary

The core aten.view.dtype change is reasonable and consistent with existing patterns. The .clone() addition in _to_copy is a separate fix that should ideally be in its own commit with its own justification. The main gap is the missing dtype propagation in slice.Tensor and c10d ops (pre-existing but worth fixing in this PR), and the lack of a minimal reproducer or tests.


@andrewor14
Copy link
Copy Markdown
Contributor

@claude can you suggest a unit test

@claude
Copy link
Copy Markdown

claude bot commented Apr 2, 2026

Claude finished @andrewor14's task in 2m 10s —— View job


Suggesting unit tests for aten.view.dtype dispatch

  • Read changed files and understand the PR
  • Review existing test patterns
  • Suggest unit tests

Here's a suggested unit test that fits the existing patterns in test/test_low_bit_optim.py. It would go in the TestQuantize class alongside the existing test_subclass_appearance_dtype test:

@parametrize("subclass", [OptimState4bit, OptimState8bit, OptimStateFp8])
@parametrize("device", _DEVICES)
def test_view_dtype(self, subclass, device):
    """Test aten.view.dtype dispatch used by torch.compile fake tensor creation."""
    if subclass == OptimStateFp8:
        if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
            pytest.skip("FP8 CUDA requires compute capability >= 8.9")

    shape = (1024,)
    tensor = subclass.zeros(shape, device=device, dtype=torch.bfloat16)
    self.assertEqual(tensor.dtype, torch.bfloat16)

    # aten.view.dtype should change only the appearance dtype
    viewed = torch.ops.aten.view.dtype(tensor, torch.float32)
    self.assertIsInstance(viewed, subclass)
    self.assertEqual(viewed.dtype, torch.float32)
    self.assertEqual(viewed.shape, tensor.shape)

    # internal data should be shared (view, not copy)
    self.assertTrue(viewed.codes.data_ptr() == tensor.codes.data_ptr())
    self.assertTrue(viewed.scale.data_ptr() == tensor.scale.data_ptr())

    # dequantized values should be the same
    torch.testing.assert_close(
        viewed.dequantize(),
        tensor.dequantize(),
    )

This test covers:

  1. Dtype change: verifies the appearance dtype is updated to the target dtype.
  2. Type preservation: verifies the result is still the same subclass.
  3. Shape preservation: verifies the shape is unchanged.
  4. View semantics: verifies internal storage is shared (not copied), which matches the implementation — the dispatch creates a new subclass wrapping the same codes and scale tensors.
  5. Value preservation: verifies the underlying quantized data is unchanged.

It follows the existing conventions: parametrized across all three subclasses and devices, with the FP8 compute capability skip guard.


View job

@gau-nernst
Copy link
Copy Markdown
Collaborator

I don't think implementing aten.view.dtype is right. It means "reinterpreting bytes", which doesnt make sense for scaled low-bit. It would be great if there is a reproducible example to see if implementing this op is actually necessary. Then we can add it to the FSDP test as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants