Skip to content

EnsembleView

flyvision.network.directories.EnsembleDir

Bases: Directory

Contains many NetworkDirs.

Source code in flyvision/network/directories.py
39
40
41
@root(flyvision.results_dir)
class EnsembleDir(Directory):
    """Contains many NetworkDirs."""

flyvision.network.EnsembleView

Bases: Ensemble

A view of an ensemble of trained networks.

This class extends the Ensemble class with visualization and analysis methods.

Parameters:

Name Type Description Default
path Union[str, Path, Iterable, EnsembleDir, Ensemble]

Path to the ensemble directory or an existing Ensemble object.

required
network_class Module

The network class to use for instantiation.

Network
root_dir Path

Root directory for results.

results_dir
connectome_getter Callable

Function to get the connectome.

get_avgfilt_connectome
checkpoint_mapper Callable

Function to resolve checkpoints.

resolve_checkpoints
best_checkpoint_fn Callable

Function to select the best checkpoint.

best_checkpoint_default_fn
best_checkpoint_fn_kwargs dict

Keyword arguments for best_checkpoint_fn.

{'validation_subdir': 'validation', 'loss_file_name': 'loss'}
recover_fn Callable

Function to recover the network.

recover_network
try_sort bool

Whether to try sorting the ensemble.

False
Source code in flyvision/network/ensemble_view.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
class EnsembleView(Ensemble):
    """A view of an ensemble of trained networks.

    This class extends the Ensemble class with visualization and analysis methods.

    Args:
        path: Path to the ensemble directory or an existing Ensemble object.
        network_class: The network class to use for instantiation.
        root_dir: Root directory for results.
        connectome_getter: Function to get the connectome.
        checkpoint_mapper: Function to resolve checkpoints.
        best_checkpoint_fn: Function to select the best checkpoint.
        best_checkpoint_fn_kwargs: Keyword arguments for best_checkpoint_fn.
        recover_fn: Function to recover the network.
        try_sort: Whether to try sorting the ensemble.

    Attributes:
        Inherits all attributes from the Ensemble class.
    """

    def __init__(
        self,
        path: Union[str, Path, Iterable, EnsembleDir, Ensemble],
        network_class: nn.Module = Network,
        root_dir: Path = flyvision.results_dir,
        connectome_getter: Callable = get_avgfilt_connectome,
        checkpoint_mapper: Callable = resolve_checkpoints,
        best_checkpoint_fn: Callable = best_checkpoint_default_fn,
        best_checkpoint_fn_kwargs: dict = {
            "validation_subdir": "validation",
            "loss_file_name": "loss",
        },
        recover_fn: Callable = recover_network,
        try_sort: bool = False,
    ):
        init_args = (
            path,
            network_class,
            root_dir,
            connectome_getter,
            checkpoint_mapper,
            best_checkpoint_fn,
            best_checkpoint_fn_kwargs,
            recover_fn,
            try_sort,
        )
        if isinstance(path, Ensemble):
            init_args = path._init_args
        super().__init__(*init_args)

    @wraps(plots.loss_curves)
    def training_loss(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
        """Plot training loss curves for the ensemble.

        Args:
            **kwargs: Additional keyword arguments to pass to plots.loss_curves.

        Returns:
            A tuple containing the matplotlib Figure and Axes objects.
        """
        task_error = self.task_error()
        losses = np.array([nv.dir.loss[:] for nv in self.values()])
        return plots.loss_curves(
            losses,
            cbar=True,
            colors=task_error.colors,
            cmap=task_error.cmap,
            norm=task_error.norm,
            xlabel="iterations",
            ylabel="training loss",
            **kwargs,
        )

    @wraps(plots.loss_curves)
    def validation_loss(
        self,
        validation_subdir: Optional[str] = None,
        loss_file_name: Optional[str] = None,
        **kwargs,
    ) -> Tuple[plt.Figure, plt.Axes]:
        """Plot validation loss curves for the ensemble.

        Args:
            validation_subdir: Subdirectory containing validation data.
            loss_file_name: Name of the loss file.
            **kwargs: Additional keyword arguments to pass to plots.loss_curves.

        Returns:
            A tuple containing the matplotlib Figure and Axes objects.
        """
        task_error = self.task_error()
        losses = self.validation_losses(validation_subdir, loss_file_name)
        return plots.loss_curves(
            losses,
            cbar=True,
            colors=task_error.colors,
            cmap=task_error.cmap,
            norm=task_error.norm,
            xlabel="checkpoints",
            ylabel="validation loss",
            **kwargs,
        )

    @wraps(plots.histogram)
    def task_error_histogram(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
        """Plot a histogram of the validation losses of the ensemble.

        Args:
            **kwargs: Additional keyword arguments to pass to plots.histogram.

        Returns:
            A tuple containing the matplotlib Figure and Axes objects.
        """
        losses = self.min_validation_losses()
        return plots.histogram(
            losses, xlabel="task error", ylabel="number models", **kwargs
        )

    @wraps(plots.violins)
    def node_parameters(
        self, key: str, max_per_ax: int = 34, **kwargs
    ) -> Tuple[plt.Figure, List[plt.Axes]]:
        """Plot violin plots of node parameters for the ensemble.

        Args:
            key: The parameter key to plot.
            max_per_ax: Maximum number of violins per axis.
            **kwargs: Additional keyword arguments to pass to plots.violins.

        Returns:
            A tuple containing the matplotlib Figure and a list of Axes objects.
        """
        parameters = self.parameters()[f"nodes_{key}"]
        parameter_keys = self.parameter_keys()[f"nodes_{key}"]
        return plots.violins(
            parameter_keys, parameters, ylabel=key, max_per_ax=max_per_ax, **kwargs
        )

    @wraps(plots.violins)
    def edge_parameters(
        self, key: str, max_per_ax: int = 120, **kwargs
    ) -> Tuple[plt.Figure, List[plt.Axes]]:
        """Plot violin plots of edge parameters for the ensemble.

        Args:
            key: The parameter key to plot.
            max_per_ax: Maximum number of violins per axis.
            **kwargs: Additional keyword arguments to pass to plots.violins.

        Returns:
            A tuple containing the matplotlib Figure and a list of Axes objects.
        """
        parameters = self.parameters()[f"edges_{key}"]
        parameter_keys = self.parameter_keys()[f"edges_{key}"]
        variable_names = np.array([
            f"{source}->{target}" for source, target in parameter_keys
        ])
        return plots.violins(
            variable_names,
            variable_values=parameters,
            ylabel=key,
            max_per_ax=max_per_ax,
            **kwargs,
        )

    @wraps(plots.heatmap)
    def dead_or_alive(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
        """Plot a heatmap of dead cells in the ensemble.

        Args:
            **kwargs: Additional keyword arguments to pass to plots.heatmap.

        Returns:
            A tuple containing the matplotlib Figure and Axes objects.
        """
        responses = self.naturalistic_stimuli_responses()
        dead_count = (responses['responses'].values < 0).all(axis=(1, 2))
        return plots.heatmap(
            dead_count,
            ylabels=np.arange(len(self)),
            xlabels=responses.cell_type.values,
            size_scale=15,
            cbar=False,
            **kwargs,
        )

    @wraps(plot_fris)
    def flash_response_index(
        self, cell_types: Optional[List[str]] = None, **kwargs
    ) -> Tuple[plt.Figure, plt.Axes]:
        """Plot the flash response indices of the ensemble.

        Args:
            cell_types: List of cell types to include. If None, all cell types are used.
            **kwargs: Additional keyword arguments to pass to plot_fris.

        Returns:
            A tuple containing the matplotlib Figure and Axes objects.
        """
        responses = self.flash_responses()
        fris = flash_response_index(responses, radius=6)
        if cell_types is not None:
            fris = fris.custom.where(cell_type=cell_types)
        else:
            cell_types = fris.cell_type.values
        task_error = self.task_error()
        best_index = np.argmin(task_error.values)
        return plot_fris(
            fris.values,
            cell_types,
            scatter_best=True,
            scatter_best_index=best_index,
            scatter_best_color=cm.get_cmap("Blues")(1.0),
            **kwargs,
        )

    @wraps(dsi_violins_on_and_off)
    def direction_selectivity_index(
        self, **kwargs
    ) -> Tuple[plt.Figure, Tuple[plt.Axes, plt.Axes]]:
        """Plot the direction selectivity indices of the ensemble.

        Args:
            **kwargs: Additional keyword arguments to pass to plot_dsis.

        Returns:
            A tuple containing the matplotlib Figure and a tuple of Axes objects.
        """
        responses = self.moving_edge_responses()
        dsis = direction_selectivity_index(responses)
        task_error = self.task_error()
        best_index = np.argmin(task_error.values)
        return dsi_violins_on_and_off(
            dsis,
            responses.cell_type,
            bold_output_type_labels=False,
            figsize=[10, 1.2],
            color_known_types=True,
            fontsize=6,
            scatter_best_index=best_index,
            scatter_best_color=cm.get_cmap("Blues")(1.0),
            **kwargs,
        )

training_loss

training_loss(**kwargs)

Plot training loss curves for the ensemble.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to pass to plots.loss_curves.

{}

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the matplotlib Figure and Axes objects.

Source code in flyvision/network/ensemble_view.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
@wraps(plots.loss_curves)
def training_loss(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """Plot training loss curves for the ensemble.

    Args:
        **kwargs: Additional keyword arguments to pass to plots.loss_curves.

    Returns:
        A tuple containing the matplotlib Figure and Axes objects.
    """
    task_error = self.task_error()
    losses = np.array([nv.dir.loss[:] for nv in self.values()])
    return plots.loss_curves(
        losses,
        cbar=True,
        colors=task_error.colors,
        cmap=task_error.cmap,
        norm=task_error.norm,
        xlabel="iterations",
        ylabel="training loss",
        **kwargs,
    )

validation_loss

validation_loss(
    validation_subdir=None, loss_file_name=None, **kwargs
)

Plot validation loss curves for the ensemble.

Parameters:

Name Type Description Default
validation_subdir Optional[str]

Subdirectory containing validation data.

None
loss_file_name Optional[str]

Name of the loss file.

None
**kwargs

Additional keyword arguments to pass to plots.loss_curves.

{}

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the matplotlib Figure and Axes objects.

Source code in flyvision/network/ensemble_view.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@wraps(plots.loss_curves)
def validation_loss(
    self,
    validation_subdir: Optional[str] = None,
    loss_file_name: Optional[str] = None,
    **kwargs,
) -> Tuple[plt.Figure, plt.Axes]:
    """Plot validation loss curves for the ensemble.

    Args:
        validation_subdir: Subdirectory containing validation data.
        loss_file_name: Name of the loss file.
        **kwargs: Additional keyword arguments to pass to plots.loss_curves.

    Returns:
        A tuple containing the matplotlib Figure and Axes objects.
    """
    task_error = self.task_error()
    losses = self.validation_losses(validation_subdir, loss_file_name)
    return plots.loss_curves(
        losses,
        cbar=True,
        colors=task_error.colors,
        cmap=task_error.cmap,
        norm=task_error.norm,
        xlabel="checkpoints",
        ylabel="validation loss",
        **kwargs,
    )

task_error_histogram

task_error_histogram(**kwargs)

Plot a histogram of the validation losses of the ensemble.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to pass to plots.histogram.

{}

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the matplotlib Figure and Axes objects.

Source code in flyvision/network/ensemble_view.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@wraps(plots.histogram)
def task_error_histogram(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """Plot a histogram of the validation losses of the ensemble.

    Args:
        **kwargs: Additional keyword arguments to pass to plots.histogram.

    Returns:
        A tuple containing the matplotlib Figure and Axes objects.
    """
    losses = self.min_validation_losses()
    return plots.histogram(
        losses, xlabel="task error", ylabel="number models", **kwargs
    )

node_parameters

node_parameters(key, max_per_ax=34, **kwargs)

Plot violin plots of node parameters for the ensemble.

Parameters:

Name Type Description Default
key str

The parameter key to plot.

required
max_per_ax int

Maximum number of violins per axis.

34
**kwargs

Additional keyword arguments to pass to plots.violins.

{}

Returns:

Type Description
Tuple[Figure, List[Axes]]

A tuple containing the matplotlib Figure and a list of Axes objects.

Source code in flyvision/network/ensemble_view.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
@wraps(plots.violins)
def node_parameters(
    self, key: str, max_per_ax: int = 34, **kwargs
) -> Tuple[plt.Figure, List[plt.Axes]]:
    """Plot violin plots of node parameters for the ensemble.

    Args:
        key: The parameter key to plot.
        max_per_ax: Maximum number of violins per axis.
        **kwargs: Additional keyword arguments to pass to plots.violins.

    Returns:
        A tuple containing the matplotlib Figure and a list of Axes objects.
    """
    parameters = self.parameters()[f"nodes_{key}"]
    parameter_keys = self.parameter_keys()[f"nodes_{key}"]
    return plots.violins(
        parameter_keys, parameters, ylabel=key, max_per_ax=max_per_ax, **kwargs
    )

edge_parameters

edge_parameters(key, max_per_ax=120, **kwargs)

Plot violin plots of edge parameters for the ensemble.

Parameters:

Name Type Description Default
key str

The parameter key to plot.

required
max_per_ax int

Maximum number of violins per axis.

120
**kwargs

Additional keyword arguments to pass to plots.violins.

{}

Returns:

Type Description
Tuple[Figure, List[Axes]]

A tuple containing the matplotlib Figure and a list of Axes objects.

Source code in flyvision/network/ensemble_view.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@wraps(plots.violins)
def edge_parameters(
    self, key: str, max_per_ax: int = 120, **kwargs
) -> Tuple[plt.Figure, List[plt.Axes]]:
    """Plot violin plots of edge parameters for the ensemble.

    Args:
        key: The parameter key to plot.
        max_per_ax: Maximum number of violins per axis.
        **kwargs: Additional keyword arguments to pass to plots.violins.

    Returns:
        A tuple containing the matplotlib Figure and a list of Axes objects.
    """
    parameters = self.parameters()[f"edges_{key}"]
    parameter_keys = self.parameter_keys()[f"edges_{key}"]
    variable_names = np.array([
        f"{source}->{target}" for source, target in parameter_keys
    ])
    return plots.violins(
        variable_names,
        variable_values=parameters,
        ylabel=key,
        max_per_ax=max_per_ax,
        **kwargs,
    )

dead_or_alive

dead_or_alive(**kwargs)

Plot a heatmap of dead cells in the ensemble.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to pass to plots.heatmap.

{}

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the matplotlib Figure and Axes objects.

Source code in flyvision/network/ensemble_view.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
@wraps(plots.heatmap)
def dead_or_alive(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]:
    """Plot a heatmap of dead cells in the ensemble.

    Args:
        **kwargs: Additional keyword arguments to pass to plots.heatmap.

    Returns:
        A tuple containing the matplotlib Figure and Axes objects.
    """
    responses = self.naturalistic_stimuli_responses()
    dead_count = (responses['responses'].values < 0).all(axis=(1, 2))
    return plots.heatmap(
        dead_count,
        ylabels=np.arange(len(self)),
        xlabels=responses.cell_type.values,
        size_scale=15,
        cbar=False,
        **kwargs,
    )

flash_response_index

flash_response_index(cell_types=None, **kwargs)

Plot the flash response indices of the ensemble.

Parameters:

Name Type Description Default
cell_types Optional[List[str]]

List of cell types to include. If None, all cell types are used.

None
**kwargs

Additional keyword arguments to pass to plot_fris.

{}

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the matplotlib Figure and Axes objects.

Source code in flyvision/network/ensemble_view.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
@wraps(plot_fris)
def flash_response_index(
    self, cell_types: Optional[List[str]] = None, **kwargs
) -> Tuple[plt.Figure, plt.Axes]:
    """Plot the flash response indices of the ensemble.

    Args:
        cell_types: List of cell types to include. If None, all cell types are used.
        **kwargs: Additional keyword arguments to pass to plot_fris.

    Returns:
        A tuple containing the matplotlib Figure and Axes objects.
    """
    responses = self.flash_responses()
    fris = flash_response_index(responses, radius=6)
    if cell_types is not None:
        fris = fris.custom.where(cell_type=cell_types)
    else:
        cell_types = fris.cell_type.values
    task_error = self.task_error()
    best_index = np.argmin(task_error.values)
    return plot_fris(
        fris.values,
        cell_types,
        scatter_best=True,
        scatter_best_index=best_index,
        scatter_best_color=cm.get_cmap("Blues")(1.0),
        **kwargs,
    )

direction_selectivity_index

direction_selectivity_index(**kwargs)

Plot the direction selectivity indices of the ensemble.

Parameters:

Name Type Description Default
**kwargs

Additional keyword arguments to pass to plot_dsis.

{}

Returns:

Type Description
Tuple[Figure, Tuple[Axes, Axes]]

A tuple containing the matplotlib Figure and a tuple of Axes objects.

Source code in flyvision/network/ensemble_view.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
@wraps(dsi_violins_on_and_off)
def direction_selectivity_index(
    self, **kwargs
) -> Tuple[plt.Figure, Tuple[plt.Axes, plt.Axes]]:
    """Plot the direction selectivity indices of the ensemble.

    Args:
        **kwargs: Additional keyword arguments to pass to plot_dsis.

    Returns:
        A tuple containing the matplotlib Figure and a tuple of Axes objects.
    """
    responses = self.moving_edge_responses()
    dsis = direction_selectivity_index(responses)
    task_error = self.task_error()
    best_index = np.argmin(task_error.values)
    return dsi_violins_on_and_off(
        dsis,
        responses.cell_type,
        bold_output_type_labels=False,
        figsize=[10, 1.2],
        color_known_types=True,
        fontsize=6,
        scatter_best_index=best_index,
        scatter_best_color=cm.get_cmap("Blues")(1.0),
        **kwargs,
    )

flyvision.network.Ensemble

Bases: dict

Dictionary to a collection of trained networks.

Parameters:

Name Type Description Default
path Union[str, Path, Iterable, 'EnsembleDir']

Path to ensemble directory or list of paths to model directories. Can be a single string, then assumes the path is the root directory as configured by datamate.

required
network_class Module

Class to use for initializing networks.

Network
root_dir Path

Root directory for model paths.

results_dir
connectome_getter Callable

Function to get the connectome.

get_avgfilt_connectome
checkpoint_mapper Callable

Function to map checkpoints.

resolve_checkpoints
best_checkpoint_fn Callable

Function to determine best checkpoint.

best_checkpoint_default_fn
best_checkpoint_fn_kwargs dict

Kwargs for best_checkpoint_fn.

{'validation_subdir': 'validation', 'loss_file_name': 'loss'}
recover_fn Callable

Function to recover network.

recover_network
try_sort bool

Whether to try to sort the ensemble by validation error.

False

Attributes:

Name Type Description
names List[str]

List of model names.

name str

Ensemble name.

path Path

Path to ensemble directory.

model_paths List[Path]

List of paths to model directories.

dir EnsembleDir

Directory object for ensemble directory.

Note

The ensemble is a dynamic dictionary, so you can access the networks in the ensemble by name or index. For example, to access the first network simply do:

ensemble[0]
or
ensemble['flow/000/0000']
Slice to create a subset of the ensemble:
ensemble[0:2]

Source code in flyvision/network/ensemble.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
class Ensemble(dict):
    """Dictionary to a collection of trained networks.

    Args:
        path: Path to ensemble directory or list of paths to model directories.
            Can be a single string, then assumes the path is the root directory
            as configured by datamate.
        network_class: Class to use for initializing networks.
        root_dir: Root directory for model paths.
        connectome_getter: Function to get the connectome.
        checkpoint_mapper: Function to map checkpoints.
        best_checkpoint_fn: Function to determine best checkpoint.
        best_checkpoint_fn_kwargs: Kwargs for best_checkpoint_fn.
        recover_fn: Function to recover network.
        try_sort: Whether to try to sort the ensemble by validation error.

    Attributes:
        names (List[str]): List of model names.
        name (str): Ensemble name.
        path (Path): Path to ensemble directory.
        model_paths (List[Path]): List of paths to model directories.
        dir (EnsembleDir): Directory object for ensemble directory.

    Note:
        The ensemble is a dynamic dictionary, so you can access the networks
        in the ensemble by name or index. For example, to access the first network
        simply do:
        ```python
        ensemble[0]
        ```
        or
        ```python
        ensemble['flow/000/0000']
        ```
        Slice to create a subset of the ensemble:
        ```python
        ensemble[0:2]
        ```
    """

    def __init__(
        self,
        path: Union[str, Path, Iterable, "EnsembleDir"],
        network_class: nn.Module = Network,
        root_dir: Path = flyvision.results_dir,
        connectome_getter: Callable = get_avgfilt_connectome,
        checkpoint_mapper: Callable = resolve_checkpoints,
        best_checkpoint_fn: Callable = best_checkpoint_default_fn,
        best_checkpoint_fn_kwargs: dict = {
            "validation_subdir": "validation",
            "loss_file_name": "loss",
        },
        recover_fn: Callable = recover_network,
        try_sort: bool = False,
    ):
        if isinstance(path, EnsembleDir):
            path = path.path
            self.model_paths, self.path = model_paths_from_parent(path)
            self.dir = path
        elif isinstance(path, Path):
            self.model_paths, self.path = model_paths_from_parent(path)
            self.dir = EnsembleDir(self.path)
        elif isinstance(path, str):
            self.dir = EnsembleDir(path)
            self.model_paths, self.path = model_paths_from_parent(self.dir.path)
        elif isinstance(path, Iterable):
            self.model_paths, self.path = model_paths_from_names_or_paths(path, root_dir)
            self.dir = EnsembleDir(self.path)
        else:
            raise TypeError(
                f"Unsupported path type: {type(path)}. "
                "Expected EnsembleDir, str, PathLike, or Iterable."
            )

        self.names, self.name = model_path_names(self.model_paths)
        self.in_context = False

        self._names = []
        self.model_index = []
        # Initialize pointers to model directories.
        for i, name in tqdm(
            enumerate(self.names), desc="Loading ensemble", total=len(self.names)
        ):
            try:
                with all_logging_disabled():
                    self[name] = NetworkView(
                        NetworkDir(self.model_paths[i]),
                        network_class=network_class,
                        root_dir=root_dir,
                        connectome_getter=connectome_getter,
                        checkpoint_mapper=checkpoint_mapper,
                        best_checkpoint_fn=best_checkpoint_fn,
                        best_checkpoint_fn_kwargs=best_checkpoint_fn_kwargs,
                        recover_fn=recover_fn,
                    )
                    self._names.append(name)
            except AttributeError as e:
                logging.warning(f"Failed to load {name}: {e}")
        self._broken = list(set(self.names) - set(self._names))
        self.names = self._names
        self.model_index = np.arange(len(self.names))
        logging.info(f"Loaded {len(self)} networks.")

        if try_sort:
            self.sort()

        self._init_args = (
            path,
            network_class,
            root_dir,
            connectome_getter,
            checkpoint_mapper,
            best_checkpoint_fn,
            best_checkpoint_fn_kwargs,
            recover_fn,
            try_sort,
        )
        self.connectome = self[next(iter(self))].connectome
        self.cache = FIFOCache(maxsize=3)

    def __getitem__(
        self, key: Union[str, int, slice, NDArray, list]
    ) -> Union[NetworkView, "Ensemble"]:
        """Get item from the ensemble.

        Args:
            key: Key to access the ensemble.
                Can be a string, int, slice, NDArray, or list.

        Returns:
            NetworkView or Ensemble: The requested item or subset of the ensemble.

        Raises:
            ValueError: If the key is invalid.
        """
        if isinstance(key, (int, np.integer)):
            return dict.__getitem__(self, self.names[key])
        elif isinstance(key, slice):
            return self.__class__(self.names[key])
        elif isinstance(key, (np.ndarray, list)):
            return self.__class__(np.array(self.names)[key])
        elif key in self.names:
            return dict.__getitem__(self, key)
        else:
            raise ValueError(f"{key}")

    def __repr__(self) -> str:
        """Return a string representation of the Ensemble."""
        return f"{self.__class__.__name__}({self.path})"

    def __dir__(self) -> List[str]:
        """Return a list of attributes for the Ensemble."""
        return list({*dict.__dir__(self), *dict.__iter__(self)})

    def __len__(self) -> int:
        """Return the number of networks in the Ensemble."""
        return len(self.names)

    def __iter__(self) -> Iterator[str]:
        """Iterate over the names of the networks in the Ensemble."""
        yield from self.names

    def items(self) -> Iterator[Tuple[str, NetworkView]]:
        """Return an iterator over the (name, NetworkView) pairs in the Ensemble."""
        return iter((k, self[k]) for k in self)

    def keys(self) -> List[str]:
        """Return a list of network names in the Ensemble."""
        return list(self)

    def values(self) -> List[NetworkView]:
        """Return a list of NetworkViews in the Ensemble."""
        return [self[k] for k in self]

    def _clear_cache(self) -> None:
        """Clear the cache of the Ensemble."""
        self.cache = {}

    def _clear_memory(self) -> None:
        """Clear the memory of all NetworkViews in the Ensemble."""
        for nv in self.values():
            nv._clear_memory()

    def check_configs_match(self) -> bool:
        """Check if the configurations of the networks in the ensemble match.

        Returns:
            bool: True if all configurations match, False otherwise.
        """
        config0 = self[0].dir.config
        for i in range(1, len(self)):
            diff = config0.diff(self[i].dir.config, name1="first", name2="second").first
            if diff and not (len(diff) == 1 and "network_name" in diff[0]):
                logging.warning(
                    "%(first)s differs from %(second)s. Diff is %(diff)s.",
                    {"first": self[0].name, "second": self[i].name, "diff": diff},
                )
                return False
        return True

    def yield_networks(self) -> Generator[Network, None, None]:
        """Yield initialized networks from the ensemble.

        Yields:
            Network: Initialized network from the ensemble.
        """
        network = self[0].init_network()
        yield network
        for network_view in self.values()[1:]:
            yield network_view.init_network(network=network)

    def yield_decoders(self) -> Generator[nn.Module, None, None]:
        """Yield initialized decoders from the ensemble.

        Yields:
            nn.Module: Initialized decoder from the ensemble.
        """
        assert self.check_configs_match(), "configurations do not match"
        decoder = self[0].init_decoder()
        for network_view in self.values():
            yield network_view.init_decoder(decoder=decoder)

    def simulate(
        self, movie_input: torch.Tensor, dt: float, fade_in: bool = True
    ) -> Generator[np.ndarray, None, None]:
        """Simulate the ensemble activity from movie input.

        Args:
            movie_input: Tensor with shape (batch_size, n_frames, 1, hexals).
            dt: Integration time constant. Warns if dt > 1/50.
            fade_in: Whether to use `network.fade_in_state` to compute the initial
                state. Defaults to True. If False, uses the
                `network.steady_state` after 1s of grey input.

        Yields:
            np.ndarray: Response of each individual network.

        Note:
            Simulates across batch_size in parallel, which can easily lead to OOM for
            large batch sizes.
        """
        for network in tqdm(
            self.yield_networks(),
            desc="Simulating network",
            total=len(self.names),
        ):
            yield (
                network.simulate(
                    movie_input,
                    dt,
                    initial_state=(
                        network.fade_in_state(1.0, dt, movie_input[:, 0])
                        if fade_in
                        else "auto"
                    ),
                )
                .cpu()
                .numpy()
            )

    def simulate_from_dataset(
        self,
        dataset,
        dt: float,
        indices: Iterable[int] = None,
        t_pre: float = 1.0,
        t_fade_in: float = 0.0,
        default_stim_key: str = "lum",
        batch_size: int = 1,
        central_cell_only: bool = True,
    ) -> Generator[np.ndarray, None, None]:
        """Simulate the ensemble activity from a dataset.

        Args:
            dataset: Dataset to simulate from.
            dt: Integration time constant.
            indices: Indices of stimuli to simulate. Defaults to None (all stimuli).
            t_pre: Time before stimulus onset. Defaults to 1.0.
            t_fade_in: Fade-in time. Defaults to 0.0.
            default_stim_key: Default stimulus key. Defaults to "lum".
            batch_size: Batch size for simulation. Defaults to 1.
            central_cell_only: Whether to return only central cells. Defaults to True.

        Yields:
            np.ndarray: Simulated responses for each network.
        """
        if central_cell_only:
            central_cells_index = self[0].connectome.central_cells_index[:]

        progress_bar = tqdm(desc="Simulating network", total=len(self.names))

        for network in self.yield_networks():

            def handle_network(network: Network):
                for _, resp in network.stimulus_response(
                    dataset,
                    dt=dt,
                    indices=indices,
                    t_pre=t_pre,
                    t_fade_in=t_fade_in,
                    default_stim_key=default_stim_key,
                    batch_size=batch_size,
                ):
                    if central_cell_only:
                        yield resp[:, :, central_cells_index]
                    else:
                        yield resp

            r = np.stack(list(handle_network(network)))
            yield r.reshape(-1, r.shape[-2], r.shape[-1])

            progress_bar.update(1)
        progress_bar.close()

    def decode(
        self, movie_input: torch.Tensor, dt: float
    ) -> Generator[np.ndarray, None, None]:
        """Decode the ensemble responses with the ensemble decoders.

        Args:
            movie_input: Input movie tensor.
            dt: Integration time constant.

        Yields:
            np.ndarray: Decoded responses for each network.
        """
        responses = torch.tensor(list(self.simulate(movie_input, dt)))
        for i, decoder in enumerate(self.yield_decoders()):
            with simulation(decoder):
                yield decoder(responses[i]).cpu().numpy()

    def validation_file(
        self,
        validation_subdir: Optional[str] = None,
        loss_file_name: Optional[str] = None,
    ) -> Tuple[str, str]:
        """Return the validation file for each network in the ensemble.

        Args:
            validation_subdir: Subdirectory for validation files. Defaults to None.
            loss_file_name: Name of the loss file. Defaults to None.

        Returns:
            Tuple[str, str]: Validation subdirectory and loss file name.
        """
        network_view0 = self[0]
        if validation_subdir is None:
            validation_subdir = network_view0.best_checkpoint_fn_kwargs.get(
                "validation_subdir"
            )
        if loss_file_name is None:
            loss_file_name = network_view0.best_checkpoint_fn_kwargs.get("loss_file_name")

        return validation_subdir, loss_file_name

    def sort(
        self,
        validation_subdir: Optional[str] = None,
        loss_file_name: Optional[str] = None,
    ) -> None:
        """Sort the ensemble by validation loss.

        Args:
            validation_subdir: Subdirectory for validation files. Defaults to None.
            loss_file_name: Name of the loss file. Defaults to None.
        """
        try:
            self.names = sorted(
                self.keys(),
                key=lambda key: dict(
                    zip(
                        self.keys(),
                        self.min_validation_losses(validation_subdir, loss_file_name),
                    )
                )[key],
                reverse=False,
            )
        except Exception as e:
            logging.info(f"sorting failed: {e}")

    def argsort(
        self,
        validation_subdir: Optional[str] = None,
        loss_file_name: Optional[str] = None,
    ) -> np.ndarray:
        """Return the indices that would sort the ensemble by validation loss.

        Args:
            validation_subdir: Subdirectory for validation files. Defaults to None.
            loss_file_name: Name of the loss file. Defaults to None.

        Returns:
            np.ndarray: Indices that would sort the ensemble.
        """
        return np.argsort(
            self.min_validation_losses(
                *self.validation_file(validation_subdir, loss_file_name)
            )
        )

    def zorder(
        self,
        validation_subdir: Optional[str] = None,
        loss_file_name: Optional[str] = None,
    ) -> np.ndarray:
        """Return the z-order of the ensemble based on validation loss.

        Args:
            validation_subdir: Subdirectory for validation files. Defaults to None.
            loss_file_name: Name of the loss file. Defaults to None.

        Returns:
            np.ndarray: Z-order of the ensemble.
        """
        return len(self) - self.argsort(validation_subdir, loss_file_name).argsort()

    @context_aware_cache(context=lambda self: (self.names))
    def validation_losses(
        self, subdir: Optional[str] = None, file: Optional[str] = None
    ) -> np.ndarray:
        """Return a list of validation losses for each network in the ensemble.

        Args:
            subdir: Subdirectory for validation files. Defaults to None.
            file: Name of the loss file. Defaults to None.

        Returns:
            np.ndarray: Validation losses for each network.
        """
        subdir, file = self.validation_file(subdir, file)
        losses = np.array([nv.dir[subdir][file][()] for nv in self.values()])
        return losses

    def min_validation_losses(
        self, subdir: Optional[str] = None, file: Optional[str] = None
    ) -> np.ndarray:
        """Return the minimum validation loss of the ensemble.

        Args:
            subdir: Subdirectory for validation files. Defaults to None.
            file: Name of the loss file. Defaults to None.

        Returns:
            np.ndarray: Minimum validation losses for each network.
        """
        losses = self.validation_losses(subdir, file)
        if losses.ndim == 1:
            return losses
        return np.min(losses, axis=1)

    @contextmanager
    def rank_by_validation_error(self, reverse: bool = False):
        """Temporarily sort the ensemble based on validation error.

        Args:
            reverse: Whether to sort in descending order. Defaults to False.

        Yields:
            None
        """
        _names = deepcopy(self.names)

        try:
            self.names = sorted(
                self.keys(),
                key=lambda key: dict(zip(self.keys(), self.min_validation_losses()))[key],
                reverse=reverse,
            )
        except Exception as e:
            logging.info(f"sorting failed: {e}")
        try:
            yield
        finally:
            self.names = list(_names)

    @contextmanager
    def ratio(self, best: Optional[float] = None, worst: Optional[float] = None):
        """Temporarily filter the ensemble by a ratio of best or worst performing models.

        Args:
            best: Ratio of best performing models to keep. Defaults to None.
            worst: Ratio of worst performing models to keep. Defaults to None.

        Yields:
            None

        Raises:
            ValueError: If best and worst sum to more than 1.
        """
        # no-op
        if best is None and worst is None:
            yield
            return

        _names = tuple(self.names)
        _model_index = tuple(self.model_index)

        with self.rank_by_validation_error():
            if best is not None and worst is not None and best + worst > 1:
                raise ValueError("best and worst must add up to 1")

            if best is not None or worst is not None:
                _context_best_names, _context_worst_names = [], []
                if best is not None:
                    _context_best_names = list(self.names[: int(best * len(self))])
                    self._best_ratio = best
                else:
                    self._best_ratio = 0
                if worst is not None:
                    _context_worst_names = list(self.names[-int(worst * len(self)) :])
                    self._worst_ratio = worst
                else:
                    self._worst_ratio = 0

                in_context_names = [*_context_best_names, *_context_worst_names]

                if in_context_names:  # to prevent an empty index
                    self.model_index = np.array([
                        i
                        for i, name in zip(_model_index, _names)
                        if name in in_context_names
                    ])
                self.names = in_context_names
                self.in_context = True
            try:
                yield
            finally:
                self.names = list(_names)
                self.model_index = _model_index
                self._best_ratio = 0.5
                self._worst_ratio = 0.5
                self.in_context = False

    @contextmanager
    def select_items(self, indices: List[int]):
        """Temporarily filter the ensemble by a list of indices.

        Args:
            indices: List of indices to select.

        Yields:
            None

        Raises:
            ValueError: If indices are invalid.
        """
        # no-op
        try:
            if indices is None:
                yield
                return
            _names = tuple(self.names)
            _model_index = tuple(self.model_index)
            self._names = _names

            if isinstance(indices, (int, np.integer, slice)):
                in_context_names = self.names[indices]
            elif isinstance(indices, (list, np.ndarray)):
                if np.array(indices).dtype == np.array(self.names).dtype:
                    in_context_names = indices
                elif np.array(indices).dtype == np.int_:
                    in_context_names = np.array(self.names)[indices]
                else:
                    raise ValueError(f"{indices}")
            else:
                raise ValueError(f"{indices}")
            self.model_index = np.array([
                i for i, name in zip(_model_index, _names) if name in in_context_names
            ])
            self.names = in_context_names
            self.in_context = True
            yield
        finally:
            self.names = list(_names)
            self.model_index = list(_model_index)
            self.in_context = False

    def task_error(
        self,
        cmap: Union[str, Colormap] = "Blues_r",
        truncate: Optional[Dict[str, Union[float, int]]] = None,
        vmin: Optional[float] = None,
        vmax: Optional[float] = None,
    ) -> "TaskError":
        """Return a TaskError object for the ensemble.

        Args:
            cmap: Colormap to use. Defaults to "Blues_r".
            truncate: Dictionary to truncate the colormap. Defaults to None.
            vmin: Minimum value for normalization. Defaults to None.
            vmax: Maximum value for normalization. Defaults to None.

        Returns:
            TaskError: Object containing validation losses, colors, colormap, norm,
                and scalar mapper.
        """
        error = self.min_validation_losses()

        if truncate is None:
            # truncate because the maxval would be white with the default colormap
            # which would be invisible on a white background
            truncate = {"minval": 0.0, "maxval": 0.9, "n": 256}
        cmap = cm.get_cmap(cmap) if isinstance(cmap, str) else cmap
        cmap = plots.plt_utils.truncate_colormap(cmap, **truncate)
        sm, norm = plots.plt_utils.get_scalarmapper(
            cmap=cmap,
            vmin=vmin or np.min(error),
            vmax=vmax or np.max(error),
        )
        colors = sm.to_rgba(np.array(error))

        return TaskError(error, colors, cmap, norm, sm)

    def parameters(self) -> Dict[str, np.ndarray]:
        """Return the parameters of the ensemble.

        Returns:
            Dict[str, np.ndarray]: Dictionary of parameter arrays.
        """
        network_params = {}
        for network_view in self.values():
            chkpt_params = torch.load(network_view.network('best').checkpoint)
            for key, val in chkpt_params["network"].items():
                if key not in network_params:
                    network_params[key] = []
                network_params[key].append(val.cpu().numpy())
        for key, val in network_params.items():
            network_params[key] = np.array(val)
        return network_params

    def parameter_keys(self) -> Dict[str, List[str]]:
        """Return the keys of the parameters of the ensemble.

        Returns:
            Dict[str, List[str]]: Dictionary of parameter keys.
        """
        self.check_configs_match()
        network_view = self[0]
        config = network_view.dir.config.network

        parameter_keys = {}
        for param_name, param_config in config.node_config.items():
            param = forward_subclass(
                Parameter,
                config={
                    "type": param_config.type,
                    "param_config": param_config,
                    "connectome": network_view.connectome,
                },
            )
            parameter_keys[f"nodes_{param_name}"] = param.keys
        for param_name, param_config in config.edge_config.items():
            param = forward_subclass(
                Parameter,
                config={
                    "type": param_config.type,
                    "param_config": param_config,
                    "connectome": network_view.connectome,
                },
            )
            parameter_keys[f"edges_{param_name}"] = param.keys
        return parameter_keys

    @wraps(stimulus_responses.flash_responses)
    @context_aware_cache(context=lambda self: (self.names))
    def flash_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate flash responses."""
        return stimulus_responses.flash_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.moving_edge_responses)
    @context_aware_cache(context=lambda self: (self.names))
    def moving_edge_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate moving edge responses."""
        return stimulus_responses.moving_edge_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.moving_bar_responses)
    @context_aware_cache(context=lambda self: (self.names))
    def moving_bar_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate moving bar responses."""
        return stimulus_responses.moving_bar_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.naturalistic_stimuli_responses)
    @context_aware_cache(context=lambda self: (self.names))
    def naturalistic_stimuli_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate naturalistic stimuli responses."""
        return stimulus_responses.naturalistic_stimuli_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.central_impulses_responses)
    @context_aware_cache(context=lambda self: (self.names))
    def central_impulses_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate central ommatidium impulses responses."""
        return stimulus_responses.central_impulses_responses(self, *args, **kwargs)

    @wraps(stimulus_responses.spatial_impulses_responses)
    @context_aware_cache(context=lambda self: (self.names))
    def spatial_impulses_responses(self, *args, **kwargs) -> xr.Dataset:
        """Generate spatial ommatidium impulses responses."""
        return stimulus_responses.spatial_impulses_responses(self, *args, **kwargs)

    @wraps(stimulus_responses_currents.moving_edge_currents)
    @context_aware_cache(context=lambda self: (self.names))
    def moving_edge_currents(
        self, *args, **kwargs
    ) -> List[stimulus_responses_currents.ExperimentData]:
        """Generate moving edge currents."""
        return stimulus_responses_currents.moving_edge_currents(self, *args, **kwargs)

    @context_aware_cache
    def clustering(self, cell_type) -> GaussianMixtureClustering:
        """Return the clustering of the ensemble for a given cell type.

        Args:
            cell_type: The cell type to cluster.

        Returns:
            GaussianMixtureClustering: Clustering object for the given cell type.

        Raises:
            ValueError: If clustering is not available in context.
        """
        if self.in_context:
            raise ValueError("clustering is not available in context")

        if (
            not self.dir.umap_and_clustering
            or not self.dir.umap_and_clustering[cell_type]
        ):
            return compute_umap_and_clustering(self, cell_type)

        path = self.dir.umap_and_clustering[f"{cell_type}.pickle"]
        with open(path, "rb") as file:
            clustering = pickle.load(file)

        return clustering

    def cluster_indices(self, cell_type: str) -> Dict[int, NDArray[int]]:
        """Clusters from responses to naturalistic stimuli of the given cell type.

        Args:
            cell_type: The cell type to return the clusters for.

        Returns:
            Dict[int, NDArray[int]]: Keys are the cluster ids and the values are the
            model indices in the ensemble.

        Example:
            ```python
            ensemble = Ensemble("path/to/ensemble")
            cluster_indices = ensemble.cluster_indices("T4a")
            first_cluster = ensemble[cluster_indices[0]]
            ```

        Raises:
            ValueError: If stored clustering does not match ensemble.
        """
        clustering = self.clustering(cell_type)
        cluster_indices = get_cluster_to_indices(
            clustering.embedding.mask,
            clustering.labels,
            task_error=self.task_error(),
        )

        _models = sorted(np.concatenate(list(cluster_indices.values())))
        if len(_models) != clustering.embedding.mask.sum() or len(_models) > len(self):
            raise ValueError("stored clustering does not match ensemble")

        return cluster_indices

    def responses_norm(self, rectified: bool = False) -> np.ndarray:
        """Compute the norm of responses to naturalistic stimuli.

        Args:
            rectified: Whether to rectify responses before computing norm.

        Returns:
            np.ndarray: Norm of responses for each network.
        """
        response_set = self.naturalistic_stimuli_responses()
        responses = response_set['responses'].values

        def compute_norm(X, rectified=True):
            """Computes a normalization constant for stimulus
                responses per cell hypothesis, i.e. cell_type independent values.

            Args:
                X: (n_stimuli, n_frames, n_cell_types)
            """
            if rectified:
                X = np.maximum(X, 0)
            n_models, n_samples, n_frames, n_cell_types = X.shape

            # replace NaNs with 0
            X[np.isnan(X)] = 0

            return (
                1
                / np.sqrt(n_samples * n_frames)
                * np.linalg.norm(
                    X,
                    axis=(1, 2),
                    keepdims=True,
                )
            )

        return np.take(
            compute_norm(responses, rectified=rectified), self.model_index, axis=0
        )

__getitem__

__getitem__(key)

Get item from the ensemble.

Parameters:

Name Type Description Default
key Union[str, int, slice, NDArray, list]

Key to access the ensemble. Can be a string, int, slice, NDArray, or list.

required

Returns:

Type Description
Union[NetworkView, 'Ensemble']

NetworkView or Ensemble: The requested item or subset of the ensemble.

Raises:

Type Description
ValueError

If the key is invalid.

Source code in flyvision/network/ensemble.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def __getitem__(
    self, key: Union[str, int, slice, NDArray, list]
) -> Union[NetworkView, "Ensemble"]:
    """Get item from the ensemble.

    Args:
        key: Key to access the ensemble.
            Can be a string, int, slice, NDArray, or list.

    Returns:
        NetworkView or Ensemble: The requested item or subset of the ensemble.

    Raises:
        ValueError: If the key is invalid.
    """
    if isinstance(key, (int, np.integer)):
        return dict.__getitem__(self, self.names[key])
    elif isinstance(key, slice):
        return self.__class__(self.names[key])
    elif isinstance(key, (np.ndarray, list)):
        return self.__class__(np.array(self.names)[key])
    elif key in self.names:
        return dict.__getitem__(self, key)
    else:
        raise ValueError(f"{key}")

__repr__

__repr__()

Return a string representation of the Ensemble.

Source code in flyvision/network/ensemble.py
211
212
213
def __repr__(self) -> str:
    """Return a string representation of the Ensemble."""
    return f"{self.__class__.__name__}({self.path})"

__dir__

__dir__()

Return a list of attributes for the Ensemble.

Source code in flyvision/network/ensemble.py
215
216
217
def __dir__(self) -> List[str]:
    """Return a list of attributes for the Ensemble."""
    return list({*dict.__dir__(self), *dict.__iter__(self)})

__len__

__len__()

Return the number of networks in the Ensemble.

Source code in flyvision/network/ensemble.py
219
220
221
def __len__(self) -> int:
    """Return the number of networks in the Ensemble."""
    return len(self.names)

__iter__

__iter__()

Iterate over the names of the networks in the Ensemble.

Source code in flyvision/network/ensemble.py
223
224
225
def __iter__(self) -> Iterator[str]:
    """Iterate over the names of the networks in the Ensemble."""
    yield from self.names

items

items()

Return an iterator over the (name, NetworkView) pairs in the Ensemble.

Source code in flyvision/network/ensemble.py
227
228
229
def items(self) -> Iterator[Tuple[str, NetworkView]]:
    """Return an iterator over the (name, NetworkView) pairs in the Ensemble."""
    return iter((k, self[k]) for k in self)

keys

keys()

Return a list of network names in the Ensemble.

Source code in flyvision/network/ensemble.py
231
232
233
def keys(self) -> List[str]:
    """Return a list of network names in the Ensemble."""
    return list(self)

values

values()

Return a list of NetworkViews in the Ensemble.

Source code in flyvision/network/ensemble.py
235
236
237
def values(self) -> List[NetworkView]:
    """Return a list of NetworkViews in the Ensemble."""
    return [self[k] for k in self]

check_configs_match

check_configs_match()

Check if the configurations of the networks in the ensemble match.

Returns:

Name Type Description
bool bool

True if all configurations match, False otherwise.

Source code in flyvision/network/ensemble.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def check_configs_match(self) -> bool:
    """Check if the configurations of the networks in the ensemble match.

    Returns:
        bool: True if all configurations match, False otherwise.
    """
    config0 = self[0].dir.config
    for i in range(1, len(self)):
        diff = config0.diff(self[i].dir.config, name1="first", name2="second").first
        if diff and not (len(diff) == 1 and "network_name" in diff[0]):
            logging.warning(
                "%(first)s differs from %(second)s. Diff is %(diff)s.",
                {"first": self[0].name, "second": self[i].name, "diff": diff},
            )
            return False
    return True

yield_networks

yield_networks()

Yield initialized networks from the ensemble.

Yields:

Name Type Description
Network Network

Initialized network from the ensemble.

Source code in flyvision/network/ensemble.py
265
266
267
268
269
270
271
272
273
274
def yield_networks(self) -> Generator[Network, None, None]:
    """Yield initialized networks from the ensemble.

    Yields:
        Network: Initialized network from the ensemble.
    """
    network = self[0].init_network()
    yield network
    for network_view in self.values()[1:]:
        yield network_view.init_network(network=network)

yield_decoders

yield_decoders()

Yield initialized decoders from the ensemble.

Yields:

Type Description
Module

nn.Module: Initialized decoder from the ensemble.

Source code in flyvision/network/ensemble.py
276
277
278
279
280
281
282
283
284
285
def yield_decoders(self) -> Generator[nn.Module, None, None]:
    """Yield initialized decoders from the ensemble.

    Yields:
        nn.Module: Initialized decoder from the ensemble.
    """
    assert self.check_configs_match(), "configurations do not match"
    decoder = self[0].init_decoder()
    for network_view in self.values():
        yield network_view.init_decoder(decoder=decoder)

simulate

simulate(movie_input, dt, fade_in=True)

Simulate the ensemble activity from movie input.

Parameters:

Name Type Description Default
movie_input Tensor

Tensor with shape (batch_size, n_frames, 1, hexals).

required
dt float

Integration time constant. Warns if dt > 1/50.

required
fade_in bool

Whether to use network.fade_in_state to compute the initial state. Defaults to True. If False, uses the network.steady_state after 1s of grey input.

True

Yields:

Type Description
ndarray

np.ndarray: Response of each individual network.

Note

Simulates across batch_size in parallel, which can easily lead to OOM for large batch sizes.

Source code in flyvision/network/ensemble.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def simulate(
    self, movie_input: torch.Tensor, dt: float, fade_in: bool = True
) -> Generator[np.ndarray, None, None]:
    """Simulate the ensemble activity from movie input.

    Args:
        movie_input: Tensor with shape (batch_size, n_frames, 1, hexals).
        dt: Integration time constant. Warns if dt > 1/50.
        fade_in: Whether to use `network.fade_in_state` to compute the initial
            state. Defaults to True. If False, uses the
            `network.steady_state` after 1s of grey input.

    Yields:
        np.ndarray: Response of each individual network.

    Note:
        Simulates across batch_size in parallel, which can easily lead to OOM for
        large batch sizes.
    """
    for network in tqdm(
        self.yield_networks(),
        desc="Simulating network",
        total=len(self.names),
    ):
        yield (
            network.simulate(
                movie_input,
                dt,
                initial_state=(
                    network.fade_in_state(1.0, dt, movie_input[:, 0])
                    if fade_in
                    else "auto"
                ),
            )
            .cpu()
            .numpy()
        )

simulate_from_dataset

simulate_from_dataset(
    dataset,
    dt,
    indices=None,
    t_pre=1.0,
    t_fade_in=0.0,
    default_stim_key="lum",
    batch_size=1,
    central_cell_only=True,
)

Simulate the ensemble activity from a dataset.

Parameters:

Name Type Description Default
dataset

Dataset to simulate from.

required
dt float

Integration time constant.

required
indices Iterable[int]

Indices of stimuli to simulate. Defaults to None (all stimuli).

None
t_pre float

Time before stimulus onset. Defaults to 1.0.

1.0
t_fade_in float

Fade-in time. Defaults to 0.0.

0.0
default_stim_key str

Default stimulus key. Defaults to “lum”.

'lum'
batch_size int

Batch size for simulation. Defaults to 1.

1
central_cell_only bool

Whether to return only central cells. Defaults to True.

True

Yields:

Type Description
ndarray

np.ndarray: Simulated responses for each network.

Source code in flyvision/network/ensemble.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def simulate_from_dataset(
    self,
    dataset,
    dt: float,
    indices: Iterable[int] = None,
    t_pre: float = 1.0,
    t_fade_in: float = 0.0,
    default_stim_key: str = "lum",
    batch_size: int = 1,
    central_cell_only: bool = True,
) -> Generator[np.ndarray, None, None]:
    """Simulate the ensemble activity from a dataset.

    Args:
        dataset: Dataset to simulate from.
        dt: Integration time constant.
        indices: Indices of stimuli to simulate. Defaults to None (all stimuli).
        t_pre: Time before stimulus onset. Defaults to 1.0.
        t_fade_in: Fade-in time. Defaults to 0.0.
        default_stim_key: Default stimulus key. Defaults to "lum".
        batch_size: Batch size for simulation. Defaults to 1.
        central_cell_only: Whether to return only central cells. Defaults to True.

    Yields:
        np.ndarray: Simulated responses for each network.
    """
    if central_cell_only:
        central_cells_index = self[0].connectome.central_cells_index[:]

    progress_bar = tqdm(desc="Simulating network", total=len(self.names))

    for network in self.yield_networks():

        def handle_network(network: Network):
            for _, resp in network.stimulus_response(
                dataset,
                dt=dt,
                indices=indices,
                t_pre=t_pre,
                t_fade_in=t_fade_in,
                default_stim_key=default_stim_key,
                batch_size=batch_size,
            ):
                if central_cell_only:
                    yield resp[:, :, central_cells_index]
                else:
                    yield resp

        r = np.stack(list(handle_network(network)))
        yield r.reshape(-1, r.shape[-2], r.shape[-1])

        progress_bar.update(1)
    progress_bar.close()

decode

decode(movie_input, dt)

Decode the ensemble responses with the ensemble decoders.

Parameters:

Name Type Description Default
movie_input Tensor

Input movie tensor.

required
dt float

Integration time constant.

required

Yields:

Type Description
ndarray

np.ndarray: Decoded responses for each network.

Source code in flyvision/network/ensemble.py
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def decode(
    self, movie_input: torch.Tensor, dt: float
) -> Generator[np.ndarray, None, None]:
    """Decode the ensemble responses with the ensemble decoders.

    Args:
        movie_input: Input movie tensor.
        dt: Integration time constant.

    Yields:
        np.ndarray: Decoded responses for each network.
    """
    responses = torch.tensor(list(self.simulate(movie_input, dt)))
    for i, decoder in enumerate(self.yield_decoders()):
        with simulation(decoder):
            yield decoder(responses[i]).cpu().numpy()

validation_file

validation_file(
    validation_subdir=None, loss_file_name=None
)

Return the validation file for each network in the ensemble.

Parameters:

Name Type Description Default
validation_subdir Optional[str]

Subdirectory for validation files. Defaults to None.

None
loss_file_name Optional[str]

Name of the loss file. Defaults to None.

None

Returns:

Type Description
Tuple[str, str]

Tuple[str, str]: Validation subdirectory and loss file name.

Source code in flyvision/network/ensemble.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def validation_file(
    self,
    validation_subdir: Optional[str] = None,
    loss_file_name: Optional[str] = None,
) -> Tuple[str, str]:
    """Return the validation file for each network in the ensemble.

    Args:
        validation_subdir: Subdirectory for validation files. Defaults to None.
        loss_file_name: Name of the loss file. Defaults to None.

    Returns:
        Tuple[str, str]: Validation subdirectory and loss file name.
    """
    network_view0 = self[0]
    if validation_subdir is None:
        validation_subdir = network_view0.best_checkpoint_fn_kwargs.get(
            "validation_subdir"
        )
    if loss_file_name is None:
        loss_file_name = network_view0.best_checkpoint_fn_kwargs.get("loss_file_name")

    return validation_subdir, loss_file_name

sort

sort(validation_subdir=None, loss_file_name=None)

Sort the ensemble by validation loss.

Parameters:

Name Type Description Default
validation_subdir Optional[str]

Subdirectory for validation files. Defaults to None.

None
loss_file_name Optional[str]

Name of the loss file. Defaults to None.

None
Source code in flyvision/network/ensemble.py
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def sort(
    self,
    validation_subdir: Optional[str] = None,
    loss_file_name: Optional[str] = None,
) -> None:
    """Sort the ensemble by validation loss.

    Args:
        validation_subdir: Subdirectory for validation files. Defaults to None.
        loss_file_name: Name of the loss file. Defaults to None.
    """
    try:
        self.names = sorted(
            self.keys(),
            key=lambda key: dict(
                zip(
                    self.keys(),
                    self.min_validation_losses(validation_subdir, loss_file_name),
                )
            )[key],
            reverse=False,
        )
    except Exception as e:
        logging.info(f"sorting failed: {e}")

argsort

argsort(validation_subdir=None, loss_file_name=None)

Return the indices that would sort the ensemble by validation loss.

Parameters:

Name Type Description Default
validation_subdir Optional[str]

Subdirectory for validation files. Defaults to None.

None
loss_file_name Optional[str]

Name of the loss file. Defaults to None.

None

Returns:

Type Description
ndarray

np.ndarray: Indices that would sort the ensemble.

Source code in flyvision/network/ensemble.py
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def argsort(
    self,
    validation_subdir: Optional[str] = None,
    loss_file_name: Optional[str] = None,
) -> np.ndarray:
    """Return the indices that would sort the ensemble by validation loss.

    Args:
        validation_subdir: Subdirectory for validation files. Defaults to None.
        loss_file_name: Name of the loss file. Defaults to None.

    Returns:
        np.ndarray: Indices that would sort the ensemble.
    """
    return np.argsort(
        self.min_validation_losses(
            *self.validation_file(validation_subdir, loss_file_name)
        )
    )

zorder

zorder(validation_subdir=None, loss_file_name=None)

Return the z-order of the ensemble based on validation loss.

Parameters:

Name Type Description Default
validation_subdir Optional[str]

Subdirectory for validation files. Defaults to None.

None
loss_file_name Optional[str]

Name of the loss file. Defaults to None.

None

Returns:

Type Description
ndarray

np.ndarray: Z-order of the ensemble.

Source code in flyvision/network/ensemble.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def zorder(
    self,
    validation_subdir: Optional[str] = None,
    loss_file_name: Optional[str] = None,
) -> np.ndarray:
    """Return the z-order of the ensemble based on validation loss.

    Args:
        validation_subdir: Subdirectory for validation files. Defaults to None.
        loss_file_name: Name of the loss file. Defaults to None.

    Returns:
        np.ndarray: Z-order of the ensemble.
    """
    return len(self) - self.argsort(validation_subdir, loss_file_name).argsort()

validation_losses

validation_losses(subdir=None, file=None)

Return a list of validation losses for each network in the ensemble.

Parameters:

Name Type Description Default
subdir Optional[str]

Subdirectory for validation files. Defaults to None.

None
file Optional[str]

Name of the loss file. Defaults to None.

None

Returns:

Type Description
ndarray

np.ndarray: Validation losses for each network.

Source code in flyvision/network/ensemble.py
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
@context_aware_cache(context=lambda self: (self.names))
def validation_losses(
    self, subdir: Optional[str] = None, file: Optional[str] = None
) -> np.ndarray:
    """Return a list of validation losses for each network in the ensemble.

    Args:
        subdir: Subdirectory for validation files. Defaults to None.
        file: Name of the loss file. Defaults to None.

    Returns:
        np.ndarray: Validation losses for each network.
    """
    subdir, file = self.validation_file(subdir, file)
    losses = np.array([nv.dir[subdir][file][()] for nv in self.values()])
    return losses

min_validation_losses

min_validation_losses(subdir=None, file=None)

Return the minimum validation loss of the ensemble.

Parameters:

Name Type Description Default
subdir Optional[str]

Subdirectory for validation files. Defaults to None.

None
file Optional[str]

Name of the loss file. Defaults to None.

None

Returns:

Type Description
ndarray

np.ndarray: Minimum validation losses for each network.

Source code in flyvision/network/ensemble.py
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
def min_validation_losses(
    self, subdir: Optional[str] = None, file: Optional[str] = None
) -> np.ndarray:
    """Return the minimum validation loss of the ensemble.

    Args:
        subdir: Subdirectory for validation files. Defaults to None.
        file: Name of the loss file. Defaults to None.

    Returns:
        np.ndarray: Minimum validation losses for each network.
    """
    losses = self.validation_losses(subdir, file)
    if losses.ndim == 1:
        return losses
    return np.min(losses, axis=1)

rank_by_validation_error

rank_by_validation_error(reverse=False)

Temporarily sort the ensemble based on validation error.

Parameters:

Name Type Description Default
reverse bool

Whether to sort in descending order. Defaults to False.

False

Yields:

Type Description

None

Source code in flyvision/network/ensemble.py
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
@contextmanager
def rank_by_validation_error(self, reverse: bool = False):
    """Temporarily sort the ensemble based on validation error.

    Args:
        reverse: Whether to sort in descending order. Defaults to False.

    Yields:
        None
    """
    _names = deepcopy(self.names)

    try:
        self.names = sorted(
            self.keys(),
            key=lambda key: dict(zip(self.keys(), self.min_validation_losses()))[key],
            reverse=reverse,
        )
    except Exception as e:
        logging.info(f"sorting failed: {e}")
    try:
        yield
    finally:
        self.names = list(_names)

ratio

ratio(best=None, worst=None)

Temporarily filter the ensemble by a ratio of best or worst performing models.

Parameters:

Name Type Description Default
best Optional[float]

Ratio of best performing models to keep. Defaults to None.

None
worst Optional[float]

Ratio of worst performing models to keep. Defaults to None.

None

Yields:

Type Description

None

Raises:

Type Description
ValueError

If best and worst sum to more than 1.

Source code in flyvision/network/ensemble.py
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
@contextmanager
def ratio(self, best: Optional[float] = None, worst: Optional[float] = None):
    """Temporarily filter the ensemble by a ratio of best or worst performing models.

    Args:
        best: Ratio of best performing models to keep. Defaults to None.
        worst: Ratio of worst performing models to keep. Defaults to None.

    Yields:
        None

    Raises:
        ValueError: If best and worst sum to more than 1.
    """
    # no-op
    if best is None and worst is None:
        yield
        return

    _names = tuple(self.names)
    _model_index = tuple(self.model_index)

    with self.rank_by_validation_error():
        if best is not None and worst is not None and best + worst > 1:
            raise ValueError("best and worst must add up to 1")

        if best is not None or worst is not None:
            _context_best_names, _context_worst_names = [], []
            if best is not None:
                _context_best_names = list(self.names[: int(best * len(self))])
                self._best_ratio = best
            else:
                self._best_ratio = 0
            if worst is not None:
                _context_worst_names = list(self.names[-int(worst * len(self)) :])
                self._worst_ratio = worst
            else:
                self._worst_ratio = 0

            in_context_names = [*_context_best_names, *_context_worst_names]

            if in_context_names:  # to prevent an empty index
                self.model_index = np.array([
                    i
                    for i, name in zip(_model_index, _names)
                    if name in in_context_names
                ])
            self.names = in_context_names
            self.in_context = True
        try:
            yield
        finally:
            self.names = list(_names)
            self.model_index = _model_index
            self._best_ratio = 0.5
            self._worst_ratio = 0.5
            self.in_context = False

select_items

select_items(indices)

Temporarily filter the ensemble by a list of indices.

Parameters:

Name Type Description Default
indices List[int]

List of indices to select.

required

Yields:

Type Description

None

Raises:

Type Description
ValueError

If indices are invalid.

Source code in flyvision/network/ensemble.py
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
@contextmanager
def select_items(self, indices: List[int]):
    """Temporarily filter the ensemble by a list of indices.

    Args:
        indices: List of indices to select.

    Yields:
        None

    Raises:
        ValueError: If indices are invalid.
    """
    # no-op
    try:
        if indices is None:
            yield
            return
        _names = tuple(self.names)
        _model_index = tuple(self.model_index)
        self._names = _names

        if isinstance(indices, (int, np.integer, slice)):
            in_context_names = self.names[indices]
        elif isinstance(indices, (list, np.ndarray)):
            if np.array(indices).dtype == np.array(self.names).dtype:
                in_context_names = indices
            elif np.array(indices).dtype == np.int_:
                in_context_names = np.array(self.names)[indices]
            else:
                raise ValueError(f"{indices}")
        else:
            raise ValueError(f"{indices}")
        self.model_index = np.array([
            i for i, name in zip(_model_index, _names) if name in in_context_names
        ])
        self.names = in_context_names
        self.in_context = True
        yield
    finally:
        self.names = list(_names)
        self.model_index = list(_model_index)
        self.in_context = False

task_error

task_error(
    cmap="Blues_r", truncate=None, vmin=None, vmax=None
)

Return a TaskError object for the ensemble.

Parameters:

Name Type Description Default
cmap Union[str, Colormap]

Colormap to use. Defaults to “Blues_r”.

'Blues_r'
truncate Optional[Dict[str, Union[float, int]]]

Dictionary to truncate the colormap. Defaults to None.

None
vmin Optional[float]

Minimum value for normalization. Defaults to None.

None
vmax Optional[float]

Maximum value for normalization. Defaults to None.

None

Returns:

Name Type Description
TaskError 'TaskError'

Object containing validation losses, colors, colormap, norm, and scalar mapper.

Source code in flyvision/network/ensemble.py
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def task_error(
    self,
    cmap: Union[str, Colormap] = "Blues_r",
    truncate: Optional[Dict[str, Union[float, int]]] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
) -> "TaskError":
    """Return a TaskError object for the ensemble.

    Args:
        cmap: Colormap to use. Defaults to "Blues_r".
        truncate: Dictionary to truncate the colormap. Defaults to None.
        vmin: Minimum value for normalization. Defaults to None.
        vmax: Maximum value for normalization. Defaults to None.

    Returns:
        TaskError: Object containing validation losses, colors, colormap, norm,
            and scalar mapper.
    """
    error = self.min_validation_losses()

    if truncate is None:
        # truncate because the maxval would be white with the default colormap
        # which would be invisible on a white background
        truncate = {"minval": 0.0, "maxval": 0.9, "n": 256}
    cmap = cm.get_cmap(cmap) if isinstance(cmap, str) else cmap
    cmap = plots.plt_utils.truncate_colormap(cmap, **truncate)
    sm, norm = plots.plt_utils.get_scalarmapper(
        cmap=cmap,
        vmin=vmin or np.min(error),
        vmax=vmax or np.max(error),
    )
    colors = sm.to_rgba(np.array(error))

    return TaskError(error, colors, cmap, norm, sm)

parameters

parameters()

Return the parameters of the ensemble.

Returns:

Type Description
Dict[str, ndarray]

Dict[str, np.ndarray]: Dictionary of parameter arrays.

Source code in flyvision/network/ensemble.py
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
def parameters(self) -> Dict[str, np.ndarray]:
    """Return the parameters of the ensemble.

    Returns:
        Dict[str, np.ndarray]: Dictionary of parameter arrays.
    """
    network_params = {}
    for network_view in self.values():
        chkpt_params = torch.load(network_view.network('best').checkpoint)
        for key, val in chkpt_params["network"].items():
            if key not in network_params:
                network_params[key] = []
            network_params[key].append(val.cpu().numpy())
    for key, val in network_params.items():
        network_params[key] = np.array(val)
    return network_params

parameter_keys

parameter_keys()

Return the keys of the parameters of the ensemble.

Returns:

Type Description
Dict[str, List[str]]

Dict[str, List[str]]: Dictionary of parameter keys.

Source code in flyvision/network/ensemble.py
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
def parameter_keys(self) -> Dict[str, List[str]]:
    """Return the keys of the parameters of the ensemble.

    Returns:
        Dict[str, List[str]]: Dictionary of parameter keys.
    """
    self.check_configs_match()
    network_view = self[0]
    config = network_view.dir.config.network

    parameter_keys = {}
    for param_name, param_config in config.node_config.items():
        param = forward_subclass(
            Parameter,
            config={
                "type": param_config.type,
                "param_config": param_config,
                "connectome": network_view.connectome,
            },
        )
        parameter_keys[f"nodes_{param_name}"] = param.keys
    for param_name, param_config in config.edge_config.items():
        param = forward_subclass(
            Parameter,
            config={
                "type": param_config.type,
                "param_config": param_config,
                "connectome": network_view.connectome,
            },
        )
        parameter_keys[f"edges_{param_name}"] = param.keys
    return parameter_keys

flash_responses

flash_responses(*args, **kwargs)

Generate flash responses.

Source code in flyvision/network/ensemble.py
728
729
730
731
732
@wraps(stimulus_responses.flash_responses)
@context_aware_cache(context=lambda self: (self.names))
def flash_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate flash responses."""
    return stimulus_responses.flash_responses(self, *args, **kwargs)

moving_edge_responses

moving_edge_responses(*args, **kwargs)

Generate moving edge responses.

Source code in flyvision/network/ensemble.py
734
735
736
737
738
@wraps(stimulus_responses.moving_edge_responses)
@context_aware_cache(context=lambda self: (self.names))
def moving_edge_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate moving edge responses."""
    return stimulus_responses.moving_edge_responses(self, *args, **kwargs)

moving_bar_responses

moving_bar_responses(*args, **kwargs)

Generate moving bar responses.

Source code in flyvision/network/ensemble.py
740
741
742
743
744
@wraps(stimulus_responses.moving_bar_responses)
@context_aware_cache(context=lambda self: (self.names))
def moving_bar_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate moving bar responses."""
    return stimulus_responses.moving_bar_responses(self, *args, **kwargs)

naturalistic_stimuli_responses

naturalistic_stimuli_responses(*args, **kwargs)

Generate naturalistic stimuli responses.

Source code in flyvision/network/ensemble.py
746
747
748
749
750
@wraps(stimulus_responses.naturalistic_stimuli_responses)
@context_aware_cache(context=lambda self: (self.names))
def naturalistic_stimuli_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate naturalistic stimuli responses."""
    return stimulus_responses.naturalistic_stimuli_responses(self, *args, **kwargs)

central_impulses_responses

central_impulses_responses(*args, **kwargs)

Generate central ommatidium impulses responses.

Source code in flyvision/network/ensemble.py
752
753
754
755
756
@wraps(stimulus_responses.central_impulses_responses)
@context_aware_cache(context=lambda self: (self.names))
def central_impulses_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate central ommatidium impulses responses."""
    return stimulus_responses.central_impulses_responses(self, *args, **kwargs)

spatial_impulses_responses

spatial_impulses_responses(*args, **kwargs)

Generate spatial ommatidium impulses responses.

Source code in flyvision/network/ensemble.py
758
759
760
761
762
@wraps(stimulus_responses.spatial_impulses_responses)
@context_aware_cache(context=lambda self: (self.names))
def spatial_impulses_responses(self, *args, **kwargs) -> xr.Dataset:
    """Generate spatial ommatidium impulses responses."""
    return stimulus_responses.spatial_impulses_responses(self, *args, **kwargs)

moving_edge_currents

moving_edge_currents(*args, **kwargs)

Generate moving edge currents.

Source code in flyvision/network/ensemble.py
764
765
766
767
768
769
770
@wraps(stimulus_responses_currents.moving_edge_currents)
@context_aware_cache(context=lambda self: (self.names))
def moving_edge_currents(
    self, *args, **kwargs
) -> List[stimulus_responses_currents.ExperimentData]:
    """Generate moving edge currents."""
    return stimulus_responses_currents.moving_edge_currents(self, *args, **kwargs)

clustering

clustering(cell_type)

Return the clustering of the ensemble for a given cell type.

Parameters:

Name Type Description Default
cell_type

The cell type to cluster.

required

Returns:

Name Type Description
GaussianMixtureClustering GaussianMixtureClustering

Clustering object for the given cell type.

Raises:

Type Description
ValueError

If clustering is not available in context.

Source code in flyvision/network/ensemble.py
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
@context_aware_cache
def clustering(self, cell_type) -> GaussianMixtureClustering:
    """Return the clustering of the ensemble for a given cell type.

    Args:
        cell_type: The cell type to cluster.

    Returns:
        GaussianMixtureClustering: Clustering object for the given cell type.

    Raises:
        ValueError: If clustering is not available in context.
    """
    if self.in_context:
        raise ValueError("clustering is not available in context")

    if (
        not self.dir.umap_and_clustering
        or not self.dir.umap_and_clustering[cell_type]
    ):
        return compute_umap_and_clustering(self, cell_type)

    path = self.dir.umap_and_clustering[f"{cell_type}.pickle"]
    with open(path, "rb") as file:
        clustering = pickle.load(file)

    return clustering

cluster_indices

cluster_indices(cell_type)

Clusters from responses to naturalistic stimuli of the given cell type.

Parameters:

Name Type Description Default
cell_type str

The cell type to return the clusters for.

required

Returns:

Type Description
Dict[int, NDArray[int]]

Dict[int, NDArray[int]]: Keys are the cluster ids and the values are the

Dict[int, NDArray[int]]

model indices in the ensemble.

Example
ensemble = Ensemble("path/to/ensemble")
cluster_indices = ensemble.cluster_indices("T4a")
first_cluster = ensemble[cluster_indices[0]]

Raises:

Type Description
ValueError

If stored clustering does not match ensemble.

Source code in flyvision/network/ensemble.py
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
def cluster_indices(self, cell_type: str) -> Dict[int, NDArray[int]]:
    """Clusters from responses to naturalistic stimuli of the given cell type.

    Args:
        cell_type: The cell type to return the clusters for.

    Returns:
        Dict[int, NDArray[int]]: Keys are the cluster ids and the values are the
        model indices in the ensemble.

    Example:
        ```python
        ensemble = Ensemble("path/to/ensemble")
        cluster_indices = ensemble.cluster_indices("T4a")
        first_cluster = ensemble[cluster_indices[0]]
        ```

    Raises:
        ValueError: If stored clustering does not match ensemble.
    """
    clustering = self.clustering(cell_type)
    cluster_indices = get_cluster_to_indices(
        clustering.embedding.mask,
        clustering.labels,
        task_error=self.task_error(),
    )

    _models = sorted(np.concatenate(list(cluster_indices.values())))
    if len(_models) != clustering.embedding.mask.sum() or len(_models) > len(self):
        raise ValueError("stored clustering does not match ensemble")

    return cluster_indices

responses_norm

responses_norm(rectified=False)

Compute the norm of responses to naturalistic stimuli.

Parameters:

Name Type Description Default
rectified bool

Whether to rectify responses before computing norm.

False

Returns:

Type Description
ndarray

np.ndarray: Norm of responses for each network.

Source code in flyvision/network/ensemble.py
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
def responses_norm(self, rectified: bool = False) -> np.ndarray:
    """Compute the norm of responses to naturalistic stimuli.

    Args:
        rectified: Whether to rectify responses before computing norm.

    Returns:
        np.ndarray: Norm of responses for each network.
    """
    response_set = self.naturalistic_stimuli_responses()
    responses = response_set['responses'].values

    def compute_norm(X, rectified=True):
        """Computes a normalization constant for stimulus
            responses per cell hypothesis, i.e. cell_type independent values.

        Args:
            X: (n_stimuli, n_frames, n_cell_types)
        """
        if rectified:
            X = np.maximum(X, 0)
        n_models, n_samples, n_frames, n_cell_types = X.shape

        # replace NaNs with 0
        X[np.isnan(X)] = 0

        return (
            1
            / np.sqrt(n_samples * n_frames)
            * np.linalg.norm(
                X,
                axis=(1, 2),
                keepdims=True,
            )
        )

    return np.take(
        compute_norm(responses, rectified=rectified), self.model_index, axis=0
    )