# -*- coding: utf-8 -*-
import cvxpy
import itertools
import numpy as np
import os
import sys
from operfact import measurements, operators, solvers
from operfact import regularizers as regs


PARAMS = {'shape': [],
          'rank': [],
          'inner_type': [],
          'db': [],
          'penconst_offset': [],
          'solve_rank': [],
          'relconvergetol': [],
          'outer_iters': [],
          'inner_solver': [],
          'regularizer': [],
          'solveropts': {}}


FIELDNAMES = ['shape', 'rank', 'inner_type',
              'seed', 'hash', 'noise_level',
              'regularizer', 'penconst', 'penconst_offset', 'solve_rank',
              'solver', 'max_outer_iters', 'max_inner_iters',
              'relconvergetol', 'relchange',
              'abs_error', 'rel_error', 'orig_error',
              'objval', 'time', 'outer_iters']


def db_to_sigma(db, shape):
    # want db = 10*log10(1/(prod(shape)*sigma^2))
    return 10.0**(-db/20.0)*(1.0/np.sqrt(np.prod(shape)))


def generate_operator(shape=None, rank=None, inner_type=None):
    problem = solvers.Problem()
    problem.shape = shape

    # Generate operator
    mat_rank = 0
    tries = 0
    while mat_rank != rank:
        tries += 1
        if tries % min(shape) == 0:
            print('{0} tries taken for type {1}'.format(tries, inner_type))
        left = [inner_type[0](shape[0:2]) for r in range(rank)]
        right = [inner_type[1](shape[2:4]) for r in range(rank)]
        problem.trueOperator = operators.DyadsOperator(left, right)
        X0 = problem.trueOperator.asmatrix()
        mat_rank = np.linalg.matrix_rank(X0)

    # Normalize
    trueNorm = np.linalg.norm(X0, ord='fro')
    for r in range(rank):
        problem.trueOperator.lfactors[r] *= 1/np.sqrt(trueNorm)
        problem.trueOperator.rfactors[r] *= 1/np.sqrt(trueNorm)
    X0 = problem.trueOperator.asmatrix()
    Xhash = hash(X0.tobytes())
    trueNorm = np.linalg.norm(X0, ord='fro')  # should be 1
    assert np.isclose(trueNorm, 1.0)

    # Update cache
    problem.cache['X0'] = X0
    problem.cache['inner_type'] = inner_type
    problem.cache['oper_rank'] = rank
    problem.cache['trueNorm'] = 1.0
    problem.cache['Xhash'] = Xhash

    return problem


def generate_noise(problem, db=None):
    shape = problem.shape

    # Pull from cache
    X0 = problem.cache['X0']

    # Generate noise
    sigma = db_to_sigma(db, shape)
    Xnoisy = X0 + np.random.normal(scale=sigma,
                                   size=(shape[0]*shape[1], shape[2]*shape[3]))
    problem.measurementobj = measurements.IdentityMeasurement(shape)
    problem.measurementvec = Xnoisy.flatten(order='F')

    # Initialize for altminsolve
    U, S, V = operators.kpsvd(problem.measurementobj.initfrommeas(problem.measurementvec))
    problem.rfactorsinit = [V[r, ].reshape(shape[2:4], order='F') for r in range(len(S))]

    # Update cache
    problem.cache['sigma'] = sigma
    problem.cache['db'] = db
    problem.cache['orig_error'] = np.linalg.norm(X0 - Xnoisy, ord='fro')


def solve_problem(problem, penconst_offset=None, relconvergetol=None,
                  outer_iters=None, outer_solver=None, inner_solver=None,
                  solve_rank=None, solveropts={}):
    ALTMIN = (outer_solver == 'altmin')
    INNER_ITERS = 2500  # should be different for CVXOPT vs SCS

    # Pull from cache
    penconst = problem.cache['penconst']
    inner_type = problem.cache['inner_type']
    seed = problem.cache['seed']
    Xhash = problem.cache['Xhash']
    db = problem.cache['db']
    orig_error = problem.cache['orig_error']
    trueNorm = problem.cache['trueNorm']
    oper_rank = problem.cache['oper_rank']

    # Set properties
    problem.penconst = penconst*(2.0**penconst_offset)
    problem.rank = solve_rank
    problem.relconvergetol = relconvergetol
    problem.maxiters = outer_iters
    problem.solver = inner_solver
    # passing warm_start to matsolve is not a problem
    problem.solveropts = {}
    problem.solveropts['verbose'] = solveropts.get('verbose', False)
    problem.solveropts['max_iters'] = INNER_ITERS
    if problem.solver is cvxpy.SCS:
        problem.solveropts['warm_start'] = True

    # Call the solver and write the output
    rows = []
    out = solvers.SOLVERS[outer_solver](problem)
    outlist = out if isinstance(out, list) else [out, ]
    for out in outlist:
        Xout = out.recovered.asArrayOperator().flatten(order='F')
        abs_error = np.linalg.norm(Xout - problem.trueOperator.asArrayOperator().flatten(order='F'), ord=2)
        rel_error = abs_error/trueNorm
        row = {'shape': problem.shape,
               'rank': oper_rank,
               'inner_type': '{0}, {1}'.format(*[fn.__name__.replace('rand_', '').replace('mat', '') for fn in inner_type]),
               'seed': seed,
               'hash': Xhash,
               'noise_level': db,
               'regularizer': str(problem.norm),
               'penconst': penconst,  # FIXME: do we really want this?
               'penconst_offset': penconst_offset,
               'solve_rank': solve_rank if ALTMIN else None,
               'solver': '{0}_{1}'.format(outer_solver, inner_solver),
               'max_outer_iters': out.maxiters if ALTMIN else None,
               'max_inner_iters': INNER_ITERS,
               'relconvergetol': out.relconvtol if ALTMIN else None,
               'relchange': out.relchange if ALTMIN else None,
               'abs_error': abs_error,
               'rel_error': rel_error,
               'orig_error': orig_error,
               'objval': out.objval,
               'time': out.total_time,
               'outer_iters': out.outer_iters if ALTMIN else None}
        rows.append(row)
    return (rows, outlist)


def run_experiment(params, seed=None, redirect=False):
    if redirect:
        # Redirect stdout and stderr for logging
        sys.stdout = open(str(os.getpid()) + '.out', 'wt', buffering=1)
        sys.stderr = open(str(os.getpid()) + '_error.out', 'wt', buffering=1)

    np.random.seed(seed=seed)
    out_rows = []


    for shape, inner_type, rank in \
            itertools.product(params['shape'], params['inner_type'], params['rank']):
        print('type: {0} rank: {1}'.format(inner_type, rank))
        problem = generate_operator(shape=shape, rank=rank, inner_type=inner_type)
        problem.cache['seed'] = seed

        for db in params['db']:
            generate_noise(problem, db=db)

            for reg in params['regularizer']:
                problem.norm = reg
                problem.cache['penconst'] = regs.penconst_denoise(shape, problem.cache['sigma'], reg)

                for offset, solve_rank, inner_solver in \
                        itertools.product(params['penconst_offset'], params['solve_rank'],
                                          params['inner_solver']):
                    for solver in problem.norm.available_solvers():
                        ALTMIN = (solver == 'altmin')
                        # Cases to skip
                        if ALTMIN and (rank > solve_rank):
                            continue
                        if (not ALTMIN) and (solve_rank != params['solve_rank'][0]):
                            continue
                        out_rows += solve_problem(problem,
                                                  penconst_offset=offset, solve_rank=solve_rank,
                                                  relconvergetol=params['relconvergetol'],
                                                  inner_solver=inner_solver, outer_iters=params['outer_iters'],
                                                  outer_solver=solver, solveropts=params['solveropts'])[0]

    return out_rows
