1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4import os
5import unittest
6
7from collections import OrderedDict
8from tempfile import TemporaryDirectory
9
10import torch
11
12from ..utils.checkpointer import Checkpointer
13
14
15class TestCheckpointer(unittest.TestCase):
16 def create_model(self):
17 return torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 1))
18
19 def create_complex_model(self):
20 m = torch.nn.Module()
21 m.block1 = torch.nn.Module()
22 m.block1.layer1 = torch.nn.Linear(2, 3)
23 m.layer2 = torch.nn.Linear(3, 2)
24 m.res = torch.nn.Module()
25 m.res.layer2 = torch.nn.Linear(3, 2)
26
27 state_dict = OrderedDict()
28 state_dict["layer1.weight"] = torch.rand(3, 2)
29 state_dict["layer1.bias"] = torch.rand(3)
30 state_dict["layer2.weight"] = torch.rand(2, 3)
31 state_dict["layer2.bias"] = torch.rand(2)
32 state_dict["res.layer2.weight"] = torch.rand(2, 3)
33 state_dict["res.layer2.bias"] = torch.rand(2)
34
35 return m, state_dict
36
37 def test_from_last_checkpoint_model(self):
38 # test that loading works even if they differ by a prefix
39 trained_model = self.create_model()
40 fresh_model = self.create_model()
41 with TemporaryDirectory() as f:
42 checkpointer = Checkpointer(trained_model, path=f)
43 checkpointer.save("checkpoint_file")
44
45 # in the same folder
46 fresh_checkpointer = Checkpointer(fresh_model, path=f)
47 assert fresh_checkpointer.has_checkpoint()
48 assert fresh_checkpointer.last_checkpoint() == os.path.realpath(
49 os.path.join(f, "checkpoint_file.pth")
50 )
51 _ = fresh_checkpointer.load()
52
53 for trained_p, loaded_p in zip(
54 trained_model.parameters(), fresh_model.parameters()
55 ):
56 # different tensor references
57 assert id(trained_p) != id(loaded_p)
58 # same content
59 assert trained_p.equal(loaded_p)
60
61 def test_from_name_file_model(self):
62 # test that loading works even if they differ by a prefix
63 trained_model = self.create_model()
64 fresh_model = self.create_model()
65 with TemporaryDirectory() as f:
66 checkpointer = Checkpointer(trained_model, path=f)
67 checkpointer.save("checkpoint_file")
68
69 # on different folders
70 with TemporaryDirectory() as g:
71 fresh_checkpointer = Checkpointer(fresh_model, path=g)
72 assert not fresh_checkpointer.has_checkpoint()
73 assert fresh_checkpointer.last_checkpoint() is None
74 _ = fresh_checkpointer.load(
75 os.path.join(f, "checkpoint_file.pth")
76 )
77
78 for trained_p, loaded_p in zip(
79 trained_model.parameters(), fresh_model.parameters()
80 ):
81 # different tensor references
82 assert id(trained_p) != id(loaded_p)
83 # same content
84 assert trained_p.equal(loaded_p)
85
86 def test_checkpointer_process(self):
87 from ...binseg.configs.models.lwnet import model
88 from ...binseg.engine.trainer import checkpointer_process
89
90 with TemporaryDirectory() as f:
91 checkpointer = Checkpointer(model, path=f)
92 lowest_validation_loss = 0.001
93 checkpoint_period = 0
94 valid_loss = 1.5
95 arguments = {"epoch": 0, "max_epoch": 10}
96 epoch = 1
97 max_epoch = 10
98 lowest_validation_loss = checkpointer_process(
99 checkpointer,
100 checkpoint_period,
101 valid_loss,
102 lowest_validation_loss,
103 arguments,
104 epoch,
105 max_epoch,
106 )
107
108 assert lowest_validation_loss == 0.001
109
110 lowest_validation_loss = 1000
111 lowest_validation_loss = checkpointer_process(
112 checkpointer,
113 checkpoint_period,
114 valid_loss,
115 lowest_validation_loss,
116 arguments,
117 epoch,
118 max_epoch,
119 )
120
121 assert lowest_validation_loss == 1.5