Definition of the classes and modules we use to build our 3D UNet
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *
# class InferenceNetwork(nn.Module):
# def __init__(self, ch_in: int =1, ch_out: int=10, final_sigmoid : bool =False, depth: int =3, inp_scale: float=1., inp_offset: float=0., order='bcr', f_maps=64, p_offset=-5.,
# int_conc=5, int_rate=1, int_loc=1):
# super().__init__()
# self.unet = UNet3D(ch_in, ch_out, final_sigmoid=final_sigmoid, num_levels=depth,
# layer_order = order, inp_scale=inp_scale, inp_offset=inp_offset, f_maps=f_maps)
# self.p_offset = p_offset
# self.int_dist = IntensityDist(int_conc, int_rate, int_loc)
# self.p_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
# self.p_out2 = nn.Conv3d(f_maps, 1, kernel_size=1, padding=0)
# nn.init.constant_(self.p_out2.bias,p_offset)
# self.xyzi_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
# self.xyzi_out2 = nn.Conv3d(f_maps, 4, kernel_size=1, padding=0)
# self.xyzis_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
# self.xyzis_out2 = nn.Conv3d(f_maps, 4, kernel_size=1, padding=0)
# self.bg_out1 = nn.Conv3d(f_maps, f_maps, kernel_size=3, padding=1)
# self.bg_out2 = nn.Conv3d(f_maps, 1, kernel_size=1, padding=0)
# nn.init.kaiming_normal_(self.p_out1.weight, mode='fan_in', nonlinearity='relu')
# nn.init.kaiming_normal_(self.p_out2.weight, mode='fan_in', nonlinearity='linear')
# nn.init.kaiming_normal_(self.xyzi_out1.weight, mode='fan_in', nonlinearity='relu')
# nn.init.kaiming_normal_(self.xyzi_out2.weight, mode='fan_in', nonlinearity='linear')
# nn.init.kaiming_normal_(self.xyzis_out1.weight, mode='fan_in', nonlinearity='relu')
# nn.init.kaiming_normal_(self.xyzis_out2.weight, mode='fan_in', nonlinearity='linear')
# nn.init.kaiming_normal_(self.bg_out1.weight, mode='fan_in', nonlinearity='relu')
# nn.init.kaiming_normal_(self.bg_out2.weight, mode='fan_in', nonlinearity='linear')
# def forward(self, x):
# out = self.unet(x)
# logit = F.elu(self.p_out1(out))
# logit = self.p_out2(logit)
# logit = torch.clamp(logit, -15., 15)
# xyzi = F.elu(self.xyzi_out1(out))
# xyzi = self.xyzi_out2(xyzi)
# xyz_mu = torch.tanh(xyzi[:, :3])
# i_mu = F.softplus(xyzi[:, 3:]) + self.int_dist.int_loc.detach() + 0.01
# xyzi_mu =, i_mu), dim=1)
# xyzis = F.elu(self.xyzis_out1(out))
# xyzis = self.xyzis_out2(xyzis)
# xyzi_sig = F.softplus(xyzis) + 0.01
# background = F.elu(self.bg_out1(out))
# background = self.bg_out2(background)
# background = self.unet.inp_scale * F.softplus(background)
# return[logit,xyzi_mu,xyzi_sig,background],1)
# def tensor_to_dict(self, x):
# return {'logits': x[:,0:1],
# 'xyzi_mu': x[:,1:5],
# 'xyzi_sigma': x[:,5:9],
# 'background': x[:,9:10]}
# output = model.tensor_to_dict(model(torch.randn([10,1,20,20,20])))
# for k in output.keys():
# print(k, output[k].shape)
model = UnetDecodeNoBn(order= 'ce', f_maps=32)
output = model.tensor_to_dict(model(torch.randn([2,1,37,48,48])))
for k in output.keys():
print(k, output[k].shape)
sum(p.numel() for p in model.parameters())
cfg = OmegaConf.load(default_conf)
model = hydra.utils.instantiate(cfg.model, int_loc=1, inp_scale=1, inp_offset=0)
pytorch_total_params = sum(p.numel() for p in model.parameters())