# -*- coding: utf-8 -*-
import cvxpy
import matplotlib.pyplot as plt
import numpy as np
from operfact import regularizers as regs
from operfact import measurements, operators, solvers, utils


# Generate trupe operator
SHAPE = (4,4,4,4)
oper = operators.DyadsOperator([utils.rand_signmat(SHAPE[0:2], False) for r in range(4)],
                               [utils.rand_lowrankmat(SHAPE[2:4], 1) for r in range(4)])

# Take SVD
U,S,V = operators.kpsvd(oper)
oper_svd = operators.DyadsOperator([U[:,r].flatten().reshape(SHAPE[0:2], order='F') for r in range(4)],
                                   [V[r,:].flatten().reshape(SHAPE[2:4], order='F') for r in range(4)])


def db_to_sigma(db, true):
    # want db = 10*log10(trueNorm^2/(prod(shape)*sigma^2))
    return 10.0**(-db/20.0)*(np.linalg.norm(true, ord='fro')/np.sqrt(np.prod(true.shape)))


# Create denoising problem
sigma = db_to_sigma(10.0, oper.asArrayOperator().reshape((16,16)))

prob = solvers.Problem()
prob.maxiters = 10
prob.norm = regs.NucNorm_Sum(regs.norm_linf, regs.norm_s1)
prob.measurementobj = measurements.IdentityMeasurement(SHAPE)
prob.measurementvec = prob.measurementobj.apply(oper) + sigma*np.random.normal(size=(np.prod(SHAPE),))
prob.penconst = regs.penconst_denoise(SHAPE, sigma, prob.norm) * 2**-1
prob.relconvergetol = 1e-3
prob.rank = 16
prob.shape = SHAPE
out = solvers.altminsolve(prob)

rsdr = -20*np.log10(np.linalg.norm(out.recovered.asArrayOperator().flatten() - oper.asArrayOperator().flatten(), ord=2)/np.linalg.norm(oper.asArrayOperator().flatten(), ord=2))
snr = 20*np.log10(np.linalg.norm(oper.asArrayOperator().flatten(), ord=2)/np.linalg.norm(oper.asArrayOperator().flatten(order='F') - prob.measurementvec, ord=2))
print("Gain:", rsdr-snr)

# Make images
f, ax = plt.subplots(2, 4)
for i in range(oper.nfactors):
    ax[0, i].imshow(oper.lfactors[i], cmap='Greys', interpolation='nearest')
    ax[0, i].set_xticklabels([])
    ax[0, i].set_yticklabels([])
    ax[0, i].set_title('$X_{0}$'.format(i+1))
    ax[1, i].imshow(oper.rfactors[i], cmap='Greys', interpolation='nearest')
    ax[1, i].set_xticklabels([])
    ax[1, i].set_yticklabels([])
    ax[1, i].set_title('$Y_{0}$'.format(i+1))
f.set_size_inches(3*oper.nfactors*1, 3*2*1)
f.savefig('figures/denoise-demix-true.pdf', bbox_inches='tight', pad_inches=0)

f, ax = plt.subplots(2, 4)
for i in range(oper.nfactors):
    ax[0, i].imshow(out.recovered.lfactors[i], cmap='Greys', interpolation='nearest')
    ax[0, i].set_xticklabels([])
    ax[0, i].set_yticklabels([])
    ax[0, i].set_title('$X_{0}$'.format(i+1))
    ax[1, i].imshow(out.recovered.rfactors[i], cmap='Greys', interpolation='nearest')
    ax[1, i].set_xticklabels([])
    ax[1, i].set_yticklabels([])
    ax[1, i].set_title('$Y_{0}$'.format(i+1))
f.set_size_inches(3*oper.nfactors*1, 3*2*1)
f.savefig('figures/denoise-demix-recovered.pdf', bbox_inches='tight', pad_inches=0)

oper_noisy = prob.measurementvec.reshape(SHAPE, order='F')
U, S, V = operators.kpsvd(operators.ArrayOperator(oper_noisy))
f, ax = plt.subplots(2, 4)
for i in range(oper.nfactors):
    lfactor = U[:,i].flatten().reshape((4,4), order='F')
    rfactor = V[i,:].flatten().reshape((4,4), order='F')
    ax[0, i].imshow(lfactor, cmap='Greys', interpolation='nearest')
    ax[0, i].set_xticklabels([])
    ax[0, i].set_yticklabels([])
    ax[0, i].set_title('$X_{0}$'.format(i+1))
    ax[1, i].imshow(rfactor, cmap='Greys', interpolation='nearest')
    ax[1, i].set_xticklabels([])
    ax[1, i].set_yticklabels([])
    ax[1, i].set_title('$Y_{0}$'.format(i+1))
f.set_size_inches(3*oper.nfactors*1, 3*2*1)
f.savefig('figures/denoise-demix-kpsvd.pdf', bbox_inches='tight', pad_inches=0)
