source: mds-and-trees/tree-genealogy.py @ 702

Last change on this file since 702 was 702, checked in by konrad, 7 years ago

Non-normalized step values for cmap can now extend outside the range of values from the data (yet min and max step values should still be the same for all color channels!)

File size: 37.6 KB
RevLine 
[562]1import json
[624]2import math
[562]3import random
4import argparse
[624]5import bisect
[702]6import copy
[624]7import time as timelib
8from PIL import Image, ImageDraw, ImageFont
[633]9from scipy import stats
[695]10from matplotlib import colors
[633]11import numpy as np
[562]12
[624]13class LoadingError(Exception):
14    pass
[562]15
[624]16class Drawer:
[571]17
[624]18    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
19        self.design = design
20        self.width = w
21        self.height = h
22        self.w_margin = w_margin
23        self.h_margin = h_margin
24        self.w_no_margs = w - 2* w_margin
25        self.h_no_margs = h - 2* h_margin
[571]26
[695]27        self.color_converter = colors.ColorConverter()
[562]28
[624]29        self.settings = {
30            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
31            'dots': {
32                'color': {
[628]33                    'meaning': 'Lifespan',
[700]34                    'normalize_cmap': False,
35                    'cmap': {},
[628]36                    'start': 'red',
[624]37                    'end': 'green',
38                    'bias': 1
39                    },
40                'size': {
[628]41                    'meaning': 'EnergyEaten',
[624]42                    'start': 1,
[628]43                    'end': 6,
[624]44                    'bias': 0.5
45                    },
46                'opacity': {
[628]47                    'meaning': 'EnergyEaten',
48                    'start': 0.2,
49                    'end': 1,
[624]50                    'bias': 1
51                    }
52            },
53            'lines': {
54                'color': {
55                    'meaning': 'adepth',
[700]56                    'normalize_cmap': False,
57                    'cmap': {},
[624]58                    'start': 'black',
59                    'end': 'red',
60                    'bias': 3
61                    },
62                'width': {
63                    'meaning': 'adepth',
[627]64                    'start': 0.1,
[624]65                    'end': 4,
66                    'bias': 3
67                    },
68                'opacity': {
69                    'meaning': 'adepth',
70                    'start': 0.1,
71                    'end': 0.8,
72                    'bias': 5
73                    }
74            }
75        }
[577]76
[624]77        def merge(source, destination):
78            for key, value in source.items():
79                if isinstance(value, dict):
80                    node = destination.setdefault(key, {})
81                    merge(value, node)
82                else:
83                    destination[key] = value
84            return destination
[576]85
[624]86        if config_file != "":
87            with open(config_file) as config:
88                c = json.load(config)
89            self.settings = merge(c, self.settings)
90            #print(json.dumps(self.settings, indent=4, sort_keys=True))
[586]91
[700]92        self.compile_cmaps()
93
94    def compile_cmaps(self):
[702]95        def normalize_and_compile_cmap(cmap):
96            for key in cmap:
97                for arr in cmap[key]:
98                    arr[0] = (arr[0] - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
99            return colors.LinearSegmentedColormap('Custom', cmap)
[700]100
101        for part in ['dots', 'lines']:
[702]102            if self.settings[part]['color']['cmap']:
103                if self.settings[part]['color']['normalize_cmap']:
104                    cmap = self.settings[part]['color']['cmap']
105                    min = self.design.props[self.settings[part]['color']['meaning'] + "_min"]
106                    max = self.design.props[self.settings[part]['color']['meaning'] + "_max"]
[700]107
[702]108                    for key in cmap:
109                        if cmap[key][0][0] > min:
110                            cmap[key].insert(0, cmap[key][0][:])
111                            cmap[key][0][0] = min
112                        if cmap[key][-1][0] < max:
113                            cmap[key].append(cmap[key][-1][:])
114                            cmap[key][-1][0] = max
[700]115
[702]116                    og_cmap = normalize_and_compile_cmap(copy.deepcopy(cmap))
[700]117
[702]118                    col2key = {'red':0, 'green':1, 'blue':2}
119                    for key in cmap:
120                        # for color from (r/g/b) #n's should be the same for all keys!
121                        n_min = (min - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
122                        n_max = (max - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
123
124                        min_col = og_cmap(n_min)
125                        max_col = og_cmap(n_max)
126
127                        cmap[key][0] = [min, min_col[col2key[key]], min_col[col2key[key]]]
128                        cmap[key][-1] = [max, max_col[col2key[key]], max_col[col2key[key]]]
129                print(self.settings[part]['color']['cmap'])
130                self.settings[part]['color']['cmap'] = normalize_and_compile_cmap(self.settings[part]['color']['cmap'])
131
[624]132    def draw_dots(self, file, min_width, max_width, max_height):
133        for i in range(len(self.design.positions)):
134            node = self.design.positions[i]
135            if 'x' not in node:
136                continue
137            dot_style = self.compute_dot_style(node=i)
138            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
139                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
[585]140
[624]141    def draw_lines(self, file, min_width, max_width, max_height):
142        for parent in range(len(self.design.positions)):
143            par_pos = self.design.positions[parent]
144            if not 'x' in par_pos:
145                continue
146            for child in self.design.tree.children[parent]:
147                chi_pos = self.design.positions[child]
148                if 'x' not in chi_pos:
149                    continue
150                line_style = self.compute_line_style(parent, child)
151                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
152                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
153                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
154                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
[585]155
[624]156    def draw_scale(self, file, filename):
[626]157        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")
[562]158
[624]159        start_text = ""
160        end_text = ""
161        if self.design.TIME == "BIRTHS":
162           start_text = "Birth #0"
163           end_text = "Birth #" + str(len(self.design.positions)-1)
164        if self.design.TIME == "REAL":
165           start_text = "Time " + str(min(self.design.tree.time))
166           end_text = "Time " + str(max(self.design.tree.time))
167        if self.design.TIME == "GENERATIONAL":
[633]168           start_text = "Depth " + str(self.design.props['adepth_min'])
169           end_text = "Depth " + str(self.design.props['adepth_max'])
[576]170
[626]171        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
172        self.add_text(file, start_text, (self.width, self.h_margin), "end")
[624]173        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
[626]174        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
[562]175
[624]176    def compute_property(self, part, prop, node):
177        start = self.settings[part][prop]['start']
178        end = self.settings[part][prop]['end']
179        value = (self.design.props[self.settings[part][prop]['meaning']][node]
180                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
181        bias = self.settings[part][prop]['bias']
182        if prop == "color":
[700]183            if not self.settings[part][prop]['cmap']:
184                return self.compute_color(start, end, value, bias)
185            else:
186                return self.compute_color_from_cmap(self.settings[part][prop]['cmap'], value, bias)
[624]187        else:
188            return self.compute_value(start, end, value, bias)
[562]189
[700]190    def compute_color_from_cmap(self, cmap, value, bias=1):
191        value = 1 - (1-value)**bias
192        rgba = cmap(value)
193        return (100*rgba[0], 100*rgba[1], 100*rgba[2])
194
195
[624]196    def compute_color(self, start, end, value, bias=1):
197        if isinstance(value, str):
198            value = int(value)
[695]199            r, g, b = self.color_converter.to_rgb(self.settings['colors_of_kinds'][value])
[624]200        else:
[695]201            start_color = self.color_converter.to_rgb(start)
202            end_color = self.color_converter.to_rgb(end)
[624]203            value = 1 - (1-value)**bias
[695]204            r = start_color[0]*(1-value)+end_color[0]*value
205            g = start_color[1]*(1-value)+end_color[1]*value
206            b = start_color[2]*(1-value)+end_color[2]*value
207        return (100*r, 100*g, 100*b)
[562]208
[624]209    def compute_value(self, start, end, value, bias=1):
210        value = 1 - (1-value)**bias
211        return start*(1-value) + end*value
[564]212
[624]213class PngDrawer(Drawer):
[626]214
215    def scale_up(self):
216        self.width *= self.multi
217        self.height *= self.multi
218        self.w_margin *= self.multi
219        self.h_margin *= self.multi
220        self.h_no_margs *= self.multi
221        self.w_no_margs *= self.multi
222
223    def scale_down(self):
224        self.width /= self.multi
225        self.height /= self.multi
226        self.w_margin /= self.multi
227        self.h_margin /= self.multi
228        self.h_no_margs /= self.multi
229        self.w_no_margs /= self.multi
230
231    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
[624]232        print("Drawing...")
[564]233
[626]234        self.multi=multi
235        self.scale_up()
236
[624]237        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
[562]238
[624]239        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
240        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
241        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
[562]242
[624]243        self.draw_lines(back, min_width, max_width, max_height)
244        self.draw_dots(back, min_width, max_width, max_height)
[562]245
[624]246        if scale == "SIMPLE":
247            self.draw_scale(back, input_filename)
[564]248
[626]249        #back.show()
250        self.scale_down()
251
252        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
253
[624]254        back.save(filename)
[562]255
[624]256    def add_dot(self, file, pos, style):
257        x, y = int(pos[0]), int(pos[1])
[626]258        r = style['r']*self.multi
[624]259        offset = (int(x - r), int(y - r))
260        size = (2*int(r), 2*int(r))
[622]261
[624]262        c = style['color']
[572]263
[624]264        img = Image.new('RGBA', size)
265        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
266                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
267        file.paste(img, offset, mask=img)
[572]268
[624]269    def add_line(self, file, from_pos, to_pos, style):
270        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
[626]271        w = int(style['width'])*self.multi
[562]272
[624]273        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
274        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
[700]275        if size[0] == 0 or size[1] == 0:
276            return
[577]277
[624]278        c = style['color']
[622]279
[624]280        img = Image.new('RGBA', size)
281        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
[626]282                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
[624]283        file.paste(img, offset, mask=img)
[562]284
[624]285    def add_dashed_line(self, file, from_pos, to_pos):
286        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
287        sublines = 50
288        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
[684]289        normdiv = 2*sublines-1
[624]290        for i in range(sublines):
[684]291            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
292                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
293            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
294                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
[624]295            self.add_line(file, from_pos_sub, to_pos_sub, style)
[562]296
[624]297    def add_text(self, file, text, pos, anchor, style=''):
[626]298        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
[621]299
[624]300        img = Image.new('RGBA', (self.width, self.height))
301        draw = ImageDraw.Draw(img)
302        txtsize = draw.textsize(text, font=font)
[626]303        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
[624]304        draw.text(pos, text, (0,0,0), font=font)
305        file.paste(img, (0,0), mask=img)
[622]306
[624]307    def compute_line_style(self, parent, child):
308        return {'color': self.compute_property('lines', 'color', child),
309                'width': self.compute_property('lines', 'width', child),
310                'opacity': self.compute_property('lines', 'opacity', child)}
311
312    def compute_dot_style(self, node):
313        return {'color': self.compute_property('dots', 'color', node),
314                'r': self.compute_property('dots', 'size', node),
315                'opacity': self.compute_property('dots', 'opacity', node)}
316
317class SvgDrawer(Drawer):
[626]318    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
[624]319        print("Drawing...")
320        file = open(filename, "w")
321
322        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
323        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
324        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
325
326        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
327                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
328                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
329
330        self.draw_lines(file, min_width, max_width, max_height)
331        self.draw_dots(file, min_width, max_width, max_height)
332
333        if scale == "SIMPLE":
334            self.draw_scale(file, input_filename)
335
336        file.write("</svg>")
337        file.close()
338
339    def add_text(self, file, text, pos, anchor, style=''):
340        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
[628]341        # assuming font size 12, it should be taken from the style string!
342        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
[624]343
344    def add_dot(self, file, pos, style):
345        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
346
347    def add_line(self, file, from_pos, to_pos, style):
348        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
349                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
350
351    def add_dashed_line(self, file, from_pos, to_pos):
352        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
353        self.add_line(file, from_pos, to_pos, style)
354
355    def compute_line_style(self, parent, child):
356        return self.compute_stroke_color('lines', child) + ' ' \
357               + self.compute_stroke_width('lines', child) + ' ' \
358               + self.compute_stroke_opacity(child)
359
360    def compute_dot_style(self, node):
361        return self.compute_dot_size(node) + ' ' \
362               + self.compute_fill_opacity(node) + ' ' \
363               + self.compute_dot_fill(node)
364
365    def compute_stroke_color(self, part, node):
366        color = self.compute_property(part, 'color', node)
367        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
368
369    def compute_stroke_width(self, part, node):
370        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
371
372    def compute_stroke_opacity(self, node):
373        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
374
375    def compute_fill_opacity(self, node):
376        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
377
378    def compute_dot_size(self, node):
379        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
380
381    def compute_dot_fill(self, node):
382        color = self.compute_property('dots', 'color', node)
383        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
384
385class Designer:
386
387    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
388        self.props = {}
389
390        self.tree = tree
391
392        self.TIME = time
393        self.JITTER = jitter
394
395        if balance == "RANDOM":
396            self.xmin_crowd = self.xmin_crowd_random
397        elif balance == "MIN":
398            self.xmin_crowd = self.xmin_crowd_min
399        elif balance == "DENSITY":
400            self.xmin_crowd = self.xmin_crowd_density
[562]401        else:
[624]402            raise ValueError("Error, the value of BALANCE does not match any expected value.")
[562]403
[624]404    def calculate_measures(self):
405        print("Calculating measures...")
[679]406        self.compute_depth()
[701]407        self.compute_maxdepth()
[624]408        self.compute_adepth()
409        self.compute_children()
410        self.compute_kind()
411        self.compute_time()
[633]412        self.compute_progress()
[624]413        self.compute_custom()
[622]414
[624]415    def xmin_crowd_random(self, x1, x2, y):
416        return (x1 if random.randrange(2) == 0 else x2)
[562]417
[624]418    def xmin_crowd_min(self, x1, x2, y):
419        x1_closest = 999999
420        x2_closest = 999999
421        miny = y-3
422        maxy = y+3
423        i = bisect.bisect_left(self.y_sorted, miny)
424        while True:
425            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
426                break
427            pos = self.positions_sorted[i]
[562]428
[624]429            x1_closest = min(x1_closest, abs(x1-pos['x']))
430            x2_closest = min(x2_closest, abs(x2-pos['x']))
[562]431
[624]432            i += 1
433        return (x1 if x1_closest > x2_closest else x2)
[562]434
[624]435    def xmin_crowd_density(self, x1, x2, y):
[690]436        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
[694]437        CONST_LOCAL_AREA_RADIUS = 5
438        CONST_GLOBAL_AREA_RADIUS = 10
439        CONST_WINDOW_SIZE = 20000 #TODO should depend on the maxY ?
[690]440        x1_dist_loc = 0
441        x2_dist_loc = 0
442        count_loc = 1
443        x1_dist_glob = 0
444        x2_dist_glob = 0
445        count_glob = 1
[694]446        miny = y-CONST_WINDOW_SIZE
447        maxy = y+CONST_WINDOW_SIZE
[624]448        i_left = bisect.bisect_left(self.y_sorted, miny)
449        i_right = bisect.bisect_right(self.y_sorted, maxy)
[694]450        #TODO test: maxy=y should give the same results, right?
[562]451
[624]452        def include_pos(pos):
[690]453            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob
[562]454
[694]455            dysq = (pos['y']-y)**2 + 1 #+1 so 1/dysq is at most 1
456            dx1 = math.fabs(pos['x']-x1)
457            dx2 = math.fabs(pos['x']-x2)
458
[690]459            d = math.fabs(pos['x'] - (x1+x2)/2)
[623]460
[694]461            if d < CONST_LOCAL_AREA_RADIUS:
462                x1_dist_loc += math.sqrt(dx1/dysq + dx1**2)
463                x2_dist_loc += math.sqrt(dx2/dysq + dx2**2)
[690]464                count_loc += 1
[694]465            elif d > CONST_GLOBAL_AREA_RADIUS:
466                x1_dist_glob += math.sqrt(dx1/dysq + dx1**2)
467                x2_dist_glob += math.sqrt(dx2/dysq + dx2**2)
[690]468                count_glob += 1
469
[624]470        # optimized to draw from all the nodes, if less than 10 nodes in the range
471        if len(self.positions_sorted) > i_left:
472            if i_right - i_left < 10:
473                for j in range(i_left, i_right):
474                    include_pos(self.positions_sorted[j])
475            else:
476                for j in range(10):
477                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
478                    include_pos(pos)
[562]479
[690]480        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
481        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
[624]482        #print(x1_dist, x2_dist)
483        #x1_dist = x1_dist**2
484        #x2_dist = x2_dist**2
485        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
486        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
[571]487
[624]488    def calculate_node_positions(self, ignore_last=0):
489        print("Calculating positions...")
[562]490
[624]491        def add_node(node):
492            index = bisect.bisect_left(self.y_sorted, node['y'])
493            self.y_sorted.insert(index, node['y'])
494            self.positions_sorted.insert(index, node)
495            self.positions[node['id']] = node
[572]496
[624]497        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
498        self.y_sorted = [0]
499        self.positions = [{} for x in range(len(self.tree.parents))]
500        self.positions[0] = {'x':0, 'y':0, 'id':0}
[572]501
[677]502        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
503        visiting_order = [i for i in range(0, len(self.tree.parents))]
[701]504        visiting_order = sorted(visiting_order, key=lambda q:\
505                            0 if q == 0 else self.props["maxdepth"][q])
[562]506
[624]507        start_time = timelib.time()
[566]508
[677]509        # for each child of the current node
[686]510        for node_counter,child in enumerate(visiting_order, start=1):
[677]511            # debug info - elapsed time
[685]512            if node_counter % 100000 == 0:
513               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
[677]514               start_time = timelib.time()
[562]515
[677]516            # using normalized adepth
517            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
[621]518
[677]519                ypos = 0
520                if self.TIME == "BIRTHS":
521                    ypos = child
522                elif self.TIME == "GENERATIONAL":
523                    # one more than its parent (what if more than one parent?)
[680]524                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
525                        if self.tree.parents[child] else 0
[677]526                elif self.TIME == "REAL":
527                    ypos = self.tree.time[child]
[621]528
[677]529                if len(self.tree.parents[child]) == 1:
530                # if current_node is the only parent
[687]531                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]
[621]532
[677]533                    if self.JITTER:
[690]534                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
[621]535                    else:
[687]536                        dissimilarity = (1-similarity) + 0.001
[677]537                    add_node({'id':child, 'y':ypos, 'x':
538                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
539                              self.positions[parent]['x']+dissimilarity, ypos)})
540                else:
541                    # position weighted by the degree of inheritence from each parent
542                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
543                    xpos = sum([self.positions[k]['x']*v/total_inheretance
544                               for k, v in self.tree.parents[child].items()])
545                    if self.JITTER:
546                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
547                    else:
548                        add_node({'id':child, 'y':ypos, 'x':xpos})
[621]549
550
[624]551    def compute_custom(self):
552        for prop in self.tree.props:
553            self.props[prop] = [None for x in range(len(self.tree.children))]
[621]554
[624]555            for i in range(len(self.props[prop])):
556                self.props[prop][i] = self.tree.props[prop][i]
[621]557
[624]558            self.normalize_prop(prop)
[562]559
[624]560    def compute_time(self):
561        # simple rewrite from the tree
562        self.props["time"] = [0 for x in range(len(self.tree.children))]
[562]563
[624]564        for i in range(len(self.props['time'])):
565            self.props['time'][i] = self.tree.time[i]
[572]566
[624]567        self.normalize_prop('time')
[617]568
[624]569    def compute_kind(self):
570        # simple rewrite from the tree
571        self.props["kind"] = [0 for x in range(len(self.tree.children))]
[617]572
[624]573        for i in range (len(self.props['kind'])):
574            self.props['kind'][i] = str(self.tree.kind[i])
[617]575
[624]576    def compute_depth(self):
577        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
[681]578        visited = [0 for x in range(len(self.tree.children))]
[617]579
[624]580        nodes_to_visit = [0]
[681]581        visited[0] = 1
[624]582        self.props["depth"][0] = 0
583        while True:
[681]584            current_node = nodes_to_visit[0]
[682]585
[681]586            for child in self.tree.children[current_node]:
587                if visited[child] == 0:
588                    visited[child] = 1
589                    nodes_to_visit.append(child)
590                    self.props["depth"][child] = self.props["depth"][current_node]+1
[624]591            nodes_to_visit = nodes_to_visit[1:]
592            if len(nodes_to_visit) == 0:
593                break
[617]594
[624]595        self.normalize_prop('depth')
[617]596
[701]597    def compute_maxdepth(self):
598        self.props["maxdepth"] = [999999999 for x in range(len(self.tree.children))]
599        visited = [0 for x in range(len(self.tree.children))]
600
601        nodes_to_visit = [0]
602        visited[0] = 1
603        self.props["maxdepth"][0] = 0
604        while True:
605            current_node = nodes_to_visit[0]
606
607            for child in self.tree.children[current_node]:
608                if visited[child] == 0:
609                    visited[child] = 1
610                    nodes_to_visit.append(child)
611                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
612                elif self.props["maxdepth"][child] < self.props["maxdepth"][current_node]+1:
613                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
614                    if child not in  nodes_to_visit:
615                        nodes_to_visit.append(child)
616
617            nodes_to_visit = nodes_to_visit[1:]
618            if len(nodes_to_visit) == 0:
619                break
620
621        self.normalize_prop('maxdepth')
622
[624]623    def compute_adepth(self):
624        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
[617]625
[679]626        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
627        visiting_order = [i for i in range(0, len(self.tree.parents))]
[701]628        visiting_order = sorted(visiting_order, key=lambda q: self.props["maxdepth"][q])[::-1]
[617]629
[679]630        for node in visiting_order:
631            children = self.tree.children[node]
632            if len(children) != 0:
633                # 0 by default
634                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
[624]635        self.normalize_prop('adepth')
[594]636
[624]637    def compute_children(self):
638        self.props["children"] = [0 for x in range(len(self.tree.children))]
639        for i in range (len(self.props['children'])):
640            self.props['children'][i] = len(self.tree.children[i])
[562]641
[624]642        self.normalize_prop('children')
[564]643
[633]644    def compute_progress(self):
645        self.props["progress"] = [0 for x in range(len(self.tree.children))]
646        for i in range(len(self.props['children'])):
647            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
648            if len(times) > 4:
649                times = [times[i+1] - times[i] for i in range(len(times)-1)]
650                #print(times)
651                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
652                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
653
654        for i in range(0, 5):
655            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
656            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
657
658        mini = min(self.props['progress'])
659        maxi = max(self.props['progress'])
660        for k in range(len(self.props['progress'])):
661            if self.props['progress'][k] == 0:
662                self.props['progress'][k] = mini
663
664        #for k in range(len(self.props['progress'])):
665        #        self.props['progress'][k] = 1-self.props['progress'][k]
666
667        self.normalize_prop('progress')
668
[624]669    def normalize_prop(self, prop):
[678]670        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
[624]671        if len(noneless) > 0:
672            max_val = max(noneless)
673            min_val = min(noneless)
[697]674            print("%s: [%g, %g]" % (prop, min_val, max_val))
[624]675            self.props[prop +'_max'] = max_val
676            self.props[prop +'_min'] = min_val
677            for i in range(len(self.props[prop])):
678                if self.props[prop][i] is not None:
[633]679                    qqq = self.props[prop][i]
680                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
[594]681
[624]682class TreeData:
683    simple_data = None
[615]684
[624]685    children = []
686    parents = []
687    time = []
688    kind = []
[562]689
[624]690    def __init__(self): #, simple_data=False):
691        #self.simple_data = simple_data
692        pass
[562]693
[624]694    def load(self, filename, max_nodes=0):
695        print("Loading...")
[576]696
[624]697        CLI_PREFIX = "Script.Message:"
698        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
[576]699
[633]700        self.ids = {}
[628]701        def get_id(id, createOnError = True):
702            if createOnError:
[633]703                if id not in self.ids:
704                    self.ids[id] = len(self.ids)
[628]705            else:
[633]706                if id not in self.ids:
[628]707                    return None
[701]708
[633]709            return self.ids[id]
[576]710
[624]711        file = open(filename)
[576]712
[624]713        # counting the number of expected nodes
714        nodes = 0
715        for line in file:
716            line_arr = line.split(' ', 1)
717            if len(line_arr) == 2:
718                if line_arr[0] == CLI_PREFIX:
719                    line_arr = line_arr[1].split(' ', 1)
720                if line_arr[0] == "[OFFSPRING]":
721                    nodes += 1
[562]722
[624]723        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
724        self.parents = [{} for x in range(nodes)]
725        self.children = [[] for x in range(nodes)]
726        self.time = [0] * nodes
727        self.kind = [0] * nodes
[628]728        self.life_lenght = [0] * nodes
[624]729        self.props = {}
[562]730
[688]731        print("nodes: %d" % len(self.parents))
[562]732
[624]733        file.seek(0)
734        loaded_so_far = 0
735        lasttime = timelib.time()
736        for line in file:
737            line_arr = line.split(' ', 1)
738            if len(line_arr) == 2:
739                if line_arr[0] == CLI_PREFIX:
740                    line_arr = line_arr[1].split(' ', 1)
741                if line_arr[0] == "[OFFSPRING]":
[682]742                    try:
743                        creature = json.loads(line_arr[1])
744                    except ValueError:
745                        print("Json format error - the line cannot be read. Breaking the loading loop.")
746                        # fixing arrays by removing the last element
747                        # ! assuming that only the last line is broken !
748                        self.parents.pop()
749                        self.children.pop()
750                        self.time.pop()
751                        self.kind.pop()
752                        self.life_lenght.pop()
753                        nodes -= 1
754                        break
755
[624]756                    if "FromIDs" in creature:
[562]757
[624]758                        # make sure that ID's of parents are lower than that of their children
759                        for i in range(0, len(creature["FromIDs"])):
[633]760                            if creature["FromIDs"][i] not in self.ids:
[627]761                                get_id("virtual_parent")
[562]762
[624]763                        creature_id = get_id(creature["ID"])
764
765                        # debug
766                        if loaded_so_far%1000 == 0:
767                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
768                            lasttime = timelib.time()
769
770                        # we assign to each parent its contribution to the genotype of the child
771                        for i in range(0, len(creature["FromIDs"])):
[633]772                            if creature["FromIDs"][i] in self.ids:
[627]773                                parent_id = get_id(creature["FromIDs"][i])
774                            else:
775                                parent_id = get_id("virtual_parent")
[687]776                            inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
[624]777                            self.parents[creature_id][parent_id] = inherited
778
779                        if "Time" in creature:
780                            self.time[creature_id] = creature["Time"]
781
782                        if "Kind" in creature:
783                            self.kind[creature_id] = creature["Kind"]
784
785                        for prop in creature:
786                            if prop not in default_props:
787                                if prop not in self.props:
[628]788                                    self.props[prop] = [0 for i in range(nodes)]
[624]789                                self.props[prop][creature_id] = creature[prop]
790
791                        loaded_so_far += 1
792                    else:
793                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
[628]794                if line_arr[0] == "[DIED]":
795                    creature = json.loads(line_arr[1])
796                    creature_id = get_id(creature["ID"], False)
797                    if creature_id is not None:
798                        for prop in creature:
799                            if prop not in default_props:
800                                if prop not in self.props:
801                                    self.props[prop] = [0 for i in range(nodes)]
802                                self.props[prop][creature_id] = creature[prop]
[624]803
[628]804
[624]805            if loaded_so_far >= max_nodes and max_nodes != 0:
806                break
807
808        for k in range(len(self.parents)):
809            v = self.parents[k]
810            for val in self.parents[k]:
811                self.children[val].append(k)
812
[562]813depth = {}
[577]814kind = {}
[562]815
816def main():
817
[624]818    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
819                                                 'information from a text file. Supports files generated by Framsticks experiments.')
[615]820    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
[624]821    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
822    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
[562]823
[624]824    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
[626]825    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
826    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
[562]827
[620]828    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
[614]829                                                                      'BIRTHS: time measured as the number of births since the beginning; '
[571]830                                                                      'GENERATIONAL: time measured as number of ancestors; '
831                                                                      'REAL: real time of the simulation')
[620]832    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
[624]833    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
[571]834    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
[624]835    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
836    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
[562]837    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
838
839    parser.set_defaults(draw_tree=True)
840    parser.set_defaults(draw_skeleton=False)
841    parser.set_defaults(draw_spine=False)
842
843    parser.set_defaults(seed=-1)
844
845    args = parser.parse_args()
846
[620]847    TIME = args.time.upper()
848    BALANCE = args.balance.upper()
849    SCALE = args.scale.upper()
[571]850    JITTER = args.jitter
[620]851    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
852        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
853        or not SCALE in ['NONE', 'SIMPLE']:
[683]854        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
[620]855        return
[562]856
857    dir = args.input
858    seed = args.seed
859    if seed == -1:
860        seed = random.randint(0, 10000)
861    random.seed(seed)
[689]862    print("randomseed:", seed)
[562]863
[624]864    tree = TreeData()
865    tree.load(dir, max_nodes=args.max_nodes)
[562]866
[682]867
[624]868    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
869    designer.calculate_measures()
870    designer.calculate_node_positions(ignore_last=args.skip)
[562]871
[624]872    if args.output.endswith(".svg"):
873        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
874    else:
875        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
[626]876    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
[562]877
878
879main()
Note: See TracBrowser for help on using the repository browser.