File size: 7,341 Bytes
			
			| 854f0d0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | import torch
from torch.nn.functional import grid_sample
def back_project_sparse_type(coords, origin, voxel_size, feats, KRcam, sizeH=None, sizeW=None, only_mask=False,
                             with_proj_z=False):
    # - modified version from NeuRecon
    '''
    Unproject the image fetures to form a 3D (sparse) feature volume
    :param coords: coordinates of voxels,
    dim: (num of voxels, 4) (4 : batch ind, x, y, z)
    :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0))
    dim: (batch size, 3) (3: x, y, z)
    :param voxel_size: floats specifying the size of a voxel
    :param feats: image features
    dim: (num of views, batch size, C, H, W)
    :param KRcam: projection matrix
    dim: (num of views, batch size, 4, 4)
    :return: feature_volume_all: 3D feature volumes
    dim: (num of voxels, num_of_views, c)
    :return: mask_volume_all: indicate the voxel of sampled feature volume is valid or not
    dim: (num of voxels, num_of_views)
    '''
    n_views, bs, c, h, w = feats.shape
    device = feats.device
    if sizeH is None:
        sizeH, sizeW = h, w  # - if the KRcam is not suitable for the current feats
    feature_volume_all = torch.zeros(coords.shape[0], n_views, c).to(device)
    mask_volume_all = torch.zeros([coords.shape[0], n_views], dtype=torch.int32).to(device)
    # import ipdb; ipdb.set_trace()
    for batch in range(bs):
        # import ipdb; ipdb.set_trace()
        batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1)
        coords_batch = coords[batch_ind][:, 1:]
        coords_batch = coords_batch.view(-1, 3)
        origin_batch = origin[batch].unsqueeze(0)
        feats_batch = feats[:, batch]
        proj_batch = KRcam[:, batch]
        grid_batch = coords_batch * voxel_size + origin_batch.float()
        rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1)
        rs_grid = rs_grid.permute(0, 2, 1).contiguous()
        nV = rs_grid.shape[-1]
        rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1)
        # Project grid
        im_p = proj_batch @ rs_grid  # - transform world pts to image UV space
        im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2]
        im_z[im_z >= 0] = im_z[im_z >= 0].clamp(min=1e-6)
        im_x = im_x / im_z
        im_y = im_y / im_z
        im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1)
        mask = im_grid.abs() <= 1
        mask = (mask.sum(dim=-1) == 2) & (im_z > 0)
        mask = mask.view(n_views, -1)
        mask = mask.permute(1, 0).contiguous()  # [num_pts, nviews]
        mask_volume_all[batch_ind] = mask.to(torch.int32)
        if only_mask:
            return mask_volume_all
        feats_batch = feats_batch.view(n_views, c, h, w)
        im_grid = im_grid.view(n_views, 1, -1, 2)
        features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True)
        # if features.isnan().sum() > 0:
        #     import ipdb; ipdb.set_trace()
        features = features.view(n_views, c, -1)
        features = features.permute(2, 0, 1).contiguous()  # [num_pts, nviews, c]
        feature_volume_all[batch_ind] = features
        if with_proj_z:
            im_z = im_z.view(n_views, 1, -1).permute(2, 0, 1).contiguous()  # [num_pts, nviews, 1]
            return feature_volume_all, mask_volume_all, im_z
    # if feature_volume_all.isnan().sum() > 0:
    #     import ipdb; ipdb.set_trace()
    return feature_volume_all, mask_volume_all
def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode, sizeH=None, sizeW=None, with_depth=False):
    """Transform coordinates in the camera frame to the pixel frame.
    Args:
        cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W]
        proj_c2p_rot: rotation matrix of cameras -- [B, 3, 3]
        proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
    Returns:
        array of [-1,1] coordinates -- [B, H, W, 2]
    """
    b, _, h, w = cam_coords.size()
    if sizeH is None:
        sizeH = h
        sizeW = w
    cam_coords_flat = cam_coords.view(b, 3, -1)  # [B, 3, H*W]
    if proj_c2p_rot is not None:
        pcoords = proj_c2p_rot.bmm(cam_coords_flat)
    else:
        pcoords = cam_coords_flat
    if proj_c2p_tr is not None:
        pcoords = pcoords + proj_c2p_tr  # [B, 3, H*W]
    X = pcoords[:, 0]
    Y = pcoords[:, 1]
    Z = pcoords[:, 2].clamp(min=1e-3)
    X_norm = 2 * (X / Z) / (sizeW - 1) - 1  # Normalized, -1 if on extreme left,
    # 1 if on extreme right (x = w-1) [B, H*W]
    Y_norm = 2 * (Y / Z) / (sizeH - 1) - 1  # Idem [B, H*W]
    if padding_mode == 'zeros':
        X_mask = ((X_norm > 1) + (X_norm < -1)).detach()
        X_norm[X_mask] = 2  # make sure that no point in warped image is a combinaison of im and gray
        Y_mask = ((Y_norm > 1) + (Y_norm < -1)).detach()
        Y_norm[Y_mask] = 2
    if with_depth:
        pixel_coords = torch.stack([X_norm, Y_norm, Z], dim=2)  # [B, H*W, 3]
        return pixel_coords.view(b, h, w, 3)
    else:
        pixel_coords = torch.stack([X_norm, Y_norm], dim=2)  # [B, H*W, 2]
        return pixel_coords.view(b, h, w, 2)
# * have already checked, should check whether proj_matrix is for right coordinate system and resolution
def back_project_dense_type(coords, origin, voxel_size, feats, proj_matrix, sizeH=None, sizeW=None):
    '''
    Unproject the image fetures to form a 3D (dense) feature volume
    :param coords: coordinates of voxels,
    dim: (batch, nviews, 3, X,Y,Z)
    :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0))
    dim: (batch size, 3) (3: x, y, z)
    :param voxel_size: floats specifying the size of a voxel
    :param feats: image features
    dim: (batch size, num of views,  C, H, W)
    :param proj_matrix: projection matrix
    dim: (batch size, num of views, 4, 4)
    :return: feature_volume_all: 3D feature volumes
    dim: (batch, nviews, C, X,Y,Z)
    :return: count: number of times each voxel can be seen
    dim: (batch, nviews, 1, X,Y,Z)
    '''
    batch, nviews, _, wX, wY, wZ = coords.shape
    if sizeH is None:
        sizeH, sizeW = feats.shape[-2:]
    proj_matrix = proj_matrix.view(batch * nviews, *proj_matrix.shape[2:])
    coords_wrd = coords * voxel_size + origin.view(batch, 1, 3, 1, 1, 1)
    coords_wrd = coords_wrd.view(batch * nviews, 3, wX * wY * wZ, 1)  # (b*nviews,3,wX*wY*wZ, 1)
    pixel_grids = cam2pixel(coords_wrd, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:],
                            'zeros', sizeH=sizeH, sizeW=sizeW)  # (b*nviews,wX*wY*wZ, 2)
    pixel_grids = pixel_grids.view(batch * nviews, 1, wX * wY * wZ, 2)
    feats = feats.view(batch * nviews, *feats.shape[2:])  # (b*nviews,c,h,w)
    ones = torch.ones((batch * nviews, 1, *feats.shape[2:])).to(feats.dtype).to(feats.device)
    features_volume = torch.nn.functional.grid_sample(feats, pixel_grids, padding_mode='zeros', align_corners=True)
    counts_volume = torch.nn.functional.grid_sample(ones, pixel_grids, padding_mode='zeros', align_corners=True)
    features_volume = features_volume.view(batch, nviews, -1, wX, wY, wZ)  # (batch, nviews, C, X,Y,Z)
    counts_volume = counts_volume.view(batch, nviews, -1, wX, wY, wZ)
    return features_volume, counts_volume
 | 
