import json
import math
import random
import argparse
import bisect
import copy
import time as timelib
from PIL import Image, ImageDraw, ImageFont
from scipy import stats
from matplotlib import colors
import numpy as np

class LoadingError(Exception):
    pass

class Drawer:

    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
        self.design = design
        self.width = w
        self.height = h
        self.w_margin = w_margin
        self.h_margin = h_margin
        self.w_no_margs = w - 2* w_margin
        self.h_no_margs = h - 2* h_margin

        self.color_converter = colors.ColorConverter()

        self.settings = {
            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
            'dots': {
                'color': {
                    'meaning': 'Lifespan',
                    'normalize_cmap': False,
                    'cmap': {},
                    'start': 'red',
                    'end': 'green',
                    'bias': 1
                    },
                'size': {
                    'meaning': 'EnergyEaten',
                    'start': 1,
                    'end': 6,
                    'bias': 0.5
                    },
                'opacity': {
                    'meaning': 'EnergyEaten',
                    'start': 0.2,
                    'end': 1,
                    'bias': 1
                    }
            },
            'lines': {
                'color': {
                    'meaning': 'adepth',
                    'normalize_cmap': False,
                    'cmap': {},
                    'start': 'black',
                    'end': 'red',
                    'bias': 3
                    },
                'width': {
                    'meaning': 'adepth',
                    'start': 0.1,
                    'end': 4,
                    'bias': 3
                    },
                'opacity': {
                    'meaning': 'adepth',
                    'start': 0.1,
                    'end': 0.8,
                    'bias': 5
                    }
            }
        }

        def merge(source, destination):
            for key, value in source.items():
                if isinstance(value, dict):
                    node = destination.setdefault(key, {})
                    merge(value, node)
                else:
                    destination[key] = value
            return destination

        if config_file != "":
            with open(config_file) as config:
                c = json.load(config)
            self.settings = merge(c, self.settings)
            #print(json.dumps(self.settings, indent=4, sort_keys=True))

        self.compile_cmaps()

    def compile_cmaps(self):
        def normalize_and_compile_cmap(cmap):
            for key in cmap:
                for arr in cmap[key]:
                    arr[0] = (arr[0] - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
            return colors.LinearSegmentedColormap('Custom', cmap)

        for part in ['dots', 'lines']:
            if self.settings[part]['color']['cmap']:
                if self.settings[part]['color']['normalize_cmap']:
                    cmap = self.settings[part]['color']['cmap']
                    min = self.design.props[self.settings[part]['color']['meaning'] + "_min"]
                    max = self.design.props[self.settings[part]['color']['meaning'] + "_max"]

                    for key in cmap:
                        if cmap[key][0][0] > min:
                            cmap[key].insert(0, cmap[key][0][:])
                            cmap[key][0][0] = min
                        if cmap[key][-1][0] < max:
                            cmap[key].append(cmap[key][-1][:])
                            cmap[key][-1][0] = max

                    og_cmap = normalize_and_compile_cmap(copy.deepcopy(cmap))

                    col2key = {'red':0, 'green':1, 'blue':2}
                    for key in cmap:
                        # for color from (r/g/b) #n's should be the same for all keys!
                        n_min = (min - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
                        n_max = (max - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])

                        min_col = og_cmap(n_min)
                        max_col = og_cmap(n_max)

                        cmap[key][0] = [min, min_col[col2key[key]], min_col[col2key[key]]]
                        cmap[key][-1] = [max, max_col[col2key[key]], max_col[col2key[key]]]
                print(self.settings[part]['color']['cmap'])
                self.settings[part]['color']['cmap'] = normalize_and_compile_cmap(self.settings[part]['color']['cmap'])

    def draw_dots(self, file, min_width, max_width, max_height):
        for i in range(len(self.design.positions)):
            node = self.design.positions[i]
            if 'x' not in node:
                continue
            dot_style = self.compute_dot_style(node=i)
            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)

    def draw_lines(self, file, min_width, max_width, max_height):
        for parent in range(len(self.design.positions)):
            par_pos = self.design.positions[parent]
            if not 'x' in par_pos:
                continue
            for child in self.design.tree.children[parent]:
                chi_pos = self.design.positions[child]
                if 'x' not in chi_pos:
                    continue
                line_style = self.compute_line_style(parent, child)
                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)

    def draw_scale(self, file, filenames):
        self.add_text(file, "Generated from " + filenames[0].split("\\")[-1]
                      + (" and " + str(len(filenames)-1) + " more" if len(filenames) > 1 else ""), (5, 5), "start")

        start_text = ""
        end_text = ""
        if self.design.TIME == "BIRTHS":
           start_text = "Birth #0"
           end_text = "Birth #" + str(len(self.design.positions)-1)
        if self.design.TIME == "REAL":
           start_text = "Time " + str(min(self.design.tree.time))
           end_text = "Time " + str(max(self.design.tree.time))
        if self.design.TIME == "GENERATIONAL":
           start_text = "Depth " + str(self.design.props['adepth_min'])
           end_text = "Depth " + str(self.design.props['adepth_max'])

        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
        self.add_text(file, start_text, (self.width, self.h_margin), "end")
        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")

    def compute_property(self, part, prop, node):
        start = self.settings[part][prop]['start']
        end = self.settings[part][prop]['end']
        value = (self.design.props[self.settings[part][prop]['meaning']][node]
                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
        bias = self.settings[part][prop]['bias']
        if prop == "color":
            if not self.settings[part][prop]['cmap']:
                return self.compute_color(start, end, value, bias)
            else:
                return self.compute_color_from_cmap(self.settings[part][prop]['cmap'], value, bias)
        else:
            return self.compute_value(start, end, value, bias)

    def compute_color_from_cmap(self, cmap, value, bias=1):
        value = 1 - (1-value)**bias
        rgba = cmap(value)
        return (100*rgba[0], 100*rgba[1], 100*rgba[2])


    def compute_color(self, start, end, value, bias=1):
        if isinstance(value, str):
            value = int(value)
            r, g, b = self.color_converter.to_rgb(self.settings['colors_of_kinds'][value])
        else:
            start_color = self.color_converter.to_rgb(start)
            end_color = self.color_converter.to_rgb(end)
            value = 1 - (1-value)**bias
            r = start_color[0]*(1-value)+end_color[0]*value
            g = start_color[1]*(1-value)+end_color[1]*value
            b = start_color[2]*(1-value)+end_color[2]*value
        return (100*r, 100*g, 100*b)

    def compute_value(self, start, end, value, bias=1):
        value = 1 - (1-value)**bias
        return start*(1-value) + end*value

class PngDrawer(Drawer):

    def scale_up(self):
        self.width *= self.multi
        self.height *= self.multi
        self.w_margin *= self.multi
        self.h_margin *= self.multi
        self.h_no_margs *= self.multi
        self.w_no_margs *= self.multi

    def scale_down(self):
        self.width /= self.multi
        self.height /= self.multi
        self.w_margin /= self.multi
        self.h_margin /= self.multi
        self.h_no_margs /= self.multi
        self.w_no_margs /= self.multi

    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
        print("Drawing...")

        self.multi=multi
        self.scale_up()

        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))

        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
        max_height = max([x['y'] for x in self.design.positions if 'y' in x])

        self.draw_lines(back, min_width, max_width, max_height)
        self.draw_dots(back, min_width, max_width, max_height)

        if scale == "SIMPLE":
            self.draw_scale(back, input_filename)

        #back.show()
        self.scale_down()

        back.thumbnail((self.width, self.height), Image.ANTIALIAS)

        back.save(filename)

    def add_dot(self, file, pos, style):
        x, y = int(pos[0]), int(pos[1])
        r = style['r']*self.multi
        offset = (int(x - r), int(y - r))
        size = (2*int(r), 2*int(r))

        c = style['color']

        img = Image.new('RGBA', size)
        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
        file.paste(img, offset, mask=img)

    def add_line(self, file, from_pos, to_pos, style):
        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
        w = int(style['width'])*self.multi

        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
        if size[0] == 0 or size[1] == 0:
            return

        c = style['color']

        img = Image.new('RGBA', size)
        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
        file.paste(img, offset, mask=img)

    def add_dashed_line(self, file, from_pos, to_pos):
        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
        sublines = 50
        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
        normdiv = 2*sublines-1
        for i in range(sublines):
            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
            self.add_line(file, from_pos_sub, to_pos_sub, style)

    def add_text(self, file, text, pos, anchor, style=''):
        font = ImageFont.truetype("Vera.ttf", 16*self.multi)

        img = Image.new('RGBA', (self.width, self.height))
        draw = ImageDraw.Draw(img)
        txtsize = draw.textsize(text, font=font)
        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
        draw.text(pos, text, (0,0,0), font=font)
        file.paste(img, (0,0), mask=img)

    def compute_line_style(self, parent, child):
        return {'color': self.compute_property('lines', 'color', child),
                'width': self.compute_property('lines', 'width', child),
                'opacity': self.compute_property('lines', 'opacity', child)}

    def compute_dot_style(self, node):
        return {'color': self.compute_property('dots', 'color', node),
                'r': self.compute_property('dots', 'size', node),
                'opacity': self.compute_property('dots', 'opacity', node)}

class SvgDrawer(Drawer):
    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
        print("Drawing...")
        file = open(filename, "w")

        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
        max_height = max([x['y'] for x in self.design.positions if 'y' in x])

        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')

        self.draw_lines(file, min_width, max_width, max_height)
        self.draw_dots(file, min_width, max_width, max_height)

        if scale == "SIMPLE":
            self.draw_scale(file, input_filename)

        file.write("</svg>")
        file.close()

    def add_text(self, file, text, pos, anchor, style=''):
        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
        # assuming font size 12, it should be taken from the style string!
        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')

    def add_dot(self, file, pos, style):
        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')

    def add_line(self, file, from_pos, to_pos, style):
        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')

    def add_dashed_line(self, file, from_pos, to_pos):
        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
        self.add_line(file, from_pos, to_pos, style)

    def compute_line_style(self, parent, child):
        return self.compute_stroke_color('lines', child) + ' ' \
               + self.compute_stroke_width('lines', child) + ' ' \
               + self.compute_stroke_opacity(child)

    def compute_dot_style(self, node):
        return self.compute_dot_size(node) + ' ' \
               + self.compute_fill_opacity(node) + ' ' \
               + self.compute_dot_fill(node)

    def compute_stroke_color(self, part, node):
        color = self.compute_property(part, 'color', node)
        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'

    def compute_stroke_width(self, part, node):
        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'

    def compute_stroke_opacity(self, node):
        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'

    def compute_fill_opacity(self, node):
        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'

    def compute_dot_size(self, node):
        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'

    def compute_dot_fill(self, node):
        color = self.compute_property('dots', 'color', node)
        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'

class Designer:

    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
        self.props = {}

        self.tree = tree

        self.TIME = time
        self.JITTER = jitter

        if balance == "RANDOM":
            self.xmin_crowd = self.xmin_crowd_random
        elif balance == "MIN":
            self.xmin_crowd = self.xmin_crowd_min
        elif balance == "DENSITY":
            self.xmin_crowd = self.xmin_crowd_density
        else:
            raise ValueError("Error, the value of BALANCE does not match any expected value.")

    def calculate_measures(self):
        print("Calculating measures...")
        self.compute_depth()
        self.compute_maxdepth()
        self.compute_adepth()
        self.compute_children()
        self.compute_kind()
        self.compute_time()
        self.compute_progress()
        self.compute_custom()

    def xmin_crowd_random(self, x1, x2, y):
        return (x1 if random.randrange(2) == 0 else x2)

    def xmin_crowd_min(self, x1, x2, y):
        x1_closest = 999999
        x2_closest = 999999
        miny = y-3
        maxy = y+3
        i = bisect.bisect_left(self.y_sorted, miny)
        while True:
            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
                break
            pos = self.positions_sorted[i]

            x1_closest = min(x1_closest, abs(x1-pos['x']))
            x2_closest = min(x2_closest, abs(x2-pos['x']))

            i += 1
        return (x1 if x1_closest > x2_closest else x2)

    def xmin_crowd_density(self, x1, x2, y):
        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
        CONST_LOCAL_AREA_RADIUS = 5
        CONST_GLOBAL_AREA_RADIUS = 10
        CONST_WINDOW_SIZE = 20000 #TODO should depend on the maxY ?
        x1_dist_loc = 0
        x2_dist_loc = 0
        count_loc = 1
        x1_dist_glob = 0
        x2_dist_glob = 0
        count_glob = 1
        miny = y-CONST_WINDOW_SIZE
        maxy = y+CONST_WINDOW_SIZE
        i_left = bisect.bisect_left(self.y_sorted, miny)
        i_right = bisect.bisect_right(self.y_sorted, maxy)
        #TODO test: maxy=y should give the same results, right?

        def include_pos(pos):
            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob

            dysq = (pos['y']-y)**2 + 1 #+1 so 1/dysq is at most 1
            dx1 = math.fabs(pos['x']-x1)
            dx2 = math.fabs(pos['x']-x2)

            d = math.fabs(pos['x'] - (x1+x2)/2)

            if d < CONST_LOCAL_AREA_RADIUS:
                x1_dist_loc += math.sqrt(dx1/dysq + dx1**2)
                x2_dist_loc += math.sqrt(dx2/dysq + dx2**2)
                count_loc += 1
            elif d > CONST_GLOBAL_AREA_RADIUS:
                x1_dist_glob += math.sqrt(dx1/dysq + dx1**2)
                x2_dist_glob += math.sqrt(dx2/dysq + dx2**2)
                count_glob += 1

        # optimized to draw from all the nodes, if less than 10 nodes in the range
        if len(self.positions_sorted) > i_left:
            if i_right - i_left < 10:
                for j in range(i_left, i_right):
                    include_pos(self.positions_sorted[j])
            else:
                for j in range(10):
                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
                    include_pos(pos)

        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
        #print(x1_dist, x2_dist)
        #x1_dist = x1_dist**2
        #x2_dist = x2_dist**2
        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)

    def calculate_node_positions(self, ignore_last=0):
        print("Calculating positions...")

        def add_node(node):
            index = bisect.bisect_left(self.y_sorted, node['y'])
            self.y_sorted.insert(index, node['y'])
            self.positions_sorted.insert(index, node)
            self.positions[node['id']] = node

        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
        self.y_sorted = [0]
        self.positions = [{} for x in range(len(self.tree.parents))]
        self.positions[0] = {'x':0, 'y':0, 'id':0}

        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
        visiting_order = [i for i in range(0, len(self.tree.parents))]
        visiting_order = sorted(visiting_order, key=lambda q:\
                            0 if q == 0 else self.props["maxdepth"][q])

        start_time = timelib.time()

        # for each child of the current node
        for node_counter,child in enumerate(visiting_order, start=1):
            # debug info - elapsed time
            if node_counter % 100000 == 0:
               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
               start_time = timelib.time()

            # using normalized adepth
            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:

                ypos = 0
                if self.TIME == "BIRTHS":
                    ypos = child
                elif self.TIME == "GENERATIONAL":
                    # one more than its parent (what if more than one parent?)
                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
                        if self.tree.parents[child] else 0
                elif self.TIME == "REAL":
                    ypos = self.tree.time[child]

                if len(self.tree.parents[child]) == 1:
                # if current_node is the only parent
                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]

                    if self.JITTER:
                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
                    else:
                        dissimilarity = (1-similarity) + 0.001
                    add_node({'id':child, 'y':ypos, 'x':
                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
                              self.positions[parent]['x']+dissimilarity, ypos)})
                else:
                    # position weighted by the degree of inheritence from each parent
                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
                    xpos = sum([self.positions[k]['x']*v/total_inheretance
                               for k, v in self.tree.parents[child].items()])
                    if self.JITTER:
                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
                    else:
                        add_node({'id':child, 'y':ypos, 'x':xpos})


    def compute_custom(self):
        for prop in self.tree.props:
            self.props[prop] = [None for x in range(len(self.tree.children))]

            for i in range(len(self.props[prop])):
                self.props[prop][i] = self.tree.props[prop][i]

            self.normalize_prop(prop)

    def compute_time(self):
        # simple rewrite from the tree
        self.props["time"] = [0 for x in range(len(self.tree.children))]

        for i in range(len(self.props['time'])):
            self.props['time'][i] = self.tree.time[i]

        self.normalize_prop('time')

    def compute_kind(self):
        # simple rewrite from the tree
        self.props["kind"] = [0 for x in range(len(self.tree.children))]

        for i in range (len(self.props['kind'])):
            self.props['kind'][i] = str(self.tree.kind[i])

    def compute_depth(self):
        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
        visited = [0 for x in range(len(self.tree.children))]

        nodes_to_visit = [0]
        visited[0] = 1
        self.props["depth"][0] = 0
        while True:
            current_node = nodes_to_visit[0]

            for child in self.tree.children[current_node]:
                if visited[child] == 0:
                    visited[child] = 1
                    nodes_to_visit.append(child)
                    self.props["depth"][child] = self.props["depth"][current_node]+1
            nodes_to_visit = nodes_to_visit[1:]
            if len(nodes_to_visit) == 0:
                break

        self.normalize_prop('depth')

    def compute_maxdepth(self):
        self.props["maxdepth"] = [999999999 for x in range(len(self.tree.children))]
        visited = [0 for x in range(len(self.tree.children))]

        nodes_to_visit = [0]
        visited[0] = 1
        self.props["maxdepth"][0] = 0
        while True:
            current_node = nodes_to_visit[0]

            for child in self.tree.children[current_node]:
                if visited[child] == 0:
                    visited[child] = 1
                    nodes_to_visit.append(child)
                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
                elif self.props["maxdepth"][child] < self.props["maxdepth"][current_node]+1:
                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
                    if child not in  nodes_to_visit:
                        nodes_to_visit.append(child)

            nodes_to_visit = nodes_to_visit[1:]
            if len(nodes_to_visit) == 0:
                break

        self.normalize_prop('maxdepth')

    def compute_adepth(self):
        self.props["adepth"] = [0 for x in range(len(self.tree.children))]

        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
        visiting_order = [i for i in range(0, len(self.tree.parents))]
        visiting_order = sorted(visiting_order, key=lambda q: self.props["maxdepth"][q])[::-1]

        for node in visiting_order:
            children = self.tree.children[node]
            if len(children) != 0:
                # 0 by default
                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
        self.normalize_prop('adepth')

    def compute_children(self):
        self.props["children"] = [0 for x in range(len(self.tree.children))]
        for i in range (len(self.props['children'])):
            self.props['children'][i] = len(self.tree.children[i])

        self.normalize_prop('children')

    def compute_progress(self):
        self.props["progress"] = [0 for x in range(len(self.tree.children))]
        for i in range(len(self.props['children'])):
            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
            if len(times) > 4:
                times = [times[i+1] - times[i] for i in range(len(times)-1)]
                #print(times)
                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0

        for i in range(0, 5):
            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0

        mini = min(self.props['progress'])
        maxi = max(self.props['progress'])
        for k in range(len(self.props['progress'])):
            if self.props['progress'][k] == 0:
                self.props['progress'][k] = mini

        #for k in range(len(self.props['progress'])):
        #        self.props['progress'][k] = 1-self.props['progress'][k]

        self.normalize_prop('progress')

    def normalize_prop(self, prop):
        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
        if len(noneless) > 0:
            max_val = max(noneless)
            min_val = min(noneless)
            print("%s: [%g, %g]" % (prop, min_val, max_val))
            self.props[prop +'_max'] = max_val
            self.props[prop +'_min'] = min_val
            for i in range(len(self.props[prop])):
                if self.props[prop][i] is not None:
                    qqq = self.props[prop][i]
                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)

class TreeData:
    simple_data = None

    children = []
    parents = []
    time = []
    kind = []

    def __init__(self): #, simple_data=False):
        #self.simple_data = simple_data
        pass

    def load(self, filenames, max_nodes=0):
        print("Loading...")

        CLI_PREFIX = "Script.Message:"
        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]

        merged_with_virtual_parent = [] #this list will contain individuals for which the parent could not be found

        self.ids = {}
        def get_id(id, createOnError = True):
            if createOnError:
                if id not in self.ids:
                    self.ids[id] = len(self.ids)
            else:
                if id not in self.ids:
                    return None

            return self.ids[id]

        def try_to_load(input):
            creature = False
            try:
                creature = json.loads(input)
            except ValueError:
                print("Json format error: the line cannot be read. Breaking the loading loop.")
                # fixing arrays by removing the last element
                # ! assuming that only the last line is broken !
                self.parents.pop()
                self.children.pop()
                self.time.pop()
                self.kind.pop()
                self.life_lenght.pop()
            return creature

        def load_creature_props(creature):
            creature_id = get_id(creature["ID"])
            for prop in creature:
                if prop not in default_props:
                    if prop not in self.props:
                        self.props[prop] = [0 for i in range(nodes)]
                    self.props[prop][creature_id] = creature[prop]

        def load_born_props(creature):
            nonlocal max_time
            creature_id = get_id(creature["ID"])
            if "Time" in creature:
                self.time[creature_id] = creature["Time"] + time_offset
                max_time = max(self.time[creature_id], max_time)

        def load_offspring_props(creature):
            creature_id = get_id(creature["ID"])#, False)
            if "FromIDs" in creature:
                # make sure that ID's of parents are lower than that of their children
                for i in range(0, len(creature["FromIDs"])):
                    if creature["FromIDs"][i] not in self.ids:
                        get_id("virtual_parent")


                # we assign to each parent its contribution to the genotype of the child
                for i in range(0, len(creature["FromIDs"])):
                    if creature["FromIDs"][i] in self.ids:
                        parent_id = get_id(creature["FromIDs"][i])
                    else:
                        if creature["FromIDs"][i] not in merged_with_virtual_parent:
                            merged_with_virtual_parent.append(creature["FromIDs"][i])
                        parent_id = get_id("virtual_parent")
                    inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
                    self.parents[creature_id][parent_id] = inherited

                if "Kind" in creature:
                    self.kind[creature_id] = creature["Kind"]
            else:
                raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")

        # counting the number of expected nodes
        nodes_born, nodes_offspring = 0, 0
        for filename in filenames:
            file = open(filename)
            for line in file:
                line_arr = line.split(' ', 1)
                if len(line_arr) == 2:
                    if line_arr[0] == CLI_PREFIX:
                        line_arr = line_arr[1].split(' ', 1)
                    if line_arr[0] == "[BORN]":
                        nodes_born += 1
                    if line_arr[0] == "[OFFSPRING]":
                        nodes_offspring += 1
        # assuming that either BORN or OFFSPRING, or both, are present for each individual
        nodes = max(nodes_born, nodes_offspring)
        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1

        self.parents = [{} for x in range(nodes)]
        self.children = [[] for x in range(nodes)]
        self.time = [0] * nodes
        self.kind = [0] * nodes
        self.life_lenght = [0] * nodes
        self.props = {}

        print("nodes: %d" % len(self.parents))


        get_id("virtual_parent")
        loaded_so_far = 0
        max_time = 0
        # rewind the file

        for filename in filenames:
            file = open(filename)
            time_offset = max_time
            if max_time != 0:
                print("NOTE: merging files, assuming cumulative time offset for '%s' to be %d" % (filename, time_offset))

            lasttime = timelib.time()

            for line in file:
                line_arr = line.split(' ', 1)
                if len(line_arr) == 2:
                    if line_arr[0] == CLI_PREFIX:
                        line_arr = line_arr[1].split(' ', 1)
                    if line_arr[0] == "[BORN]":
                        creature = try_to_load(line_arr[1])
                        if not creature:
                            nodes -= 1
                            break

                        if get_id(creature["ID"], False) is None:
                            loaded_so_far += 1

                        load_born_props(creature)
                        load_creature_props(creature)

                    if line_arr[0] == "[OFFSPRING]":
                        creature = try_to_load(line_arr[1])
                        if not creature:
                            nodes -= 1
                            break

                        if get_id(creature["ID"], False) is None:
                            loaded_so_far += 1
                            # load time only if there was no [BORN] yet
                            load_born_props(creature)

                        load_offspring_props(creature)

                    if line_arr[0] == "[DIED]":
                        creature = try_to_load(line_arr[1])
                        if not creature:
                            nodes -= 1
                            break
                        if get_id(creature["ID"], False) is not None:
                            load_creature_props(creature)
                        else:
                            print("NOTE: encountered [DIED] entry for individual '%s' before it was [BORN] or [OFFSPRING]" % creature["ID"])

                # debug
                if loaded_so_far%1000 == 0:
                    #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
                    lasttime = timelib.time()

                # breaking both loops
                if loaded_so_far >= max_nodes and max_nodes != 0:
                    break
            if loaded_so_far >= max_nodes and max_nodes != 0:
                break

        print("NOTE: all individuals with parent not provided or missing were connected to a single 'virtual parent' node: " + str(merged_with_virtual_parent))

        for c_id in range(1, nodes):
            if not self.parents[c_id]:
                self.parents[c_id][get_id("virtual_parent")] = 1

        for k in range(len(self.parents)):
            v = self.parents[k]
            for val in self.parents[k]:
                self.children[val].append(k)

depth = {}
kind = {}

def main():

    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
                                                 'information from a text file. Supports files generated by Framsticks experiments.')
    parser.add_argument('-i', '--in', nargs='+', dest='input', required=True, help='input file name with stuctured evolutionary data (or a list of input files)')
    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')

    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')

    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
                                                                      'BIRTHS: time measured as the number of births since the beginning; '
                                                                      'GENERATIONAL: time measured as number of ancestors; '
                                                                      'REAL: real time of the simulation')
    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')

    parser.set_defaults(draw_tree=True)
    parser.set_defaults(draw_skeleton=False)
    parser.set_defaults(draw_spine=False)

    parser.set_defaults(seed=-1)

    args = parser.parse_args()

    TIME = args.time.upper()
    BALANCE = args.balance.upper()
    SCALE = args.scale.upper()
    JITTER = args.jitter
    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
        or not SCALE in ['NONE', 'SIMPLE']:
        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
        return

    dir = args.input

    seed = args.seed
    if seed == -1:
        seed = random.randint(0, 10000)
    random.seed(seed)
    print("randomseed:", seed)

    tree = TreeData()
    tree.load(dir, max_nodes=args.max_nodes)


    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
    designer.calculate_measures()
    designer.calculate_node_positions(ignore_last=args.skip)

    if args.output.endswith(".svg"):
        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
    else:
        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)


main()
