# -*- coding: utf-8 -*-
import cvxpy
import itertools
import matplotlib.pyplot as plt
import numpy as np
from operfact import regularizers as regs
from operfact import measurements, operators, solvers, utils
import time


# Load abundance matrix
with np.load('splib_test.npz') as npzfile:
    A = npzfile['A']

# Load curated spectral library
with np.load('splib_curate.npz') as npzfile:
    data = npzfile['data']
    names = npzfile['names']

# Generate test image using first 5 endmembers
X = data[:, 0:5]  # just pull the first 5 columns from the library
alpha = np.eye(5)  # corresponding library coefficients


# Generate Figure 6.2
f, ax = plt.subplots(2, A.shape[2])
for i in range(A.shape[2]):
    ax[0, i].imshow(A[:,:,i], cmap='Greys')
    ax[0, i].set_title('Abundance matrix {0}'.format(i+1))
    ax[0, i].set_xticklabels([])
    ax[0, i].set_yticklabels([])
    ax[1, i].plot(X[:,i])
    ax[1, i].set_title('Endmember {0}'.format(i+1))
    ax[1, i].set_xlabel('wavelength #')
f.set_size_inches(3*A.shape[2]*1.5, 3*2*1.5)
f.savefig('figures/hsi-test-image.pdf', bbox_inches='tight', pad_inches=0)


## Generate Figure 6.3

# Initialization code
# Helper function
def db_to_sigma(db):
    # want db = 10*log10(trueNorm^2/(prod(shape)*sigma^2))
    return 10.0**(-db/20.0)*(YNORM/np.sqrt(np.prod(SHAPE)))

def altminsolve_noquad(problem):
    return solvers.altminsolve(problem, noquad=True)

# Total variation regularizer
class TVRegularizer(regs.Regularizer):
    def __init__(self):
        # This doesn't work when set as a class attribute.
        self.norm_mat = cvxpy.tv

    def __repr__(self):
        return 'TV'

    def __call__(self, X):
        return self.norm_mat(X).value


# Measurement operator for the spectral library case
class SpecLibMeasurement(measurements.Measurement):
    """Measurement operator for the spectral library problem"""

    def __init__(self, oper_shape, D):
        assert oper_shape[3] == 1  # works on matrix \otimes vector
        assert D.shape[1] == oper_shape[2]  # need D to be compatible with right factor
        self.shape = oper_shape
        self.nmeas = np.prod(oper_shape[0:2]) * D.shape[0]
        self.D = D

    def apply(self, oper):
        if isinstance(oper, operators.DyadsOperator):
            return sum([np.outer(oper.lfactors[r].flatten(order='F'),
                                 self.D @ oper.rfactors[r])
                        for r in range(oper.nfactors)]).flatten(order='F')
        else:
            raise NotImplementedError

    def cvxapply(self, oper):
        return cvxpy.vec(sum([cvxpy.vec(oper.lfactors[r]) *
                              cvxpy.vec(np.matrix(self.D) * oper.rfactors[r]).T
                              for r in range(oper.nfactors)]))

    def matapply(self, mat):
        raise NotImplementedError

    def asOperator(self):
        raise NotImplementedError

    def initfromoper(self, oper):
        raise NotImplementedError

    def initfrommeas(self, meas):
        _y = meas.reshape((self.shape[0], self.shape[1], self.D.shape[0]), order='F')
        out = operators.ArrayOperator(np.zeros(self.shape))
        for i, j, k in itertools.product(*[range(_y.shape[d]) for d in range(3)]):
            out[i, j, :, 0] += _y[i, j, k] * self.D[k, :]
        return out


# Constants for experiment
SHAPE = (A.shape[0], A.shape[1], 224, 1)
SHAPE_LIB = (A.shape[0], A.shape[1], data.shape[1], 1)
MEASOBJ = measurements.IdentityMeasurement(SHAPE)
MEASOBJ_LIB = SpecLibMeasurement(SHAPE_LIB, data)
YDYADS = operators.DyadsOperator([A[:, :, i] for i in range(5)], [X[:, i:i + 1] for i in range(5)])
YTRUE = YDYADS.asArrayOperator()
YNORM = np.linalg.norm(YTRUE.flatten(), ord=2)

def get_measobj(params):
    meas_type = params['measurements']
    if meas_type is measurements.IdentityMeasurement:
        return (MEASOBJ, SHAPE)
    elif meas_type is SpecLibMeasurement:
        return (MEASOBJ_LIB, SHAPE_LIB)

def solve_problem(params, seed=None, redirect=None):
    return params['hsisolver'](params, seed, redirect=None)


## HSI solvers
# Denoise using operator norms
def opersolve(params, seed=None, redirect=None):
    rows = []
    if seed is None:
        seed = int(time.time())
    np.random.seed(seed=seed)
    # Generate noise
    sigma = db_to_sigma(params['snr'])
    noise = np.random.normal(size=(np.prod(SHAPE))) * sigma
    # Create the problem
    (measobj, shape) = get_measobj(params)
    prob = solvers.Problem()
    prob.shape = shape
    prob.measurementobj = measobj
    prob.measurementvec = prob.measurementobj.apply(YDYADS) + noise
    prob.norm = params['reg']
    prob.penconst = regs.penconst_denoise(SHAPE, sigma, prob.norm) * (2 ** params['offset'])  # always using SHAPE
    prob.solveropts = {'verbose': False, 'warm_start': True, 'max_iters': params['inner_iters'], 'eps': 1e-3}
    prob.solver = cvxpy.SCS
    prob.relconvergetol = params['relconvtol']
    prob.maxiters = params['outer_iters']
    prob.rank = params['rank']
    # Solve the problem
    outlist = params['opersolver'](prob)
    outlist = outlist if utils.istol(outlist) else [outlist, ]
    for out in outlist:
        abs_error = np.linalg.norm(out.recovered.asArrayOperator().flatten(order='F') - YTRUE.flatten(order='F'), ord=2)
        rel_error = abs_error/YNORM
        rsdr = -20.0*np.log10(rel_error)
        row = {'seed': seed,
               'snr': params['snr'],
               'measurements': type(prob.measurementobj).__name__,
               'regularizer': str(prob.norm),
               'penconst': prob.penconst,
               'penconst_offset': params['offset'],
               'solve_rank': prob.rank,
               'hsisolver': params['hsisolver'].__name__,
               'opersolver': params['opersolver'].__name__,
               'max_outer_iters': out.maxiters,
               'max_inner_iters': params['inner_iters'],
               'relconvergetol': out.relconvtol,
               'relchange': out.relchange,
               'abs_error': abs_error,
               'rel_error': rel_error,
               'rsdr': rsdr,
               'time': out.total_time,
               'outer_iters': out.outer_iters,
               'recovered': out.recovered}
        rows.append(row)
    return rows


# Denoise using SVD
def kpsvdsolve(params, seed=None, redirect=None):
    assert params['measurements'] is measurements.IdentityMeasurement
    if seed is None:
        seed = int(time.time())
    np.random.seed(seed=seed)
    rows = []
    # Generate noise
    sigma = db_to_sigma(params['snr'])
    noise = np.random.normal(size=SHAPE)*sigma
    Ynoisy = YTRUE + noise

    # Compute KP SVDs
    Utrue, Strue, Vtrue = operators.kpsvd(YTRUE)
    time_start = time.time()
    U, S, V = operators.kpsvd(Ynoisy)
    time_end = time.time()
    # FIXME: do we want rank to be an iterable here? it is more efficient
    for k in params['rank']:
        recovered = operators.DyadsOperator([(U[:,r]*S[r]).reshape(SHAPE[0:2], order='F') for r in range(k)],
                                            [V[r, :].T.reshape(SHAPE[2:4], order='F') for r in range(k)])
        abs_error = np.linalg.norm((Utrue[:, 0:SHAPE[2]] @ np.diag(Strue) @ Vtrue[0:SHAPE[2],:]) -
                                   (U[:, 0:k] @ np.diag(S[0:k]) @ V[0:k, :]), ord='fro')
        rel_error = abs_error/YNORM
        rsdr = -20.0*np.log10(rel_error)
        row = {'seed': seed,
               'snr': params['snr'],
               'measurements': params['measurements'].__name__,
               'regularizer': 'KP-SVD',
               'penconst': None,
               'penconst_offset': None,
               'solve_rank': k,
               'hsisolver': params['hsisolver'].__name__,
               'opersolver': None,
               'max_outer_iters': None,
               'max_inner_iters': None,
               'relconvergetol': None,
               'relchange': None,
               'abs_error': abs_error,
               'rel_error': rel_error,
               'rsdr': rsdr,
               'time': time_end - time_start,
               'outer_iters': None,
               'recovered': recovered}
        rows.append(row)
    return rows

# Generate Figure 6.3[top]
Utrue, Strue, Vtrue = operators.kpsvd(YTRUE)
nfactors = 5
print('nfactors:', nfactors)
k = nfactors
recovered = operators.DyadsOperator([(Utrue[:,r]*Strue[r]).reshape(SHAPE[0:2], order='F') for r in range(k)],
                                    [Vtrue[r, :].T.reshape(SHAPE[2:4], order='F') for r in range(k)])
nfactors = recovered.nfactors
print("Singular values:", Strue[0:5])
f, ax = plt.subplots(2, nfactors)
for i in range(nfactors):
    ax[0, i].imshow(recovered.lfactors[i], cmap='Greys')
    ax[0, i].set_title('$X_{0}$'.format(i+1))
    ax[0, i].set_xticklabels([])
    ax[0, i].set_yticklabels([])
    ax[1, i].plot(recovered.rfactors[i])
    ax[1, i].set_title('$y_{0}$'.format(i+1))
    ax[1, i].set_xlabel('wavelength #')
f.set_size_inches(3*nfactors*1.5, 3*2*1.5)
f.savefig('figures/hsi-test-kpsvd-rank{0}.pdf'.format(nfactors), bbox_inches='tight', pad_inches=0)


# Generate Figure 6.3[middle]
def getnoisy(snr, seed):
    np.random.seed(seed=seed)
    sigma = db_to_sigma(snr)
    noise = np.random.normal(size=SHAPE)*sigma
    return YTRUE + noise

Unoisy, Snoisy, Vnoisy = operators.kpsvd(getnoisy(10.0, 1))
nfactors = 5
print('nfactors:', nfactors)
k = nfactors
recovered = operators.DyadsOperator([(Unoisy[:,r]*Snoisy[r]).reshape(SHAPE[0:2], order='F') for r in range(k)],
                                    [Vnoisy[r, :].T.reshape(SHAPE[2:4], order='F') for r in range(k)])
nfactors = recovered.nfactors
f, ax = plt.subplots(2, nfactors)
for i in range(nfactors):
    ax[0, i].imshow(recovered.lfactors[i], cmap='Greys')
    ax[0, i].set_title('$X_{0}$'.format(i+1))
    ax[0, i].set_xticklabels([])
    ax[0, i].set_yticklabels([])
    ax[1, i].plot(recovered.rfactors[i])
    ax[1, i].set_title('$y_{0}$'.format(i+1))
    ax[1, i].set_xlabel('wavelength #')
f.set_size_inches(3*nfactors*1.5, 3*2*1.5)
f.savefig('figures/hsi-kpsvdsolve-rank{0}.pdf'.format(nfactors), bbox_inches='tight', pad_inches=0)


# Generate Figure 6.3[bottom]
PARAMS = {'snr': 10.0,
          'reg': None,
          'offset': 0,
          'relconvtol': 1e-3,
          'outer_iters': 10,
          'inner_iters': 2500,
          'rank': 0,
          'hsisolver': opersolve,
          'opersolver': altminsolve_noquad,
          'measurements': measurements.IdentityMeasurement}

PARAMS['reg'] = regs.NucNorm_Prod(cvxpy.tv, regs.norm_l2)
PARAMS['rank'] = 5
PARAMS['offset'] = -12
row = opersolve(PARAMS, seed=1)[0]

print('reg:', row['regularizer'], 'rank:', row['solve_rank'], 'rsdr:', row['rsdr'])
recovered = row['recovered']
nfactors = recovered.nfactors
f, ax = plt.subplots(2, nfactors)
for i in range(nfactors):
    ax[0, i].imshow(recovered.lfactors[i], cmap='Greys')
    ax[0, i].set_title('$X_{0}$'.format(i+1))
    ax[0, i].set_xticklabels([])
    ax[0, i].set_yticklabels([])
    ax[1, i].plot(recovered.rfactors[i])
    ax[1, i].set_title('$y_{0}$'.format(i+1))
    ax[1, i].set_xlabel('wavelength #')
f.set_size_inches(3*nfactors*1.5, 3*2*1.5)
f.savefig('figures/hsi-opersolve-tvl2-rank{0}.pdf'.format(nfactors), bbox_inches='tight', pad_inches=0)
