source: mds-and-trees/mds_plot.py @ 597

Last change on this file since 597 was 597, checked in by Maciej Komosinski, 8 years ago

No longer ignores "jitter" and "outname" parameters

File size: 2.9 KB
Line 
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4import sys
5import numpy as np
6from sklearn import manifold
7
8#to make it work in console, http://stackoverflow.com/questions/2801882/generating-a-png-with-matplotlib-when-display-is-undefined
9#import matplotlib
10#matplotlib.use('Agg')
11
12import matplotlib.pyplot as plt
13from mpl_toolkits.mplot3d import Axes3D
14from matplotlib import cm
15import argparse
16
17
18
19
20def rand_jitter(arr):
21        stdev = arr.max() / 100.
22        return arr + np.random.randn(len(arr)) * stdev * 2
23
24
25def read_file(fname, separator):
26        distances = np.genfromtxt(fname, delimiter=separator)
27        if np.isnan(distances[0][len(distances[0])-1]):#separator after the last element in row
28                distances = np.array([row[:-1] for row in distances])
29        return distances
30
31
32def compute_mds(distance_matrix, dim):
33        seed = np.random.RandomState(seed=3)
34        mds = manifold.MDS(n_components=int(dim), metric=True, max_iter=3000, eps=1e-9, random_state=seed, dissimilarity="precomputed")
35        embed = mds.fit(distance_matrix).embedding_
36        return embed
37
38
39def compute_variances(embed):
40        variances = []
41        for i in range(len(embed[0])):
42                variances.append(np.var(embed[:,i]))
43        percent_variances = [sum(variances[:i+1])/sum(variances) for i in range(len(variances))]
44        return percent_variances
45
46
47def plot(coordinates, dimensions, jitter=0, outname=""):
48        fig = plt.figure()
49
50        if dimensions < 3:
51                ax = fig.add_subplot(111)
52        else:
53                ax = fig.add_subplot(111, projection='3d')
54
55        add_jitter = lambda tab : rand_jitter(tab) if jitter==1 else tab
56
57        x_dim = len(coordinates[0])
58        y_dim = len(coordinates)
59
60        ax.scatter(*[add_jitter(coordinates[:, i]) for i in range(x_dim)], alpha=0.5)
61
62        plt.title('Phenotypes distances')
63        plt.tight_layout()
64        plt.axis('tight')
65
66        if outname == "":
67                plt.show()
68
69        else:
70                plt.savefig(outname+".pdf")
71
72
73def main(filename, dimensions=3, outname="", jitter=0, separator='\t'):
74        distances = read_file(filename, separator)
75        embed = compute_mds(distances, dimensions)
76
77        variances_perc = compute_variances(embed)
78        for i,vc in enumerate(variances_perc):
79                print(i+1,"dimension:",vc)
80
81        dimensions = int(dimensions)
82        if dimensions == 1:
83                embed = np.array([np.insert(e, 0, 0, axis=0) for e in embed])
84       
85        plot(embed, dimensions, jitter, outname)
86
87
88if __name__ == '__main__':
89        parser = argparse.ArgumentParser()
90        parser.add_argument('--in', dest='input', required=True, help='input file with dissimilarity matrix')
91        parser.add_argument('--out', dest='output', required=False, help='output file name without extension')
92        parser.add_argument('--dim', required=False, help='number of dimensions of the new space')
93        parser.add_argument('--sep', required=False, help='separator of the source file')
94        parser.add_argument('--j', required=False, help='for j=1 random jitter is added to the plot')
95
96        args = parser.parse_args()
97        set_value = lambda value, default : default if value == None else value
98        main(args.input, set_value(args.dim, 3), set_value(args.output, ""), set_value(args.j, 0), set_value(args.sep, "\t"))
Note: See TracBrowser for help on using the repository browser.