From 49a65d948357cfcf0e94b0683467b98858eabcac Mon Sep 17 00:00:00 2001 From: "Alexander V. Hopp" Date: Wed, 6 Nov 2024 09:06:25 +0100 Subject: [PATCH 1/7] Remove as_tensor argument of set_tensors_from_ndarray_1d --- botorch/optim/closures/core.py | 9 ++------- botorch/optim/utils/numpy_utils.py | 8 ++++++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/botorch/optim/closures/core.py b/botorch/optim/closures/core.py index 694289d7f5..77e20d5ad0 100644 --- a/botorch/optim/closures/core.py +++ b/botorch/optim/closures/core.py @@ -85,7 +85,6 @@ def __init__( closure: Callable[[], tuple[Tensor, Sequence[Tensor | None]]], parameters: dict[str, Tensor], as_array: Callable[[Tensor], npt.NDArray] = None, # pyre-ignore [9] - as_tensor: Callable[[npt.NDArray], Tensor] = torch.as_tensor, get_state: Callable[[], npt.NDArray] = None, # pyre-ignore [9] set_state: Callable[[npt.NDArray], None] = None, # pyre-ignore [9] fill_value: float = 0.0, @@ -99,14 +98,13 @@ def __init__( Expected to correspond with the first `len(parameters)` optional gradient tensors returned by `closure`. as_array: Callable used to convert tensors to ndarrays. - as_tensor: Callable used to convert ndarrays to tensors. get_state: Callable that returns the closure's state as an ndarray. When passed as `None`, defaults to calling `get_tensors_as_ndarray_1d` on `closure.parameters` while passing `as_array` (if given by the user). set_state: Callable that takes a 1-dimensional ndarray and sets the closure's state. When passed as `None`, `set_state` defaults to calling `set_tensors_from_ndarray_1d` with `closure.parameters` and - a given ndarray while passing `as_tensor`. + a given ndarray. fill_value: Fill value for parameters whose gradients are None. In most cases, `fill_value` should either be zero or NaN. persistent: Boolean specifying whether an ndarray should be retained @@ -128,15 +126,12 @@ def __init__( as_array = partial(as_ndarray, dtype=np_float64) if set_state is None: - set_state = partial( - set_tensors_from_ndarray_1d, parameters, as_tensor=as_tensor - ) + set_state = partial(set_tensors_from_ndarray_1d, parameters) self.closure = closure self.parameters = parameters self.as_array = as_ndarray - self.as_tensor = as_tensor self._get_state = get_state self._set_state = set_state diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index bb5d6b9093..bc8fadfe35 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -113,7 +113,6 @@ def get_tensors_as_ndarray_1d( def set_tensors_from_ndarray_1d( tensors: Iterator[Tensor] | dict[str, Tensor], array: npt.NDArray, - as_tensor: Callable[[npt.NDArray], Tensor] = torch.as_tensor, ) -> None: r"""Sets the values of one more tensors based off of a vector of assignments.""" named_tensors_iter = ( @@ -125,7 +124,12 @@ def set_tensors_from_ndarray_1d( try: size = tnsr.numel() vals = array[index : index + size] if tnsr.ndim else array[index] - tnsr.copy_(as_tensor(vals).to(tnsr).view(tnsr.shape).to(tnsr)) + tnsr.copy_( + torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype) + .to(tnsr) + .view(tnsr.shape) + .to(tnsr) + ) index += size except Exception as e: raise RuntimeError( From 29b81cfce36c35def087cb91871bbf356770001c Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 6 Nov 2024 09:41:41 -0500 Subject: [PATCH 2/7] Update botorch/optim/utils/numpy_utils.py --- botorch/optim/utils/numpy_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index bc8fadfe35..ee4791b3e4 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -126,9 +126,7 @@ def set_tensors_from_ndarray_1d( vals = array[index : index + size] if tnsr.ndim else array[index] tnsr.copy_( torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype) - .to(tnsr) .view(tnsr.shape) - .to(tnsr) ) index += size except Exception as e: From 9cc9ffe055221bb6790bf8d41d2bd2452df80add Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 6 Nov 2024 09:44:26 -0500 Subject: [PATCH 3/7] Update botorch/optim/utils/numpy_utils.py --- botorch/optim/utils/numpy_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index ee4791b3e4..73f09effc9 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -125,8 +125,9 @@ def set_tensors_from_ndarray_1d( size = tnsr.numel() vals = array[index : index + size] if tnsr.ndim else array[index] tnsr.copy_( - torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype) - .view(tnsr.shape) + torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype).view( + tnsr.shape + ) ) index += size except Exception as e: From 3ea83850504883eb2e4849e31046f904247c1f5a Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 6 Nov 2024 10:00:41 -0500 Subject: [PATCH 4/7] Update botorch/optim/utils/numpy_utils.py --- botorch/optim/utils/numpy_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index 73f09effc9..7a945bff8a 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -125,9 +125,7 @@ def set_tensors_from_ndarray_1d( size = tnsr.numel() vals = array[index : index + size] if tnsr.ndim else array[index] tnsr.copy_( - torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype).view( - tnsr.shape - ) + torch.from_numpy(vals.reshape(tnsr.shape)).to(tnsr) ) index += size except Exception as e: From f6f11f90697bb742544b04587f9db73cce1ea47e Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 6 Nov 2024 10:02:46 -0500 Subject: [PATCH 5/7] Update botorch/optim/utils/numpy_utils.py --- botorch/optim/utils/numpy_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index 7a945bff8a..3e32d6eec4 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -124,9 +124,7 @@ def set_tensors_from_ndarray_1d( try: size = tnsr.numel() vals = array[index : index + size] if tnsr.ndim else array[index] - tnsr.copy_( - torch.from_numpy(vals.reshape(tnsr.shape)).to(tnsr) - ) + tnsr.copy_(torch.from_numpy(vals.reshape(tnsr.shape)).to(tnsr)) index += size except Exception as e: raise RuntimeError( From 3557da309522e398a14943c4f196cfc9400e02bb Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 6 Nov 2024 10:11:30 -0500 Subject: [PATCH 6/7] Update botorch/optim/utils/numpy_utils.py --- botorch/optim/utils/numpy_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index 3e32d6eec4..035cfc9e53 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -124,7 +124,9 @@ def set_tensors_from_ndarray_1d( try: size = tnsr.numel() vals = array[index : index + size] if tnsr.ndim else array[index] - tnsr.copy_(torch.from_numpy(vals.reshape(tnsr.shape)).to(tnsr)) + torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype).view( + tnsr.shape + ) index += size except Exception as e: raise RuntimeError( From c49de8f3373d44389891259868eb0795dca0e733 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 6 Nov 2024 10:23:15 -0500 Subject: [PATCH 7/7] fix missing copy_ --- botorch/optim/utils/numpy_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/botorch/optim/utils/numpy_utils.py b/botorch/optim/utils/numpy_utils.py index 035cfc9e53..3c5b84dd30 100644 --- a/botorch/optim/utils/numpy_utils.py +++ b/botorch/optim/utils/numpy_utils.py @@ -9,7 +9,6 @@ from __future__ import annotations from collections.abc import Callable, Iterator - from itertools import tee import numpy as np @@ -124,8 +123,10 @@ def set_tensors_from_ndarray_1d( try: size = tnsr.numel() vals = array[index : index + size] if tnsr.ndim else array[index] - torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype).view( - tnsr.shape + tnsr.copy_( + torch.as_tensor(vals, device=tnsr.device, dtype=tnsr.dtype).view( + tnsr.shape + ) ) index += size except Exception as e: