source: mds-and-trees/tree-genealogy.py @ 701

Last change on this file since 701 was 701, checked in by konrad, 7 years ago

Fixed a bug that could occur for trees with a lot of crossover (assigning nodes positions in wrong order)

File size: 37.0 KB
Line 
1import json
2import math
3import random
4import argparse
5import bisect
6import time as timelib
7from PIL import Image, ImageDraw, ImageFont
8from scipy import stats
9from matplotlib import colors
10import numpy as np
11
12class LoadingError(Exception):
13    pass
14
15class Drawer:
16
17    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
18        self.design = design
19        self.width = w
20        self.height = h
21        self.w_margin = w_margin
22        self.h_margin = h_margin
23        self.w_no_margs = w - 2* w_margin
24        self.h_no_margs = h - 2* h_margin
25
26        self.color_converter = colors.ColorConverter()
27
28        self.settings = {
29            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
30            'dots': {
31                'color': {
32                    'meaning': 'Lifespan',
33                    'normalize_cmap': False,
34                    'cmap': {},
35                    'start': 'red',
36                    'end': 'green',
37                    'bias': 1
38                    },
39                'size': {
40                    'meaning': 'EnergyEaten',
41                    'start': 1,
42                    'end': 6,
43                    'bias': 0.5
44                    },
45                'opacity': {
46                    'meaning': 'EnergyEaten',
47                    'start': 0.2,
48                    'end': 1,
49                    'bias': 1
50                    }
51            },
52            'lines': {
53                'color': {
54                    'meaning': 'adepth',
55                    'normalize_cmap': False,
56                    'cmap': {},
57                    'start': 'black',
58                    'end': 'red',
59                    'bias': 3
60                    },
61                'width': {
62                    'meaning': 'adepth',
63                    'start': 0.1,
64                    'end': 4,
65                    'bias': 3
66                    },
67                'opacity': {
68                    'meaning': 'adepth',
69                    'start': 0.1,
70                    'end': 0.8,
71                    'bias': 5
72                    }
73            }
74        }
75
76        def merge(source, destination):
77            for key, value in source.items():
78                if isinstance(value, dict):
79                    node = destination.setdefault(key, {})
80                    merge(value, node)
81                else:
82                    destination[key] = value
83            return destination
84
85        if config_file != "":
86            with open(config_file) as config:
87                c = json.load(config)
88            self.settings = merge(c, self.settings)
89            #print(json.dumps(self.settings, indent=4, sort_keys=True))
90
91        self.compile_cmaps()
92
93    def compile_cmaps(self):
94
95        # normalization of color breaking points
96        for part in ['dots', 'lines']:
97            if self.settings[part]['color']['cmap'] \
98                    and self.settings[part]['color']['normalize_cmap']:
99                cmap = self.settings[part]['color']['cmap']
100                min = self.design.props[self.settings[part]['color']['meaning'] + "_min"]
101                max = self.design.props[self.settings[part]['color']['meaning'] + "_max"]
102                for key in cmap:
103                    for arr in cmap[key]:
104                        arr[0] = (arr[0] - min) / (max-min)
105                    if cmap[key][0][0] != 0:
106                        cmap[key].insert(0, cmap[key][0][:])
107                        cmap[key][0][0] = 0
108                    if cmap[key][-1][0] != 1:
109                        cmap[key].append(cmap[key][-1][:])
110                        cmap[key][-1][0] = 1
111
112
113        for part in ['dots', 'lines']:
114            if self.settings[part]['color']['cmap']:
115                cmap = self.settings[part]['color']['cmap']
116                for key in cmap:
117                    new_arr = []
118                    for arr in cmap[key]:
119                        new_arr.append(tuple(arr))
120                    cmap[key] = tuple(new_arr)
121                #print(cmap)
122                self.settings[part]['color']['cmap'] = colors.LinearSegmentedColormap('Custom', cmap)
123
124    def draw_dots(self, file, min_width, max_width, max_height):
125        for i in range(len(self.design.positions)):
126            node = self.design.positions[i]
127            if 'x' not in node:
128                continue
129            dot_style = self.compute_dot_style(node=i)
130            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
131                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
132
133    def draw_lines(self, file, min_width, max_width, max_height):
134        for parent in range(len(self.design.positions)):
135            par_pos = self.design.positions[parent]
136            if not 'x' in par_pos:
137                continue
138            for child in self.design.tree.children[parent]:
139                chi_pos = self.design.positions[child]
140                if 'x' not in chi_pos:
141                    continue
142                line_style = self.compute_line_style(parent, child)
143                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
144                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
145                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
146                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
147
148    def draw_scale(self, file, filename):
149        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")
150
151        start_text = ""
152        end_text = ""
153        if self.design.TIME == "BIRTHS":
154           start_text = "Birth #0"
155           end_text = "Birth #" + str(len(self.design.positions)-1)
156        if self.design.TIME == "REAL":
157           start_text = "Time " + str(min(self.design.tree.time))
158           end_text = "Time " + str(max(self.design.tree.time))
159        if self.design.TIME == "GENERATIONAL":
160           start_text = "Depth " + str(self.design.props['adepth_min'])
161           end_text = "Depth " + str(self.design.props['adepth_max'])
162
163        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
164        self.add_text(file, start_text, (self.width, self.h_margin), "end")
165        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
166        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
167
168    def compute_property(self, part, prop, node):
169        start = self.settings[part][prop]['start']
170        end = self.settings[part][prop]['end']
171        value = (self.design.props[self.settings[part][prop]['meaning']][node]
172                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
173        bias = self.settings[part][prop]['bias']
174        if prop == "color":
175            if not self.settings[part][prop]['cmap']:
176                return self.compute_color(start, end, value, bias)
177            else:
178                return self.compute_color_from_cmap(self.settings[part][prop]['cmap'], value, bias)
179        else:
180            return self.compute_value(start, end, value, bias)
181
182    def compute_color_from_cmap(self, cmap, value, bias=1):
183        value = 1 - (1-value)**bias
184        rgba = cmap(value)
185        return (100*rgba[0], 100*rgba[1], 100*rgba[2])
186
187
188    def compute_color(self, start, end, value, bias=1):
189        if isinstance(value, str):
190            value = int(value)
191            r, g, b = self.color_converter.to_rgb(self.settings['colors_of_kinds'][value])
192        else:
193            start_color = self.color_converter.to_rgb(start)
194            end_color = self.color_converter.to_rgb(end)
195            value = 1 - (1-value)**bias
196            r = start_color[0]*(1-value)+end_color[0]*value
197            g = start_color[1]*(1-value)+end_color[1]*value
198            b = start_color[2]*(1-value)+end_color[2]*value
199        return (100*r, 100*g, 100*b)
200
201    def compute_value(self, start, end, value, bias=1):
202        value = 1 - (1-value)**bias
203        return start*(1-value) + end*value
204
205class PngDrawer(Drawer):
206
207    def scale_up(self):
208        self.width *= self.multi
209        self.height *= self.multi
210        self.w_margin *= self.multi
211        self.h_margin *= self.multi
212        self.h_no_margs *= self.multi
213        self.w_no_margs *= self.multi
214
215    def scale_down(self):
216        self.width /= self.multi
217        self.height /= self.multi
218        self.w_margin /= self.multi
219        self.h_margin /= self.multi
220        self.h_no_margs /= self.multi
221        self.w_no_margs /= self.multi
222
223    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
224        print("Drawing...")
225
226        self.multi=multi
227        self.scale_up()
228
229        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
230
231        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
232        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
233        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
234
235        self.draw_lines(back, min_width, max_width, max_height)
236        self.draw_dots(back, min_width, max_width, max_height)
237
238        if scale == "SIMPLE":
239            self.draw_scale(back, input_filename)
240
241        #back.show()
242        self.scale_down()
243
244        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
245
246        back.save(filename)
247
248    def add_dot(self, file, pos, style):
249        x, y = int(pos[0]), int(pos[1])
250        r = style['r']*self.multi
251        offset = (int(x - r), int(y - r))
252        size = (2*int(r), 2*int(r))
253
254        c = style['color']
255
256        img = Image.new('RGBA', size)
257        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
258                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
259        file.paste(img, offset, mask=img)
260
261    def add_line(self, file, from_pos, to_pos, style):
262        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
263        w = int(style['width'])*self.multi
264
265        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
266        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
267        if size[0] == 0 or size[1] == 0:
268            return
269
270        c = style['color']
271
272        img = Image.new('RGBA', size)
273        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
274                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
275        file.paste(img, offset, mask=img)
276
277    def add_dashed_line(self, file, from_pos, to_pos):
278        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
279        sublines = 50
280        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
281        normdiv = 2*sublines-1
282        for i in range(sublines):
283            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
284                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
285            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
286                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
287            self.add_line(file, from_pos_sub, to_pos_sub, style)
288
289    def add_text(self, file, text, pos, anchor, style=''):
290        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
291
292        img = Image.new('RGBA', (self.width, self.height))
293        draw = ImageDraw.Draw(img)
294        txtsize = draw.textsize(text, font=font)
295        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
296        draw.text(pos, text, (0,0,0), font=font)
297        file.paste(img, (0,0), mask=img)
298
299    def compute_line_style(self, parent, child):
300        return {'color': self.compute_property('lines', 'color', child),
301                'width': self.compute_property('lines', 'width', child),
302                'opacity': self.compute_property('lines', 'opacity', child)}
303
304    def compute_dot_style(self, node):
305        return {'color': self.compute_property('dots', 'color', node),
306                'r': self.compute_property('dots', 'size', node),
307                'opacity': self.compute_property('dots', 'opacity', node)}
308
309class SvgDrawer(Drawer):
310    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
311        print("Drawing...")
312        file = open(filename, "w")
313
314        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
315        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
316        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
317
318        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
319                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
320                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
321
322        self.draw_lines(file, min_width, max_width, max_height)
323        self.draw_dots(file, min_width, max_width, max_height)
324
325        if scale == "SIMPLE":
326            self.draw_scale(file, input_filename)
327
328        file.write("</svg>")
329        file.close()
330
331    def add_text(self, file, text, pos, anchor, style=''):
332        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
333        # assuming font size 12, it should be taken from the style string!
334        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
335
336    def add_dot(self, file, pos, style):
337        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
338
339    def add_line(self, file, from_pos, to_pos, style):
340        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
341                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
342
343    def add_dashed_line(self, file, from_pos, to_pos):
344        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
345        self.add_line(file, from_pos, to_pos, style)
346
347    def compute_line_style(self, parent, child):
348        return self.compute_stroke_color('lines', child) + ' ' \
349               + self.compute_stroke_width('lines', child) + ' ' \
350               + self.compute_stroke_opacity(child)
351
352    def compute_dot_style(self, node):
353        return self.compute_dot_size(node) + ' ' \
354               + self.compute_fill_opacity(node) + ' ' \
355               + self.compute_dot_fill(node)
356
357    def compute_stroke_color(self, part, node):
358        color = self.compute_property(part, 'color', node)
359        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
360
361    def compute_stroke_width(self, part, node):
362        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
363
364    def compute_stroke_opacity(self, node):
365        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
366
367    def compute_fill_opacity(self, node):
368        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
369
370    def compute_dot_size(self, node):
371        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
372
373    def compute_dot_fill(self, node):
374        color = self.compute_property('dots', 'color', node)
375        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
376
377class Designer:
378
379    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
380        self.props = {}
381
382        self.tree = tree
383
384        self.TIME = time
385        self.JITTER = jitter
386
387        if balance == "RANDOM":
388            self.xmin_crowd = self.xmin_crowd_random
389        elif balance == "MIN":
390            self.xmin_crowd = self.xmin_crowd_min
391        elif balance == "DENSITY":
392            self.xmin_crowd = self.xmin_crowd_density
393        else:
394            raise ValueError("Error, the value of BALANCE does not match any expected value.")
395
396    def calculate_measures(self):
397        print("Calculating measures...")
398        self.compute_depth()
399        self.compute_maxdepth()
400        self.compute_adepth()
401        self.compute_children()
402        self.compute_kind()
403        self.compute_time()
404        self.compute_progress()
405        self.compute_custom()
406
407    def xmin_crowd_random(self, x1, x2, y):
408        return (x1 if random.randrange(2) == 0 else x2)
409
410    def xmin_crowd_min(self, x1, x2, y):
411        x1_closest = 999999
412        x2_closest = 999999
413        miny = y-3
414        maxy = y+3
415        i = bisect.bisect_left(self.y_sorted, miny)
416        while True:
417            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
418                break
419            pos = self.positions_sorted[i]
420
421            x1_closest = min(x1_closest, abs(x1-pos['x']))
422            x2_closest = min(x2_closest, abs(x2-pos['x']))
423
424            i += 1
425        return (x1 if x1_closest > x2_closest else x2)
426
427    def xmin_crowd_density(self, x1, x2, y):
428        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
429        CONST_LOCAL_AREA_RADIUS = 5
430        CONST_GLOBAL_AREA_RADIUS = 10
431        CONST_WINDOW_SIZE = 20000 #TODO should depend on the maxY ?
432        x1_dist_loc = 0
433        x2_dist_loc = 0
434        count_loc = 1
435        x1_dist_glob = 0
436        x2_dist_glob = 0
437        count_glob = 1
438        miny = y-CONST_WINDOW_SIZE
439        maxy = y+CONST_WINDOW_SIZE
440        i_left = bisect.bisect_left(self.y_sorted, miny)
441        i_right = bisect.bisect_right(self.y_sorted, maxy)
442        #TODO test: maxy=y should give the same results, right?
443
444        def include_pos(pos):
445            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob
446
447            dysq = (pos['y']-y)**2 + 1 #+1 so 1/dysq is at most 1
448            dx1 = math.fabs(pos['x']-x1)
449            dx2 = math.fabs(pos['x']-x2)
450
451            d = math.fabs(pos['x'] - (x1+x2)/2)
452
453            if d < CONST_LOCAL_AREA_RADIUS:
454                x1_dist_loc += math.sqrt(dx1/dysq + dx1**2)
455                x2_dist_loc += math.sqrt(dx2/dysq + dx2**2)
456                count_loc += 1
457            elif d > CONST_GLOBAL_AREA_RADIUS:
458                x1_dist_glob += math.sqrt(dx1/dysq + dx1**2)
459                x2_dist_glob += math.sqrt(dx2/dysq + dx2**2)
460                count_glob += 1
461
462        # optimized to draw from all the nodes, if less than 10 nodes in the range
463        if len(self.positions_sorted) > i_left:
464            if i_right - i_left < 10:
465                for j in range(i_left, i_right):
466                    include_pos(self.positions_sorted[j])
467            else:
468                for j in range(10):
469                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
470                    include_pos(pos)
471
472        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
473        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
474        #print(x1_dist, x2_dist)
475        #x1_dist = x1_dist**2
476        #x2_dist = x2_dist**2
477        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
478        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
479
480    def calculate_node_positions(self, ignore_last=0):
481        print("Calculating positions...")
482
483        def add_node(node):
484            index = bisect.bisect_left(self.y_sorted, node['y'])
485            self.y_sorted.insert(index, node['y'])
486            self.positions_sorted.insert(index, node)
487            self.positions[node['id']] = node
488
489        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
490        self.y_sorted = [0]
491        self.positions = [{} for x in range(len(self.tree.parents))]
492        self.positions[0] = {'x':0, 'y':0, 'id':0}
493
494        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
495        visiting_order = [i for i in range(0, len(self.tree.parents))]
496        visiting_order = sorted(visiting_order, key=lambda q:\
497                            0 if q == 0 else self.props["maxdepth"][q])
498
499        start_time = timelib.time()
500
501        # for each child of the current node
502        for node_counter,child in enumerate(visiting_order, start=1):
503            # debug info - elapsed time
504            if node_counter % 100000 == 0:
505               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
506               start_time = timelib.time()
507
508            # using normalized adepth
509            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
510
511                ypos = 0
512                if self.TIME == "BIRTHS":
513                    ypos = child
514                elif self.TIME == "GENERATIONAL":
515                    # one more than its parent (what if more than one parent?)
516                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
517                        if self.tree.parents[child] else 0
518                elif self.TIME == "REAL":
519                    ypos = self.tree.time[child]
520
521                if len(self.tree.parents[child]) == 1:
522                # if current_node is the only parent
523                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]
524
525                    if self.JITTER:
526                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
527                    else:
528                        dissimilarity = (1-similarity) + 0.001
529                    add_node({'id':child, 'y':ypos, 'x':
530                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
531                              self.positions[parent]['x']+dissimilarity, ypos)})
532                else:
533                    # position weighted by the degree of inheritence from each parent
534                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
535                    xpos = sum([self.positions[k]['x']*v/total_inheretance
536                               for k, v in self.tree.parents[child].items()])
537                    if self.JITTER:
538                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
539                    else:
540                        add_node({'id':child, 'y':ypos, 'x':xpos})
541
542
543    def compute_custom(self):
544        for prop in self.tree.props:
545            self.props[prop] = [None for x in range(len(self.tree.children))]
546
547            for i in range(len(self.props[prop])):
548                self.props[prop][i] = self.tree.props[prop][i]
549
550            self.normalize_prop(prop)
551
552    def compute_time(self):
553        # simple rewrite from the tree
554        self.props["time"] = [0 for x in range(len(self.tree.children))]
555
556        for i in range(len(self.props['time'])):
557            self.props['time'][i] = self.tree.time[i]
558
559        self.normalize_prop('time')
560
561    def compute_kind(self):
562        # simple rewrite from the tree
563        self.props["kind"] = [0 for x in range(len(self.tree.children))]
564
565        for i in range (len(self.props['kind'])):
566            self.props['kind'][i] = str(self.tree.kind[i])
567
568    def compute_depth(self):
569        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
570        visited = [0 for x in range(len(self.tree.children))]
571
572        nodes_to_visit = [0]
573        visited[0] = 1
574        self.props["depth"][0] = 0
575        while True:
576            current_node = nodes_to_visit[0]
577
578            for child in self.tree.children[current_node]:
579                if visited[child] == 0:
580                    visited[child] = 1
581                    nodes_to_visit.append(child)
582                    self.props["depth"][child] = self.props["depth"][current_node]+1
583            nodes_to_visit = nodes_to_visit[1:]
584            if len(nodes_to_visit) == 0:
585                break
586
587        self.normalize_prop('depth')
588
589    def compute_maxdepth(self):
590        self.props["maxdepth"] = [999999999 for x in range(len(self.tree.children))]
591        visited = [0 for x in range(len(self.tree.children))]
592
593        nodes_to_visit = [0]
594        visited[0] = 1
595        self.props["maxdepth"][0] = 0
596        while True:
597            current_node = nodes_to_visit[0]
598
599            for child in self.tree.children[current_node]:
600                if visited[child] == 0:
601                    visited[child] = 1
602                    nodes_to_visit.append(child)
603                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
604                elif self.props["maxdepth"][child] < self.props["maxdepth"][current_node]+1:
605                    self.props["maxdepth"][child] = self.props["maxdepth"][current_node]+1
606                    if child not in  nodes_to_visit:
607                        nodes_to_visit.append(child)
608
609            nodes_to_visit = nodes_to_visit[1:]
610            if len(nodes_to_visit) == 0:
611                break
612
613        self.normalize_prop('maxdepth')
614
615    def compute_adepth(self):
616        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
617
618        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
619        visiting_order = [i for i in range(0, len(self.tree.parents))]
620        visiting_order = sorted(visiting_order, key=lambda q: self.props["maxdepth"][q])[::-1]
621
622        for node in visiting_order:
623            children = self.tree.children[node]
624            if len(children) != 0:
625                # 0 by default
626                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
627        self.normalize_prop('adepth')
628
629    def compute_children(self):
630        self.props["children"] = [0 for x in range(len(self.tree.children))]
631        for i in range (len(self.props['children'])):
632            self.props['children'][i] = len(self.tree.children[i])
633
634        self.normalize_prop('children')
635
636    def compute_progress(self):
637        self.props["progress"] = [0 for x in range(len(self.tree.children))]
638        for i in range(len(self.props['children'])):
639            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
640            if len(times) > 4:
641                times = [times[i+1] - times[i] for i in range(len(times)-1)]
642                #print(times)
643                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
644                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
645
646        for i in range(0, 5):
647            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
648            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
649
650        mini = min(self.props['progress'])
651        maxi = max(self.props['progress'])
652        for k in range(len(self.props['progress'])):
653            if self.props['progress'][k] == 0:
654                self.props['progress'][k] = mini
655
656        #for k in range(len(self.props['progress'])):
657        #        self.props['progress'][k] = 1-self.props['progress'][k]
658
659        self.normalize_prop('progress')
660
661    def normalize_prop(self, prop):
662        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
663        if len(noneless) > 0:
664            max_val = max(noneless)
665            min_val = min(noneless)
666            print("%s: [%g, %g]" % (prop, min_val, max_val))
667            self.props[prop +'_max'] = max_val
668            self.props[prop +'_min'] = min_val
669            for i in range(len(self.props[prop])):
670                if self.props[prop][i] is not None:
671                    qqq = self.props[prop][i]
672                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
673
674class TreeData:
675    simple_data = None
676
677    children = []
678    parents = []
679    time = []
680    kind = []
681
682    def __init__(self): #, simple_data=False):
683        #self.simple_data = simple_data
684        pass
685
686    def load(self, filename, max_nodes=0):
687        print("Loading...")
688
689        CLI_PREFIX = "Script.Message:"
690        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
691
692        self.ids = {}
693        def get_id(id, createOnError = True):
694            if createOnError:
695                if id not in self.ids:
696                    self.ids[id] = len(self.ids)
697            else:
698                if id not in self.ids:
699                    return None
700
701            return self.ids[id]
702
703        file = open(filename)
704
705        # counting the number of expected nodes
706        nodes = 0
707        for line in file:
708            line_arr = line.split(' ', 1)
709            if len(line_arr) == 2:
710                if line_arr[0] == CLI_PREFIX:
711                    line_arr = line_arr[1].split(' ', 1)
712                if line_arr[0] == "[OFFSPRING]":
713                    nodes += 1
714
715        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
716        self.parents = [{} for x in range(nodes)]
717        self.children = [[] for x in range(nodes)]
718        self.time = [0] * nodes
719        self.kind = [0] * nodes
720        self.life_lenght = [0] * nodes
721        self.props = {}
722
723        print("nodes: %d" % len(self.parents))
724
725        file.seek(0)
726        loaded_so_far = 0
727        lasttime = timelib.time()
728        for line in file:
729            line_arr = line.split(' ', 1)
730            if len(line_arr) == 2:
731                if line_arr[0] == CLI_PREFIX:
732                    line_arr = line_arr[1].split(' ', 1)
733                if line_arr[0] == "[OFFSPRING]":
734                    try:
735                        creature = json.loads(line_arr[1])
736                    except ValueError:
737                        print("Json format error - the line cannot be read. Breaking the loading loop.")
738                        # fixing arrays by removing the last element
739                        # ! assuming that only the last line is broken !
740                        self.parents.pop()
741                        self.children.pop()
742                        self.time.pop()
743                        self.kind.pop()
744                        self.life_lenght.pop()
745                        nodes -= 1
746                        break
747
748                    if "FromIDs" in creature:
749
750                        # make sure that ID's of parents are lower than that of their children
751                        for i in range(0, len(creature["FromIDs"])):
752                            if creature["FromIDs"][i] not in self.ids:
753                                get_id("virtual_parent")
754
755                        creature_id = get_id(creature["ID"])
756
757                        # debug
758                        if loaded_so_far%1000 == 0:
759                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
760                            lasttime = timelib.time()
761
762                        # we assign to each parent its contribution to the genotype of the child
763                        for i in range(0, len(creature["FromIDs"])):
764                            if creature["FromIDs"][i] in self.ids:
765                                parent_id = get_id(creature["FromIDs"][i])
766                            else:
767                                parent_id = get_id("virtual_parent")
768                            inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
769                            self.parents[creature_id][parent_id] = inherited
770
771                        if "Time" in creature:
772                            self.time[creature_id] = creature["Time"]
773
774                        if "Kind" in creature:
775                            self.kind[creature_id] = creature["Kind"]
776
777                        for prop in creature:
778                            if prop not in default_props:
779                                if prop not in self.props:
780                                    self.props[prop] = [0 for i in range(nodes)]
781                                self.props[prop][creature_id] = creature[prop]
782
783                        loaded_so_far += 1
784                    else:
785                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
786                if line_arr[0] == "[DIED]":
787                    creature = json.loads(line_arr[1])
788                    creature_id = get_id(creature["ID"], False)
789                    if creature_id is not None:
790                        for prop in creature:
791                            if prop not in default_props:
792                                if prop not in self.props:
793                                    self.props[prop] = [0 for i in range(nodes)]
794                                self.props[prop][creature_id] = creature[prop]
795
796
797            if loaded_so_far >= max_nodes and max_nodes != 0:
798                break
799
800        for k in range(len(self.parents)):
801            v = self.parents[k]
802            for val in self.parents[k]:
803                self.children[val].append(k)
804
805depth = {}
806kind = {}
807
808def main():
809
810    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
811                                                 'information from a text file. Supports files generated by Framsticks experiments.')
812    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
813    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
814    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
815
816    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
817    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
818    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
819
820    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
821                                                                      'BIRTHS: time measured as the number of births since the beginning; '
822                                                                      'GENERATIONAL: time measured as number of ancestors; '
823                                                                      'REAL: real time of the simulation')
824    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
825    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
826    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
827    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
828    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
829    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
830
831    parser.set_defaults(draw_tree=True)
832    parser.set_defaults(draw_skeleton=False)
833    parser.set_defaults(draw_spine=False)
834
835    parser.set_defaults(seed=-1)
836
837    args = parser.parse_args()
838
839    TIME = args.time.upper()
840    BALANCE = args.balance.upper()
841    SCALE = args.scale.upper()
842    JITTER = args.jitter
843    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
844        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
845        or not SCALE in ['NONE', 'SIMPLE']:
846        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
847        return
848
849    dir = args.input
850    seed = args.seed
851    if seed == -1:
852        seed = random.randint(0, 10000)
853    random.seed(seed)
854    print("randomseed:", seed)
855
856    tree = TreeData()
857    tree.load(dir, max_nodes=args.max_nodes)
858
859
860    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
861    designer.calculate_measures()
862    designer.calculate_node_positions(ignore_last=args.skip)
863
864    if args.output.endswith(".svg"):
865        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
866    else:
867        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
868    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
869
870
871main()
Note: See TracBrowser for help on using the repository browser.