Coverage for /scratch/builds/bob/bob.med.tb/miniconda/conda-bld/bob.med.tb_1637571489937/_test_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placeho/lib/python3.8/site-packages/bob/med/tb/test/test_checkpointer.py: 73%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

51 statements  

1#!/usr/bin/env python 

2# -*- coding: utf-8 -*- 

3 

4import os 

5import unittest 

6from collections import OrderedDict 

7from tempfile import TemporaryDirectory 

8 

9import torch 

10 

11from ..utils.checkpointer import Checkpointer 

12 

13 

14class TestCheckpointer(unittest.TestCase): 

15 def create_model(self): 

16 return torch.nn.Sequential( 

17 torch.nn.Linear(2, 3), torch.nn.Linear(3, 1) 

18 ) 

19 

20 def create_complex_model(self): 

21 m = torch.nn.Module() 

22 m.block1 = torch.nn.Module() 

23 m.block1.layer1 = torch.nn.Linear(2, 3) 

24 m.layer2 = torch.nn.Linear(3, 2) 

25 m.res = torch.nn.Module() 

26 m.res.layer2 = torch.nn.Linear(3, 2) 

27 

28 state_dict = OrderedDict() 

29 state_dict["layer1.weight"] = torch.rand(3, 2) 

30 state_dict["layer1.bias"] = torch.rand(3) 

31 state_dict["layer2.weight"] = torch.rand(2, 3) 

32 state_dict["layer2.bias"] = torch.rand(2) 

33 state_dict["res.layer2.weight"] = torch.rand(2, 3) 

34 state_dict["res.layer2.bias"] = torch.rand(2) 

35 

36 return m, state_dict 

37 

38 def test_from_last_checkpoint_model(self): 

39 # test that loading works even if they differ by a prefix 

40 trained_model = self.create_model() 

41 fresh_model = self.create_model() 

42 with TemporaryDirectory() as f: 

43 checkpointer = Checkpointer(trained_model, path=f) 

44 checkpointer.save("checkpoint_file") 

45 

46 # in the same folder 

47 fresh_checkpointer = Checkpointer(fresh_model, path=f) 

48 assert fresh_checkpointer.has_checkpoint() 

49 assert fresh_checkpointer.last_checkpoint() == os.path.realpath( 

50 os.path.join(f, "checkpoint_file.pth") 

51 ) 

52 _ = fresh_checkpointer.load() 

53 

54 for trained_p, loaded_p in zip( 

55 trained_model.parameters(), fresh_model.parameters() 

56 ): 

57 # different tensor references 

58 assert id(trained_p) != id(loaded_p) 

59 # same content 

60 assert trained_p.equal(loaded_p) 

61 

62 def test_from_name_file_model(self): 

63 # test that loading works even if they differ by a prefix 

64 trained_model = self.create_model() 

65 fresh_model = self.create_model() 

66 with TemporaryDirectory() as f: 

67 checkpointer = Checkpointer(trained_model, path=f) 

68 checkpointer.save("checkpoint_file") 

69 

70 # on different folders 

71 with TemporaryDirectory() as g: 

72 fresh_checkpointer = Checkpointer(fresh_model, path=g) 

73 assert not fresh_checkpointer.has_checkpoint() 

74 assert fresh_checkpointer.last_checkpoint() == None 

75 _ = fresh_checkpointer.load( 

76 os.path.join(f, "checkpoint_file.pth") 

77 ) 

78 

79 for trained_p, loaded_p in zip( 

80 trained_model.parameters(), fresh_model.parameters() 

81 ): 

82 # different tensor references 

83 assert id(trained_p) != id(loaded_p) 

84 # same content 

85 assert trained_p.equal(loaded_p)