Changeset 602 for mds-and-trees


Ignore:
Timestamp:
08/28/16 17:00:05 (4 years ago)
Author:
Maciej Komosinski
Message:
  • More flexible reading of distance matrices from files
  • Can also read labels for individual points and display them
File:
1 edited

Legend:

Unmodified
Added
Removed
  • mds-and-trees/mds_plot.py

    r600 r602  
    44import sys
    55import numpy as np
    6 from sklearn import manifold
     6#from sklearn import manifold #was needed for manifold MDS http://scikit-learn.org/stable/auto_examples/manifold/plot_compare_methods.html
    77
    88#to make it work in console, http://stackoverflow.com/questions/2801882/generating-a-png-with-matplotlib-when-display-is-undefined
     
    7070def read_file(fname, separator):
    7171        distances = np.genfromtxt(fname, delimiter=separator)
    72         if np.isnan(distances[0][len(distances[0])-1]):#separator after the last element in row
    73                 distances = np.array([row[:-1] for row in distances])
    74         return distances
     72        if (distances.shape[0]!=distances.shape[1]):
     73                print("Matrix is not square:",distances.shape)
     74                minsize = min(distances.shape[0],distances.shape[1])
     75                distances = np.array([row[:minsize] for row in distances]) #this can only fix matrices with more columns than rows
     76                print("Making it square:",distances.shape)
     77
     78        try: #maybe the file has more columns than rows, and the extra column has labels?
     79                labels = np.genfromtxt(fname, delimiter=separator, usecols=distances.shape[0],dtype=[('label','S10')])
     80                labels = [label[0].decode("utf-8") for label in labels]
     81        except ValueError:
     82                labels = None #no labels
     83       
     84        return distances,labels
    7585
    7686
     
    8999
    90100
    91 def plot(coordinates, dimensions, jitter=0, outname=""):
     101def plot(coordinates, labels, dimensions, jitter=0, outname=""):
    92102        fig = plt.figure()
    93103
     
    102112        y_dim = len(coordinates)
    103113
    104         ax.scatter(*[add_jitter(coordinates[:, i]) for i in range(x_dim)], alpha=0.5)
     114        points = [add_jitter(coordinates[:, i]) for i in range(x_dim)]
     115       
     116        if labels is not None and dimensions==2:
     117                ax.scatter(*points, alpha=0.1) #barely visible points, because we will show labels anyway
     118                labelconvert={'vel':'V','vpp':'P','vpa':'A'} #use this if you want to replace long names with short IDs
     119                #for point in points:
     120                #       print(point)
     121                for label, x, y in zip(labels, points[0], points[1]):
     122                        #if label not in knownlabels:
     123                        #       knownlabels.append(label)
     124                        #       colors.append('#ff0000')
     125                        for key in labelconvert:
     126                                if label.startswith(key):
     127                                        label=labelconvert[key]
     128                        plt.annotate(
     129                                label,
     130                                xy = (x, y), xytext = (0, 0),
     131                                textcoords = 'offset points', ha = 'center', va = 'center',
     132                                #bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
     133                                #arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0')
     134                                )
     135        else:
     136                ax.scatter(*points, alpha=0.5)
     137
    105138
    106139        plt.title('Phenotypes distances')
     
    118151def main(filename, dimensions=3, outname="", jitter=0, separator='\t'):
    119152        dimensions = int(dimensions)
    120         distances = read_file(filename, separator)
     153        distances,labels = read_file(filename, separator)
    121154        embed = compute_mds(distances, dimensions)
    122155
     
    124157                embed = np.array([np.insert(e, 0, 0, axis=0) for e in embed])
    125158       
    126         plot(embed, dimensions, jitter, outname)
     159        plot(embed, labels, dimensions, jitter, outname)
    127160
    128161
Note: See TracChangeset for help on using the changeset viewer.