# -*- coding: utf-8 -*-
import cvxpy
import itertools
import splib_experiment as exp
from operfact import measurements, solvers, utils
from operfact import regularizers as regs

## Generate the parameter lists
# Parameters template
PARAMS = {'snr': 0.0,
          'reg': None,
          'offset': 0,
          'relconvtol': 0.0,
          'outer_iters': 0,
          'inner_iters': 0,
          'rank': 0,
          'hsisolver': None,
          'opersolver': None,
          'measurements': None}

# Helper function
def generate_params_oper(reg, opersolver, offsets):
    snrs = [10.0, ]
    relconvtols = [1e-3, ]
    outer_iters = [10, ]
    inner_iters = [2500, ]
    ranks = [5, 10]
    meastype = measurements.IdentityMeasurement
    params = []
    for snr, offset, inner_iter, rank in itertools.product(snrs, offsets, inner_iters, ranks):
        params.append({'snr': snr,
                       'reg': reg,
                       'offset': offset,
                       'relconvtol': relconvtols,
                       'outer_iters': outer_iters,
                       'inner_iters': inner_iter,
                       'rank': rank,
                       'hsisolver': exp.opersolve,
                       'opersolver': opersolver,
                       'measurements': meastype})
    return params


def generate_params_indiv(reg, offsets):
    snrs = [10.0, ]
    inner_iters = [2500, ]
    meastype = measurements.IdentityMeasurement
    params = []
    for snr, offset, inner_iter in itertools.product(snrs, offsets, inner_iters):
        params.append({'snr': snr,
                       'reg': reg,
                       'offset': offset,
                       'relconvtol': None,
                       'outer_iters': None,
                       'inner_iters': inner_iter,
                       'rank': None,
                       'hsisolver': exp.indivsolve,
                       'opersolver': solvers.matsolve,
                       'measurements': meastype})
    return params


def generate_params_kpsvd():
    snrs = [10.0, ]
    ranks = [5, 10]
    meastype = measurements.IdentityMeasurement
    params = []
    for snr in snrs:
        params.append({'snr': snr,
                       'reg': None,
                       'offset': None,
                       'relconvtol': None,
                       'outer_iters': None,
                       'inner_iters': None,
                       'rank': ranks,
                       'hsisolver': exp.kpsvdsolve,
                       'opersolver': None,
                       'measurements': meastype})
    return params


def generate_params_splr():
    snrs = [10.0, ]
    meastype = measurements.IdentityMeasurement
    params = []
    for snr in snrs:
        params.append({'snr': snr,
                       'reg': None,
                       'offset': None,
                       'relconvtol': None,
                       'outer_iters': 10,
                       'inner_iters': None,
                       'rank': ranks,
                       'hsisolver': exp.splrsolve,
                       'opersolver': None,
                       'measurements': meastype})
    return params


def nnp(l, r): return regs.NucNorm_Prod(l, r)
quad = solvers.altminsolve
noquad = exp.altminsolve_noquad

# opersolve
params_oper = []
params_oper += generate_params_oper(nnp(regs.norm_l1, regs.norm_l2),
                                    quad, list(range(-5, -1)))
params_oper += generate_params_oper(nnp(regs.norm_l1, regs.norm_l2),
                                    noquad, list(range(-10, -6)))
params_oper += generate_params_oper(nnp(regs.norm_l1, cvxpy.tv),
                                    quad, list(range(-9, -5)))
params_oper += generate_params_oper(nnp(regs.norm_l1, cvxpy.tv),
                                    noquad, list(range(-13, -8)))
params_oper += generate_params_oper(nnp(regs.norm_s1, regs.norm_l2),
                                    quad, list(range(-4, -1)))
params_oper += generate_params_oper(nnp(regs.norm_s1, regs.norm_l2),
                                    noquad, list(range(-9, -6)))
params_oper += generate_params_oper(nnp(regs.norm_s1, cvxpy.tv),
                                    quad, list(range(-5, 0)))
params_oper += generate_params_oper(nnp(regs.norm_s1, cvxpy.tv),
                                    noquad, list(range(-9, -5)))
params_oper += generate_params_oper(nnp(cvxpy.tv, regs.norm_l2),
                                    quad, list(range(-7, -4)))
params_oper += generate_params_oper(nnp(cvxpy.tv, regs.norm_l2),
                                    noquad, list(range(-13, -9)))
params_oper += generate_params_oper(nnp(cvxpy.tv, cvxpy.tv),
                                    quad, list(range(-7, -4)))
params_oper += generate_params_oper(nnp(cvxpy.tv, cvxpy.tv),
                                    noquad, list(range(-13, -9)))

# indivsolve
params_indiv = []
for reg in (regs.NucNorm(regs.norm_l1, regs.norm_l1),
            regs.NucNorm(regs.norm_l2, regs.norm_l2),
            exp.TVRegularizer()):
    params_indiv += generate_params_indiv(reg, [-3, -2, -1, 0, 1])

# kpsvdsolve
params_kpsvd = generate_params_kpsvd()

# splrsolve
params_splr = generate_params_splr()

# concatenate params
params_all = params_oper + params_indiv + params_kpsvd + params_splr

## Run the experiments
if __name__ == '__main__':
    NPROCS = 12
    NTRIALS = 10
    SEED = None

    fieldnames = ['seed', 'snr', 'measurements', 'regularizer', 'penconst', 'penconst_offset',
                  'solve_rank', 'hsisolver', 'opersolver', 'max_outer_iters', 'max_inner_iters',
                  'relconvergetol', 'relchange', 'abs_error', 'rel_error',
                  'rsdr', 'time', 'outer_iters']

    rows = utils.experiment_csv('splib_identmeas.csv.gz', fieldnames,
                                exp.solve_problem, params_all, ntrials=NTRIALS, nprocs=NPROCS, seed=SEED, return_rows=False)
