import numpy as np
import matplotlib.pyplot as plt

def huber(x, symmetric=True, threshold=1):
    out = np.empty_like(x)
    if symmetric:
        mask = np.abs(x) < threshold
    else:
        mask = x < threshold
    out[mask] = x[mask]**2
    out[~mask] = 2 * threshold * np.abs(x[~mask]) - threshold**2

    return out

def truncated_quadratic(x, symmetric=True, threshold=1):
    out = np.empty_like(x)
    if symmetric:
        mask = np.abs(x) < threshold
    else:
        mask = x < threshold
    out[mask] = x[mask]**2
    out[~mask] = threshold**2

    return out

def indec(x, symmetric=True, threshold=1):
    out = np.empty_like(x)
    if symmetric:
        mask = np.abs(x) < threshold
    else:
        mask = x < threshold
    out[mask] = x[mask]**2
    out[~mask] = (threshold**3 / (2 * np.abs(x[~mask]))) + (threshold**2) / 2

    return out

x = np.linspace(-3, 3, 100)
y = x * x
s_huber = huber(x)
a_huber = huber(x, False)
s_tquad = truncated_quadratic(x)
a_tquad = truncated_quadratic(x, False)
s_indec = indec(x)
a_indec = indec(x, False)

fig, (ax, ax2) = plt.subplots(
    1, 2, gridspec_kw={'hspace': 0, 'wspace': 0},
    tight_layout={'pad': 0.6, 'w_pad': 0, 'h_pad': 0}
)
ax.plot(
    x, y, '-',
    x, s_huber, '--',
    x, s_tquad, '-.',
    x, s_indec, ':'
)
handles = ax2.plot(
    x, y, '-',
    x, a_huber, '--',
    x, a_tquad, '-.',
    x, a_indec, ':'
)

ax.axvline(1, ymax=0.7, color='black', linestyle=':')
ax.axvline(-1, ymax=0.7, color='black', linestyle=':')
ax.annotate('threshold', (0.3, 6.6))
ax.set_title('Symmetric')
ax.annotate('residual, y - baseline', (2, -1.5), annotation_clip=False)
ax.set_ylabel('Contribution to cost function')

ax2.legend(handles, ('quadratic', 'Huber', 'truncated-quadratic', 'Indec'), frameon=False)
ax2.axvline(1, ymax=0.7, color='black', linestyle=':')
ax2.set_yticks([])
ax2.set_title('Asymmetric')

plt.show()