import numpy as np
import matplotlib.pyplot as plt
from pybaselines.utils import gaussian
from pybaselines import Baseline


def create_data():
    x = np.linspace(1, 1000, 500)
    signal = (
        gaussian(x, 6, 180, 5)
        + gaussian(x, 8, 350, 10)
        + gaussian(x, 6, 550, 5)
        + gaussian(x, 9, 800, 10)
    )
    signal_2 = (
        gaussian(x, 9, 100, 12)
        + gaussian(x, 15, 400, 8)
        + gaussian(x, 13, 700, 12)
        + gaussian(x, 9, 880, 8)
    )
    signal_3 = (
        gaussian(x, 8, 150, 10)
        + gaussian(x, 20, 120, 12)
        + gaussian(x, 16, 300, 20)
        + gaussian(x, 12, 550, 5)
        + gaussian(x, 20, 750, 12)
        + gaussian(x, 18, 800, 18)
        + gaussian(x, 15, 830, 12)
    )
    noise = np.random.default_rng(1).normal(0, 0.2, x.size)
    linear_baseline = 3 + 0.01 * x
    exponential_baseline = 5 + 15 * np.exp(-x / 400)
    gaussian_baseline = 5 + gaussian(x, 20, 500, 500)

    baseline_1 = linear_baseline
    baseline_2 = gaussian_baseline
    baseline_3 = exponential_baseline
    baseline_4 = 10 - 0.005 * x + gaussian(x, 5, 850, 200)
    baseline_5 = linear_baseline + 20

    y1 = signal * 2 + baseline_1 + 5 * noise
    y2 = signal + signal_2 + signal_3 + baseline_2 + noise
    y3 = signal + signal_2 + baseline_3 + noise
    y4 = signal + + signal_2 + baseline_4 + noise * 0.5
    y5 = signal * 2 - signal_2 + baseline_5 + noise

    baselines = (baseline_1, baseline_2, baseline_3, baseline_4, baseline_5)
    data = (y1, y2, y3, y4, y5)

    return x, data, baselines


def create_plots(data=None, baselines=None):
    fig, axes = plt.subplots(
        3, 2, tight_layout={'pad': 0.1, 'w_pad': 0, 'h_pad': 0},
        gridspec_kw={'wspace': 0, 'hspace': 0}
    )
    axes = axes.ravel()

    legend_handles = []
    if data is None:
        plot_data = False
        legend_handles.append(None)
    else:
        plot_data = True
    if baselines is None:
        plot_baselines = False
        legend_handles.append(None)
    else:
        plot_baselines = True

    for i, axis in enumerate(axes):
        axis.set_xticks([])
        axis.set_yticks([])
        axis.tick_params(
            which='both', labelbottom=False, labelleft=False,
            labeltop=False, labelright=False
        )
        if i < 5:
            if plot_data:
                data_handle = axis.plot(data[i])
            if plot_baselines:
                baseline_handle = axis.plot(baselines[i], lw=2.5)
    fit_handle = axes[-1].plot((), (), 'g--')
    if plot_data:
        legend_handles.append(data_handle[0])
    if plot_baselines:
        legend_handles.append(baseline_handle[0])
    legend_handles.append(fit_handle[0])

    if None not in legend_handles:
        axes[-1].legend(
            (data_handle[0], baseline_handle[0], fit_handle[0]),
            ('data', 'real baseline', 'estimated baseline'),
            loc='center', frameon=False
        )

    return fig, axes, legend_handles


x, data, baselines = create_data()
baseline_fitter = Baseline(x, check_finite=False)

figure, axes, handles = create_plots(data, baselines)
for i, (ax, y) in enumerate(zip(axes, data)):
    if i in (0, 4):
        lam = 5e8
    elif i == 1:
        lam = 5e6
    else:
        lam = 1e5
    if i == 4:
        symmetric = True
        p = 0.5
    else:
        symmetric = False
        p = 0.01
    baseline, params = baseline_fitter.mixture_model(y, lam=lam, p=p, symmetric=symmetric)
    ax.plot(baseline, 'g--')