import numpy as np
from pyemd import emd
from ctypes import cdll
from ctypes.util import find_library
from alignmodel import align

class DensityDistribution:
    libm = cdll.LoadLibrary(find_library('m'))
    EPSILON = 0.0001
    def __init__(self, FramsLib=None, density = 10, steps = 3, reduce=True, frequency=False, metric = 'emd', fixedZaxis=False, verbose=False):
        """ __init__
        Args:
            density (int, optional): density of samplings for frams.ModelGeometry . Defaults to 10.
            steps (int, optional): How many steps is used for sampling space of voxels, 
                The higher value the more accurate sampling and the longer calculations. Defaults to 3.
            reduce (bool, optional): If we should use reduction to remove blank samples. Defaults to True.
            frequency (bool, optional): If we should use frequency distribution. Defaults to False.
            metric (string, optional): The distance metric that should be used ('emd', 'l1', or 'l2'). Defaults to 'emd'.
            fixedZaxis (bool, optional): If the z axis should be fixed during alignment. Defaults to False.
            verbose (bool, optional): Turning on logging, works only for calculateEMDforGeno. Defaults to False.            
        """
        if FramsLib == None:
            raise ValueError('Frams library not provided!')
        self.frams_lib = FramsLib

        self.density = density
        self.steps = steps
        self.verbose = verbose
        self.reduce = reduce
        self.frequency = frequency
        self.metric = metric
        self.fixedZaxis = fixedZaxis


    def calculateNeighberhood(self,array,mean_coords):
        """ Calculates number of elements for given sample and set ups the center of this sample
        to the center of mass (calculated by mean of every coordinate)
        Args:
            array ([[float,float,float],...,[float,float,float]]): array of voxels that belong to given sample.
            mean_coords ([float,float,float]): default coordinates that are the
                middle of the sample (used when number of voxels in sample is equal to 0)

        Returns:
            weight [int]: number of voxels in a sample
            coordinates [float,float,float]: center of mass for a sample
        """
        weight = len(array)
        if weight > 0:
            point = [np.mean(array[:,0]),np.mean(array[:,1]),np.mean(array[:,2])]
            return weight, point
        else:
            return 0, mean_coords


    def calculateDistPoints(self,point1, point2):
        """ Returns euclidean distance between two points
        Args (distribution):
            point1 ([float,float,float]) - coordinates of first point
            point2 ([float,float,float]) - coordinates of second point
        Args (frequency):
            point1 (float) - value of the first sample
            point2 (float) - value of the second sample

        Returns:
            [float]: euclidean distance
        """
        if self.frequency:
            return abs(point1-point2)
        else:
            return np.sqrt(np.sum(np.square(point1-point2)))


    def calculateDistanceMatrix(self,array1, array2):
        """

        Args:
            array1 ([type]): array of size n with points representing firsts model 
            array2 ([type]): array of size n with points representing second model

        Returns:
            np.array(np.array(,dtype=float)): distance matrix n x n 
        """
        n = len(array1)
        distMatrix = np.zeros((n,n))
        for i in range(n):
            for j in range(n):
                distMatrix[i][j] = self.calculateDistPoints(array1[i], array2[j])
        return np.array(distMatrix)


    def reduceSignaturesFreq(self,s1,s2):
        """Removes samples from signatures if corresponding samples for both models have weight 0.
        Args:
            s1 (np.array(,dtype=np.float64)): values of samples
            s2 (np.array(,dtype=np.float64)): values of samples

        Returns:
            s1new (np.array(,dtype=np.float64)): coordinates of samples after reduction
            s2new (np.array(,dtype=np.float64)): coordinates of samples after reduction
        """
        lens = len(s1)
        indices = []
        for i in range(lens):
            if s1[i]==0 and s2[i]==0:
                    indices.append(i)

        return np.delete(s1, indices), np.delete(s2, indices)


    def reduceSignaturesDens(self,s1,s2):
        """Removes samples from signatures if corresponding samples for both models have weight 0. 
        Args:
            s1 ([np.array(,dtype=np.float64),np.array(,dtype=np.float64)]): [coordinates of samples, weights]
            s2 ([np.array(,dtype=np.float64),np.array(,dtype=np.float64)]): [coordinates of samples, weights]

        Returns:
            s1new ([np.array(,dtype=np.float64),np.array(,dtype=np.float64)]): [coordinates of samples, weights] after reduction
            s2new ([np.array(,dtype=np.float64),np.array(,dtype=np.float64)]): [coordinates of samples, weights] after reduction
        """
        lens = len(s1[0])
        indices = []
        for i in range(lens):
            if s1[1][i]==0 and s2[1][i]==0:
                indices.append(i)

        s1 = [np.delete(s1[0], indices, axis=0), np.delete(s1[1], indices, axis=0)]
        s2 = [np.delete(s2[0], indices, axis=0), np.delete(s2[1], indices, axis=0)]
        return s1, s2


    def getSignatures(self,array,steps_all,step_all):
        """Generates signature for array representing model. Signature is composed of list of points [x,y,z] (float) and list of weights (int).

        Args:
            array (np.array(np.array(,dtype=float))): array with voxels representing model
            steps_all ([np.array(,dtype=float),np.array(,dtype=float),np.array(,dtype=float)]): lists with edges for each step for each axis in order x,y,z
            step_all ([float,float,float]): [size of step for x axis, size of step for y axis, size of step for y axis] 

        Returns (distribution):
           signature [np.array(,dtype=np.float64),np.array(,dtype=np.float64)]: returns signatuere [np.array of points, np.array of weights]
        Returns (frequency):
           signature np.array(,dtype=np.float64): returns signatuere np.array of coefficients
        """
        x_steps,y_steps,z_steps = steps_all
        x_step,y_step,z_step=step_all
        feature_array = []
        weight_array = []
        step_half_x = x_step/2
        step_half_y = y_step/2
        step_half_z = z_step/2
        for x in range(len(x_steps[:-1])):
            for y in range(len(y_steps[:-1])) :
                for z in range(len(z_steps[:-1])):
                    rows=np.where((array[:,0]> x_steps[x]) &
                                  (array[:,0]<= x_steps[x+1]) &
                                  (array[:,1]> y_steps[y]) &
                                  (array[:,1]<= y_steps[y+1]) &
                                  (array[:,2]> z_steps[z]) &
                                  (array[:,2]<= z_steps[z+1]))
                    if self.frequency:
                        feature_array.append(len(array[rows]))
                    else:
                        weight, point = self.calculateNeighberhood(array[rows],[x_steps[x]+step_half_x,y_steps[y]+step_half_y,z_steps[z]+step_half_z])
                        feature_array.append(point)
                        weight_array.append(weight)

        if self.frequency:
            samples = np.array(feature_array,dtype=np.float64)
            return abs(np.fft.fft(samples))
        else:
            return [np.array(feature_array,dtype=np.float64), np.array(weight_array,dtype=np.float64)]


    def getSignaturesForPair(self,array1,array2):
        """generates signatures for given pair of models represented by array of voxels.
        We calculate space for given models by taking the extremas for each axis and dividing the space by the number of steps.
        This divided space generate us samples which contains points. Each sample will have new coordinates which are mean of all points from it and weight
        which equals to the number of points.
       
        Args:
            array1 (np.array(np.array(,dtype=float))): array with voxels representing model1
            array2 (np.array(np.array(,dtype=float))): array with voxels representing model2
            steps (int, optional): How many steps is used for sampling space of voxels. Defaults to self.steps (3).

        Returns:
            s1 ([np.array(,dtype=np.float64),np.array(,dtype=np.float64)]): [coordinates of samples, weights] 
            s2 ([np.array(,dtype=np.float64),np.array(,dtype=np.float64)]): [coordinates of samples, weights]
        """

        min_x = np.min([np.min(array1[:,0]),np.min(array2[:,0])])
        max_x = np.max([np.max(array1[:,0]),np.max(array2[:,0])])
        min_y = np.min([np.min(array1[:,1]),np.min(array2[:,1])])
        max_y = np.max([np.max(array1[:,1]),np.max(array2[:,1])])
        min_z = np.min([np.min(array1[:,2]),np.min(array2[:,2])])
        max_z = np.max([np.max(array1[:,2]),np.max(array2[:,2])])

        x_steps,x_step = np.linspace(min_x,max_x,self.steps,retstep=True)
        y_steps,y_step = np.linspace(min_y,max_y,self.steps,retstep=True)
        z_steps,z_step = np.linspace(min_z,max_z,self.steps,retstep=True)
        
        for intervals in (x_steps, y_steps, z_steps):  # EPSILON subtracted to deal with boundary voxels (one-sided open intervals and comparisons in loops in function getSignatures())
            intervals[0] -= self.EPSILON

        steps_all = (x_steps,y_steps,z_steps)
        step_all = (x_step,y_step,z_step)
        
        s1 = self.getSignatures(array1,steps_all,step_all)
        s2 = self.getSignatures(array2,steps_all,step_all)    
        
        return s1,s2


    def getVoxels(self,geno):
        """ Generates voxels for genotype using frams.ModelGeometry

        Args:
            geno (string): representation of model in one of the formats handled by frams http://www.framsticks.com/a/al_genotype.html

        Returns:
            np.array([np.array(,dtype=float)]: list of voxels representing model.
        """
        model = self.frams_lib.Model.newFromString(geno)
        align(model, self.fixedZaxis)
        model_geometry = self.frams_lib.ModelGeometry.forModel(model)

        model_geometry.geom_density = self.density
        voxels = np.array([np.array([p.x._value(),p.y._value(),p.z._value()]) for p in model_geometry.voxels()])
        return voxels


    def calculateDissimforVoxels(self, voxels1, voxels2):
        """ Calculate EMD for pair of voxels representing models.
        Args:
            voxels1 np.array([np.array(,dtype=float)]: list of voxels representing model1.
            voxels2 np.array([np.array(,dtype=float)]: list of voxels representing model2.
            steps (int, optional): How many steps is used for sampling space of voxels. Defaults to self.steps (3).

        Returns:
            float: dissim for pair of list of voxels
        """
        numvox1 = len(voxels1)
        numvox2 = len(voxels2)    

        s1, s2 = self.getSignaturesForPair(voxels1, voxels2)

        if numvox1 != sum(s1[1]) or numvox2 != sum(s2[1]):
            print("Bad signature!")
            print("Base voxels fig1: ", numvox1, " fig2: ", numvox2)
            print("After reduction voxels fig1: ", sum(s1[1]), " fig2: ", sum(s2[1]))
            raise ValueError("BAd signature!")

        reduce_fun = self.reduceSignaturesFreq if self.frequency else self.reduceSignaturesDens
        if self.reduce:
            s1, s2 = reduce_fun(s1,s2)

            if not self.frequency:
                if numvox1 != sum(s1[1]) or numvox2 != sum(s2[1]):
                    print("Voxel reduction didnt work properly")
                    print("Base voxels fig1: ", numvox1, " fig2: ", numvox2)
                    print("After reduction voxels fig1: ", sum(s1[1]), " fig2: ", sum(s2[1]))
        
        if self.metric == 'l1':
            if self.frequency:
                out = np.linalg.norm((s1-s2), ord=1)
            else:
                out = np.linalg.norm((s1[1]-s2[1]), ord=1)

        elif self.metric == 'l2':
            if self.frequency:
                out = np.linalg.norm((s1-s2))
            else:
                out = np.linalg.norm((s1[1]-s2[1]))

        elif self.metric == 'emd':
            if self.frequency:
                num_points = len(s1)
                dist_matrix = self.calculateDistanceMatrix(range(num_points),range(num_points))
            else:
                dist_matrix = self.calculateDistanceMatrix(s1[0],s2[0])

            self.libm.fedisableexcept(0x04)  # allowing for operation divide by 0 because pyemd requiers it.

            if self.frequency:
                out = emd(s1,s2,np.array(dist_matrix,dtype=np.float64))
            else:
                out = emd(s1[1],s2[1],dist_matrix)

            self.libm.feclearexcept(0x04) # disabling operation divide by 0 because framsticks doesnt like it.
            self.libm.feenableexcept(0x04)

        else:
            raise ValueError("Wrong metric '%s'"%self.metric)

        return out


    def calculateDissimforGeno(self, geno1, geno2):
        """ Calculate EMD for pair of genos.
        Args:
            geno1 (string): representation of model1 in one of the formats handled by frams http://www.framsticks.com/a/al_genotype.html
            geno2 (string): representation of model2 in one of the formats handled by frams http://www.framsticks.com/a/al_genotype.html
            steps (int, optional): How many steps is used for sampling space of voxels. Defaults to self.steps (3).

        Returns:
            float: dissim for pair of strings representing models.
        """     

        voxels1 = self.getVoxels(geno1)
        voxels2 = self.getVoxels(geno2)

        out = self.calculateDissimforVoxels(voxels1, voxels2)

        if self.verbose == True:
            print("Steps: ", self.steps)
            print("Geno1:\n",geno1)
            print("Geno2:\n",geno2)
            print("EMD:\n",out)

        return out


    def getDissimilarityMatrix(self,listOfGeno):
        """

        Args:
            listOfGeno ([string]): list of strings representing genotypes in one of the formats handled by frams http://www.framsticks.com/a/al_genotype.html

        Returns:
            np.array(np.array(,dtype=float)): dissimilarity matrix of EMD for given list of genotypes
        """
        numOfGeno = len(listOfGeno)
        dissimMatrix = np.zeros(shape=[numOfGeno,numOfGeno])
        listOfVoxels = [self.getVoxels(g) for g in listOfGeno]
        for i in range(numOfGeno):
            for j in range(numOfGeno):
                dissimMatrix[i,j] = self.calculateDissimforVoxels(listOfVoxels[i], listOfVoxels[j])
        return dissimMatrix
