%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
# perf, pred_df, matches = get_sim_perf(xsim, px_to_nm(gt, cfg.evaluation.px_size_zyx), model, post_proc, micro, print_res=True)
# plt.scatter(matches['int_tar'], matches['int_pred'])
# plt.plot([0,10],[0,10])
def filt_perc(df, perc = 90, return_low=True, metric='comb_sig'):
if perc >= 100:
return df
ret_df = DF()
for f in dec_df_col['frame_idx'].unique():
frame_df = df[df['frame_idx']==f]
filt_val = np.percentile(frame_df[metric], perc)
print(filt_val)
if return_low:
frame_df = frame_df[frame_df[metric] < filt_val]
else:
frame_df = frame_df[frame_df[metric] > filt_val]
ret_df = ret_df.append(frame_df)
return ret_df
def int_hist_nnnew(df, micro):
ints = df['int']#*micro.int_sig.item() + micro.int_mu.item() - micro.int_mu.item() * micro.min_fac
dist_samp = cpu(torch.distributions.Gamma(model.int_dist.int_conc.item(), model.int_dist.int_rate.item()).sample([10000]))+model.int_dist.int_loc.item()
max_bin = dist_samp.max()
print(ints.min())
_ = plt.hist(ints, bins=np.linspace(0,max_bin,101), density=True, label='Predictions')
_ = plt.hist(dist_samp, density=True, bins=np.linspace(0,max_bin,101), alpha=0.5, label='Distr.')
plt.plot([model.int_dist.int_loc.item(),model.int_dist.int_loc.item()],[0,1], color='red')
plt.title(ints.mean())
plt.legend()
def psf_rmse(psf1, psf2):
return np.sqrt(np.mean((psf1-psf2)**2))
def psf_corr(psf1, psf2):
return np.corrcoef(psf1.reshape(-1), psf2.reshape(-1))[0,1]
int_hist_nnnew(dec_df, micro)
import scipy.stats as stats
plt.figure(figsize=(10,6))
#Fit parameters:
shape = 3.7504156297802815,
scale = 0.6406917808418713,
locaction = 2.104038290757017
# files = sorted(glob.glob('/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_ff18/*/train.yaml'))
files = sorted(glob.glob('/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_fj3/*/train.yaml'))
cfg = OmegaConf.load(files[0])
model = hydra.utils.instantiate(cfg.model)
for f in files[:]:
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_fd1/int_mu:{m}//train.yaml')
cfg = OmegaConf.load(f)
name = str(Path(f).parent).split('/')[-1]
cfg.microscope._target_ = 'decode_fish.engine.microscope.Microscope'
path = Path(cfg.output.save_dir)
model = load_model_state(model, path, 'model.pkl')
_ = plt.hist(cpu(torch.distributions.Gamma(model.int_dist.int_conc.item(), model.int_dist.int_rate.item()).sample([10000]))+model.int_dist.int_loc.item(),
bins=np.linspace(0,15,101),histtype='step', label=name, density=True)
x = np.linspace(0,15,101)
y = stats.gamma.pdf(x, shape, locaction, scale)
plt.plot(x, y, label='Fit', linewidth=3, color='black')
plt.legend()
# img_fish = round(reshape(sum(sum(sum(img_6D,1),3),5),size_img)./(factor_binning^3));
large_psf = cpu(load_tiff_image('../figures/PSF.tif')[0])
small_psf = large_psf[:-1,:-1,:-1].reshape([10,3,20,3,20,3]).sum(1).sum(2).sum(3)/27
gt_psf = cpu(load_tiff_image(cfg.evaluation.psf_path)[0])
gt_psf /= gt_psf.max()
from decode_fish.engine.psf import LinearInterpolatedPSF
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_f4/min_fac:0.4/train.yaml')
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_f5/int_mu:4.0//train.yaml')
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_fb6/int_mu:2.0/train.yaml')
cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_g1/start_psf:3000/train.yaml')
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_fi3d/seed:3/train.yaml')
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_fi7_gt/seed:4/train.yaml')
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_fh9/seed:4/train.yaml')
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_fi2c/norm_reg:0.01/train.yaml')
# cfg = OmegaConf.load(f'/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/psf_opt_runs/i_psf_max_norm_foci_3loc/train.yaml')
# cfg.microscope._target_ = 'decode_fish.engine.microscope.Microscope'
cfg.foci.n_foci_avg = 0.0
model, post_proc, psf, micro, img_3d, decode_dl = load_all(cfg, False)
psf_init = get_gaussian_psf(cfg.PSF.psf_extent_zyx,cfg.PSF.gauss_radii)
psf_init = cpu(psf_init.psf_volume[0])
psf_vol = cpu(psf.psf_volume[0])
psf_init /= psf_init.max()
psf_vol /= psf_vol.max()
print(psf.get_com())
plot_3d_projections(gt_psf)
plot_3d_projections(psf_vol)
plot_3d_projections(gt_psf-psf_vol[:,::-1,::-1])
print('Sum: ', gt_psf.sum(), psf_vol.sum(), psf_init.sum())
print('CSum: ', gt_psf.sum(), np.clip(psf_vol,0,10).sum(), psf_init.sum())
print('Max: ', gt_psf.max(), psf_vol.max(), psf_init.max())
print('RMSE init: ', psf_rmse(gt_psf, psf_init))
print('RMSE: ', psf_rmse(gt_psf, psf_vol))
print('Corr init: ', psf_corr(gt_psf, psf_init))
print('Corr: ', psf_corr(gt_psf, psf_vol))
model.cuda()
ind = 0
sl = np.s_[:,:,:,:]
# sl = np.s_[:,30:,150:180,115:135]
# basedir = '/groups/turaga/home/speisera/share_TUM/FishSIM/sim_foci_fac1_1//'
# img, gt_df, fq_nog_df, fq_gmm_df = load_sim_fish(basedir, 100, 'foci', 'strong', ind)
basedir = '/groups/turaga/home/speisera/share_TUM/FishSIM/sim_density_fac1_2/'
gt_df_col = DF()
fq_df_col = DF()
dec_df_col = DF()
for ind in tqdm([0]):
img, gt_df, fq_nog_df, fq_gmm_df = load_sim_fish(basedir, 250, 'random', 'NR', ind)
bg_est = EstimateBackground(smoothing_filter_size=cfg.bg_estimation.smoothing_filter_size)(img)
bg_gt = load_tiff_image('/groups/turaga/home/speisera/share_TUM/FishSIM/bg_tifs/w1_HelaKyoto_Gapdh_2597_p01_cy3__Cell_CP_10.tif')
img_in = img #- bg_gt + bg_est # bg_gt.mean()
# gt_df = crop_df(gt_df, sl, px_size_zyx=[300,100,100])
fq_gmm_df = crop_df(fq_gmm_df, sl, px_size_zyx=[300,100,100])
with torch.no_grad():
dec_df = shift_df(post_proc.get_df(model.tensor_to_dict(model(img_in[sl][None].cuda()))), [-100,-100,-300])
free_mem()
gt_df_col = cat_emitter_dfs([gt_df_col, gt_df])
fq_df_col = cat_emitter_dfs([fq_df_col, fq_gmm_df])
dec_df_col = cat_emitter_dfs([dec_df_col, dec_df])
# dec_df_col = sig_filt(dec_df_col, 80)
perf_df, matches, shift = matching(gt_df_col, dec_df_col, print_res=False)
dec_df_col = shift_df(dec_df_col, shift)
axes = plot_3d_projections(img_in[sl][0], 'max', size=15)
scat_3d_projections(axes, [dec_df, gt_df], px_size_zyx=[300,100,100])
perf_df, matches, _ = matching(gt_df_col, dec_df_col, print_res=True)
print(len(dec_df_col)/len(gt_df_col))
dec_df_filt = sig_filt(dec_df_col, 95)
perf_df, matches, shift = matching(gt_df_col, dec_df_filt, print_res=True)
plt.scatter(matches['int_tar'], matches['int_pred'])
print(np.round(np.linalg.lstsq(matches['int_tar'].values.reshape(-1,1), matches['int_pred'].values, rcond=None)[0][0],2))
plt.plot([0,10],[0,10])
eval_random_crop(decode_dl, model, post_proc, micro, projection='max', samples=3, int_threshold=1000)
basedir
from decode_fish.funcs.fit_psf import get_peaks_3d
from decode_fish.funcs.fit_psf import plot_detection
basedir = '/groups/turaga/home/speisera/share_TUM/FishSIM/sim_foci_fac1_1/'
files = sorted(glob.glob('/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_fj1/*/train.yaml'))
dec_count_col = []
prob_sum_col = []
names = []
for f in files:
cfg = OmegaConf.load(f)
names.append(str(Path(f).parent).split('/')[-1])
# model, post_proc, psf, micro, img_3d, decode_dl = load_all(cfg, False)
model = load_model_state(model, Path(cfg.output.save_dir), 'model.pkl')
_ = model.cuda()
box_sz = 10
n_cells = 20
gt_counts = []
fq_counts = []
dec_counts = []
prob_sums = []
for i in tqdm(range(n_cells)):
img, gt_df, fq_nog_df, fq_gmm_df = load_sim_fish(basedir, 100, 'foci', 'strong', i)
gt_px = nm_to_px(gt_df, px_size_zyx=[300,100,100])
fq_px = nm_to_px(fq_gmm_df, px_size_zyx=[300,100,100])
with torch.no_grad():
res_dict = model(img[None].cuda())
dec_df = shift_df(post_proc(res_dict, 'df'), [-100,-100,-100])
free_mem()
dec_px = nm_to_px(dec_df, px_size_zyx=[300,100,100])
try:
coords_xyz = get_peaks_3d(img[0], threshold=2000, min_distance=10)
except AssertionError:
continue
coords_zyx = coords_xyz[:,::-1]
for c in coords_zyx:
sl = np.s_[:,c[0]-box_sz:c[0]+box_sz+1, c[1]-box_sz:c[1]+box_sz+1, c[2]-box_sz:c[2]+box_sz+1]
gt_crop = crop_df(gt_px, sl)
fq_crop = crop_df(fq_px, sl)
dec_crop = crop_df(dec_px, sl)
prob_crop = torch.sigmoid(res_dict['logits'])[0][sl]
gt_counts.append(len(gt_crop))
fq_counts.append(len(fq_crop))
dec_counts.append(len(dec_crop))
prob_sums.append(prob_crop.sum().item())
dec_count_col.append(dec_counts)
prob_sum_col.append(prob_sums)
plt.figure(figsize=(15,8))
import seaborn as sns
plt.subplot(121)
corr = np.round(np.corrcoef([gt_counts, fq_counts])[0,1],3)
plt.scatter(gt_counts, fq_counts, label=f'FISH-quant. Corr: {corr}')
for i,c in enumerate(prob_sum_col):
corr = np.round(np.corrcoef([gt_counts, c])[0,1],3)
plt.scatter(gt_counts, c, label=f'DEC {names[i]} probs. Corr: {corr}', alpha=1.0)
plt.plot([0,100],[0,100], 'red')
plt.xlabel('Ground truth N mRNA')
plt.ylabel('Predicted N mRNA')
plt.xlim(0,40)
plt.ylim(0,40)
plt.legend()
sns.despine()
plt.subplot(122)
corr = np.round(np.corrcoef([gt_counts, fq_counts])[0,1],3)
plt.scatter(gt_counts, fq_counts, label=f'FISH-quant. Corr: {corr}')
for i,c in enumerate(dec_count_col[2:3]):
corr = np.round(np.corrcoef([gt_counts, c])[0,1],3)
plt.scatter(gt_counts, c, label=f'DEC {names[i]} counts. Corr: {corr}', alpha=1.0)
plt.plot([0,100],[0,100], 'red')
plt.xlabel('Ground truth N mRNA')
plt.ylabel('Predicted N mRNA')
plt.xlim(0,40)
plt.ylim(0,40)
plt.legend()
sns.despine()
# print('Corr. DECODE: ', np.corrcoef([gt_counts, dec_counts])[0,1])
plt.figure(figsize=(15,8))
import seaborn as sns
plt.subplot(121)
plt.scatter(gt_counts, fq_counts, label='FISH-quant')
for i,c in enumerate(prob_sum_col):
plt.scatter(gt_counts, c, label=f'DEC {mults[i]} probs')
plt.plot([0,100],[0,100], 'red')
plt.xlabel('Ground truth N mRNA')
plt.ylabel('Predicted N mRNA')
plt.legend()
sns.despine()
plt.subplot(122)
plt.scatter(gt_counts, fq_counts, label='FISH-quant')
for i,c in enumerate(dec_count_col):
plt.scatter(gt_counts, c, label=f'DEC {mults[i]} counts')
plt.plot([0,100],[0,100], 'red')
plt.xlabel('Ground truth N mRNA')
plt.ylabel('Predicted N mRNA')
plt.legend()
sns.despine()
# print('Corr. FQ: ', np.corrcoef([gt_counts, fq_counts])[0,1])
# print('Corr. DECODE: ', np.corrcoef([gt_counts, dec_counts])[0,1])
!nbdev_build_lib