diff --git a/.github/workflows/nightly-publish.yml b/.github/workflows/nightly-publish.yml index a3655271..e8cd8081 100644 --- a/.github/workflows/nightly-publish.yml +++ b/.github/workflows/nightly-publish.yml @@ -2,9 +2,6 @@ name: NightlyPublish on: workflow_dispatch: # Allow manual triggers - schedule: - # Runs every day at 3:07am UTC. - - cron: '7 3 * * *' jobs: check: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8dbd8bd5..b5e06077 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,14 +14,14 @@ jobs: fail-fast: false matrix: include: - - python-version: '3.7' + - python-version: '3.8' tf-version: '2.8' - - python-version: '3.7' - tf-version: '2.11' + - python-version: '3.9' + tf-version: '2.15.0' - python-version: '3.10' tf-version: '2.8' - - python-version: '3.10' - tf-version: '2.11' + - python-version: '3.11' + tf-version: '2.15.0' steps: - uses: actions/checkout@v2 @@ -36,7 +36,7 @@ jobs: - name: Install TF package run: | - pip install tensorflow==${{ matrix.tf-version }} + pip install tensorflow[and-cuda]==${{ matrix.tf-version }} # Fix proto dep issue in protobuf 4 pip install protobuf==3.20.* diff --git a/setup.py b/setup.py index 0350018e..9ab686b6 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ def get_version(rel_path): author_email="tf-similarity@google.com", url="https://github.com/tensorflow/similarity", license="Apache License 2.0", + python_requires=">=3.8", install_requires=[ "numpy", "pandas", @@ -88,9 +89,9 @@ def get_version(rel_path): "redis": ["redis"], "faiss": ["faiss-gpu"], "nmslib": ["nmslib"], - "tensorflow": ["tensorflow>=2.7,<=2.11"], - "tensorflow-gpu": ["tensorflow-gpu>=2.7,<=2.11"], - "tensorflow-cpu": ["tensorflow-cpu>=2.7,<=2.11"], + "tensorflow": ["tensorflow>=2.8,<=2.15"], + "tensorflow-gpu": ["tensorflow-gpu>=2.8,<=2.15"], + "tensorflow-cpu": ["tensorflow-cpu>=2.8,<=2.15"], }, classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/tensorflow_similarity/__init__.py b/tensorflow_similarity/__init__.py index 70da91b1..37bc95e3 100644 --- a/tensorflow_similarity/__init__.py +++ b/tensorflow_similarity/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.18.0.dev12" +__version__ = "0.18.0.dev13" from . import models # noqa diff --git a/tensorflow_similarity/distances/__init__.py b/tensorflow_similarity/distances/__init__.py index e2fb7925..e3fda865 100644 --- a/tensorflow_similarity/distances/__init__.py +++ b/tensorflow_similarity/distances/__init__.py @@ -110,12 +110,13 @@ def get(identifier) -> Distance: Raises: ValueError: If `identifier` cannot be interpreted. """ - if isinstance(identifier, Distance): - return identifier - elif isinstance(identifier, dict): - return deserialize(identifier) + if isinstance(identifier, dict): + identifier = deserialize(identifier) elif isinstance(identifier, str): config = {"class_name": str(identifier), "config": {}} - return deserialize(config) + identifier = deserialize(config) + + if isinstance(identifier, Distance): + return identifier else: raise ValueError("Could not interpret search identifier: {}".format(identifier)) diff --git a/tensorflow_similarity/search/__init__.py b/tensorflow_similarity/search/__init__.py index 4447e485..9fd56ef9 100644 --- a/tensorflow_similarity/search/__init__.py +++ b/tensorflow_similarity/search/__init__.py @@ -117,12 +117,13 @@ def get(identifier, **kwargs) -> Search: Raises: ValueError: If `identifier` cannot be interpreted. """ - if isinstance(identifier, Search): - return identifier - elif isinstance(identifier, dict): - return deserialize(identifier) + if isinstance(identifier, dict): + identifier = deserialize(identifier) elif isinstance(identifier, str): config = {"class_name": str(identifier), "config": kwargs} - return deserialize(config) + identifier = deserialize(config) + + if isinstance(identifier, Search): + return identifier else: raise ValueError("Could not interpret search identifier: {}".format(identifier)) diff --git a/tensorflow_similarity/stores/__init__.py b/tensorflow_similarity/stores/__init__.py index 12c1f4aa..5cf54826 100644 --- a/tensorflow_similarity/stores/__init__.py +++ b/tensorflow_similarity/stores/__init__.py @@ -109,12 +109,13 @@ def get(identifier) -> Store: Raises: ValueError: If `identifier` cannot be interpreted. """ - if isinstance(identifier, Store): - return identifier - elif isinstance(identifier, dict): - return deserialize(identifier) + if isinstance(identifier, dict): + identifier = deserialize(identifier) elif isinstance(identifier, str): config = {"class_name": str(identifier), "config": {}} - return deserialize(config) + identifier = deserialize(config) + + if isinstance(identifier, Store): + return identifier else: raise ValueError("Could not interpret Store identifier: {}".format(identifier)) diff --git a/tests/test_layers.py b/tests/test_layers.py index dc908f8a..7c96731f 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,6 +1,7 @@ import math import tensorflow as tf +from tensorflow.keras import layers from tensorflow_similarity.layers import ( GeneralizedMeanPooling1D, @@ -160,29 +161,6 @@ def test_metric_embedding(self): expected_result = tf.constant([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]]) self.assertAllClose(result, expected_result, rtol=1e-06) - def test_metric_embedding_get_config(self): - me_layer = MetricEmbedding(32) - config = me_layer.get_config() - expected_config = { - "name": "metric_embedding", - "trainable": True, - "dtype": "float32", - "units": 32, - "activation": "linear", - "use_bias": True, - "kernel_initializer": { - "class_name": "GlorotUniform", - "config": {"seed": None}, - }, - "bias_initializer": {"class_name": "Zeros", "config": {}}, - "kernel_regularizer": None, - "bias_regularizer": None, - "activity_regularizer": None, - "kernel_constraint": None, - "bias_constraint": None, - } - self.assertEqual(expected_config, config) - if __name__ == "__main__": tf.test.main()