laneatt_to_onnx.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import torch
  2. from lib.models.laneatt import LaneATT
  3. class LaneATTONNX(torch.nn.Module):
  4. def __init__(self, model):
  5. super(LaneATTONNX, self).__init__()
  6. # Params
  7. self.fmap_h = model.fmap_h # 11
  8. self.fmap_w = model.fmap_w # 20
  9. self.anchor_feat_channels = model.anchor_feat_channels # 64
  10. self.anchors = model.anchors
  11. self.cut_xs = model.cut_xs
  12. self.cut_ys = model.cut_ys
  13. self.cut_zs = model.cut_zs
  14. self.invalid_mask = model.invalid_mask
  15. # Layers
  16. self.feature_extractor = model.feature_extractor
  17. self.conv1 = model.conv1
  18. self.cls_layer = model.cls_layer
  19. self.reg_layer = model.reg_layer
  20. self.attention_layer = model.attention_layer
  21. # Exporting the operator eye to ONNX opset version 11 is not supported
  22. attention_matrix = torch.eye(1000)
  23. self.non_diag_inds = torch.nonzero(attention_matrix == 0., as_tuple=False)
  24. self.non_diag_inds = self.non_diag_inds[:, 1] + 1000 * self.non_diag_inds[:, 0] # 999000
  25. def forward(self, x):
  26. batch_features = self.feature_extractor(x)
  27. batch_features = self.conv1(batch_features)
  28. # batch_anchor_features = self.cut_anchor_features(batch_features)
  29. batch_anchor_features = batch_features[0].flatten()
  30. # h, w = batch_features.shape[2:4] # 12, 20
  31. batch_anchor_features = batch_anchor_features[self.cut_xs + 20 * self.cut_ys + 12 * 20 * self.cut_zs].\
  32. view(1000, self.anchor_feat_channels, self.fmap_h, 1)
  33. # batch_anchor_features[self.invalid_mask] = 0
  34. batch_anchor_features = batch_anchor_features * torch.logical_not(self.invalid_mask)
  35. # Join proposals from all images into a single proposals features batch
  36. batch_anchor_features = batch_anchor_features.view(-1, self.anchor_feat_channels * self.fmap_h)
  37. # Add attention features
  38. softmax = torch.nn.Softmax(dim=1)
  39. scores = self.attention_layer(batch_anchor_features)
  40. attention = softmax(scores)
  41. attention_matrix = torch.zeros(1000 * 1000, device=x.device)
  42. attention_matrix[self.non_diag_inds] = attention.flatten() # ScatterND
  43. attention_matrix = attention_matrix.view(1000, 1000)
  44. attention_features = torch.matmul(torch.transpose(batch_anchor_features, 0, 1),
  45. torch.transpose(attention_matrix, 0, 1)).transpose(0, 1)
  46. batch_anchor_features = torch.cat((attention_features, batch_anchor_features), dim=1)
  47. # Predict
  48. cls_logits = self.cls_layer(batch_anchor_features)
  49. reg = self.reg_layer(batch_anchor_features)
  50. # Add offsets to anchors (1000, 2+2+73)
  51. reg_proposals = torch.cat([softmax(cls_logits), self.anchors[:, 2:4], self.anchors[:, 4:] + reg], dim=1)
  52. return reg_proposals
  53. def export_onnx(onnx_file_path):
  54. # e.g. laneatt_r18_culane
  55. backbone_name = 'resnet18'
  56. checkpoint_file_path = 'experiments/laneatt_r18_culane/models/model_0015.pt'
  57. anchors_freq_path = 'culane_anchors_freq.pt'
  58. # Load specified checkpoint
  59. model = LaneATT(backbone=backbone_name, anchors_freq_path=anchors_freq_path, topk_anchors=1000)
  60. checkpoint = torch.load(checkpoint_file_path)
  61. model.load_state_dict(checkpoint['model'])
  62. model.eval()
  63. # Export to ONNX
  64. onnx_model = LaneATTONNX(model)
  65. dummy_input = torch.randn(1, 3, 360, 640)
  66. torch.onnx.export(onnx_model, dummy_input, onnx_file_path, opset_version=11)
  67. if __name__ == '__main__':
  68. export_onnx('./LaneATT_test.onnx')