# -*- coding: utf-8 -*-
import cvxpy
import itertools
import numpy as np
from operfact import regularizers as regs
from operfact import measurements, operators, solvers, utils
import time


## Load data
# Load abundance matrix
with np.load('splib_test.npz') as npzfile:
    A = npzfile['A']  # use the same abundance matrix

# Load curated spectral library
with np.load('splib_curate.npz') as npzfile:
    data = npzfile['data']

# Generate test image using first 5 endmembers
X = data[:, 0:5]  # just pull the first 5 columns from the library
alpha = np.eye(5)  # corresponding library coefficients


## Experiment setup and functions
# Helper function
def db_to_sigma(db):
    # want db = 10*log10(trueNorm^2/(prod(shape)*sigma^2))
    return 10.0**(-db/20.0)*(YNORM/np.sqrt(np.prod(SHAPE)))

def altminsolve_noquad(problem):
    return solvers.altminsolve(problem, noquad=True)

# Total variation regularizer
class TVRegularizer(regs.Regularizer):
    def __init__(self):
        # This doesn't work when set as a class attribute.
        self.norm_mat = cvxpy.tv

    def __repr__(self):
        return 'TV'

    def __call__(self, X):
        return self.norm_mat(X).value

# Measurement operator for the spectral library case
class SpecLibMeasurement(measurements.Measurement):
    """Measurement operator for the spectral library problem"""
    def __init__(self, oper_shape, D):
        assert oper_shape[3] == 1  # works on matrix \otimes vector
        assert D.shape[1] == oper_shape[2]  # need D to be compatible with right factor
        self.shape = oper_shape
        self.nmeas = np.prod(oper_shape[0:2])*D.shape[0]
        self.D = D

    def apply(self, oper):
        if isinstance(oper, operators.DyadsOperator):
            return sum([np.outer(oper.lfactors[r].flatten(order='F'),
                                 self.D @ oper.rfactors[r])
                        for r in range(oper.nfactors)]).flatten(order='F')
        else:
            raise NotImplementedError

    def cvxapply(self, oper):
        return cvxpy.vec(sum([cvxpy.vec(oper.lfactors[r]) *
                              cvxpy.vec(np.matrix(self.D)*oper.rfactors[r]).T
                              for r in range(oper.nfactors)]))

    def matapply(self, mat):
        raise NotImplementedError

    def asOperator(self):
        raise NotImplementedError

    def initfromoper(self, oper):
        raise NotImplementedError

    def initfrommeas(self, meas):
        _y = meas.reshape((self.shape[0], self.shape[1], self.D.shape[0]), order='F')
        out = operators.ArrayOperator(np.zeros(self.shape))
        for i, j, k in itertools.product(*[range(_y.shape[d]) for d in range(3)]):
            out[i,j,:,0] += _y[i,j,k]*self.D[k,:]
        return out

# test the initfrommeas (adjoint) function
def test_adjoint():
    shape = (8, 8, 4, 1)
    D_shape = (8, shape[2])
    D = np.random.normal(size=(D_shape))
    ndyads = 2
    left = [np.random.normal(size=shape[0:2]) for i in range(ndyads)]
    right = [np.random.normal(size=shape[2:4]) for i in range(ndyads)]
    dyop = operators.DyadsOperator(left, right)
    meas = SpecLibMeasurement(shape, D)
    fwd = meas.apply(dyop)
    y = np.random.normal(size=shape[0]*shape[1]*D_shape[0])
    fwd_ip = np.sum(fwd*y)
    adj = meas.initfrommeas(y)
    adj_ip = np.sum(dyop.asArrayOperator()*adj)
    assert np.isclose(fwd_ip, adj_ip)

# test that apply and cvx apply agree
def test_apply_cvxapply():
    shape = (8, 8, 4, 1)
    D_shape = (8, shape[2])
    D = np.random.normal(size=(D_shape))
    ndyads = 2
    left = [np.random.normal(size=shape[0:2]) for i in range(ndyads)]
    right = [np.random.normal(size=shape[2:4]) for i in range(ndyads)]
    dyop = operators.DyadsOperator(left, right)
    meas = SpecLibMeasurement(shape, D)
    mv = np.matrix(meas.apply(dyop)).T  # handle cvxpy returning np.matrix objects
    lparams = [cvxpy.Parameter(*shape[0:2]) for i in range(ndyads)]
    rparams = [cvxpy.Parameter(*shape[2:4]) for i in range(ndyads)]
    for i in range(ndyads):
        lparams[i].value = left[i]
        rparams[i].value = right[i]
    dyopcvx = operators.DyadsOperator(lparams, rparams)
    mvcvx = meas.cvxapply(dyopcvx)
    assert np.sum(np.isclose(mv, mvcvx.value)) == meas.nmeas  # really should use an assert almost equals (checks size)

# Constants for experiment
SHAPE = (A.shape[0], A.shape[1], 224, 1)
SHAPE_LIB = (A.shape[0], A.shape[1], data.shape[1], 1)
MEASOBJ = measurements.IdentityMeasurement(SHAPE)
MEASOBJ_LIB = SpecLibMeasurement(SHAPE_LIB, data)
YDYADS = operators.DyadsOperator([A[:,:,i] for i in range(5)], [X[:, i:i+1] for i in range(5)])
YTRUE = YDYADS.asArrayOperator()
YNORM = np.linalg.norm(YTRUE.flatten(), ord=2)

def get_measobj(params):
    meas_type = params['measurements']
    if meas_type is measurements.IdentityMeasurement:
        return (MEASOBJ, SHAPE)
    elif meas_type is SpecLibMeasurement:
        return (MEASOBJ_LIB, SHAPE_LIB)

def solve_problem(params, seed=None, redirect=None):
    return params['hsisolver'](params, seed, redirect=None)

## HSI solvers
# Denoise using operator norms
def opersolve(params, seed=None, redirect=None):
    rows = []
    if seed is None:
        seed = int(time.time())
    np.random.seed(seed=seed)
    # Generate noise
    sigma = db_to_sigma(params['snr'])
    noise = np.random.normal(size=(np.prod(SHAPE))) * sigma
    # Create the problem
    (measobj, shape) = get_measobj(params)
    prob = solvers.Problem()
    prob.shape = shape
    prob.measurementobj = measobj
    prob.measurementvec = prob.measurementobj.apply(YDYADS) + noise
    prob.norm = params['reg']
    prob.penconst = regs.penconst_denoise(SHAPE, sigma, prob.norm) * (2 ** params['offset'])  # always using SHAPE
    prob.solveropts = {'verbose': False, 'warm_start': True, 'max_iters': params['inner_iters'], 'eps': 1e-3}
    prob.solver = cvxpy.SCS
    prob.relconvergetol = params['relconvtol']
    prob.maxiters = params['outer_iters']
    prob.rank = params['rank']
    # Solve the problem
    outlist = params['opersolver'](prob)
    outlist = outlist if utils.istol(outlist) else [outlist, ]
    for out in outlist:
        abs_error = np.linalg.norm(out.recovered.asArrayOperator().flatten(order='F') - YTRUE.flatten(order='F'), ord=2)
        rel_error = abs_error/YNORM
        rsdr = -20.0*np.log10(rel_error)
        row = {'seed': seed,
               'snr': params['snr'],
               'measurements': type(prob.measurementobj).__name__,
               'regularizer': str(prob.norm),
               'penconst': prob.penconst,
               'penconst_offset': params['offset'],
               'solve_rank': prob.rank,
               'hsisolver': params['hsisolver'].__name__,
               'opersolver': params['opersolver'].__name__,
               'max_outer_iters': out.maxiters,
               'max_inner_iters': params['inner_iters'],
               'relconvergetol': out.relconvtol,
               'relchange': out.relchange,
               'abs_error': abs_error,
               'rel_error': rel_error,
               'rsdr': rsdr,
               'time': out.total_time,
               'outer_iters': out.outer_iters}
        rows.append(row)
    return rows

# Denoise bands individually
def indivsolve(params, seed=None, redirect=None):
    assert params['measurements'] is measurements.IdentityMeasurement
    if seed is None:
        seed = int(time.time())
    np.random.seed(seed=seed)
    # Generate noise
    sigma = db_to_sigma(params['snr'])
    noise = np.random.normal(size=SHAPE)*sigma
    Ynoisy = YTRUE + noise
    Yrecon = np.empty(SHAPE)
    time_start = time.time()
    for k in range(SHAPE[2]):
        prob = solvers.Problem()
        prob.shape = (SHAPE[0], 1, SHAPE[1], 1)
        prob.measurementobj = measurements.IdentityMeasurement((SHAPE[0], 1, SHAPE[1], 1))
        prob.measurementvec = np.array(Ynoisy[:,:,k,0].flatten(order='F'))
        prob.norm = params['reg']
        prob.penconst = regs.penconst_denoise((SHAPE[0], 1, SHAPE[1], 1), sigma, prob.norm)*(2**params['offset'])
        prob.solveropts = {'verbose': False, 'max_iters': params['inner_iters']}
        prob.solver = cvxpy.SCS
        out = solvers.matsolve(prob)
        Yrecon[:,:,k,0] = out.recovered[:,0,:,0]
    time_end = time.time()
    abs_error = np.linalg.norm(Yrecon.flatten(order='F') - YTRUE.flatten(order='F'), ord=2)
    rel_error = abs_error/YNORM
    rsdr = -20.0*np.log10(rel_error)
    row = {'seed': seed,
           'snr': params['snr'],
           'measurements': type(prob.measurementobj).__name__,
           'regularizer': str(prob.norm),
           'penconst': prob.penconst,
           'penconst_offset': params['offset'],
           'solve_rank': None,
           'hsisolver': params['hsisolver'].__name__,
           'opersolver': 'mat_SCS',
           'max_outer_iters': None,
           'max_inner_iters': params['inner_iters'],
           'relconvergetol': None,
           'relchange': None,
           'abs_error': abs_error,
           'rel_error': rel_error,
           'rsdr': rsdr,
           'time': time_end - time_start,
           'outer_iters': None}
    return [row, ]


# Denoise using SVD
def kpsvdsolve(params, seed=None, redirect=None):
    assert params['measurements'] is measurements.IdentityMeasurement
    if seed is None:
        seed = int(time.time())
    np.random.seed(seed=seed)
    rows = []
    # Generate noise
    sigma = db_to_sigma(params['snr'])
    noise = np.random.normal(size=SHAPE)*sigma
    Ynoisy = YTRUE + noise

    # Compute KP SVDs
    Utrue, Strue, Vtrue = operators.kpsvd(YTRUE)
    time_start = time.time()
    U, S, V = operators.kpsvd(Ynoisy)
    time_end = time.time()
    for k in params['rank']:
        abs_error = np.linalg.norm((Utrue[:, 0:SHAPE[2]] @ np.diag(Strue) @ Vtrue[0:SHAPE[2],:]) -
                                   (U[:, 0:k] @ np.diag(S[0:k]) @ V[0:k, :]), ord='fro')
        rel_error = abs_error/YNORM
        rsdr = -20.0*np.log10(rel_error)
        row = {'seed': seed,
               'snr': params['snr'],
               'measurements': params['measurements'].__name__,
               'regularizer': 'KP-SVD',
               'penconst': None,
               'penconst_offset': None,
               'solve_rank': k,
               'hsisolver': params['hsisolver'].__name__,
               'opersolver': None,
               'max_outer_iters': None,
               'max_inner_iters': None,
               'relconvergetol': None,
               'relchange': None,
               'abs_error': abs_error,
               'rel_error': rel_error,
               'rsdr': rsdr,
               'time': time_end - time_start,
               'outer_iters': None}
        rows.append(row)
    return rows


# Denoise using K-SVD
def splrsolve(params, seed=None, redirect=None):
    """Implement the Spa+Lr HSI denoiser from Zhao & Yang."""
    def mean_metrics(im_true, im_test):
        from skimage.measure import compare_psnr, compare_ssim
        im_max = np.max(im_true)
        im_min = np.min(im_true)
        print(im_max, im_min)
        dynamic_range = 1  # forced for this image (should probably do normalization for all float images)
        true_reshape = im_true.reshape(YTRUE.shape, order='C')
        test_reshape = im_test.reshape(YTRUE.shape, order='C')
        psnrs = [compare_psnr(true_reshape[:,:,i], test_reshape[:,:,i], dynamic_range=dynamic_range)
                 for i in range(YTRUE.shape[2])]
        mssim = compare_ssim(true_reshape, test_reshape, dynamic_range=dynamic_range, multichannel=True)
        return {'mpsnr': np.mean(psnrs), 'mssim': mssim}
    assert params['measurements'] is measurements.IdentityMeasurement
    if seed is None:
        seed = int(time.time())
    np.random.seed(seed=seed)
    # Import from sklearn
    from sklearn.decomposition import MiniBatchDictionaryLearning
    from sklearn.feature_extraction.image import extract_patches_2d
    from sklearn.feature_extraction.image import reconstruct_from_patches_2d
    # Flatten to obtain matricized HSI
    Ymat = YTRUE.reshape((YTRUE.shape[0]*YTRUE.shape[1], YTRUE.shape[2]), order='C')
    # Add noise and verify rsdr
    sigma = db_to_sigma(params['snr'])
    noise = np.random.normal(size=Ymat.shape)*sigma
    Ynoisy = Ymat + noise
    abs_error = np.linalg.norm(Ynoisy.flatten(order='F') - Ymat.flatten(order='F'), ord=2)
    rel_error = abs_error / YNORM
    rsdr = -20.0 * np.log10(rel_error)
    mets = mean_metrics(Ymat, Ynoisy)
    print('orignal sdr:', rsdr, 'original mpsnr:', mets['mpsnr'], 'original mssim:', mets['mssim'])
    # Set parameters
    gamma = (30.0/sigma)/(10.0**2.5)  # positing that gamma is too large
    lamb = 100
    mu = lamb*np.max(np.sqrt(Ymat.shape))*sigma/6.5
    print('gamma:', gamma, 'lambda:', lamb, 'mu:', mu, 'mu/lambda:', mu/lamb)
    # Learn the dictionary on the true image (simplifies our testing; not practical)
    print('learning dictionary')
    time_start = time.time()
    patch_size = (8, 8)
    patches = extract_patches_2d(Ymat, patch_size)
    patches = patches.reshape(patches.shape[0], -1)
    patches -= np.mean(patches, axis=0)
    patches /= np.std(patches, axis=0)
    Dobj = MiniBatchDictionaryLearning(n_components=128, alpha=1, n_iter=500).fit(patches)
    # Alternating optimization
    k_max = params['outer_iters']
    k = 0
    X = Ynoisy
    abs_errors = np.zeros((k_max,))
    rel_errors = np.zeros((k_max,))
    rsdrs = np.zeros((k_max,))
    while k < k_max:
        print('iteration', k+1)
        # Get patches of noisy image
        print('extracting noisy patches')
        patches = extract_patches_2d(X, patch_size)
        patches = patches.reshape(patches.shape[0], -1)
        patches_mean = np.mean(patches, axis=0)
        patches -= patches_mean
        # Sparse coding
        print('sparse coding')
        alphas = Dobj.transform(patches)
        print('reconstructing patches')
        patches = np.dot(alphas, Dobj.components_)
        patches += patches_mean
        patches = patches.reshape(len(patches), *patch_size)
        print('reconstructing image')
        Xrecon = reconstruct_from_patches_2d(patches, Ymat.shape)
        # Compute U using singular value soft-thresholding
        print('SVD')
        W, S, Vt = np.linalg.svd(X, full_matrices=0)
        Sst = np.sign(S)*np.maximum(np.abs(S) - (mu/lamb), 0)  # soft-thresh
        U = (W * Sst) @ Vt
        # Update the decision variable X (assuming for simplicity that all patches are used the same number of times)
        print('update X')
        rho = np.prod(patch_size)
        X = (gamma*Ynoisy + rho*Xrecon + lamb*U)/(gamma + rho + lamb)
        # Evaluate error for testing (FIXME: remove)
        abs_errors[k] = np.linalg.norm(X.flatten(order='F') - Ymat.flatten(order='F'), ord=2)
        rel_errors[k] = abs_error / YNORM
        rsdrs[k] = -20.0 * np.log10(rel_error)
        mets = mean_metrics(Ymat, X)
        print('time:', time.time()-time_start, 'rsdr:', rsdr, 'mpsnr:', mets['mpsnr'], 'mssim:', mets['mssim'])
        # Update parameters
        #mu /= 10.0  # this division from the paper does not appear to be helpful
        k += 1
    time_end = time.time()
    # Create output
    row = {'seed': seed,
           'snr': params['snr'],
           'measurements': params['measurements'].__name__,
           'regularizer': None,
           'penconst': None,
           'penconst_offset': None,
           'solve_rank': None,
           'hsisolver': params['hsisolver'].__name__,
           'opersolver': None,
           'max_outer_iters': k_max,
           'max_inner_iters': None,
           'relconvergetol': None,
           'relchange': None,
           'abs_error': np.min(abs_error),
           'rel_error': np.min(rel_error),
           'rsdr': np.max(rsdrs),
           'time': time_end - time_start,
           'outer_iters': None}
    return [row, ]
