import os import sys import tensorflow as tf import numpy as np import imageio import json def get_similar_k(pose, pose_set, img_set, top_size = None, num_from_top = 1, k = 2): vp = pose[:,3] vp_set = pose_set[:,:,3] vp_set_norm = tf.norm(vp_set, axis = -1)[...,None] vp_norm = tf.norm(vp, axis = -1) simil = tf.reduce_sum( (vp / vp_norm) * (vp_set / vp_set_norm) , -1) sorted_inds = tf.argsort(simil, direction = 'DESCENDING') if top_size is None: return tf.gather(pose_set, sorted_inds[:k]), np.take(img_set, sorted_inds[:k], axis = 0) else: rand_idxs = np.random.choice(np.arange(1,top_size), num_from_top, replace = False) sorted_inds = np.take(sorted_inds, rand_idxs, axis = 0) #use np take so that img_set is not all copied to GPU return tf.gather(pose_set, sorted_inds), np.take(img_set, sorted_inds, axis = 0) # Misc utils def img2mse(x, y): return tf.reduce_mean(tf.square(x - y)) def mse2psnr(x): return -10.*tf.log(x)/tf.log(10.) def to8b(x): return (255*np.clip(x, 0, 1)).astype(np.uint8) def load_intrinsic(filename): with open(filename) as f: nums = f.read().split() nums = list(map(lambda x:float(x), nums)) intrinsic = np.zeros((3,3)) H,W = nums[-2],nums[-1] intrinsic[0,0] = nums[0] intrinsic[1,1] = nums[0] intrinsic[:2,2] = nums[1:3] intrinsic[2,2] = 1 return intrinsic def parse_attributes(logdir): with open(os.path.join(logdir, "scene_attributes.txt"), 'r') as f: near, far, H, W = f.read().split(" ") return float(near), float(far), int(H), int(W) def preprocess_images(im): im = tf.cast(im, tf.float32) / 255. if im.shape[-1]==4: alpha = im[..., 3, None] im = im[..., :3] * alpha + (1. - alpha) return im # Positional encoding # Nerf embedding class which upsamples data class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x: x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2.**tf.linspace(0., max_freq, N_freqs) else: freq_bands = tf.linspace(2.**0., 2.**max_freq, N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return tf.concat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(multires, i=0, input_dims = 3): if i == -1: return tf.identity, 3 embed_kwargs = { 'include_input': True, 'input_dims': input_dims, 'max_freq_log2': multires-1, 'num_freqs': multires, 'log_sampling': True, 'periodic_fns': [tf.math.sin, tf.math.cos], } embedder_obj = Embedder(**embed_kwargs) def embed(x, eo=embedder_obj): return eo.embed(x) return embed, embedder_obj.out_dim # Ray helpers def get_rays(H, W, focal, c2w): """Get ray origins, directions from a pinhole camera.""" i, j = tf.meshgrid(tf.range(W, dtype=tf.float32), tf.range(H, dtype=tf.float32), indexing='xy') dirs = tf.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -tf.ones_like(i)], -1) rays_d = tf.reduce_sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) rays_o = tf.broadcast_to(c2w[:3, -1], tf.shape(rays_d)) return rays_o, rays_d def get_random_ray_direction(H, W, focal, c2w): i, j = np.random.randint(W), np.random.randint(H) dir = np.array([(i-W*.5)/focal, -(j-H*.5)/focal, -1]) ray_d = tf.reduce_sum(dir[..., np.newaxis, :] * c2w[:3, :3], -1) return ray_d def get_rays_np(H, W, focal, c2w): """Get ray origins, directions from a pinhole camera.""" i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') dirs = np.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -np.ones_like(i)], -1) rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d)) return rays_o, rays_d def ndc_rays(H, W, focal, near, rays_o, rays_d): """Normalized device coordinate rays. Space such that the canvas is a cube with sides [-1, 1] in each axis. Args: H: int. Height in pixels. W: int. Width in pixels. focal: float. Focal length of pinhole camera. near: float or array of shape[batch_size]. Near depth bound for the scene. rays_o: array of shape [batch_size, 3]. Camera origin. rays_d: array of shape [batch_size, 3]. Ray direction. Returns: rays_o: array of shape [batch_size, 3]. Camera origin in NDC. rays_d: array of shape [batch_size, 3]. Ray direction in NDC. """ # Shift ray origins to near plane t = -(near + rays_o[..., 2]) / rays_d[..., 2] rays_o = rays_o + t[..., None] * rays_d # Projection o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2] o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2] o2 = 1. + 2. * near / rays_o[..., 2] d0 = -1./(W/(2.*focal)) * \ (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2]) d1 = -1./(H/(2.*focal)) * \ (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2]) d2 = -2. * near / rays_o[..., 2] rays_o = tf.stack([o0, o1, o2], -1) rays_d = tf.stack([d0, d1, d2], -1) return rays_o, rays_d # Hierarchical sampling helper def sample_pdf(bins, weights, N_samples, det=False): # Get pdf weights += 1e-5 # prevent nans pdf = weights / tf.reduce_sum(weights, -1, keepdims=True) cdf = tf.cumsum(pdf, -1) cdf = tf.concat([tf.zeros_like(cdf[..., :1]), cdf], -1) # Take uniform samples if det: u = tf.linspace(0., 1., N_samples) u = tf.broadcast_to(u, list(cdf.shape[:-1]) + [N_samples]) else: u = tf.random.uniform(list(cdf.shape[:-1]) + [N_samples]) # Invert CDF inds = tf.searchsorted(cdf, u, side='right') below = tf.maximum(0, inds-1) above = tf.minimum(cdf.shape[-1]-1, inds) inds_g = tf.stack([below, above], -1) cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) denom = (cdf_g[..., 1]-cdf_g[..., 0]) denom = tf.where(denom < 1e-5, tf.ones_like(denom), denom) t = (u-cdf_g[..., 0])/denom samples = bins_g[..., 0] + t * (bins_g[..., 1]-bins_g[..., 0]) return samples def config_parser(): import configargparse parser = configargparse.ArgumentParser() parser.add_argument('--config', is_config_file=True, help='config file path') parser.add_argument("--expname", type=str, help='experiment name') parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs') parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory') # training options parser.add_argument("--netdepth", type=int, default=8, help='layers in network') parser.add_argument("--netwidth", type=int, default=256, help='channels per layer') parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network') parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network') parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)') parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') parser.add_argument("--lrate_decay", type=int, default=250, help='exponential learning rate decay (in 1000s)') parser.add_argument("--chunk", type=int, default=1024*32, help='number of rays processed in parallel, decrease if running out of memory') parser.add_argument("--netchunk", type=int, default=1024*64, help='number of pts sent through network in parallel, decrease if running out of memory') parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time') parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network') parser.add_argument("--random_seed", type=int, default=None, help='fix random seed for repeatability') # pre-crop options parser.add_argument("--precrop_iters", type=int, default=0, help='number of steps to train on central crops') parser.add_argument("--precrop_frac", type=float, default=.5, help='fraction of img taken for central crops') # rendering options parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray') parser.add_argument("--N_importance", type=int, default=0, help='number of additional fine samples per ray') parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter') parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D') parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none') parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)') parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)') parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended') parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path') parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path') parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') # dataset options parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels / shapenet') parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') # deepvoxels flags parser.add_argument("--shape", type=str, default='greek', help='options : armchair / cube / greek / vase') # blender flags parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)') parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800') # llff flags parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images') parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)') parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth') parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes') parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8') # logging/saving options parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin') parser.add_argument("--i_img", type=int, default=500, help='frequency of tensorboard image logging') parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving') parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving') parser.add_argument("--i_video", type=int, default=50000, help='frequency of render_poses video saving') #flags for attention model parser.add_argument("--attention_direction_multires", type=int, default=10, help='frequency of embedding for value') parser.add_argument("--attention_view_multires", type=int, default=4, help='frequency of embedding for direction') parser.add_argument("--render_per_scene", type = int, default = -1, help="number of views to render per scene, -1 for all") parser.add_argument("--training_recon", action='store_true', help="perform testing with training set as input, if not passed, perform one two shot") parser.add_argument("--use_quaternion", action='store_true', help="use quaternion") parser.add_argument("--no_globl", action = 'store_true', help="use global features in unet") parser.add_argument("--no_render_pose", action = 'store_true', help="use rendered pose in unet") parser.add_argument("--use_attsets", action = 'store_true', help="use rendered pose in unet") parser.add_argument("--from_scene",type =int, default =0, help="scene to start rendering from in sorted order") parser.add_argument("--to_scene",type =int, default =-1, help="scene to end rendering from in sorted order") return parser # Inverts a pose or extrinsic matrix def invert(mat): rot = mat[...,:3,:3] trans = mat[...,:3,3,None] rot_t = tf.transpose(rot, [0,2,1]) trans_t = -1 * tf.transpose(rot, [0,2,1]) @ trans return tf.concat([rot_t, trans_t], -1) @tf.function def ndc2world(H, W, focal, near, ndc_pts): w = tf.math.minimum(ndc_pts[...,2][...,None], 0.99999997) z = 2 * near / (w - 1.) inv_proj = tf.constant([ [W/2/focal, 0, 0, 0], [ 0, H/2/focal , 0 , 0 ], [ 0, 0 , 0 , -1 ], [ 0 , 0 , 1/(-2*near), 1/(2*near) ] ], dtype = tf.float32) unclipped = ndc_pts * (-1. * z) unclipped = tf.concat([unclipped, -1. * z], -1) return (unclipped @ tf.transpose(inv_proj))[...,:3]