Skip to content

关于CenterNet后接segmentation的请教 #13

@NeuZhangQiang

Description

@NeuZhangQiang

非常感谢您分享的代码。

由于我暂时还没有配置成功代码,所以只能干看代码,没法调试。

有一个问题想请教一下,就是原始的CenterNet的输出有三个分支,分别是 heatmap (W*H*C),offset (W*H*2)和size (W*H*2),然后你这里加了一个seg_feat,这个分支是怎么加的,能介绍一下吗?能否告知是在代码的哪一处?这里的seg_feat它的size是什么样子的?怎么为每个中心点分配一个mask?难道与offset和size一样,预测一个 W*H*W*H的seg_feat?

此外,代码中关于dice loss的计算,我也不是很明白:

    def forward(self, seg_feat, conv_weight, mask,ind, target):
        mask_loss=0.
        batch_size = seg_feat.size(0)
        weight = _tranpose_and_gather_feat(conv_weight, ind)
        h,w = seg_feat.size(-2),seg_feat.size(-1)
        x,y = ind%w,ind/w
        x_range = torch.arange(w).float().to(device=seg_feat.device)
        y_range = torch.arange(h).float().to(device=seg_feat.device)
        y_grid, x_grid = torch.meshgrid([y_range, x_range])
        for i in range(batch_size):
            num_obj = target[i].size(0)
            conv1w,conv1b,conv2w,conv2b,conv3w,conv3b= \
                torch.split(weight[i,:num_obj],[(self.feat_channel+2)*self.feat_channel,self.feat_channel,
                                          self.feat_channel**2,self.feat_channel,
                                          self.feat_channel,1],dim=-1)
            y_rel_coord = (y_grid[None,None] - y[i,:num_obj].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float())/128.
            x_rel_coord = (x_grid[None,None] - x[i,:num_obj].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).float())/128.
            feat = seg_feat[i][None].repeat([num_obj,1,1,1])
            feat = torch.cat([feat,x_rel_coord, y_rel_coord],dim=1).view(1,-1,h,w)

            conv1w=conv1w.contiguous().view(-1,self.feat_channel+2,1,1)
            conv1b=conv1b.contiguous().flatten()
            feat = F.conv2d(feat,conv1w,conv1b,groups=num_obj).relu()

            conv2w=conv2w.contiguous().view(-1,self.feat_channel,1,1)
            conv2b=conv2b.contiguous().flatten()
            feat = F.conv2d(feat,conv2w,conv2b,groups=num_obj).relu()

            conv3w=conv3w.contiguous().view(-1,self.feat_channel,1,1)
            conv3b=conv3b.contiguous().flatten()
            feat = F.conv2d(feat,conv3w,conv3b,groups=num_obj).sigmoid().squeeze()

            true_mask = mask[i,:num_obj,None,None].float()
            mask_loss+=dice_loss(feat*true_mask,target[i]*true_mask)

        return mask_loss/batch_size

里面还进行了卷积计算?能否说明一下思路?

非常期待您的回复。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions