Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

examples: add .obj sphere motion simulation #20

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/configs/spheres.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ xray:
noise: False # Whether to simulate noise
num_darks: 100
num_flats: 100
num_projections_per_image: 1 # Superimpose this many projections in one (motion blur)
every_nth_projection: 1 # Read every nth-projection
use_capillary: True # Project also capillary
source:
drift: False
num_periods: 1 # Source moves along a sine vertically, this specifies number of periods it makes
Expand Down
87 changes: 69 additions & 18 deletions examples/mesh_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Laminography data set generation with mesh geometry."""
import imageio
import itertools
import glob
import logging
import os
import time
Expand All @@ -39,7 +40,7 @@ def make_projection(shape, ps, axis, mesh, center, lamino_angle, tomo_angle, ss=
if axis == "z":
lamino_angle = lamino_angle + 90 * q.deg
tomo_angle = -tomo_angle
axis = Y_AX if axis == "y" else Z_AX
axis = X_AX if axis == "y" else Z_AX
mesh.clear_transformation()
mesh.translate(center)
mesh.rotate(lamino_angle, X_AX)
Expand All @@ -54,30 +55,60 @@ def make_projection(shape, ps, axis, mesh, center, lamino_angle, tomo_angle, ss=
if ss > 1:
projection = bin_image(projection, orig_shape, average=True)

return projection.get()
return projection.get().T


def read_mesh(filename, iterations=1, mesh_pixel_size=None):
from syris.bodies.mesh import Mesh, read_blender_obj
from syris.geometry import Trajectory

path, ext = os.path.splitext(filename)
if ext == ".obj":
tmp = read_blender_obj(filename)
else:
tmp = np.load(filename)

tri = np.copy(tmp)
tri[0, :] = tmp[1, :]
tri[1, :] = tmp[0, :]

if mesh_pixel_size:
tri = tri * mesh_pixel_size * q.nm
else:
tri = tri * q.um
tr = Trajectory([(0, 0, 0)] * q.um)

return Mesh(tri, tr, center=None, iterations=iterations)


def scan(
shape,
ps,
axis,
mesh,
mesh_filename,
angles,
prefix,
lamino_angle=45 * q.deg,
index=0,
num_devices=1,
shift_coeff=1e4,
ss=1,
mesh_pixel_size=None,
num_meshes=1,
supersampling_projection=1,
):
"""Make a scan of tomographic angles. *shift_coeff* is the coefficient multiplied by pixel size
which shifts the triangles to get rid of faulty pixels.
"""
psm = ps.simplified.magnitude
log_fmt = "{}: {:>04}/{:>04} in {:6.2f} s, angle: {:>6.2f} deg, maxima: {}"
if os.path.isfile(mesh_filename):
mesh_filenames = [mesh_filename]
else:
mesh_filenames = sorted(glob.glob(mesh_filename))

# Move to the middle of the FOV
point = (shape[1] * psm / 2, shape[0] * psm / 2, 0) * q.m
point = (0, shape[1] * psm / 2, 0) * q.m
if index == 0:
LOG.info("Mesh shift: {}".format(point.rescale(q.um)))
LOG.info("Mesh shift in pixels: {}".format((point / ps).simplified.magnitude))
Expand All @@ -94,7 +125,17 @@ def scan(
checked_indices = []
# Projections which exceed the allowed metric difference even after more iterations
bad_indices = []
i_mesh = None
for i, angle in mine:
if i * num_meshes // num_angles != i_mesh:
i_mesh = i * num_meshes // num_angles
mesh = read_mesh(
mesh_filenames[i_mesh],
iterations=supersampling_projection,
mesh_pixel_size=mesh_pixel_size,
)
with LOCK:
LOG.info("i: %d, reading mesh %d", i, i_mesh)
st = time.time()
projs = [make_projection(shape, ps, axis, mesh, point, lamino_angle, angle, ss=ss)]
max_vals = [projs[-1].max()]
Expand Down Expand Up @@ -174,21 +215,14 @@ def make_ground_truth(args, shape, mesh):

def process(args, device_index):
import syris
from syris.geometry import Trajectory
from syris.bodies.mesh import Mesh, read_blender_obj

syris.init(
device_index=device_index, logfile=args.logfile, double_precision=args.double_precision
)
path, ext = os.path.splitext(args.input)
if ext == ".obj":
tri = read_blender_obj(args.input)
else:
tri = np.load(args.input)
tri = tri * q.um

tr = Trajectory([(0, 0, 0)] * q.um)
mesh = Mesh(tri, tr, center=None, iterations=args.supersampling_projection)
mesh = read_mesh(
args.input if os.path.isfile(args.input) else sorted(glob.glob(args.input))[0],
iterations=args.supersampling_projection
)

if args.n:
n = args.n
Expand Down Expand Up @@ -219,19 +253,26 @@ def process(args, device_index):
shape,
args.pixel_size,
args.rotation_axis,
mesh,
args.input,
angles,
args.prefix,
lamino_angle=args.lamino_angle,
index=device_index,
num_devices=args.num_devices,
ss=args.supersampling,
num_meshes=args.num_meshes,
supersampling_projection=args.supersampling_projection,
mesh_pixel_size=args.mesh_pixel_size,
)


def parse_args():
parser = get_default_parser(__doc__)
parser.add_argument("input", type=str, help="Blender .obj input file name")
parser.add_argument(
"input",
type=str,
help="Input file name or file names prefix"
)
parser.add_argument("--n", type=int, help="Number of pixels")
parser.add_argument(
"--supersampling-projection",
Expand All @@ -253,6 +294,9 @@ def parse_args():
parser.add_argument(
"--pixel-size", type=float, default=[750.0], nargs="+", help="Pixel size in nm"
)
parser.add_argument(
"--mesh-pixel-size", type=float, help="Physical mesh pixel size in nm"
)
parser.add_argument(
"--rotation-angle", type=float, default=180, help="Total rotation angle in degrees"
)
Expand All @@ -270,6 +314,9 @@ def parse_args():
parser.add_argument(
"--num-devices", type=int, default=1, help="Number of compute devices to use"
)
parser.add_argument(
"--num-meshes", type=int, default=1, help="Number of meshes / rotation angle"
)
parser.add_argument(
"--supersampling",
type=int,
Expand All @@ -294,6 +341,10 @@ def parse_args():

def main():
args = parse_args()

if os.path.isfile(args.input) and args.num_meshes > 1:
raise ValueError("--num-meshes > 1 can be used only if more meshes are specified")

combinations = list(
itertools.product(
args.lamino_angle, args.pixel_size, args.rotation_axis, args.supersampling
Expand All @@ -306,7 +357,7 @@ def main():
image_directory = "projections"
file_prefix = image_directory[:-1]

file_prefix += "_{:>04}.tif"
file_prefix += "_{:>06}.tif"

devices = list(range(args.num_devices))
pool = Pool(processes=args.num_devices)
Expand Down
102 changes: 53 additions & 49 deletions examples/spheres.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,33 +240,10 @@ def get_camera_image(image, camera, xray_gain, noise=False):
)


def get_low_resolution_image(
hd_image,
supersampling,
camera,
xray_gain,
max_intensity,
noise=False,
spots_image=None
):
n = hd_image.shape[1] // supersampling
image = decimate(
hd_image,
(n, n),
sigma=fwnm_to_sigma(supersampling, n=2),
average=False
).get()
image = get_camera_image(image, camera, xray_gain, noise=noise)
if spots_image is not None:
image = np.clip(image + spots_image * max_intensity, 0, max_intensity)

return image


def create_xray_projections(common):
args = common.xray
syris.init()
projection_filenames = sorted(glob.glob(args.projections_fmt))
projection_filenames = sorted(glob.glob(args.projections_fmt))[::args.every_nth_projection]
n_hd = imageio.imread(projection_filenames[0]).shape[0]
supersampling = n_hd // common.n
# Compute grid
Expand Down Expand Up @@ -357,8 +334,10 @@ def create_xray_projections(common):
if args.spots_filename:
spots_image = imageio.imread(args.spots_filename) if args.spots_filename else None

image_acc = None
flats_done = False
max_flat = None
max_intensity = None
# Projections
for i, filename in tqdm.tqdm(enumerate(projection_filenames)):
# Flat field
Expand All @@ -369,21 +348,26 @@ def create_xray_projections(common):
t=i / num_projections * source_traj.time if args.source.drift else 0 * q.s
)) ** 2
flat_hd = flat_hd / cl_array.max(flat_hd) * args.max_absorbed_photons / supersampling ** 2
flat_ld = decimate(
flat_hd,
(common.n, common.n),
sigma=fwnm_to_sigma(supersampling, n=2),
average=False
).get()
if max_flat is None:
max_flat = cl_array.max(flat_hd).get() * xray_gain * camera.gain * supersampling ** 2
max_intensity = 1.2 * max_flat
print("Max flat value:", max_flat)
if not flats_done:
flat_ld = get_low_resolution_image(
flat_hd,
supersampling,
camera,
xray_gain,
max_flat * 1.2,
noise=args.noise,
spots_image=spots_image
)
flat_ld_save = get_camera_image(flat_ld, camera, xray_gain, noise=args.noise)
if spots_image is not None:
flat_ld_save = np.clip(
flat_ld_save + spots_image * max_intensity,
0,
max_intensity
)
if i < args.num_flats:
flats.append(flat_ld[y_cutoff:-y_cutoff])
flats.append(flat_ld_save[y_cutoff:-y_cutoff])
else:
imageio.volwrite(
os.path.join(output_directory, f"flats{args.output_suffix}.tif"),
Expand All @@ -393,23 +377,43 @@ def create_xray_projections(common):

# Sample
spheres = imageio.imread(filename)
projection = spheres * ps_hd
projection = spheres * q.m
sample = StaticBody(projection, ps_hd, material=material)
samples = [sample, capillary] if args.use_capillary else [sample]
# Propagation with a monochromatic plane incident wave
hd = propagate([capillary, sample], shape, [energy], propagation_distance, ps_hd)
proj = get_low_resolution_image(
flat_hd * hd,
supersampling,
camera,
xray_gain,
max_flat * 1.2,
noise=args.noise,
spots_image=spots_image
)
imageio.imwrite(
os.path.join(projs_dir, "projection-{:>05}.tif".format(i)),
proj.astype(np.float32)[y_cutoff:-y_cutoff]
)
hd = propagate(samples, shape, [energy], propagation_distance, ps_hd)
image = decimate(
hd,
(common.n, common.n),
sigma=fwnm_to_sigma(supersampling, n=2),
average=True
).get()
if image_acc is None:
image_acc = image
else:
image_acc += image
if (i + 1) % args.num_projections_per_image == 0:
image_acc = get_camera_image(
flat_ld * image_acc / args.num_projections_per_image,
camera,
xray_gain,
noise=args.noise
)
if spots_image is not None:
image_acc = np.clip(
image_acc + spots_image * max_intensity,
0,
max_intensity
)

imageio.imwrite(
os.path.join(
projs_dir,
"projection-{:>05}.tif".format(i // args.num_projections_per_image)
),
image_acc.astype(np.float32)[y_cutoff:-y_cutoff]
)
image_acc = None


@hydra.main(version_base=None, config_path="configs", config_name="spheres")
Expand Down
2 changes: 1 addition & 1 deletion syris/bodies/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def max_triangle_x_diff(self):
x_2 = self._current[0, 2::3]
d_0 = np.max(np.abs(x_1 - x_0))
d_1 = np.max(np.abs(x_1 - x_2))
d_2 = np.max(np.abs(x_2 - x_1))
d_2 = np.max(np.abs(x_2 - x_0))

return max(d_0, d_1, d_2) * q.um

Expand Down
Loading