Learn about beam search algorithm

Beam search 是seq2seq中decoder的常用算法,用于生成sequence,比greedy decoding 效果更好,但是计算量更大,也更耗时。下面是Transformer model中的beam search 中的程序的分析。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def beam_search(self, encoding, mask_src=None, mask_trg=None, width=2, alpha=1.0):  
# W, width: beamsize, alpha: sentence length-norm
# B : batch size
# T : max sentence length in src and trg language
# C : hidden state size of decoder
# encoding 是list,每一个element是encoder相应layer输出的hidden state, encoding[0]是input embedding
# encoding size : B*W x T x C
# mask_src size : B*W x T
# outs size : B x W x (T+1)
# logps size : B x W
# hidden size : B x W x T x C
# eos_yet size : B x W
# topk2_logps size : B x W x Vocab
# topk_beam_inds size : B x W
# topk_token_inds size : B x W
# embedW used to transfer from index to word embedding in trg language
# eos_yet : indicate whether the sentence has come to an end. when 1, then yes.



encoding = encoding[1:]
W = width
B, T, C = encoding[0].size()

# expanding
for i in range(len(encoding)):
encoding[i] = encoding[i][:, None, :].expand(
B, W, T, C).contiguous().view(B * W, T, C)
mask_src = mask_src[:, None, :].expand(B, W, T).contiguous().view(B * W, T)

T *= self.length_ratio # used for sentence length normalization
outs = Variable(encoding[0].data.new(B, W, T + 1).long().fill_(
self.field.vocab.stoi['<pad>']))
outs[:,:,0] = self.field.vocab.stoi['<init>']

logps = Variable(encoding[0].data.new(B, W).float().fill_(0)) # scores
hiddens = [Variable(encoding[0].data.new(B, W, T, C).zero_()) # decoder states: batch x beamsize x len x h
for l in range(len(self.layers) + 1)]
embedW = self.out.weight * math.sqrt(self.d_model)
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
eos_yet = encoding[0].data.new(B, W).byte().zero_() # batch x beamsize, all the sentences are not finished yet.
eos_mask = eos_yet.float().new(B,W,W).fill_(INF)
eos_mask[:, :, 0] = 0 # batch x beam x beam

for t in range(T):
hiddens[0][:, :, t] = self.dropout(
hiddens[0][:, :, t] + F.embedding(outs[:, :, t], embedW))
for l in range(len(self.layers)):
x = hiddens[l][:, :, :t + 1].contiguous().view(B * W, -1, C)
x = self.layers[l].selfattn(x[:, -1:, :], x, x)

hiddens[l + 1][:, :, t] = self.layers[l].feedforward(
self.layers[l].attention(x, encoding[l], encoding[l], mask_src)).view(
B, W, C)

# topk2_logps: scores, topk2_inds: top word index at each beam, batch x beam x beam
topk2_logps = log_softmax(self.out(hiddens[-1][:, :, t])) # B*W*Vocab
topk2_logps[:, :, self.field.vocab.stoi['<pad>']] = -INF
topk2_logps, topk2_inds = topk2_logps.topk(W, dim=-1)


topk2_logps = topk2_logps * Variable(eos_yet[:, :, None].float() * eos_mask + 1 - eos_yet[:, :, None].float())

topk2_logps = topk2_logps + logps[:, :, None]

if t == 0:
logps, topk_inds = topk2_logps[:, 0].topk(W, dim=-1)
else:
logps, topk_inds = topk2_logps.view(B, W * W).topk(W, dim=-1) # use topk_inds to get eos_yet

topk_beam_inds = topk_inds.div(W)

topk_token_inds = topk2_inds.view(B, W * W).gather(1, topk_inds) # get next token vocab_index
eos_yet = eos_yet.gather(1, topk_beam_inds.data)

# logps = logps * (1 - Variable(eos_yet.float()) * 1 / (t + 2)).pow(alpha)
logps = logps * (1 + Variable(eos_yet.float()) * 1 / (t + 1)).pow(alpha)

outs = outs.gather(1, topk_beam_inds[:, :, None].expand_as(outs))
outs[:, :, t + 1] = topk_token_inds

topk_beam_inds_h = topk_beam_inds[:, :, None, None].expand_as(hiddens[0])
for i in range(len(hiddens)):
hiddens[i] = hiddens[i].gather(1, topk_beam_inds_h)

eos_yet = eos_yet | (topk_token_inds.data == self.field.vocab.stoi['<eos>'])
if eos_yet.all():
return outs[:, 0, 1:]
return outs[:, 0, 1:]