Skip to content

NetworkView

flyvision.network.directories.NetworkDir

Bases: Directory

Directory for a network.

Written to by the solver.

Name Type Description
loss ArrayFile

Loss values over iterations.

activity ArrayFile

Mean activity values over iterations.

activity_min ArrayFile

Minimum activity values over iterations.

activity_max ArrayFile

Maximum activity values over iterations.

loss_<task> ArrayFile

Loss values for each specific task over iterations.

chkpt_index ArrayFile

Numerical identifiers of checkpoints.

chkpt_iter ArrayFile

Iterations at which checkpoints were recorded.

best_chkpt_index ArrayFile

Checkpoint index with minimal validation loss.

dt ArrayFile

Current time constant of the dataset.

time_trained ArrayFile

Total time spent training.

Written by NetworkView.

Name Type Description
__cache__ Directory

joblib cache.

Source code in flyvision/network/directories.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@root(flyvision.results_dir)
class NetworkDir(Directory):
    """Directory for a network.

    Attributes: Written to by the solver.
        loss (ArrayFile): Loss values over iterations.
        activity (ArrayFile): Mean activity values over iterations.
        activity_min (ArrayFile): Minimum activity values over iterations.
        activity_max (ArrayFile): Maximum activity values over iterations.
        loss_<task> (ArrayFile): Loss values for each specific task over iterations.
        chkpt_index (ArrayFile): Numerical identifiers of checkpoints.
        chkpt_iter (ArrayFile): Iterations at which checkpoints were recorded.
        best_chkpt_index (ArrayFile): Checkpoint index with minimal validation loss.
        dt (ArrayFile): Current time constant of the dataset.
        time_trained (ArrayFile): Total time spent training.

    Attributes: Written by NetworkView.
        __cache__ (Directory): joblib cache.
    """

flyvision.network.network_view.NetworkView

IO interface for network.

Parameters:

Name Type Description Default
network_dir Union[str, PathLike, NetworkDir]

Directory of the network.

required
network_class Module

Network class. Defaults to Network.

Network
root_dir PathLike

Root directory. Defaults to flyvision.results_dir.

results_dir
connectome_getter Callable

Function to get the connectome. Defaults to flyvision_connectome.

get_avgfilt_connectome
checkpoint_mapper Callable

Function to map checkpoints. Defaults to resolve_checkpoints.

resolve_checkpoints
best_checkpoint_fn Callable

Function to get the best checkpoint. Defaults to best_checkpoint_default_fn.

best_checkpoint_default_fn
best_checkpoint_fn_kwargs dict

Keyword arguments for best_checkpoint_fn. Defaults to {“validation_subdir”: “validation”, “loss_file_name”: “loss”}.

{'validation_subdir': 'validation', 'loss_file_name': 'loss'}
recover_fn Callable

Function to recover the network. Defaults to recover_network.

recover_network

Attributes:

Name Type Description
network_class Module

Network class.

dir Directory

Network directory.

name str

Network name.

root_dir PathLike

Root directory.

connectome_getter Callable

Function to get the connectome.

checkpoint_mapper Callable

Function to map checkpoints.

connectome_view ConnectomeView

Connectome view.

connectome Directory

Connectome directory.

checkpoints

Mapped checkpoints.

memory Memory

Joblib memory cache.

best_checkpoint_fn Callable

Function to get the best checkpoint.

best_checkpoint_fn_kwargs dict

Keyword arguments for best_checkpoint_fn.

recover_fn Callable

Function to recover the network.

_network CheckpointedNetwork

Checkpointed network instance.

decoder

Decoder instance.

_initialized dict

Initialization status for network and decoder.

cache FIFOCache

Cache for storing results.

Source code in flyvision/network/network_view.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
class NetworkView:
    """IO interface for network.

    Args:
        network_dir: Directory of the network.
        network_class: Network class. Defaults to Network.
        root_dir: Root directory. Defaults to flyvision.results_dir.
        connectome_getter: Function to get the connectome.
            Defaults to flyvision_connectome.
        checkpoint_mapper: Function to map checkpoints. Defaults to resolve_checkpoints.
        best_checkpoint_fn: Function to get the best checkpoint. Defaults to
            best_checkpoint_default_fn.
        best_checkpoint_fn_kwargs: Keyword arguments for best_checkpoint_fn. Defaults to
            {"validation_subdir": "validation", "loss_file_name": "loss"}.
        recover_fn: Function to recover the network. Defaults to recover_network.

    Attributes:
        network_class (nn.Module): Network class.
        dir (Directory): Network directory.
        name (str): Network name.
        root_dir (PathLike): Root directory.
        connectome_getter (Callable): Function to get the connectome.
        checkpoint_mapper (Callable): Function to map checkpoints.
        connectome_view (ConnectomeView): Connectome view.
        connectome (Directory): Connectome directory.
        checkpoints: Mapped checkpoints.
        memory (Memory): Joblib memory cache.
        best_checkpoint_fn (Callable): Function to get the best checkpoint.
        best_checkpoint_fn_kwargs (dict): Keyword arguments for best_checkpoint_fn.
        recover_fn (Callable): Function to recover the network.
        _network (CheckpointedNetwork): Checkpointed network instance.
        decoder: Decoder instance.
        _initialized (dict): Initialization status for network and decoder.
        cache (FIFOCache): Cache for storing results.
    """

    def __init__(
        self,
        network_dir: Union[str, PathLike, NetworkDir],
        network_class: nn.Module = Network,
        root_dir: PathLike = flyvision.results_dir,
        connectome_getter: Callable = get_avgfilt_connectome,
        checkpoint_mapper: Callable = resolve_checkpoints,
        best_checkpoint_fn: Callable = best_checkpoint_default_fn,
        best_checkpoint_fn_kwargs: dict = {
            "validation_subdir": "validation",
            "loss_file_name": "loss",
        },
        recover_fn: Callable = recover_network,
    ):
        self.network_class = network_class
        self.dir, self.name = self._resolve_dir(network_dir, root_dir)
        self.root_dir = root_dir
        self.connectome_getter = connectome_getter
        self.checkpoint_mapper = checkpoint_mapper
        self.connectome_view: ConnectomeView = connectome_getter(
            self.dir.config.network.connectome
        )
        self.connectome = self.connectome_view.dir
        self.checkpoints = checkpoint_mapper(self.dir)
        self.memory = Memory(
            location=self.dir.path / "__cache__", verbose=0, backend="xarray_dataset_h5"
        )
        self.best_checkpoint_fn = best_checkpoint_fn
        self.best_checkpoint_fn_kwargs = best_checkpoint_fn_kwargs
        self.recover_fn = recover_fn
        self._network = CheckpointedNetwork(
            self.network_class,
            self.dir.config.network.to_dict(),
            self.name,
            self.get_checkpoint("best"),
            self.recover_fn,
            network=None,
        )
        self.decoder = None
        self._initialized = {"network": None, "decoder": None}
        self.cache = FIFOCache(maxsize=3)
        logging.info("Initialized network view at %s", str(self.dir.path))

    def _clear_cache(self):
        """Clear the FIFO cache."""
        self.cache = self.cache.__class__(maxsize=self.cache.maxsize)

    def _clear_memory(self):
        """Clear the joblib memory cache."""
        self.memory.clear()

    # --- ConnectomeView API for static code analysis
    # pylint: disable=missing-function-docstring
    @wraps(ConnectomeView.connectivity_matrix)
    def connectivity_matrix(self, *args, **kwargs):
        return self.connectome_view.connectivity_matrix(*args, **kwargs)

    connectivity_matrix.__doc__ = ConnectomeView.connectivity_matrix.__doc__

    @wraps(ConnectomeView.network_layout)
    def network_layout(self, *args, **kwargs):
        return self.connectome_view.network_layout(*args, **kwargs)

    network_layout.__doc__ = ConnectomeView.network_layout.__doc__

    @wraps(ConnectomeView.hex_layout)
    def hex_layout(self, *args, **kwargs):
        return self.connectome_view.hex_layout(*args, **kwargs)

    hex_layout.__doc__ = ConnectomeView.hex_layout.__doc__

    @wraps(ConnectomeView.hex_layout_all)
    def hex_layout_all(self, *args, **kwargs):
        return self.connectome_view.hex_layout_all(*args, **kwargs)

    hex_layout_all.__doc__ = ConnectomeView.hex_layout_all.__doc__

    @wraps(ConnectomeView.get_uv)
    def get_uv(self, *args, **kwargs):
        return self.connectome_view.get_uv(*args, **kwargs)

    get_uv.__doc__ = ConnectomeView.get_uv.__doc__

    @wraps(ConnectomeView.sources_list)
    def sources_list(self, *args, **kwargs):
        return self.connectome_view.sources_list(*args, **kwargs)

    sources_list.__doc__ = ConnectomeView.sources_list.__doc__

    @wraps(ConnectomeView.targets_list)
    def targets_list(self, *args, **kwargs):
        return self.connectome_view.targets_list(*args, **kwargs)

    targets_list.__doc__ = ConnectomeView.targets_list.__doc__

    @wraps(ConnectomeView.receptive_field)
    def receptive_field(self, *args, **kwargs):
        return self.connectome_view.receptive_field(*args, **kwargs)

    receptive_field.__doc__ = ConnectomeView.receptive_field.__doc__

    @wraps(ConnectomeView.receptive_fields_grid)
    def receptive_fields_grid(self, *args, **kwargs):
        return self.connectome_view.receptive_fields_grid(*args, **kwargs)

    receptive_fields_grid.__doc__ = ConnectomeView.receptive_fields_grid.__doc__

    @wraps(ConnectomeView.projective_field)
    def projective_field(self, *args, **kwargs):
        return self.connectome_view.projective_field(*args, **kwargs)

    projective_field.__doc__ = ConnectomeView.projective_field.__doc__

    @wraps(ConnectomeView.projective_fields_grid)
    def projective_fields_grid(self, *args, **kwargs):
        return self.connectome_view.projective_fields_grid(*args, **kwargs)

    projective_fields_grid.__doc__ = ConnectomeView.projective_fields_grid.__doc__

    @wraps(ConnectomeView.receptive_fields_df)
    def receptive_fields_df(self, *args, **kwargs):
        return self.connectome_view.receptive_fields_df(*args, **kwargs)

    receptive_fields_df.__doc__ = ConnectomeView.receptive_fields_df.__doc__

    @wraps(ConnectomeView.receptive_fields_sum)
    def receptive_fields_sum(self, *args, **kwargs):
        return self.connectome_view.receptive_fields_sum(*args, **kwargs)

    receptive_fields_sum.__doc__ = ConnectomeView.receptive_fields_sum.__doc__

    @wraps(ConnectomeView.projective_fields_df)
    def projective_fields_df(self, *args, **kwargs):
        return self.connectome_view.projective_fields_df(*args, **kwargs)

    projective_fields_df.__doc__ = ConnectomeView.projective_fields_df.__doc__

    @wraps(ConnectomeView.projective_fields_sum)
    def projective_fields_sum(self, *args, **kwargs):
        return self.connectome_view.projective_fields_sum(*args, **kwargs)

    projective_fields_sum.__doc__ = ConnectomeView.projective_fields_sum.__doc__

    # --- own API

    def get_checkpoint(self, checkpoint="best"):
        """Return the best checkpoint index.

        Args:
            checkpoint: Checkpoint identifier. Defaults to "best".

        Returns:
            str: Path to the checkpoint.
        """
        try:
            if checkpoint == "best":
                return self.best_checkpoint_fn(
                    self.dir.path,
                    **self.best_checkpoint_fn_kwargs,
                )
            return self.checkpoints.paths[checkpoint]
        except FileNotFoundError:
            logger.warning("Checkpoint %s not found at %s", checkpoint, self.dir.path)
            return None

    def network(
        self, checkpoint="best", network: Optional[Any] = None, lazy=False
    ) -> CheckpointedNetwork:
        """Lazy loading of network instance.

        Args:
            checkpoint: Checkpoint identifier. Defaults to "best".
            network: Existing network instance to use. Defaults to None.
            lazy: If True, don't recover the network immediately. Defaults to False.

        Returns:
            CheckpointedNetwork: Checkpointed network instance.
        """
        self._network = CheckpointedNetwork(
            self.network_class,
            self.dir.config.network.to_dict(),
            self.name,
            self.get_checkpoint(checkpoint),
            self.recover_fn,
            network=network or self._network.network,
        )
        if self._network.network is not None and not lazy:
            self._network.recover()
        return self._network

    def init_network(self, checkpoint="best", network: Optional[Any] = None) -> Network:
        """Initialize the network.

        Args:
            checkpoint: Checkpoint identifier. Defaults to "best".
            network: Existing network instance to use. Defaults to None.

        Returns:
            Network: Initialized network instance.
        """
        checkpointed_network = self.network(checkpoint=checkpoint, network=network)

        if checkpointed_network.network is not None:
            return checkpointed_network.network
        checkpointed_network.init()
        return checkpointed_network.recover()

    def init_decoder(self, checkpoint="best", decoder=None):
        """Initialize the decoder.

        Args:
            checkpoint: Checkpoint identifier. Defaults to "best".
            decoder: Existing decoder instance to use. Defaults to None.

        Returns:
            Decoder: Initialized decoder instance.
        """
        checkpointed_network = self.network(checkpoint=checkpoint, lazy=True)
        if (
            self._initialized["decoder"] == checkpointed_network.checkpoint
            and decoder is None
        ):
            return self.decoder
        self.decoder = decoder or init_decoder(
            self.dir.config.task.decoder, self.connectome
        )
        recover_decoder(self.decoder, checkpointed_network.checkpoint)
        self._initialized["decoder"] = checkpointed_network.checkpoint
        return self.decoder

    def _resolve_dir(self, network_dir, root_dir):
        """Resolve the network directory.

        Args:
            network_dir: Network directory path or Directory instance.
            root_dir: Root directory path.

        Returns:
            tuple: (Directory, str) - Network directory and name.

        Raises:
            ValueError: If the directory is not a NetworkDir.
        """
        if isinstance(network_dir, (PathLike, str)):
            with set_root_context(root_dir):
                network_dir = Directory(network_dir)
        if not network_dir.config.type == "NetworkDir":
            raise ValueError(f"NetworkDir not found at {network_dir.path}.")
        name = os.path.sep.join(network_dir.path.parts[-3:])
        return network_dir, name

    # --- stimulus responses

    @wraps(stimulus_responses.flash_responses)
    @context_aware_cache
    def flash_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate flash responses."""
        return stimulus_responses.flash_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.moving_edge_responses)
    @context_aware_cache
    def moving_edge_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate moving edge responses."""
        return stimulus_responses.moving_edge_responses(self, *args, **kwargs)

    @wraps(stimulus_responses_currents.moving_edge_currents)
    @context_aware_cache
    def moving_edge_currents(
        self, *args, **kwargs
    ) -> List[stimulus_responses_currents.ExperimentData]:
        """Generate moving edge currents."""
        return stimulus_responses_currents.moving_edge_currents(self, *args, **kwargs)

    @wraps(stimulus_responses.moving_bar_responses)
    @context_aware_cache
    def moving_bar_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate moving bar responses."""
        return stimulus_responses.moving_bar_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.naturalistic_stimuli_responses)
    @context_aware_cache
    def naturalistic_stimuli_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate naturalistic stimuli responses."""
        return stimulus_responses.naturalistic_stimuli_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.central_impulses_responses)
    @context_aware_cache
    def central_impulses_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate central ommatidium impulses responses."""
        return stimulus_responses.central_impulses_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.spatial_impulses_responses)
    @context_aware_cache
    def spatial_impulses_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate spatial ommatidium impulses responses."""
        return stimulus_responses.spatial_impulses_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.optimal_stimulus_responses)
    @context_aware_cache
    def optimal_stimulus_responses(
        self, cell_type, *args, **kwargs
    ) -> optimal_stimuli.RegularizedOptimalStimulus:
        """Generate optimal stimuli responses."""
        return stimulus_responses.optimal_stimulus_responses(
            self, cell_type, *args, **kwargs
        )

get_checkpoint

get_checkpoint(checkpoint='best')

Return the best checkpoint index.

Parameters:

Name Type Description Default
checkpoint

Checkpoint identifier. Defaults to “best”.

'best'

Returns:

Name Type Description
str

Path to the checkpoint.

Source code in flyvision/network/network_view.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def get_checkpoint(self, checkpoint="best"):
    """Return the best checkpoint index.

    Args:
        checkpoint: Checkpoint identifier. Defaults to "best".

    Returns:
        str: Path to the checkpoint.
    """
    try:
        if checkpoint == "best":
            return self.best_checkpoint_fn(
                self.dir.path,
                **self.best_checkpoint_fn_kwargs,
            )
        return self.checkpoints.paths[checkpoint]
    except FileNotFoundError:
        logger.warning("Checkpoint %s not found at %s", checkpoint, self.dir.path)
        return None

network

network(checkpoint='best', network=None, lazy=False)

Lazy loading of network instance.

Parameters:

Name Type Description Default
checkpoint

Checkpoint identifier. Defaults to “best”.

'best'
network Optional[Any]

Existing network instance to use. Defaults to None.

None
lazy

If True, don’t recover the network immediately. Defaults to False.

False

Returns:

Name Type Description
CheckpointedNetwork CheckpointedNetwork

Checkpointed network instance.

Source code in flyvision/network/network_view.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
def network(
    self, checkpoint="best", network: Optional[Any] = None, lazy=False
) -> CheckpointedNetwork:
    """Lazy loading of network instance.

    Args:
        checkpoint: Checkpoint identifier. Defaults to "best".
        network: Existing network instance to use. Defaults to None.
        lazy: If True, don't recover the network immediately. Defaults to False.

    Returns:
        CheckpointedNetwork: Checkpointed network instance.
    """
    self._network = CheckpointedNetwork(
        self.network_class,
        self.dir.config.network.to_dict(),
        self.name,
        self.get_checkpoint(checkpoint),
        self.recover_fn,
        network=network or self._network.network,
    )
    if self._network.network is not None and not lazy:
        self._network.recover()
    return self._network

init_network

init_network(checkpoint='best', network=None)

Initialize the network.

Parameters:

Name Type Description Default
checkpoint

Checkpoint identifier. Defaults to “best”.

'best'
network Optional[Any]

Existing network instance to use. Defaults to None.

None

Returns:

Name Type Description
Network Network

Initialized network instance.

Source code in flyvision/network/network_view.py
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def init_network(self, checkpoint="best", network: Optional[Any] = None) -> Network:
    """Initialize the network.

    Args:
        checkpoint: Checkpoint identifier. Defaults to "best".
        network: Existing network instance to use. Defaults to None.

    Returns:
        Network: Initialized network instance.
    """
    checkpointed_network = self.network(checkpoint=checkpoint, network=network)

    if checkpointed_network.network is not None:
        return checkpointed_network.network
    checkpointed_network.init()
    return checkpointed_network.recover()

init_decoder

init_decoder(checkpoint='best', decoder=None)

Initialize the decoder.

Parameters:

Name Type Description Default
checkpoint

Checkpoint identifier. Defaults to “best”.

'best'
decoder

Existing decoder instance to use. Defaults to None.

None

Returns:

Name Type Description
Decoder

Initialized decoder instance.

Source code in flyvision/network/network_view.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def init_decoder(self, checkpoint="best", decoder=None):
    """Initialize the decoder.

    Args:
        checkpoint: Checkpoint identifier. Defaults to "best".
        decoder: Existing decoder instance to use. Defaults to None.

    Returns:
        Decoder: Initialized decoder instance.
    """
    checkpointed_network = self.network(checkpoint=checkpoint, lazy=True)
    if (
        self._initialized["decoder"] == checkpointed_network.checkpoint
        and decoder is None
    ):
        return self.decoder
    self.decoder = decoder or init_decoder(
        self.dir.config.task.decoder, self.connectome
    )
    recover_decoder(self.decoder, checkpointed_network.checkpoint)
    self._initialized["decoder"] = checkpointed_network.checkpoint
    return self.decoder

flash_responses

flash_responses(*args, **kwargs)

Generate flash responses.

Source code in flyvision/network/network_view.py
337
338
339
340
341
@wraps(stimulus_responses.flash_responses)
@context_aware_cache
def flash_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate flash responses."""
    return stimulus_responses.flash_responses(self, *args, **kwargs)

moving_edge_responses

moving_edge_responses(*args, **kwargs)

Generate moving edge responses.

Source code in flyvision/network/network_view.py
343
344
345
346
347
@wraps(stimulus_responses.moving_edge_responses)
@context_aware_cache
def moving_edge_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate moving edge responses."""
    return stimulus_responses.moving_edge_responses(self, *args, **kwargs)

moving_edge_currents

moving_edge_currents(*args, **kwargs)

Generate moving edge currents.

Source code in flyvision/network/network_view.py
349
350
351
352
353
354
355
@wraps(stimulus_responses_currents.moving_edge_currents)
@context_aware_cache
def moving_edge_currents(
    self, *args, **kwargs
) -> List[stimulus_responses_currents.ExperimentData]:
    """Generate moving edge currents."""
    return stimulus_responses_currents.moving_edge_currents(self, *args, **kwargs)

moving_bar_responses

moving_bar_responses(*args, **kwargs)

Generate moving bar responses.

Source code in flyvision/network/network_view.py
357
358
359
360
361
@wraps(stimulus_responses.moving_bar_responses)
@context_aware_cache
def moving_bar_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate moving bar responses."""
    return stimulus_responses.moving_bar_responses(self, *args, **kwargs)

naturalistic_stimuli_responses

naturalistic_stimuli_responses(*args, **kwargs)

Generate naturalistic stimuli responses.

Source code in flyvision/network/network_view.py
363
364
365
366
367
@wraps(stimulus_responses.naturalistic_stimuli_responses)
@context_aware_cache
def naturalistic_stimuli_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate naturalistic stimuli responses."""
    return stimulus_responses.naturalistic_stimuli_responses(self, *args, **kwargs)

central_impulses_responses

central_impulses_responses(*args, **kwargs)

Generate central ommatidium impulses responses.

Source code in flyvision/network/network_view.py
369
370
371
372
373
@wraps(stimulus_responses.central_impulses_responses)
@context_aware_cache
def central_impulses_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate central ommatidium impulses responses."""
    return stimulus_responses.central_impulses_responses(self, *args, **kwargs)

spatial_impulses_responses

spatial_impulses_responses(*args, **kwargs)

Generate spatial ommatidium impulses responses.

Source code in flyvision/network/network_view.py
375
376
377
378
379
@wraps(stimulus_responses.spatial_impulses_responses)
@context_aware_cache
def spatial_impulses_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate spatial ommatidium impulses responses."""
    return stimulus_responses.spatial_impulses_responses(self, *args, **kwargs)

optimal_stimulus_responses

optimal_stimulus_responses(cell_type, *args, **kwargs)

Generate optimal stimuli responses.

Source code in flyvision/network/network_view.py
381
382
383
384
385
386
387
388
389
@wraps(stimulus_responses.optimal_stimulus_responses)
@context_aware_cache
def optimal_stimulus_responses(
    self, cell_type, *args, **kwargs
) -> optimal_stimuli.RegularizedOptimalStimulus:
    """Generate optimal stimuli responses."""
    return stimulus_responses.optimal_stimulus_responses(
        self, cell_type, *args, **kwargs
    )

flyvision.network.network_view.CheckpointedNetwork dataclass

A network representation with checkpoint that can be pickled.

Attributes:

Name Type Description
network_class Any

Network class (e.g., flyvision.Network).

config Dict

Configuration for the network.

name str

Name of the network.

checkpoint PathLike

Checkpoint path.

recover_fn Any

Function to recover the network.

network Optional[Network]

Network instance to avoid reinitialization.

Source code in flyvision/network/network_view.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
@dataclass
class CheckpointedNetwork:
    """A network representation with checkpoint that can be pickled.

    Attributes:
        network_class: Network class (e.g., flyvision.Network).
        config: Configuration for the network.
        name: Name of the network.
        checkpoint: Checkpoint path.
        recover_fn: Function to recover the network.
        network: Network instance to avoid reinitialization.
    """

    network_class: Any
    config: Dict
    name: str
    checkpoint: PathLike
    recover_fn: Any = recover_network
    network: Optional[Network] = None

    def init(self, eval: bool = True) -> Network:
        """Initialize the network.

        Args:
            eval: Whether to set the network in evaluation mode.

        Returns:
            The initialized network.
        """
        if self.network is None:
            self.network = self.network_class(**self.config)
        if eval:
            self.network.eval()
        return self.network

    def recover(self, checkpoint: Optional[PathLike] = None) -> Network:
        """Recover the network from the checkpoint.

        Args:
            checkpoint: Path to the checkpoint. If None, uses the default checkpoint.

        Returns:
            The recovered network.

        Note:
            Initializes the network if it hasn't been initialized yet.
        """
        if self.network is None:
            self.init()
        return self.recover_fn(self.network, checkpoint or self.checkpoint)

    def __hash__(self):
        return hash((
            self.network_class,
            make_hashable(self.config),
            self.checkpoint,
        ))

    # Equality check based on hashable elements.
    def __eq__(self, other):
        if not isinstance(other, CheckpointedNetwork):
            return False
        return (
            self.network_class == other.network_class
            and make_hashable(self.config) == make_hashable(other.config)
            and self.checkpoint == other.checkpoint
        )

    # Custom reduce method to make the object compatible with joblib's pickling.
    # This ensures the 'network' attribute is never pickled.
    # Return a tuple containing:
    # 1. A callable that will recreate the object (here, the class itself)
    # 2. The arguments required to recreate the object (excluding the network)
    # 3. The state, excluding the 'network' attribute
    def __reduce__(self):
        state = self.__dict__.copy()
        state["network"] = None  # Exclude the complex network from being pickled

        return (
            self.__class__,  # The callable (class itself)
            (
                self.network_class,
                self.config,
                self.checkpoint,
                self.recover_fn,
                None,
            ),  # Arguments to reconstruct the object
            state,  # State without the 'network' attribute
        )

    # Restore the object's state, but do not load the network from the state.
    def __setstate__(self, state):
        self.__dict__.update(state)
        self.network = None

init

init(eval=True)

Initialize the network.

Parameters:

Name Type Description Default
eval bool

Whether to set the network in evaluation mode.

True

Returns:

Type Description
Network

The initialized network.

Source code in flyvision/network/network_view.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def init(self, eval: bool = True) -> Network:
    """Initialize the network.

    Args:
        eval: Whether to set the network in evaluation mode.

    Returns:
        The initialized network.
    """
    if self.network is None:
        self.network = self.network_class(**self.config)
    if eval:
        self.network.eval()
    return self.network

recover

recover(checkpoint=None)

Recover the network from the checkpoint.

Parameters:

Name Type Description Default
checkpoint Optional[PathLike]

Path to the checkpoint. If None, uses the default checkpoint.

None

Returns:

Type Description
Network

The recovered network.

Note

Initializes the network if it hasn’t been initialized yet.

Source code in flyvision/network/network_view.py
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
def recover(self, checkpoint: Optional[PathLike] = None) -> Network:
    """Recover the network from the checkpoint.

    Args:
        checkpoint: Path to the checkpoint. If None, uses the default checkpoint.

    Returns:
        The recovered network.

    Note:
        Initializes the network if it hasn't been initialized yet.
    """
    if self.network is None:
        self.init()
    return self.recover_fn(self.network, checkpoint or self.checkpoint)