Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: auto batch size supports methods that return a dict #3626

Merged
merged 2 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 61 additions & 15 deletions deepmd/pt/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@


class AutoBatchSize(AutoBatchSizeBase):
"""Auto batch size.

Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor

"""

def __init__(

Check warning on line 27 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L27

Added line #L27 was not covered by tests
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
):
super().__init__(

Check warning on line 32 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L32

Added line #L32 was not covered by tests
initial_batch_size=initial_batch_size,
factor=factor,
)

def is_gpu_available(self) -> bool:
"""Check if GPU is available.

Expand Down Expand Up @@ -78,26 +100,50 @@
)

index = 0
results = []
results = None
returned_dict = None

Check warning on line 104 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L103-L104

Added lines #L103 - L104 were not covered by tests
while index < total_size:
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
if not isinstance(result, tuple):
result = (result,)
returned_dict = (

Check warning on line 107 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L107

Added line #L107 was not covered by tests
isinstance(result, dict) if returned_dict is None else returned_dict
)
if not returned_dict:
result = (result,) if not isinstance(result, tuple) else result

Check warning on line 111 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L110-L111

Added lines #L110 - L111 were not covered by tests
index += n_batch
if n_batch:
for rr in result:
rr.reshape((n_batch, -1))
results.append(result)
r_list = []
for r in zip(*results):

def append_to_list(res_list, res):
if n_batch:
res_list.append(res)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return res_list

Check warning on line 117 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L114-L117

Added lines #L114 - L117 were not covered by tests

if not returned_dict:
results = [] if results is None else results
results = append_to_list(results, result)

Check warning on line 121 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L119-L121

Added lines #L119 - L121 were not covered by tests
else:
results = (

Check warning on line 123 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L123

Added line #L123 was not covered by tests
{kk: [] for kk in result.keys()} if results is None else results
)
results = {

Check warning on line 126 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L126

Added line #L126 was not covered by tests
kk: append_to_list(results[kk], result[kk]) for kk in result.keys()
}
assert results is not None
assert returned_dict is not None

Check warning on line 130 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L129-L130

Added lines #L129 - L130 were not covered by tests

def concate_result(r):

Check warning on line 132 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L132

Added line #L132 was not covered by tests
if isinstance(r[0], np.ndarray):
r_list.append(np.concatenate(r, axis=0))
ret = np.concatenate(r, axis=0)

Check warning on line 134 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L134

Added line #L134 was not covered by tests
elif isinstance(r[0], torch.Tensor):
r_list.append(torch.cat(r, dim=0))
ret = torch.cat(r, dim=0)

Check warning on line 136 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L136

Added line #L136 was not covered by tests
else:
raise RuntimeError(f"Unexpected result type {type(r[0])}")
r = tuple(r_list)
if len(r) == 1:
# avoid returning tuple if callable doesn't return tuple
r = r[0]
return ret

Check warning on line 139 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L139

Added line #L139 was not covered by tests

if not returned_dict:
r_list = [concate_result(r) for r in zip(*results)]
r = tuple(r_list)
if len(r) == 1:

Check warning on line 144 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L141-L144

Added lines #L141 - L144 were not covered by tests
# avoid returning tuple if callable doesn't return tuple
r = r[0]

Check warning on line 146 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L146

Added line #L146 was not covered by tests
else:
r = {kk: concate_result(vv) for kk, vv in results.items()}

Check warning on line 148 in deepmd/pt/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/auto_batch_size.py#L148

Added line #L148 was not covered by tests
return r
37 changes: 37 additions & 0 deletions source/tests/pt/test_auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.pt.utils.auto_batch_size import (
AutoBatchSize,
)


class TestAutoBatchSize(unittest.TestCase):
def test_execute_all(self):
dd0 = np.zeros((10000, 2, 1, 3, 4))
dd1 = np.ones((10000, 2, 1, 3, 4))
auto_batch_size = AutoBatchSize(256, 2.0)

def func(dd1):
return np.zeros_like(dd1), np.ones_like(dd1)

dd2 = auto_batch_size.execute_all(func, 10000, 2, dd1)
np.testing.assert_equal(dd0, dd2[0])
np.testing.assert_equal(dd1, dd2[1])

def test_execute_all_dict(self):
dd0 = np.zeros((10000, 2, 1, 3, 4))
dd1 = np.ones((10000, 2, 1, 3, 4))
auto_batch_size = AutoBatchSize(256, 2.0)

def func(dd1):
return {
"foo": np.zeros_like(dd1),
"bar": np.ones_like(dd1),
}

dd2 = auto_batch_size.execute_all(func, 10000, 2, dd1)
np.testing.assert_equal(dd0, dd2["foo"])
np.testing.assert_equal(dd1, dd2["bar"])