source: mds-and-trees/tree-genealogy.py @ 690

Last change on this file since 690 was 690, checked in by konrad, 7 years ago

Better (but not ideal) behavior for very big trees

File size: 34.7 KB
Line 
1import json
2import math
3import random
4import argparse
5import bisect
6import time as timelib
7from PIL import Image, ImageDraw, ImageFont
8from scipy import stats
9import numpy as np
10
11class LoadingError(Exception):
12    pass
13
14class Drawer:
15
16    def __init__(self, design, config_file, w=600, h=800, w_margin=10, h_margin=20):
17        self.design = design
18        self.width = w
19        self.height = h
20        self.w_margin = w_margin
21        self.h_margin = h_margin
22        self.w_no_margs = w - 2* w_margin
23        self.h_no_margs = h - 2* h_margin
24
25        self.colors = {
26            'white' :   {'r':100,   'g':100,    'b':100},
27            'black' :   {'r':0,     'g':0,      'b':0},
28            'red' :     {'r':100,   'g':0,      'b':0},
29            'green' :   {'r':0,     'g':100,    'b':0},
30            'blue' :    {'r':0,     'g':0,      'b':100},
31            'yellow' :  {'r':100,   'g':100,    'b':0},
32            'magenta' : {'r':100,   'g':0,      'b':100},
33            'cyan' :    {'r':0,     'g':100,    'b':100},
34            'orange':   {'r':100,   'g':50,     'b':0},
35            'purple':   {'r':50,    'g':0,      'b':100}
36        }
37
38        self.settings = {
39            'colors_of_kinds': ['red', 'green', 'blue', 'magenta', 'yellow', 'cyan', 'orange', 'purple'],
40            'dots': {
41                'color': {
42                    'meaning': 'Lifespan',
43                    'start': 'red',
44                    'end': 'green',
45                    'bias': 1
46                    },
47                'size': {
48                    'meaning': 'EnergyEaten',
49                    'start': 1,
50                    'end': 6,
51                    'bias': 0.5
52                    },
53                'opacity': {
54                    'meaning': 'EnergyEaten',
55                    'start': 0.2,
56                    'end': 1,
57                    'bias': 1
58                    }
59            },
60            'lines': {
61                'color': {
62                    'meaning': 'adepth',
63                    'start': 'black',
64                    'end': 'red',
65                    'bias': 3
66                    },
67                'width': {
68                    'meaning': 'adepth',
69                    'start': 0.1,
70                    'end': 4,
71                    'bias': 3
72                    },
73                'opacity': {
74                    'meaning': 'adepth',
75                    'start': 0.1,
76                    'end': 0.8,
77                    'bias': 5
78                    }
79            }
80        }
81
82        def merge(source, destination):
83            for key, value in source.items():
84                if isinstance(value, dict):
85                    node = destination.setdefault(key, {})
86                    merge(value, node)
87                else:
88                    destination[key] = value
89
90            return destination
91
92        if config_file != "":
93            with open(config_file) as config:
94                c = json.load(config)
95            self.settings = merge(c, self.settings)
96            #print(json.dumps(self.settings, indent=4, sort_keys=True))
97
98    def draw_dots(self, file, min_width, max_width, max_height):
99        for i in range(len(self.design.positions)):
100            node = self.design.positions[i]
101            if 'x' not in node:
102                continue
103            dot_style = self.compute_dot_style(node=i)
104            self.add_dot(file, (self.w_margin+self.w_no_margs*(node['x']-min_width)/(max_width-min_width),
105                               self.h_margin+self.h_no_margs*node['y']/max_height), dot_style)
106
107    def draw_lines(self, file, min_width, max_width, max_height):
108        for parent in range(len(self.design.positions)):
109            par_pos = self.design.positions[parent]
110            if not 'x' in par_pos:
111                continue
112            for child in self.design.tree.children[parent]:
113                chi_pos = self.design.positions[child]
114                if 'x' not in chi_pos:
115                    continue
116                line_style = self.compute_line_style(parent, child)
117                self.add_line(file, (self.w_margin+self.w_no_margs*(par_pos['x']-min_width)/(max_width-min_width),
118                                  self.h_margin+self.h_no_margs*par_pos['y']/max_height),
119                                  (self.w_margin+self.w_no_margs*(chi_pos['x']-min_width)/(max_width-min_width),
120                                  self.h_margin+self.h_no_margs*chi_pos['y']/max_height), line_style)
121
122    def draw_scale(self, file, filename):
123        self.add_text(file, "Generated from " + filename.split("\\")[-1], (5, 5), "start")
124
125        start_text = ""
126        end_text = ""
127        if self.design.TIME == "BIRTHS":
128           start_text = "Birth #0"
129           end_text = "Birth #" + str(len(self.design.positions)-1)
130        if self.design.TIME == "REAL":
131           start_text = "Time " + str(min(self.design.tree.time))
132           end_text = "Time " + str(max(self.design.tree.time))
133        if self.design.TIME == "GENERATIONAL":
134           start_text = "Depth " + str(self.design.props['adepth_min'])
135           end_text = "Depth " + str(self.design.props['adepth_max'])
136
137        self.add_dashed_line(file, (self.width*0.7, self.h_margin), (self.width, self.h_margin))
138        self.add_text(file, start_text, (self.width, self.h_margin), "end")
139        self.add_dashed_line(file, (self.width*0.7, self.height-self.h_margin), (self.width, self.height-self.h_margin))
140        self.add_text(file, end_text, (self.width, self.height-self.h_margin), "end")
141
142    def compute_property(self, part, prop, node):
143        start = self.settings[part][prop]['start']
144        end = self.settings[part][prop]['end']
145        value = (self.design.props[self.settings[part][prop]['meaning']][node]
146                 if self.settings[part][prop]['meaning'] in self.design.props else 0 )
147        bias = self.settings[part][prop]['bias']
148        if prop == "color":
149            return self.compute_color(start, end, value, bias)
150        else:
151            return self.compute_value(start, end, value, bias)
152
153    def compute_color(self, start, end, value, bias=1):
154        if isinstance(value, str):
155            value = int(value)
156            r = self.colors[self.settings['colors_of_kinds'][value]]['r']
157            g = self.colors[self.settings['colors_of_kinds'][value]]['g']
158            b = self.colors[self.settings['colors_of_kinds'][value]]['b']
159        else:
160            start_color = self.colors[start]
161            end_color = self.colors[end]
162            value = 1 - (1-value)**bias
163            r = start_color['r']*(1-value)+end_color['r']*value
164            g = start_color['g']*(1-value)+end_color['g']*value
165            b = start_color['b']*(1-value)+end_color['b']*value
166        return (r, g, b)
167
168    def compute_value(self, start, end, value, bias=1):
169        value = 1 - (1-value)**bias
170        return start*(1-value) + end*value
171
172class PngDrawer(Drawer):
173
174    def scale_up(self):
175        self.width *= self.multi
176        self.height *= self.multi
177        self.w_margin *= self.multi
178        self.h_margin *= self.multi
179        self.h_no_margs *= self.multi
180        self.w_no_margs *= self.multi
181
182    def scale_down(self):
183        self.width /= self.multi
184        self.height /= self.multi
185        self.w_margin /= self.multi
186        self.h_margin /= self.multi
187        self.h_no_margs /= self.multi
188        self.w_no_margs /= self.multi
189
190    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
191        print("Drawing...")
192
193        self.multi=multi
194        self.scale_up()
195
196        back = Image.new('RGBA', (self.width, self.height), (255,255,255,0))
197
198        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
199        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
200        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
201
202        self.draw_lines(back, min_width, max_width, max_height)
203        self.draw_dots(back, min_width, max_width, max_height)
204
205        if scale == "SIMPLE":
206            self.draw_scale(back, input_filename)
207
208        #back.show()
209        self.scale_down()
210
211        back.thumbnail((self.width, self.height), Image.ANTIALIAS)
212
213        back.save(filename)
214
215    def add_dot(self, file, pos, style):
216        x, y = int(pos[0]), int(pos[1])
217        r = style['r']*self.multi
218        offset = (int(x - r), int(y - r))
219        size = (2*int(r), 2*int(r))
220
221        c = style['color']
222
223        img = Image.new('RGBA', size)
224        ImageDraw.Draw(img).ellipse((1, 1, size[0]-1, size[1]-1),
225                                    (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])))
226        file.paste(img, offset, mask=img)
227
228    def add_line(self, file, from_pos, to_pos, style):
229        fx, fy, tx, ty = int(from_pos[0]), int(from_pos[1]), int(to_pos[0]), int(to_pos[1])
230        w = int(style['width'])*self.multi
231
232        offset = (min(fx-w, tx-w), min(fy-w, ty-w))
233        size = (abs(fx-tx)+2*w, abs(fy-ty)+2*w)
234
235        c = style['color']
236
237        img = Image.new('RGBA', size)
238        ImageDraw.Draw(img).line((w, w, size[0]-w, size[1]-w) if (fx-tx)*(fy-ty)>0 else (size[0]-w, w, w, size[1]-w),
239                                  (int(2.55*c[0]), int(2.55*c[1]), int(2.55*c[2]), int(255*style['opacity'])), w)
240        file.paste(img, offset, mask=img)
241
242    def add_dashed_line(self, file, from_pos, to_pos):
243        style = {'color': (0,0,0), 'width': 1, 'opacity': 1}
244        sublines = 50
245        # TODO could be faster: compute delta and only add delta each time (but currently we do not use it often)
246        normdiv = 2*sublines-1
247        for i in range(sublines):
248            from_pos_sub = (self.compute_value(from_pos[0], to_pos[0], 2*i/normdiv, 1),
249                            self.compute_value(from_pos[1], to_pos[1], 2*i/normdiv, 1))
250            to_pos_sub = (self.compute_value(from_pos[0], to_pos[0], (2*i+1)/normdiv, 1),
251                          self.compute_value(from_pos[1], to_pos[1], (2*i+1)/normdiv, 1))
252            self.add_line(file, from_pos_sub, to_pos_sub, style)
253
254    def add_text(self, file, text, pos, anchor, style=''):
255        font = ImageFont.truetype("Vera.ttf", 16*self.multi)
256
257        img = Image.new('RGBA', (self.width, self.height))
258        draw = ImageDraw.Draw(img)
259        txtsize = draw.textsize(text, font=font)
260        pos = pos if anchor == "start" else (pos[0]-txtsize[0], pos[1])
261        draw.text(pos, text, (0,0,0), font=font)
262        file.paste(img, (0,0), mask=img)
263
264    def compute_line_style(self, parent, child):
265        return {'color': self.compute_property('lines', 'color', child),
266                'width': self.compute_property('lines', 'width', child),
267                'opacity': self.compute_property('lines', 'opacity', child)}
268
269    def compute_dot_style(self, node):
270        return {'color': self.compute_property('dots', 'color', node),
271                'r': self.compute_property('dots', 'size', node),
272                'opacity': self.compute_property('dots', 'opacity', node)}
273
274class SvgDrawer(Drawer):
275    def draw_design(self, filename, input_filename, multi=1, scale="SIMPLE"):
276        print("Drawing...")
277        file = open(filename, "w")
278
279        min_width = min([x['x'] for x in self.design.positions if 'x' in x])
280        max_width = max([x['x'] for x in self.design.positions if 'x' in x])
281        max_height = max([x['y'] for x in self.design.positions if 'y' in x])
282
283        file.write('<svg xmlns:svg="http://www.w3.org/2000/svg" xmlns="http://www.w3.org/2000/svg" '
284                   'xmlns:xlink="http://www.w3.org/1999/xlink" version="1.0" '
285                   'width="' + str(self.width) + '" height="' + str(self.height) + '">')
286
287        self.draw_lines(file, min_width, max_width, max_height)
288        self.draw_dots(file, min_width, max_width, max_height)
289
290        if scale == "SIMPLE":
291            self.draw_scale(file, input_filename)
292
293        file.write("</svg>")
294        file.close()
295
296    def add_text(self, file, text, pos, anchor, style=''):
297        style = (style if style != '' else 'style="font-family: Arial; font-size: 12; fill: #000000;"')
298        # assuming font size 12, it should be taken from the style string!
299        file.write('<text ' + style + ' text-anchor="' + anchor + '" x="' + str(pos[0]) + '" y="' + str(pos[1]+12) + '" >' + text + '</text>')
300
301    def add_dot(self, file, pos, style):
302        file.write('<circle ' + style + ' cx="' + str(pos[0]) + '" cy="' + str(pos[1]) + '" />')
303
304    def add_line(self, file, from_pos, to_pos, style):
305        file.write('<line ' + style + ' x1="' + str(from_pos[0]) + '" x2="' + str(to_pos[0]) +
306                       '" y1="' + str(from_pos[1]) + '" y2="' + str(to_pos[1]) + '"  fill="none"/>')
307
308    def add_dashed_line(self, file, from_pos, to_pos):
309        style = 'stroke="black" stroke-width="0.5" stroke-opacity="1" stroke-dasharray="5, 5"'
310        self.add_line(file, from_pos, to_pos, style)
311
312    def compute_line_style(self, parent, child):
313        return self.compute_stroke_color('lines', child) + ' ' \
314               + self.compute_stroke_width('lines', child) + ' ' \
315               + self.compute_stroke_opacity(child)
316
317    def compute_dot_style(self, node):
318        return self.compute_dot_size(node) + ' ' \
319               + self.compute_fill_opacity(node) + ' ' \
320               + self.compute_dot_fill(node)
321
322    def compute_stroke_color(self, part, node):
323        color = self.compute_property(part, 'color', node)
324        return 'stroke="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
325
326    def compute_stroke_width(self, part, node):
327        return 'stroke-width="' + str(self.compute_property(part, 'width', node)) + '"'
328
329    def compute_stroke_opacity(self, node):
330        return 'stroke-opacity="' + str(self.compute_property('lines', 'opacity', node)) + '"'
331
332    def compute_fill_opacity(self, node):
333        return 'fill-opacity="' + str(self.compute_property('dots', 'opacity', node)) + '"'
334
335    def compute_dot_size(self, node):
336        return 'r="' + str(self.compute_property('dots', 'size', node)) + '"'
337
338    def compute_dot_fill(self, node):
339        color = self.compute_property('dots', 'color', node)
340        return 'fill="rgb(' + str(color[0]) + '%,' + str(color[1]) + '%,' + str(color[2]) + '%)"'
341
342class Designer:
343
344    def __init__(self, tree, jitter=False, time="GENERATIONAL", balance="DENSITY"):
345        self.props = {}
346
347        self.tree = tree
348
349        self.TIME = time
350        self.JITTER = jitter
351
352        if balance == "RANDOM":
353            self.xmin_crowd = self.xmin_crowd_random
354        elif balance == "MIN":
355            self.xmin_crowd = self.xmin_crowd_min
356        elif balance == "DENSITY":
357            self.xmin_crowd = self.xmin_crowd_density
358        else:
359            raise ValueError("Error, the value of BALANCE does not match any expected value.")
360
361    def calculate_measures(self):
362        print("Calculating measures...")
363        self.compute_depth()
364        self.compute_adepth()
365        self.compute_children()
366        self.compute_kind()
367        self.compute_time()
368        self.compute_progress()
369        self.compute_custom()
370
371    def xmin_crowd_random(self, x1, x2, y):
372        return (x1 if random.randrange(2) == 0 else x2)
373
374    def xmin_crowd_min(self, x1, x2, y):
375        x1_closest = 999999
376        x2_closest = 999999
377        miny = y-3
378        maxy = y+3
379        i = bisect.bisect_left(self.y_sorted, miny)
380        while True:
381            if len(self.positions_sorted) <= i or self.positions_sorted[i]['y'] > maxy:
382                break
383            pos = self.positions_sorted[i]
384
385            x1_closest = min(x1_closest, abs(x1-pos['x']))
386            x2_closest = min(x2_closest, abs(x2-pos['x']))
387
388            i += 1
389        return (x1 if x1_closest > x2_closest else x2)
390
391    def xmin_crowd_density(self, x1, x2, y):
392        # TODO experimental - requires further work to make it less 'jumpy' and more predictable
393        x1_dist_loc = 0
394        x2_dist_loc = 0
395        count_loc = 1
396        x1_dist_glob = 0
397        x2_dist_glob = 0
398        count_glob = 1
399        miny = y-2000
400        maxy = y+2000
401        i_left = bisect.bisect_left(self.y_sorted, miny)
402        i_right = bisect.bisect_right(self.y_sorted, maxy)
403        # print("i " + str(i) + " len " + str(len(self.positions)))
404        #
405        # i = bisect.bisect_left(self.y_sorted, y)
406        # i_left = max(0, i - 25)
407        # i_right = min(len(self.y_sorted), i + 25)
408
409        def include_pos(pos):
410            nonlocal x1_dist_loc, x2_dist_loc, x1_dist_glob, x2_dist_glob, count_loc, count_glob
411            dysq = (pos['y']-y)**2
412            dx1 = pos['x']-x1
413            dx2 = pos['x']-x2
414
415            d = math.fabs(pos['x'] - (x1+x2)/2)
416
417            if d < 10:
418                x1_dist_loc += math.sqrt(dysq + dx1**2)
419                x2_dist_loc += math.sqrt(dysq + dx2**2)
420                count_loc += 1
421            elif d > 20:
422                x1_dist_glob += math.sqrt(dysq + dx1**2)
423                x2_dist_glob += math.sqrt(dysq + dx2**2)
424                count_glob += 1
425
426        # optimized to draw from all the nodes, if less than 10 nodes in the range
427        if len(self.positions_sorted) > i_left:
428            if i_right - i_left < 10:
429                for j in range(i_left, i_right):
430                    include_pos(self.positions_sorted[j])
431            else:
432                for j in range(10):
433                    pos = self.positions_sorted[random.randrange(i_left, i_right)]
434                    include_pos(pos)
435
436        # return ((x1 if x1_dist > x2_dist else x2)
437        #             if x1_dist < 10000 else
438        #         (x1 if x1_dist < x2_dist else x2))
439
440
441        return (x1 if (x1_dist_loc-x2_dist_loc)/count_loc-(x1_dist_glob-x2_dist_glob)/count_glob > 0  else x2)
442        #return (x1 if x1_dist +random.gauss(0, 0.00001) > x2_dist +random.gauss(0, 0.00001)  else x2)
443        #print(x1_dist, x2_dist)
444        #x1_dist = x1_dist**2
445        #x2_dist = x2_dist**2
446        #return x1 if x1_dist+x2_dist==0 else (x1*x1_dist + x2*x2_dist) / (x1_dist+x2_dist) + random.gauss(0, 0.01)
447        #return (x1 if random.randint(0, int(x1_dist+x2_dist)) < x1_dist else x2)
448
449    def calculate_node_positions(self, ignore_last=0):
450        print("Calculating positions...")
451
452        def add_node(node):
453            index = bisect.bisect_left(self.y_sorted, node['y'])
454            self.y_sorted.insert(index, node['y'])
455            self.positions_sorted.insert(index, node)
456            self.positions[node['id']] = node
457
458        self.positions_sorted = [{'x':0, 'y':0, 'id':0}]
459        self.y_sorted = [0]
460        self.positions = [{} for x in range(len(self.tree.parents))]
461        self.positions[0] = {'x':0, 'y':0, 'id':0}
462
463        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
464        visiting_order = [i for i in range(0, len(self.tree.parents))]
465        visiting_order = sorted(visiting_order, key=lambda q:
466                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))
467
468        start_time = timelib.time()
469
470        # for each child of the current node
471        for node_counter,child in enumerate(visiting_order, start=1):
472            # debug info - elapsed time
473            if node_counter % 100000 == 0:
474               print("%d%%\t%d\t%g" % (node_counter*100/len(self.tree.parents), node_counter, timelib.time()-start_time))
475               start_time = timelib.time()
476
477            # using normalized adepth
478            if self.props['adepth'][child] >= ignore_last/self.props['adepth_max']:
479
480                ypos = 0
481                if self.TIME == "BIRTHS":
482                    ypos = child
483                elif self.TIME == "GENERATIONAL":
484                    # one more than its parent (what if more than one parent?)
485                    ypos = max([self.positions[par]['y'] for par, v in self.tree.parents[child].items()])+1 \
486                        if self.tree.parents[child] else 0
487                elif self.TIME == "REAL":
488                    ypos = self.tree.time[child]
489
490                if len(self.tree.parents[child]) == 1:
491                # if current_node is the only parent
492                    parent, similarity = [(par, v) for par, v in self.tree.parents[child].items()][0]
493
494                    if self.JITTER:
495                        dissimilarity = (1-similarity) + random.gauss(0, 0.01) + 0.001
496                    else:
497                        dissimilarity = (1-similarity) + 0.001
498                    add_node({'id':child, 'y':ypos, 'x':
499                             self.xmin_crowd(self.positions[parent]['x']-dissimilarity,
500                              self.positions[parent]['x']+dissimilarity, ypos)})
501                else:
502                    # position weighted by the degree of inheritence from each parent
503                    total_inheretance = sum([v for k, v in self.tree.parents[child].items()])
504                    xpos = sum([self.positions[k]['x']*v/total_inheretance
505                               for k, v in self.tree.parents[child].items()])
506                    if self.JITTER:
507                        add_node({'id':child, 'y':ypos, 'x':xpos + random.gauss(0, 0.1)})
508                    else:
509                        add_node({'id':child, 'y':ypos, 'x':xpos})
510
511
512    def compute_custom(self):
513        for prop in self.tree.props:
514            self.props[prop] = [None for x in range(len(self.tree.children))]
515
516            for i in range(len(self.props[prop])):
517                self.props[prop][i] = self.tree.props[prop][i]
518
519            self.normalize_prop(prop)
520
521    def compute_time(self):
522        # simple rewrite from the tree
523        self.props["time"] = [0 for x in range(len(self.tree.children))]
524
525        for i in range(len(self.props['time'])):
526            self.props['time'][i] = self.tree.time[i]
527
528        self.normalize_prop('time')
529
530    def compute_kind(self):
531        # simple rewrite from the tree
532        self.props["kind"] = [0 for x in range(len(self.tree.children))]
533
534        for i in range (len(self.props['kind'])):
535            self.props['kind'][i] = str(self.tree.kind[i])
536
537    def compute_depth(self):
538        self.props["depth"] = [999999999 for x in range(len(self.tree.children))]
539        visited = [0 for x in range(len(self.tree.children))]
540
541        nodes_to_visit = [0]
542        visited[0] = 1
543        self.props["depth"][0] = 0
544        while True:
545            current_node = nodes_to_visit[0]
546
547            for child in self.tree.children[current_node]:
548                if visited[child] == 0:
549                    visited[child] = 1
550                    nodes_to_visit.append(child)
551                    self.props["depth"][child] = self.props["depth"][current_node]+1
552            nodes_to_visit = nodes_to_visit[1:]
553            if len(nodes_to_visit) == 0:
554                break
555
556        self.normalize_prop('depth')
557
558    def compute_adepth(self):
559        self.props["adepth"] = [0 for x in range(len(self.tree.children))]
560
561        # order by maximum depth of the parent guarantees that co child is evaluated before its parent
562        visiting_order = [i for i in range(0, len(self.tree.parents))]
563        visiting_order = sorted(visiting_order, key=lambda q:
564                            0 if q == 0 else max([self.props["depth"][d] for d in self.tree.parents[q]]))[::-1]
565
566        for node in visiting_order:
567            children = self.tree.children[node]
568            if len(children) != 0:
569                # 0 by default
570                self.props["adepth"][node] = max([self.props["adepth"][child] for child in children])+1
571        self.normalize_prop('adepth')
572
573    def compute_children(self):
574        self.props["children"] = [0 for x in range(len(self.tree.children))]
575        for i in range (len(self.props['children'])):
576            self.props['children'][i] = len(self.tree.children[i])
577
578        self.normalize_prop('children')
579
580    def compute_progress(self):
581        self.props["progress"] = [0 for x in range(len(self.tree.children))]
582        for i in range(len(self.props['children'])):
583            times = sorted([self.props["time"][self.tree.children[i][j]]*100000 for j in range(len(self.tree.children[i]))])
584            if len(times) > 4:
585                times = [times[i+1] - times[i] for i in range(len(times)-1)]
586                #print(times)
587                slope, intercept, r_value, p_value, std_err = stats.linregress(range(len(times)), times)
588                self.props['progress'][i] = slope if not np.isnan(slope) and not np.isinf(slope) else 0
589
590        for i in range(0, 5):
591            self.props['progress'][self.props['progress'].index(min(self.props['progress']))] = 0
592            self.props['progress'][self.props['progress'].index(max(self.props['progress']))] = 0
593
594        mini = min(self.props['progress'])
595        maxi = max(self.props['progress'])
596        for k in range(len(self.props['progress'])):
597            if self.props['progress'][k] == 0:
598                self.props['progress'][k] = mini
599
600        #for k in range(len(self.props['progress'])):
601        #        self.props['progress'][k] = 1-self.props['progress'][k]
602
603        self.normalize_prop('progress')
604
605    def normalize_prop(self, prop):
606        noneless = [v for v in self.props[prop] if (type(v)!=str and type(v)!=list)]
607        if len(noneless) > 0:
608            max_val = max(noneless)
609            min_val = min(noneless)
610            print(prop, max_val, min_val)
611            self.props[prop +'_max'] = max_val
612            self.props[prop +'_min'] = min_val
613            for i in range(len(self.props[prop])):
614                if self.props[prop][i] is not None:
615                    qqq = self.props[prop][i]
616                    self.props[prop][i] = 0 if max_val == min_val else (self.props[prop][i] - min_val) / (max_val - min_val)
617
618class TreeData:
619    simple_data = None
620
621    children = []
622    parents = []
623    time = []
624    kind = []
625
626    def __init__(self): #, simple_data=False):
627        #self.simple_data = simple_data
628        pass
629
630    def load(self, filename, max_nodes=0):
631        print("Loading...")
632
633        CLI_PREFIX = "Script.Message:"
634        default_props = ["Time", "FromIDs", "ID", "Operation", "Inherited"]
635
636        self.ids = {}
637        def get_id(id, createOnError = True):
638            if createOnError:
639                if id not in self.ids:
640                    self.ids[id] = len(self.ids)
641            else:
642                if id not in self.ids:
643                    return None
644            return self.ids[id]
645
646        file = open(filename)
647
648        # counting the number of expected nodes
649        nodes = 0
650        for line in file:
651            line_arr = line.split(' ', 1)
652            if len(line_arr) == 2:
653                if line_arr[0] == CLI_PREFIX:
654                    line_arr = line_arr[1].split(' ', 1)
655                if line_arr[0] == "[OFFSPRING]":
656                    nodes += 1
657
658        nodes = min(nodes, max_nodes if max_nodes != 0 else nodes)+1
659        self.parents = [{} for x in range(nodes)]
660        self.children = [[] for x in range(nodes)]
661        self.time = [0] * nodes
662        self.kind = [0] * nodes
663        self.life_lenght = [0] * nodes
664        self.props = {}
665
666        print("nodes: %d" % len(self.parents))
667
668        file.seek(0)
669        loaded_so_far = 0
670        lasttime = timelib.time()
671        for line in file:
672            line_arr = line.split(' ', 1)
673            if len(line_arr) == 2:
674                if line_arr[0] == CLI_PREFIX:
675                    line_arr = line_arr[1].split(' ', 1)
676                if line_arr[0] == "[OFFSPRING]":
677                    try:
678                        creature = json.loads(line_arr[1])
679                    except ValueError:
680                        print("Json format error - the line cannot be read. Breaking the loading loop.")
681                        # fixing arrays by removing the last element
682                        # ! assuming that only the last line is broken !
683                        self.parents.pop()
684                        self.children.pop()
685                        self.time.pop()
686                        self.kind.pop()
687                        self.life_lenght.pop()
688                        nodes -= 1
689                        break
690
691                    if "FromIDs" in creature:
692
693                        # make sure that ID's of parents are lower than that of their children
694                        for i in range(0, len(creature["FromIDs"])):
695                            if creature["FromIDs"][i] not in self.ids:
696                                get_id("virtual_parent")
697
698                        creature_id = get_id(creature["ID"])
699
700                        # debug
701                        if loaded_so_far%1000 == 0:
702                            #print(". " + str(creature_id) + " " + str(timelib.time() - lasttime))
703                            lasttime = timelib.time()
704
705                        # we assign to each parent its contribution to the genotype of the child
706                        for i in range(0, len(creature["FromIDs"])):
707                            if creature["FromIDs"][i] in self.ids:
708                                parent_id = get_id(creature["FromIDs"][i])
709                            else:
710                                parent_id = get_id("virtual_parent")
711                            inherited = (creature["Inherited"][i] if 'Inherited' in creature else 1)
712                            self.parents[creature_id][parent_id] = inherited
713
714                        if "Time" in creature:
715                            self.time[creature_id] = creature["Time"]
716
717                        if "Kind" in creature:
718                            self.kind[creature_id] = creature["Kind"]
719
720                        for prop in creature:
721                            if prop not in default_props:
722                                if prop not in self.props:
723                                    self.props[prop] = [0 for i in range(nodes)]
724                                self.props[prop][creature_id] = creature[prop]
725
726                        loaded_so_far += 1
727                    else:
728                        raise LoadingError("[OFFSPRING] misses the 'FromIDs' field!")
729                if line_arr[0] == "[DIED]":
730                    creature = json.loads(line_arr[1])
731                    creature_id = get_id(creature["ID"], False)
732                    if creature_id is not None:
733                        for prop in creature:
734                            if prop not in default_props:
735                                if prop not in self.props:
736                                    self.props[prop] = [0 for i in range(nodes)]
737                                self.props[prop][creature_id] = creature[prop]
738
739
740            if loaded_so_far >= max_nodes and max_nodes != 0:
741                break
742
743        for k in range(len(self.parents)):
744            v = self.parents[k]
745            for val in self.parents[k]:
746                self.children[val].append(k)
747
748depth = {}
749kind = {}
750
751def main():
752
753    parser = argparse.ArgumentParser(description='Draws a genealogical tree (generates a SVG file) based on parent-child relationship '
754                                                 'information from a text file. Supports files generated by Framsticks experiments.')
755    parser.add_argument('-i', '--in', dest='input', required=True, help='input file name with stuctured evolutionary data')
756    parser.add_argument('-o', '--out', dest='output', required=True, help='output file name for the evolutionary tree (SVG/PNG/JPG/BMP)')
757    parser.add_argument('-c', '--config', dest='config', default="", help='config file name ')
758
759    parser.add_argument('-W', '--width', default=600, type=int, dest='width', help='width of the output image (600 by default)')
760    parser.add_argument('-H', '--height', default=800, type=int, dest='height', help='height of the output image (800 by default)')
761    parser.add_argument('-m', '--multi', default=1, type=int, dest='multi', help='multisampling factor (applicable only for raster images)')
762
763    parser.add_argument('-t', '--time', default='GENERATIONAL', dest='time', help='values on vertical axis (BIRTHS/GENERATIONAL(d)/REAL); '
764                                                                      'BIRTHS: time measured as the number of births since the beginning; '
765                                                                      'GENERATIONAL: time measured as number of ancestors; '
766                                                                      'REAL: real time of the simulation')
767    parser.add_argument('-b', '--balance', default='DENSITY', dest='balance', help='method of placing nodes in the tree (RANDOM/MIN/DENSITY(d))')
768    parser.add_argument('-s', '--scale', default='SIMPLE', dest='scale', help='type of timescale added to the tree (NONE(d)/SIMPLE)')
769    parser.add_argument('-j', '--jitter', dest="jitter", action='store_true', help='draw horizontal positions of children from the normal distribution')
770    parser.add_argument('-p', '--skip', dest="skip", type=int, default=0, help='skip last P levels of the tree (0 by default)')
771    parser.add_argument('-x', '--max-nodes', type=int, default=0, dest='max_nodes', help='maximum number of nodes drawn (starting from the first one)')
772    parser.add_argument('--seed', type=int, dest='seed', help='seed for the random number generator (-1 for random)')
773
774    parser.set_defaults(draw_tree=True)
775    parser.set_defaults(draw_skeleton=False)
776    parser.set_defaults(draw_spine=False)
777
778    parser.set_defaults(seed=-1)
779
780    args = parser.parse_args()
781
782    TIME = args.time.upper()
783    BALANCE = args.balance.upper()
784    SCALE = args.scale.upper()
785    JITTER = args.jitter
786    if not TIME in ['BIRTHS', 'GENERATIONAL', 'REAL']\
787        or not BALANCE in ['RANDOM', 'MIN', 'DENSITY']\
788        or not SCALE in ['NONE', 'SIMPLE']:
789        print("Incorrect value of one of the parameters! (time or balance or scale).") #user has to figure out which parameter is wrong...
790        return
791
792    dir = args.input
793    seed = args.seed
794    if seed == -1:
795        seed = random.randint(0, 10000)
796    random.seed(seed)
797    print("randomseed:", seed)
798
799    tree = TreeData()
800    tree.load(dir, max_nodes=args.max_nodes)
801
802
803    designer = Designer(tree, jitter=JITTER, time=TIME, balance=BALANCE)
804    designer.calculate_measures()
805    designer.calculate_node_positions(ignore_last=args.skip)
806
807    if args.output.endswith(".svg"):
808        drawer = SvgDrawer(designer, args.config, w=args.width, h=args.height)
809    else:
810        drawer = PngDrawer(designer, args.config, w=args.width, h=args.height)
811    drawer.draw_design(args.output, args.input, multi=args.multi, scale=SCALE)
812
813
814main()
Note: See TracBrowser for help on using the repository browser.