%load_ext autoreload
%autoreload 2

class PointProcessGaussian[source]

PointProcessGaussian(logits:tensor, xyzi_mu:tensor, xyzi_sigma:tensor, **kwargs) :: Distribution

Distribution is the abstract base class for probability distributions.

get_sample_mask[source]

get_sample_mask(bs, locations)

get_true_labels[source]

get_true_labels(bs, locations, x_os, y_os, z_os, *args)

grp_range[source]

grp_range(counts:Tensor)

cum_count_per_group[source]

cum_count_per_group(arr)

Helper function that returns the cumulative sum per group.
Example:
    [0, 0, 0, 1, 2, 2, 0] --> [0, 1, 2, 0, 0, 1, 3]
model_out = torch.load('../data/model_batch_output.pt')
model_out.keys()
model_out['logits'].shape
torch.Size([2, 1, 48, 48, 48])
from decode_fish.engine.point_process import PointProcessUniform
point_process = PointProcessUniform(local_rate = torch.ones([2,1,48,48,48])*.0001, sim_iters=5)
locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample()
locs_3d = [l.cuda() for l in locs_3d]
xyzi_true, s_mask = get_true_labels(2, locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())
print(len(locs_3d[0]))
print(s_mask)
print(s_mask.sum(-1))
29
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.]],
       device='cuda:0')
tensor([16., 13.], device='cuda:0')
gmm_loss = PointProcessGaussian(**model_out).log_prob(locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())
plt.hist((ints_3d + 2).numpy())
(array([6., 3., 5., 4., 2., 3., 2., 1., 0., 3.]),
 array([0.64889324, 0.9456027 , 1.2423122 , 1.5390217 , 1.8357313 ,
        2.1324408 , 2.4291503 , 2.7258596 , 3.0225692 , 3.3192787 ,
        3.6159883 ], dtype=float32),
 <BarContainer object of 10 artists>)
filt_gt_filt(ints_3d, 1, 1, 0.3)
tensor([ True, False, False, False,  True,  True, False,  True, False,  True,
         True,  True, False,  True, False, False,  True,  True,  True,  True,
        False,  True,  True,  True, False,  True,  True,  True,  True])
gmm_loss
(tensor([-5.8608e+17, -2.3596e+00], device='cuda:0', grad_fn=<SubBackward0>),
 tensor([-1074.2325, -1207.4500], device='cuda:0', grad_fn=<SumBackward1>))
from decode_fish.funcs.utils import free_mem
free_mem()
for i in range(1000):
    gmm_loss = PointProcessGaussian(**model_out).log_prob(locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())
!nbdev_build_lib
Converted 00_models.ipynb.
Converted 01_psf.ipynb.
Converted 02_microscope.ipynb.
Converted 03_noise.ipynb.
Converted 04_pointsource.ipynb.
Converted 05_gmm_loss.ipynb.
Converted 06_plotting.ipynb.
Converted 07_file_io.ipynb.
Converted 08_dataset.ipynb.
Converted 09_output_trafo.ipynb.
Converted 10_evaluation.ipynb.
Converted 11_emitter_io.ipynb.
Converted 12_utils.ipynb.
Converted 13_train.ipynb.
Converted 15_fit_psf.ipynb.
Converted 16_visualization.ipynb.
Converted 17_eval_routines.ipynb.
Converted index.ipynb.