|  | import numpy as np | 
					
						
						|  | import torch | 
					
						
						|  | import torch.nn as nn | 
					
						
						|  | import torch.nn.functional as F | 
					
						
						|  | from torchsparse.tensor import PointTensor, SparseTensor | 
					
						
						|  | import torchsparse.nn as spnn | 
					
						
						|  |  | 
					
						
						|  | from tsparse.modules import SparseCostRegNet | 
					
						
						|  | from tsparse.torchsparse_utils import sparse_to_dense_channel | 
					
						
						|  | from ops.grid_sampler import grid_sample_3d, tricubic_sample_3d | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from ops.back_project import back_project_sparse_type | 
					
						
						|  | from ops.generate_grids import generate_grid | 
					
						
						|  |  | 
					
						
						|  | from inplace_abn import InPlaceABN | 
					
						
						|  |  | 
					
						
						|  | from models.embedder import Embedding | 
					
						
						|  | from models.featurenet import ConvBnReLU | 
					
						
						|  |  | 
					
						
						|  | import pdb | 
					
						
						|  | import random | 
					
						
						|  |  | 
					
						
						|  | torch._C._jit_set_profiling_executor(False) | 
					
						
						|  | torch._C._jit_set_profiling_mode(False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @torch.jit.script | 
					
						
						|  | def fused_mean_variance(x, weight): | 
					
						
						|  | mean = torch.sum(x * weight, dim=1, keepdim=True) | 
					
						
						|  | var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True) | 
					
						
						|  | return mean, var | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LatentSDFLayer(nn.Module): | 
					
						
						|  | def __init__(self, | 
					
						
						|  | d_in=3, | 
					
						
						|  | d_out=129, | 
					
						
						|  | d_hidden=128, | 
					
						
						|  | n_layers=4, | 
					
						
						|  | skip_in=(4,), | 
					
						
						|  | multires=0, | 
					
						
						|  | bias=0.5, | 
					
						
						|  | geometric_init=True, | 
					
						
						|  | weight_norm=True, | 
					
						
						|  | activation='softplus', | 
					
						
						|  | d_conditional_feature=16): | 
					
						
						|  | super(LatentSDFLayer, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self.d_conditional_feature = d_conditional_feature | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dims_in = [d_in] + [d_hidden + d_conditional_feature for _ in range(n_layers - 2)] + [d_hidden] | 
					
						
						|  | dims_out = [d_hidden for _ in range(n_layers - 1)] + [d_out] | 
					
						
						|  |  | 
					
						
						|  | self.embed_fn_fine = None | 
					
						
						|  |  | 
					
						
						|  | if multires > 0: | 
					
						
						|  | embed_fn = Embedding(in_channels=d_in, N_freqs=multires) | 
					
						
						|  | self.embed_fn_fine = embed_fn | 
					
						
						|  | dims_in[0] = embed_fn.out_channels | 
					
						
						|  |  | 
					
						
						|  | self.num_layers = n_layers | 
					
						
						|  | self.skip_in = skip_in | 
					
						
						|  |  | 
					
						
						|  | for l in range(0, self.num_layers - 1): | 
					
						
						|  | if l in self.skip_in: | 
					
						
						|  | in_dim = dims_in[l] + dims_in[0] | 
					
						
						|  | else: | 
					
						
						|  | in_dim = dims_in[l] | 
					
						
						|  |  | 
					
						
						|  | out_dim = dims_out[l] | 
					
						
						|  | lin = nn.Linear(in_dim, out_dim) | 
					
						
						|  |  | 
					
						
						|  | if geometric_init: | 
					
						
						|  | if l == self.num_layers - 2: | 
					
						
						|  | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(in_dim), std=0.0001) | 
					
						
						|  | torch.nn.init.constant_(lin.bias, -bias) | 
					
						
						|  |  | 
					
						
						|  | torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) | 
					
						
						|  | torch.nn.init.constant_(lin.bias[-d_conditional_feature:], 0.0) | 
					
						
						|  |  | 
					
						
						|  | elif multires > 0 and l == 0: | 
					
						
						|  | torch.nn.init.constant_(lin.bias, 0.0) | 
					
						
						|  |  | 
					
						
						|  | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) | 
					
						
						|  |  | 
					
						
						|  | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) | 
					
						
						|  | elif multires > 0 and l in self.skip_in: | 
					
						
						|  | torch.nn.init.constant_(lin.bias, 0.0) | 
					
						
						|  | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | 
					
						
						|  |  | 
					
						
						|  | torch.nn.init.constant_(lin.weight[:, -(dims_in[0] - 3 + d_conditional_feature):], 0.0) | 
					
						
						|  | else: | 
					
						
						|  | torch.nn.init.constant_(lin.bias, 0.0) | 
					
						
						|  | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | 
					
						
						|  |  | 
					
						
						|  | torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) | 
					
						
						|  |  | 
					
						
						|  | if weight_norm: | 
					
						
						|  | lin = nn.utils.weight_norm(lin) | 
					
						
						|  |  | 
					
						
						|  | setattr(self, "lin" + str(l), lin) | 
					
						
						|  |  | 
					
						
						|  | if activation == 'softplus': | 
					
						
						|  | self.activation = nn.Softplus(beta=100) | 
					
						
						|  | else: | 
					
						
						|  | assert activation == 'relu' | 
					
						
						|  | self.activation = nn.ReLU() | 
					
						
						|  |  | 
					
						
						|  | def forward(self, inputs, latent): | 
					
						
						|  | inputs = inputs | 
					
						
						|  | if self.embed_fn_fine is not None: | 
					
						
						|  | inputs = self.embed_fn_fine(inputs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if latent.shape[1] != self.d_conditional_feature: | 
					
						
						|  | latent = torch.cat([latent, latent], dim=1) | 
					
						
						|  |  | 
					
						
						|  | x = inputs | 
					
						
						|  | for l in range(0, self.num_layers - 1): | 
					
						
						|  | lin = getattr(self, "lin" + str(l)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if l in self.skip_in: | 
					
						
						|  | x = torch.cat([x, inputs], 1) / np.sqrt(2) | 
					
						
						|  |  | 
					
						
						|  | if 0 < l < self.num_layers - 1: | 
					
						
						|  | x = torch.cat([x, latent], 1) | 
					
						
						|  |  | 
					
						
						|  | x = lin(x) | 
					
						
						|  |  | 
					
						
						|  | if l < self.num_layers - 2: | 
					
						
						|  | x = self.activation(x) | 
					
						
						|  |  | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SparseSdfNetwork(nn.Module): | 
					
						
						|  | ''' | 
					
						
						|  | Coarse-to-fine sparse cost regularization network | 
					
						
						|  | return sparse volume feature for extracting sdf | 
					
						
						|  | ''' | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, lod, ch_in, voxel_size, vol_dims, | 
					
						
						|  | hidden_dim=128, activation='softplus', | 
					
						
						|  | cost_type='variance_mean', | 
					
						
						|  | d_pyramid_feature_compress=16, | 
					
						
						|  | regnet_d_out=8, num_sdf_layers=4, | 
					
						
						|  | multires=6, | 
					
						
						|  | ): | 
					
						
						|  | super(SparseSdfNetwork, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self.lod = lod | 
					
						
						|  | self.ch_in = ch_in | 
					
						
						|  | self.voxel_size = voxel_size | 
					
						
						|  | self.vol_dims = torch.tensor(vol_dims) | 
					
						
						|  |  | 
					
						
						|  | self.selected_views_num = 2 | 
					
						
						|  | self.hidden_dim = hidden_dim | 
					
						
						|  | self.activation = activation | 
					
						
						|  | self.cost_type = cost_type | 
					
						
						|  | self.d_pyramid_feature_compress = d_pyramid_feature_compress | 
					
						
						|  | self.gru_fusion = None | 
					
						
						|  |  | 
					
						
						|  | self.regnet_d_out = regnet_d_out | 
					
						
						|  | self.multires = multires | 
					
						
						|  |  | 
					
						
						|  | self.pos_embedder = Embedding(3, self.multires) | 
					
						
						|  |  | 
					
						
						|  | self.compress_layer = ConvBnReLU( | 
					
						
						|  | self.ch_in, self.d_pyramid_feature_compress, 3, 1, 1, | 
					
						
						|  | norm_act=InPlaceABN) | 
					
						
						|  | sparse_ch_in = self.d_pyramid_feature_compress * 2 | 
					
						
						|  |  | 
					
						
						|  | sparse_ch_in = sparse_ch_in + 16 if self.lod > 0 else sparse_ch_in | 
					
						
						|  | self.sparse_costreg_net = SparseCostRegNet( | 
					
						
						|  | d_in=sparse_ch_in, d_out=self.regnet_d_out) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if activation == 'softplus': | 
					
						
						|  | self.activation = nn.Softplus(beta=100) | 
					
						
						|  | else: | 
					
						
						|  | assert activation == 'relu' | 
					
						
						|  | self.activation = nn.ReLU() | 
					
						
						|  |  | 
					
						
						|  | self.sdf_layer = LatentSDFLayer(d_in=3, | 
					
						
						|  | d_out=self.hidden_dim + 1, | 
					
						
						|  | d_hidden=self.hidden_dim, | 
					
						
						|  | n_layers=num_sdf_layers, | 
					
						
						|  | multires=multires, | 
					
						
						|  | geometric_init=True, | 
					
						
						|  | weight_norm=True, | 
					
						
						|  | activation=activation, | 
					
						
						|  | d_conditional_feature=16 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def upsample(self, pre_feat, pre_coords, interval, num=8): | 
					
						
						|  | ''' | 
					
						
						|  |  | 
					
						
						|  | :param pre_feat: (Tensor), features from last level, (N, C) | 
					
						
						|  | :param pre_coords: (Tensor), coordinates from last level, (N, 4) (4 : Batch ind, x, y, z) | 
					
						
						|  | :param interval: interval of voxels, interval = scale ** 2 | 
					
						
						|  | :param num: 1 -> 8 | 
					
						
						|  | :return: up_feat : (Tensor), upsampled features, (N*8, C) | 
					
						
						|  | :return: up_coords: (N*8, 4), upsampled coordinates, (4 : Batch ind, x, y, z) | 
					
						
						|  | ''' | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | pos_list = [1, 2, 3, [1, 2], [1, 3], [2, 3], [1, 2, 3]] | 
					
						
						|  | n, c = pre_feat.shape | 
					
						
						|  | up_feat = pre_feat.unsqueeze(1).expand(-1, num, -1).contiguous() | 
					
						
						|  | up_coords = pre_coords.unsqueeze(1).repeat(1, num, 1).contiguous() | 
					
						
						|  | for i in range(num - 1): | 
					
						
						|  | up_coords[:, i + 1, pos_list[i]] += interval | 
					
						
						|  |  | 
					
						
						|  | up_feat = up_feat.view(-1, c) | 
					
						
						|  | up_coords = up_coords.view(-1, 4) | 
					
						
						|  |  | 
					
						
						|  | return up_feat, up_coords | 
					
						
						|  |  | 
					
						
						|  | def aggregate_multiview_features(self, multiview_features, multiview_masks): | 
					
						
						|  | """ | 
					
						
						|  | aggregate mutli-view features by compute their cost variance | 
					
						
						|  | :param multiview_features: (num of voxels, num_of_views, c) | 
					
						
						|  | :param multiview_masks: (num of voxels, num_of_views) | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  | num_pts, n_views, C = multiview_features.shape | 
					
						
						|  |  | 
					
						
						|  | counts = torch.sum(multiview_masks, dim=1, keepdim=False) | 
					
						
						|  |  | 
					
						
						|  | assert torch.all(counts > 0) | 
					
						
						|  |  | 
					
						
						|  | volume_sum = torch.sum(multiview_features, dim=1, keepdim=False) | 
					
						
						|  | volume_sq_sum = torch.sum(multiview_features ** 2, dim=1, keepdim=False) | 
					
						
						|  |  | 
					
						
						|  | if volume_sum.isnan().sum() > 0: | 
					
						
						|  | import ipdb; ipdb.set_trace() | 
					
						
						|  |  | 
					
						
						|  | del multiview_features | 
					
						
						|  |  | 
					
						
						|  | counts = 1. / (counts + 1e-5) | 
					
						
						|  | costvar = volume_sq_sum * counts[:, None] - (volume_sum * counts[:, None]) ** 2 | 
					
						
						|  |  | 
					
						
						|  | costvar_mean = torch.cat([costvar, volume_sum * counts[:, None]], dim=1) | 
					
						
						|  | del volume_sum, volume_sq_sum, counts | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | return costvar_mean | 
					
						
						|  |  | 
					
						
						|  | def sparse_to_dense_volume(self, coords, feature, vol_dims, interval, device=None): | 
					
						
						|  | """ | 
					
						
						|  | convert the sparse volume into dense volume to enable trilinear sampling | 
					
						
						|  | to save GPU memory; | 
					
						
						|  | :param coords: [num_pts, 3] | 
					
						
						|  | :param feature: [num_pts, C] | 
					
						
						|  | :param vol_dims: [3]  dX, dY, dZ | 
					
						
						|  | :param interval: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if device is None: | 
					
						
						|  | device = feature.device | 
					
						
						|  |  | 
					
						
						|  | coords_int = (coords / interval).to(torch.int64) | 
					
						
						|  | vol_dims = (vol_dims / interval).to(torch.int64) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dense_volume = sparse_to_dense_channel( | 
					
						
						|  | coords_int.to(device), feature.to(device), vol_dims.to(device), | 
					
						
						|  | feature.shape[1], 0, device) | 
					
						
						|  |  | 
					
						
						|  | valid_mask_volume = sparse_to_dense_channel( | 
					
						
						|  | coords_int.to(device), | 
					
						
						|  | torch.ones([feature.shape[0], 1]).to(feature.device), | 
					
						
						|  | vol_dims.to(device), | 
					
						
						|  | 1, 0, device) | 
					
						
						|  |  | 
					
						
						|  | dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) | 
					
						
						|  | valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | return dense_volume, valid_mask_volume | 
					
						
						|  |  | 
					
						
						|  | def get_conditional_volume(self, feature_maps, partial_vol_origin, proj_mats, sizeH=None, sizeW=None, lod=0, | 
					
						
						|  | pre_coords=None, pre_feats=None, | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | :param feature_maps: pyramid features (B,V,C0+C1+C2,H,W) fused pyramid features | 
					
						
						|  | :param partial_vol_origin: [B, 3]  the world coordinates of the volume origin (0,0,0) | 
					
						
						|  | :param proj_mats: projection matrix transform world pts into image space [B,V,4,4] suitable for original image size | 
					
						
						|  | :param sizeH: the H of original image size | 
					
						
						|  | :param sizeW: the W of original image size | 
					
						
						|  | :param pre_coords: the coordinates of sparse volume from the prior lod | 
					
						
						|  | :param pre_feats: the features of sparse volume from the prior lod | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  | device = proj_mats.device | 
					
						
						|  | bs = feature_maps.shape[0] | 
					
						
						|  | N_views = feature_maps.shape[1] | 
					
						
						|  | minimum_visible_views = np.min([1, N_views - 1]) | 
					
						
						|  |  | 
					
						
						|  | outputs = {} | 
					
						
						|  | pts_samples = [] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.compress_layer is not None: | 
					
						
						|  | feats = self.compress_layer(feature_maps[0]) | 
					
						
						|  | else: | 
					
						
						|  | feats = feature_maps[0] | 
					
						
						|  | feats = feats[:, None, :, :, :] | 
					
						
						|  | KRcam = proj_mats.permute(1, 0, 2, 3).contiguous() | 
					
						
						|  | interval = 1 | 
					
						
						|  |  | 
					
						
						|  | if self.lod == 0: | 
					
						
						|  |  | 
					
						
						|  | coords = generate_grid(self.vol_dims, 1)[0] | 
					
						
						|  | coords = coords.view(3, -1).to(device) | 
					
						
						|  | up_coords = [] | 
					
						
						|  | for b in range(bs): | 
					
						
						|  | up_coords.append(torch.cat([torch.ones(1, coords.shape[-1]).to(coords.device) * b, coords])) | 
					
						
						|  | up_coords = torch.cat(up_coords, dim=1).permute(1, 0).contiguous() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | frustum_mask = back_project_sparse_type( | 
					
						
						|  | up_coords, partial_vol_origin, self.voxel_size, | 
					
						
						|  | feats, KRcam, sizeH=sizeH, sizeW=sizeW, only_mask=True) | 
					
						
						|  | frustum_mask = torch.sum(frustum_mask, dim=-1) > minimum_visible_views | 
					
						
						|  | up_coords = up_coords[frustum_mask] | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | assert pre_feats is not None | 
					
						
						|  | assert pre_coords is not None | 
					
						
						|  | up_feat, up_coords = self.upsample(pre_feats, pre_coords, 1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | multiview_features, multiview_masks = back_project_sparse_type( | 
					
						
						|  | up_coords, partial_vol_origin, self.voxel_size, feats, | 
					
						
						|  | KRcam, sizeH=sizeH, sizeW=sizeW) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.lod > 0: | 
					
						
						|  |  | 
					
						
						|  | frustum_mask = torch.sum(multiview_masks, dim=-1) > 1 | 
					
						
						|  | up_feat = up_feat[frustum_mask] | 
					
						
						|  | up_coords = up_coords[frustum_mask] | 
					
						
						|  | multiview_features = multiview_features[frustum_mask] | 
					
						
						|  | multiview_masks = multiview_masks[frustum_mask] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | volume = self.aggregate_multiview_features(multiview_features, multiview_masks) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | del multiview_features, multiview_masks | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if self.lod != 0: | 
					
						
						|  | feat = torch.cat([volume, up_feat], dim=1) | 
					
						
						|  | else: | 
					
						
						|  | feat = volume | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | r_coords = up_coords[:, [1, 2, 3, 0]] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sparse_feat = SparseTensor(feat, r_coords.to( | 
					
						
						|  | torch.int32)) | 
					
						
						|  |  | 
					
						
						|  | feat = self.sparse_costreg_net(sparse_feat) | 
					
						
						|  |  | 
					
						
						|  | dense_volume, valid_mask_volume = self.sparse_to_dense_volume(up_coords[:, 1:], feat, self.vol_dims, interval, | 
					
						
						|  | device=None) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | outputs['dense_volume_scale%d' % self.lod] = dense_volume | 
					
						
						|  | outputs['valid_mask_volume_scale%d' % self.lod] = valid_mask_volume | 
					
						
						|  | outputs['visible_mask_scale%d' % self.lod] = valid_mask_volume | 
					
						
						|  | outputs['coords_scale%d' % self.lod] = generate_grid(self.vol_dims, interval).to(device) | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  | def sdf(self, pts, conditional_volume, lod): | 
					
						
						|  | num_pts = pts.shape[0] | 
					
						
						|  | device = pts.device | 
					
						
						|  | pts_ = pts.clone() | 
					
						
						|  | pts = pts.view(1, 1, 1, num_pts, 3) | 
					
						
						|  |  | 
					
						
						|  | pts = torch.flip(pts, dims=[-1]) | 
					
						
						|  |  | 
					
						
						|  | sampled_feature = grid_sample_3d(conditional_volume, pts) | 
					
						
						|  | sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous().to(device) | 
					
						
						|  |  | 
					
						
						|  | sdf_pts = self.sdf_layer(pts_, sampled_feature) | 
					
						
						|  |  | 
					
						
						|  | outputs = {} | 
					
						
						|  | outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] | 
					
						
						|  | outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] | 
					
						
						|  | outputs['sampled_latent_scale%d' % lod] = sampled_feature | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def sdf_from_sdfvolume(self, pts, sdf_volume, lod=0): | 
					
						
						|  | num_pts = pts.shape[0] | 
					
						
						|  | device = pts.device | 
					
						
						|  | pts_ = pts.clone() | 
					
						
						|  | pts = pts.view(1, 1, 1, num_pts, 3) | 
					
						
						|  |  | 
					
						
						|  | pts = torch.flip(pts, dims=[-1]) | 
					
						
						|  |  | 
					
						
						|  | sdf = torch.nn.functional.grid_sample(sdf_volume, pts, mode='bilinear', align_corners=True, | 
					
						
						|  | padding_mode='border') | 
					
						
						|  | sdf = sdf.view(-1, num_pts).permute(1, 0).contiguous().to(device) | 
					
						
						|  |  | 
					
						
						|  | outputs = {} | 
					
						
						|  | outputs['sdf_pts_scale%d' % lod] = sdf | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def get_sdf_volume(self, conditional_volume, mask_volume, coords_volume, partial_origin): | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | :param conditional_volume: [1,C, dX,dY,dZ] | 
					
						
						|  | :param mask_volume: [1,1, dX,dY,dZ] | 
					
						
						|  | :param coords_volume: [1,3, dX,dY,dZ] | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  | device = conditional_volume.device | 
					
						
						|  | chunk_size = 10240 | 
					
						
						|  |  | 
					
						
						|  | _, C, dX, dY, dZ = conditional_volume.shape | 
					
						
						|  | conditional_volume = conditional_volume.view(C, dX * dY * dZ).permute(1, 0).contiguous() | 
					
						
						|  | mask_volume = mask_volume.view(-1) | 
					
						
						|  | coords_volume = coords_volume.view(3, dX * dY * dZ).permute(1, 0).contiguous() | 
					
						
						|  |  | 
					
						
						|  | pts = coords_volume * self.voxel_size + partial_origin | 
					
						
						|  |  | 
					
						
						|  | sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(device) | 
					
						
						|  |  | 
					
						
						|  | conditional_volume = conditional_volume[mask_volume > 0] | 
					
						
						|  | pts = pts[mask_volume > 0] | 
					
						
						|  | conditional_volume = conditional_volume.split(chunk_size) | 
					
						
						|  | pts = pts.split(chunk_size) | 
					
						
						|  |  | 
					
						
						|  | sdf_all = [] | 
					
						
						|  | for pts_part, feature_part in zip(pts, conditional_volume): | 
					
						
						|  | sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] | 
					
						
						|  | sdf_all.append(sdf_part) | 
					
						
						|  |  | 
					
						
						|  | sdf_all = torch.cat(sdf_all, dim=0) | 
					
						
						|  | sdf_volume[mask_volume > 0] = sdf_all | 
					
						
						|  | sdf_volume = sdf_volume.view(1, 1, dX, dY, dZ) | 
					
						
						|  | return sdf_volume | 
					
						
						|  |  | 
					
						
						|  | def gradient(self, x, conditional_volume, lod): | 
					
						
						|  | """ | 
					
						
						|  | return the gradient of specific lod | 
					
						
						|  | :param x: | 
					
						
						|  | :param lod: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  | x.requires_grad_(True) | 
					
						
						|  |  | 
					
						
						|  | with torch.enable_grad(): | 
					
						
						|  | output = self.sdf(x, conditional_volume, lod) | 
					
						
						|  | y = output['sdf_pts_scale%d' % lod] | 
					
						
						|  |  | 
					
						
						|  | d_output = torch.ones_like(y, requires_grad=False, device=y.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | gradients = torch.autograd.grad( | 
					
						
						|  | outputs=y, | 
					
						
						|  | inputs=x, | 
					
						
						|  | grad_outputs=d_output, | 
					
						
						|  | create_graph=True, | 
					
						
						|  | retain_graph=True, | 
					
						
						|  | only_inputs=True)[0] | 
					
						
						|  | return gradients.unsqueeze(1) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def sparse_to_dense_volume(coords, feature, vol_dims, interval, device=None): | 
					
						
						|  | """ | 
					
						
						|  | convert the sparse volume into dense volume to enable trilinear sampling | 
					
						
						|  | to save GPU memory; | 
					
						
						|  | :param coords: [num_pts, 3] | 
					
						
						|  | :param feature: [num_pts, C] | 
					
						
						|  | :param vol_dims: [3]  dX, dY, dZ | 
					
						
						|  | :param interval: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if device is None: | 
					
						
						|  | device = feature.device | 
					
						
						|  |  | 
					
						
						|  | coords_int = (coords / interval).to(torch.int64) | 
					
						
						|  | vol_dims = (vol_dims / interval).to(torch.int64) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dense_volume = sparse_to_dense_channel( | 
					
						
						|  | coords_int.to(device), feature.to(device), vol_dims.to(device), | 
					
						
						|  | feature.shape[1], 0, device) | 
					
						
						|  |  | 
					
						
						|  | valid_mask_volume = sparse_to_dense_channel( | 
					
						
						|  | coords_int.to(device), | 
					
						
						|  | torch.ones([feature.shape[0], 1]).to(feature.device), | 
					
						
						|  | vol_dims.to(device), | 
					
						
						|  | 1, 0, device) | 
					
						
						|  |  | 
					
						
						|  | dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) | 
					
						
						|  | valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) | 
					
						
						|  |  | 
					
						
						|  | return dense_volume, valid_mask_volume | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SdfVolume(nn.Module): | 
					
						
						|  | def __init__(self, volume, coords=None, type='dense'): | 
					
						
						|  | super(SdfVolume, self).__init__() | 
					
						
						|  | self.volume = torch.nn.Parameter(volume, requires_grad=True) | 
					
						
						|  | self.coords = coords | 
					
						
						|  | self.type = type | 
					
						
						|  |  | 
					
						
						|  | def forward(self): | 
					
						
						|  | return self.volume | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FinetuneOctreeSdfNetwork(nn.Module): | 
					
						
						|  | ''' | 
					
						
						|  | After obtain the conditional volume from generalized network; | 
					
						
						|  | directly optimize the conditional volume | 
					
						
						|  | The conditional volume is still sparse | 
					
						
						|  | ''' | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, voxel_size, vol_dims, | 
					
						
						|  | origin=[-1., -1., -1.], | 
					
						
						|  | hidden_dim=128, activation='softplus', | 
					
						
						|  | regnet_d_out=8, | 
					
						
						|  | multires=6, | 
					
						
						|  | if_fitted_rendering=True, | 
					
						
						|  | num_sdf_layers=4, | 
					
						
						|  | ): | 
					
						
						|  | super(FinetuneOctreeSdfNetwork, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self.voxel_size = voxel_size | 
					
						
						|  | self.vol_dims = torch.tensor(vol_dims) | 
					
						
						|  |  | 
					
						
						|  | self.origin = torch.tensor(origin).to(torch.float32) | 
					
						
						|  |  | 
					
						
						|  | self.hidden_dim = hidden_dim | 
					
						
						|  | self.activation = activation | 
					
						
						|  |  | 
					
						
						|  | self.regnet_d_out = regnet_d_out | 
					
						
						|  |  | 
					
						
						|  | self.if_fitted_rendering = if_fitted_rendering | 
					
						
						|  | self.multires = multires | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.sparse_volume_lod0 = None | 
					
						
						|  | self.sparse_coords_lod0 = None | 
					
						
						|  |  | 
					
						
						|  | if activation == 'softplus': | 
					
						
						|  | self.activation = nn.Softplus(beta=100) | 
					
						
						|  | else: | 
					
						
						|  | assert activation == 'relu' | 
					
						
						|  | self.activation = nn.ReLU() | 
					
						
						|  |  | 
					
						
						|  | self.sdf_layer = LatentSDFLayer(d_in=3, | 
					
						
						|  | d_out=self.hidden_dim + 1, | 
					
						
						|  | d_hidden=self.hidden_dim, | 
					
						
						|  | n_layers=num_sdf_layers, | 
					
						
						|  | multires=multires, | 
					
						
						|  | geometric_init=True, | 
					
						
						|  | weight_norm=True, | 
					
						
						|  | activation=activation, | 
					
						
						|  | d_conditional_feature=16 | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.renderer = None | 
					
						
						|  |  | 
					
						
						|  | d_in_renderer = 3 + self.regnet_d_out + 3 + 3 | 
					
						
						|  | self.renderer = BlendingRenderingNetwork( | 
					
						
						|  | d_feature=self.hidden_dim - 1, | 
					
						
						|  | mode='idr', | 
					
						
						|  | d_in=d_in_renderer, | 
					
						
						|  | d_out=50, | 
					
						
						|  | d_hidden=self.hidden_dim, | 
					
						
						|  | n_layers=3, | 
					
						
						|  | weight_norm=True, | 
					
						
						|  | multires_view=4, | 
					
						
						|  | squeeze_out=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | def initialize_conditional_volumes(self, dense_volume_lod0, dense_volume_mask_lod0, | 
					
						
						|  | sparse_volume_lod0=None, sparse_coords_lod0=None): | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | :param dense_volume_lod0: [1,C,dX,dY,dZ] | 
					
						
						|  | :param dense_volume_mask_lod0: [1,1,dX,dY,dZ] | 
					
						
						|  | :param dense_volume_lod1: | 
					
						
						|  | :param dense_volume_mask_lod1: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if sparse_volume_lod0 is None: | 
					
						
						|  | device = dense_volume_lod0.device | 
					
						
						|  | _, C, dX, dY, dZ = dense_volume_lod0.shape | 
					
						
						|  |  | 
					
						
						|  | dense_volume_lod0 = dense_volume_lod0.view(C, dX * dY * dZ).permute(1, 0).contiguous() | 
					
						
						|  | mask_lod0 = dense_volume_mask_lod0.view(dX * dY * dZ) > 0 | 
					
						
						|  |  | 
					
						
						|  | self.sparse_volume_lod0 = SdfVolume(dense_volume_lod0[mask_lod0], type='sparse') | 
					
						
						|  |  | 
					
						
						|  | coords = generate_grid(self.vol_dims, 1)[0] | 
					
						
						|  | coords = coords.view(3, dX * dY * dZ).permute(1, 0).to(device) | 
					
						
						|  | self.sparse_coords_lod0 = torch.nn.Parameter(coords[mask_lod0], requires_grad=False) | 
					
						
						|  | else: | 
					
						
						|  | self.sparse_volume_lod0 = SdfVolume(sparse_volume_lod0, type='sparse') | 
					
						
						|  | self.sparse_coords_lod0 = torch.nn.Parameter(sparse_coords_lod0, requires_grad=False) | 
					
						
						|  |  | 
					
						
						|  | def get_conditional_volume(self): | 
					
						
						|  | dense_volume, valid_mask_volume = sparse_to_dense_volume( | 
					
						
						|  | self.sparse_coords_lod0, | 
					
						
						|  | self.sparse_volume_lod0(), self.vol_dims, interval=1, | 
					
						
						|  | device=None) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | outputs = {} | 
					
						
						|  | outputs['dense_volume_scale%d' % 0] = dense_volume | 
					
						
						|  | outputs['valid_mask_volume_scale%d' % 0] = valid_mask_volume | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  | def tv_regularizer(self): | 
					
						
						|  | dense_volume, valid_mask_volume = sparse_to_dense_volume( | 
					
						
						|  | self.sparse_coords_lod0, | 
					
						
						|  | self.sparse_volume_lod0(), self.vol_dims, interval=1, | 
					
						
						|  | device=None) | 
					
						
						|  |  | 
					
						
						|  | dx = (dense_volume[:, :, 1:, :, :] - dense_volume[:, :, :-1, :, :]) ** 2 | 
					
						
						|  | dy = (dense_volume[:, :, :, 1:, :] - dense_volume[:, :, :, :-1, :]) ** 2 | 
					
						
						|  | dz = (dense_volume[:, :, :, :, 1:] - dense_volume[:, :, :, :, :-1]) ** 2 | 
					
						
						|  |  | 
					
						
						|  | tv = dx[:, :, :, :-1, :-1] + dy[:, :, :-1, :, :-1] + dz[:, :, :-1, :-1, :] | 
					
						
						|  |  | 
					
						
						|  | mask = valid_mask_volume[:, :, :-1, :-1, :-1] * valid_mask_volume[:, :, 1:, :-1, :-1] * \ | 
					
						
						|  | valid_mask_volume[:, :, :-1, 1:, :-1] * valid_mask_volume[:, :, :-1, :-1, 1:] | 
					
						
						|  |  | 
					
						
						|  | tv = torch.sqrt(tv + 1e-6).mean(dim=1, keepdim=True) * mask | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | assert torch.all(~torch.isnan(tv)) | 
					
						
						|  |  | 
					
						
						|  | return torch.mean(tv) | 
					
						
						|  |  | 
					
						
						|  | def sdf(self, pts, conditional_volume, lod): | 
					
						
						|  |  | 
					
						
						|  | outputs = {} | 
					
						
						|  |  | 
					
						
						|  | num_pts = pts.shape[0] | 
					
						
						|  | device = pts.device | 
					
						
						|  | pts_ = pts.clone() | 
					
						
						|  | pts = pts.view(1, 1, 1, num_pts, 3) | 
					
						
						|  |  | 
					
						
						|  | pts = torch.flip(pts, dims=[-1]) | 
					
						
						|  |  | 
					
						
						|  | sampled_feature = grid_sample_3d(conditional_volume, pts) | 
					
						
						|  | sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous() | 
					
						
						|  | outputs['sampled_latent_scale%d' % lod] = sampled_feature | 
					
						
						|  |  | 
					
						
						|  | sdf_pts = self.sdf_layer(pts_, sampled_feature) | 
					
						
						|  |  | 
					
						
						|  | lod = 0 | 
					
						
						|  | outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] | 
					
						
						|  | outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] | 
					
						
						|  |  | 
					
						
						|  | return outputs | 
					
						
						|  |  | 
					
						
						|  | def color_blend(self, pts, position, normals, view_dirs, feature_vectors, img_index, | 
					
						
						|  | pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): | 
					
						
						|  |  | 
					
						
						|  | return self.renderer(torch.cat([pts, position], dim=-1), normals, view_dirs, feature_vectors, | 
					
						
						|  | img_index, pts_pixel_color, pts_pixel_mask, | 
					
						
						|  | pts_patch_color=pts_patch_color, pts_patch_mask=pts_patch_mask) | 
					
						
						|  |  | 
					
						
						|  | def gradient(self, x, conditional_volume, lod): | 
					
						
						|  | """ | 
					
						
						|  | return the gradient of specific lod | 
					
						
						|  | :param x: | 
					
						
						|  | :param lod: | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  | x.requires_grad_(True) | 
					
						
						|  | output = self.sdf(x, conditional_volume, lod) | 
					
						
						|  | y = output['sdf_pts_scale%d' % 0] | 
					
						
						|  |  | 
					
						
						|  | d_output = torch.ones_like(y, requires_grad=False, device=y.device) | 
					
						
						|  |  | 
					
						
						|  | gradients = torch.autograd.grad( | 
					
						
						|  | outputs=y, | 
					
						
						|  | inputs=x, | 
					
						
						|  | grad_outputs=d_output, | 
					
						
						|  | create_graph=True, | 
					
						
						|  | retain_graph=True, | 
					
						
						|  | only_inputs=True)[0] | 
					
						
						|  | return gradients.unsqueeze(1) | 
					
						
						|  |  | 
					
						
						|  | @torch.no_grad() | 
					
						
						|  | def prune_dense_mask(self, threshold=0.02): | 
					
						
						|  | """ | 
					
						
						|  | Just gradually prune the mask of dense volume to decrease the number of sdf network inference | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  | chunk_size = 10240 | 
					
						
						|  | coords = generate_grid(self.vol_dims_lod0, 1)[0] | 
					
						
						|  |  | 
					
						
						|  | _, dX, dY, dZ = coords.shape | 
					
						
						|  |  | 
					
						
						|  | pts = coords.view(3, -1).permute(1, | 
					
						
						|  | 0).contiguous() * self.voxel_size_lod0 + self.origin[None, :] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | dense_volume, _ = sparse_to_dense_volume( | 
					
						
						|  | self.sparse_coords_lod0, | 
					
						
						|  | self.sparse_volume_lod0(), self.vol_dims_lod0, interval=1, | 
					
						
						|  | device=None) | 
					
						
						|  |  | 
					
						
						|  | sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(dense_volume.device) * 100 | 
					
						
						|  |  | 
					
						
						|  | mask = self.dense_volume_mask_lod0.view(-1) > 0 | 
					
						
						|  |  | 
					
						
						|  | pts_valid = pts[mask].to(dense_volume.device) | 
					
						
						|  | feature_valid = dense_volume.view(self.regnet_d_out, -1).permute(1, 0).contiguous()[mask] | 
					
						
						|  |  | 
					
						
						|  | pts_valid = pts_valid.split(chunk_size) | 
					
						
						|  | feature_valid = feature_valid.split(chunk_size) | 
					
						
						|  |  | 
					
						
						|  | sdf_list = [] | 
					
						
						|  |  | 
					
						
						|  | for pts_part, feature_part in zip(pts_valid, feature_valid): | 
					
						
						|  | sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] | 
					
						
						|  | sdf_list.append(sdf_part) | 
					
						
						|  |  | 
					
						
						|  | sdf_list = torch.cat(sdf_list, dim=0) | 
					
						
						|  |  | 
					
						
						|  | sdf_volume[mask] = sdf_list | 
					
						
						|  |  | 
					
						
						|  | occupancy_mask = torch.abs(sdf_volume) < threshold | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | occupancy_mask = occupancy_mask.float() | 
					
						
						|  | occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) | 
					
						
						|  | occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) | 
					
						
						|  | occupancy_mask = occupancy_mask > 0 | 
					
						
						|  |  | 
					
						
						|  | self.dense_volume_mask_lod0 = torch.logical_and(self.dense_volume_mask_lod0, | 
					
						
						|  | occupancy_mask).float() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BlendingRenderingNetwork(nn.Module): | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | d_feature, | 
					
						
						|  | mode, | 
					
						
						|  | d_in, | 
					
						
						|  | d_out, | 
					
						
						|  | d_hidden, | 
					
						
						|  | n_layers, | 
					
						
						|  | weight_norm=True, | 
					
						
						|  | multires_view=0, | 
					
						
						|  | squeeze_out=True, | 
					
						
						|  | ): | 
					
						
						|  | super(BlendingRenderingNetwork, self).__init__() | 
					
						
						|  |  | 
					
						
						|  | self.mode = mode | 
					
						
						|  | self.squeeze_out = squeeze_out | 
					
						
						|  | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] | 
					
						
						|  |  | 
					
						
						|  | self.embedder = None | 
					
						
						|  | if multires_view > 0: | 
					
						
						|  | self.embedder = Embedding(3, multires_view) | 
					
						
						|  | dims[0] += (self.embedder.out_channels - 3) | 
					
						
						|  |  | 
					
						
						|  | self.num_layers = len(dims) | 
					
						
						|  |  | 
					
						
						|  | for l in range(0, self.num_layers - 1): | 
					
						
						|  | out_dim = dims[l + 1] | 
					
						
						|  | lin = nn.Linear(dims[l], out_dim) | 
					
						
						|  |  | 
					
						
						|  | if weight_norm: | 
					
						
						|  | lin = nn.utils.weight_norm(lin) | 
					
						
						|  |  | 
					
						
						|  | setattr(self, "lin" + str(l), lin) | 
					
						
						|  |  | 
					
						
						|  | self.relu = nn.ReLU() | 
					
						
						|  |  | 
					
						
						|  | self.color_volume = None | 
					
						
						|  |  | 
					
						
						|  | self.softmax = nn.Softmax(dim=1) | 
					
						
						|  |  | 
					
						
						|  | self.type = 'blending' | 
					
						
						|  |  | 
					
						
						|  | def sample_pts_from_colorVolume(self, pts): | 
					
						
						|  | device = pts.device | 
					
						
						|  | num_pts = pts.shape[0] | 
					
						
						|  | pts_ = pts.clone() | 
					
						
						|  | pts = pts.view(1, 1, 1, num_pts, 3) | 
					
						
						|  |  | 
					
						
						|  | pts = torch.flip(pts, dims=[-1]) | 
					
						
						|  |  | 
					
						
						|  | sampled_color = grid_sample_3d(self.color_volume, pts) | 
					
						
						|  | sampled_color = sampled_color.view(-1, num_pts).permute(1, 0).contiguous().to(device) | 
					
						
						|  |  | 
					
						
						|  | return sampled_color | 
					
						
						|  |  | 
					
						
						|  | def forward(self, position, normals, view_dirs, feature_vectors, img_index, | 
					
						
						|  | pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | :param position: can be 3d coord or interpolated volume latent | 
					
						
						|  | :param normals: | 
					
						
						|  | :param view_dirs: | 
					
						
						|  | :param feature_vectors: | 
					
						
						|  | :param img_index: [N_views], used to extract corresponding weights | 
					
						
						|  | :param pts_pixel_color: [N_pts, N_views, 3] | 
					
						
						|  | :param pts_pixel_mask: [N_pts, N_views] | 
					
						
						|  | :param pts_patch_color: [N_pts, N_views, Npx, 3] | 
					
						
						|  | :return: | 
					
						
						|  | """ | 
					
						
						|  | if self.embedder is not None: | 
					
						
						|  | view_dirs = self.embedder(view_dirs) | 
					
						
						|  |  | 
					
						
						|  | rendering_input = None | 
					
						
						|  |  | 
					
						
						|  | if self.mode == 'idr': | 
					
						
						|  | rendering_input = torch.cat([position, view_dirs, normals, feature_vectors], dim=-1) | 
					
						
						|  | elif self.mode == 'no_view_dir': | 
					
						
						|  | rendering_input = torch.cat([position, normals, feature_vectors], dim=-1) | 
					
						
						|  | elif self.mode == 'no_normal': | 
					
						
						|  | rendering_input = torch.cat([position, view_dirs, feature_vectors], dim=-1) | 
					
						
						|  | elif self.mode == 'no_points': | 
					
						
						|  | rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) | 
					
						
						|  | elif self.mode == 'no_points_no_view_dir': | 
					
						
						|  | rendering_input = torch.cat([normals, feature_vectors], dim=-1) | 
					
						
						|  |  | 
					
						
						|  | x = rendering_input | 
					
						
						|  |  | 
					
						
						|  | for l in range(0, self.num_layers - 1): | 
					
						
						|  | lin = getattr(self, "lin" + str(l)) | 
					
						
						|  |  | 
					
						
						|  | x = lin(x) | 
					
						
						|  |  | 
					
						
						|  | if l < self.num_layers - 2: | 
					
						
						|  | x = self.relu(x) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | x_extracted = torch.index_select(x, 1, img_index.long()) | 
					
						
						|  |  | 
					
						
						|  | weights_pixel = self.softmax(x_extracted) | 
					
						
						|  | weights_pixel = weights_pixel * pts_pixel_mask | 
					
						
						|  | weights_pixel = weights_pixel / ( | 
					
						
						|  | torch.sum(weights_pixel.float(), dim=1, keepdim=True) + 1e-8) | 
					
						
						|  | final_pixel_color = torch.sum(pts_pixel_color * weights_pixel[:, :, None], dim=1, | 
					
						
						|  | keepdim=False) | 
					
						
						|  |  | 
					
						
						|  | final_pixel_mask = torch.sum(pts_pixel_mask.float(), dim=1, keepdim=True) > 0 | 
					
						
						|  |  | 
					
						
						|  | final_patch_color, final_patch_mask = None, None | 
					
						
						|  |  | 
					
						
						|  | if pts_patch_color is not None: | 
					
						
						|  | N_pts, N_views, Npx, _ = pts_patch_color.shape | 
					
						
						|  | patch_mask = torch.sum(pts_patch_mask, dim=-1, keepdim=False) > Npx - 1 | 
					
						
						|  |  | 
					
						
						|  | weights_patch = self.softmax(x_extracted) | 
					
						
						|  | weights_patch = weights_patch * patch_mask | 
					
						
						|  | weights_patch = weights_patch / ( | 
					
						
						|  | torch.sum(weights_patch.float(), dim=1, keepdim=True) + 1e-8) | 
					
						
						|  |  | 
					
						
						|  | final_patch_color = torch.sum(pts_patch_color * weights_patch[:, :, None, None], dim=1, | 
					
						
						|  | keepdim=False) | 
					
						
						|  | final_patch_mask = torch.sum(patch_mask, dim=1, keepdim=True) > 0 | 
					
						
						|  |  | 
					
						
						|  | return final_pixel_color, final_pixel_mask, final_patch_color, final_patch_mask | 
					
						
						|  |  |