Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The loss of hand axis angle pose will make the effect worse #7

Open
youngstu opened this issue Jun 17, 2021 · 2 comments
Open

The loss of hand axis angle pose will make the effect worse #7

youngstu opened this issue Jun 17, 2021 · 2 comments

Comments

@youngstu
Copy link

youngstu commented Jun 17, 2021

I reproduced the hand training module and found that the loss of hand axis angle pose may make the effect worse.
The data verification is correct. After the loss of axis angle is added, the hand often turns forward and backward.

import torch
import torch.nn as nn

class ManoLoss:
    def __init__(
            self,
            lambda_pose=100.0,
            lambda_shape=100.0,
            lambda_joint3d=1.0,
            lambda_kp2d=1.0,
    ):
        self.lambda_pose = lambda_pose
        self.lambda_shape = lambda_shape
        self.lambda_joint3d = lambda_joint3d
        self.lambda_kp2d = lambda_kp2d

        self.criterion_pose = nn.MSELoss().cuda()
        self.criterion_shape = nn.MSELoss().cuda()
        self.criterion_joint3d = nn.MSELoss().cuda()
        self.criterion_kp2d = nn.MSELoss().cuda()

    def compute_loss(self, preds, targs, infos):

        inp_res = infos['inp_res']
        root_id = infos['root_id']
        batch_size = infos['batch_size']
        flag = targs['flag_3d']
        batch_3d_size = flag.sum()

        flag = flag.bool()

        total_loss = torch.Tensor([0]).cuda()
        mano_losses = {}

        gt_pose = targs['pose']
        gt_shape = targs['shape'].float()
        gt_kp2d = targs['kp2d'].float()
        gt_joint3d = targs['joint'] * 1000.0
        gt_joint3d = gt_joint3d - gt_joint3d[:, root_id:root_id+1, :]
   
        for idx, pred in enumerate(preds):

            pred_pose = pred['pose']
            pred_shape = pred['shape']
            pred_kp2d = pred['kp2d']
            pred_joint3d = pred['joint']
            pred_joint3d = pred_joint3d - pred_joint3d[:, root_id:root_id + 1, :]

            total_loss = torch.Tensor([0]).cuda()
            if self.lambda_pose:
                pose_loss = self.criterion_pose(pred_pose, gt_pose) * self.lambda_pose
                mano_losses['pose_%d' % idx] = pose_loss
                total_loss += pose_loss

            if self.lambda_shape:
                shape_loss = self.criterion_pose(pred_shape, gt_shape) * self.lambda_shape
                #shape_loss = self.criterion_pose(pred_shape, torch.zeros_like(pred_shape)) * self.lambda_shape
                mano_losses['shape_%d' % idx] = shape_loss
                total_loss += shape_loss

            if self.lambda_joint3d:
                joint3d_loss = self.criterion_pose(pred_joint3d, gt_joint3d) * self.lambda_joint3d
                mano_losses['joint3d_%d' % idx] = joint3d_loss
                total_loss += joint3d_loss

            if self.lambda_kp2d:
                kp2d_loss = self.criterion_pose(pred_kp2d, gt_kp2d) * self.lambda_kp2d
                mano_losses['kp2d_%d' % idx] = kp2d_loss
                total_loss += kp2d_loss

        mano_losses["total"] = total_loss

        return total_loss, mano_losses, batch_3d_size

loginfo: (1000/1018) d: 0.03s | b: 0.31s | s: 72.8770745 | p: 66.2028805 | j: 255.2186516 | k: 137.7673337 | t: 406.3150635 |
loginfo: (1001/1018) d: 0.03s | b: 0.31s | s: 72.8693185 | p: 66.2040951 | j: 255.2534850 | k: 137.7815070 | t: 574.5736694 |
loginfo: (1002/1018) d: 0.03s | b: 0.31s | s: 72.8742214 | p: 66.2062452 | j: 255.1808178 | k: 137.7479424 | t: 432.7313232 |
loginfo: (1003/1018) d: 0.03s | b: 0.31s | s: 72.8717776 | p: 66.2096950 | j: 255.2022860 | k: 137.7690050 | t: 575.6766357 |
loginfo: (1004/1018) d: 0.03s | b: 0.31s | s: 72.8674182 | p: 66.2144529 | j: 255.2273901 | k: 137.7747007 | t: 563.3758545 |
loginfo: (1005/1018) d: 0.03s | b: 0.31s | s: 72.8567293 | p: 66.2034177 | j: 255.2001193 | k: 137.7657703 | t: 473.8689270 |
loginfo: (1006/1018) d: 0.03s | b: 0.31s | s: 72.8619979 | p: 66.2114522 | j: 255.1318979 | k: 137.7335525 | t: 444.3671875 |
loginfo: (1007/1018) d: 0.03s | b: 0.31s | s: 72.8579450 | p: 66.2035864 | j: 255.2208746 | k: 137.7956344 | t: 672.0527344 |
loginfo: (1008/1018) d: 0.03s | b: 0.31s | s: 72.8569219 | p: 66.2097032 | j: 255.2702999 | k: 137.8076349 | t: 599.1296997 |
loginfo: (1009/1018) d: 0.03s | b: 0.31s | s: 72.8681490 | p: 66.2060013 | j: 255.2751612 | k: 137.7991216 | t: 536.0526733 |
loginfo: (1010/1018) d: 0.03s | b: 0.31s | s: 72.8743189 | p: 66.2180742 | j: 255.2236392 | k: 137.7840679 | t: 483.3320618 |
loginfo: (1011/1018) d: 0.03s | b: 0.31s | s: 72.8830146 | p: 66.2242202 | j: 255.2460789 | k: 137.8080345 | t: 594.0219727 |
loginfo: (1012/1018) d: 0.03s | b: 0.31s | s: 72.8847830 | p: 66.2241230 | j: 255.2646785 | k: 137.7912989 | t: 535.7387695 |
loginfo: (1013/1018) d: 0.03s | b: 0.31s | s: 72.8771316 | p: 66.2204524 | j: 255.2239141 | k: 137.7670858 | t: 454.8735657 |
loginfo: (1014/1018) d: 0.03s | b: 0.31s | s: 72.8792146 | p: 66.2185865 | j: 255.1972219 | k: 137.7473550 | t: 485.2358704 |
loginfo: (1015/1018) d: 0.03s | b: 0.31s | s: 72.8804192 | p: 66.2231351 | j: 255.2193977 | k: 137.7601394 | t: 573.3665161 |
loginfo: (1016/1018) d: 0.03s | b: 0.31s | s: 72.8631759 | p: 66.2301331 | j: 255.1524250 | k: 137.7305885 | t: 423.6058655 |
loginfo: (1017/1018) d: 0.03s | b: 0.31s | s: 72.8613670 | p: 66.2278808 | j: 255.2338836 | k: 137.7320333 | t: 612.1587524 |
loginfo: (1018/1018) d: 0.03s | b: 0.31s | s: 72.8622015 | p: 66.2263912 | j: 255.2558645 | k: 137.7465916 | t: 605.0795898 |

@Rookienovice
Copy link

could I get your reproduced code for study? Thanks a lot! @youngstu

@lvZic
Copy link

lvZic commented Aug 3, 2022

你好 能分享下你复现的训练代码吗 感谢~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants