123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import torch
- from lib.models.laneatt import LaneATT
- class LaneATTONNX(torch.nn.Module):
- def __init__(self, model):
- super(LaneATTONNX, self).__init__()
- # Params
- self.fmap_h = model.fmap_h # 11
- self.fmap_w = model.fmap_w # 20
- self.anchor_feat_channels = model.anchor_feat_channels # 64
- self.anchors = model.anchors
- self.cut_xs = model.cut_xs
- self.cut_ys = model.cut_ys
- self.cut_zs = model.cut_zs
- self.invalid_mask = model.invalid_mask
- # Layers
- self.feature_extractor = model.feature_extractor
- self.conv1 = model.conv1
- self.cls_layer = model.cls_layer
- self.reg_layer = model.reg_layer
- self.attention_layer = model.attention_layer
- # Exporting the operator eye to ONNX opset version 11 is not supported
- attention_matrix = torch.eye(1000)
- self.non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
- self.non_diag_inds = self.non_diag_inds[:, 1] + 1000 * self.non_diag_inds[:, 0] # 999000
- def forward(self, x):
- batch_features = self.feature_extractor(x)
- batch_features = self.conv1(batch_features)
- # batch_anchor_features = self.cut_anchor_features(batch_features)
- batch_anchor_features = batch_features[0].flatten()
- # h, w = batch_features.shape[2:4] # 12, 20
- batch_anchor_features = batch_anchor_features[self.cut_xs + 20 * self.cut_ys + 12 * 20 * self.cut_zs].\
- view(1000, self.anchor_feat_channels, self.fmap_h, 1)
- # batch_anchor_features[self.invalid_mask] = 0
- batch_anchor_features = batch_anchor_features * torch.logical_not(self.invalid_mask)
- # Join proposals from all images into a single proposals features batch
- batch_anchor_features = batch_anchor_features.view(-1, self.anchor_feat_channels * self.fmap_h)
- # Add attention features
- softmax = torch.nn.Softmax(dim=1)
- scores = self.attention_layer(batch_anchor_features)
- attention = softmax(scores)
- attention_matrix = torch.zeros(1000 * 1000, device=x.device)
- attention_matrix[self.non_diag_inds] = attention.flatten() # ScatterND
- attention_matrix = attention_matrix.view(1000, 1000)
- attention_features = torch.matmul(torch.transpose(batch_anchor_features, 0, 1),
- torch.transpose(attention_matrix, 0, 1)).transpose(0, 1)
- batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=1)
- # Predict
- cls_logits = self.cls_layer(batch_anchor_features)
- reg = self.reg_layer(batch_anchor_features)
- # Add offsets to anchors (1000, 2+2+73)
- reg_proposals = torch.cat([softmax(cls_logits), self.anchors[:, 2:4], self.anchors[:, 4:] + reg], dim=1)
- return reg_proposals
- def export_onnx(onnx_file_path):
- # e.g. laneatt_r18_culane
- backbone_name = 'resnet18'
- checkpoint_file_path = 'experiments/laneatt_r18_culane/models/model_0015.pt'
- anchors_freq_path = 'culane_anchors_freq.pt'
- # Load specified checkpoint
- model = LaneATT(backbone=backbone_name, anchors_freq_path=anchors_freq_path, topk_anchors=1000)
- checkpoint = torch.load(checkpoint_file_path)
- model.load_state_dict(checkpoint['model'])
- model.eval()
- # Export to ONNX
- onnx_model = LaneATTONNX(model)
- dummy_input = torch.randn(1, 3, 360, 640)
- torch.onnx.export(onnx_model, dummy_input, onnx_file_path, opset_version=11)
- if __name__ == '__main__':
- export_onnx('./LaneATT_test.onnx')
|