upgrade torchao to 0.17.0#3569
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR updates pinned dependencies in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
src/axolotl/utils/quantization.py (1)
103-106: Consider adding test coverage for theweight_granularityparameter.The new code path that sets
weight_granularity=PerGroup(group_size=...)whengroup_sizeis provided (Lines 104-105) is not directly validated in the tests. Pertests/e2e/test_quantization.py:63-83, theptq_config_test_casesentry for(int4, int8)usesgroup_size=None, and the test at Lines 117-128 only validates the returned type viaisinstance().While
ptq_test_casesdoes includegroup_size=8for this combination, it only validates the quantized tensor class, not the config'sweight_granularityattribute. Consider adding a test case that verifies thePerGroupgranularity is correctly set.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/axolotl/utils/quantization.py` around lines 103 - 106, Add a unit test that covers the code path when group_size is provided by calling the same factory that builds the Int8DynamicActivationIntxWeightConfig (the code that currently sets kwargs = {"weight_dtype": torch.int4} and returns Int8DynamicActivationIntxWeightConfig(**kwargs)), pass group_size=8 (or another value), and assert that the returned config.weight_granularity is an instance of PerGroup and that its group_size equals the value passed; this ensures the PerGroup(weight_granularity=...) branch is exercised and validated.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/core/builders/base.py`:
- Around line 331-335: The import for AdamWFp8 is using the wrong module path;
inside the branch that checks self.cfg.optimizer == "ao_adamw_fp8" replace the
import statement so it imports AdamWFp8 from torchao.optim (i.e., use "from
torchao.optim import AdamWFp8"), leaving the rest of the block (setting
optimizer_cls = AdamWFp8 and updating optimizer_kwargs with adam_kwargs)
unchanged so optimizer_cls and optimizer_kwargs continue to work as before.
---
Nitpick comments:
In `@src/axolotl/utils/quantization.py`:
- Around line 103-106: Add a unit test that covers the code path when group_size
is provided by calling the same factory that builds the
Int8DynamicActivationIntxWeightConfig (the code that currently sets kwargs =
{"weight_dtype": torch.int4} and returns
Int8DynamicActivationIntxWeightConfig(**kwargs)), pass group_size=8 (or another
value), and assert that the returned config.weight_granularity is an instance of
PerGroup and that its group_size equals the value passed; this ensures the
PerGroup(weight_granularity=...) branch is exercised and validated.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 65fbea7f-3aad-4d12-887e-953d06c75a8f
📒 Files selected for processing (4)
requirements.txtsrc/axolotl/core/builders/base.pysrc/axolotl/utils/quantization.pytests/e2e/test_quantization.py
| elif self.cfg.optimizer == "ao_adamw_fp8": | ||
| from torchao.prototype.low_bit_optim import AdamWFp8 | ||
| from torchao.optim.adam import AdamWFp8 | ||
|
|
||
| optimizer_cls = AdamWFp8 | ||
| optimizer_kwargs.update(adam_kwargs) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
What is the import path for AdamWFp8 in torchao 0.17.0?
💡 Result:
The import path for AdamWFp8 in torchao 0.17.0 is from torchao.optim import AdamWFp8.
Citations:
- 1: https://pypi.org/project/torchao/0.17.0/
- 2: https://github.com/pytorch/ao
- 3: Not Seeing much memory savings with Fp8 optimizer suddenly pytorch/ao#1499
Fix incorrect import path for AdamWFp8 in torchao 0.17.0.
The import path should be from torchao.optim import AdamWFp8, not from torchao.optim.adam import AdamWFp8. Update line 332 to use the correct module path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/axolotl/core/builders/base.py` around lines 331 - 335, The import for
AdamWFp8 is using the wrong module path; inside the branch that checks
self.cfg.optimizer == "ao_adamw_fp8" replace the import statement so it imports
AdamWFp8 from torchao.optim (i.e., use "from torchao.optim import AdamWFp8"),
leaving the rest of the block (setting optimizer_cls = AdamWFp8 and updating
optimizer_kwargs with adam_kwargs) unchanged so optimizer_cls and
optimizer_kwargs continue to work as before.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
| import torch | ||
| from torch.utils._python_dispatch import return_and_correct_aliasing | ||
|
|
||
| logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
Use the axolotl logger and not this
|
The upstream fixes for the ao 8bit and 4bit optimizers are here: pytorch/ao#4216, but we've patched the aten's in axolotl for now. |
Summary by CodeRabbit
Release Notes
Chores
torchaoto 0.17.0 andmistral-commonto 1.11.0Improvements