%load_ext autoreload
%autoreload 2
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
cfg = OmegaConf.load('../config/experiment/i_psf_max_norm_foci.yaml')
# cfg.run_name = 'rab11_nb'
# cfg.training.start_micro = 1000
# cfg.training.start_psf = 1000
# cfg.microscope.int_conc = 2.
psf, noise, micro = load_psf_noise_micro(cfg)
post_proc = hydra.utils.instantiate(cfg.post_proc_isi)
cfg.training.start_micro = 0
cfg.training.start_psf = 0
# cfg.data_path.model_init = '/groups/turaga/home/speisera/Mackebox/Artur/WorkDB/deepstorm/models/fishcod/Fig_sim_density/sweep_ff21/int_conc:2.0/sl_save'
# cfg.data_path.model_init = cfg.output.save_dir + '/sl_save'
img_3d, decode_dl = get_dataloader(cfg)
# get_simulation_statistics(decode_dl, micro, int_threshold=1500)
inp_offset, inp_scale = get_forward_scaling(img_3d[0])
model = hydra.utils.instantiate(cfg.model, inp_scale=float(inp_scale), inp_offset=float(inp_offset))
psf .to('cuda')
model.to('cuda')
micro.to('cuda')
opt_net = hydra.utils.instantiate(cfg.training.net.opt, params=model.unet.parameters())
opt_psf = hydra.utils.instantiate(cfg.training.psf.opt, params=list(psf.parameters()))
opt_mic = hydra.utils.instantiate(cfg.training.micro.opt, params=list(model.int_dist.parameters())[:3])
scheduler_net = hydra.utils.instantiate(cfg.training.net.sched, optimizer=opt_net)
scheduler_psf = hydra.utils.instantiate(cfg.training.psf.sched, optimizer=opt_psf)
scheduler_mic = hydra.utils.instantiate(cfg.training.psf.sched, optimizer=opt_mic)
if cfg.evaluation is not None:
eval_dict = dict(cfg.evaluation)
eval_dict['crop_sl'] = eval(eval_dict['crop_sl'],{'__builtins__': None},{'s_': np.s_})
eval_dict['px_size_zyx'] = list(eval_dict['px_size_zyx'])
else:
eval_dict = None
save_dir = Path(cfg.output.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
if cfg.data_path.model_init is not None:
print('loading')
model = load_model_state(model, cfg.data_path.model_init).cuda()
micro.load_state_dict(torch.load(Path(cfg.data_path.model_init)/'microscope.pkl'))
opt_net.load_state_dict(torch.load(Path(cfg.data_path.model_init)/'opt_net.pkl'))
# opt_psf.load_state_dict(torch.load(Path(cfg.data_path.model_init)/'opt_psf.pkl'))
# psf.load_state_dict(torch.load(Path(cfg.data_path.model_init)/'psf.pkl'))
# locs_3d, x_os_3d, y_os_3d, z_os_3d, ints_3d, output_shape = point_process.sample()
# plot_3d_projections(xsim[0,0])
_ = wandb.init(project=cfg.output.project,
config=OmegaConf.to_container(cfg, resolve=True),
dir=cfg.output.log_dir,
group=cfg.output.group,
name=cfg.run_name
)
train(cfg=cfg,
model=model,
dl=decode_dl,
optim_net=opt_net,
optim_psf=opt_psf,
optim_mic=opt_mic,
sched_net=scheduler_net,
sched_psf=scheduler_psf,
sched_mic=scheduler_mic,
psf=psf,
post_proc=post_proc,
microscope=micro,
eval_dict=eval_dict)
!nbdev_build_lib