%load_ext autoreload
%autoreload 2
model_out = torch.load('../data/model_batch_output.pt')
model_out.keys()
model_out['logits'].shape
from decode_fish.engine.point_process import PointProcessUniform
point_process = PointProcessUniform(local_rate = torch.ones([2,1,48,48,48])*.0001, sim_iters=5)
locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample()
locs_3d = [l.cuda() for l in locs_3d]
xyzi_true, s_mask = get_true_labels(2, locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())
print(len(locs_3d[0]))
print(s_mask)
print(s_mask.sum(-1))
gmm_loss = PointProcessGaussian(**model_out).log_prob(locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())
plt.hist((ints_3d + 2).numpy())
filt_gt_filt(ints_3d, 1, 1, 0.3)
gmm_loss
from decode_fish.funcs.utils import free_mem
free_mem()
for i in range(1000):
gmm_loss = PointProcessGaussian(**model_out).log_prob(locs_3d, x_os_3d.cuda(), y_os_3d.cuda(), z_os_3d.cuda(), ints_3d.cuda())
!nbdev_build_lib