Evaluating the Transfer of Information in Phase Retrieval STEM Techniques

Segmented SSB

# enable interactive matplotlib
%matplotlib widget 

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.patches as mpatches
import ctf # import custom plotting / utils
import cmasher as cmr
from tqdm.notebook import tqdm

import ipywidgets
from IPython.display import display

4D STEM Simulation

# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms

scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0

cmap = cmr.eclipse
sample_cmap = 'gray'
icom_line_color = 'cornflowerblue'
iter_ptycho_line_color = 'mediumvioletred'

pixelated_ssb_line_color = 'darkgreen'
segmented_ssb_line_color = 'yellowgreen'

White Noise Potential

def white_noise_object_2D(n, phi0):
    """ creates a 2D real-valued array, whose FFT has random phase and constant amplitude """

    evenQ = n%2 == 0
    
    # indices
    pos_ind = np.arange(1,(n if evenQ else n+1)//2)
    neg_ind = np.flip(np.arange(n//2+1,n))

    # random phase
    arr = np.random.randn(n,n)
    
    # top-left // bottom-right
    arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
    # bottom-left // top-right
    arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
    # kx=0
    arr[0,pos_ind] = -arr[0,neg_ind]
    # ky=0
    arr[pos_ind,0] = -arr[neg_ind,0]

    # zero-out components which don't have k-> -k mapping
    if evenQ:
        arr[n//2,:] = 0 # zero highest spatial freq
        arr[:,n//2] = 0 # zero highest spatial freq

    arr[0,0] = 0 # DC component

    # fourier-array
    arr = np.exp(2j*np.pi*arr)*phi0

    # inverse FFT and remove floating point errors
    arr = np.fft.ifft2(arr).real
    
    return arr

# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)

sx, sy = potential.shape

Import sample potentials

sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()

Probe

def soft_aperture(q,q_probe,reciprocal_sampling):
    return np.sqrt(
        np.clip(
            (q_probe - q)/reciprocal_sampling + 0.5,
            0,
            1,
        ),
    )

def hard_aperture(q,q_probe,reciprocal_sampling):
    return ((q_probe - q)>0).astype(np.float64)

qx = qy = np.fft.fftfreq(n,sampling)
q = np.sqrt(qx[:,None]**2 + qy[None,:]**2)

Kx = qx
Ky = qy

K = np.sqrt(Kx[:,None]**2 + Ky[None,:]**2)

Qx = qx
Qy = qy

x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))
def simulate_intensities(defocus, use_soft_aperture, batch_size=n**2, pbar=None):

    m = n**2
    n_batch = int(m // batch_size)
    order = np.arange(m).reshape((n_batch,batch_size))
    intensities = np.zeros((m,n,n))

    if pbar is not None:
        pbar.reset(n_batch)
        pbar.colour = None
        pbar.refresh()

    aperture = soft_aperture if use_soft_aperture else hard_aperture
    probe_array_fourier = aperture(q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q**2)
    probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
    probe_array = np.fft.ifft2(probe_array_fourier) * n
    
    for batch_index in range(n_batch):
        batch_order = order[batch_index]
        intensities[batch_order] = ctf.simulate_data(
            complex_obj,
            probe_array,
            row[batch_order],
            col[batch_order],
        )
        if pbar is not None:
            pbar.update(1)

    if pbar is not None:
        pbar.colour = 'green'
        
    return intensities.reshape((sx,sy,n,n))**2 / n**2

intensities = [
    simulate_intensities(
        defocus=0,
        use_soft_aperture=False,
        batch_size=1024,
        pbar=None,
    )
]
intensities_FFT = [np.fft.fft2(intensities[0],axes=(0,1))]
def annular_segmented_detectors(
    gpts,
    sampling,
    n_angular_bins,
    rotation_offset = 0,
    inner_radius = 0,
    outer_radius = np.inf,
):
    """ """
    nx,ny = gpts
    sx,sy = sampling

    k_x = np.fft.fftfreq(nx,sx)
    k_y = np.fft.fftfreq(ny,sy)

    k = np.sqrt(k_x[:,None]**2 + k_y[None,:]**2)
    radial_mask = ((inner_radius <= k) & (k < outer_radius))
    
    theta = (np.arctan2(k_y[None,:], k_x[:,None]) + rotation_offset) % (2 * np.pi)
    angular_bins = np.floor(n_angular_bins * (theta / (2 * np.pi))) + 1
    angular_bins *= radial_mask.astype("int")

    angular_bins = [np.fft.fftshift((angular_bins == i).astype("int")) for i in range(1,n_angular_bins+1)]
    
    return angular_bins

def mask_intensities_using_virtual_detectors(
    corner_centered_intensities,
    corner_centered_masked_intensities,
    center_centered_masks,
): 
    """ """
    
    masks = np.fft.ifftshift(np.asarray(center_centered_masks).astype(np.bool_),axes=(-1,-2))
    inverse_mask = (1-masks.sum(0)).astype(np.bool_)

    for mask in masks:
        val = np.sum(corner_centered_intensities * mask,axis=(-1,-2)) / np.sum(mask)
        corner_centered_masked_intensities[...,mask] = val[...,None]
    corner_centered_masked_intensities[...,inverse_mask] = 0.0 
    
    return None

def mask_gamma_using_virtual_detectors(
    corner_centered_gamma,
    center_centered_masks,
): 
    """ """
    
    masks = np.fft.ifftshift(np.asarray(center_centered_masks).astype(np.bool_),axes=(-1,-2))
    inverse_mask = (1-masks.sum(0)).astype(np.bool_)

    for mask in masks:
        val = np.sum(corner_centered_gamma * mask) / np.sum(mask)
        corner_centered_gamma[mask] = val
    corner_centered_gamma[inverse_mask] = 0.0 

    return None
virtual_masks_annular = [np.zeros((n,n))]
virtual_masks_annular[0][0,0] = 1

masked_intensities_FFT = [np.zeros_like(intensities_FFT[0])]
masked_intensities_FFT[0][0,0] = 1
def ptychography_reconstruction(
    masked_intensities_FFT,
    virtual_masks_annular,
    defocus,
    use_soft_aperture,
    use_OBF_weighting,
    intensities_FFT=None,
    pbar=None
): 
    aperture = soft_aperture if use_soft_aperture else hard_aperture
    threshold = 1e-3 if use_soft_aperture else 0.0
    psi = np.empty_like(complex_obj)
    
    if intensities_FFT is not None:
        psi_0 = np.empty_like(complex_obj)
    
    A_q = aperture(K,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*K**2)
    A_q_conj = A_q.conj()

    if use_OBF_weighting:
        probe_normalization = np.abs(A_q)**2
        probe_normalization /= probe_normalization.sum()

        mask_gamma_using_virtual_detectors(
            probe_normalization,
            virtual_masks_annular,
        )

    if pbar is not None:
        pbar.reset(sx*sy)
        pbar.colour = None
        pbar.refresh()
    
    for ind_x in range(sx):
        for ind_y in range(sy):
            G = masked_intensities_FFT[ind_x,ind_y]
            if intensities_FFT is not None:
                G_0 = intensities_FFT[ind_x,ind_y]
            if ind_x == 0 and ind_y == 0 :
                psi[ind_x,ind_y] = np.abs(G).sum()
                if intensities_FFT is not None:
                    psi_0[ind_x,ind_y] = np.abs(G_0).sum()          
            else:
                q_plus_Q = np.sqrt((Kx[:,None]+Qx[ind_x])**2 + (Ky[None,:]+Qy[ind_y])**2)
                A_q_plus_Q = aperture(q_plus_Q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q_plus_Q**2)
        
                q_minus_Q = np.sqrt((Kx[:,None]-Qx[ind_x])**2 + (Ky[None,:]-Qy[ind_y])**2)
                A_q_minus_Q = aperture(q_minus_Q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q_minus_Q**2)
    
                gamma = A_q_conj * A_q_minus_Q - A_q * A_q_plus_Q.conj()
                if intensities_FFT is not None:
                    gamma_abs = np.abs(gamma)
                    gamma_ind = gamma_abs > threshold
                    psi_0[ind_x,ind_y] = (G_0[gamma_ind] *  np.conj(gamma[gamma_ind])/gamma_abs[gamma_ind]).sum()
                    
                mask_gamma_using_virtual_detectors(
                    gamma,
                    virtual_masks_annular,
                )
                gamma_abs = np.abs(gamma)
                gamma_ind = gamma_abs > threshold
                
                normalization = gamma_abs[gamma_ind]
                
                if use_OBF_weighting:
                    d = probe_normalization[gamma_ind]
                    normalization = d * np.sqrt(np.sum(normalization**2 / d))
                    
                psi[ind_x,ind_y] = (G[gamma_ind] *  np.conj(gamma[gamma_ind])/normalization).sum()

            if pbar is not None:
                pbar.update(1)

    if pbar is not None:
        pbar.colour = 'green'

    return_val = (np.fft.ifft2(psi),None) if intensities_FFT is None else (np.fft.ifft2(psi),np.fft.ifft2(psi_0))
    return return_val

def ptychography_reconstruction_pixelated(
    intensities_FFT,
    defocus,
    use_soft_aperture,
    use_OBF_weighting,
    pbar=None,
): 
    aperture = soft_aperture if use_soft_aperture else hard_aperture
    threshold = 1e-3 if use_soft_aperture else 0.0
    psi = np.empty_like(complex_obj)
    
    A_q = aperture(K,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*K**2)
    A_q_conj = A_q.conj()
    
    if use_OBF_weighting:
        probe_normalization = np.abs(A_q)**2
        probe_normalization /= probe_normalization.sum()

    if pbar is not None:
        pbar.reset(sx*sy)
        pbar.colour = None
        pbar.refresh()
    
    for ind_x in range(sx):
        for ind_y in range(sy):
            G = intensities_FFT[ind_x,ind_y]
            if ind_x == 0 and ind_y == 0 :
                psi[ind_x,ind_y] = np.abs(G).sum()        
            else:
                q_plus_Q = np.sqrt((Kx[:,None]+Qx[ind_x])**2 + (Ky[None,:]+Qy[ind_y])**2)
                A_q_plus_Q = aperture(q_plus_Q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q_plus_Q**2)
        
                q_minus_Q = np.sqrt((Kx[:,None]-Qx[ind_x])**2 + (Ky[None,:]-Qy[ind_y])**2)
                A_q_minus_Q = aperture(q_minus_Q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q_minus_Q**2)
    
                gamma = A_q_conj * A_q_minus_Q - A_q * A_q_plus_Q.conj()
                gamma_abs = np.abs(gamma)
                gamma_ind = gamma_abs > threshold
                normalization = gamma_abs[gamma_ind]
                
                if use_OBF_weighting:
                    d = probe_normalization[gamma_ind]
                    normalization = d * np.sqrt(np.sum(normalization**2 /d))
                    
                psi[ind_x,ind_y] = (G[gamma_ind] *  np.conj(gamma[gamma_ind])/normalization).sum()

            if pbar is not None:
                pbar.update(1)

    if pbar is not None:
        pbar.colour = 'green'

    return_val = np.fft.ifft2(psi)
    return return_val
recon_0 = ptychography_reconstruction_pixelated(
    intensities_FFT=intensities_FFT[0],
    defocus=0,
    use_soft_aperture=False,
    use_OBF_weighting=False,
)

numeric_ctf_0 = np.abs(np.fft.fft2(np.angle(recon_0))) / 2
numeric_ctf_0[0,0] = 0.0

q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(numeric_ctf_0,(sampling,sampling))
def mask_opacities(virtual_masks):
    n = len(virtual_masks)
    if n % 2 == 0:
        vals = np.tile([0.25,0.375],n)[:n]
    else:
        vals = np.tile([0.25,0.375],n)[:n] + [0.125]
        
    opacities = 1-np.tensordot(
        np.array(virtual_masks),
        vals,
        axes=(0,0)
    )
    return opacities
with plt.ioff():
    dpi=72
    fig, axs = plt.subplots(2,4,figsize=(640/dpi,400/dpi),dpi=dpi)

empty = np.zeros((n,n))
empty[0,0] = 1

ax_trotter_pixelated = axs[0,0]
im_trotter_pixelated = ax_trotter_pixelated.imshow(virtual_masks_annular[0])

ax_trotter_annular = axs[0,1]
im_trotter_annular = ax_trotter_annular.imshow(virtual_masks_annular[0])

ax_ctf = axs[0,2]
im_ctf = ax_ctf.imshow(empty,cmap=cmap)

ax_ctf_rad = axs[0,3]
plot_ctf_pixelated = ax_ctf_rad.plot(q_bins_pixelated,I_bins_pixelated,color=pixelated_ssb_line_color,label='pixelated SSB')[0]
plot_ctf = ax_ctf_rad.plot(np.linspace(0,q_max,n//2 + 1),np.zeros(n//2 + 1),color=segmented_ssb_line_color,label='segmented SSB')[0]

for ax, title in zip(
    axs.flatten(),
    [
        "pixelated trotter",
        "segmented trotter",
        "segmented CTF",
        "radially-averaged CTF",
        "white noise object",
        "strontium titanate",
        "metal-organic framework",
        "apoferritin protein",
    ]
):
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    
for ax in axs[0,:3]:
    ctf.add_scalebar(ax,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')

# remove y-ticks, add x-label, add vlines to radial avg. plot
ax_ctf_rad.set_ylim([0,1])
ax_ctf_rad.set_xlim([0,q_max])
ax_ctf_rad.vlines([q_probe/2,q_probe],0,2,colors='k',linestyles='--',linewidth=1,)
ax_ctf_rad.set_xticks([0,q_probe,q_max])
ax_ctf_rad.set_xticklabels([0,1,2])
ax_ctf_rad.set_yticks([])
ax_ctf_rad.set_aspect(2)
ax_ctf_rad.set_xlabel(r"spatial frequency, $q/q_{\mathrm{probe}}$")
ax_ctf_rad.legend()


ax_white_noise_obj = axs[1,0]
im_white_noise_obj = ax_white_noise_obj.imshow(
    np.zeros((n,n)),vmin=0,vmax=1,
    cmap=sample_cmap
)
ctf.add_scalebar(ax_white_noise_obj,length=n//5,sampling=sampling,units=r'Å')

ax_sto_obj = axs[1,1]
im_sto_obj = ax_sto_obj.imshow(
    np.zeros((n,n)),vmin=0,vmax=1,
    cmap=sample_cmap
)
sto_sampling = 23.67 / n  # Å
ctf.add_scalebar(ax_sto_obj,length=n//5,sampling=sto_sampling,units=r'Å')

ax_mof_obj = axs[1,2]
im_mof_obj = ax_mof_obj.imshow(
    np.zeros((n,n)),vmin=0,vmax=1,
    cmap=sample_cmap
)
mof_sampling = 4.48 / n  # nm
ctf.add_scalebar(ax_mof_obj,length=n//5,sampling=mof_sampling,units=r'nm')

ax_apo_obj = axs[1,3]
im_apo_obj = ax_apo_obj.imshow(
    np.zeros((n,n)),vmin=0,vmax=1,
    cmap=sample_cmap
)
apo_sampling = 19.2 / n  # nm
ctf.add_scalebar(ax_apo_obj,length=n//5,sampling=apo_sampling,units=r'nm')


fig.tight_layout()
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
# fig.canvas.toolbar_visible = True
# fig.canvas.toolbar_position = 'bottom'
fig.canvas.layout.width = '640px'
fig.canvas.layout.height = '420px'
fig.tight_layout()
None
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
layout_half = ipywidgets.Layout(width="160px",height="30px")

kwargs = {'style':style,'layout':layout,'continuous_update':False}
kwargs_half = {'style':style,'layout':layout_half,'continuous_update':False}

inner_collection_angle_slider = ipywidgets.FloatSlider(
    value = q_probe/2,
    min = 0,
    max = q_probe, 
    step = q_probe/20,
    description = r"inner collection angle [$q_{\mathrm{probe}}$]",
    **kwargs
)

outer_collection_angle_slider = ipywidgets.FloatSlider(
    value = q_probe, 
    min = q_probe/20, 
    max = q_max, 
    step = q_probe/20,
    description = r"outer collection angle [$q_{\mathrm{probe}}$]",
    **kwargs
)

number_of_segments_slider = ipywidgets.IntSlider(
    value = 4, 
    min = 3, 
    max = 16, 
    step = 1,
    description = "number of annular segments",
    **kwargs
)

rotation_offset_slider = ipywidgets.IntSlider(
    value = 0, min = 0, max = 180/4, step = 1,
    description = "rotation offset [°]",
    **kwargs
)

number_of_rings_slider = ipywidgets.IntSlider(
    value = 1, 
    min = 1, 
    max = 8, 
    step = 1,
    description = "number of radial rings",
    **kwargs
)

rotate_half_the_rings = ipywidgets.ToggleButton(
    value = False,
    description = 'offset radial rings',
    disabled = False,
    layout=ipywidgets.Layout(width="155px",height="30px")
)

area_toggle = ipywidgets.ToggleButton(
    value = False,
    description = 'distribute by area',
    layout=ipywidgets.Layout(width="155px",height="30px")
)

def update_outer_collection_angle(change):
    value = change['new']
    outer_collection_angle_slider.min = value*1.05

def update_inner_collection_angle(change):
    value = change['new']
    inner_collection_angle_slider.max = value

inner_collection_angle_slider.observe(update_outer_collection_angle, names='value')
outer_collection_angle_slider.observe(update_inner_collection_angle, names='value')

# rotation offset is modulo 180/n
def update_rotation_offset_range(change):
    value = change['new']
    rotation_offset_slider.max = 180/value

number_of_segments_slider.observe(update_rotation_offset_range, names='value') 

defocus_slider = ipywidgets.IntSlider(
    value = 0, min = -n, max = n, step = 1,
    description = "negative defocus, $C_{1,0}$ [Å]",
    **kwargs
)

frequencies_toggle = ipywidgets.ToggleButton(
    description='show low freq trotter',
    value=True,
    **kwargs_half
)

OBF_toggle = ipywidgets.ToggleButton(
    value=False,
    description='use OBF weights',
    **kwargs_half
)

simulate_button = ipywidgets.Button(
    description='simulate (expensive)',
    layout=ipywidgets.Layout(width="160px",height="30px")
)

simulation_pbar = tqdm(total=9,display=False)
simulation_pbar_wrapper = ipywidgets.HBox(simulation_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))

reconstruct_button = ipywidgets.Button(
    description='reconstruct (expensive)',
    layout=ipywidgets.Layout(width="160px",height="30px")
)

reconstruction_pbar = tqdm(total=9,display=False)
reconstruction_pbar_wrapper = ipywidgets.HBox(reconstruction_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))

def disable_all(boolean):
    inner_collection_angle_slider.disabled = boolean
    outer_collection_angle_slider.disabled = boolean
    number_of_segments_slider.disabled = boolean
    rotation_offset_slider.disabled = boolean
    number_of_rings_slider.disabled = boolean
    rotate_half_the_rings.disabled = boolean
    area_toggle.disabled = boolean
    defocus_slider.disabled = boolean
    # soft_aperture_toggle.disabled = boolean
    frequencies_toggle.disabled = boolean
    OBF_toggle.disabled = boolean
    simulate_button.disabled = boolean
    reconstruct_button.disabled = boolean
    return None

def defocus_aperture_wrapper(*args):
    im_trotter_annular.set_alpha(0.25)
    im_trotter_pixelated.set_alpha(0.25)
    im_ctf.set_alpha(0.25)
    im_white_noise_obj.set_alpha(0.25)
    im_sto_obj.set_alpha(0.25)
    im_mof_obj.set_alpha(0.25)
    im_apo_obj.set_alpha(0.25)
    plot_ctf.set_alpha(0.25)
    plot_ctf_pixelated.set_alpha(0.25)
    simulate_button.button_style = 'warning'
    simulation_pbar.reset()
defocus_slider.observe(defocus_aperture_wrapper,names='value')

def simulate_wrapper(*args):
    disable_all(True)
    simulate_and_update_trotters(
        defocus_slider.value,
        False,
        OBF_toggle.value,
        pbar=simulation_pbar,
    )
    disable_all(False)
    reconstruct_button.button_style = 'warning'
    reconstruction_pbar.reset()
    simulate_button.button_style = ''
    disable_all(False)
simulate_button.on_click(simulate_wrapper)

def reconstruct_wrapper(*args):
    disable_all(True)
    update_ctfs(
        defocus_slider.value,
        False,
        OBF_toggle.value,
        pbar=reconstruction_pbar,
    )
    disable_all(False)
reconstruct_button.on_click(reconstruct_wrapper)
def simulate_and_update_trotters(
    defocus,
    use_soft_aperture,
    use_OBF_weighting,
    pbar=None,
):
    """ """
    intensities[0] = simulate_intensities(
        defocus=defocus,
        use_soft_aperture=use_soft_aperture,
        pbar=pbar,
        batch_size=1024,
    )
    intensities_FFT[0] = np.fft.fft2(intensities[0],axes=(0,1))

    ind = n//12 if frequencies_toggle.value else n//8
    
    update_virtual_and_pixelated_trotters()

    update_pixelated_ctf(
        defocus=defocus,
        use_soft_aperture=use_soft_aperture,
        use_OBF_weighting=use_OBF_weighting,
        intensities_FFT = intensities_FFT[0],
    )
    
    return None

def update_pixelated_ctf(
    defocus,
    use_soft_aperture,
    use_OBF_weighting,
    intensities_FFT,
):
    """ """
    recon_0 = ptychography_reconstruction_pixelated(
        intensities_FFT=intensities_FFT,
        defocus=defocus,
        use_soft_aperture=use_soft_aperture,
        use_OBF_weighting= use_OBF_weighting,
    )
    
    numeric_ctf_0 = ctf.compute_ctf(np.angle(recon_0)) / 2
    numeric_ctf_0[0,0] = 0.0
    
    q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(numeric_ctf_0,(sampling,sampling))
    plot_ctf_pixelated.set_ydata(I_bins_pixelated)
    plot_ctf_pixelated.set_alpha(1)
    fig.canvas.draw()
    return None
def update_ctfs(
    defocus,
    use_soft_aperture,
    use_OBF_weighting,
    intensities_FFT=None,
    pbar=None,
):
    """ """

    recon,recon_0 = ptychography_reconstruction(
        masked_intensities_FFT[0],
        virtual_masks_annular[0],
        defocus=defocus,
        use_soft_aperture=use_soft_aperture,
        use_OBF_weighting=use_OBF_weighting,
        intensities_FFT=intensities_FFT,
        pbar=pbar,
    )

    numeric_ctf = ctf.compute_ctf(np.angle(recon)) / 2
    numeric_ctf[0,0] = 0.0
    
    im_ctf.set_data(
        ctf.histogram_scaling(
            np.fft.fftshift(numeric_ctf),
            normalize=True
        )
    )    
    
    # real space samples
    im_white_noise_obj.set_data(
        ctf.histogram_scaling(
            np.fft.ifft2(
                np.fft.fft2(potential) * numeric_ctf
            ).real
            ,normalize=True
        )
    )
    
    zero_pad_ctf_to_4qprobe = np.fft.ifftshift(np.pad(np.fft.fftshift(numeric_ctf),n//2))
    resample_2qprobe_ctf_to_192  = np.fft.fft2(
        np.fft.ifftshift(
            np.pad(
                np.fft.fftshift(
                    np.fft.ifft2(numeric_ctf).real),
                n//2)
        )
    )
    
    im_sto_obj.set_data(
        ctf.histogram_scaling(
            np.fft.ifft2(
                np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe).real,
            normalize=True)
    )
    im_mof_obj.set_data(
        ctf.histogram_scaling(
            np.fft.ifft2(
                np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe).real,
            normalize=True)
    )
    im_apo_obj.set_data(
        ctf.histogram_scaling(
            np.fft.ifft2(
                np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe).real,
            normalize=True)
    )
    
    
    q_bins, I_bins = ctf.radially_average_ctf(numeric_ctf,(sampling,sampling))
    plot_ctf.set_ydata(I_bins)

    if recon_0 is not None:
        numeric_ctf_0 = ctf.compute_ctf(np.angle(recon_0)) / 2
        numeric_ctf_0[0,0] = 0.0       
    
        q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(numeric_ctf_0,(sampling,sampling))
        plot_ctf_pixelated.set_ydata(I_bins_pixelated)
    
    ax_ctf_rad.collections[0].remove()
    ax_ctf_rad.vlines(
        [
            inner_collection_angle_slider.value,
            outer_collection_angle_slider.value
        ],0,2,
        colors='k',linestyles='--',linewidth=1,
    )

    im_ctf.set_alpha(1)
    im_white_noise_obj.set_alpha(1)
    im_sto_obj.set_alpha(1)
    im_mof_obj.set_alpha(1)
    im_apo_obj.set_alpha(1)
    plot_ctf.set_alpha(1)
    plot_ctf_pixelated.set_alpha(1)
    reconstruct_button.button_style = ''

    fig.canvas.draw()
    return None
def update_virtual_and_pixelated_trotters(
    *args,
):
    """ """
    disable_all(True)
    
    # compute new datasets
    _virtual_masks_annular = []
    if area_toggle.value:
        ring_collection_angles = np.linspace(
            inner_collection_angle_slider.value**2,
            outer_collection_angle_slider.value**2,
            num=number_of_rings_slider.value + 1
        )**(1/2)
    else:
        ring_collection_angles = np.linspace(
            inner_collection_angle_slider.value,
            outer_collection_angle_slider.value,
            num=number_of_rings_slider.value + 1
        )
    if rotate_half_the_rings.value:
        ring_rotation = np.deg2rad((180/number_of_segments_slider.value))
    else:
        ring_rotation = 0
    for i in range(1,number_of_rings_slider.value+1):
        j = i-1
        _virtual_masks_annular.append(
            annular_segmented_detectors(
                gpts=(n,n),
                sampling=(sampling,sampling),
                n_angular_bins=number_of_segments_slider.value,
                inner_radius=ring_collection_angles[j],
                outer_radius=ring_collection_angles[i],
                rotation_offset=np.deg2rad(rotation_offset_slider.value) + ring_rotation*(j%2),
            )
        )

    virtual_masks_annular[0] = np.vstack(_virtual_masks_annular)
    
    # Previously: mask intensities and then take FFT
    # mask_intensities_using_virtual_detectors(
    #     intensities[0],
    #     masked_intensities_FFT[0],
    #     virtual_masks_annular[0]
    # )

    # masked_intensities_FFT[0] = np.fft.fft2(masked_intensities_FFT[0],axes=(0,1))

    # Now: equivalently, mask intensities FFT directly -- thanks linearity
    mask_intensities_using_virtual_detectors(
        intensities_FFT[0],
        masked_intensities_FFT[0],
        virtual_masks_annular[0]
    )
   
    ind = n//12 if frequencies_toggle.value else n//8

    im_trotter_pixelated.set_data(
        np.dstack(
            (
                ctf.complex_to_rgb(
                    np.fft.fftshift(intensities_FFT[0][ind,2*ind])
                ),
                mask_opacities(virtual_masks_annular[0])
            )
        )
    )
    
    im_trotter_annular.set_data(
        ctf.complex_to_rgb(
            np.fft.fftshift(masked_intensities_FFT[0][ind,2*ind])
        )
    )
    
    im_trotter_annular.set_alpha(1)
    im_trotter_pixelated.set_alpha(1)
    im_ctf.set_alpha(0.25)
    im_white_noise_obj.set_alpha(0.25)
    im_sto_obj.set_alpha(0.25)
    im_mof_obj.set_alpha(0.25)
    im_apo_obj.set_alpha(0.25)
    plot_ctf.set_alpha(0.25)
    reconstruct_button.button_style = 'warning'
    reconstruction_pbar.reset()

    fig.canvas.draw()
    disable_all(False)
    return None

inner_collection_angle_slider.observe(update_virtual_and_pixelated_trotters,names='value')
outer_collection_angle_slider.observe(update_virtual_and_pixelated_trotters,names='value')
number_of_segments_slider.observe(update_virtual_and_pixelated_trotters,names='value')
rotation_offset_slider.observe(update_virtual_and_pixelated_trotters,names='value')
number_of_rings_slider.observe(update_virtual_and_pixelated_trotters,names='value')
rotate_half_the_rings.observe(update_virtual_and_pixelated_trotters,names='value')
area_toggle.observe(update_virtual_and_pixelated_trotters,names='value')
def update_trotters_frequency(*args):
    
    ind = n//12 if frequencies_toggle.value else n//8
    
    im_trotter_pixelated.set_data(
        np.dstack(
            (
                ctf.complex_to_rgb(
                    np.fft.fftshift(intensities_FFT[0][ind,2*ind])
                ),
                mask_opacities(virtual_masks_annular[0])
            )
        )
    )

    im_trotter_annular.set_data(
        ctf.complex_to_rgb(
            np.fft.fftshift(masked_intensities_FFT[0][ind,2*ind])
        )
    )

    fig.canvas.draw()    
    return None

frequencies_toggle.observe(update_trotters_frequency,'value')
def update_weights(change):
    
    disable_all(True)
    recon_0 = ptychography_reconstruction_pixelated(
        intensities_FFT=intensities_FFT[0],
        defocus=defocus_slider.value,
        use_soft_aperture=False,
        use_OBF_weighting= change.new,
    )
    
    numeric_ctf_0 = ctf.compute_ctf(np.angle(recon_0)) / 2
    numeric_ctf_0[0,0] = 0.0
    
    q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(numeric_ctf_0,(sampling,sampling))
    plot_ctf_pixelated.set_ydata(I_bins_pixelated)
    plot_ctf_pixelated.set_alpha(1)
    im_ctf.set_alpha(0.25)
    im_white_noise_obj.set_alpha(0.25)
    im_sto_obj.set_alpha(0.25)
    im_mof_obj.set_alpha(0.25)
    im_apo_obj.set_alpha(0.25)
    plot_ctf.set_alpha(0.25)
    fig.canvas.draw()
    reconstruct_button.button_style = 'warning'
    reconstruction_pbar.reset()
    disable_all(False)
    return None

OBF_toggle.observe(update_weights,'value')
update_virtual_and_pixelated_trotters()
display(
    ipywidgets.VBox(
        [
            ipywidgets.VBox(
                [
                    ipywidgets.HBox([defocus_slider,simulate_button,simulation_pbar_wrapper]),
                    ipywidgets.HBox([frequencies_toggle,OBF_toggle,reconstruct_button,reconstruction_pbar_wrapper]),
                    ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
                    ipywidgets.HBox([inner_collection_angle_slider,outer_collection_angle_slider]),
                    ipywidgets.HBox([number_of_segments_slider,rotation_offset_slider]),
                    ipywidgets.HBox([number_of_rings_slider,rotate_half_the_rings,area_toggle]),
                ]
            ),
            fig.canvas
        ]
    )
)