gen_wts.py 1018 B

1234567891011121314151617181920212223242526272829303132333435363738
  1. import os, sys
  2. import torch
  3. import struct
  4. # TODO: YOLOP_BASE_DIR is the root of YOLOP
  5. print("[WARN] Please download/clone YOLOP, then set YOLOP_BASE_DIR to the root of YOLOP")
  6. #YOLOP_BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  7. YOLOP_BASE_DIR = "/home/catarc/YOLOP"
  8. sys.path.append(YOLOP_BASE_DIR)
  9. from lib.models import get_net
  10. from lib.config import cfg
  11. # Initialize
  12. device = torch.device('cpu')
  13. # Load model
  14. model = get_net(cfg)
  15. checkpoint = torch.load(YOLOP_BASE_DIR + '/weights/End-to-end.pth', map_location=device)
  16. model.load_state_dict(checkpoint['state_dict'])
  17. # load to FP32
  18. model.float()
  19. model.to(device).eval()
  20. f = open('yolop.wts', 'w')
  21. f.write('{}\n'.format(len(model.state_dict().keys())))
  22. for k, v in model.state_dict().items():
  23. vr = v.reshape(-1).cpu().numpy()
  24. f.write('{} {} '.format(k, len(vr)))
  25. for vv in vr:
  26. f.write(' ')
  27. f.write(struct.pack('>f',float(vv)).hex())
  28. f.write('\n')
  29. f.close()
  30. print("save as yolop.wts")