[cpu-offloading] Implement prefetching for cpu offloading pass#3166
[cpu-offloading] Implement prefetching for cpu offloading pass#3166mlazos wants to merge 6 commits intogh/mlazos/2/basefrom
Conversation
[ghstack-poisoned]
|
@claude can you review this and ensure that the prefetching per layer does not miss any edge cases and that the tests do not mask any behavior |
|
Claude finished @mlazos's task in 2m 44s —— View job Review: CPU Offload Prefetching Implementation
Prefetching Edge Cases1.
|
[ghstack-poisoned]
|
@claude can you review to see if we are missing any overlap opportunities or edge cases? |
|
Claude finished @mlazos's task in 3m 41s —— View job
Review: Overlap Opportunities & Edge CasesThe core design is sound — decoupling Missed Overlap OpportunityHighest offloaded layer gets zero prefetch benefit (
|
…pass" [ghstack-poisoned]
…pass" [ghstack-poisoned]
…pass" [ghstack-poisoned]
…pass" [ghstack-poisoned]
| continue | ||
|
|
||
| layer_id = _get_reload_layer(node) | ||
| if layer_id != _NOT_IN_LAYERS: |
There was a problem hiding this comment.
I think harder about this... and we can do global optimization.
emb, lmhead and loss are in _NOT_IN_LAYERS
and the full graph is actually
emb, layer0... layerN, lm_head , loss, loss_bw, (layer_bwN..0), emb_bw
we can prefetch/defer out of the layer.* boundary.
So concretely, we can do it like this
- build an list of anchors globally, (include the non-layer region)
- prefetch == move to an anchor before, defer == move to an anchor after.
There was a problem hiding this comment.
This is the unique kind of optimization you can do with full graph, with loss in the graph.
Stack from ghstack (oldest at bottom):