NoRepeatNGramLogitsProcessor的

编程入门 行业动态 更新时间:2024-10-11 03:20:05

NoRepeatNGramLogitsProcessor的

NoRepeatNGramLogitsProcessor的

#transformer.generation_logits_process NoRepeatNGramLogitsProcessor的_calc_banned_ngram_tokens目的是生成不重复的ngram

import torch
from typing import List, Iterable


def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams


def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    # Before decoding the next token, prevent decoding of ngrams that have already appeared
    start_idx = cur_len + 1 - ngram_size
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    return banned_ngrams.get(ngram_idx, [])


def _calc_banned_ngram_tokens(
        ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int
) -> List[Iterable[int]]:
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

    banned_tokens = [
        _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens


x = _get_ngrams(3, torch.LongTensor([[0, 5, 6, 0, 5, 6], [0, 6, 4, 3, 2, 1]]), 2)
print('x', x)

y = _calc_banned_ngram_tokens(3, torch.LongTensor([[0, 5, 6, 0, 5, 6], [0, 6, 4, 3, 2, 1]]), 2, 6)
print('y', y)
 

输出:

x [{(0, 5): [6, 6], (5, 6): [0], (6, 0): [5]}, {(0, 6): [4], (6, 4): [3], (4, 3): [2], (3, 2): [1]}]
y [[0], []] #[0, 5, 6, 0, 5, 6]序列后禁止输入0.

更多推荐

NoRepeatNGramLogitsProcessor的

本文发布于:2024-02-07 05:48:57,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1753203.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:NoRepeatNGramLogitsProcessor

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!