Evaluating the Transfer of Information in Phase Retrieval STEM Techniques
Contents
Pixelated Parallax
# enable interactive matplotlib
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import ctf # import custom plotting / utils
import cmasher as cmr
import ipywidgets
# parameters
n = 384
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms
bin_value = n // 96
C10 = -128
C30 = 0
cmap = cmr.viola
sample_cmap = 'gray'
pixelated_parallax_line_color = 'darkred'
sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()
potentials = [sto_potential,mof_potential,apo_potential]
sto_sampling = 23.67 / sto_potential.shape[0] # Å
mof_sampling = 4.48 / mof_potential.shape[0] # nm
apo_sampling = 19.2 / apo_potential.shape[0] # nm
def autocorrelation(array):
""" """
return np.real(
np.fft.ifft2(
np.abs(
np.fft.fft2(
array
)
)**2
)
)
def return_chi(
q,
wavelength,
C10,
C30,
):
""" """
prefactor = 2*np.pi / wavelength
alpha = q*wavelength
order_2 = alpha**2 / 2 * C10
order_4 = alpha**4 / 4 * C30
return (order_2 + order_4) * prefactor
qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q = np.sqrt(q2)
probe_array_fourier_0 = np.sqrt(
np.clip(
(q_probe - q)/reciprocal_sampling + 0.5,
0,
1,
),
)
probe_array_fourier_0 /= np.sqrt(np.sum(np.abs(probe_array_fourier_0)**2))
chi = return_chi(
q,
wavelength,
C10,
C30
)
sin_chi = -np.sin(chi)
parallax_ctf_2D = autocorrelation(probe_array_fourier_0) * sin_chi
q_bins, I_bins = ctf.radially_average_ctf(
parallax_ctf_2D,
(sampling,sampling)
)
binned_ctf_to_96 = np.abs(parallax_ctf_2D).reshape(
(
n//bin_value,
bin_value,
n//bin_value,
bin_value
)
).mean((1,3))
zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
np.pad(np.fft.fftshift(binned_ctf_to_96),48)
)
convolved_object_sto = np.fft.ifft2(
np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe
).real
convolved_object_mof = np.fft.ifft2(
np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe
).real
convolved_object_apo = np.fft.ifft2(
np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe
).real
sto_limits = [convolved_object_sto.min(),convolved_object_sto.max()]
mof_limits = [convolved_object_mof.min(),convolved_object_mof.max()]
apo_limits = [convolved_object_apo.min(),convolved_object_apo.max()]
limits = [sto_limits,mof_limits,apo_limits]
with plt.ioff():
dpi=72
fig, axs = plt.subplots(1,3,figsize=(640/dpi,270/dpi),dpi=dpi)
im_ctf = axs[0].imshow(
np.fft.fftshift(
parallax_ctf_2D
),
vmin=-1,
vmax=1,
cmap=cmap
)
ctf.add_scalebar(
axs[0],
length=n//4,
sampling=reciprocal_sampling,
units=r'$q_{\mathrm{probe}}$',
color='black'
)
axs[0].set(xticks=[],yticks=[],title="contrast transfer function (CTF)")
plot_ctf = axs[1].plot(
q_bins,
I_bins,
color=pixelated_parallax_line_color
)[0]
axs[1].axhline(0,color='black',lw=1,linestyle='--')
axs[1].set(
xlim=[0,2],
ylim=[-1,1],
aspect= 1,
xticks=[0,1,2],
yticks=[],
xlabel=r"spatial frequency, $q/q_{\mathrm{probe}}$",
title="radially averaged CTF"
)
im_obj = axs[2].imshow(
convolved_object_sto,
cmap=sample_cmap,
vmin=sto_limits[0],
vmax=sto_limits[1]
)
ctf.add_scalebar(
axs[2],
length=40,
sampling=sto_sampling,
units=r'Å',
size_vertical=2
)
ctf.add_scalebar(
axs[2],
length=40,
sampling=mof_sampling,
units=r'nm',
size_vertical=2
)
ctf.add_scalebar(
axs[2],
length=40,
sampling=apo_sampling,
units=r'nm',
size_vertical=2
)
sto_scalebar, mof_scalebar, apo_scalebar = axs[2].artists
mof_scalebar.set_visible(False)
apo_scalebar.set_visible(False)
axs[2].set(xticks=[],yticks=[],title="CTF-convolved weak phase object")
fig.tight_layout()
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
fig.canvas.layout.height = "280px"
fig.canvas.layout.width = '640px'
None
style = {'description_width': 'initial'}
layout_half = ipywidgets.Layout(width="320px",height="30px")
layout_quarter = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout_half}
kwargs_quarter = {'style':style,'layout':layout_quarter}
C10_slider = ipywidgets.FloatSlider(
value = -128,
min = -500,
max = 500,
step = 1,
description = r"negative defocus, $C_{1,0}$ [Å]",
**kwargs
)
C30_slider = ipywidgets.FloatSlider(
value = 0,
min = -100,
max = 100,
step = 0.1,
description = r"spherical aberration, $C_{3,0}$ [µm]",
**kwargs
)
scherzer_button = ipywidgets.Button(
description="use Scherzer defocus",
**kwargs_quarter
)
clim_button = ipywidgets.ToggleButton(
value=False,
description="use relative scaling",
**kwargs_quarter
)
phase_flip_button = ipywidgets.ToggleButton(
value=True,
description="correct phase flipping",
**kwargs_quarter
)
object_dropdown = ipywidgets.Dropdown(
options=[("strontium titanate",0),("metal-organic framework",1),("apoferritin protein",2)],
**kwargs_quarter
)
def update_ctf(*args):
""" """
C10 = C10_slider.value
C30 = C30_slider.value * 1e4
object_index = object_dropdown.value
chi = return_chi(
q,
wavelength,
C10,
C30
)
sin_chi = -np.sin(chi)
parallax_ctf_2D = autocorrelation(probe_array_fourier_0) * sin_chi
q_bins, I_bins = ctf.radially_average_ctf(
parallax_ctf_2D,
(sampling,sampling)
)
if phase_flip_button.value:
_parallax_ctf_2D = np.abs(parallax_ctf_2D)
else:
_parallax_ctf_2D = parallax_ctf_2D
binned_ctf_to_96 = _parallax_ctf_2D.reshape(
(
n//bin_value,
bin_value,
n//bin_value,
bin_value
)
).mean((1,3))
zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
np.pad(np.fft.fftshift(binned_ctf_to_96),48)
)
chosen_potential = potentials[object_index]
convolved_object = np.fft.ifft2(
np.fft.fft2(chosen_potential) * zero_pad_ctf_to_4qprobe
).real
im_ctf.set_data(np.fft.fftshift(parallax_ctf_2D))
plot_ctf.set_ydata(I_bins)
if clim_button.value:
convolved_object = ctf.histogram_scaling(convolved_object,normalize=True)
im_obj.set_data(convolved_object)
im_obj.set_clim(vmin=0,vmax=1)
else:
im_obj.set_data(convolved_object)
im_obj.set_clim(
vmin=limits[object_index][0],
vmax=limits[object_index][1]
)
sto_scalebar.set_visible(object_index==0)
mof_scalebar.set_visible(object_index==1)
apo_scalebar.set_visible(object_index==2)
fig.canvas.draw_idle()
return None
C10_slider.observe(update_ctf,"value")
C30_slider.observe(update_ctf,"value")
object_dropdown.observe(update_ctf,"value")
def apply_scherzer(*args):
""" """
Cs = C30_slider.value*1e4
C10_slider.value = -np.sign(Cs) * np.sqrt(3/2*np.abs(Cs)*wavelength)
return None
scherzer_button.on_click(apply_scherzer)
clim_button.observe(update_ctf,"value")
phase_flip_button.observe(update_ctf,"value")
ipywidgets.VBox(
[
ipywidgets.HBox([C10_slider,C30_slider]),
ipywidgets.HBox([scherzer_button,clim_button,phase_flip_button,object_dropdown]),
fig.canvas
]
)
VBox(children=(HBox(children=(FloatSlider(value=-128.0, description='negative defocus, $C_{1,0}$ [Å]', layout=…