class DETRdemo(nn.Module):
"""
Demo DETR implementation.
Demo implementation of DETR in minimal number of lines, with the
following differences wrt DETR in the paper:
* learned positional encoding (instead of sine)
* positional encoding is passed at input (instead of attention)
* fc bbox predictor (instead of MLP)
The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
Only batch size 1 supported.
"""
def __init__(self, num_classes, hidden_dim=256, nheads=8,
num_encoder_layers=6, num_decoder_layers=6):
super().__init__()
# backbone을 resnet50으로 하고 fc부분 제거
self.backbone = resnet50()
del self.backbone.fc
# 2048->hidden_dim 차원으로 1x1 conv
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# torch에 있는 transformer사용
self.transformer = nn.Transformer(
hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
# 배경을 예측 하기 위해 class+1
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
# object query(100 x hidden_dim) 생성
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
# backbone인 resnet을 통과하면 size가 32분의 1로 줄기 때문에 w,h가 50을 넘지 않음
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
# inputs [1, 3, 800, 1066]
x = self.backbone.conv1(inputs)
# [1, 64, 400, 533]
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
# [1, 64, 200, 267]
x = self.backbone.layer1(x)
# [1, 256, 200, 267]
x = self.backbone.layer2(x)
# [1, 512, 100, 134]
x = self.backbone.layer3(x)
# [1, 1024, 50, 67]
x = self.backbone.layer4(x)
# [1, 2048, 25, 34]
h = self.conv(x)
# [1, 256, 25, 34]
# construct positional encodings
H, W = h.shape[-2:]
# 25, 34
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
# col_embed[:W] [34, 128]
# col_embed[:W].unsqueeze(0) [1, 34, 128]
# unsqueeze(): 특정 위치에 1인 차원을 추가
# col_embed[:W].unsqueeze(0).repeat(H, 1, 1) [25, 34, 128]
# repeat(): sizes 차원의 데이터를 반복
# row_embed[:H] [25, 128]
# row_embed[:H].unsqueeze(1) [25, 1, 128]
# row_embed[:H].unsqueeze(1).repeat(1, W, 1) [25, 34, 128]
# torch.cat([x,y],dim=-1) [25, 34, 256]
# torch.cat([x,y],dim=-1).flatten(0, 1) [25x34, 256] [850, 256]
# 2d positional embeddings을 1d positional embeddings으로 변환
# torch.cat([x,y],dim=-1).flatten(0, 1).unsqueeze(1) [850, 1, 256]
# [850, 1, 256]
h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1)).transpose(0, 1)
# pos + 0.1 * h.flatten(2).permute(2, 0, 1)
# 1d feature sequence와 positional embeddings이 concat되어 transformer에 input으로 들어감
# h.flatten(2) [1, 256, 850]
# h.flatten(2).permute(2, 0, 1) [850, 1, 256]
# query_pos.unsqueeze(1) [100, 1, 256]
# h [1, 100, 256]
return {'pred_logits': self.linear_class(h),
'pred_boxes': self.linear_bbox(h).sigmoid()}