-
Notifications
You must be signed in to change notification settings - Fork 897
Expand file tree
/
Copy pathdrmm.py
More file actions
121 lines (103 loc) · 4.15 KB
/
drmm.py
File metadata and controls
121 lines (103 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""An implementation of DRMM Model."""
import typing
import keras
import keras.backend as K
import tensorflow as tf
from matchzoo.engine.base_model import BaseModel
from matchzoo.engine.param import Param
from matchzoo.engine.param_table import ParamTable
class DRMM(BaseModel):
"""
DRMM Model.
Examples:
>>> model = DRMM()
>>> model.params['mlp_num_layers'] = 1
>>> model.params['mlp_num_units'] = 5
>>> model.params['mlp_num_fan_out'] = 1
>>> model.params['mlp_activation_func'] = 'tanh'
>>> model.guess_and_fill_missing_params(verbose=0)
>>> model.build()
>>> model.compile()
"""
@classmethod
def get_default_params(cls) -> ParamTable:
""":return: model default parameters."""
params = super().get_default_params(with_embedding=True,
with_multi_layer_perceptron=True)
params.add(Param(name='mask_value', value=-1,
desc="The value to be masked from inputs."))
params['optimizer'] = 'adam'
params['input_shapes'] = [(5,), (5, 30,)]
return params
def build(self):
"""Build model structure."""
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# D = embedding size
# L = `input_left` sequence length
# R = `input_right` sequence length
# H = histogram size
# K = size of top-k
# Left input and right input.
# query: shape = [B, L]
# doc: shape = [B, L, H]
# Note here, the doc is the matching histogram between original query
# and original document.
query = keras.layers.Input(
name='text_left',
shape=self._params['input_shapes'][0]
)
match_hist = keras.layers.Input(
name='match_histogram',
shape=self._params['input_shapes'][1]
)
embedding = self._make_embedding_layer()
# Process left input.
# shape = [B, L, D]
embed_query = embedding(query)
# shape = [B, L]
atten_mask = tf.not_equal(query, self._params['mask_value'])
# shape = [B, L]
atten_mask = tf.cast(atten_mask, K.floatx())
# shape = [B, L, D]
atten_mask = tf.expand_dims(atten_mask, axis=2)
# shape = [B, L, D]
attention_probs = self.attention_layer(embed_query, atten_mask)
# Process right input.
# shape = [B, L, 1]
dense_output = self._make_multi_layer_perceptron_layer()(match_hist)
# shape = [B, 1, 1]
dot_score = keras.layers.Dot(axes=[1, 1])(
[attention_probs, dense_output])
flatten_score = keras.layers.Flatten()(dot_score)
x_out = self._make_output_layer()(flatten_score)
self._backend = keras.Model(inputs=[query, match_hist], outputs=x_out)
@classmethod
def attention_layer(cls, attention_input: typing.Any,
attention_mask: typing.Any = None
) -> keras.layers.Layer:
"""
Performs attention on the input.
:param attention_input: The input tensor for attention layer.
:param attention_mask: A tensor to mask the invalid values.
:return: The masked output tensor.
"""
# shape = [B, L, 1]
dense_input = keras.layers.Dense(1, use_bias=False)(attention_input)
if attention_mask is not None:
# Since attention_mask is 1.0 for positions we want to attend and
# 0.0 for masked positions, this operation will create a tensor
# which is 0.0 for positions we want to attend and -10000.0 for
# masked positions.
# shape = [B, L, 1]
dense_input = keras.layers.Lambda(
lambda x: x + (1.0 - attention_mask) * -10000.0,
name="attention_mask"
)(dense_input)
# shape = [B, L, 1]
attention_probs = keras.layers.Lambda(
lambda x: tf.nn.softmax(x, axis=1),
output_shape=lambda s: (s[0], s[1], s[2]),
name="attention_probs"
)(dense_input)
return attention_probs