Definition of the classes and modules we use to simulate recordings given network outputs
%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
from decode_fish.engine.psf import LinearInterpolatedPSF
from decode_fish.engine.noise import sCMOS
from decode_fish.engine.point_process import PointProcessUniform
from decode_fish.funcs.plotting import plot_3d_projections
from decode_fish.funcs.file_io import get_gaussian_psf
psf = get_gaussian_psf([13,21,21],[1,1,1]).cuda()
noise = sCMOS()
micro = Microscope(parametric_psf=[psf], noise=noise, int_mu=2, int_scale=5, int_loc=1, psf_noise=0.0001).cuda()
point_process = PointProcessUniform(local_rate = torch.ones([2,1,48,48,48]).cuda()*.0001)
locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample()
plot_3d_projections(psf.forward(torch.tensor([1]).cuda(),torch.tensor([2.5]).cuda(),torch.tensor([0]).cuda())[0,0])
xsim = micro(locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape)
plot_3d_projections(xsim[0,0])
from decode_fish.funcs.file_io import get_gaussian_psf
cfg = OmegaConf.load('/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/N2_352/sweep_b1/psf_noise:0xnum_iter:10000/train.yaml')
psf = get_gaussian_psf(cfg.microscope.psf_extent_zyx, cfg.PSF.gauss_radii)
# psf.load_state_dict(torch.load(Path(cfg.output.save_dir)/'psf.pkl'))
noise = hydra.utils.instantiate(cfg.noise)
micro = Microscope(parametric_psf=[psf], noise=noise, multipl=cfg.microscope.multipl, psf_noise=2e-3, clamp_mode=cfg.microscope.clamp_mode).cuda()
get_gaussian_psf(cfg.microscope.psf_extent_zyx, cfg.PSF.gauss_radii).psf_volume.sum()/psf.psf_volume.sum()
xsim = micro(locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape)
print(xsim.sum())
plot_3d_projections(xsim[0,0])
xsim = micro(locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape)
print(xsim.sum())
plot_3d_projections(xsim[0,0])
plot_3d_projections(torch.clamp_min(psf.psf_volume[0],-5),'mean')
from decode_fish.funcs.file_io import *
from decode_fish.funcs.output_trafo import *
cfg = OmegaConf.load(default_conf)
gt_img = load_tiff_image('/groups/turaga/home/speisera/share_TUM/FishSIM/sim_1/mRNAlevel_200/cell3D/strong/w1_HelaKyoto_Gapdh_2597_p01_cy3__Cell_CP_14__cell3D__1.tif')
model_out = torch.load('../data/model_output.pt')
locs_ae, x_os_ae, y_os_ae, z_os_ae, ints_ae, output_shape_ae = model_output_to_micro_input(model_out, threshold=0.1)
ae_img = micro(locs_ae, x_os_ae, y_os_ae, z_os_ae, ints_ae, output_shape_ae)
log_p_x_given_z = - micro.noise(ae_img,model_out['background']).log_prob(gt_img[None,:,30:].cuda()).mean()
log_p_x_given_z.backward()
log_p_x_given_z
!nbdev_build_lib