Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
4import os
5import unittest
7from collections import OrderedDict
8from tempfile import TemporaryDirectory
10import torch
12from ..utils.checkpointer import Checkpointer
15class TestCheckpointer(unittest.TestCase):
16 def create_model(self):
17 return torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 1))
19 def create_complex_model(self):
20 m = torch.nn.Module()
21 m.block1 = torch.nn.Module()
22 m.block1.layer1 = torch.nn.Linear(2, 3)
23 m.layer2 = torch.nn.Linear(3, 2)
24 m.res = torch.nn.Module()
25 m.res.layer2 = torch.nn.Linear(3, 2)
27 state_dict = OrderedDict()
28 state_dict["layer1.weight"] = torch.rand(3, 2)
29 state_dict["layer1.bias"] = torch.rand(3)
30 state_dict["layer2.weight"] = torch.rand(2, 3)
31 state_dict["layer2.bias"] = torch.rand(2)
32 state_dict["res.layer2.weight"] = torch.rand(2, 3)
33 state_dict["res.layer2.bias"] = torch.rand(2)
35 return m, state_dict
37 def test_from_last_checkpoint_model(self):
38 # test that loading works even if they differ by a prefix
39 trained_model = self.create_model()
40 fresh_model = self.create_model()
41 with TemporaryDirectory() as f:
42 checkpointer = Checkpointer(trained_model, path=f)
43 checkpointer.save("checkpoint_file")
45 # in the same folder
46 fresh_checkpointer = Checkpointer(fresh_model, path=f)
47 assert fresh_checkpointer.has_checkpoint()
48 assert fresh_checkpointer.last_checkpoint() == os.path.realpath(
49 os.path.join(f, "checkpoint_file.pth")
50 )
51 _ = fresh_checkpointer.load()
53 for trained_p, loaded_p in zip(
54 trained_model.parameters(), fresh_model.parameters()
55 ):
56 # different tensor references
57 assert id(trained_p) != id(loaded_p)
58 # same content
59 assert trained_p.equal(loaded_p)
61 def test_from_name_file_model(self):
62 # test that loading works even if they differ by a prefix
63 trained_model = self.create_model()
64 fresh_model = self.create_model()
65 with TemporaryDirectory() as f:
66 checkpointer = Checkpointer(trained_model, path=f)
67 checkpointer.save("checkpoint_file")
69 # on different folders
70 with TemporaryDirectory() as g:
71 fresh_checkpointer = Checkpointer(fresh_model, path=g)
72 assert not fresh_checkpointer.has_checkpoint()
73 assert fresh_checkpointer.last_checkpoint() is None
74 _ = fresh_checkpointer.load(
75 os.path.join(f, "checkpoint_file.pth")
76 )
78 for trained_p, loaded_p in zip(
79 trained_model.parameters(), fresh_model.parameters()
80 ):
81 # different tensor references
82 assert id(trained_p) != id(loaded_p)
83 # same content
84 assert trained_p.equal(loaded_p)