Skip to content

problems about depth2normal #213

Description

@wyf0414

Hello, thanks for open source your great works!

I've met problems about depth2normal.

I use your model metric3d_vit_giant2 with my image as input to obtain depth and normal.
I call them gt_depth and gt_normal respectively.
Then I use the code depth2normal.py with gt_depth as input to obtain "pred_normal".
Then I Calculate the cosine similarity loss between gt_normal and pred_normal.
In theory, it should be close to 0 but I got 0.1378.
I also visualized them, they are very different.
So where does the problem lie?

img_origin:

Image
gt_normal:

Image
pred_normal:

Image

my code:
The scale_factor is 1/3 because the size of gt_depth and gt_normal is 1/3 of the original image.

import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from mmcv.runner import force_fp32
from torch.cuda.amp.autocast_mode import autocast
import os
import cv2
from torch.nn import init
from collections import OrderedDict
from torch.utils.checkpoint import checkpoint

def get_surface_normalv2(xyz, patch_size=5, mask_valid=None):
"""
xyz: xyz coordinates, in [b, h, w, c]
patch: [p1, p2, p3,
p4, p5, p6,
p7, p8, p9]
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
return: normal [b, c, h, w]
"""
b, h, w, c = xyz.shape
half_patch = patch_size // 2

if mask_valid == None:
    mask_valid = xyz[:, :, :, 2] > 0 # [b, h, w]
mask_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1), device=mask_valid.device).bool()
mask_pad[:, half_patch:-half_patch, half_patch:-half_patch] = mask_valid

xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz

xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :]  # p4
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :]  # p6
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :]  # p2
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :]  # p8
xyz_horizon = xyz_left - xyz_right  # p4p6
xyz_vertical = xyz_top - xyz_bottom  # p2p8

xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :]  # p4
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :]  # p6
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :]  # p2
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :]  # p8
xyz_horizon_in = xyz_left_in - xyz_right_in  # p4p6
xyz_vertical_in = xyz_top_in - xyz_bottom_in  # p2p8

n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)

# re-orient normals consistently
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
n_img_1[orient_mask] *= -1
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
n_img_2[orient_mask] *= -1

n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True)  + 1e-4)
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)

n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True)  + 1e-4)
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)

# average 2 norms
n_img_aver = n_img1_norm + n_img2_norm
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True) + 1e-4)
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
# re-orient normals consistently
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
n_img_aver_norm[orient_mask] *= -1
#n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0))  # [h, w, c, b]

# get mask for normals
mask_p4p6 = mask_pad[:, half_patch:half_patch + h, :w] & mask_pad[:, half_patch:half_patch + h, -w:]
mask_p2p8 = mask_pad[:, :h, half_patch:half_patch + w] & mask_pad[:, -h:, half_patch:half_patch + w]
mask_normal = mask_p2p8 & mask_p4p6
n_img_aver_norm[~mask_normal] = 0

return n_img_aver_norm.permute(0, 3, 1, 2).contiguous(), mask_normal[:, None, :, :] # [b, h, w, 3]

class Depth2Normal(nn.Module):
"""Layer to compute surface normal from depth map
"""
def init(self,):
super(Depth2Normal, self).init()

def init_img_coor(self, height, width):
    """
    Args:
        height (int): image height
        width (int): image width
    """
    y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device="cuda"),
                           torch.arange(0, width, dtype=torch.float32, device="cuda")], indexing='ij')
    meshgrid = torch.stack((x, y))

    # generate homogeneous pixel coordinates
    ones = torch.ones((1, 1, height * width), device="cuda")
    xy = meshgrid.reshape(2, -1).unsqueeze(0)
    xy = torch.cat([xy, ones], 1)
    
    self.register_buffer('xy', xy, persistent=False)

def back_projection(self, depth, inv_K, img_like_out=False, scale=1.0):
    """
    Args:
        depth (Nx1xHxW): depth map
        inv_K (Nx3x3): inverse camera intrinsics
        img_like_out (bool): if True, the output shape is Nx4xHxW; else Nx4x(HxW)
    Returns:
        points (Nx4x(HxW)): 3D points in homogeneous coordinates
    """
    N, C, H, W = depth.shape
    depth = depth.contiguous()
    xy = self.xy
    
    points = torch.matmul(inv_K, xy)
    points = depth.view(depth.shape[0], 1, -1) * points
    depth_descale = points[:, 2:3, :] / scale
    points = torch.cat((points[:, 0:2, :], depth_descale), dim=1)

    if img_like_out:
        points = points.reshape(depth.shape[0], 3, H, W)
    return points

def forward(self, depth, intrinsics, scale=1.0):
    """
    Args:
        depth (Nx1xHxW): depth map
        intrinsics (Nx3x3): camera intrinsics
    Returns:
        normal (Nx3xHxW): normalized surface normal
    """
    N, C, H, W = depth.shape
    if 'xy' not in self._buffers or self.xy.shape[-1] != H*W:
        self.init_img_coor(height=H, width=W)
    # Compute 3D point cloud
    inv_K = intrinsics.inverse()
    
    xyz = self.back_projection(depth, inv_K, scale=scale) # [N, 4, HxW]
    print("Point cloud stats - X:", xyz[:, 0].mean(), "Y:", xyz[:, 1].mean(), "Z:", xyz[:, 2].mean())
    print("Point cloud valid ratio:", (xyz[:, 2] > 0).float().mean())  # 检查有效深度比例

    xyz = xyz.view(depth.shape[0], 3, H, W)
    xyz = xyz[:,:3].permute(0, 2, 3, 1).contiguous() # [b, h, w, c]

    normals, normal_masks = get_surface_normalv2(xyz, mask_valid=None)
    return normals

class NormalBranchLoss(nn.Module):
def init(self,):
super(NormalBranchLoss, self).init()

def forward(self, pred_normal, gt_normal):
    gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True)

    loss = self.forward_R(pred_normal, gt_normal, gt_normal_mask)

    return loss

def forward_R(self, pred_norm, gt_norm, gt_norm_mask):
    dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
    print("Cosine similarity - Mean:", dot.mean(), "Min:", dot.min(), "Max:", dot.max())
    reverse_mask = dot < 0  # 找到方向相反的区域
    reverse_ratio = reverse_mask.float().mean()
    print("Reverse normal ratio:", reverse_ratio.item())

    valid_mask = gt_norm_mask.squeeze(0)
    # valid_mask = gt_norm_mask[:, 0, :, :].float() \
    #                 * (dot.detach() < 0.999).float() \
    #                 * (dot.detach() > -0.999).float()
    valid_mask = valid_mask > 0.5

    dot = dot[valid_mask]
    loss = 1 - torch.mean(dot)

    return loss

if name == 'main':
gt_depth = torch.load("/home/zhudd/Metric3D/vis_nearest/pred_depth.pth").unsqueeze(0).unsqueeze(0)
gt_normal = torch.load("/home/zhudd/Metric3D/vis_nearest/pred_normal.pth").unsqueeze(0)
dst_dir = "d2n_nearest/"

intrins = [[601.5645751953125, 0.0, 963.6561781434721], [0.0, 754.5694580078125, 786.8377374531847], [0.0, 0.0, 1.0]]
intrins = torch.tensor(intrins, dtype=torch.float32).unsqueeze(0).cuda()
scale_factor = 1/3
intrins[:, :2, :] *= scale_factor
d2n = Depth2Normal()
pred_normal = d2n(gt_depth, intrins)
pred_normal = F.normalize(pred_normal, p=2, dim=1)

normal_loss_func = NormalBranchLoss()
loss_depth_metric3d = normal_loss_func(pred_normal, gt_normal)
print("loss_depth_metric3d: ", loss_depth_metric3d)

# os.makedirs(dst_dir, exist_ok=True)
# pred_normal_vis_ = pred_normal.squeeze(0).cpu().detach().numpy().transpose((1, 2, 0))
# pred_normal_vis = (pred_normal_vis_ + 1) / 2
# viz_normal = np.copy(pred_normal_vis)
# max_v, min_v = np.max(pred_normal_vis),np.min(pred_normal_vis)
# viz_normal = (pred_normal_vis - min_v) / (max_v - min_v)
# viz_normal = (viz_normal * 255).astype(np.uint8)
# normal_image = cv2.cvtColor(viz_normal, cv2.COLOR_RGB2BGR)
# cv2.imwrite(dst_dir+"pred_normal.jpg", normal_image)

# pred_normal_vis_ = gt_normal.cpu().squeeze(0).detach().numpy().transpose((1, 2, 0))
# pred_normal_vis = (pred_normal_vis_ + 1) / 2
# viz_normal = np.copy(pred_normal_vis)
# max_v, min_v = np.max(pred_normal_vis),np.min(pred_normal_vis)
# viz_normal = (pred_normal_vis - min_v) / (max_v - min_v)
# viz_normal = (viz_normal * 255).astype(np.uint8)
# normal_image = cv2.cvtColor(viz_normal, cv2.COLOR_RGB2BGR)
# cv2.imwrite(dst_dir+"gt_normal.jpg", normal_image)
# print("write ok")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions