1
0
Fork 0
mirror of https://github.com/Findus23/halo_comparison.git synced 2024-09-18 14:53:49 +02:00
halo_comparison/PlotDaubechies.py
2022-08-19 11:44:28 +02:00

443 lines
13 KiB
Python

"""
originally created by Oliver Hahn
in PlotDaubechies.ipynb
"""
from math import pi
from pathlib import Path
from typing import List
import matplotlib.pyplot as plt
import numpy as np
# two-fold upsampling -- https://cnx.org/contents/xsppCgXj@8.18:H_wA16rf@16/Upsampling
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
# # The Cascade Algorithm to Compute the Wavelet and the Scaling Function
# Algorithm adapted from [this webpage](https://cnx.org/contents/0nnvPGYf@7/Computing-the-Scaling-Function-The-Cascade-Algorithm). The iterations are defined by
# $$ \varphi^{(k+1)}(t)=\sqrt{2} \;\sum_{n=0}^{N-1} h[n]\, \varphi^{(k)} (2t-n) $$
# For the $k$th iteration, where an initial $\varphi^{(0)}(t)$ must be given.
#
# The idea of the cascade algorithm is to re-write this as an ordinary convolution, yielding
# $$ \varphi^{(k+1)}_{1/2}[t] := \varphi^{(k+1)} \left( \frac{t}{2} \right) = \sqrt{2} \; \sum_{n=0}^{N-1} h[n]\, \varphi^{(k)} (t-n) $$
# Consider the next iteration (and let $p:=2n$)
# $$
# \begin{align}
# \varphi^{(k+2)} \left( \frac{t}{4} \right) &= \sqrt{2} \; \sum_{n=0}^{N-1} h[n]\, \varphi^{(k+1)} \left(\frac{t}{2}-n\right) \\
# & = \sqrt{2} \; \sum_{p \;{\rm even}} h\left[\frac{p}{2}\right]\, \varphi^{(k+1)} \left(\frac{t-p}{2}\right) \\
# & = \sqrt{2} \; \sum_{p=0}^{N-1} h_{\uparrow 2}\left[p\right]\, \varphi^{(k+1)}_{1/2} \left(t-p\right) \\
# \end{align}
# $$
# which defines the iterations of the cascade algorithm as simple convolutions $\varphi \to \sqrt{2} \,\varphi \ast h$ followed by upsampling of $h \to h_{\uparrow 2}$.
#
# And upsampling is defined in the usual way as ([see also here](https://cnx.org/contents/xsppCgXj@8.18:H_wA16rf@16/Upsampling))
# $$
# x_{\uparrow L}[n] :=
# \left\{ \begin{array}{ll}
# x[n/L] & \textrm{if } \frac{n}{L} \in \mathbb{Z} \\
# 0 & \textrm{otherwise}
# \end{array}\right. .
# $$
#
# Finally, to obtain the wavelet function, we also need the wavelet-scaling equation
# $$
# \psi(t) = \sqrt{2} \sum_{n=0}^{N-1} g[n] \, \varphi(2t-n)
# $$
# which can be applied in the cascade algorithm in the last iteration to obtain the wavelet rather than the scaling function.
#
from utils import figsize_from_page_fraction
def upsample(sig):
signew = np.zeros(len(sig) * 2 - 1)
signew[::2] = sig
return signew
# use the cascade algorithm to compute the scaling function and the wavelet
# -- https://cnx.org/contents/0nnvPGYf@7/Computing-the-Scaling-Function-The-Cascade-Algorithm
def cascade_algorithm(h, g, maxit):
x = np.arange(len(h), dtype=float)
phi = np.ones(len(h))
psi = np.ones(len(h))
phi_it = np.copy(phi)
psi_it = np.copy(psi)
h_it = np.copy(h)
g_it = np.copy(g)
for it in range(maxit):
# perform repeated convolutions
phi_it = np.sqrt(2) * np.convolve(h_it, phi_it, mode="full")
if it != maxit - 1:
psi_it = np.sqrt(2) * np.convolve(h_it, psi_it, mode="full")
else:
psi_it = np.sqrt(2) * np.convolve(g_it, psi_it, mode="full")
# upsample the coefficients
h_it = upsample(h_it)
g_it = upsample(g_it)
# get the new mid-point positions
xnew = np.zeros(len(x) + len(x) - 1)
xnew[::2] = x
xnew[1::2] = x[:-1] + 2 ** -(it + 1)
x = xnew
return x, phi_it / len(h), psi_it / len(h)
# # Wavelet Coefficients for DB2 to DB8
maxit = 10
# Haar wavelet
h_Haar = np.array([0.7071067811865476, 0.7071067811865476])
g_Haar = np.array([0.7071067811865476, -0.7071067811865476])
xhaar, phihaar, psihaar = cascade_algorithm(h_Haar, g_Haar, maxit)
# DB4 -- http://wavelets.pybytes.com/wavelet/db2/
h_DB4 = np.array([0.4829629131, 0.8365163037, 0.2241438680, -0.1294095226])
g_DB4 = np.array([-0.1294095226, -0.2241438680, 0.8365163037, -0.4829629131])
xdb4, phidb4, psidb4 = cascade_algorithm(h_DB4, g_DB4, maxit)
# DB3 -- http://wavelets.pybytes.com/wavelet/db3/
# h_DB3 = np.array(
# [
# 0.3326705529509569,
# 0.8068915093133388,
# 0.4598775021193313,
# -0.13501102001039084,
# -0.08544127388224149,
# 0.035226291882100656,
# ]
# )
# g_DB3 = np.array(
# [
# 0.035226291882100656,
# 0.08544127388224149,
# -0.13501102001039084,
# -0.4598775021193313,
# 0.8068915093133388,
# -0.3326705529509569,
# ]
# )
#
# xdb3, phidb3, psidb3 = cascade_algorithm(h_DB3, g_DB3, maxit)
# DB8 -- http://wavelets.pybytes.com/wavelet/db4/
h_DB8 = np.array(
[
0.23037781330885523,
0.7148465705525415,
0.6308807679295904,
-0.02798376941698385,
-0.18703481171888114,
0.030841381835986965,
0.032883011666982945,
-0.010597401784997278,
]
)
g_DB8 = np.array(
[
-0.010597401784997278,
-0.032883011666982945,
0.030841381835986965,
0.18703481171888114,
-0.02798376941698385,
-0.6308807679295904,
0.7148465705525415,
-0.23037781330885523,
]
)
xdb8, phidb8, psidb8 = cascade_algorithm(h_DB8, g_DB8, maxit)
# # DB8 -- http://wavelets.pybytes.com/wavelet/db8/
# h_DB8 = np.array(
# [
# 0.05441584224308161,
# 0.3128715909144659,
# 0.6756307362980128,
# 0.5853546836548691,
# -0.015829105256023893,
# -0.2840155429624281,
# 0.00047248457399797254,
# 0.128747426620186,
# -0.01736930100202211,
# -0.04408825393106472,
# 0.013981027917015516,
# 0.008746094047015655,
# -0.00487035299301066,
# -0.0003917403729959771,
# 0.0006754494059985568,
# -0.00011747678400228192,
# ]
# )
# g_DB8 = np.array(
# [
# -0.00011747678400228192,
# -0.0006754494059985568,
# -0.0003917403729959771,
# 0.00487035299301066,
# 0.008746094047015655,
# -0.013981027917015516,
# -0.04408825393106472,
# 0.01736930100202211,
# 0.128747426620186,
# -0.00047248457399797254,
# -0.2840155429624281,
# 0.015829105256023893,
# 0.5853546836548691,
# -0.6756307362980128,
# 0.3128715909144659,
# -0.05441584224308161,
# ]
# )
#
# xdb8, phidb8, psidb8 = cascade_algorithm(h_DB8, g_DB8, maxit)
# # DB16 --
# h_DB16 = np.array(
# [
# 0.0031892209253436892,
# 0.03490771432362905,
# 0.1650642834886438,
# 0.43031272284545874,
# 0.6373563320829833,
# 0.44029025688580486,
# -0.08975108940236352,
# -0.3270633105274758,
# -0.02791820813292813,
# 0.21119069394696974,
# 0.027340263752899923,
# -0.13238830556335474,
# -0.006239722752156254,
# 0.07592423604445779,
# -0.007588974368642594,
# -0.036888397691556774,
# 0.010297659641009963,
# 0.013993768859843242,
# -0.006990014563390751,
# -0.0036442796214883506,
# 0.00312802338120381,
# 0.00040789698084934395,
# -0.0009410217493585433,
# 0.00011424152003843815,
# 0.00017478724522506327,
# -6.103596621404321e-05,
# -1.394566898819319e-05,
# 1.133660866126152e-05,
# -1.0435713423102517e-06,
# -7.363656785441815e-07,
# 2.3087840868545578e-07,
# -2.1093396300980412e-08,
# ]
# )
# g_DB16 = np.array(
# [
# -2.1093396300980412e-08,
# -2.3087840868545578e-07,
# -7.363656785441815e-07,
# 1.0435713423102517e-06,
# 1.133660866126152e-05,
# 1.394566898819319e-05,
# -6.103596621404321e-05,
# -0.00017478724522506327,
# 0.00011424152003843815,
# 0.0009410217493585433,
# 0.00040789698084934395,
# -0.00312802338120381,
# -0.0036442796214883506,
# 0.006990014563390751,
# 0.013993768859843242,
# -0.010297659641009963,
# -0.036888397691556774,
# 0.007588974368642594,
# 0.07592423604445779,
# 0.006239722752156254,
# -0.13238830556335474,
# -0.027340263752899923,
# 0.21119069394696974,
# 0.02791820813292813,
# -0.3270633105274758,
# 0.08975108940236352,
# 0.44029025688580486,
# -0.6373563320829833,
# 0.43031272284545874,
# -0.1650642834886438,
# 0.03490771432362905,
# -0.0031892209253436892,
# ]
# )
#
# xdb16, phidb16, psidb16 = cascade_algorithm(h_DB16, g_DB16, maxit)
###################################
fig: Figure
fig, ax = plt.subplots(
3,
2,
figsize=figsize_from_page_fraction(),
# sharex="all", sharey="all"
)
fig.subplots_adjust(hspace=0, wspace=0)
labels = ["Haar (DB2)", "DB4", "DB8", "DB16"]
prange = [[0., 1.], [0., 3.], [0., 7.]]
for p in prange:
p[0] -= 0.25
p[1] += 0.25
ax[0, 0].set_title("scaling functions $\\varphi$")
ax[0, 1].set_title("wavelets $\\psi$")
ax[0, 0].plot(xhaar, phihaar, lw=1)
ax[0, 1].plot(xhaar, psihaar, lw=1)
ax[1, 0].plot(xdb4, phidb4, lw=1)
ax[1, 1].plot(xdb4, psidb4, lw=1)
ax[2, 0].plot(xdb8, phidb8, lw=1)
ax[2, 1].plot(xdb8, psidb8, lw=1)
for i, a in enumerate(ax.flatten()):
if i in [5, 6]:
a.set_xlabel('t')
a.plot(prange[i // 2], [0.0, 0.0], 'k:', lw=1)
a.set_xlim(prange[i // 2])
a.set_xticks(np.arange(prange[i // 2][1] - 0.25 + 1))
a.set_yticks([-1.0, 0.0, 1.0])
def inset_label(ax: Axes, text: str, top=False):
ax.text(
0.98,
0.90 if top else 0.05,
text,
horizontalalignment="right",
verticalalignment="top" if top else "bottom",
transform=ax.transAxes,
)
for a, label in zip(ax[:, 0], labels):
text = r"$\varphi_{\textrm{LABEL}}$".replace("LABEL", label)
inset_label(a, text)
a.set_ylim([-1.0, 1.499])
for a, label in zip(ax[:, 1], labels):
text = r"$\psi_{\textrm{LABEL}}$".replace("LABEL", label)
inset_label(a, text, top=True)
# a.set_ylabel('$\\psi_{\\rm ' + i + '}[t]$')
a.set_ylim([-1.499, 1.999])
fig.tight_layout()
fig.savefig(Path(f"~/tmp/wavelets.pdf").expanduser())
# # Spectral Response of Scaling Functions and Wavelets
fig2: Figure = plt.figure(figsize=figsize_from_page_fraction())
ax: Axes = fig2.gca()
def fourier_wavelet(h, g, n):
k = np.linspace(0, np.pi, n)
fphi = np.zeros(n, dtype=complex)
fpsi = np.zeros(n, dtype=complex)
N = len(h)
for i in range(N):
fphi += np.exp(1j * k * i) * h[i] / np.sqrt(2)
fpsi += np.exp(1j * k * i) * g[i] / np.sqrt(2)
return k, fphi, fpsi
# ax.plot([0, np.pi], [0., 0.], 'k:')
# ax.plot([0, np.pi], [1., 1.], 'k:')
kh, fphih, fpsih = fourier_wavelet(h_Haar, g_Haar, 256)
ax.plot(kh, np.abs(fphih) ** 2, label=r"$\hat\varphi_\textrm{Haar (DB2)}$", c="C0")
ax.plot(
kh,
np.abs(fpsih) ** 2,
label=r"$\hat\psi_\textrm{Haar (DB2)}$",
c="C0",
linestyle="dashed",
)
kdb4, fphidb4, fpsidb4 = fourier_wavelet(h_DB4, g_DB4, 256)
ax.plot(kdb4, np.abs(fphidb4) ** 2, label=r"$\hat\varphi_\textrm{DB4}$", c="C1")
ax.plot(
kdb4,
np.abs(fpsidb4) ** 2,
label=r"$\hat\psi_\textrm{DB2}$",
c="C1",
linestyle="dashed",
)
kdb8, fphidb8, fpsidb8 = fourier_wavelet(h_DB8, g_DB8, 256)
ax.plot(kdb8, np.abs(fphidb8) ** 2, label=r"$\hat\varphi_\textrm{DB8}$", c="C2")
ax.plot(
kdb8,
np.abs(fpsidb8) ** 2,
label=r"$\hat\psi_\textrm{DB4}$",
c="C2",
linestyle="dashed",
)
# all k* are np.linspace(0, np.pi, 256), so we can also use them for shannon
def shannon(k):
y = np.zeros_like(k)
y[k > pi / 2] = 1
return y
ax.plot(kdb8, 1 - shannon(kdb8), label=r"$\hat\varphi_\textrm{shannon}$", c="C4")
ax.plot(
kdb8,
shannon(kdb8),
label=r"$\hat\psi_\textrm{shannon}$",
c="C4",
linestyle="dashed",
)
# ax.plot(kdb8, np.abs(fpsidb8) ** 2, label='$\\hat\\psi_{DB8}$', c="C3", linestyle="dashed")
# kdb16, fphidb16, fpsidb16 = fourier_wavelet(h_DB16, g_DB16, 256)
# ax.plot(kdb16, np.abs(fphidb16) ** 2, label='$\\hat\\varphi_{DB16}$', c="C4")
# ax.plot(kdb16, np.abs(fpsidb16) ** 2, label='$\\hat\\psi_{DB16}$', c="C4", linestyle="dashed")
lines: List[Line2D] = ax.get_lines()
philines = []
psilines = []
for line in lines:
if "phi" in line.get_label():
philines.append(line)
else:
psilines.append(line)
leg1 = ax.legend(frameon=False, handles=philines, loc="center left")
leg2 = ax.legend(frameon=False, handles=psilines, loc="center right")
ax.add_artist(leg1)
ax.add_artist(leg2)
ax.set_xlabel("k")
ax.set_ylabel("P(k)")
ax.set_xticks([0, pi / 2, pi])
ax.set_xticklabels(
["0", r"$k_\textrm{coarse}^\textrm{ny}$", r"$k_\textrm{fine}^\textrm{ny}$"]
)
# plt.semilogy()
# plt.ylim([1e-4,2.0])
fig2.tight_layout()
fig2.savefig(Path(f"~/tmp/wavelet_power.pdf").expanduser())
plt.show()
f"C{i}"