diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py index 1311987f2..7d26d5ea6 100644 --- a/funasr/models/sense_voice/model.py +++ b/funasr/models/sense_voice/model.py @@ -19,7 +19,7 @@ from funasr.models.paraformer.search import Hypothesis - +import math class SinusoidalPositionEncoder(torch.nn.Module): """ """ @@ -884,7 +884,13 @@ def inference( for i in range(b): x = ctc_logits[i, : encoder_out_lens[i].item(), :] yseq = x.argmax(dim=-1) + yseq_prob = x.max(dim=-1).values yseq = torch.unique_consecutive(yseq, dim=-1) + + token_probs = [] + for idx, token in enumerate(yseq): + if token != self.blank_id: + token_probs.append((tokenizer.decode(yseq[idx].tolist()), math.exp(yseq_prob[idx].item()))) ibest_writer = None if kwargs.get("output_dir") is not None: @@ -898,7 +904,7 @@ def inference( # Change integer-ids to tokens text = tokenizer.decode(token_int) - result_i = {"key": key[i], "text": text} + result_i = {"key": key[i], "text": text,"token_probs": token_probs} results.append(result_i) if ibest_writer is not None: