source: framspy/evolalg/experiment.py @ 1127

Last change on this file since 1127 was 1127, checked in by Maciej Komosinski, 3 years ago

Experiment: measure running time and save checkpoints safely

File size: 2.8 KB
Line 
1import os
2from typing import List, Callable, Union
3
4from evolalg.base.step import Step
5import pickle
6import time
7
8from evolalg.base.union_step import UnionStep
9from evolalg.selection.selection import Selection
10from evolalg.utils.stable_generation import StableGeneration
11
12
13class Experiment:
14    def __init__(self, init_population: List[Callable],
15                 selection: Selection,
16                 new_generation_steps: List[Union[Callable, Step]],
17                 generation_modification: List[Union[Callable, Step]],
18                 end_steps: List[Union[Callable, Step]],
19                 population_size,
20                 checkpoint_path=None, checkpoint_interval=None):
21
22        self.init_population = init_population
23        self.running_time = 0
24        self.step = StableGeneration(
25            selection=selection,
26            steps=new_generation_steps,
27            population_size=population_size)
28        self.generation_modification = UnionStep(generation_modification)
29
30        self.end_steps = UnionStep(end_steps)
31
32        self.checkpoint_path = checkpoint_path
33        self.checkpoint_interval = checkpoint_interval
34        self.generation = 0
35        self.population = None
36
37    def init(self):
38        self.generation = 0
39        for s in self.init_population:
40            if isinstance(s, Step):
41                s.init()
42
43        self.step.init()
44        self.generation_modification.init()
45        self.end_steps.init()
46
47        for s in self.init_population:
48            self.population = s(self.population)
49
50    def run(self, num_generations):
51        for i in range(self.generation + 1, num_generations + 1):
52            start_time = time.time()
53            self.generation = i
54            self.population = self.step(self.population)
55            self.population = self.generation_modification(self.population)
56
57            self.running_time += time.time() - start_time
58            if (self.checkpoint_path is not None
59                    and self.checkpoint_interval is not None
60                    and i % self.checkpoint_interval == 0):
61                self.save_checkpoint()
62
63        self.population = self.end_steps(self.population)
64
65    def save_checkpoint(self):
66        tmp_filepath = self.checkpoint_path+"_tmp"
67        try:
68            with open(tmp_filepath, "wb") as file:
69                pickle.dump(self, file)
70            os.replace(tmp_filepath, self.checkpoint_path)  # ensures the new file was first saved OK (e.g. enough free space on device), then replace
71        except Exception as ex:
72            raise RuntimeError("Failed to save checkpoint '%s' (because: %s). This does not prevent the experiment from continuing, but let's stop here to fix the problem with saving checkpoints." % (tmp_filepath, ex))
73
74
75    @staticmethod
76    def restore(path):
77        with open(path) as file:
78            res = pickle.load(file)
79        return res
Note: See TracBrowser for help on using the repository browser.