Skip to content

Commit a05bc37

Browse files
authored
fix dynamic ntk device (#4483)
1 parent 1ba95a9 commit a05bc37

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

lmdeploy/pytorch/backends/default/rotary_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,10 @@ def __init__(self, dim: int, base: int = 10000, scaling_factor: float = 1.0, max
114114

115115
def _ntk_inv_freq(self, seq_len: torch.Tensor):
116116
"""ntk_inv_freq."""
117+
device = seq_len.device
117118
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
118119
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
119-
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
120+
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
120121
return inv_freq
121122

122123
def forward(self, x: torch.Tensor, position_ids: torch.Tensor):

0 commit comments

Comments
 (0)