from nupack import energy, to_thepairs, to_intseq, to_strseq, to_dpp, can_pair

from copy import copy
import scipy as sp
import numpy as np

from scipy.sparse import csr_matrix, dia_matrix
from scipy.sparse.linalg import expm_multiply, expm, eigsh
from scipy.integrate import odeint

from datrie import BaseTrie
import sys

# This is a poorly written library to calculate the exact
# master equation solution for single stranded kinetics.
# It served its purpose, but should be replaced.

def find_neighbors(seq, struc):
    """
    Find neighboring structures.
    struc: source structure in thepairs format
    seq: sequence in integer format
    """

    neighbors = []
    # First, find neighbors by breaking base pairs
    for i, j in enumerate(struc):
        if j > i:
            n = copy(struc)
            n[i] = -1
            n[j] = -1
            neighbors.append(n)

    for i, j in enumerate(struc):
        if j == -1:
            k = (i - 1 + len(struc)) % len(struc)
            while k != i:
                l = struc[k]
                if l == -1:
                    na, nb = seq[i], seq[k]
                    if i < k and can_pair[na, nb]:
                        n = copy(struc)
                        n[i], n[k] = k, i
                        neighbors.append(n)
                    k -= 1
                else:
                    k = struc[k] - 1
                if k == -1:
                    k = len(struc) - 1

    return neighbors

# def get_state_energy(n_strands, mem, ords, thepairs, n_nucs):
#     n_comps = max(mem)
#     comps = []
#     for i in range(n_strands):


class SimulationGraph:
    def __init__(self, seq, start_struc, endtime=2.50,
            material='rna1995',
            temperature=(273.15 + 23), sodium=1.0, magnesium=0.0):

        self.start_struc = to_thepairs(start_struc)
        self.sequence = to_intseq(seq)
        self.endtime=endtime

        self.physical_opts = {
                'material': material,
                'temperature': temperature,
                'sodium': sodium,
                'magnesium': magnesium
                }
        self.trie = BaseTrie(u'.()+')
        self.trie[unicode(to_dpp(self.start_struc))] = 0
        self.active = [False]
        self.miniprune = [False]
        self.energies = [energy(self.sequence, self.start_struc,
            **self.physical_opts)]
        self.ti = []
        self.tj = []
        self.rates = []
        self.strucs = np.array([self.start_struc], dtype=np.int32)
        self.k_uni = 10
        self.kBT = self.physical_opts['temperature'] * 0.00198717
        self.accuracy = 0
        self.method = 'kawasaki'
        # self.method = 'metropolis'

    def compute_rate(self, G1, G2):
        """
        Compute rate to go from state with free energy G1 to state with
        free energy G2
        """
        if self.method == 'kawasaki':
            return self.k_uni * np.exp(-(G2 - G1) / (2 * self.kBT))
        else: # self.method = 'metropolis'
            if G1 < G2:
                return self.k_uni * np.exp(-(G2 - G1) / (self.kBT))
            else:
                return self.k_uni

    def get_rate_matrix(self, prune=False, miniprune=False):

        indices = np.array((self.ti, self.tj))
        rates = self.rates
        n_vals = len(self.active)

        ident_indices = (list(range(n_vals)), list(range(n_vals)))

        if prune or miniprune:
            indices = ([], [])
            rates = []
            ident_indices = ([], [])
            curactive = self.active

            if miniprune:
                curactive = np.ones((len(curactive),), dtype=np.bool_)
                curactive -= np.array(self.miniprune)

            j = 0
            for i, ca in enumerate(curactive):
                if ca:
                    ident_indices[0].append(i)
                    ident_indices[1].append(j)
                    j += 1

            n_vals = sum(curactive) + 1
            active_map = np.cumsum(curactive)
            for i, j, r in zip(self.ti, self.tj, self.rates):
                if curactive[i] and curactive[j]:
                    rates.append(r)
                    indices[0].append(active_map[i] - 1)
                    indices[1].append(active_map[j] - 1)
                elif curactive[j]:
                    rates.append(r)
                    indices[0].append(n_vals - 1)
                    indices[1].append(active_map[j] - 1)
                # elif self.active[j]:
                #     rates.append(r)
                #     indices[0].append(n_vals - 1)
                #     indices[1].append(active_map[j] - 1)

        rates = np.array(rates)
        rate_mat = csr_matrix((rates, np.array(indices)),
                (n_vals, n_vals))
        outrates = csr_matrix((-rates, (indices[1], indices[1])),
                (n_vals, n_vals))

        ident_mat = csr_matrix((np.ones(len(ident_indices[0])), ident_indices),
            (len(self.active), n_vals))

        rate_mat = rate_mat + outrates
        print("Rate matrix size: ", n_vals, " out of ", len(self.active))

        return rate_mat, ident_mat

    def integrate_2(self, endtime=None, prune=False, miniprune=False, n_points=1):
        n_vals = len(self.active)
        if endtime is None:
            endtime = self.endtime

        rate_mat, resmap = self.get_rate_matrix(prune=prune, miniprune=miniprune)
        # Get the full thing for now
        B = np.zeros((rate_mat.shape[0],))
        B[0] = 1
        ## TODO fix remapping of species back to original species

        # B = dia_matrix(np.eye(1))
        # B = B.todense()
        # print(B)
        # print(B.todense())

        # B1 = self.active
        # B2 = [not b for b in self.active]
        # B = np.array((B1, B2), dtype=np.float64)
        # print(B)

        res = expm_multiply(rate_mat, B.T, start=0, stop=endtime, num=n_points + 1)

        # print("Rate matrix", rate_mat.todense())
        # print("Rate shape", rate_mat.shape)
        # print("Result", res)
        # print("Result shape", res.shape)
        # print("Resultmap:", resmap.todense())
        # print("Resultmap shape", resmap.shape)

        finalres = resmap.dot(res.T)
        # print("Finalres:", finalres)
        # print(finalres)
        # print(res.shape)
        # print(finalres.shape)
        # print(self.strucs.shape)
        # for i, r in enumerate(finalres):
        #     print(i, to_dpp(self.strucs[i]), r)

        return finalres

    def integrate(self, endtime=None, prune=False, miniprune=False):
        n_vals = len(self.active)
        if endtime is None:
            endtime = self.endtime

        rate_mat, resmap = self.get_rate_matrix(prune=prune, miniprune=miniprune)
        # Get the full thing for now
        B = np.zeros((rate_mat.shape[0],))
        B[0] = 1
        ## TODO fix remapping of species back to original species

        # B = dia_matrix(np.eye(1))
        # B = B.todense()
        # print(B)
        # print(B.todense())

        # B1 = self.active
        # B2 = [not b for b in self.active]
        # B = np.array((B1, B2), dtype=np.float64)
        # print(B)

        res = expm_multiply(rate_mat, B.T, start=0, stop=endtime, num=2)

        # print("Rate matrix", rate_mat.todense())
        # print("Rate shape", rate_mat.shape)
        # print("Result", res)
        # print("Result shape", res.shape)
        # print("Resultmap:", resmap.todense())
        # print("Resultmap shape", resmap.shape)

        finalres = resmap.dot(res[1].T)
        # print("Finalres:", finalres)
        # print(finalres)
        # print(res.shape)
        # print(finalres.shape)
        # print(self.strucs.shape)
        # for i, r in enumerate(finalres):
        #     print(i, to_dpp(self.strucs[i]), r)

        return finalres

    def get_n_active(self):
        return sum(self.active)

    def get_n_minipruned(self):
        return sum(self.miniprune)

    def get_n_species(self):
        return len(self.active)

    def make_all_species(self):
        i = 0
        while i < len(self.active):
            print("i:", i, "N:", len(self.active))
            neighbors = find_neighbors(self.sequence,
                    self.strucs[i])
            self.active[i] = True
            cur_size = len(self.active)
            max_size = cur_size + len(neighbors)
            # print(to_dpp(self.strucs[i]), fval[i])
            if max_size > self.strucs.shape[0]:
                self.strucs.resize((max_size * 2, self.strucs.shape[1]))

            for n in neighbors:
                dp = unicode(to_dpp(n))
                include = True
                if dp in self.trie:
                    ind = self.trie[dp]
                else:
                    ind = len(self.active)

                    ene = energy(self.sequence, n, **self.physical_opts)

                    if ene < 1000:
                        self.trie[dp] = ind
                        self.active.append(False)
                        self.miniprune.append(False)
                        self.energies.append(energy(self.sequence, n,
                            **self.physical_opts))
                        self.strucs[ind] = n
                        # print(dp)
                    else:
                        include = False

                if include and self.energies[i] < 1000 and self.energies[ind] < 1000:
                    rate = self.compute_rate(self.energies[i],
                            self.energies[ind])
                    self.rates.append(rate)
                    self.ti.append(ind)
                    self.tj.append(i)
            i += 1

    def get_consistent_species(self, maxtime=100):

        timescales = np.logspace(-2, np.log10(maxtime), 20)

        for t in timescales:
            print("Starting timescale:", t)
            print('Active species:', self.get_n_active())
            print('Miniprune species:', self.get_n_minipruned())
            print('Total species:', self.get_n_species())
            # for i, s in enumerate(self.trie.keys()):
            #     print(i, s, self.trie[s])
            n_active = self.get_n_active()
            self.expand(maxtime=t)
            while n_active != self.get_n_active():
                n_active = self.get_n_active()
                self.expand(maxtime=t)


    def expand(self, maxtime=1):
        print("Integrating timescale: ", maxtime)
        res = self.integrate(maxtime, miniprune=True)
        print("Expanding set")

        # print(res.shape)
        # print(self.active)

        fval = res
        to_include = []
        inactive = 1 - np.array(self.active, dtype=np.int32)
        inactive_frac = np.dot(fval, inactive)

        print("Inactive:", inactive_frac)
        if inactive_frac > self.accuracy:
            vals = list(zip(fval, range(len(fval))))
            vals.sort()
            added_frac = 0.0

            i = len(fval) - 1
            while i >= 0 and added_frac < (inactive_frac - (self.accuracy * 0.98)):
                if not self.active[vals[i][1]]:
                    to_include.append(vals[i][1])
                    added_frac += vals[i][0]
                i -= 1
            print("Added frac",added_frac)

            j = 0
            forgotten_frac = 0.0
            while j < i and forgotten_frac + vals[j][0] < (self.accuracy * 1e-6):
                if not self.active[vals[j][1]]:
                    forgotten_frac += vals[j][0]
                    self.miniprune[vals[j][1]] = True
                j += 1


            for ti in to_include:
                neighbors = find_neighbors(self.sequence,
                        self.strucs[ti])
                self.active[ti] = True
                cur_size = len(self.active)
                max_size = cur_size + len(neighbors)
                # print(to_dpp(self.strucs[ti]), fval[ti])
                if max_size > self.strucs.shape[0]:
                    self.strucs.resize((max_size * 2, self.strucs.shape[1]))

                for n in neighbors:
                    dp = unicode(to_dpp(n))
                    include = True
                    if dp in self.trie:
                        ind = self.trie[dp]
                    else:
                        ind = len(self.active)

                        ene = energy(self.sequence, n, **self.physical_opts)

                        if ene < 1000:
                            self.trie[dp] = ind
                            self.active.append(False)
                            self.miniprune.append(False)
                            self.energies.append(energy(self.sequence, n,
                                **self.physical_opts))
                            self.strucs[ind] = n
                            print(dp)
                        else:
                            include = False

                    if include and self.energies[ti] < 1000 and self.energies[ind] < 1000:
                        rate = self.compute_rate(self.energies[ti],
                                self.energies[ind])
                        # print(self.energies[ind], self.energies[ti], ind, ti, "Rate: ", rate)
                        self.rates.append(rate)
                        self.ti.append(ind)
                        self.tj.append(ti)

    def get_eigs(self):

        rate_mat = self.get_rate_matrix(prune=True)

        evals_small, evecs_small = eigs(rate_mat, 10, which='LM',
                sigma=0)

        return (evals_small, evecs_small)


def test_neighbors():
    # struc = '(((((....))))).'
    # seq =   'GGGGGAAAACCCCCC'
    import pickle
    seq =   'GCGUCGCGUCGCUAUGC'
    struc = '.....((((....))))'
    graph = SimulationGraph(seq, struc)

    graph.make_all_species()
    for timescale in [500]:
        print(timescale)
        sys.stdout.flush()
        res = graph.integrate_2(endtime=timescale, n_points=300)
        f = open('test_{}.pickle'.format(timescale), 'wb')
        pickle.dump(res, f)
        f.close()

if __name__ == '__main__':
    test_neighbors()

