#!/usr/bin/env python
import sys
from optparse import OptionParser
import os
import os.path

import numpy as np
import random
import PyUtils


# strings for optionparser
usage_str = "usage: %prog [options]"
version_str = "%prog 0.9"

helix_range = [8, 14]
unpaired_range = [0, 8]
guaranteed_continuity = 6
guaranteed_connection = 8
def main(argv=None):
    if argv is None:
        argv = sys.argv
    parser = OptionParser(usage=usage_str,version=version_str)

    parser.add_option('-n', '--ntrials', dest='ntrials',
            help='Number of trials per length', default=60, type='int')
    parser.add_option('-l', '--lengths', dest='lengths',
            help='comma-separated list of lengths', default='100,200,400,800')
    parser.add_option('--nstrands', dest='nstrands',
            help='the number of strands', default=2, type='int')

    (options,args) = parser.parse_args(argv[1:])

    NUPACKHOME='/home/wolfe/mrc/nupack_core-hg'

    n_trials = options.ntrials
    lengths = [int(s) for s in options.lengths.split(',')]
    n_strands = options.nstrands

    # get input options
    if len(args) != 0:
        sys.stderr.write("No arguments\n")
        parser.print_help()
        return 2
    path = 'strucs'

    try:
        os.makedirs(path)
    except os.error:
        pass

    dir_format = os.path.join(path, '%04i')
    name_format = os.path.join(path, '%04i', '%04i.fold')


    for length in lengths:
        try:
            print dir_format % length
            os.makedirs(dir_format % length)
        except:
            pass

        for trial in range(n_trials):
            name = name_format % (length, trial)

            struc = get_engineered_structure(length, n_strands)
            constraint = 'N' * length

            f = open(name, 'w')
            f.write('%s\n%s\n'%(struc, constraint))
            f.close()

class Loop:
    def __init__(self):
        self.children = []
        self.type = 'loop'

    def add_child(self, comp):
        self.children.append(comp)

    def __str__(self):
        return ''.join([str(ch) for ch in self.children])

    def get_weight(self):
        x = int(float(len(self.children)) / 2)
        weight = [0.1, 1.0, 0.25, 0.05, 0.0, 0.0, 0.0, 0.0]
        rw = weight[x]
        return rw

class Helix:
    def __init__(self, length, child):
        self.child = Break()
        self.type = 'helix'
        self.length = length
        self.child = child

    def __str__(self):
        return ''.join(['(' * self.length, str(self.child), ')' * self.length])

class Break:
    def __init__(self):
        self.child = None
        self.type = 'break'

    def __str__(self):
        return '+'

class Unpaired:
    def __init__(self, length):
        self.child = None
        self.type = 'unpaired'
        self.length = length

    def __str__(self):
        return '.' * self.length

def pick_loop(loops):
    weights = [l.get_weight() for l in loops]

    wsum = sum(weights)
    rnum = random.random()
    rnum *= wsum
    csum = np.cumsum(weights)
    choice = np.argwhere(csum > rnum)
    return loops[choice[0][0]]

def dpp_to_plist(struc):
    stack = []
    breaks = []
    j = 0
    plist = [-5] * (len(struc) - struc.count('+'))
    for c in struc:
        if c == '(':
            stack.append(j)
            j += 1
        elif c == ')':
            k = stack[-1]
            stack.pop()
            plist[j] = k
            plist[k] = j
            j += 1
        elif c == '.':
            plist[j] = -1
            j += 1
        elif c == '+':
            breaks.append(j)

    return plist, breaks

def is_valid(struc):
    plist, breaks = dpp_to_plist(struc)
    valid = True

    g_len = guaranteed_continuity - 1

    for bi, b in enumerate(breaks):
        c = plist[b]
        if c != -1:
            for i in range(g_len):
                if plist[b + i] != plist[b + i + 1] + 1:
                    valid = False
                    # print struc
                    # print ' '*(b+i+bi+1) + '^'
                    # print b+i, b+i+1
                    # print plist[b+i], plist[b+i+1]

        c2 = plist[b - 1]
        if c2 != -1:
            for i in range(g_len):
                if plist[b - i - 1] != plist[b - i - 2] - 1:
                    valid = False

    if not valid:
        return valid

    identity = []
    lastm = 0
    for i, m in enumerate(breaks):
        identity += [i] * (m - lastm)
        lastm = m

    identity += [len(breaks)]*(len(plist) - lastm)

    n_strands = len(breaks) + 1
    connection_count = np.zeros((n_strands, n_strands))

    for i, j, in enumerate(plist):
        if j >= 0:
            id1 = identity[i]
            id2 = identity[j]
            connection_count[id1, id2] += 1

    connected_sets = dict([(i, i) for i in range(n_strands)])

    for i in range(n_strands):
        for j in range(n_strands):
            if connection_count[i, j] > guaranteed_connection:
                val = min(connected_sets[i], connected_sets[j])
                connected_sets[i] = val
                connected_sets[j] = val

    for i, j in connected_sets.iteritems():
        if j != 0:
            valid = False

    hairpinsize = -1
    for c in struc:
        if c == '(':
            hairpinsize = 0
        elif hairpinsize >= 0 and c == '.':
            hairpinsize += 1
        elif c == ')' and 0 <= hairpinsize < 3:
            valid = False
        else:
            hairpinsize = -1

    return valid

def get_engineered_structure(length, n_strands):
    strand_len = length / n_strands
    if length % n_strands > 0:
        strand_len += 1

    broken_struc = ".+."

    while not is_valid(broken_struc):
        l0 = Loop()
        l0.add_child(Unpaired(random.randint(unpaired_range[0], unpaired_range[1])))
        loops = [l0]
        c_len = len(str(l0))

        while c_len < length:
            c_loop = pick_loop(loops)
            max_len = length - c_len

            if (max_len - 3) / 2 < helix_range[0]:
                c_loop.add_child(Unpaired(max_len))
                c_len = c_len + max_len
            else:
                min_h_len = helix_range[0]
                max_h_len = min((max_len - 3)/2, helix_range[1])

                helix_length = random.randint(min_h_len, max_h_len)
                c_c_loop = Loop()
                c_h = Helix(helix_length, c_c_loop)
                c_loop.add_child(c_h)

                new_len_1 = c_len + helix_length * 2
                new_n_allowed = length - new_len_1

                min_u_len = min(new_n_allowed, unpaired_range[0])
                max_u_len = min(new_n_allowed, unpaired_range[1])

                u_len = random.randint(min_u_len, max_u_len)
                c_c_loop.add_child(Unpaired(u_len))

                new_len_2 = c_len + helix_length * 2 + u_len
                new_n_allowed_2 = length - new_len_2
                max_u_len_2 = min(new_n_allowed_2, unpaired_range[1])
                min_u_len_2 = min(new_n_allowed_2, 0)

                u_len_2 = random.randint(min_u_len_2, max_u_len_2)
                c_loop.add_child(Unpaired(u_len_2))
                c_len = c_len + helix_length * 2 + u_len + u_len_2
                loops.append(c_c_loop)

        pre_struc = str(loops[0])
        pre_struc_list = [pre_struc[(i*strand_len):(i+1)*strand_len] for i in range(n_strands)]
        broken_struc = '+'.join(pre_struc_list)

    print broken_struc, len(broken_struc)
    return broken_struc


if __name__ == "__main__":
    sys.exit(main())
