From 2cb66db9750fca59a3350f556858f7efded4ced2 Mon Sep 17 00:00:00 2001 From: Sergey O Date: Sun, 9 Oct 2022 19:18:01 -0400 Subject: [PATCH] bugfix for latest jax, clear_mem() was giving error --- colabdesign/shared/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/colabdesign/shared/utils.py b/colabdesign/shared/utils.py index f3f59968..fd5aa237 100644 --- a/colabdesign/shared/utils.py +++ b/colabdesign/shared/utils.py @@ -19,7 +19,10 @@ def clear_mem(): for obj_name in dir(module): obj = getattr(module, obj_name) if hasattr(obj, "cache_clear"): - obj.cache_clear() + try: + obj.cache_clear() + except: + pass gc.collect() def update_dict(D, *args, **kwargs): @@ -116,4 +119,4 @@ def copy_missing(a,b): if i not in b: b[i] = v elif isinstance(v,dict): - copy_missing(v,b[i]) \ No newline at end of file + copy_missing(v,b[i])