Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,27 @@ extern "C" {
int n_past,
int n_threads);

// Same as whisper_decode_with_state, but saves alignment head cross-attention data.
// Requires context created with dtw_token_timestamps=true and flash_attn=false.
WHISPER_API int whisper_decode_with_state_and_aheads(
struct whisper_context * ctx,
struct whisper_state * state,
const whisper_token * tokens,
int n_tokens,
int n_past,
int n_threads);

// Get cross-attention data from alignment heads after a decode call with aheads enabled.
// Returns pointer to float array of shape [n_tokens x n_audio_ctx x n_heads].
// Copies data from GPU/backend to CPU on each call.
// Returns NULL if DTW is not enabled or no attention data is available.
// The pointer is valid until the next call to this function or whisper_free_state.
WHISPER_API const float * whisper_state_get_aheads_cross_qks(
struct whisper_state * state,
int * n_tokens,
int * n_audio_ctx,
int * n_heads);

// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
Expand Down
32 changes: 32 additions & 0 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3954,6 +3954,38 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i
return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads);
}

int whisper_decode_with_state_and_aheads(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) {
whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0);
whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1);

if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) {
WHISPER_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}

return 0;
}

const float * whisper_state_get_aheads_cross_qks(struct whisper_state * state, int * out_n_tokens, int * out_n_audio_ctx, int * out_n_heads) {
if (state->aheads_cross_QKs == nullptr) {
return nullptr;
}

const int n_tokens = state->aheads_cross_QKs->ne[0];
const int n_audio_ctx = state->aheads_cross_QKs->ne[1];
const int n_heads = state->aheads_cross_QKs->ne[2];

auto & data = state->aheads_cross_QKs_data;
data.resize(n_tokens * n_audio_ctx * n_heads);
ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads);

if (out_n_tokens) *out_n_tokens = n_tokens;
if (out_n_audio_ctx) *out_n_audio_ctx = n_audio_ctx;
if (out_n_heads) *out_n_heads = n_heads;

return data.data();
}

int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) {
const auto res = tokenize(ctx->vocab, text);

Expand Down