source: mds-and-trees/tree-genealogy.py @ 707

Last change on this file since 707 was 707, checked in by konrad, 7 years ago

Added a warning message about merging nodes with a virtual parent node.

File size: 38.0 KB
Line 
1import json
2import math
3import random
4import argparse
5import bisect
6import copy
7import time as timelib
8from PIL import Image, ImageDraw, ImageFont
9from scipy import stats
10from matplotlib import colors
11import numpy as np
12
13class LoadingError(Exception):
14    pass
15
16class Drawer:
17
18    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
19        self.design = design
20        self.width = w
21        self.height = h
22        self.w_margin = w_margin
23        self.h_margin = h_margin
24        self.w_no_margs = w - 2* w_margin
25        self.h_no_margs = h - 2* h_margin
26
27        self.color_converter = colors.ColorConverter()
28
29        self.settings = {
30            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
31            'dots': {
32                'color': {
33                    'meaning': 'Lifespan',
34                    'normalize_cmap': False,
35                    'cmap': {},
36                    'start': 'red',
37                    'end': 'green',
38                    'bias': 1
39                    },
40                'size': {
41                    'meaning': 'EnergyEaten',
42                    'start': 1,
43                    'end': 6,
44                    'bias': 0.5
45                    },
46                'opacity': {
47                    'meaning': 'EnergyEaten',
48                    'start': 0.2,
49                    'end': 1,
50                    'bias': 1
51                    }
52            },
53            'lines': {
54                'color': {
55                    'meaning': 'adepth',
56                    'normalize_cmap': False,
57                    'cmap': {},
58                    'start': 'black',
59                    'end': 'red',
60                    'bias': 3
61                    },
62                'width': {
63                    'meaning': 'adepth',
64                    'start': 0.1,
65                    'end': 4,
66                    'bias': 3
67                    },
68                'opacity': {
69                    'meaning': 'adepth',
70                    'start': 0.1,
71                    'end': 0.8,
72                    'bias': 5
73                    }
74            }
75        }
76
77        def merge(source, destination):
78            for key, value in source.items():
79                if isinstance(value, dict):
80                    node = destination.setdefault(key, {})
81                    merge(value, node)
82                else:
83                    destination[key] = value
84            return destination
85
86        if config_file != "":
87            with open(config_file) as config:
88                c = json.load(config)
89            self.settings = merge(c, self.settings)
90            #print(json.dumps(self.settings, indent=4, sort_keys=True))
91
92        self.compile_cmaps()
93
94    def compile_cmaps(self):
95        def normalize_and_compile_cmap(cmap):
96            for key in cmap:
97                for arr in cmap[key]:
98                    arr[0] = (arr[0] - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
99            return colors.LinearSegmentedColormap('Custom', cmap)
100
101        for part in ['dots', 'lines']:
102            if self.settings[part]['color']['cmap']:
103                if self.settings[part]['color']['normalize_cmap']:
104                    cmap = self.settings[part]['color']['cmap']
105                    min = self.design.props[self.settings[part]['color']['meaning'] + "_min"]
106                    max = self.design.props[self.settings[part]['color']['meaning'] + "_max"]
107
108                    for key in cmap:
109                        if cmap[key][0][0] > min:
110                            cmap[key].insert(0, cmap[key][0][:])
111                            cmap[key][0][0] = min
112                        if cmap[key][-1][0] < max:
113                            cmap[key].append(cmap[key][-1][:])
114                            cmap[key][-1][0] = max
115
116                    og_cmap = normalize_and_compile_cmap(copy.deepcopy(cmap))
117
118                    col2key = {'red':0, 'green':1, 'blue':2}
119                    for key in cmap:
120                        # for color from (r/g/b) #n's should be the same for all keys!
121                        n_min = (min - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
122                        n_max = (max - cmap[key][0][0]) / (cmap[key][-1][0] - cmap[key][0][0])
123
124                        min_col = og_cmap(n_min)
125                        max_col = og_cmap(n_max)
126
127                        cmap[key][0] = [min, min_col[col2key[key]], min_col[col2key[key]]]
128                        cmap[key][-1] = [max, max_col[col2key[key]], max_col[col2key[key]]]
129                print(self.settings[part]['color']['cmap'])
130                self.settings[part]['color']['cmap'] = normalize_and_compile_cmap(self.settings[part]['color']['cmap'])
131
132    def draw_dots(self, file, min_width, max_width, max_height):
133        for i in range(len(self.design.positions)):
134            node = self.design.positions[i]
135            if 'x' not in node:
136                continue
137            dot_style = self.compute_dot_style(node=i)
138            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
139                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
140
141    def draw_lines(self, file, min_width, max_width, max_height):
142        for parent in range(len(self.design.positions)):
143            par_pos = self.design.positions[parent]
144            if not 'x' in par_pos:
145                continue
146            for child in self.design.tree.children[parent]:
147                chi_pos = self.design.positions[child]
148                if 'x' not in chi_pos:
149                    continue
150                line_style = self.compute_line_style(parent, child)
151                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
152                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
153                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
154                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
155
156    def draw_scale(self, file, filename):
157        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")
158
159        start_text = ""
160        end_text = ""
161        if self.design.TIME == "BIRTHS":
162           start_text = "Birth #0"
163           end_text = "Birth #" + str(len(self.design.positions)-1)
164        if self.design.TIME == "REAL":
165           start_text = "Time " + str(min(self.design.tree.time))
166           end_text = "Time " + str(max(self.design.tree.time))
167        if self.design.TIME == "GENERATIONAL":
168           start_text = "Depth " + str(self.design.props['adepth_min'])
169           end_text = "Depth " + str(self.design.props['adepth_max'])
170
171        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
172        self.add_text(file, start_text, (self.width, self.h_margin), "end")
173        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
174        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
175
176    def compute_property(self, part, prop, node):
177        start = self.settings[part][prop]['start']
178        end = self.settings[part][prop]['end']
179        value = (self.design.props[self.settings[part][prop]['meaning']][node]
180                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
181        bias = self.settings[part][prop]['bias']
182        if prop == "color":
183            if not self.settings[part][prop]['cmap']:
184                return self.compute_color(start, end, value, bias)
185            else:
186                return self.compute_color_from_cmap(self.settings[part][prop]['cmap'], value, bias)
187        else:
188            return self.compute_value(start, end, value, bias)
189
190    def compute_color_from_cmap(self, cmap, value, bias=1):
191        value = 1 - (1-value)**bias
192        rgba = cmap(value)
193        return (100*rgba[0], 100*rgba[1], 100*rgba[2])
194
195
196    def compute_color(self, start, end, value, bias=1):
197        if isinstance(value, str):
198            value = int(value)
199            r, g, b = self.color_converter.to_rgb(self.settings['colors_of_kinds'][value])
200        else:
201            start_color = self.color_converter.to_rgb(start)
202            end_color = self.color_converter.to_rgb(end)
203            value = 1 - (1-value)**bias
204            r = start_color[0]*(1-value)+end_color[0]*value
205            g = start_color[1]*(1-value)+end_color[1]*value
206            b = start_color[2]*(1-value)+end_color[2]*value
207        return (100*r, 100*g, 100*b)
208
209    def compute_value(self, start, end, value, bias=1):
210        value = 1 - (1-value)**bias
211        return start*(1-value) + end*value
212
213class PngDrawer(Drawer):
214
215    def scale_up(self):
216        self.width *= self.multi
217        self.height *= self.multi
218        self.w_margin *= self.multi
219        self.h_margin *= self.multi
220        self.h_no_margs *= self.multi
221        self.w_no_margs *= self.multi
222
223    def scale_down(self):
224        self.width /= self.multi
225        self.height /= self.multi
226        self.w_margin /= self.multi
227        self.h_margin /= self.multi
228        self.h_no_margs /= self.multi
229        self.w_no_margs /= self.multi
230
231    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
232        print("Drawing...")
233
234        self.multi=multi
235        self.scale_up()
236
237        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
238
239        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
240        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
241        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
242
243        self.draw_lines(back, min_width, max_width, max_height)
244        self.draw_dots(back, min_width, max_width, max_height)
245
246        if scale == "SIMPLE":
247            self.draw_scale(back, input_filename)
248
249        #back.show()
250        self.scale_down()
251
252        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
253
254        back.save(filename)
255
256    def add_dot(self, file, pos, style):
257        x, y = int(pos[0]), int(pos[1])
258        r = style['r']*self.multi
259        offset = (int(x - r), int(y - r))
260        size = (2*int(r), 2*int(r))
261
262        c = style['color']
263
264        img = Image.new('RGBA', size)
265        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
266                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
267        file.paste(img, offset, mask=img)
268
269    def add_line(self, file, from_pos, to_pos, style):
270        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
271        w = int(style['width'])*self.multi
272
273        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
274        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
275        if size[0] == 0 or size[1] == 0:
276            return
277
278        c = style['color']
279
280        img = Image.new('RGBA', size)
281        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
282                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
283        file.paste(img, offset, mask=img)
284
285    def add_dashed_line(self, file, from_pos, to_pos):
286        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
287        sublines = 50
288        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
289        normdiv = 2*sublines-1
290        for i in range(sublines):
291            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
292                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
293            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
294                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
295            self.add_line(file, from_pos_sub, to_pos_sub, style)
296
297    def add_text(self, file, text, pos, anchor, style=''):
298        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
299
300        img = Image.new('RGBA', (self.width, self.height))
301        draw = ImageDraw.Draw(img)
302        txtsize = draw.textsize(text, font=font)
303        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
304        draw.text(pos, text, (0,0,0), font=font)
305        file.paste(img, (0,0), mask=img)
306
307    def compute_line_style(self, parent, child):
308        return {'color': self.compute_property('lines', 'color', child),
309                'width': self.compute_property('lines', 'width', child),
310                'opacity': self.compute_property('lines', 'opacity', child)}
311
312    def compute_dot_style(self, node):
313        return {'color': self.compute_property('dots', 'color', node),
314                'r': self.compute_property('dots', 'size', node),
315                'opacity': self.compute_property('dots', 'opacity', node)}
316
317class SvgDrawer(Drawer):
318    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
319        print("Drawing...")
320        file = open(filename, "w")
321
322        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
323        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
324        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
325
326        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
327                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
328                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
329
330        self.draw_lines(file, min_width, max_width, max_height)
331        self.draw_dots(file, min_width, max_width, max_height)
332
333        if scale == "SIMPLE":
334            self.draw_scale(file, input_filename)
335
336        file.write("</svg>")
337        file.close()
338
339    def add_text(self, file, text, pos, anchor, style=''):
340        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
341        # assuming font size 12, it should be taken from the style string!
342        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
343
344    def add_dot(self, file, pos, style):
345        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
346
347    def add_line(self, file, from_pos, to_pos, style):
348        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
349                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
350
351    def add_dashed_line(self, file, from_pos, to_pos):
352        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
353        self.add_line(file, from_pos, to_pos, style)
354
355    def compute_line_style(self, parent, child):
356        return self.compute_stroke_color('lines', child) + ' ' \
357               + self.compute_stroke_width('lines', child) + ' ' \
358               + self.compute_stroke_opacity(child)
359
360    def compute_dot_style(self, node):
361        return self.compute_dot_size(node) + ' ' \
362               + self.compute_fill_opacity(node) + ' ' \
363               + self.compute_dot_fill(node)
364
365    def compute_stroke_color(self, part, node):
366        color = self.compute_property(part, 'color', node)
367        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
368
369    def compute_stroke_width(self, part, node):
370        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
371
372    def compute_stroke_opacity(self, node):
373        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
374
375    def compute_fill_opacity(self, node):
376        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
377
378    def compute_dot_size(self, node):
379        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
380
381    def compute_dot_fill(self, node):
382        color = self.compute_property('dots', 'color', node)
383        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
384
385class Designer:
386
387    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
388        self.props = {}
389
390        self.tree = tree
391
392        self.TIME = time
393        self.JITTER = jitter
394
395        if balance == "RANDOM":
396            self.xmin_crowd = self.xmin_crowd_random
397        elif balance == "MIN":
398            self.xmin_crowd = self.xmin_crowd_min
399        elif balance == "DENSITY":
400            self.xmin_crowd = self.xmin_crowd_density
401        else:
402            raise ValueError("Error, the value of BALANCE does not match any expected value.")
403
404    def calculate_measures(self):
405        print("Calculating measures...")
406        self.compute_depth()
407        self.compute_maxdepth()
408        self.compute_adepth()
409        self.compute_children()
410        self.compute_kind()
411        self.compute_time()
412        self.compute_progress()
413        self.compute_custom()
414
415    def xmin_crowd_random(self, x1, x2, y):
416        return (x1 if random.randrange(2) == 0 else x2)
417
418    def xmin_crowd_min(self, x1, x2, y):
419        x1_closest = 999999
420        x2_closest = 999999
421        miny = y-3
422        maxy = y+3
423        i = bisect.bisect_left(self.y_sorted, miny)
424        while True:
425            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
426                break
427            pos = self.positions_sorted[i]
428
429            x1_closest = min(x1_closest, abs(x1-pos['x']))
430            x2_closest = min(x2_closest, abs(x2-pos['x']))
431
432            i += 1
433        return (x1 if x1_closest > x2_closest else x2)
434
435    def xmin_crowd_density(self, x1, x2, y):
436        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
437        CONST_LOCAL_AREA_RADIUS = 5
438        CONST_GLOBAL_AREA_RADIUS = 10
439        CONST_WINDOW_SIZE = 20000 #TODO should depend on the maxY ?
440        x1_dist_loc = 0
441        x2_dist_loc = 0
442        count_loc = 1
443        x1_dist_glob = 0
444        x2_dist_glob = 0
445        count_glob = 1
446        miny = y-CONST_WINDOW_SIZE
447        maxy = y+CONST_WINDOW_SIZE
448        i_left = bisect.bisect_left(self.y_sorted, miny)
449        i_right = bisect.bisect_right(self.y_sorted, maxy)
450        #TODO test: maxy=y should give the same results, right?
451
452        def include_pos(pos):
453            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob
454
455            dysq = (pos['y']-y)**2 + 1 #+1 so 1/dysq is at most 1
456            dx1 = math.fabs(pos['x']-x1)
457            dx2 = math.fabs(pos['x']-x2)
458
459            d = math.fabs(pos['x'] - (x1+x2)/2)
460
461            if d < CONST_LOCAL_AREA_RADIUS:
462                x1_dist_loc += math.sqrt(dx1/dysq + dx1**2)
463                x2_dist_loc += math.sqrt(dx2/dysq + dx2**2)
464                count_loc += 1
465            elif d > CONST_GLOBAL_AREA_RADIUS:
466                x1_dist_glob += math.sqrt(dx1/dysq + dx1**2)
467                x2_dist_glob += math.sqrt(dx2/dysq + dx2**2)
468                count_glob += 1
469
470        # optimized to draw from all the nodes, if less than 10 nodes in the range
471        if len(self.positions_sorted) > i_left:
472            if i_right - i_left < 10:
473                for j in range(i_left, i_right):
474                    include_pos(self.positions_sorted[j])
475            else:
476                for j in range(10):
477                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
478                    include_pos(pos)
479
480        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
481        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
482        #print(x1_dist, x2_dist)
483        #x1_dist = x1_dist**2
484        #x2_dist = x2_dist**2
485        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
486        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
487
488    def calculate_node_positions(self, ignore_last=0):
489        print("Calculating positions...")
490
491        def add_node(node):
492            index = bisect.bisect_left(self.y_sorted, node['y'])
493            self.y_sorted.insert(index, node['y'])
494            self.positions_sorted.insert(index, node)
495            self.positions[node['id']] = node
496
497        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
498        self.y_sorted = [0]
499        self.positions = [{} for x in range(len(self.tree.parents))]
500        self.positions[0] = {'x':0, 'y':0, 'id':0}
501
502        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
503        visiting_order = [i for i in range(0, len(self.tree.parents))]
504        visiting_order = sorted(visiting_order, key=lambda q:\
505                            0 if q == 0 else self.props["maxdepth"][q])
506
507        start_time = timelib.time()
508
509        # for each child of the current node
510        for node_counter,child in enumerate(visiting_order, start=1):
511            # debug info - elapsed time
512            if node_counter % 100000 == 0:
513               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
514               start_time = timelib.time()
515
516            # using normalized adepth
517            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
518
519                ypos = 0
520                if self.TIME == "BIRTHS":
521                    ypos = child
522                elif self.TIME == "GENERATIONAL":
523                    # one more than its parent (what if more than one parent?)
524                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
525                        if self.tree.parents[child] else 0
526                elif self.TIME == "REAL":
527                    ypos = self.tree.time[child]
528
529                if len(self.tree.parents[child]) == 1:
530                # if current_node is the only parent
531                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]
532
533                    if self.JITTER:
534                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
535                    else:
536                        dissimilarity = (1-similarity) + 0.001
537                    add_node({'id':child, 'y':ypos, 'x':
538                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
539                              self.positions[parent]['x']+dissimilarity, ypos)})
540                else:
541                    # position weighted by the degree of inheritence from each parent
542                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
543                    xpos = sum([self.positions[k]['x']*v/total_inheretance
544                               for k, v in self.tree.parents[child].items()])
545                    if self.JITTER:
546                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
547                    else:
548                        add_node({'id':child, 'y':ypos, 'x':xpos})
549
550
551    def compute_custom(self):
552        for prop in self.tree.props:
553            self.props[prop] = [None for x in range(len(self.tree.children))]
554
555            for i in range(len(self.props[prop])):
556                self.props[prop][i] = self.tree.props[prop][i]
557
558            self.normalize_prop(prop)
559
560    def compute_time(self):
561        # simple rewrite from the tree
562        self.props["time"] = [0 for x in range(len(self.tree.children))]
563
564        for i in range(len(self.props['time'])):
565            self.props['time'][i] = self.tree.time[i]
566
567        self.normalize_prop('time')
568
569    def compute_kind(self):
570        # simple rewrite from the tree
571        self.props["kind"] = [0 for x in range(len(self.tree.children))]
572
573        for i in range (len(self.props['kind'])):
574            self.props['kind'][i] = str(self.tree.kind[i])
575
576    def compute_depth(self):
577        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
578        visited = [0 for x in range(len(self.tree.children))]
579
580        nodes_to_visit = [0]
581        visited[0] = 1
582        self.props["depth"][0] = 0
583        while True:
584            current_node = nodes_to_visit[0]
585
586            for child in self.tree.children[current_node]:
587                if visited[child] == 0:
588                    visited[child] = 1
589                    nodes_to_visit.append(child)
590                    self.props["depth"][child] = self.props["depth"][current_node]+1
591            nodes_to_visit = nodes_to_visit[1:]
592            if len(nodes_to_visit) == 0:
593                break
594
595        self.normalize_prop('depth')
596
597    def compute_maxdepth(self):
598        self.props["maxdepth"] = [999999999 for x in range(len(self.tree.children))]
599        visited = [0 for x in range(len(self.tree.children))]
600
601        nodes_to_visit = [0]
602        visited[0] = 1
603        self.props["maxdepth"][0] = 0
604        while True:
605            current_node = nodes_to_visit[0]
606
607            for child in self.tree.children[current_node]:
608                if visited[child] == 0:
609                    visited[child] = 1
610                    nodes_to_visit.append(child)
611                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
612                elif self.props["maxdepth"][child] < self.props["maxdepth"][current_node]+1:
613                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
614                    if child not in  nodes_to_visit:
615                        nodes_to_visit.append(child)
616
617            nodes_to_visit = nodes_to_visit[1:]
618            if len(nodes_to_visit) == 0:
619                break
620
621        self.normalize_prop('maxdepth')
622
623    def compute_adepth(self):
624        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
625
626        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
627        visiting_order = [i for i in range(0, len(self.tree.parents))]
628        visiting_order = sorted(visiting_order, key=lambda q: self.props["maxdepth"][q])[::-1]
629
630        for node in visiting_order:
631            children = self.tree.children[node]
632            if len(children) != 0:
633                # 0 by default
634                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
635        self.normalize_prop('adepth')
636
637    def compute_children(self):
638        self.props["children"] = [0 for x in range(len(self.tree.children))]
639        for i in range (len(self.props['children'])):
640            self.props['children'][i] = len(self.tree.children[i])
641
642        self.normalize_prop('children')
643
644    def compute_progress(self):
645        self.props["progress"] = [0 for x in range(len(self.tree.children))]
646        for i in range(len(self.props['children'])):
647            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
648            if len(times) > 4:
649                times = [times[i+1] - times[i] for i in range(len(times)-1)]
650                #print(times)
651                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
652                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
653
654        for i in range(0, 5):
655            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
656            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
657
658        mini = min(self.props['progress'])
659        maxi = max(self.props['progress'])
660        for k in range(len(self.props['progress'])):
661            if self.props['progress'][k] == 0:
662                self.props['progress'][k] = mini
663
664        #for k in range(len(self.props['progress'])):
665        #        self.props['progress'][k] = 1-self.props['progress'][k]
666
667        self.normalize_prop('progress')
668
669    def normalize_prop(self, prop):
670        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
671        if len(noneless) > 0:
672            max_val = max(noneless)
673            min_val = min(noneless)
674            print("%s: [%g, %g]" % (prop, min_val, max_val))
675            self.props[prop +'_max'] = max_val
676            self.props[prop +'_min'] = min_val
677            for i in range(len(self.props[prop])):
678                if self.props[prop][i] is not None:
679                    qqq = self.props[prop][i]
680                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
681
682class TreeData:
683    simple_data = None
684
685    children = []
686    parents = []
687    time = []
688    kind = []
689
690    def __init__(self): #, simple_data=False):
691        #self.simple_data = simple_data
692        pass
693
694    def load(self, filename, max_nodes=0):
695        print("Loading...")
696
697        CLI_PREFIX = "Script.Message:"
698        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
699
700        merged_with_virtual_parent = []
701
702        self.ids = {}
703        def get_id(id, createOnError = True):
704            if createOnError:
705                if id not in self.ids:
706                    self.ids[id] = len(self.ids)
707            else:
708                if id not in self.ids:
709                    return None
710
711            return self.ids[id]
712
713        file = open(filename)
714
715        # counting the number of expected nodes
716        nodes = 0
717        for line in file:
718            line_arr = line.split(' ', 1)
719            if len(line_arr) == 2:
720                if line_arr[0] == CLI_PREFIX:
721                    line_arr = line_arr[1].split(' ', 1)
722                if line_arr[0] == "[OFFSPRING]":
723                    nodes += 1
724
725        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
726        self.parents = [{} for x in range(nodes)]
727        self.children = [[] for x in range(nodes)]
728        self.time = [0] * nodes
729        self.kind = [0] * nodes
730        self.life_lenght = [0] * nodes
731        self.props = {}
732
733        print("nodes: %d" % len(self.parents))
734
735        file.seek(0)
736        loaded_so_far = 0
737        lasttime = timelib.time()
738        for line in file:
739            line_arr = line.split(' ', 1)
740            if len(line_arr) == 2:
741                if line_arr[0] == CLI_PREFIX:
742                    line_arr = line_arr[1].split(' ', 1)
743                if line_arr[0] == "[OFFSPRING]":
744                    try:
745                        creature = json.loads(line_arr[1])
746                    except ValueError:
747                        print("Json format error - the line cannot be read. Breaking the loading loop.")
748                        # fixing arrays by removing the last element
749                        # ! assuming that only the last line is broken !
750                        self.parents.pop()
751                        self.children.pop()
752                        self.time.pop()
753                        self.kind.pop()
754                        self.life_lenght.pop()
755                        nodes -= 1
756                        break
757
758                    if "FromIDs" in creature:
759
760                        # make sure that ID's of parents are lower than that of their children
761                        for i in range(0, len(creature["FromIDs"])):
762                            if creature["FromIDs"][i] not in self.ids:
763                                get_id("virtual_parent")
764
765                        creature_id = get_id(creature["ID"])
766
767                        # debug
768                        if loaded_so_far%1000 == 0:
769                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
770                            lasttime = timelib.time()
771
772                        # we assign to each parent its contribution to the genotype of the child
773                        for i in range(0, len(creature["FromIDs"])):
774                            if creature["FromIDs"][i] in self.ids:
775                                parent_id = get_id(creature["FromIDs"][i])
776                            else:
777                                if creature["FromIDs"][i] not in merged_with_virtual_parent:
778                                    merged_with_virtual_parent.append(creature["FromIDs"][i])
779                                parent_id = get_id("virtual_parent")
780                            inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
781                            self.parents[creature_id][parent_id] = inherited
782
783                        if "Time" in creature:
784                            self.time[creature_id] = creature["Time"]
785
786                        if "Kind" in creature:
787                            self.kind[creature_id] = creature["Kind"]
788
789                        for prop in creature:
790                            if prop not in default_props:
791                                if prop not in self.props:
792                                    self.props[prop] = [0 for i in range(nodes)]
793                                self.props[prop][creature_id] = creature[prop]
794
795                        loaded_so_far += 1
796                    else:
797                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
798                if line_arr[0] == "[DIED]":
799                    creature = json.loads(line_arr[1])
800                    creature_id = get_id(creature["ID"], False)
801                    if creature_id is not None:
802                        for prop in creature:
803                            if prop not in default_props:
804                                if prop not in self.props:
805                                    self.props[prop] = [0 for i in range(nodes)]
806                                self.props[prop][creature_id] = creature[prop]
807
808
809            if loaded_so_far >= max_nodes and max_nodes != 0:
810                break
811
812        print("WARNING: merging individuals: " + str(merged_with_virtual_parent) + " with a virtual parent node!")
813
814        for k in range(len(self.parents)):
815            v = self.parents[k]
816            for val in self.parents[k]:
817                self.children[val].append(k)
818
819depth = {}
820kind = {}
821
822def main():
823
824    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
825                                                 'information from a text file. Supports files generated by Framsticks experiments.')
826    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
827    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
828    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
829
830    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
831    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
832    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
833
834    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
835                                                                      'BIRTHS: time measured as the number of births since the beginning; '
836                                                                      'GENERATIONAL: time measured as number of ancestors; '
837                                                                      'REAL: real time of the simulation')
838    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
839    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
840    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
841    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
842    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
843    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
844
845    parser.set_defaults(draw_tree=True)
846    parser.set_defaults(draw_skeleton=False)
847    parser.set_defaults(draw_spine=False)
848
849    parser.set_defaults(seed=-1)
850
851    args = parser.parse_args()
852
853    TIME = args.time.upper()
854    BALANCE = args.balance.upper()
855    SCALE = args.scale.upper()
856    JITTER = args.jitter
857    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
858        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
859        or not SCALE in ['NONE', 'SIMPLE']:
860        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
861        return
862
863    dir = args.input
864    seed = args.seed
865    if seed == -1:
866        seed = random.randint(0, 10000)
867    random.seed(seed)
868    print("randomseed:", seed)
869
870    tree = TreeData()
871    tree.load(dir, max_nodes=args.max_nodes)
872
873
874    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
875    designer.calculate_measures()
876    designer.calculate_node_positions(ignore_last=args.skip)
877
878    if args.output.endswith(".svg"):
879        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
880    else:
881        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
882    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
883
884
885main()
Note: See TracBrowser for help on using the repository browser.