Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions bert/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def map_from_stock_variale_name(name, prefix="bert"):
name = "/".join(pns + ns[1:])
ns = name.split("/")

if ns[1] not in ["encoder", "embeddings"]:
if ns[1] not in ["encoder", "embeddings", "pooler"]:
return None
if ns[1] == "embeddings":
if ns[2] == "LayerNorm":
Expand All @@ -67,6 +67,8 @@ def map_from_stock_variale_name(name, prefix="bert"):
return "/".join(ns[:4] + ns[5:])
else:
return name
if ns[1] == "pooler":
return "/".join(ns)
return None


Expand All @@ -81,7 +83,7 @@ def map_to_stock_variable_name(name, prefix="bert"):
name = "/".join(["bert"] + ns[len(pns):])
ns = name.split("/")

if ns[1] not in ["encoder", "embeddings"]:
if ns[1] not in ["encoder", "embeddings", "pooler"]:
return None
if ns[1] == "embeddings":
if ns[2] == "LayerNorm":
Expand All @@ -99,6 +101,8 @@ def map_to_stock_variable_name(name, prefix="bert"):
return "/".join(ns[:4] + ["dense"] + ns[4:])
else:
return name
if ns[1] == "pooler":
return "/".join(ns)
return None


Expand Down Expand Up @@ -181,7 +185,7 @@ def _checkpoint_exists(ckpt_path):


def bert_prefix(bert: BertModelLayer):
re_bert = re.compile(r'(.*)/(embeddings|encoder)/(.+):0')
re_bert = re.compile(r'(.*)/(embeddings|encoder|pooler)/(.+):0')
match = re_bert.match(bert.weights[0].name)
assert match, "Unexpected bert layer: {} weight:{}".format(bert, bert.weights[0].name)
prefix = match.group(1)
Expand Down
4 changes: 3 additions & 1 deletion bert/loader_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def map_to_tfhub_albert_variable_name(name, prefix="bert"):
name = "/".join(["bert"] + ns[len(pns):])
ns = name.split("/")

if ns[1] not in ["encoder", "embeddings"]:
if ns[1] not in ["encoder", "embeddings", "pooler"]:
return None
if ns[1] == "embeddings":
if ns[2] == "LayerNorm":
Expand All @@ -256,6 +256,8 @@ def map_to_tfhub_albert_variable_name(name, prefix="bert"):
return "/".join(ns[:4] + ["dense"] + ns[4:])
else:
return name
if ns[1] == "pooler":
return "/".join(ns)
return None


Expand Down
26 changes: 21 additions & 5 deletions bert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from bert.layer import Layer
from bert.embeddings import BertEmbeddingsLayer
from bert.pooler import BertPoolerLayer
from bert.transformer import TransformerEncoderLayer


Expand All @@ -23,8 +24,9 @@ class BertModelLayer(Layer):

"""
class Params(BertEmbeddingsLayer.Params,
TransformerEncoderLayer.Params):
pass
TransformerEncoderLayer.Params,
BertPoolerLayer.Params):
return_pooler_output = False

# noinspection PyUnusedLocal
def _construct(self, **kwargs):
Expand All @@ -41,6 +43,11 @@ def _construct(self, **kwargs):

self.support_masking = True

if self.params.return_pooler_output:
self.pooler_layer = BertPoolerLayer.from_params(
self.params,
name="pooler")

# noinspection PyAttributeOutsideInit
def build(self, input_shape):
if isinstance(input_shape, list):
Expand All @@ -61,7 +68,11 @@ def compute_output_shape(self, input_shape):
input_ids_shape = input_shape

output_shape = list(input_ids_shape) + [self.params.hidden_size]
return output_shape
if self.params.return_pooler_output:
pooler_output_shape = [input_ids_shape[0], self.params.hidden_size]
return output_shape, pooler_output_shape
else:
return output_shape

def apply_adapter_freeze(self):
""" Should be called once the model has been built to freeze
Expand All @@ -77,6 +88,11 @@ def call(self, inputs, mask=None, training=None):
mask = self.embeddings_layer.compute_mask(inputs)

embedding_output = self.embeddings_layer(inputs, mask=mask, training=training)
output = self.encoders_layer(embedding_output, mask=mask, training=training)
return output # [B, seq_len, hidden_size]
encoder_output = self.encoders_layer(embedding_output, mask=mask, training=training)

if self.params.return_pooler_output:
pooler_output = self.pooler_layer(encoder_output, mask=mask, training=training)
return encoder_output, pooler_output
else:
return encoder_output # [B, seq_len, hidden_size]

38 changes: 38 additions & 0 deletions bert/pooler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# coding=utf-8
#
# created by mrinaald on 17.Oct.2020 at 12:33
#

from __future__ import absolute_import, division, print_function

import tensorflow as tf
import params_flow as pf

from bert.layer import Layer


class BertPoolerLayer(Layer):
class Params(Layer.Params):
hidden_size = 768

def _construct(self, **kwargs):
super()._construct(**kwargs)
self.pooler_layer = None

def build(self, input_shape):
# Input Shape: (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE)
assert len(input_shape) == 3

self.input_spec = tf.keras.layers.InputSpec(shape=input_shape)
self.pooler_layer = tf.keras.layers.Dense(units=self.params.hidden_size,
activation='tanh',
kernel_initializer=self.create_initializer(),
name="dense")

super(BertPoolerLayer, self).build(input_shape)

def call(self, inputs, mask=None, training=None):
first_token_tensor = inputs[:, 0, :]

pooled_output = self.pooler_layer(first_token_tensor)
return pooled_output
21 changes: 14 additions & 7 deletions tests/test_adapter_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def setUp(self) -> None:

def test_coverage_improve(self):
bert_params = bert.params_from_pretrained_ckpt(self.ckpt_dir)
model, l_bert = self.build_model(bert_params, 1)
model, l_bert = self.build_model(bert_params, 1, return_pooler_output=True)
for weight in model.weights:
l_bert_prefix = bert.loader.bert_prefix(l_bert)

Expand All @@ -52,15 +52,22 @@ def test_coverage_improve(self):
self.assertEqual(weight.name.split(":")[0], keras_name)

@staticmethod
def build_model(bert_params, max_seq_len):
def build_model(bert_params, max_seq_len, return_pooler_output=False):
# enable adapter-BERT
bert_params.adapter_size = 2
bert_params.return_pooler_output = return_pooler_output
l_bert = bert.BertModelLayer.from_params(bert_params)
model = keras.models.Sequential([
l_bert,
keras.layers.Lambda(lambda seq: seq[:, 0, :]),
keras.layers.Dense(3, name="test_cls")
])
if return_pooler_output:
inp = keras.Input(shape=(max_seq_len,))
_, pooled_out = l_bert(inp)
out = keras.layers.Dense(3, name="test_cls")(pooled_out)
model = keras.Model(inputs=[inp], outputs=out)
else:
model = keras.models.Sequential([
l_bert,
keras.layers.Lambda(lambda seq: seq[:, 0, :]),
keras.layers.Dense(3, name="test_cls")
])
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=[keras.metrics.SparseCategoricalAccuracy()])
Expand Down
127 changes: 126 additions & 1 deletion tests/test_albert_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,62 @@ def to_model(bert_params):
print(weight.name, weight.shape)
self.assertEqual(16, len(model.non_trainable_weights))

def test_albert_with_pooler(self):
bert_params = bert.BertModelLayer.Params(hidden_size=32,
vocab_size=67,
max_position_embeddings=64,
num_layers=1,
num_heads=1,
intermediate_size=4,
use_token_type=False,

embedding_size=16, # using ALBERT instead of BERT
project_embeddings_with_bias=True,
shared_layer=True,
extra_tokens_vocab_size=3,

return_pooler_output=True
)


def to_model(bert_params):
l_bert = bert.BertModelLayer.from_params(bert_params)

token_ids = keras.layers.Input(shape=(21,))
_, pooler_out = l_bert(token_ids)
model = keras.Model(inputs=[token_ids], outputs=pooler_out)

model.build(input_shape=(None, 21))
l_bert.apply_adapter_freeze()

return model

model = to_model(bert_params)
model.summary()

print("trainable_weights:", len(model.trainable_weights))
print(bert_params.adapter_size)
for weight in model.trainable_weights:
print(weight.name, weight.shape)
self.assertEqual(25, len(model.trainable_weights))

# adapter-ALBERT :-)

bert_params.adapter_size = 16

model = to_model(bert_params)
model.summary()

print("trainable_weights:", len(model.trainable_weights))
for weight in model.trainable_weights:
print(weight.name, weight.shape)
self.assertEqual(15, len(model.trainable_weights))

print("non_trainable_weights:", len(model.non_trainable_weights))
for weight in model.non_trainable_weights:
print(weight.name, weight.shape)
self.assertEqual(18, len(model.non_trainable_weights))

def test_albert_load_base_google_weights(self): # for coverage mainly
albert_model_name = "albert_base"
albert_dir = bert.fetch_tfhub_albert_model(albert_model_name, ".models")
Expand All @@ -100,6 +156,22 @@ def test_albert_load_base_google_weights(self): # for coverage mainly

model.summary()

# return_pooler_output
model_params.return_pooler_output = True
l_bert = bert.BertModelLayer.from_params(model_params, name="albert_with_pooler")

inp = keras.layers.Input(shape=(8,), dtype=tf.int32, name="input_ids")
_, pooler_out = l_bert(inp)
out = keras.layers.Dense(2)(pooler_out)
model = keras.Model(inputs=[inp], outputs=out)
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")])

bert.load_albert_weights(l_bert, albert_dir)

model.summary()

def test_albert_params(self):
albert_model_name = "albert_base"
albert_dir = bert.fetch_tfhub_albert_model(albert_model_name, ".models")
Expand All @@ -117,6 +189,14 @@ def test_albert_params(self):
l_bert(tf.zeros((1, 128)))
bert.load_albert_weights(l_bert, albert_dir)

# coverage: return_pooler_output
model_params = dir_params
model_params.vocab_size = model_params.vocab_size
model_params.return_pooler_output = True
l_bert = bert.BertModelLayer.from_params(model_params, name="albert_with_pooler")
l_bert(tf.zeros((1, 128)))
bert.load_albert_weights(l_bert, albert_dir)

def test_albert_zh_fetch_and_load(self):
albert_model_name = "albert_tiny"
albert_dir = bert.fetch_brightmart_albert_model(albert_model_name, ".models")
Expand All @@ -133,4 +213,49 @@ def test_coverage(self):
try:
bert.fetch_google_bert_model("not-existent_bert_model", ".models")
except:
pass
pass

albert_model_name = "albert_tiny"
albert_dir = bert.fetch_brightmart_albert_model(albert_model_name, ".models")

model_params = bert.params_from_pretrained_ckpt(albert_dir)
model_params.vocab_size = model_params.vocab_size + 2
l_bert = bert.BertModelLayer.from_params(model_params, name="albert")

seq_len = 128
input_ids_shape = (1, seq_len)
token_type_ids_shape = (1, seq_len)

output_shape = l_bert.compute_output_shape(input_ids_shape)
self.assertTrue(len(output_shape) == 3)
self.assertTrue(output_shape[0] == 1)
self.assertTrue(output_shape[1] == seq_len)
self.assertTrue(output_shape[2] == model_params.hidden_size)

output_shape = l_bert.compute_output_shape([input_ids_shape, token_type_ids_shape])
self.assertTrue(len(output_shape) == 3)
self.assertTrue(output_shape[0] == 1)
self.assertTrue(output_shape[1] == seq_len)
self.assertTrue(output_shape[2] == model_params.hidden_size)

# return_pooler_output
model_params.return_pooler_output = True
l_bert = bert.BertModelLayer.from_params(model_params, name="albert")

output_shape = l_bert.compute_output_shape(input_ids_shape)
self.assertTrue(isinstance(output_shape, tuple))
self.assertTrue(len(output_shape) == 2)
self.assertTrue(output_shape[0][0] == 1)
self.assertTrue(output_shape[0][1] == seq_len)
self.assertTrue(output_shape[0][2] == model_params.hidden_size)
self.assertTrue(output_shape[1][0] == 1)
self.assertTrue(output_shape[1][1] == model_params.hidden_size)

output_shape = l_bert.compute_output_shape([input_ids_shape, token_type_ids_shape])
self.assertTrue(isinstance(output_shape, tuple))
self.assertTrue(len(output_shape) == 2)
self.assertTrue(output_shape[0][0] == 1)
self.assertTrue(output_shape[0][1] == seq_len)
self.assertTrue(output_shape[0][2] == model_params.hidden_size)
self.assertTrue(output_shape[1][0] == 1)
self.assertTrue(output_shape[1][1] == model_params.hidden_size)
Loading