딥러닝

DETR 코드

승무_ 2022. 9. 2. 17:32
class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP)
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()

        # backbone을 resnet50으로 하고 fc부분 제거
        self.backbone = resnet50()
        del self.backbone.fc

        # 2048->hidden_dim 차원으로 1x1 conv
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # torch에 있는 transformer사용
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # 배경을 예측 하기 위해 class+1
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # object query(100 x hidden_dim) 생성
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # backbone인 resnet을 통과하면 size가 32분의 1로 줄기 때문에 w,h가 50을 넘지 않음
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # inputs [1, 3, 800, 1066]
        x = self.backbone.conv1(inputs)
        # [1, 64, 400, 533]
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)
        # [1, 64, 200, 267]

        x = self.backbone.layer1(x)
        # [1, 256, 200, 267]
        x = self.backbone.layer2(x)
        # [1, 512, 100, 134]
        x = self.backbone.layer3(x)
        # [1, 1024, 50, 67]
        x = self.backbone.layer4(x)
        # [1, 2048, 25, 34]

        h = self.conv(x)
        # [1, 256, 25, 34]

        # construct positional encodings
        H, W = h.shape[-2:]
        # 25, 34
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)
        # col_embed[:W] [34, 128]
        # col_embed[:W].unsqueeze(0) [1, 34, 128]
        # unsqueeze(): 특정 위치에 1인 차원을 추가
        # col_embed[:W].unsqueeze(0).repeat(H, 1, 1) [25, 34, 128]
        # repeat(): sizes 차원의 데이터를 반복
        # row_embed[:H] [25, 128]
        # row_embed[:H].unsqueeze(1) [25, 1, 128]
        # row_embed[:H].unsqueeze(1).repeat(1, W, 1) [25, 34, 128]
        # torch.cat([x,y],dim=-1) [25, 34, 256]
        # torch.cat([x,y],dim=-1).flatten(0, 1) [25x34, 256] [850, 256]
        # 2d positional embeddings을 1d positional embeddings으로 변환
        # torch.cat([x,y],dim=-1).flatten(0, 1).unsqueeze(1) [850, 1, 256]
        # [850, 1, 256]

        h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1)).transpose(0, 1)
        # pos + 0.1 * h.flatten(2).permute(2, 0, 1)
        # 1d feature sequence와 positional embeddings이 concat되어 transformer에 input으로 들어감
        # h.flatten(2) [1, 256, 850]
        # h.flatten(2).permute(2, 0, 1) [850, 1, 256]
        # query_pos.unsqueeze(1) [100, 1, 256]
        # h [1, 100, 256]

        return {'pred_logits': self.linear_class(h),
                'pred_boxes': self.linear_bbox(h).sigmoid()}

'딥러닝' 카테고리의 다른 글

COCO json file 병합  (0) 2022.09.08
YOLOv5 검출 결과 Crop  (0) 2022.09.08
CAM(Class Activation Map)  (0) 2022.08.30
NMS & Anchor box  (0) 2022.08.30
CosineAnnealingWarmRestarts  (0) 2022.08.30