Evaluating the Transfer of Information in Phase Retrieval STEM Techniques
Contents
iCOM SSNR Widget
# enable interactive matplotlib
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.patches as mpatches
import ctf # import custom plotting / utils
import cmasher as cmr
import ipywidgets
from IPython.display import display
4D STEM Simulation¶
# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms
scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0
cmap = cmr.eclipse
segmented_icom_line_color = 'cornflowerblue'
pixelated_icom_line_color = 'midnightblue'
White Noise Potential¶
def white_noise_object_2D(n, phi0):
""" creates a 2D real-valued array, whose FFT has random phase and constant amplitude """
evenQ = n%2 == 0
# indices
pos_ind = np.arange(1,(n if evenQ else n+1)//2)
neg_ind = np.flip(np.arange(n//2+1,n))
# random phase
arr = np.random.randn(n,n)
# top-left // bottom-right
arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
# bottom-left // top-right
arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
# kx=0
arr[0,pos_ind] = -arr[0,neg_ind]
# ky=0
arr[pos_ind,0] = -arr[neg_ind,0]
# zero-out components which don't have k-> -k mapping
if evenQ:
arr[n//2,:] = 0 # zero highest spatial freq
arr[:,n//2] = 0 # zero highest spatial freq
arr[0,0] = 0 # DC component
# fourier-array
arr = np.exp(2j*np.pi*arr)*phi0
# inverse FFT and remove floating point errors
arr = np.fft.ifft2(arr).real
return arr
# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)
Probe¶
# we build probe in Fourier space, using a soft aperture
qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q = np.sqrt(q2)
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))
probe_array_fourier_0 = np.sqrt(
np.clip(
(q_probe - q)/reciprocal_sampling + 0.5,
0,
1,
),
)
def simulate_intensities(defocus):
probe_array_fourier = probe_array_fourier_0 * np.exp(-1j * np.pi * wavelength * q**2 * defocus)
# normalized s.t. np.sum(np.abs(probe_array_fourier)**2) = 1.0
probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
# we then take the inverse FFT, and normalize s.t. np.sum(np.abs(probe_array)**2) = 1.0
probe_array = np.fft.ifft2(probe_array_fourier) * n
intensities = ctf.simulate_data(
complex_obj,
probe_array,
row,
col,
).reshape((sx,sy,n,n))**2 / n**2
return intensities, probe_array_fourier
ints, probe = simulate_intensities(defocus=0)
intensities = [ints]
probe_array_fourier = [probe]
# precompute sum needed for DPC
intensities_sum = [intensities[0].sum((-1,-2))]
Virtual Detectors and CoM calculation¶
def annular_segmented_detectors(
gpts,
sampling,
n_angular_bins,
rotation_offset = 0,
inner_radius = 0,
outer_radius = np.inf,
):
""" """
nx,ny = gpts
sx,sy = sampling
k_x = np.fft.fftfreq(nx,sx)
k_y = np.fft.fftfreq(ny,sy)
k = np.sqrt(k_x[:,None]**2 + k_y[None,:]**2)
radial_mask = ((inner_radius <= k) & (k < outer_radius))
theta = (np.arctan2(k_y[None,:], k_x[:,None]) + rotation_offset) % (2 * np.pi)
angular_bins = np.floor(n_angular_bins * (theta / (2 * np.pi))) + 1
angular_bins *= radial_mask.astype("int")
angular_bins = [np.fft.fftshift((angular_bins == i).astype("int")) for i in range(1,n_angular_bins+1)]
return angular_bins
def compute_com_using_virtual_detectors(
corner_centered_intensities,
center_centered_masks,
corner_centered_intensities_sum,
sx,sy,
kxa,kya,
):
""" """
masks = np.fft.ifftshift(np.asarray(center_centered_masks),axes=(-1,-2))
com_x = np.zeros((sx,sy))
com_y = np.zeros((sx,sy))
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')
for mask in masks:
kxa_i,kya_i=np.where(mask)
patches= corner_centered_intensities[:,:,kxa_i,kya_i].sum(-1) / corner_centered_intensities_sum
com_x += patches * np.mean(kxa[kxa_i,kya_i])
com_y += patches * np.mean(kya[kxa_i,kya_i])
return com_x, com_y
def integrate_com(
com_x,
com_y,
kx_op,
ky_op,
):
""" """
icom_fft = np.fft.fft2(com_x)*kx_op + np.fft.fft2(com_y)*ky_op
return np.real(np.fft.ifft2(icom_fft))
Compute CTFs and initial values¶
# Spatial frequencies
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')
k2 = kxa**2 + kya**2
k = np.sqrt(k2)
k2[0, 0] = np.inf
# iCoM operators
kx_op = -1.0j * kxa / k2
ky_op = -1.0j * kya / k2
# Compute the inverse error
inverse_error = (k*np.pi/np.sqrt(2))
# Analytical CTF (probe autocorrelation)
ctf_analytic = np.real(
np.fft.ifft2(
np.abs(
np.fft.fft2(
probe_array_fourier[0]
)
)**2
)
)
# Radially-averaged CTF and SNR
q_bins_analytic_snr, I_bins_analytic_snr = ctf.radially_average_ctf(ctf_analytic*inverse_error,(sampling,sampling))
# Initial masks and CoM
virtual_masks_annular = annular_segmented_detectors(
gpts=(n,n),
sampling=(sampling,sampling),
n_angular_bins=4,
inner_radius=q_probe/2,
outer_radius=q_probe*1.05,
rotation_offset=0,
)
com_x, com_y = compute_com_using_virtual_detectors(
intensities[0],
virtual_masks_annular,
intensities_sum[0],
sx,sy,
kxa,kya,
)
icom_annular = integrate_com(com_x,com_y,kx_op,ky_op)
ctf_annular = ctf.compute_ctf(icom_annular)
q_bins_annular_snr, I_bins_annular_snr = ctf.radially_average_ctf(
ctf_annular*inverse_error,
(sampling,sampling)
)
Visualization¶
Base Plot¶
We make the interactive plot using the initial values, and name the artists (imshow, plot) we want to modify later.
Note: I use 2-98% histogram scaling, and I normalize the values to lie within 0-1 (to avoid having to modify the clims)
with plt.ioff():
dpi=72
fig, axs = plt.subplots(1,4,figsize=(640/dpi,210/dpi),dpi=dpi)
# detector
ax_detector = axs[0]
im_detector = ax_detector.imshow(ctf.combined_images_rgb(virtual_masks_annular))
ctf.add_scalebar(ax_detector,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')
# analytic SNR
ax_snr_analytic = axs[1]
im_snr_analytic = ax_snr_analytic.imshow(ctf.histogram_scaling(np.fft.fftshift((ctf_analytic*inverse_error)),normalize=True),cmap=cmap)
# annular SNR
ax_snr_annular = axs[2]
im_snr = ax_snr_annular.imshow(ctf.histogram_scaling(np.fft.fftshift(ctf_annular*inverse_error),normalize=True),cmap=cmap)
# analytic SNR radially-averaged
ax_snr_rad = axs[3]
plot_snr_analytic = ax_snr_rad.plot(q_bins_analytic_snr,I_bins_analytic_snr,color=pixelated_icom_line_color)[0]
plot_snr = ax_snr_rad.plot(q_bins_annular_snr, I_bins_annular_snr, color=segmented_icom_line_color)[0]
# remove ticks, add titles to 2D-plots
for ax, title in zip(
axs.T.flatten(),
[
"detector geometry",
"pixelated SSNR",
"segmented SSNR",
"radially averaged SSNR"
]
):
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title)
for ax in axs[:3]:
ctf.add_scalebar(ax,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')
# remove y-ticks, add x-label, add vlines to bottom row
ax_snr_rad.set_ylim([0,1])
ax_snr_rad.set_xlim([0,q_max])
ax_snr_rad.vlines([q_probe/2,q_probe*1.05],0,2,colors='k',linestyles='--',linewidth=1,)
ax_snr_rad.set_xticks([0,q_probe,q_max])
ax_snr_rad.set_xticklabels([0,1,2])
ax_snr_rad.set_xlabel(r"spatial frequency, $q/q_{\mathrm{probe}}$")
# fix ipympl canvas from resizing
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.toolbar_position = 'bottom'
fig.canvas.layout.width = '640px'
fig.canvas.layout.height = '235px'
fig.tight_layout()
# fig
Interactive Updating¶
We need to update 6 parts of the plot each time we update:
- The virtual masks (
im_detector
) - The annular SNR (
im_snr
) - The radially-averaged annular SNR (
plot_snr
) - The vlines on the radially-averaged plots
- This one doesn’t have a single artist we can update. Instead we remove all LineCollections and replot
Widget¶
We define our 4 sliders, as-well as two callback functions to update the outer collection angle minimum and the meaningful rotation offset range
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}
inner_collection_angle_slider = ipywidgets.FloatSlider(
value = q_probe/2,
min = 0,
max = q_probe,
step = q_probe/20,
description = r"inner collection angle [$q_{\mathrm{probe}}$]",
**kwargs
)
outer_collection_angle_slider = ipywidgets.FloatSlider(
value = q_probe*1.05,
min = q_probe/20,
max = q_max,
step = q_probe/20,
description = r"outer collection angle [$q_{\mathrm{probe}}$]",
**kwargs
)
number_of_segments_slider = ipywidgets.IntSlider(
value = 4,
min = 3,
max = 16,
step = 1,
description = "number of annular segments",
**kwargs
)
rotation_offset_slider = ipywidgets.IntSlider(
value = 0, min = 0, max = 180/4, step = 1,
description = "rotation offset [°]",
**kwargs
)
number_of_rings_slider = ipywidgets.IntSlider(
value = 1,
min = 1,
max = 8,
step = 1,
description = "number of radial rings",
**kwargs
)
rotate_half_the_rings = ipywidgets.ToggleButton(
value = False,
description = 'offset radial rings',
disabled = False,
layout=ipywidgets.Layout(width="155px",height="30px")
)
area_toggle = ipywidgets.ToggleButton(
value = False,
description = 'distribute by area',
layout=ipywidgets.Layout(width="155px",height="30px")
)
def update_outer_collection_angle(change):
value = change['new']
outer_collection_angle_slider.min = value*1.05
def update_inner_collection_angle(change):
value = change['new']
inner_collection_angle_slider.max = value
inner_collection_angle_slider.observe(update_outer_collection_angle, names='value')
outer_collection_angle_slider.observe(update_inner_collection_angle, names='value')
# rotation offset is modulo 180/n
def update_rotation_offset_range(change):
value = change['new']
rotation_offset_slider.max = 180/value
number_of_segments_slider.observe(update_rotation_offset_range, names='value')
def disable_all(boolean):
inner_collection_angle_slider.disabled = boolean
outer_collection_angle_slider.disabled = boolean
number_of_segments_slider.disabled = boolean
rotation_offset_slider.disabled = boolean
number_of_rings_slider.disabled = boolean
rotate_half_the_rings.disabled = boolean
area_toggle.disabled = boolean
defocus_slider.disabled = boolean
simulate_button.disabled = boolean
return None
defocus_slider = ipywidgets.IntSlider(
value = 0,
min = -n,
max = n,
step = 1,
description = r'negative defocus, $C_{1,0}$ [Å]',
**kwargs
)
def defocus_wrapper(*args):
im_snr.set_alpha(0.25)
im_snr_analytic.set_alpha(0.25)
plot_snr.set_alpha(0.25)
plot_snr_analytic.set_alpha(0.25)
simulate_button.button_style = 'warning'
defocus_slider.observe(defocus_wrapper,names='value')
simulate_button = ipywidgets.Button(
description='simulate (expensive)',
layout=ipywidgets.Layout(width="320px",height="30px")
)
def simulate_wrapper(*args):
disable_all(True)
simulate(
defocus_slider.value,
)
im_snr.set_alpha(1)
im_snr_analytic.set_alpha(1)
plot_snr.set_alpha(1)
plot_snr_analytic.set_alpha(1)
simulate_button.button_style = ''
disable_all(False)
simulate_button.on_click(simulate_wrapper)
def update_figure(
*args,
):
""" """
# compute new datasets
virtual_masks_annular = []
if area_toggle.value:
ring_collection_angles = np.linspace(
inner_collection_angle_slider.value**2,
outer_collection_angle_slider.value**2,
num=number_of_rings_slider.value + 1
)**(1/2)
else:
ring_collection_angles = np.linspace(
inner_collection_angle_slider.value,
outer_collection_angle_slider.value,
num=number_of_rings_slider.value + 1
)
if rotate_half_the_rings.value:
ring_rotation = np.deg2rad((180/number_of_segments_slider.value))
else:
ring_rotation = 0
for i in range(1,number_of_rings_slider.value+1):
j = i-1
virtual_masks_annular.append(
annular_segmented_detectors(
gpts=(n,n),
sampling=(sampling,sampling),
n_angular_bins=number_of_segments_slider.value,
inner_radius=ring_collection_angles[j],
outer_radius=ring_collection_angles[i],
rotation_offset=np.deg2rad(rotation_offset_slider.value) + ring_rotation*(j%2),
)
)
virtual_masks_annular = np.vstack(virtual_masks_annular)
com_x, com_y = compute_com_using_virtual_detectors(
intensities[0],
virtual_masks_annular,
intensities_sum[0],
sx,sy,
kxa,kya,
)
icom_annular = integrate_com(com_x,com_y,kx_op,ky_op)
ctf_annular = ctf.compute_ctf(icom_annular)
q_bins_annular_snr, I_bins_annular_snr = ctf.radially_average_ctf(
ctf_annular*inverse_error,
(sampling,sampling)
)
# update data
# 2D arrays
im_detector.set_data(ctf.combined_images_rgb(virtual_masks_annular))
im_snr.set_data(ctf.histogram_scaling(np.fft.fftshift(ctf_annular*inverse_error),normalize=True))
# 1D line
plot_snr.set_ydata(I_bins_annular_snr)
# collections (vlines)
axs[3].collections[0].remove()
axs[3].vlines([inner_collection_angle_slider.value,outer_collection_angle_slider.value],0,2,colors='k',linestyles='--',linewidth=1,)
# re-draw figure
fig.canvas.draw_idle()
return None
inner_collection_angle_slider.observe(update_figure,names='value')
outer_collection_angle_slider.observe(update_figure,names='value')
number_of_segments_slider.observe(update_figure,names='value')
rotation_offset_slider.observe(update_figure,names='value')
number_of_rings_slider.observe(update_figure,names='value')
rotate_half_the_rings.observe(update_figure,names='value')
area_toggle.observe(update_figure,names='value')
def simulate(
defocus,
):
""" """
intensities[0], probe_array_fourier[0] = simulate_intensities(
defocus=defocus,
)
update_analytical()
update_figure("dummy")
return None
def update_analytical():
""" """
intensities_sum[0] = intensities[0].sum((-1,-2))
# Analytical CTF (probe autocorrelation)
ctf_analytic = np.real(
np.fft.ifft2(
np.abs(
np.fft.fft2(
probe_array_fourier[0]
)
)**2
)
)
# Radially-averaged CTF and SNR
q_bins_analytic_snr, I_bins_analytic_snr = ctf.radially_average_ctf(ctf_analytic*inverse_error,(sampling,sampling))
# analytic
im_snr_analytic.set_data(ctf.histogram_scaling(np.fft.fftshift(ctf_analytic*inverse_error),normalize=True))
# 1D lines
plot_snr_analytic.set_ydata(I_bins_analytic_snr)
fig.canvas.draw_idle()
return None
# Annular Segmented Detectors
display(
ipywidgets.VBox(
[
ipywidgets.VBox(
[
ipywidgets.HBox([defocus_slider, simulate_button]),
ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
ipywidgets.HBox([inner_collection_angle_slider,outer_collection_angle_slider]),
ipywidgets.HBox([number_of_segments_slider,rotation_offset_slider]),
ipywidgets.HBox([number_of_rings_slider,rotate_half_the_rings,area_toggle]),
]
),
fig.canvas
]
)
)
VBox(children=(VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='negative…