%load_ext autoreload
%autoreload 2
from decode_fish.engine.point_process import PointProcessUniform
locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = PointProcessUniform(torch.ones([1,1,40,40,40])*0.001).sample()
sample_to_df(locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d)
class ISIPostProcess(SIPostProcess):
def __init__(self, m1_threshold:float = 0.1, samp_threshold=0.1, px_size_zyx=[100,100,100], diag=False):
super().__init__(m1_threshold = m1_threshold, samp_threshold=samp_threshold, px_size_zyx=px_size_zyx, diag=diag)
self.m2_threshold = None
def forward(self, logits):
device = logits.device
p = torch.sigmoid(logits)
with torch.no_grad():
p_SI = 0
tot_mask = torch.ones_like(p)
max_mask = torch.ones_like(p)
while max_mask.sum():
# voxels with probability values > threshold,
# and which where not previously counted as locations, are canditates
p_cand = torch.where(p>self.m1_threshold, p, torch.zeros_like(p)) * tot_mask
# localize maximum (nonzero) values within a 3x3x3 volume
p_cand = F.max_pool3d(p_cand,3,1,padding=1)
max_mask = torch.eq(p, p_cand).float()
max_mask[p==0] = 0
# Add up probability values from the adjacent pixels
conv = F.conv3d(p, self.filt.to(device), padding=1)
p_sum = max_mask * conv
# Add the integrated probabilities to the return tensor.
p_SI += torch.clamp_max(p_sum, 1)
# Voxels that where added can not be added again
tot_mask *= (torch.ones_like(max_mask) - max_mask)
# The probability mass that contributed to p_sum is removed.
p_fac = 1/p_sum
p_fac[torch.isinf(p_fac)] = 0
p_fac = torch.clamp_max(p_fac, 1)
p_proc = F.conv3d(p_fac, self.filt.to(device),padding=1)*p
p = p - p_proc
torch.clamp_min_(p, 0)
return p_SI
from decode_fish.funcs.utils import *
# model_out = torch.load('../data/model_output.pt')
# probs_inp = torch.sigmoid(model_out['logits'])[:,:,:,250:300,200:250]
# model_out = torch.load('../data/model_batch_output.pt')
# probs_inp = torch.sigmoid(model_out['logits'])
model_out = torch.load('../data/model_output_t.pt')
probs_inp = torch.sigmoid(model_out['logits'])[:,:,:,:,:]
# gt_df = torch.load('../data/gt_1.pt')
# len(gt_df)
from decode_fish.funcs.evaluation import *
post_proc1 = SIPostProcess(m1_threshold=0.03, m2_threshold=0.25, samp_threshold=0.6, px_size_zyx=[100,100,100], diag=True)
post_proc2 = ISIPostProcess(m1_threshold=0.03, samp_threshold=0.5, px_size_zyx=[100,100,100], diag=True)
# matching(px_to_nm(gt_df), post_proc1.forward(model_out, ret='df'), tolerance=500, print_res=True)
# _=matching(px_to_nm(gt_df), post_proc2.forward(model_out, ret='df'), tolerance=500, print_res=True)
plt.figure(figsize=(20,10))
plt.subplot(231)
probs = cpu(probs_inp[0,0])
probsf = probs + 0
probsf[probsf<0.01] = 0
im = plt.imshow(probs.sum(0))
# plt.scatter(gt_df['x'],gt_df['y'], color='red', s=5.)
plt.title(f'Net output {probs.sum().item():.1f} and {probsf.sum().item():.1f}'.format())
add_colorbar(im)
recs = post_proc1.get_si_resdict(model_out)
plt.subplot(232)
im = plt.imshow(cpu(recs['Probs_si'][0,0]).max(0))
add_colorbar(im)
N = cpu(recs['Probs_si'][0,0]).sum().item()
plt.title(f'SI Probs SI {N:.1f}')
plt.subplot(235)
im = plt.imshow(cpu(recs['Samples_si'][0,0]).sum(0))
add_colorbar(im)
plt.title(cpu(recs['Samples_si'][0,0]).sum().item())
recs = post_proc2.get_si_resdict(model_out)
plt.subplot(233)
im = plt.imshow(cpu(recs['Probs_si'][0,0]).max(0))
add_colorbar(im)
N = cpu(recs['Probs_si'][0,0]).sum().item()
plt.title(f'ISI Probs SI {N:.1f}')
plt.subplot(236)
im = plt.imshow(cpu(recs['Samples_si'][0,0]).sum(0))
add_colorbar(im)
plt.title(cpu(recs['Samples_si'][0,0]).sum().item())
plt.figure(figsize=(20,10))
plt.subplot(231)
probs = cpu(probs_inp[0,0])
probsf = probs + 0
probsf[probsf<0.01] = 0
im = plt.imshow(probs.sum(0))
# plt.scatter(gt_df['x'],gt_df['y'], color='red', s=5.)
plt.title(f'Net output {probs.sum().item():.1f} and {probsf.sum().item():.1f}'.format())
add_colorbar(im)
recs = post_proc1.forward(model_out, ret='dict')
plt.subplot(232)
im = plt.imshow(cpu(recs['Probs_si'][0,0]).max(0))
add_colorbar(im)
N = cpu(recs['Probs_si'][0,0]).sum().item()
plt.title(f'SI Probs SI {N:.1f}')
plt.subplot(235)
im = plt.imshow(cpu(recs['Samples_si'][0,0]).sum(0))
add_colorbar(im)
plt.title(cpu(recs['Samples_si'][0,0]).sum().item())
recs = post_proc2.forward(model_out, ret='dict')
plt.subplot(233)
im = plt.imshow(cpu(recs['Probs_si'][0,0]).max(0))
add_colorbar(im)
N = cpu(recs['Probs_si'][0,0]).sum().item()
plt.title(f'ISI Probs SI {N:.1f}')
plt.subplot(236)
im = plt.imshow(cpu(recs['Samples_si'][0,0]).sum(0))
add_colorbar(im)
plt.title(cpu(recs['Samples_si'][0,0]).sum().item())
sl = np.s_[:,:10,35:45,20:30]
gt_sub = crop_df(gt_df, sl)
p_sub = crop_df(nm_to_px(post_proc2.forward(model_out, ret='df')), sl)
axes=plot_3d_projections(probs[sl[1:]], 'max', size=15)
# print(probs[sl[1:]].sum(), len(gt_sub), len(p_sub))
# axes[0].scatter(gt_sub['x'],gt_sub['y'], color='red', s=5.)
# axes[1].scatter(gt_sub['x'],gt_sub['z'], color='red', s=5.)
# axes[2].scatter(gt_sub['y'],gt_sub['z'], color='red', s=5.)
axes[0].scatter(p_sub['x'],p_sub['y'], color='red', s=15.)
axes[1].scatter(p_sub['x'],p_sub['z'], color='red', s=15.)
axes[2].scatter(p_sub['y'],p_sub['z'], color='red', s=15.)
p.shape
for p in p_col:
plt.imshow(cpu(p[0,0][sl[1:]]).max(0))
# plt.title(cpu(p[0,0][sl[1:]]).sum())
plt.colorbar()
plt.show()
plt.imshow(cpu(p[0,0][sl[1:]]).max(0))
plt.hist(probs.reshape(-1).numpy())
probs_si = post_proc.spatial_integration(probs_inp)
plt.figure(figsize=(10,5))
plt.subplot(121)
probs = probs_inp[0,0].detach().cpu()
# probs[probs<0.01] = 0
im = plt.imshow(probs.max(dim=0).values)
plt.title(probs.sum().item())
add_colorbar(im)
plt.subplot(122)
im = plt.imshow(probs_si[0,0].cpu().max(dim=0).values, vmax=1)
add_colorbar(im)
plt.title(probs_si[0].sum().item())
model_out = torch.load('../data/model_output_1.pt')
out_df = post_proc2(model_out)
out_df
plt.figure(figsize=(20,20))
plt.subplot(121)
im = plt.imshow(probs_inp[0,0].cpu().max(dim=0).values)
add_colorbar(im)
plt.title(len(out_df))
plt.scatter(out_df['x']/100,out_df['y']/100, color='red', s=5.)
model_out = torch.load('../data/model_batch_output.pt')
from decode_fish.engine.psf import LinearInterpolatedPSF
from decode_fish.engine.noise import sCMOS
from decode_fish.engine.point_process import PointProcessUniform
from decode_fish.funcs.plotting import plot_3d_projections
from decode_fish.engine.microscope import Microscope
psf_state = torch.load('/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/fishcod/simfish_psf.pkl')
_,xs,ys,zs = psf_state['psf_volume'].shape
psf = LinearInterpolatedPSF(fs_x=xs, fs_y=ys, fs_z=zs, upsample_factor= 1)
psf.load_state_dict(psf_state)
noise = sCMOS()
micro = Microscope(parametric_psf=[psf], noise=noise, multipl=10000).cuda()
point_process = PointProcessUniform(local_rate = torch.ones([1,1,48,48,48]).cuda()*.0001, min_int = 0.5)
locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample()
xsim = micro(locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape)
xrec = micro(locs_mod, x_os_mod, y_os_mod, z_os_mod, ints_mod, output_shape_mod)
plot_3d_projections(xsim[0,0])
plot_3d_projections(xrec[0,0])
!nbdev_build_lib