From dcacc801e69a2afdf5f4e072a07a6c102ed3e9f8 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Tue, 7 Feb 2023 13:08:31 -0800 Subject: [PATCH] Creates non-breaking changes where necessary in preparation for switching all of Keras to new serialization format. PiperOrigin-RevId: 507864605 --- edward2/tensorflow/constraints.py | 7 ++++--- edward2/tensorflow/initializers.py | 7 ++++--- edward2/tensorflow/layers/gaussian_process.py | 9 +++++---- edward2/tensorflow/regularizers.py | 7 ++++--- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/edward2/tensorflow/constraints.py b/edward2/tensorflow/constraints.py index 5d9fa6ea..f72428a9 100644 --- a/edward2/tensorflow/constraints.py +++ b/edward2/tensorflow/constraints.py @@ -78,15 +78,16 @@ def get_config(self): def serialize(initializer): - return tf.keras.utils.serialize_keras_object(initializer) + return tf.keras.utils.legacy.serialize_keras_object(initializer) def deserialize(config, custom_objects=None): - return tf.keras.utils.deserialize_keras_object( + return tf.keras.utils.legacy.deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, - printable_module_name='constraints') + printable_module_name='constraints', + ) def get(identifier, value=None): diff --git a/edward2/tensorflow/initializers.py b/edward2/tensorflow/initializers.py index f562df7a..7da3bec6 100644 --- a/edward2/tensorflow/initializers.py +++ b/edward2/tensorflow/initializers.py @@ -851,15 +851,16 @@ def get_config(self): def serialize(initializer): - return tf.keras.utils.serialize_keras_object(initializer) + return tf.keras.utils.legacy.serialize_keras_object(initializer) def deserialize(config, custom_objects=None): - return tf.keras.utils.deserialize_keras_object( + return tf.keras.utils.legacy.deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, - printable_module_name='initializers') + printable_module_name='initializers', + ) def get(identifier, value=None): diff --git a/edward2/tensorflow/layers/gaussian_process.py b/edward2/tensorflow/layers/gaussian_process.py index 6603be11..05fbcf91 100644 --- a/edward2/tensorflow/layers/gaussian_process.py +++ b/edward2/tensorflow/layers/gaussian_process.py @@ -104,7 +104,7 @@ def get_config(self): return { 'variance': self.variance, 'bias': self.bias, - 'encoder': tf.keras.utils.serialize_keras_object(self.encoder), + 'encoder': tf.keras.utils.legacy.serialize_keras_object(self.encoder), } @@ -250,9 +250,10 @@ def compute_output_shape(self, input_shape): def get_config(self): config = { 'units': self.units, - 'mean_fn': tf.keras.utils.serialize_keras_object(self.mean_fn), - 'covariance_fn': tf.keras.utils.serialize_keras_object( - self.covariance_fn), + 'mean_fn': tf.keras.utils.legacy.serialize_keras_object(self.mean_fn), + 'covariance_fn': tf.keras.utils.legacy.serialize_keras_object( + self.covariance_fn + ), 'conditional_inputs': None, # don't serialize as it can be large 'conditional_outputs': None, # don't serialize as it can be large } diff --git a/edward2/tensorflow/regularizers.py b/edward2/tensorflow/regularizers.py index 1a404464..02e2bc27 100644 --- a/edward2/tensorflow/regularizers.py +++ b/edward2/tensorflow/regularizers.py @@ -388,15 +388,16 @@ def get_config(self): def serialize(initializer): - return tf.keras.utils.serialize_keras_object(initializer) + return tf.keras.utils.legacy.serialize_keras_object(initializer) def deserialize(config, custom_objects=None): - return tf.keras.utils.deserialize_keras_object( + return tf.keras.utils.legacy.deserialize_keras_object( config, module_objects=globals(), custom_objects=custom_objects, - printable_module_name='regularizers') + printable_module_name='regularizers', + ) def get(identifier, value=None):