%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
path_tif = Path('/groups/turaga/home/speisera/share_TUM/FishSIM/sim_1/mRNAlevel_200/cell3D/strong/w1_HelaKyoto_Gapdh_2597_p01_cy3__Cell_CP_14__cell3D__1.tif')
bead_vol = load_tiff_image(path_tif)
estimate_backg = EstimateBackground(20, 1)
bead_vol = bead_vol - estimate_backg(bead_vol)
win_sl = np.s_[0,:,50:95,75:160]
plt.figure(figsize=(20,6))
plt.subplot(131)
plt.imshow(bead_vol[0].mean(0))
plt.subplot(132)
plt.imshow(bead_vol[win_sl].max(0).values)
plt.subplot(133)
plt.imshow(bead_vol[win_sl].mean(1))
coords_xyz = get_peaks_3d(bead_vol[win_sl], threshold=2000, min_distance=5)
plot_detection(bead_vol[win_sl], coords_xyz)
abs_coords_xyz = np.array([c + np.array([int(win_sl[3].start or 0), int(win_sl[2].start or 0), int(win_sl[1].start or 0)]) for c in coords_xyz])
rois = extract_roi(bead_vol, abs_coords_xyz, size_xy=10, size_z=10)
for i in range(len(rois)): plot_3d_projections(rois[i], projection='max', size=3);
print(rois.shape)
rois_clean = rois[[0,2,4,5]]
rois_normed = rois_clean/rois_clean.sum(-1).sum(-1).sum(-1)[:,None,None,None] * 100
extend = [0,5,5,5]
rois_extended = torch.zeros([s + 2*e for s,e in zip(rois_normed.shape,extend)])
rois_extended[:,extend[1]:-extend[1],extend[2]:-extend[2],extend[3]:-extend[3]] = rois_normed
device='cuda'
PSF = LinearInterpolatedPSF(rois_extended.shape[-3:],1, device=device)
loss_res = fit_psf(PSF, rois_extended)
plot_3d_projections(PSF.psf_volume[0], projection='mean', size=3);
plot_3d_projections(torch.clamp_min(PSF.psf_volume[0],0), projection='mean', size=3);
plt.plot(loss_res)
plt.yscale('log')
plt.plot([0,10000],[np.min(loss_res),np.min(loss_res)])
plt.title(np.min(loss_res))
PSF.psf_volume.data = torch.clamp_min_(PSF.psf_volume.data, 0)
torch.save(PSF.state_dict(), '../data/simfish_psf.pkl')
from scipy import ndimage
delta = torch.zeros_like(PSF.psf_volume.detach())
delta[0,15,15,15] = 1
gauss_3d = 100*scipy.ndimage.filters.gaussian_filter(delta[0], 1.2, order=0, output=None, mode='reflect', cval=0.0, truncate=4.0)
plot_3d_projections(gauss_3d)
gt_psf = load_tiff_image('/groups/turaga/home/speisera/share_TUM/FishSIM/PSF.tif')
plot_3d_projections(gt_psf[0,::3,::3,::3])
device='cuda'
PSF = LinearInterpolatedPSF([21,21,21],1, device=device)
PSF.psf_volume.data = torch.tensor(gauss_3d[None])
torch.save(PSF.state_dict(), '../data/gaussian_psf.pkl')
!nbdev_build_lib