# -*- coding: utf-8 -*-
"""Blind self-calibration experiment driver."""

import cvxpy
import itertools
import numpy as np
import scipy.fftpack
import time
from operfact.utils import *
from operfact import measurements, operators, solvers, utils
from operfact import regularizers as regs


class SelfCalibMeasurement(measurements.Measurement):
    """Computes the measurements as in Ling and Strohmer (2015)"""
    def __init__(self, L, k, N, M, signal_is_2d=False):
        """
        :param L: Number of measurements
        :param k: Length of the parameter vector
        :param N: Ambient dimension of the signal
        :param M: Number of snapshots
        :param signal_is_2d: If True, treat the N*M signal as 2D data. If False,
        treat the N*M signal as M snapshots of a 1-dimensional length-N signal.
        """
        self.dim_amb = L
        self.shape = (k, 1, N, M)
        if signal_is_2d:
            self.nmeas = L
        else:
            self.nmeas = L * M
        self.signal_is_2d = signal_is_2d
        # cvxpy doens't handle complex variables natively:
        # self.B = (1/np.sqrt(L))*np.fft.fft(np.eye(L))[:, 0:K]
        # Using a partial DCT matrix instead:
        self.B = scipy.fftpack.dct(np.eye(L), norm='ortho')[:, 0:k]
        if signal_is_2d:
            self.A = rand_gaussianmat((L, N*M), False)
            self.matcompat = True
        else:
            self.A = rand_gaussianmat((L, N), False)
            self.matcompat = (M == 1)

    def apply(self, oper):
        assert isinstance(oper, operators.DyadsOperator)
        # NB: broadcasting elementwise multiplication to do diag(vec) @ matrix
        if self.signal_is_2d:
            temp = sum([(self.B @ oper.lfactors[r]) *
                        (self.A @ oper.rfactors[r].reshape((self.shape[2]*self.shape[3], 1), order='F'))
                        for r in range(oper.nfactors)])
        else:
            temp = sum([(self.B @ oper.lfactors[r]) * (self.A @ oper.rfactors[r])
                        for r in range(oper.nfactors)])
        return temp.flatten(order='F')

    def cvxapply(self, oper):
        assert isinstance(oper, operators.DyadsOperator)
        if self.signal_is_2d:
            temp = sum([cvxpy.diag(self.B * oper.lfactors[r]) * (self.A * cvxpy.vec(oper.rfactors[r]))
                        for r in range(oper.nfactors)])
        else:
            temp = sum([cvxpy.diag(self.B * oper.lfactors[r]) * (self.A * oper.rfactors[r])
                        for r in range(oper.nfactors)])
        return cvxpy.vec(temp)

    def matapply(self, mat):
        assert self.signal_is_2d or (self.shape[3] == 1)
        if isinstance(mat, cvxpy.expressions.expression.Expression):
            return cvxpy.vstack(*[np.matrix(self.B[l,:])*(mat*np.matrix(self.A[l,:]).T) for l in range(self.nmeas)])
        else:
            return np.vstack([np.matrix(self.B[l,:])*(mat*np.matrix(self.A[l,:]).T) for l in range(self.nmeas)])

    def asOperator(self):
        raise NotImplementedError

    def initfromoper(self, oper):
        raise NotImplementedError

    def initfrommeas(self, meas):
        out = operators.ArrayOperator(np.zeros(self.shape))
        for i in range(self.dim_amb):
            b_i = self.B[i:i+1, :].T  # force return 2D array
            a_i = self.A[i:i+1, :].T  # force return 2D array
            if self.signal_is_2d:
                Yi0 = a_i.reshape(self.shape[2:4], order='F')
                Ai0 = operators.DyadsOperator([b_i, ], [Yi0, ])
                out += meas[i] * Ai0.asArrayOperator()
            else:
                for j in range(self.shape[3]):
                    Yij = np.zeros(self.shape[2:4])
                    Yij[:, j:j+1] = a_i  # a_i is 2D array
                    Aij = operators.DyadsOperator([b_i, ], [Yij, ])
                    out += meas[j*self.dim_amb + i] * Aij.asArrayOperator()
        return out


def test_adjoint(signal_is_2d=False):
    """Tests the initfrommeas (adjoint) function."""
    L = 16
    shape = (8, 1, 4, 2)
    ndyads = 2
    dyop = operators.RandomDyadsOperator(shape, nfactors=ndyads)
    meas = SelfCalibMeasurement(L, shape[0], shape[2], shape[3], signal_is_2d=signal_is_2d)
    fwd = meas.apply(dyop)
    y = np.random.normal(size=meas.nmeas)
    fwd_ip = np.sum(fwd*y)
    adj = meas.initfrommeas(y)
    adj_ip = np.sum(dyop.asArrayOperator()*adj)
    assert np.isclose(fwd_ip, adj_ip)


def test_apply_cvxapply(signal_is_2d=False):
    """Test agreement between apply and cvxapply."""
    L = 16
    shape = (8, 1, 4, 2)
    ndyads = 2
    dyop = operators.RandomDyadsOperator(shape, nfactors=ndyads)
    meas = SelfCalibMeasurement(L, shape[0], shape[2], shape[3], signal_is_2d=signal_is_2d)
    mv = np.matrix(meas.apply(dyop)).T  # handle cvxpy returning np.matrix objects
    lparams = [cvxpy.Parameter(*shape[0:2]) for i in range(ndyads)]
    rparams = [cvxpy.Parameter(*shape[2:4]) for i in range(ndyads)]
    for i in range(ndyads):
        lparams[i].value = dyop.lfactors[i]
        rparams[i].value = dyop.rfactors[i]
    dyopcvx = operators.DyadsOperator(lparams, rparams)
    mvcvx = meas.cvxapply(dyopcvx)
    assert np.sum(np.isclose(mv, mvcvx.value)) == meas.nmeas  # really should use an assert almost equals (checks size)


def test_apply_matapply(signal_is_2d=False):
    """Test agreement between apply and matapply."""
    L = 16
    shape = (8, 1, 4, 2)
    ndyads = 2
    dyop = operators.RandomDyadsOperator(shape, nfactors=ndyads)
    meas = SelfCalibMeasurement(L, shape[0], shape[2], shape[3], signal_is_2d=signal_is_2d)
    opapply = meas.apply(dyop)
    matapply = meas.matapply(dyop.asmatrix())
    assert np.sum(np.isclose(opapply, matapply)) == meas.nmeas  # really should use an assert almost equals (checks size)


def test_cvxapply_matapply(signal_is_2d=False):
    """Test agreement between cvxapply and matapply."""
    L = 16
    shape = (8, 1, 4, 2)
    ndyads = 2
    dyop = operators.RandomDyadsOperator(shape, nfactors=ndyads)
    mat = dyop.asmatrix()
    meas = SelfCalibMeasurement(L, shape[0], shape[2], shape[3], signal_is_2d=signal_is_2d)
    lparams = [cvxpy.Parameter(*shape[0:2]) for i in range(ndyads)]
    rparams = [cvxpy.Parameter(*shape[2:4]) for i in range(ndyads)]
    for i in range(ndyads):
        lparams[i].value = dyop.lfactors[i]
        rparams[i].value = dyop.rfactors[i]
    dyopcvx = operators.DyadsOperator(lparams, rparams)
    cvxapply = meas.cvxapply(dyopcvx)
    cvxmat = cvxpy.Parameter(np.prod(shape[0:2]), np.prod(shape[2:4]))
    cvxmat.value = mat
    cvxmatapply = meas.matapply(cvxmat)
    assert np.sum(np.isclose(cvxapply.value, cvxmatapply.value)) == meas.nmeas


def db_to_sigma(db, signal_norm, shape):
    """Helper function to compute the noise parameter for a given SNR."""
    # Satisfies: db = 10*log10(signal_norm^2/(prod(shape)*sigma^2))
    return 10.0**(-db/20.0)*(signal_norm/np.sqrt(np.prod(shape)))


def sparse_signal(oper_shape, ndyads, nnz):
    """Generate sparse signal matrix."""
    assert nnz <= np.prod(oper_shape[2:4])
    lfactors = []
    rfactors = []
    for r in range(ndyads):
        lfactors.append(np.random.normal(size=oper_shape[0:2]))
        mat = utils.rand_sparsegaussianmat(oper_shape[2:4], nnz/np.prod(oper_shape[2:4]))
        rfactors.append(mat)
    return operators.DyadsOperator(lfactors, rfactors)


def sparse_identical_entries(oper_shape, ndyads, nnz):
    """Generate identical sprase signals."""
    assert nnz <= oper_shape[2]
    lfactors = []
    rfactors = []
    for r in range(ndyads):
        lfactors.append(np.random.normal(size=oper_shape[0:2]))
        vec = utils.rand_sparsegaussianmat((oper_shape[2], 1), nnz/oper_shape[2])
        mat = np.zeros(oper_shape[2:4])
        mat[:,:] = vec
        rfactors.append(mat)
    return operators.DyadsOperator(lfactors, rfactors)


def sparse_identical_locs(oper_shape, ndyads, nnz):
    """Generate simultaneously sparse signals."""
    assert nnz <= oper_shape[2]
    lfactors = []
    rfactors = []
    for r in range(ndyads):
        lfactors.append(np.random.normal(size=oper_shape[0:2]))
        mat = np.zeros(oper_shape[2:4])
        ixs = np.random.choice(oper_shape[2], nnz, replace=False)
        for i in ixs:
            mat[i, :] = np.random.normal(size=(oper_shape[3],))
        rfactors.append(mat)
    return operators.DyadsOperator(lfactors, rfactors)


def sparse_nonidentical(oper_shape, ndyads, nnz):
    """Generate sparse signals with same number of nonzeros per column."""
    assert nnz <= oper_shape[2]
    lfactors = []
    rfactors = []
    for r in range(ndyads):
        lfactors.append(np.random.normal(size=oper_shape[0:2]))
        mat = np.zeros(oper_shape[2:4])
        for j in range(oper_shape[3]):
            ixs = np.random.choice(oper_shape[2], nnz, replace=False)
            mat[ixs, j] = np.random.normal(size=(nnz,))
        rfactors.append(mat)
    return operators.DyadsOperator(lfactors, rfactors)


def lowrank_signal(oper_shape, ndyads, rank):
    """Generate a low-rank signal."""
    assert rank <= np.min(oper_shape[2:4])
    lfactors = []
    rfactors = []
    for r in range(ndyads):
        lfactors.append(np.random.normal(size=oper_shape[0:2]))
        mat = utils.rand_lowrankmat(oper_shape[2:4], rank)
        rfactors.append(mat)
    return operators.DyadsOperator(lfactors, rfactors)

PARAMS = {'L': 128,
          'N': 256,
          'k': None,
          'n': None,
          'M': None,
          'signal_is_2d': None,
          'ndyads': None,
          'generate_fn': None,
          'snr': None,
          'reg': None,
          'penconst_offset': None,  # may be a list
          'solve_rank': None,  # may be a list
          'inner_solver': None,
          'relconvergetol': None,
          'max_outer_iters': None}

ROW = {**PARAMS,
       'penconst': None,
       'seed': None,
       'abs_error': None,
       'rel_error': None,
       'rsdr': None,
       'time': None,
       'outer_iters': None,
       'relchange': None,
       'outer_solver': None}


def solve_problem(params, seed=None, redirect=None):
    rows = []
    # Set the seed
    if seed is None:
        seed = int(time.time())
        np.random.seed(seed=seed)
    elif seed == np.Inf:
        np.random.seed()
    else:
        np.random.seed(seed=seed)
    # Generate the operator
    oper_shape = (params['k'], 1, params['N'], params['M'])
    oper_true = params['generate_fn'](oper_shape, params['ndyads'], params['n'])
    oper_flattened = oper_true.asArrayOperator().flatten()
    oper_norm = np.linalg.norm(oper_flattened)
    # Generate the measurements
    measobj = SelfCalibMeasurement(params['L'], params['k'], params['N'], params['M'], params['signal_is_2d'])
    measvec = measobj.apply(oper_true)
    # Generate the noise
    sigma = db_to_sigma(params['snr'], oper_norm, oper_shape)
    noisevec = sigma*np.random.normal(size=measobj.nmeas)
    measvec += noisevec
    penconst = regs.penconst_denoise(oper_shape, sigma, params['reg'])
    # Create problem
    prob = solvers.Problem()
    prob.shape = oper_shape
    prob.measurementobj = measobj
    prob.measurementvec = measvec
    prob.norm = params['reg']
    prob.maxiters = params['max_outer_iters']
    prob.relconvergetol = params['relconvergetol']
    prob.solver = params['inner_solver']
    prob.solveropts = {'verbose': False, 'warm_start': True}
    # Create output template
    rowtemp = ROW.copy()
    rowtemp.update(params)  # populate with current params
    rowtemp['generate_fn'] = params['generate_fn'].__name__
    rowtemp['reg'] = str(prob.norm)
    rowtemp['penconst'] = penconst
    rowtemp['seed'] = seed
    # Solve
    offsets = params['penconst_offset']
    offsets = offsets if utils.istol(offsets) else (offsets, )
    solve_ranks = params['solve_rank']
    solve_ranks = solve_ranks if utils.istol(solve_ranks) else (solve_ranks, )
    for (offset, solve_rank) in itertools.product(offsets, solve_ranks):
        prob.penconst = penconst*(2**offset)
        prob.rank = solve_rank
        outer_solvers = prob.norm.available_solvers()
        for solver in outer_solvers:
            if (solver is not 'altmin') and ((solve_rank != solve_ranks[0]) or
                                             (measobj.matcompat is False)):
                continue
            outlist = solvers.SOLVERS[solver](prob, eqconstraint=(sigma == 0.0))
            outlist = outlist if utils.istol(outlist) else [outlist, ]
            for out in outlist:
                # Compute error
                # NB: we compute recovery error by projecting the solution to a rank-1 solution
                if (solve_rank != 1) or (solver is not 'altmin'):
                    [U, S, V] = operators.kpsvd(out.recovered)
                    # Projecting down to a rank-1 operator
                    left = S[0] * U[:, 0].reshape(oper_shape[0:2], order='F')
                    right = V.T[:, 0].reshape(oper_shape[2:4], order='F')
                    Xhat_dyads = operators.DyadsOperator([left, ], [right, ])
                    Xhat = Xhat_dyads.asArrayOperator().flatten()
                else:
                    Xhat = out.recovered.asArrayOperator().flatten()

                abs_error = np.linalg.norm(Xhat - oper_flattened, ord=2)
                rel_error = abs_error/np.linalg.norm(oper_flattened, ord=2)
                # Create output row
                row = rowtemp.copy()
                row['penconst_offset'] = offset
                row['abs_error'] = abs_error
                row['rel_error'] = rel_error
                row['rsdr'] = -20.0*np.log10(rel_error)
                row['time'] = out.total_time
                row['outer_solver'] = solver
                if solver is 'altmin':
                    row['solve_rank'] = solve_rank
                    row['max_outer_iters'] = out.maxiters
                    row['outer_iters'] = out.outer_iters
                    row['relchange'] = out.relchange
                    row['relconvergetol'] = out.relconvtol
                else:
                    row['solve_rank'] = None
                    row['max_outer_iters'] = None
                    row['relconvergetol'] = None
                rows.append(row)
    return rows
