source: mds-and-trees/tree-genealogy.py @ 700

Last change on this file since 700 was 700, checked in by konrad, 7 years ago

Colors support matplotlib cmaps now. 'color' property in config file can take two more key values: 'cmap' and 'normalize_cmap'. Cmap is defined as in https://matplotlib.org/devdocs/api/_as_gen/matplotlib.colors.LinearSegmentedColormap.html but uses arrays instead of tuples. If 'normalize_cmap" == false, step values (first value in each tuple) must be normalized. Otherwise, if 'normalize_cmap" == true, step values can be raw property values (from the range of values in the data).

File size: 36.1 KB
Line 
1import json
2import math
3import random
4import argparse
5import bisect
6import time as timelib
7from PIL import Image, ImageDraw, ImageFont
8from scipy import stats
9from matplotlib import colors
10import numpy as np
11
12class LoadingError(Exception):
13    pass
14
15class Drawer:
16
17    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
18        self.design = design
19        self.width = w
20        self.height = h
21        self.w_margin = w_margin
22        self.h_margin = h_margin
23        self.w_no_margs = w - 2* w_margin
24        self.h_no_margs = h - 2* h_margin
25
26        self.color_converter = colors.ColorConverter()
27
28        self.settings = {
29            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
30            'dots': {
31                'color': {
32                    'meaning': 'Lifespan',
33                    'normalize_cmap': False,
34                    'cmap': {},
35                    'start': 'red',
36                    'end': 'green',
37                    'bias': 1
38                    },
39                'size': {
40                    'meaning': 'EnergyEaten',
41                    'start': 1,
42                    'end': 6,
43                    'bias': 0.5
44                    },
45                'opacity': {
46                    'meaning': 'EnergyEaten',
47                    'start': 0.2,
48                    'end': 1,
49                    'bias': 1
50                    }
51            },
52            'lines': {
53                'color': {
54                    'meaning': 'adepth',
55                    'normalize_cmap': False,
56                    'cmap': {},
57                    'start': 'black',
58                    'end': 'red',
59                    'bias': 3
60                    },
61                'width': {
62                    'meaning': 'adepth',
63                    'start': 0.1,
64                    'end': 4,
65                    'bias': 3
66                    },
67                'opacity': {
68                    'meaning': 'adepth',
69                    'start': 0.1,
70                    'end': 0.8,
71                    'bias': 5
72                    }
73            }
74        }
75
76        def merge(source, destination):
77            for key, value in source.items():
78                if isinstance(value, dict):
79                    node = destination.setdefault(key, {})
80                    merge(value, node)
81                else:
82                    destination[key] = value
83            return destination
84
85        if config_file != "":
86            with open(config_file) as config:
87                c = json.load(config)
88            self.settings = merge(c, self.settings)
89            #print(json.dumps(self.settings, indent=4, sort_keys=True))
90
91        self.compile_cmaps()
92
93    def compile_cmaps(self):
94
95        # normalization of color breaking points
96        for part in ['dots', 'lines']:
97            if self.settings[part]['color']['cmap'] \
98                    and self.settings[part]['color']['normalize_cmap']:
99                cmap = self.settings[part]['color']['cmap']
100                min = self.design.props[self.settings[part]['color']['meaning'] + "_min"]
101                max = self.design.props[self.settings[part]['color']['meaning'] + "_max"]
102                for key in cmap:
103                    for arr in cmap[key]:
104                        arr[0] = (arr[0] - min) / (max-min)
105                    if cmap[key][0][0] != 0:
106                        cmap[key].insert(0, cmap[key][0][:])
107                        cmap[key][0][0] = 0
108                    if cmap[key][-1][0] != 1:
109                        cmap[key].append(cmap[key][-1][:])
110                        cmap[key][-1][0] = 1
111
112
113        for part in ['dots', 'lines']:
114            if self.settings[part]['color']['cmap']:
115                cmap = self.settings[part]['color']['cmap']
116                for key in cmap:
117                    new_arr = []
118                    for arr in cmap[key]:
119                        new_arr.append(tuple(arr))
120                    cmap[key] = tuple(new_arr)
121                #print(cmap)
122                self.settings[part]['color']['cmap'] = colors.LinearSegmentedColormap('Custom', cmap)
123
124    def draw_dots(self, file, min_width, max_width, max_height):
125        for i in range(len(self.design.positions)):
126            node = self.design.positions[i]
127            if 'x' not in node:
128                continue
129            dot_style = self.compute_dot_style(node=i)
130            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
131                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
132
133    def draw_lines(self, file, min_width, max_width, max_height):
134        for parent in range(len(self.design.positions)):
135            par_pos = self.design.positions[parent]
136            if not 'x' in par_pos:
137                continue
138            for child in self.design.tree.children[parent]:
139                chi_pos = self.design.positions[child]
140                if 'x' not in chi_pos:
141                    continue
142                line_style = self.compute_line_style(parent, child)
143                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
144                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
145                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
146                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
147
148    def draw_scale(self, file, filename):
149        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")
150
151        start_text = ""
152        end_text = ""
153        if self.design.TIME == "BIRTHS":
154           start_text = "Birth #0"
155           end_text = "Birth #" + str(len(self.design.positions)-1)
156        if self.design.TIME == "REAL":
157           start_text = "Time " + str(min(self.design.tree.time))
158           end_text = "Time " + str(max(self.design.tree.time))
159        if self.design.TIME == "GENERATIONAL":
160           start_text = "Depth " + str(self.design.props['adepth_min'])
161           end_text = "Depth " + str(self.design.props['adepth_max'])
162
163        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
164        self.add_text(file, start_text, (self.width, self.h_margin), "end")
165        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
166        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
167
168    def compute_property(self, part, prop, node):
169        start = self.settings[part][prop]['start']
170        end = self.settings[part][prop]['end']
171        value = (self.design.props[self.settings[part][prop]['meaning']][node]
172                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
173        bias = self.settings[part][prop]['bias']
174        if prop == "color":
175            if not self.settings[part][prop]['cmap']:
176                return self.compute_color(start, end, value, bias)
177            else:
178                return self.compute_color_from_cmap(self.settings[part][prop]['cmap'], value, bias)
179        else:
180            return self.compute_value(start, end, value, bias)
181
182    def compute_color_from_cmap(self, cmap, value, bias=1):
183        value = 1 - (1-value)**bias
184        rgba = cmap(value)
185        return (100*rgba[0], 100*rgba[1], 100*rgba[2])
186
187
188    def compute_color(self, start, end, value, bias=1):
189        if isinstance(value, str):
190            value = int(value)
191            r, g, b = self.color_converter.to_rgb(self.settings['colors_of_kinds'][value])
192        else:
193            start_color = self.color_converter.to_rgb(start)
194            end_color = self.color_converter.to_rgb(end)
195            value = 1 - (1-value)**bias
196            r = start_color[0]*(1-value)+end_color[0]*value
197            g = start_color[1]*(1-value)+end_color[1]*value
198            b = start_color[2]*(1-value)+end_color[2]*value
199        return (100*r, 100*g, 100*b)
200
201    def compute_value(self, start, end, value, bias=1):
202        value = 1 - (1-value)**bias
203        return start*(1-value) + end*value
204
205class PngDrawer(Drawer):
206
207    def scale_up(self):
208        self.width *= self.multi
209        self.height *= self.multi
210        self.w_margin *= self.multi
211        self.h_margin *= self.multi
212        self.h_no_margs *= self.multi
213        self.w_no_margs *= self.multi
214
215    def scale_down(self):
216        self.width /= self.multi
217        self.height /= self.multi
218        self.w_margin /= self.multi
219        self.h_margin /= self.multi
220        self.h_no_margs /= self.multi
221        self.w_no_margs /= self.multi
222
223    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
224        print("Drawing...")
225
226        self.multi=multi
227        self.scale_up()
228
229        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
230
231        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
232        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
233        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
234
235        self.draw_lines(back, min_width, max_width, max_height)
236        self.draw_dots(back, min_width, max_width, max_height)
237
238        if scale == "SIMPLE":
239            self.draw_scale(back, input_filename)
240
241        #back.show()
242        self.scale_down()
243
244        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
245
246        back.save(filename)
247
248    def add_dot(self, file, pos, style):
249        x, y = int(pos[0]), int(pos[1])
250        r = style['r']*self.multi
251        offset = (int(x - r), int(y - r))
252        size = (2*int(r), 2*int(r))
253
254        c = style['color']
255
256        img = Image.new('RGBA', size)
257        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
258                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
259        file.paste(img, offset, mask=img)
260
261    def add_line(self, file, from_pos, to_pos, style):
262        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
263        w = int(style['width'])*self.multi
264
265        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
266        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
267        if size[0] == 0 or size[1] == 0:
268            return
269
270        c = style['color']
271
272        img = Image.new('RGBA', size)
273        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
274                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
275        file.paste(img, offset, mask=img)
276
277    def add_dashed_line(self, file, from_pos, to_pos):
278        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
279        sublines = 50
280        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
281        normdiv = 2*sublines-1
282        for i in range(sublines):
283            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
284                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
285            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
286                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
287            self.add_line(file, from_pos_sub, to_pos_sub, style)
288
289    def add_text(self, file, text, pos, anchor, style=''):
290        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
291
292        img = Image.new('RGBA', (self.width, self.height))
293        draw = ImageDraw.Draw(img)
294        txtsize = draw.textsize(text, font=font)
295        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
296        draw.text(pos, text, (0,0,0), font=font)
297        file.paste(img, (0,0), mask=img)
298
299    def compute_line_style(self, parent, child):
300        return {'color': self.compute_property('lines', 'color', child),
301                'width': self.compute_property('lines', 'width', child),
302                'opacity': self.compute_property('lines', 'opacity', child)}
303
304    def compute_dot_style(self, node):
305        return {'color': self.compute_property('dots', 'color', node),
306                'r': self.compute_property('dots', 'size', node),
307                'opacity': self.compute_property('dots', 'opacity', node)}
308
309class SvgDrawer(Drawer):
310    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
311        print("Drawing...")
312        file = open(filename, "w")
313
314        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
315        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
316        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
317
318        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
319                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
320                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
321
322        self.draw_lines(file, min_width, max_width, max_height)
323        self.draw_dots(file, min_width, max_width, max_height)
324
325        if scale == "SIMPLE":
326            self.draw_scale(file, input_filename)
327
328        file.write("</svg>")
329        file.close()
330
331    def add_text(self, file, text, pos, anchor, style=''):
332        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
333        # assuming font size 12, it should be taken from the style string!
334        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
335
336    def add_dot(self, file, pos, style):
337        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
338
339    def add_line(self, file, from_pos, to_pos, style):
340        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
341                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
342
343    def add_dashed_line(self, file, from_pos, to_pos):
344        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
345        self.add_line(file, from_pos, to_pos, style)
346
347    def compute_line_style(self, parent, child):
348        return self.compute_stroke_color('lines', child) + ' ' \
349               + self.compute_stroke_width('lines', child) + ' ' \
350               + self.compute_stroke_opacity(child)
351
352    def compute_dot_style(self, node):
353        return self.compute_dot_size(node) + ' ' \
354               + self.compute_fill_opacity(node) + ' ' \
355               + self.compute_dot_fill(node)
356
357    def compute_stroke_color(self, part, node):
358        color = self.compute_property(part, 'color', node)
359        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
360
361    def compute_stroke_width(self, part, node):
362        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
363
364    def compute_stroke_opacity(self, node):
365        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
366
367    def compute_fill_opacity(self, node):
368        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
369
370    def compute_dot_size(self, node):
371        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
372
373    def compute_dot_fill(self, node):
374        color = self.compute_property('dots', 'color', node)
375        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
376
377class Designer:
378
379    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
380        self.props = {}
381
382        self.tree = tree
383
384        self.TIME = time
385        self.JITTER = jitter
386
387        if balance == "RANDOM":
388            self.xmin_crowd = self.xmin_crowd_random
389        elif balance == "MIN":
390            self.xmin_crowd = self.xmin_crowd_min
391        elif balance == "DENSITY":
392            self.xmin_crowd = self.xmin_crowd_density
393        else:
394            raise ValueError("Error, the value of BALANCE does not match any expected value.")
395
396    def calculate_measures(self):
397        print("Calculating measures...")
398        self.compute_depth()
399        self.compute_adepth()
400        self.compute_children()
401        self.compute_kind()
402        self.compute_time()
403        self.compute_progress()
404        self.compute_custom()
405
406    def xmin_crowd_random(self, x1, x2, y):
407        return (x1 if random.randrange(2) == 0 else x2)
408
409    def xmin_crowd_min(self, x1, x2, y):
410        x1_closest = 999999
411        x2_closest = 999999
412        miny = y-3
413        maxy = y+3
414        i = bisect.bisect_left(self.y_sorted, miny)
415        while True:
416            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
417                break
418            pos = self.positions_sorted[i]
419
420            x1_closest = min(x1_closest, abs(x1-pos['x']))
421            x2_closest = min(x2_closest, abs(x2-pos['x']))
422
423            i += 1
424        return (x1 if x1_closest > x2_closest else x2)
425
426    def xmin_crowd_density(self, x1, x2, y):
427        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
428        CONST_LOCAL_AREA_RADIUS = 5
429        CONST_GLOBAL_AREA_RADIUS = 10
430        CONST_WINDOW_SIZE = 20000 #TODO should depend on the maxY ?
431        x1_dist_loc = 0
432        x2_dist_loc = 0
433        count_loc = 1
434        x1_dist_glob = 0
435        x2_dist_glob = 0
436        count_glob = 1
437        miny = y-CONST_WINDOW_SIZE
438        maxy = y+CONST_WINDOW_SIZE
439        i_left = bisect.bisect_left(self.y_sorted, miny)
440        i_right = bisect.bisect_right(self.y_sorted, maxy)
441        #TODO test: maxy=y should give the same results, right?
442
443        def include_pos(pos):
444            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob
445
446            dysq = (pos['y']-y)**2 + 1 #+1 so 1/dysq is at most 1
447            dx1 = math.fabs(pos['x']-x1)
448            dx2 = math.fabs(pos['x']-x2)
449
450            d = math.fabs(pos['x'] - (x1+x2)/2)
451
452            if d < CONST_LOCAL_AREA_RADIUS:
453                x1_dist_loc += math.sqrt(dx1/dysq + dx1**2)
454                x2_dist_loc += math.sqrt(dx2/dysq + dx2**2)
455                count_loc += 1
456            elif d > CONST_GLOBAL_AREA_RADIUS:
457                x1_dist_glob += math.sqrt(dx1/dysq + dx1**2)
458                x2_dist_glob += math.sqrt(dx2/dysq + dx2**2)
459                count_glob += 1
460
461        # optimized to draw from all the nodes, if less than 10 nodes in the range
462        if len(self.positions_sorted) > i_left:
463            if i_right - i_left < 10:
464                for j in range(i_left, i_right):
465                    include_pos(self.positions_sorted[j])
466            else:
467                for j in range(10):
468                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
469                    include_pos(pos)
470
471        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
472        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
473        #print(x1_dist, x2_dist)
474        #x1_dist = x1_dist**2
475        #x2_dist = x2_dist**2
476        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
477        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
478
479    def calculate_node_positions(self, ignore_last=0):
480        print("Calculating positions...")
481
482        def add_node(node):
483            index = bisect.bisect_left(self.y_sorted, node['y'])
484            self.y_sorted.insert(index, node['y'])
485            self.positions_sorted.insert(index, node)
486            self.positions[node['id']] = node
487
488        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
489        self.y_sorted = [0]
490        self.positions = [{} for x in range(len(self.tree.parents))]
491        self.positions[0] = {'x':0, 'y':0, 'id':0}
492
493        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
494        visiting_order = [i for i in range(0, len(self.tree.parents))]
495        visiting_order = sorted(visiting_order, key=lambda q:
496                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))
497
498        start_time = timelib.time()
499
500        # for each child of the current node
501        for node_counter,child in enumerate(visiting_order, start=1):
502            # debug info - elapsed time
503            if node_counter % 100000 == 0:
504               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
505               start_time = timelib.time()
506
507            # using normalized adepth
508            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
509
510                ypos = 0
511                if self.TIME == "BIRTHS":
512                    ypos = child
513                elif self.TIME == "GENERATIONAL":
514                    # one more than its parent (what if more than one parent?)
515                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
516                        if self.tree.parents[child] else 0
517                elif self.TIME == "REAL":
518                    ypos = self.tree.time[child]
519
520                if len(self.tree.parents[child]) == 1:
521                # if current_node is the only parent
522                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]
523
524                    if self.JITTER:
525                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
526                    else:
527                        dissimilarity = (1-similarity) + 0.001
528                    add_node({'id':child, 'y':ypos, 'x':
529                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
530                              self.positions[parent]['x']+dissimilarity, ypos)})
531                else:
532                    # position weighted by the degree of inheritence from each parent
533                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
534                    xpos = sum([self.positions[k]['x']*v/total_inheretance
535                               for k, v in self.tree.parents[child].items()])
536                    if self.JITTER:
537                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
538                    else:
539                        add_node({'id':child, 'y':ypos, 'x':xpos})
540
541
542    def compute_custom(self):
543        for prop in self.tree.props:
544            self.props[prop] = [None for x in range(len(self.tree.children))]
545
546            for i in range(len(self.props[prop])):
547                self.props[prop][i] = self.tree.props[prop][i]
548
549            self.normalize_prop(prop)
550
551    def compute_time(self):
552        # simple rewrite from the tree
553        self.props["time"] = [0 for x in range(len(self.tree.children))]
554
555        for i in range(len(self.props['time'])):
556            self.props['time'][i] = self.tree.time[i]
557
558        self.normalize_prop('time')
559
560    def compute_kind(self):
561        # simple rewrite from the tree
562        self.props["kind"] = [0 for x in range(len(self.tree.children))]
563
564        for i in range (len(self.props['kind'])):
565            self.props['kind'][i] = str(self.tree.kind[i])
566
567    def compute_depth(self):
568        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
569        visited = [0 for x in range(len(self.tree.children))]
570
571        nodes_to_visit = [0]
572        visited[0] = 1
573        self.props["depth"][0] = 0
574        while True:
575            current_node = nodes_to_visit[0]
576
577            for child in self.tree.children[current_node]:
578                if visited[child] == 0:
579                    visited[child] = 1
580                    nodes_to_visit.append(child)
581                    self.props["depth"][child] = self.props["depth"][current_node]+1
582            nodes_to_visit = nodes_to_visit[1:]
583            if len(nodes_to_visit) == 0:
584                break
585
586        self.normalize_prop('depth')
587
588    def compute_adepth(self):
589        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
590
591        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
592        visiting_order = [i for i in range(0, len(self.tree.parents))]
593        visiting_order = sorted(visiting_order, key=lambda q:
594                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))[::-1]
595
596        for node in visiting_order:
597            children = self.tree.children[node]
598            if len(children) != 0:
599                # 0 by default
600                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
601        self.normalize_prop('adepth')
602
603    def compute_children(self):
604        self.props["children"] = [0 for x in range(len(self.tree.children))]
605        for i in range (len(self.props['children'])):
606            self.props['children'][i] = len(self.tree.children[i])
607
608        self.normalize_prop('children')
609
610    def compute_progress(self):
611        self.props["progress"] = [0 for x in range(len(self.tree.children))]
612        for i in range(len(self.props['children'])):
613            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
614            if len(times) > 4:
615                times = [times[i+1] - times[i] for i in range(len(times)-1)]
616                #print(times)
617                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
618                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
619
620        for i in range(0, 5):
621            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
622            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
623
624        mini = min(self.props['progress'])
625        maxi = max(self.props['progress'])
626        for k in range(len(self.props['progress'])):
627            if self.props['progress'][k] == 0:
628                self.props['progress'][k] = mini
629
630        #for k in range(len(self.props['progress'])):
631        #        self.props['progress'][k] = 1-self.props['progress'][k]
632
633        self.normalize_prop('progress')
634
635    def normalize_prop(self, prop):
636        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
637        if len(noneless) > 0:
638            max_val = max(noneless)
639            min_val = min(noneless)
640            print("%s: [%g, %g]" % (prop, min_val, max_val))
641            self.props[prop +'_max'] = max_val
642            self.props[prop +'_min'] = min_val
643            for i in range(len(self.props[prop])):
644                if self.props[prop][i] is not None:
645                    qqq = self.props[prop][i]
646                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
647
648class TreeData:
649    simple_data = None
650
651    children = []
652    parents = []
653    time = []
654    kind = []
655
656    def __init__(self): #, simple_data=False):
657        #self.simple_data = simple_data
658        pass
659
660    def load(self, filename, max_nodes=0):
661        print("Loading...")
662
663        CLI_PREFIX = "Script.Message:"
664        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
665
666        self.ids = {}
667        def get_id(id, createOnError = True):
668            if createOnError:
669                if id not in self.ids:
670                    self.ids[id] = len(self.ids)
671            else:
672                if id not in self.ids:
673                    return None
674            return self.ids[id]
675
676        file = open(filename)
677
678        # counting the number of expected nodes
679        nodes = 0
680        for line in file:
681            line_arr = line.split(' ', 1)
682            if len(line_arr) == 2:
683                if line_arr[0] == CLI_PREFIX:
684                    line_arr = line_arr[1].split(' ', 1)
685                if line_arr[0] == "[OFFSPRING]":
686                    nodes += 1
687
688        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
689        self.parents = [{} for x in range(nodes)]
690        self.children = [[] for x in range(nodes)]
691        self.time = [0] * nodes
692        self.kind = [0] * nodes
693        self.life_lenght = [0] * nodes
694        self.props = {}
695
696        print("nodes: %d" % len(self.parents))
697
698        file.seek(0)
699        loaded_so_far = 0
700        lasttime = timelib.time()
701        for line in file:
702            line_arr = line.split(' ', 1)
703            if len(line_arr) == 2:
704                if line_arr[0] == CLI_PREFIX:
705                    line_arr = line_arr[1].split(' ', 1)
706                if line_arr[0] == "[OFFSPRING]":
707                    try:
708                        creature = json.loads(line_arr[1])
709                    except ValueError:
710                        print("Json format error - the line cannot be read. Breaking the loading loop.")
711                        # fixing arrays by removing the last element
712                        # ! assuming that only the last line is broken !
713                        self.parents.pop()
714                        self.children.pop()
715                        self.time.pop()
716                        self.kind.pop()
717                        self.life_lenght.pop()
718                        nodes -= 1
719                        break
720
721                    if "FromIDs" in creature:
722
723                        # make sure that ID's of parents are lower than that of their children
724                        for i in range(0, len(creature["FromIDs"])):
725                            if creature["FromIDs"][i] not in self.ids:
726                                get_id("virtual_parent")
727
728                        creature_id = get_id(creature["ID"])
729
730                        # debug
731                        if loaded_so_far%1000 == 0:
732                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
733                            lasttime = timelib.time()
734
735                        # we assign to each parent its contribution to the genotype of the child
736                        for i in range(0, len(creature["FromIDs"])):
737                            if creature["FromIDs"][i] in self.ids:
738                                parent_id = get_id(creature["FromIDs"][i])
739                            else:
740                                parent_id = get_id("virtual_parent")
741                            inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
742                            self.parents[creature_id][parent_id] = inherited
743
744                        if "Time" in creature:
745                            self.time[creature_id] = creature["Time"]
746
747                        if "Kind" in creature:
748                            self.kind[creature_id] = creature["Kind"]
749
750                        for prop in creature:
751                            if prop not in default_props:
752                                if prop not in self.props:
753                                    self.props[prop] = [0 for i in range(nodes)]
754                                self.props[prop][creature_id] = creature[prop]
755
756                        loaded_so_far += 1
757                    else:
758                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
759                if line_arr[0] == "[DIED]":
760                    creature = json.loads(line_arr[1])
761                    creature_id = get_id(creature["ID"], False)
762                    if creature_id is not None:
763                        for prop in creature:
764                            if prop not in default_props:
765                                if prop not in self.props:
766                                    self.props[prop] = [0 for i in range(nodes)]
767                                self.props[prop][creature_id] = creature[prop]
768
769
770            if loaded_so_far >= max_nodes and max_nodes != 0:
771                break
772
773        for k in range(len(self.parents)):
774            v = self.parents[k]
775            for val in self.parents[k]:
776                self.children[val].append(k)
777
778depth = {}
779kind = {}
780
781def main():
782
783    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
784                                                 'information from a text file. Supports files generated by Framsticks experiments.')
785    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
786    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
787    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
788
789    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
790    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
791    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
792
793    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
794                                                                      'BIRTHS: time measured as the number of births since the beginning; '
795                                                                      'GENERATIONAL: time measured as number of ancestors; '
796                                                                      'REAL: real time of the simulation')
797    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
798    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
799    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
800    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
801    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
802    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
803
804    parser.set_defaults(draw_tree=True)
805    parser.set_defaults(draw_skeleton=False)
806    parser.set_defaults(draw_spine=False)
807
808    parser.set_defaults(seed=-1)
809
810    args = parser.parse_args()
811
812    TIME = args.time.upper()
813    BALANCE = args.balance.upper()
814    SCALE = args.scale.upper()
815    JITTER = args.jitter
816    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
817        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
818        or not SCALE in ['NONE', 'SIMPLE']:
819        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
820        return
821
822    dir = args.input
823    seed = args.seed
824    if seed == -1:
825        seed = random.randint(0, 10000)
826    random.seed(seed)
827    print("randomseed:", seed)
828
829    tree = TreeData()
830    tree.load(dir, max_nodes=args.max_nodes)
831
832
833    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
834    designer.calculate_measures()
835    designer.calculate_node_positions(ignore_last=args.skip)
836
837    if args.output.endswith(".svg"):
838        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
839    else:
840        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
841    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
842
843
844main()
Note: See TracBrowser for help on using the repository browser.