-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathsync_attention_wrapper.py
80 lines (69 loc) · 3.05 KB
/
sync_attention_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from tensorflow.python.ops import array_ops
from tensorflow.contrib import rnn
from tensorflow.contrib import seq2seq
from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import _compute_attention
class SyncAttentionWrapper(seq2seq.AttentionWrapper):
def __init__(self,
cell,
attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
name=None):
if not isinstance(cell, (rnn.LSTMCell, rnn.GRUCell)):
raise ValueError('SyncAttentionWrapper only supports LSTMCell and GRUCell, '
'Got: {}'.format(cell))
super(SyncAttentionWrapper, self).__init__(
cell,
attention_mechanism,
attention_layer_size=attention_layer_size,
alignment_history=alignment_history,
cell_input_fn=cell_input_fn,
output_attention=output_attention,
initial_cell_state=initial_cell_state,
name=name
)
def call(self, inputs, state):
if not isinstance(state, seq2seq.AttentionWrapperState):
raise TypeError("Expected state to be instance of AttentionWrapperState. "
"Received type %s instead." % type(state))
if self._is_multi:
previous_alignments = state.alignments
previous_alignment_history = state.alignment_history
else:
previous_alignments = [state.alignments]
previous_alignment_history = [state.alignment_history]
all_alignments = []
all_attentions = []
all_attention_states = []
all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
if isinstance(self._cell, rnn.LSTMCell):
rnn_cell_state = state.cell_state.h
else:
rnn_cell_state = state.cell_state
attention, alignments, next_attention_state = _compute_attention(
attention_mechanism, rnn_cell_state, previous_alignments[i],
self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(
state.time, alignments) if self._alignment_history else ()
all_attention_states.append(next_attention_state)
all_alignments.append(alignments)
all_histories.append(alignment_history)
all_attentions.append(attention)
attention = array_ops.concat(all_attentions, 1)
cell_inputs = self._cell_input_fn(inputs, attention)
cell_output, next_cell_state = self._cell(cell_inputs, state.cell_state)
next_state = seq2seq.AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
attention_state=self._item_or_tuple(all_attention_states),
alignments=self._item_or_tuple(all_alignments),
alignment_history=self._item_or_tuple(all_histories))
if self._output_attention:
return attention, next_state
else:
return cell_output, next_state