# -*- coding: utf-8 -*-
import cvxpy
import itertools
import matplotlib.pyplot as plt
import numpy as np
import operfact
from operfact import regularizers as regs
from operfact import measurements, operators, solvers, utils
from scipy.interpolate import interp1d
from scipy.interpolate import UnivariateSpline


# Load splib06a data
with np.load('./splib06a/splib06a.npz') as npzfile:
    labels = npzfile['labels'].item()
    names = npzfile['names'].item()
    data = npzfile['data'].item()


# Create test abundance matrix
cols = [5, 20, 35, 50, 65]
rows = cols

A = np.zeros((75, 75, 5))
blk = np.ones((5,5))
for i in range(5):
    combos = list(itertools.combinations(range(5), i+1))
    if len(combos) == 1:
        combos = combos*5
    combo_ixs = np.random.choice(len(combos), 5, replace=False)
    for j in range(5):
        for k in range(i+1):
            A[rows[i]:rows[i]+5, cols[j]:cols[j]+5, combos[combo_ixs[j]][k]] = 1.0/(i+1)

f, ax = plt.subplots(1, 5)
for i in range(5):
    ax[i].imshow(A[:,:,i])
f.set_figwidth(15)

np.savez_compressed('splib_test.npz', A=A)


# Create curated spectral library
def resample_and_smooth(x, y, low, high, num):
    xnan = np.isnan(x)
    ynan = np.isnan(y)
    y[ynan] = 0.
    ixs = np.logical_and(~xnan, ~ynan)
    f = UnivariateSpline(x[ixs], y[ixs], k=3, s=10, ext='zeros')
    x_resamp = np.linspace(low, high, num=num)
    y_resamp = f(x_resamp)
    return (x_resamp, y_resamp)

low = 0.4
high = 2.5
num = 224

data_concat = np.zeros((num, sum([len(data[key]) for key in labels])))
names_concat = []
keep = 0
for key in labels:
    for j in range(len(data[key])):
        wavl = data[key][j][:,0]
        refl = data[key][j][:,1]
        try:
            (wavl_resamp, refl_resamp) = resample_and_smooth(wavl, refl, low, high, num)
        except ValueError:
            continue
        data_concat[:,keep] = refl_resamp
        keep += 1
        names_concat.append(names[key][j])
data_concat = data_concat[:,0:keep]

# Trim
# iteratively removes one of the pair of columns of `data_concat` with the highest absolute cross-correlation
# from the pair, removes the column that has the highest sum of cross-correlations
ncols = 25  # number of columns desired
cormat = np.corrcoef(data_concat, rowvar=False)
np.fill_diagonal(cormat, 0.0)
data_trim = data_concat.copy()
names_trim = names_concat.copy()
failsafe = 0
while data_trim.shape[1] > ncols:
    max_ix = np.unravel_index(np.argmax(np.abs(cormat)), cormat.shape)
    sum_rows = np.linalg.norm(cormat, axis=0)
    sum_cols = np.linalg.norm(cormat, axis=1)
    assert sum_rows.shape == (cormat.shape[0],)
    del_ix = max_ix[0] if sum_rows[max_ix[0]] > sum_cols[max_ix[1]] else max_ix[1]
    data_trim = np.delete(data_trim, del_ix, axis=1)
    names_trim = np.delete(names_trim, del_ix)
    cormat = np.delete(cormat, del_ix, axis=0)
    cormat = np.delete(cormat, del_ix, axis=1)
    failsafe += 1
    if failsafe > 2000:
        print('warn: failsafe')
        break

np.savez_compressed('splib_curate.npz', data=data_trim, names=names_trim)
