Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/boltz/model/models/boltz2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from boltz.model.optim.ema import EMA
from boltz.model.optim.scheduler import AlphaFoldLRScheduler
from boltz.model.modules.utils import autocast_device_type


class Boltz2(LightningModule):
Expand Down Expand Up @@ -529,7 +530,7 @@ def forward(
"token_trans_bias": token_trans_bias,
}

with torch.autocast("cuda", enabled=False):
with torch.autocast(autocast_device_type(s.device.type), enabled=False):
struct_out = self.structure_module.sample(
s_trunk=s.float(),
s_inputs=s_inputs.float(),
Expand Down Expand Up @@ -568,7 +569,7 @@ def forward(
feats["coords"] = atom_coords # (multiplicity, L, 3)
assert len(feats["coords"].shape) == 3

with torch.autocast("cuda", enabled=False):
with torch.autocast(autocast_device_type(s.device.type), enabled=False):
struct_out = self.structure_module(
s_trunk=s.float(),
s_inputs=s_inputs.float(),
Expand Down Expand Up @@ -625,7 +626,7 @@ def forward(
]
s_inputs = self.input_embedder(feats, affinity=True)

with torch.autocast("cuda", enabled=False):
with torch.autocast(autocast_device_type(s.device.type), enabled=False):
if self.affinity_ensemble:
dict_out_affinity1 = self.affinity_module1(
s_inputs=s_inputs.detach(),
Expand Down
12 changes: 12 additions & 0 deletions src/boltz/model/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@

LinearNoBias = partial(Linear, bias=False)

def autocast_device_type(device_type: str) -> str:
"""Return a device_type string accepted by ``torch.autocast``.

When autocast is used with ``enabled=False`` (to disable autocasting),
PyTorch still validates the device_type. MPS was not a valid autocast
device type until PyTorch 2.4. Since ``enabled=False`` is a no-op, we
fall back to ``"cpu"``, which is always accepted.
"""
from torch.amp.autocast_mode import is_autocast_available

return device_type if is_autocast_available(device_type) else "cpu"


def exists(v):
return v is not None
Expand Down