Definition of the classes and modules we use to model PSFs
%load_ext autoreload
%autoreload 2
psf_state = torch.load('/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/msp300_smFISH/nb_run/msp300_smFISH_3_3/psf_init.pkl')
shape_zyx = psf_state['psf_volume'][:,].shape[-3:]
psf = LinearInterpolatedPSF((21,21,21))
psf.load_state_dict({'psf_volume':psf_state['psf_volume']}, strict=False)
psf.get_com()
shift = psf(torch.tensor([-2,0.,+2]).cuda(),0.+torch.zeros(3).cuda(),0.+torch.zeros(3).cuda())
# First output is the volume shifted 2 pixels to the left.
plot_3d_projections(cpu(shift)[0,0])
shift.shape
cropped_psf = crop_psf(psf,[11,21,21])
plot_3d_projections(cropped_psf.psf_volume)
!nbdev_build_lib