File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -596,7 +596,8 @@ def forward(
596596 input_features = audio_values ,
597597 feature_lens = audio_feature_lengths ,
598598 )
599- inputs_embeds = inputs_embeds .masked_scatter (audio_mask .unsqueeze (- 1 ), audio_embeds )
599+ expanded_audio_mask = audio_mask .unsqueeze (- 1 ).expand_as (inputs_embeds )
600+ inputs_embeds = inputs_embeds .masked_scatter (expanded_audio_mask , audio_embeds )
600601
601602 hidden_states = self .language_model (
602603 input_ids = input_ids ,
Original file line number Diff line number Diff line change 1010from .base import MediaIO
1111
1212
13- class AudioMediaIO (MediaIO [tuple [npt .NDArray , float ]]):
13+ class AudioMediaIO (MediaIO [tuple [npt .NDArray , int ]]):
1414
1515 def __init__ (self , ** kwargs ) -> None :
1616 super ().__init__ ()
@@ -35,17 +35,17 @@ def __init__(self, **kwargs) -> None:
3535 # for potential custom arguments from --media-io-kwargs
3636 self .kwargs = kwargs
3737
38- def load_bytes (self , data : bytes ) -> tuple [npt .NDArray , float ]:
38+ def load_bytes (self , data : bytes ) -> tuple [npt .NDArray , int ]:
3939 return self ._librosa .load (BytesIO (data ), sr = self .sampling_rate )
4040
4141 def load_base64 (
4242 self ,
4343 media_type : str ,
4444 data : str ,
45- ) -> tuple [npt .NDArray , float ]:
45+ ) -> tuple [npt .NDArray , int ]:
4646 return self .load_bytes (base64 .b64decode (data ))
4747
48- def load_file (self , filepath : Path ) -> tuple [npt .NDArray , float ]:
48+ def load_file (self , filepath : Path ) -> tuple [npt .NDArray , int ]:
4949 return self ._librosa .load (filepath , sr = self .sampling_rate )
5050
5151 def encode_base64 (
You can’t perform that action at this time.
0 commit comments