Skip to content

Commit

Permalink
feat: Add convenience method to define variable for all parts
Browse files Browse the repository at this point in the history
  • Loading branch information
tkarabela committed Nov 10, 2024
1 parent a141fa2 commit 86816c7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
45 changes: 43 additions & 2 deletions ensightreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,8 @@ def reload_from_file(self) -> None:
Parse file again and update internal metadata
This is meant to be called after writing to the geometry file, to see the changes.
Be sure to flush your writes before calling this method.
"""
geofile = self.from_file_path(
file_path=self.file_path,
Expand Down Expand Up @@ -1315,6 +1317,8 @@ def reload_from_file(self) -> None:
Parse file again and update internal metadata
This is meant to be called after writing to the variable file, to see the changes.
Be sure to flush your writes before calling this method.
"""
variable = self.from_file_path(
file_path=self.file_path,
Expand Down Expand Up @@ -1600,18 +1604,21 @@ def ensure_data_for_part(
self,
fp: BinaryIO,
part_id: int,
default_value: float = 0.0
default_value: float = 0.0,
_reload_from_file: bool = True,
) -> None:
"""
Append variable definition for given part if variable is undefined currently, otherwise do nothing
This method will seek to the end of the file automatically.
See Also:
`EnsightVariableFile.ensure_data_for_all_parts()`
Args:
fp: Opened writable variable file (use `EnsightVariableFile.open_writeable()`, not mmap)
part_id: Part ID to be defined
default_value: Constant value that will be filled in if the variable is not defined
"""
if self.is_defined_for_part_id(part_id):
return
Expand All @@ -1633,6 +1640,37 @@ def ensure_data_for_part(
else:
raise NotImplementedError("unexpected variable location")

if _reload_from_file:
fp.flush()
self.reload_from_file()

def ensure_data_for_all_parts(self, fp: BinaryIO, default_value: float = 0.0) -> None:
"""
Append variable definitions for all parts that are currently undefined
This method will seek to the end of the file automatically.
Usage:
>>> from ensightreader import read_case
>>> case = read_case("sphere.case")
>>> variable = case.get_variable("RTData")
>>> with variable.open_writeable() as fp:
... variable.ensure_data_for_all_parts(fp)
See Also:
`EnsightVariableFile.ensure_data_for_part()`
Args:
fp: Opened writable variable file (use `EnsightVariableFile.open_writeable()`, not mmap)
default_value: Constant value that will be filled in if the variable is not defined
"""
for part_id in self.geometry_file.get_part_ids():
self.ensure_data_for_part(fp, part_id, default_value, _reload_from_file=False)

fp.flush()
self.reload_from_file()

def open(self) -> BinaryIO:
"""
Return the opened file in read-only binary mode (convenience method)
Expand Down Expand Up @@ -2555,6 +2593,9 @@ def define_variable(
Returns:
`EnsightVariableFile` for the new variable
See Also:
`EnsightVariableFile.ensure_data_for_all_parts()`
"""
if variable_name in self.variables:
raise ValueError(f"Variable with name {variable_name!r} is already present in case")
Expand Down
36 changes: 35 additions & 1 deletion tests/test_write_geometry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import itertools

import pytest

import ensightreader
from ensightreader import EnsightGeometryFile, GeometryPart, UnstructuredElementBlock, ElementType, read_case
from ensightreader import EnsightGeometryFile, GeometryPart, UnstructuredElementBlock, ElementType, read_case, \
VariableLocation, VariableType
import numpy as np
import tempfile
import os.path as op
Expand Down Expand Up @@ -355,3 +358,34 @@ def test_append_geometry_and_variables(tmp_path, source_case_path, dest_case_pat

dest_case2 = ensightreader.read_case(op.join(dest_dir, op.basename(dest_case_path)))
assert set(dest_case2.get_variables()) == set(dest_case_original_variables) | set(source_case.get_variables())


@pytest.mark.parametrize(
"variable_type, variable_location",
itertools.product(
[VariableType.SCALAR, VariableType.VECTOR, VariableType.TENSOR_SYMM, VariableType.TENSOR_ASYM],
[VariableLocation.PER_NODE, VariableLocation.PER_ELEMENT],
)
)
def test_ensure_data(tmp_path, variable_type: VariableType, variable_location: VariableLocation):
case_dir = tmp_path / "cavity"
shutil.copytree(op.dirname(CAVITY_CASE_PATH), case_dir)
case = ensightreader.read_case(op.join(case_dir, op.basename(CAVITY_CASE_PATH)))
my_variable = case.define_variable(
variable_location,
variable_type,
"my_variable",
"my_variable.bin"
)
with my_variable.open_writeable() as fp:
my_variable.ensure_data_for_all_parts(fp, 3.14)

with my_variable.mmap() as mm:
for part_id in case.get_geometry_model().get_part_ids():
if variable_location == VariableLocation.PER_NODE:
arr = my_variable.read_node_data(mm, part_id)
assert arr.min() == 3.14 and arr.max() == 3.14
else:
for block in case.get_geometry_model().get_part_by_id(part_id).element_blocks:
arr = my_variable.read_element_data(mm, part_id, block.element_type)
assert arr.min() == 3.14 and arr.max() == 3.14

0 comments on commit 86816c7

Please sign in to comment.