%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

eval_logger[source]

eval_logger(pred_df, target_df, iteration, data_str='Sim. ')

load_from_eval_dict[source]

load_from_eval_dict(eval_dict)

save_train_state[source]

save_train_state(save_dir, model, microscope, optim_net, psf, optim_psf)

train[source]

train(cfg, model, dl, optim_net, optim_psf, optim_mic, sched_net, sched_psf, sched_mic, microscope, psf, post_proc, eval_dict=None)

Training loop for autoencoder learning. Alternates between a simulator training step to train the inference network
and an autoencoder step to train the PSF (and microscope) parameters.

Args:
    model (torch.nn.Module): DECODE 3D UNet.
    num_iter (int): Number of training iterations for pure sl learning(batches).
    num_iter (int): Total number of training iterations (batches).
    dl  (torch.utils.data.dataloader.DataLoader): Dataloader that returns a random sub volume from the real volume, an estiamted emitter density and background.
    optim_net  (torch.optim.Optimizer): Optimizer for the network parameters.
    optim_psf  (torch.optim.Optimizer): Optimizer for the PSF parameters.
    sched_net  (torch.optim.lr_scheduler): LR scheduler for the network parameters.
    sched_psf  (torch.optim.lr_scheduler): LR scheduler for the PSF parameters.
    min_int  (float): Minimal fraction of the max intensity used when sampling emitters.
    microscope (torch.nn.Module): Microscope class that transforms emitter locations into simulated images.
    log_interval  (int): Number of iterations between performance evaluations.
    save_dir  (str, PosixPath): Output path where the trained model is stored.
    log_dir  (str, PosixPath, optional): Output path where log files for Tensorboard are stored.
    psf (torch.nn.Module): Parametric PSF.
    bl_loss_scale  (float): The background loss gets scaled by this factor when added to the GMM loss.
    grad_clip  (float): Gradient clipping threshold.
    eval_dict  (dict, optional): Dictionary with evaluation parameters
cfg = OmegaConf.load('../config/experiment/i_psf_max_norm_foci.yaml')
# cfg.run_name = 'rab11_nb'
# cfg.training.start_micro = 1000
# cfg.training.start_psf = 1000
# cfg.microscope.int_conc = 2.

psf, noise, micro = load_psf_noise_micro(cfg)
post_proc = hydra.utils.instantiate(cfg.post_proc_isi)
cfg.training.start_micro = 0
cfg.training.start_psf = 0

# cfg.data_path.model_init = '/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_ff21/int_conc:2.0/sl_save'
# cfg.data_path.model_init = cfg.output.save_dir + '/sl_save'
img_3d, decode_dl = get_dataloader(cfg)
20 volumes
 
# get_simulation_statistics(decode_dl, micro, int_threshold=1500)
inp_offset, inp_scale = get_forward_scaling(img_3d[0])
model = hydra.utils.instantiate(cfg.model, inp_scale=float(inp_scale), inp_offset=float(inp_offset))

psf  .to('cuda')
model.to('cuda')
micro.to('cuda')
Microscope(
  (noise): sCMOS()
)
opt_net = hydra.utils.instantiate(cfg.training.net.opt, params=model.unet.parameters())
opt_psf = hydra.utils.instantiate(cfg.training.psf.opt, params=list(psf.parameters()))
opt_mic = hydra.utils.instantiate(cfg.training.micro.opt, params=list(model.int_dist.parameters())[:3])

scheduler_net = hydra.utils.instantiate(cfg.training.net.sched, optimizer=opt_net)
scheduler_psf = hydra.utils.instantiate(cfg.training.psf.sched, optimizer=opt_psf)
scheduler_mic = hydra.utils.instantiate(cfg.training.psf.sched, optimizer=opt_mic)

if cfg.evaluation is not None:
    eval_dict = dict(cfg.evaluation)
    eval_dict['crop_sl'] = eval(eval_dict['crop_sl'],{'__builtins__': None},{'s_': np.s_})
    eval_dict['px_size_zyx'] = list(eval_dict['px_size_zyx'])
else:
    eval_dict = None
    
save_dir = Path(cfg.output.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
if cfg.data_path.model_init is not None:
    print('loading')
    model = load_model_state(model, cfg.data_path.model_init).cuda()
    micro.load_state_dict(torch.load(Path(cfg.data_path.model_init)/'microscope.pkl'))
    opt_net.load_state_dict(torch.load(Path(cfg.data_path.model_init)/'opt_net.pkl'))
#     opt_psf.load_state_dict(torch.load(Path(cfg.data_path.model_init)/'opt_psf.pkl'))
#     psf.load_state_dict(torch.load(Path(cfg.data_path.model_init)/'psf.pkl'))
# locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample()
# plot_3d_projections(xsim[0,0])
_ = wandb.init(project=cfg.output.project, 
               config=OmegaConf.to_container(cfg, resolve=True),
               dir=cfg.output.log_dir,
               group=cfg.output.group,
               name=cfg.run_name
          )
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: aspeiser (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.30 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
Tracking run with wandb version 0.10.23
Syncing run i_psf_max_norm_foci to Weights & Biases (Documentation).
Project page: https://wandb.ai/aspeiser/Fig_sim_density
Run page: https://wandb.ai/aspeiser/Fig_sim_density/runs/rdmdeawm
Run data is saved locally in /groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/decode_fish/runs/wandb/run-20210517_035146-rdmdeawm

train(cfg=cfg,
     model=model, 
     dl=decode_dl, 
     optim_net=opt_net, 
     optim_psf=opt_psf, 
     optim_mic=opt_mic, 
     sched_net=scheduler_net, 
     sched_psf=scheduler_psf, 
     sched_mic=scheduler_mic, 
     psf=psf,
     post_proc=post_proc,
     microscope=micro, 
     eval_dict=eval_dict)
torch.Size([2, 10, 48, 48, 48])
0
torch.Size([1, 10, 69, 100, 81])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
torch.Size([2, 10, 48, 48, 48])
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-28-8f5666bd41d3> in <module>
     11      post_proc=post_proc,
     12      microscope=micro,
---> 13      eval_dict=eval_dict)

<ipython-input-25-2430012adf23> in train(cfg, model, dl, optim_net, optim_psf, optim_mic, sched_net, sched_psf, sched_mic, microscope, psf, post_proc, eval_dict)
     49     for batch_idx in range(cfg.training.num_iters):
     50 
---> 51         x, local_rate, background = next(iter(dl))
     52 
     53         optim_net.zero_grad()

~/anaconda3/envs/decode2_dev/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/envs/decode2_dev/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    473     def _next_data(self):
    474         index = self._next_index()  # may raise StopIteration
--> 475         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    476         if self._pin_memory:
    477             data = _utils.pin_memory.pin_memory(data)

~/anaconda3/envs/decode2_dev/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/anaconda3/envs/decode2_dev/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/Dropbox (mackelab)/Artur/WorkDB/deepstorm/decode_fish/decode_fish/funcs/dataset.py in __getitem__(self, _)
     52         x = self.volumes[i]
     53         x = self._compose(x, self.dataset_tfms)
---> 54         local_rate = self._compose(x, self.rate_tfms, ind = i)
     55         background = self.bg_transform(x)
     56         return x.to(self.device), local_rate.to(self.device), background.to(self.device)

~/Dropbox (mackelab)/Artur/WorkDB/deepstorm/decode_fish/decode_fish/funcs/dataset.py in _compose(x, list_func, **kwargs)
     68         if not list_func: list_func.append(lambda x: x)
     69         for func in list_func:
---> 70             x = func(x, **kwargs)
     71         return x
     72 

~/Dropbox (mackelab)/Artur/WorkDB/deepstorm/decode_fish/decode_fish/funcs/dataset.py in __call__(self, x, **kwargs)
    241 
    242         prob = self.n_foci_avg/torch.numel(x[0])
--> 243         locations = torch.distributions.Bernoulli(torch.ones_like(x)*prob).sample()
    244         xwf = x + 0
    245 

~/anaconda3/envs/decode2_dev/lib/python3.7/site-packages/torch/distributions/bernoulli.py in sample(self, sample_shape)
     87         shape = self._extended_shape(sample_shape)
     88         with torch.no_grad():
---> 89             return torch.bernoulli(self.probs.expand(shape))
     90 
     91     def log_prob(self, value):

KeyboardInterrupt: 
!nbdev_build_lib
Converted 00_models.ipynb.
Converted 01_psf.ipynb.
Converted 02_microscope.ipynb.
Converted 03_noise.ipynb.
Converted 04_pointsource.ipynb.
Converted 05_gmm_loss.ipynb.
Converted 06_plotting.ipynb.
Converted 07_file_io.ipynb.
Converted 08_dataset.ipynb.
Converted 09_output_trafo.ipynb.
Converted 10_evaluation.ipynb.
Converted 11_emitter_io.ipynb.
Converted 12_utils.ipynb.
Converted 13_train.ipynb.
Converted 15_fit_psf.ipynb.
Converted 16_visualization.ipynb.
Converted 17_eval_routines.ipynb.
Converted 18_predict_funcs.ipynb.
Converted index.ipynb.