from typing import List, Callable, Union

from evolalg.base.step import Step
import pickle

from evolalg.base.union_step import UnionStep
from evolalg.selection.selection import Selection
from evolalg.utils.stable_generation import StableGeneration


class Experiment:
    def __init__(self, init_population: List[Callable],
                 selection: Selection,
                 new_generation_steps: List[Union[Callable, Step]],
                 generation_modification: List[Union[Callable, Step]],
                 end_steps: List[Union[Callable, Step]],
                 population_size,
                 checkpoint_path=None, checkpoint_interval=None):

        self.init_population = init_population
        self.step = StableGeneration(
            selection=selection,
            steps=new_generation_steps,
            population_size=population_size)
        self.generation_modification = UnionStep(generation_modification)

        self.end_steps = UnionStep(end_steps)

        self.checkpoint_path = checkpoint_path
        self.checkpoint_interval = checkpoint_interval
        self.generation = 0
        self.population = None

    def init(self):
        self.generation = 0
        for s in self.init_population:
            if isinstance(s, Step):
                s.init()

        self.step.init()
        self.generation_modification.init()
        self.end_steps.init()

        for s in self.init_population:
            self.population = s(self.population)

    def run(self, num_generations):
        for i in range(self.generation + 1, num_generations + 1):
            self.generation = i
            self.population = self.step(self.population)
            self.population = self.generation_modification(self.population)

            if (self.checkpoint_path is not None
                    and self.checkpoint_interval is not None
                    and i % self.checkpoint_interval == 0):
                with open(self.checkpoint_path, "wb") as file:
                    pickle.dump(self, file)

        self.population = self.end_steps(self.population)

    @staticmethod
    def restore(path):
        with open(path) as file:
            res = pickle.load(file)
        return res
