Skip to content

Network

flyvision.network.network.Network

Bases: Module

A connectome-constrained network with nodes, edges, and dynamics.

Parameters:

Name Type Description Default
connectome Dict[str, Any]

Connectome configuration.

Namespace(type='ConnectomeFromAvgFilters', file='fib25-fib19_v2.2.json', extent=15, n_syn_fill=1)
dynamics Dict[str, Any]

Network dynamics configuration.

Namespace(type='PPNeuronIGRSynapses', activation=Namespace(type='relu'))
node_config Dict[str, Any]

Node parameter configuration.

Namespace(bias=Namespace(type='RestingPotential', groupby=['type'], initial_dist='Normal', mode='sample', requires_grad=True, mean=0.5, std=0.05, penalize=Namespace(activity=True), seed=0), time_const=Namespace(type='TimeConstant', groupby=['type'], initial_dist='Value', value=0.05, requires_grad=True))
edge_config Dict[str, Any]

Edge parameter configuration.

Namespace(sign=Namespace(type='SynapseSign', initial_dist='Value', requires_grad=False, groupby=['source_type', 'target_type']), syn_count=Namespace(type='SynapseCount', initial_dist='Lognormal', mode='mean', requires_grad=False, std=1.0, groupby=['source_type', 'target_type', 'dv', 'du']), syn_strength=Namespace(type='SynapseCountScaling', initial_dist='Value', requires_grad=True, scale=0.01, clamp='non_negative', groupby=['source_type', 'target_type']))

Attributes:

Name Type Description
connectome Connectome

Connectome directory.

dynamics NetworkDynamics

Network dynamics.

node_params Namespace

Node parameters.

edge_params Namespace

Edge parameters.

n_nodes int

Number of nodes.

n_edges int

Number of edges.

num_parameters int

Number of parameters.

config Namespace

Config namespace.

_source_indices Tensor

Source indices.

_target_indices Tensor

Target indices.

symmetry_config Namespace

Symmetry config.

clamp_config Namespace

Clamp config.

stimulus Stimulus

Stimulus object.

_state_hooks tuple

State hooks.

Source code in flyvision/network/network.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
class Network(nn.Module):
    """A connectome-constrained network with nodes, edges, and dynamics.

    Args:
        connectome: Connectome configuration.
        dynamics: Network dynamics configuration.
        node_config: Node parameter configuration.
        edge_config: Edge parameter configuration.

    Attributes:
        connectome (Connectome): Connectome directory.
        dynamics (NetworkDynamics): Network dynamics.
        node_params (Namespace): Node parameters.
        edge_params (Namespace): Edge parameters.
        n_nodes (int): Number of nodes.
        n_edges (int): Number of edges.
        num_parameters (int): Number of parameters.
        config (Namespace): Config namespace.
        _source_indices (Tensor): Source indices.
        _target_indices (Tensor): Target indices.
        symmetry_config (Namespace): Symmetry config.
        clamp_config (Namespace): Clamp config.
        stimulus (Stimulus): Stimulus object.
        _state_hooks (tuple): State hooks.
    """

    def __init__(
        self,
        connectome: Dict[str, Any] = Namespace(
            type="ConnectomeFromAvgFilters",
            file="fib25-fib19_v2.2.json",
            extent=15,
            n_syn_fill=1,
        ),
        dynamics: Dict[str, Any] = Namespace(
            type="PPNeuronIGRSynapses", activation=Namespace(type="relu")
        ),
        node_config: Dict[str, Any] = Namespace(
            bias=Namespace(
                type="RestingPotential",
                groupby=["type"],
                initial_dist="Normal",
                mode="sample",
                requires_grad=True,
                mean=0.5,
                std=0.05,
                penalize=Namespace(activity=True),
                seed=0,
            ),
            time_const=Namespace(
                type="TimeConstant",
                groupby=["type"],
                initial_dist="Value",
                value=0.05,
                requires_grad=True,
            ),
        ),
        edge_config: Dict[str, Any] = Namespace(
            sign=Namespace(
                type="SynapseSign",
                initial_dist="Value",
                requires_grad=False,
                groupby=["source_type", "target_type"],
            ),
            syn_count=Namespace(
                type="SynapseCount",
                initial_dist="Lognormal",
                mode="mean",
                requires_grad=False,
                std=1.0,
                groupby=["source_type", "target_type", "dv", "du"],
            ),
            syn_strength=Namespace(
                type="SynapseCountScaling",
                initial_dist="Value",
                requires_grad=True,
                scale=0.01,
                clamp="non_negative",
                groupby=["source_type", "target_type"],
            ),
        ),
        stimulus_config: Dict[str, Any] = Namespace(type="Stimulus", init_buffer=False),
    ):
        super().__init__()

        # Prepare configs.
        connectome, dynamics, node_config, edge_config, stimulus_config, self.config = (
            self.prepare_configs(
                connectome, dynamics, node_config, edge_config, stimulus_config
            )
        )

        # Store the connectome, dynamics, and parameters.
        self.connectome = init_connectome(**connectome)
        self.dynamics = forward_subclass(NetworkDynamics, dynamics)

        # Load constant indices into memory.
        # Store source/target indices.
        self._source_indices = torch.tensor(self.connectome.edges.source_index[:])
        self._target_indices = torch.tensor(self.connectome.edges.target_index[:])

        self.n_nodes = len(self.connectome.nodes.index)
        self.n_edges = len(self.connectome.edges.source_index)

        # Optional way of parameter sharing is averaging at every call across
        # precomputed masks. This can be useful for e.g. symmetric electrical
        # compartments.
        # These masks are collected from Parameters into this namespace.
        self.symmetry_config = Namespace()  # type: Dict[str, List[torch.Tensor]]
        # Clamp configuration is collected from Parameter into this Namespace
        # for projected gradient descent.
        self.clamp_config = Namespace()

        # Construct node parameter sets.
        self.node_params = Namespace()
        for param_name, param_config in node_config.items():
            param = forward_subclass(
                Parameter,
                config={
                    "type": param_config.type,
                    "param_config": param_config,
                    "connectome": self.connectome,
                },
            )

            # register parameter to module
            self.register_parameter(f"nodes_{param_name}", param.raw_values)

            # creating index to map shared parameters onto all nodes,
            # sources, or targets
            param.readers = dict(
                nodes=param.indices,
                sources=param.indices[self._source_indices],
                targets=param.indices[self._target_indices],
            )
            self.node_params[param_name] = param

            # additional map to optional boolean masks to constrain
            # parameters (called in self.clamp)
            self.symmetry_config[f"nodes_{param_name}"] = getattr(
                param, "symmetry_masks", []
            )

            # additional map to optional clamp configuration to constrain
            # parameters (called in self.clamp)
            self.clamp_config[f"nodes_{param_name}"] = getattr(
                param_config, "clamp", None
            )

        # Construct edge parameter sets.
        self.edge_params = Namespace()
        for param_name, param_config in edge_config.items():
            param = forward_subclass(
                Parameter,
                config={
                    "type": param_config.type,
                    "param_config": param_config,
                    "connectome": self.connectome,
                },
            )

            self.register_parameter(f"edges_{param_name}", param.raw_values)

            # creating index to map shared parameters onto all edges
            param.readers = dict(edges=param.indices)

            self.edge_params[param_name] = param

            self.symmetry_config[f"edges_{param_name}"] = getattr(
                param, "symmetry_masks", []
            )

            self.clamp_config[f"edges_{param_name}"] = getattr(
                param_config, "clamp", None
            )

        self.num_parameters = n_params(self)
        self._state_hooks = tuple()

        self.stimulus = init_stimulus(self.connectome, **stimulus_config)

        logger.info("Initialized network with %s parameters.", self.num_parameters)

    def __repr__(self):
        return self.config.__repr__().replace("Namespace", "Network", 1)

    def prepare_configs(
        self, connectome, dynamics, node_config, edge_config, stimulus_config
    ):
        """Prepare configs for network initialization."""
        connectome = namespacify(connectome).deepcopy()
        dynamics = namespacify(dynamics).deepcopy()
        node_config = namespacify(node_config).deepcopy()
        edge_config = namespacify(edge_config).deepcopy()
        stimulus_config = namespacify(stimulus_config).deepcopy()
        config = Namespace(
            connectome=connectome,
            dynamics=dynamics,
            node_config=node_config,
            edge_config=edge_config,
            stimulus_config=stimulus_config,
        ).deepcopy()
        return connectome, dynamics, node_config, edge_config, stimulus_config, config

    def param_api(self) -> Dict[str, Dict[str, Tensor]]:
        """Param api for inspection.

        Returns:
            Parameter namespace for inspection.

        Note:
            This is not the same as the parameter api passed to the dynamics. This is a
            convenience function to inspect the parameters, but does not write derived
            parameters or sources and targets states.
        """
        # Construct the base parameter namespace.
        params = Namespace(
            nodes=Namespace(),
            edges=Namespace(),
            sources=Namespace(),
            targets=Namespace(),
        )
        for param_name, parameter in {
            **self.node_params,
            **self.edge_params,
        }.items():
            values = parameter.semantic_values
            for route, indices in parameter.readers.items():
                # route one of ("nodes", "sources", "target", "edges")
                params[route][param_name] = Namespace(parameter=values, indices=indices)
        return params

    def _param_api(self) -> AutoDeref[str, AutoDeref[str, RefTensor]]:
        """Returns params object passed to `dynamics`.

        Returns:
            Parameter namespace for dynamics.
        """
        # Construct the base parameter namespace.
        params = AutoDeref(
            nodes=AutoDeref(),
            edges=AutoDeref(),
            sources=AutoDeref(),
            targets=AutoDeref(),
        )
        for param_name, parameter in {
            **self.node_params,
            **self.edge_params,
        }.items():
            values = parameter.semantic_values
            for route, indices in parameter.readers.items():
                # route one of ("nodes", "sources", "target", "edges")
                params[route][param_name] = RefTensor(values, indices)
        # Add derived parameters.
        self.dynamics.write_derived_params(params)
        for k, v in params.nodes.items():
            if k not in params.sources:
                params.sources[k] = self._source_gather(v)
                params.targets[k] = self._target_gather(v)

        return params

    def _source_gather(self, x: Tensor) -> RefTensor:
        """Gathers source node states across edges.

        Args:
            x: Node-level activation, e.g., voltages. Shape is (n_nodes).

        Returns:
            Edge-level representation. Shape is (n_edges).

        Note:
            For edge-level access to target node states for elementwise operations.
            Called in _param_api and _state_api.
        """
        return RefTensor(x, self._source_indices)

    def _target_gather(self, x: Tensor) -> RefTensor:
        """Gathers target node states across edges.

        Args:
            x: Node-level activation, e.g., voltages. Shape is (n_nodes).

        Returns:
            Edge-level representation. Shape is (n_edges).

        Note:
            For edge-level access to target node states for elementwise operations.
            Called in _param_api and _state_api.
        """
        return RefTensor(x, self._target_indices)

    def target_sum(self, x: Tensor) -> Tensor:
        """Scatter sum operation creating target node states from inputs.

        Args:
            x: Edge inputs to targets, e.g., currents. Shape is (batch_size, n_edges).

        Returns:
            Node-level input. Shape is (batch_size, n_nodes).
        """
        result = torch.zeros((*x.shape[:-1], self.n_nodes))
        # signature: tensor.scatter_add_(dim, index, other)
        result.scatter_add_(
            -1,  # nodes dim
            self._target_indices.expand(  # view of index expanded over dims of x
                *x.shape
            ),
            x,
        )
        return result

    def _initial_state(
        self, params: AutoDeref[str, AutoDeref[str, RefTensor]], batch_size: int
    ) -> AutoDeref[str, AutoDeref[str, Union[Tensor, RefTensor]]]:
        """Compute the initial state, given the parameters and batch size.

        Args:
            params: Parameter namespace.
            batch_size: Batch size.

        Returns:
            Initial state namespace of node, edge, source, and target states.
        """
        # Initialize the network.
        state = AutoDeref(nodes=AutoDeref(), edges=AutoDeref())
        self.dynamics.write_initial_state(state, params)

        # Expand over batch dimension.
        for k, v in state.nodes.items():
            state.nodes[k] = v.expand(batch_size, *v.shape)
        for k, v in state.edges.items():
            state.edges[k] = v.expand(batch_size, *v.shape)

        return self._state_api(state)

    def _next_state(
        self,
        params: AutoDeref[str, AutoDeref[str, RefTensor]],
        state: AutoDeref[str, AutoDeref[str, Union[Tensor, RefTensor]]],
        x_t: Tensor,
        dt: float,
    ) -> AutoDeref[str, AutoDeref[str, Union[Tensor, RefTensor]]]:
        """Compute the next state, given the current `state` and stimulus `x_t`.

        Args:
            params: Parameters.
            state: Current state.
            x_t: Stimulus at time t. Shape is (batch_size, n_nodes).
            dt: Time step.

        Returns:
            Next state namespace of node, edge, source, and target states.

        Note:
            Uses simple, elementwise Euler integration.
        """
        vel = AutoDeref(nodes=AutoDeref(), edges=AutoDeref())

        self.dynamics.write_state_velocity(
            vel, state, params, self.target_sum, x_t, dt=dt
        )

        next_state = AutoDeref(
            nodes=AutoDeref(**{
                k: state.nodes[k] + vel.nodes[k] * dt for k in state.nodes
            }),
            edges=AutoDeref(**{
                k: state.edges[k] + vel.edges[k] * dt for k in state.edges
            }),
        )

        return self._state_api(next_state)

    def _state_api(
        self, state: AutoDeref[str, AutoDeref[str, Union[Tensor, RefTensor]]]
    ) -> AutoDeref[str, AutoDeref[str, Union[Tensor, RefTensor]]]:
        """Populate sources and targets states from nodes states.

        Args:
            state: Current state.

        Returns:
            Updated state with populated sources and targets.

        Note:
            Optional state hooks are called here (in order of registration).
            This is returned by _initial_state and _next_state.
        """
        for hook in self._state_hooks:
            _state = hook(state)
            if _state is not None:
                state = _state

        state = AutoDeref(
            nodes=state.nodes,
            edges=state.edges,
            sources=AutoDeref(**valmap(self._source_gather, state.nodes)),
            targets=AutoDeref(**valmap(self._target_gather, state.nodes)),
        )

        return state

    def register_state_hook(self, state_hook: Callable, **kwargs) -> None:
        """Register a state hook to retrieve or modify the state.

        Args:
            state_hook: Callable to be used as a hook.
            **kwargs: Keyword arguments to pass to the callable.

        Raises:
            ValueError: If state_hook is not callable.

        Note:
            The hook is called in _state_api. Useful for targeted perturbations.
        """

        class StateHook:
            def __init__(self, hook, **kwargs):
                self.hook = hook
                self.kwargs = kwargs or {}

            def __call__(self, state):
                return self.hook(state, **self.kwargs)

        if not isinstance(state_hook, Callable):
            raise ValueError("state_hook must be callable")

        self._state_hooks += (StateHook(state_hook, **kwargs),)

    def clear_state_hooks(self, clear: bool = True):
        """Clear all state hooks.

        Args:
            clear: If True, clear all state hooks.
        """
        if clear:
            self._state_hooks = tuple()

    def clamp(self):
        """Clamp free parameters to their range specified in their config.

        Valid configs are `non_negative` to clamp at zero and tuple of the form
        (min, max) to clamp to an arbitrary range.

        Note:
            This function also enforces symmetry constraints.
        """
        # clamp parameters
        for param_name, mode in self.clamp_config.items():
            param = getattr(self, param_name)
            if param.requires_grad:
                if mode is None:
                    pass
                elif mode == "non_negative":
                    param.data.clamp_(0)
                elif isinstance(mode, Iterable) and len(mode) == 2:
                    param.data.clamp_(*mode)
                else:
                    raise NotImplementedError(f"Clamping mode {mode} not implemented.")

        # enforce symmetry constraints
        for param_name, masks in self.symmetry_config.items():
            param = getattr(self, param_name)
            if param.requires_grad:
                for symmetry in masks:
                    param.data[symmetry] = param.data[symmetry].mean()

    def forward(
        self, x: Tensor, dt: float, state: AutoDeref = None, as_states: bool = False
    ) -> Union[torch.Tensor, AutoDeref]:
        """Forward pass of the network.

        Args:
            x: Whole-network stimulus of shape (batch_size, n_frames, n_cells).
            dt: Integration time constant.
            state: Initial state of the network. If not given, computed from
                NetworksDynamics.write_initial_state. initial_state and fade_in_state
                are convenience functions to compute initial steady states.
            as_states: If True, returns the states as List[AutoDeref], else concatenates
                the activity of the nodes and returns a tensor.

        Returns:
            Network activity or states.
        """
        # To keep the parameters within their valid domain, they get clamped.
        self.clamp()
        # Construct the parameter API.
        params = self._param_api()

        # Initialize the network state.
        if state is None:
            state = self._initial_state(params, x.shape[0])

        def handle(state):
            # loop over the temporal dimension for integration of dynamics
            for i in range(x.shape[1]):
                state = self._next_state(params, state, x[:, i], dt)
                if as_states is False:
                    yield state.nodes.activity
                else:
                    yield state

        if as_states is True:
            return list(handle(state))
        return torch.stack(list(handle(state)), dim=1)

    def steady_state(
        self,
        t_pre: float,
        dt: float,
        batch_size: int,
        value: float = 0.5,
        state: Optional[AutoDeref] = None,
        grad: bool = False,
        return_last: bool = True,
    ) -> AutoDeref:
        """Compute state after grey-scale stimulus.

        Args:
            t_pre: Time of the grey-scale stimulus.
            dt: Integration time constant.
            batch_size: Batch size.
            value: Value of the grey-scale stimulus.
            state: Initial state of the network. If not given, computed from
                NetworksDynamics.write_initial_state. initial_state and fade_in_state
                are convenience functions to compute initial steady states.
            grad: If True, the state is computed with gradient.
            return_last: If True, return only the last state.

        Returns:
            Steady state of the network after a grey-scale stimulus.
        """
        if t_pre is None or t_pre <= 0.0:
            return state

        if value is None:
            return state

        self.stimulus.zero(batch_size, int(t_pre / dt))
        self.stimulus.add_pre_stim(value)

        with self.enable_grad(grad):
            if return_last:
                return self(self.stimulus(), dt, as_states=True, state=state)[-1]
            return self(self.stimulus(), dt, as_states=True, state=state)

    def fade_in_state(
        self,
        t_fade_in: float,
        dt: float,
        initial_frames: Tensor,
        state: Optional[AutoDeref] = None,
        grad: bool = False,
    ) -> AutoDeref:
        """Compute state after fade-in stimulus of initial_frames.

        Args:
            t_fade_in: Time of the fade-in stimulus.
            dt: Integration time constant.
            initial_frames: Tensor of shape (batch_size, 1, n_input_elements).
            state: Initial state of the network. If not given, computed from
                NetworksDynamics.write_initial_state. initial_state and fade_in_state
                are convenience functions to compute initial steady states.
            grad: If True, the state is computed with gradient.

        Returns:
            State after fade-in stimulus.
        """
        if t_fade_in is None or t_fade_in <= 0.0:
            return state

        batch_size = initial_frames.shape[0]

        # replicate initial frame over int(t_fade_in/dt) frames and fade in
        # by ramping up the contrast
        self.stimulus.zero(batch_size, int(t_fade_in / dt))

        initial_frames = (
            torch.linspace(0, 1, int(t_fade_in / dt))[None, :, None]
            * (initial_frames.repeat(1, int(t_fade_in / dt), 1) - 0.5)
            + 0.5
        )
        self.stimulus.add_input(initial_frames[:, :, None])
        with self.enable_grad(grad):
            return self(self.stimulus(), dt, as_states=True, state=state)[-1]

    def simulate(
        self,
        movie_input: torch.Tensor,
        dt: float,
        initial_state: Union[AutoDeref, None, Literal["auto"]] = "auto",
        as_states: bool = False,
        as_layer_activity: bool = False,
    ) -> Union[torch.Tensor, AutoDeref, LayerActivity]:
        """Simulate the network activity from movie input.

        Args:
            movie_input: Tensor of shape (batch_size, n_frames, 1, hexals).
            dt: Integration time constant. Warns if dt > 1/50.
            initial_state: Network activity at the beginning of the simulation.
                Use fade_in_state or steady_state to compute the initial state from grey
                input or from ramping up the contrast of the first movie frame.
                Defaults to "auto", which uses the steady_state after 1s of grey input.
            as_states: If True, return the states as AutoDeref dictionary instead of
                a tensor. Defaults to False.
            as_layer_activity: If True, return a LayerActivity object. Defaults to False.
                Currently only supported for ConnectomeFromAvgFilters.

        Returns:
            Activity tensor of shape (batch_size, n_frames, #neurons),
            or AutoDeref dictionary if `as_states` is True,
            or LayerActivity object if `as_layer_activity` is True.

        Raises:
            ValueError: If the movie_input is not four-dimensional.
            ValueError: If the integration time step is bigger than 1/50.
            ValueError: If the network is not in evaluation mode or any
                parameters require grad.
        """
        if len(movie_input.shape) != 4:
            raise ValueError("requires shape (sample, frame, 1, hexals)")

        if (
            as_layer_activity
            and not self.connectome.__class__.__name__ == "ConnectomeFromAvgFilters"
        ):
            raise ValueError(
                "as_layer_activity is currently only supported for "
                "ConnectomeFromAvgFilters"
            )

        if dt > 1 / 50:
            warnings.warn(
                f"dt={dt} is very large for integration. "
                "Better choose a smaller dt (<= 1/50 to avoid this warning)",
                IntegrationWarning,
                stacklevel=2,
            )

        batch_size, n_frames = movie_input.shape[:2]
        if initial_state == "auto":
            initial_state = self.steady_state(1.0, dt, batch_size)
        with simulation(self):
            assert self.training is False and all(
                not p.requires_grad for p in self.parameters()
            )
            self.stimulus.zero(batch_size, n_frames)
            self.stimulus.add_input(movie_input)
            if as_layer_activity:
                return LayerActivity(
                    self.forward(self.stimulus(), dt, initial_state, as_states).cpu(),
                    self.connectome,
                    keepref=True,
                )
            return self.forward(self.stimulus(), dt, initial_state, as_states)

    @contextmanager
    def enable_grad(self, grad: bool = True):
        """Context manager to enable or disable gradient computation.

        Args:
            grad: If True, enable gradient computation.
        """
        prev = torch.is_grad_enabled()
        torch.set_grad_enabled(grad)
        try:
            yield
        finally:
            torch.set_grad_enabled(prev)

    def stimulus_response(
        self,
        stim_dataset: SequenceDataset,
        dt: float,
        indices: Optional[Iterable[int]] = None,
        t_pre: float = 1.0,
        t_fade_in: float = 0.0,
        grad: bool = False,
        default_stim_key: Any = "lum",
        batch_size: int = 1,
    ):
        """Compute stimulus responses for a given stimulus dataset.

        Args:
            stim_dataset: Stimulus dataset.
            dt: Integration time constant.
            indices: Indices of the stimuli to compute the response for.
                If not given, all stimuli responses are computed.
            t_pre: Time of the grey-scale stimulus.
            t_fade_in: Time of the fade-in stimulus (slow).
            grad: If True, the state is computed with gradient.
            default_stim_key: Key of the stimulus in the dataset if it returns
                a dictionary.
            batch_size: Batch size for processing.

        Note:
            Per default, applies a grey-scale stimulus for 1 second, no
            fade-in stimulus.

        Yields:
            Tuple of (stimulus, response) as numpy arrays.
        """
        stim_dataset.dt = dt
        if indices is None:
            indices = np.arange(len(stim_dataset))
        stim_loader = DataLoader(
            stim_dataset, batch_size=batch_size, sampler=IndexSampler(indices)
        )

        stimulus = self.stimulus

        # compute initial state
        initial_state = self.steady_state(t_pre, dt, batch_size=1, value=0.5)

        with self.enable_grad(grad):
            logger.info("Computing %s stimulus responses.", len(indices))
            for stim in tqdm(
                stim_loader, desc="Batch", total=len(stim_loader), leave=False
            ):
                # when datasets return dictionaries, we assume that the stimulus
                # is stored under the key `default_stim_key`
                if isinstance(stim, dict):
                    stim = stim[default_stim_key]  # (batch, frames, 1, hexals)
                else:
                    stim = stim.unsqueeze(-2)  # (batch, frames, 1, hexals)

                # fade in stimulus
                fade_in_state = self.fade_in_state(
                    t_fade_in=t_fade_in,
                    dt=dt,
                    initial_frames=stim[:, 0],
                    state=initial_state,
                )

                def handle_stim(stim, fade_in_state):
                    # reset stimulus
                    batch_size, n_frames = stim.shape[:2]
                    stimulus.zero(batch_size, n_frames)

                    # add stimulus
                    stimulus.add_input(stim)

                    # compute response
                    if grad is False:
                        return (
                            stim.cpu().numpy(),
                            self(stimulus(), dt, state=fade_in_state)
                            .detach()
                            .cpu()
                            .numpy(),
                        )
                    elif grad is True:
                        return (
                            stim.cpu().numpy(),
                            self(stimulus(), dt, state=fade_in_state),
                        )

                yield handle_stim(stim, fade_in_state)

    def current_response(
        self,
        stim_dataset: SequenceDataset,
        dt: float,
        indices: Optional[Iterable[int]] = None,
        t_pre: float = 1.0,
        t_fade_in: float = 0,
        default_stim_key: Any = "lum",
    ):
        """Compute stimulus currents and responses for a given stimulus dataset.

        Note:
            Requires Dynamics to implement `currents`.

        Args:
            stim_dataset: Stimulus dataset.
            dt: Integration time constant.
            indices: Indices of the stimuli to compute the response for.
                If not given, all stimuli responses are computed.
            t_pre: Time of the grey-scale stimulus.
            t_fade_in: Time of the fade-in stimulus (slow).
            default_stim_key: Key of the stimulus in the dataset if it returns
                a dictionary.

        Yields:
            Tuple of (stimulus, activity, currents) as numpy arrays.
        """
        self.clamp()
        # Construct the parameter API.
        params = self._param_api()

        stim_dataset.dt = dt
        if indices is None:
            indices = np.arange(len(stim_dataset))
        stim_loader = DataLoader(
            stim_dataset, batch_size=1, sampler=IndexSampler(indices)
        )

        stimulus = self.stimulus
        initial_state = self.steady_state(t_pre, dt, batch_size=1, value=0.5)
        with torch.no_grad():
            logger.info("Computing %d stimulus responses.", len(indices))
            for stim in stim_loader:
                if isinstance(stim, dict):
                    stim = stim[default_stim_key].squeeze(-2)

                fade_in_state = self.fade_in_state(
                    t_fade_in=t_fade_in,
                    dt=dt,
                    initial_frames=stim[:, 0].unsqueeze(1),
                    state=initial_state,
                )

                def handle_stim(stim, fade_in_state):
                    # reset stimulus
                    batch_size, n_frames, _ = stim.shape
                    stimulus.zero(batch_size, n_frames)

                    # add stimulus
                    stimulus.add_input(stim.unsqueeze(2))

                    # compute response
                    states = self(stimulus(), dt, state=fade_in_state, as_states=True)
                    return (
                        stim.cpu().numpy().squeeze(),
                        torch.stack(
                            [s.nodes.activity.cpu() for s in states],
                            dim=1,
                        )
                        .numpy()
                        .squeeze(),
                        torch.stack(
                            [self.dynamics.currents(s, params).cpu() for s in states],
                            dim=1,
                        )
                        .numpy()
                        .squeeze(),
                    )

                # stim, activity, currents
                yield handle_stim(stim, fade_in_state)

prepare_configs

prepare_configs(
    connectome,
    dynamics,
    node_config,
    edge_config,
    stimulus_config,
)

Prepare configs for network initialization.

Source code in flyvision/network/network.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def prepare_configs(
    self, connectome, dynamics, node_config, edge_config, stimulus_config
):
    """Prepare configs for network initialization."""
    connectome = namespacify(connectome).deepcopy()
    dynamics = namespacify(dynamics).deepcopy()
    node_config = namespacify(node_config).deepcopy()
    edge_config = namespacify(edge_config).deepcopy()
    stimulus_config = namespacify(stimulus_config).deepcopy()
    config = Namespace(
        connectome=connectome,
        dynamics=dynamics,
        node_config=node_config,
        edge_config=edge_config,
        stimulus_config=stimulus_config,
    ).deepcopy()
    return connectome, dynamics, node_config, edge_config, stimulus_config, config

param_api

param_api()

Param api for inspection.

Returns:

Type Description
Dict[str, Dict[str, Tensor]]

Parameter namespace for inspection.

Note

This is not the same as the parameter api passed to the dynamics. This is a convenience function to inspect the parameters, but does not write derived parameters or sources and targets states.

Source code in flyvision/network/network.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def param_api(self) -> Dict[str, Dict[str, Tensor]]:
    """Param api for inspection.

    Returns:
        Parameter namespace for inspection.

    Note:
        This is not the same as the parameter api passed to the dynamics. This is a
        convenience function to inspect the parameters, but does not write derived
        parameters or sources and targets states.
    """
    # Construct the base parameter namespace.
    params = Namespace(
        nodes=Namespace(),
        edges=Namespace(),
        sources=Namespace(),
        targets=Namespace(),
    )
    for param_name, parameter in {
        **self.node_params,
        **self.edge_params,
    }.items():
        values = parameter.semantic_values
        for route, indices in parameter.readers.items():
            # route one of ("nodes", "sources", "target", "edges")
            params[route][param_name] = Namespace(parameter=values, indices=indices)
    return params

target_sum

target_sum(x)

Scatter sum operation creating target node states from inputs.

Parameters:

Name Type Description Default
x Tensor

Edge inputs to targets, e.g., currents. Shape is (batch_size, n_edges).

required

Returns:

Type Description
Tensor

Node-level input. Shape is (batch_size, n_nodes).

Source code in flyvision/network/network.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
def target_sum(self, x: Tensor) -> Tensor:
    """Scatter sum operation creating target node states from inputs.

    Args:
        x: Edge inputs to targets, e.g., currents. Shape is (batch_size, n_edges).

    Returns:
        Node-level input. Shape is (batch_size, n_nodes).
    """
    result = torch.zeros((*x.shape[:-1], self.n_nodes))
    # signature: tensor.scatter_add_(dim, index, other)
    result.scatter_add_(
        -1,  # nodes dim
        self._target_indices.expand(  # view of index expanded over dims of x
            *x.shape
        ),
        x,
    )
    return result

register_state_hook

register_state_hook(state_hook, **kwargs)

Register a state hook to retrieve or modify the state.

Parameters:

Name Type Description Default
state_hook Callable

Callable to be used as a hook.

required
**kwargs

Keyword arguments to pass to the callable.

{}

Raises:

Type Description
ValueError

If state_hook is not callable.

Note

The hook is called in _state_api. Useful for targeted perturbations.

Source code in flyvision/network/network.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def register_state_hook(self, state_hook: Callable, **kwargs) -> None:
    """Register a state hook to retrieve or modify the state.

    Args:
        state_hook: Callable to be used as a hook.
        **kwargs: Keyword arguments to pass to the callable.

    Raises:
        ValueError: If state_hook is not callable.

    Note:
        The hook is called in _state_api. Useful for targeted perturbations.
    """

    class StateHook:
        def __init__(self, hook, **kwargs):
            self.hook = hook
            self.kwargs = kwargs or {}

        def __call__(self, state):
            return self.hook(state, **self.kwargs)

    if not isinstance(state_hook, Callable):
        raise ValueError("state_hook must be callable")

    self._state_hooks += (StateHook(state_hook, **kwargs),)

clear_state_hooks

clear_state_hooks(clear=True)

Clear all state hooks.

Parameters:

Name Type Description Default
clear bool

If True, clear all state hooks.

True
Source code in flyvision/network/network.py
471
472
473
474
475
476
477
478
def clear_state_hooks(self, clear: bool = True):
    """Clear all state hooks.

    Args:
        clear: If True, clear all state hooks.
    """
    if clear:
        self._state_hooks = tuple()

clamp

clamp()

Clamp free parameters to their range specified in their config.

Valid configs are non_negative to clamp at zero and tuple of the form (min, max) to clamp to an arbitrary range.

Note

This function also enforces symmetry constraints.

Source code in flyvision/network/network.py
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def clamp(self):
    """Clamp free parameters to their range specified in their config.

    Valid configs are `non_negative` to clamp at zero and tuple of the form
    (min, max) to clamp to an arbitrary range.

    Note:
        This function also enforces symmetry constraints.
    """
    # clamp parameters
    for param_name, mode in self.clamp_config.items():
        param = getattr(self, param_name)
        if param.requires_grad:
            if mode is None:
                pass
            elif mode == "non_negative":
                param.data.clamp_(0)
            elif isinstance(mode, Iterable) and len(mode) == 2:
                param.data.clamp_(*mode)
            else:
                raise NotImplementedError(f"Clamping mode {mode} not implemented.")

    # enforce symmetry constraints
    for param_name, masks in self.symmetry_config.items():
        param = getattr(self, param_name)
        if param.requires_grad:
            for symmetry in masks:
                param.data[symmetry] = param.data[symmetry].mean()

forward

forward(x, dt, state=None, as_states=False)

Forward pass of the network.

Parameters:

Name Type Description Default
x Tensor

Whole-network stimulus of shape (batch_size, n_frames, n_cells).

required
dt float

Integration time constant.

required
state AutoDeref

Initial state of the network. If not given, computed from NetworksDynamics.write_initial_state. initial_state and fade_in_state are convenience functions to compute initial steady states.

None
as_states bool

If True, returns the states as List[AutoDeref], else concatenates the activity of the nodes and returns a tensor.

False

Returns:

Type Description
Union[Tensor, AutoDeref]

Network activity or states.

Source code in flyvision/network/network.py
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
def forward(
    self, x: Tensor, dt: float, state: AutoDeref = None, as_states: bool = False
) -> Union[torch.Tensor, AutoDeref]:
    """Forward pass of the network.

    Args:
        x: Whole-network stimulus of shape (batch_size, n_frames, n_cells).
        dt: Integration time constant.
        state: Initial state of the network. If not given, computed from
            NetworksDynamics.write_initial_state. initial_state and fade_in_state
            are convenience functions to compute initial steady states.
        as_states: If True, returns the states as List[AutoDeref], else concatenates
            the activity of the nodes and returns a tensor.

    Returns:
        Network activity or states.
    """
    # To keep the parameters within their valid domain, they get clamped.
    self.clamp()
    # Construct the parameter API.
    params = self._param_api()

    # Initialize the network state.
    if state is None:
        state = self._initial_state(params, x.shape[0])

    def handle(state):
        # loop over the temporal dimension for integration of dynamics
        for i in range(x.shape[1]):
            state = self._next_state(params, state, x[:, i], dt)
            if as_states is False:
                yield state.nodes.activity
            else:
                yield state

    if as_states is True:
        return list(handle(state))
    return torch.stack(list(handle(state)), dim=1)

steady_state

steady_state(
    t_pre,
    dt,
    batch_size,
    value=0.5,
    state=None,
    grad=False,
    return_last=True,
)

Compute state after grey-scale stimulus.

Parameters:

Name Type Description Default
t_pre float

Time of the grey-scale stimulus.

required
dt float

Integration time constant.

required
batch_size int

Batch size.

required
value float

Value of the grey-scale stimulus.

0.5
state Optional[AutoDeref]

Initial state of the network. If not given, computed from NetworksDynamics.write_initial_state. initial_state and fade_in_state are convenience functions to compute initial steady states.

None
grad bool

If True, the state is computed with gradient.

False
return_last bool

If True, return only the last state.

True

Returns:

Type Description
AutoDeref

Steady state of the network after a grey-scale stimulus.

Source code in flyvision/network/network.py
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def steady_state(
    self,
    t_pre: float,
    dt: float,
    batch_size: int,
    value: float = 0.5,
    state: Optional[AutoDeref] = None,
    grad: bool = False,
    return_last: bool = True,
) -> AutoDeref:
    """Compute state after grey-scale stimulus.

    Args:
        t_pre: Time of the grey-scale stimulus.
        dt: Integration time constant.
        batch_size: Batch size.
        value: Value of the grey-scale stimulus.
        state: Initial state of the network. If not given, computed from
            NetworksDynamics.write_initial_state. initial_state and fade_in_state
            are convenience functions to compute initial steady states.
        grad: If True, the state is computed with gradient.
        return_last: If True, return only the last state.

    Returns:
        Steady state of the network after a grey-scale stimulus.
    """
    if t_pre is None or t_pre <= 0.0:
        return state

    if value is None:
        return state

    self.stimulus.zero(batch_size, int(t_pre / dt))
    self.stimulus.add_pre_stim(value)

    with self.enable_grad(grad):
        if return_last:
            return self(self.stimulus(), dt, as_states=True, state=state)[-1]
        return self(self.stimulus(), dt, as_states=True, state=state)

fade_in_state

fade_in_state(
    t_fade_in, dt, initial_frames, state=None, grad=False
)

Compute state after fade-in stimulus of initial_frames.

Parameters:

Name Type Description Default
t_fade_in float

Time of the fade-in stimulus.

required
dt float

Integration time constant.

required
initial_frames Tensor

Tensor of shape (batch_size, 1, n_input_elements).

required
state Optional[AutoDeref]

Initial state of the network. If not given, computed from NetworksDynamics.write_initial_state. initial_state and fade_in_state are convenience functions to compute initial steady states.

None
grad bool

If True, the state is computed with gradient.

False

Returns:

Type Description
AutoDeref

State after fade-in stimulus.

Source code in flyvision/network/network.py
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
def fade_in_state(
    self,
    t_fade_in: float,
    dt: float,
    initial_frames: Tensor,
    state: Optional[AutoDeref] = None,
    grad: bool = False,
) -> AutoDeref:
    """Compute state after fade-in stimulus of initial_frames.

    Args:
        t_fade_in: Time of the fade-in stimulus.
        dt: Integration time constant.
        initial_frames: Tensor of shape (batch_size, 1, n_input_elements).
        state: Initial state of the network. If not given, computed from
            NetworksDynamics.write_initial_state. initial_state and fade_in_state
            are convenience functions to compute initial steady states.
        grad: If True, the state is computed with gradient.

    Returns:
        State after fade-in stimulus.
    """
    if t_fade_in is None or t_fade_in <= 0.0:
        return state

    batch_size = initial_frames.shape[0]

    # replicate initial frame over int(t_fade_in/dt) frames and fade in
    # by ramping up the contrast
    self.stimulus.zero(batch_size, int(t_fade_in / dt))

    initial_frames = (
        torch.linspace(0, 1, int(t_fade_in / dt))[None, :, None]
        * (initial_frames.repeat(1, int(t_fade_in / dt), 1) - 0.5)
        + 0.5
    )
    self.stimulus.add_input(initial_frames[:, :, None])
    with self.enable_grad(grad):
        return self(self.stimulus(), dt, as_states=True, state=state)[-1]

simulate

simulate(
    movie_input,
    dt,
    initial_state="auto",
    as_states=False,
    as_layer_activity=False,
)

Simulate the network activity from movie input.

Parameters:

Name Type Description Default
movie_input Tensor

Tensor of shape (batch_size, n_frames, 1, hexals).

required
dt float

Integration time constant. Warns if dt > 1/50.

required
initial_state Union[AutoDeref, None, Literal['auto']]

Network activity at the beginning of the simulation. Use fade_in_state or steady_state to compute the initial state from grey input or from ramping up the contrast of the first movie frame. Defaults to “auto”, which uses the steady_state after 1s of grey input.

'auto'
as_states bool

If True, return the states as AutoDeref dictionary instead of a tensor. Defaults to False.

False
as_layer_activity bool

If True, return a LayerActivity object. Defaults to False. Currently only supported for ConnectomeFromAvgFilters.

False

Returns:

Type Description
Union[Tensor, AutoDeref, LayerActivity]

Activity tensor of shape (batch_size, n_frames, #neurons),

Union[Tensor, AutoDeref, LayerActivity]

or AutoDeref dictionary if as_states is True,

Union[Tensor, AutoDeref, LayerActivity]

or LayerActivity object if as_layer_activity is True.

Raises:

Type Description
ValueError

If the movie_input is not four-dimensional.

ValueError

If the integration time step is bigger than 1/50.

ValueError

If the network is not in evaluation mode or any parameters require grad.

Source code in flyvision/network/network.py
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
def simulate(
    self,
    movie_input: torch.Tensor,
    dt: float,
    initial_state: Union[AutoDeref, None, Literal["auto"]] = "auto",
    as_states: bool = False,
    as_layer_activity: bool = False,
) -> Union[torch.Tensor, AutoDeref, LayerActivity]:
    """Simulate the network activity from movie input.

    Args:
        movie_input: Tensor of shape (batch_size, n_frames, 1, hexals).
        dt: Integration time constant. Warns if dt > 1/50.
        initial_state: Network activity at the beginning of the simulation.
            Use fade_in_state or steady_state to compute the initial state from grey
            input or from ramping up the contrast of the first movie frame.
            Defaults to "auto", which uses the steady_state after 1s of grey input.
        as_states: If True, return the states as AutoDeref dictionary instead of
            a tensor. Defaults to False.
        as_layer_activity: If True, return a LayerActivity object. Defaults to False.
            Currently only supported for ConnectomeFromAvgFilters.

    Returns:
        Activity tensor of shape (batch_size, n_frames, #neurons),
        or AutoDeref dictionary if `as_states` is True,
        or LayerActivity object if `as_layer_activity` is True.

    Raises:
        ValueError: If the movie_input is not four-dimensional.
        ValueError: If the integration time step is bigger than 1/50.
        ValueError: If the network is not in evaluation mode or any
            parameters require grad.
    """
    if len(movie_input.shape) != 4:
        raise ValueError("requires shape (sample, frame, 1, hexals)")

    if (
        as_layer_activity
        and not self.connectome.__class__.__name__ == "ConnectomeFromAvgFilters"
    ):
        raise ValueError(
            "as_layer_activity is currently only supported for "
            "ConnectomeFromAvgFilters"
        )

    if dt > 1 / 50:
        warnings.warn(
            f"dt={dt} is very large for integration. "
            "Better choose a smaller dt (<= 1/50 to avoid this warning)",
            IntegrationWarning,
            stacklevel=2,
        )

    batch_size, n_frames = movie_input.shape[:2]
    if initial_state == "auto":
        initial_state = self.steady_state(1.0, dt, batch_size)
    with simulation(self):
        assert self.training is False and all(
            not p.requires_grad for p in self.parameters()
        )
        self.stimulus.zero(batch_size, n_frames)
        self.stimulus.add_input(movie_input)
        if as_layer_activity:
            return LayerActivity(
                self.forward(self.stimulus(), dt, initial_state, as_states).cpu(),
                self.connectome,
                keepref=True,
            )
        return self.forward(self.stimulus(), dt, initial_state, as_states)

enable_grad

enable_grad(grad=True)

Context manager to enable or disable gradient computation.

Parameters:

Name Type Description Default
grad bool

If True, enable gradient computation.

True
Source code in flyvision/network/network.py
698
699
700
701
702
703
704
705
706
707
708
709
710
@contextmanager
def enable_grad(self, grad: bool = True):
    """Context manager to enable or disable gradient computation.

    Args:
        grad: If True, enable gradient computation.
    """
    prev = torch.is_grad_enabled()
    torch.set_grad_enabled(grad)
    try:
        yield
    finally:
        torch.set_grad_enabled(prev)

stimulus_response

stimulus_response(
    stim_dataset,
    dt,
    indices=None,
    t_pre=1.0,
    t_fade_in=0.0,
    grad=False,
    default_stim_key="lum",
    batch_size=1,
)

Compute stimulus responses for a given stimulus dataset.

Parameters:

Name Type Description Default
stim_dataset SequenceDataset

Stimulus dataset.

required
dt float

Integration time constant.

required
indices Optional[Iterable[int]]

Indices of the stimuli to compute the response for. If not given, all stimuli responses are computed.

None
t_pre float

Time of the grey-scale stimulus.

1.0
t_fade_in float

Time of the fade-in stimulus (slow).

0.0
grad bool

If True, the state is computed with gradient.

False
default_stim_key Any

Key of the stimulus in the dataset if it returns a dictionary.

'lum'
batch_size int

Batch size for processing.

1
Note

Per default, applies a grey-scale stimulus for 1 second, no fade-in stimulus.

Yields:

Type Description

Tuple of (stimulus, response) as numpy arrays.

Source code in flyvision/network/network.py
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
def stimulus_response(
    self,
    stim_dataset: SequenceDataset,
    dt: float,
    indices: Optional[Iterable[int]] = None,
    t_pre: float = 1.0,
    t_fade_in: float = 0.0,
    grad: bool = False,
    default_stim_key: Any = "lum",
    batch_size: int = 1,
):
    """Compute stimulus responses for a given stimulus dataset.

    Args:
        stim_dataset: Stimulus dataset.
        dt: Integration time constant.
        indices: Indices of the stimuli to compute the response for.
            If not given, all stimuli responses are computed.
        t_pre: Time of the grey-scale stimulus.
        t_fade_in: Time of the fade-in stimulus (slow).
        grad: If True, the state is computed with gradient.
        default_stim_key: Key of the stimulus in the dataset if it returns
            a dictionary.
        batch_size: Batch size for processing.

    Note:
        Per default, applies a grey-scale stimulus for 1 second, no
        fade-in stimulus.

    Yields:
        Tuple of (stimulus, response) as numpy arrays.
    """
    stim_dataset.dt = dt
    if indices is None:
        indices = np.arange(len(stim_dataset))
    stim_loader = DataLoader(
        stim_dataset, batch_size=batch_size, sampler=IndexSampler(indices)
    )

    stimulus = self.stimulus

    # compute initial state
    initial_state = self.steady_state(t_pre, dt, batch_size=1, value=0.5)

    with self.enable_grad(grad):
        logger.info("Computing %s stimulus responses.", len(indices))
        for stim in tqdm(
            stim_loader, desc="Batch", total=len(stim_loader), leave=False
        ):
            # when datasets return dictionaries, we assume that the stimulus
            # is stored under the key `default_stim_key`
            if isinstance(stim, dict):
                stim = stim[default_stim_key]  # (batch, frames, 1, hexals)
            else:
                stim = stim.unsqueeze(-2)  # (batch, frames, 1, hexals)

            # fade in stimulus
            fade_in_state = self.fade_in_state(
                t_fade_in=t_fade_in,
                dt=dt,
                initial_frames=stim[:, 0],
                state=initial_state,
            )

            def handle_stim(stim, fade_in_state):
                # reset stimulus
                batch_size, n_frames = stim.shape[:2]
                stimulus.zero(batch_size, n_frames)

                # add stimulus
                stimulus.add_input(stim)

                # compute response
                if grad is False:
                    return (
                        stim.cpu().numpy(),
                        self(stimulus(), dt, state=fade_in_state)
                        .detach()
                        .cpu()
                        .numpy(),
                    )
                elif grad is True:
                    return (
                        stim.cpu().numpy(),
                        self(stimulus(), dt, state=fade_in_state),
                    )

            yield handle_stim(stim, fade_in_state)

current_response

current_response(
    stim_dataset,
    dt,
    indices=None,
    t_pre=1.0,
    t_fade_in=0,
    default_stim_key="lum",
)

Compute stimulus currents and responses for a given stimulus dataset.

Note

Requires Dynamics to implement currents.

Parameters:

Name Type Description Default
stim_dataset SequenceDataset

Stimulus dataset.

required
dt float

Integration time constant.

required
indices Optional[Iterable[int]]

Indices of the stimuli to compute the response for. If not given, all stimuli responses are computed.

None
t_pre float

Time of the grey-scale stimulus.

1.0
t_fade_in float

Time of the fade-in stimulus (slow).

0
default_stim_key Any

Key of the stimulus in the dataset if it returns a dictionary.

'lum'

Yields:

Type Description

Tuple of (stimulus, activity, currents) as numpy arrays.

Source code in flyvision/network/network.py
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
def current_response(
    self,
    stim_dataset: SequenceDataset,
    dt: float,
    indices: Optional[Iterable[int]] = None,
    t_pre: float = 1.0,
    t_fade_in: float = 0,
    default_stim_key: Any = "lum",
):
    """Compute stimulus currents and responses for a given stimulus dataset.

    Note:
        Requires Dynamics to implement `currents`.

    Args:
        stim_dataset: Stimulus dataset.
        dt: Integration time constant.
        indices: Indices of the stimuli to compute the response for.
            If not given, all stimuli responses are computed.
        t_pre: Time of the grey-scale stimulus.
        t_fade_in: Time of the fade-in stimulus (slow).
        default_stim_key: Key of the stimulus in the dataset if it returns
            a dictionary.

    Yields:
        Tuple of (stimulus, activity, currents) as numpy arrays.
    """
    self.clamp()
    # Construct the parameter API.
    params = self._param_api()

    stim_dataset.dt = dt
    if indices is None:
        indices = np.arange(len(stim_dataset))
    stim_loader = DataLoader(
        stim_dataset, batch_size=1, sampler=IndexSampler(indices)
    )

    stimulus = self.stimulus
    initial_state = self.steady_state(t_pre, dt, batch_size=1, value=0.5)
    with torch.no_grad():
        logger.info("Computing %d stimulus responses.", len(indices))
        for stim in stim_loader:
            if isinstance(stim, dict):
                stim = stim[default_stim_key].squeeze(-2)

            fade_in_state = self.fade_in_state(
                t_fade_in=t_fade_in,
                dt=dt,
                initial_frames=stim[:, 0].unsqueeze(1),
                state=initial_state,
            )

            def handle_stim(stim, fade_in_state):
                # reset stimulus
                batch_size, n_frames, _ = stim.shape
                stimulus.zero(batch_size, n_frames)

                # add stimulus
                stimulus.add_input(stim.unsqueeze(2))

                # compute response
                states = self(stimulus(), dt, state=fade_in_state, as_states=True)
                return (
                    stim.cpu().numpy().squeeze(),
                    torch.stack(
                        [s.nodes.activity.cpu() for s in states],
                        dim=1,
                    )
                    .numpy()
                    .squeeze(),
                    torch.stack(
                        [self.dynamics.currents(s, params).cpu() for s in states],
                        dim=1,
                    )
                    .numpy()
                    .squeeze(),
                )

            # stim, activity, currents
            yield handle_stim(stim, fade_in_state)

Stimulus

Stimuli must implement the StimulusProtocol to be compatible with flyvision.network.network.Network.

flyvision.network.stimulus.StimulusProtocol

Bases: Protocol

Protocol for the Stimulus class.

Source code in flyvision/network/stimulus.py
15
16
17
18
19
20
21
22
23
@runtime_checkable
class StimulusProtocol(Protocol):
    """Protocol for the Stimulus class."""

    def __call__(self) -> Tensor: ...
    def add_input(self, x: Tensor, **kwargs) -> None: ...
    def add_pre_stim(self, x: Tensor, **kwargs) -> None: ...
    def zero(self, **kwargs) -> None: ...
    def nonzero(self) -> bool: ...

flyvision.network.stimulus.Stimulus

Interface to control the cell-specific stimulus buffer for the network.

Creates a buffer and maps standard video input to the photoreceptors but can map input to any other cell as well, e.g. to do perturbation experiments.

Parameters:

Name Type Description Default
connectome ConnectomeFromAvgFilters

Connectome directory to retrieve indexes for the stimulus buffer at the respective cell positions.

required
n_samples int

Number of samples to initialize the buffer with.

1
n_frames int

Number of frames to initialize the buffer with.

1
init_buffer bool

If False, do not initialize the stimulus buffer.

True

Attributes:

Name Type Description
layer_index Dict[str, NDArray]

Dictionary of cell type to index array.

central_cells_index Dict[str, int]

Dictionary of cell type to central cell index.

input_index NDArray

Index array of photoreceptors.

n_frames int

Number of frames in the stimulus buffer.

n_samples int

Number of samples in the stimulus buffer.

n_nodes int

Number of nodes in the stimulus buffer.

n_input_elements int

Number of input elements.

buffer Tensor

Stimulus buffer of shape (n_samples, n_frames, n_cells).

Returns:

Name Type Description
Tensor

Stimulus of shape (n_samples, n_frames, n_cells)

Example
stim = Stimulus(network.connectome, *x.shape[:2])
stim.add_input(x)
response = network(stim(), dt)
Source code in flyvision/network/stimulus.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
@register_stimulus
class Stimulus:
    """Interface to control the cell-specific stimulus buffer for the network.

    Creates a buffer and maps standard video input to the photoreceptors
    but can map input to any other cell as well, e.g. to do perturbation
    experiments.

    Args:
        connectome: Connectome directory to retrieve indexes for the stimulus
            buffer at the respective cell positions.
        n_samples: Number of samples to initialize the buffer with.
        n_frames: Number of frames to initialize the buffer with.
        init_buffer: If False, do not initialize the stimulus buffer.

    Attributes:
        layer_index (Dict[str, NDArray]): Dictionary of cell type to index array.
        central_cells_index (Dict[str, int]): Dictionary of cell type to
            central cell index.
        input_index (NDArray): Index array of photoreceptors.
        n_frames (int): Number of frames in the stimulus buffer.
        n_samples (int): Number of samples in the stimulus buffer.
        n_nodes (int): Number of nodes in the stimulus buffer.
        n_input_elements (int): Number of input elements.
        buffer (Tensor): Stimulus buffer of shape (n_samples, n_frames, n_cells).

    Returns:
        Tensor: Stimulus of shape (n_samples, n_frames, n_cells)

    Example:
        ```python
        stim = Stimulus(network.connectome, *x.shape[:2])
        stim.add_input(x)
        response = network(stim(), dt)
        ```
    """

    layer_index: Dict[str, NDArray]
    central_cells_index: Dict[str, int]
    input_index: NDArray
    n_frames: int
    n_samples: int
    n_nodes: int
    n_input_elements: int
    buffer: Tensor

    def __init__(
        self,
        connectome: ConnectomeFromAvgFilters,
        n_samples: int = 1,
        n_frames: int = 1,
        init_buffer: bool = True,
    ):
        self.layer_index = {
            cell_type: index[:]
            for cell_type, index in connectome.nodes.layer_index.items()
        }
        self.central_cells_index = dict(
            zip(
                connectome.unique_cell_types[:].astype(str),
                connectome.central_cells_index[:],
            )
        )
        self.input_index = np.array([
            self.layer_index[cell_type.decode()]
            for cell_type in connectome.input_cell_types[:]
        ])
        self.n_input_elements = self.input_index.shape[1]
        self.n_samples, self.n_frames, self.n_nodes = (
            n_samples,
            n_frames,
            len(connectome.nodes.type),
        )
        self.connectome = connectome
        if init_buffer:
            self.zero()

    def zero(
        self,
        n_samples: Optional[int] = None,
        n_frames: Optional[int] = None,
    ) -> None:
        """Reset the stimulus buffer to zero.

        Args:
            n_samples: Number of samples. If provided, the buffer will be resized.
            n_frames: Number of frames. If provided, the buffer will be resized.
        """
        self.n_samples = n_samples or self.n_samples
        self.n_frames = n_frames or self.n_frames
        if hasattr(self, "buffer") and self.buffer.shape[:2] == (
            self.n_samples,
            self.n_frames,
        ):
            self.buffer.zero_()
            return
        self.buffer = torch.zeros((self.n_samples, self.n_frames, self.n_nodes))
        self._nonzero = False

    @property
    def nonzero(self) -> bool:
        """Check if elements have been added to the stimulus buffer.

        Returns:
            bool: True if elements have been added, even if those elements were all zero.
        """
        return self._nonzero

    def add_input(
        self,
        x: torch.Tensor,
        start: Optional[int] = None,
        stop: Optional[int] = None,
        n_frames_buffer: Optional[int] = None,
        cumulate: bool = False,
    ) -> None:
        """Add input to the input/photoreceptor cells.

        Args:
            x: Input video of shape (n_samples, n_frames, 1, n_input_elements).
            start: Temporal start index of the stimulus.
            stop: Temporal stop index of the stimulus.
            n_frames_buffer: Number of frames to resize the buffer to.
            cumulate: If True, add input to the existing buffer.

        Raises:
            ValueError: If input shape is incorrect.
            RuntimeError: If input shape doesn't match buffer shape.
        """
        shape = x.shape
        if len(shape) != 4:
            raise ValueError(
                f"input has shape {x.shape} but must have "
                "(n_samples, n_frames, 1, n_input_elements)"
            )
        n_samples, n_frames_input = shape[:2]

        if not hasattr(self, "buffer") or not cumulate and self.nonzero:
            self.zero(n_samples, n_frames_buffer or n_frames_input)

        try:
            self.buffer[:, slice(start, stop), self.input_index] += x.to(
                self.buffer.device
            )
        except RuntimeError as e:
            raise RuntimeError(
                f"input has shape {x.shape} but buffer has shape {self.buffer.shape}"
            ) from e
        self._nonzero = True

    def add_pre_stim(
        self,
        x: torch.Tensor,
        start: Optional[int] = None,
        stop: Optional[int] = None,
        n_frames_buffer: Optional[int] = None,
    ) -> None:
        """Add a constant or sequence of constants to the input/photoreceptor cells.

        Args:
            x: Grey value(s). If Tensor, must have length `n_frames` or `stop - start`.
            start: Start index in time.
            stop: Stop index in time.
            n_frames_buffer: Number of frames to resize the buffer to.

        Raises:
            RuntimeError: If input shape doesn't match buffer shape.
        """
        if not hasattr(self, "buffer") or self.nonzero:
            self.zero(None, n_frames_buffer)

        try:
            if isinstance(x, torch.Tensor) and x.ndim == 1:
                self.buffer[:, slice(start, stop), self.input_index] += x.view(
                    1, len(x), 1, 1
                )
            else:
                self.buffer[:, slice(start, stop), self.input_index] += x
        except RuntimeError as e:
            raise RuntimeError(
                f"input has shape {x.shape} but buffer has shape {self.buffer.shape}"
            ) from e
        self._nonzero = True

    def __call__(self) -> torch.Tensor:
        """Return the stimulus tensor.

        Returns:
            torch.Tensor: The stimulus buffer.
        """
        return self.buffer

nonzero property

nonzero

Check if elements have been added to the stimulus buffer.

Returns:

Name Type Description
bool bool

True if elements have been added, even if those elements were all zero.

zero

zero(n_samples=None, n_frames=None)

Reset the stimulus buffer to zero.

Parameters:

Name Type Description Default
n_samples Optional[int]

Number of samples. If provided, the buffer will be resized.

None
n_frames Optional[int]

Number of frames. If provided, the buffer will be resized.

None
Source code in flyvision/network/stimulus.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def zero(
    self,
    n_samples: Optional[int] = None,
    n_frames: Optional[int] = None,
) -> None:
    """Reset the stimulus buffer to zero.

    Args:
        n_samples: Number of samples. If provided, the buffer will be resized.
        n_frames: Number of frames. If provided, the buffer will be resized.
    """
    self.n_samples = n_samples or self.n_samples
    self.n_frames = n_frames or self.n_frames
    if hasattr(self, "buffer") and self.buffer.shape[:2] == (
        self.n_samples,
        self.n_frames,
    ):
        self.buffer.zero_()
        return
    self.buffer = torch.zeros((self.n_samples, self.n_frames, self.n_nodes))
    self._nonzero = False

add_input

add_input(
    x,
    start=None,
    stop=None,
    n_frames_buffer=None,
    cumulate=False,
)

Add input to the input/photoreceptor cells.

Parameters:

Name Type Description Default
x Tensor

Input video of shape (n_samples, n_frames, 1, n_input_elements).

required
start Optional[int]

Temporal start index of the stimulus.

None
stop Optional[int]

Temporal stop index of the stimulus.

None
n_frames_buffer Optional[int]

Number of frames to resize the buffer to.

None
cumulate bool

If True, add input to the existing buffer.

False

Raises:

Type Description
ValueError

If input shape is incorrect.

RuntimeError

If input shape doesn’t match buffer shape.

Source code in flyvision/network/stimulus.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def add_input(
    self,
    x: torch.Tensor,
    start: Optional[int] = None,
    stop: Optional[int] = None,
    n_frames_buffer: Optional[int] = None,
    cumulate: bool = False,
) -> None:
    """Add input to the input/photoreceptor cells.

    Args:
        x: Input video of shape (n_samples, n_frames, 1, n_input_elements).
        start: Temporal start index of the stimulus.
        stop: Temporal stop index of the stimulus.
        n_frames_buffer: Number of frames to resize the buffer to.
        cumulate: If True, add input to the existing buffer.

    Raises:
        ValueError: If input shape is incorrect.
        RuntimeError: If input shape doesn't match buffer shape.
    """
    shape = x.shape
    if len(shape) != 4:
        raise ValueError(
            f"input has shape {x.shape} but must have "
            "(n_samples, n_frames, 1, n_input_elements)"
        )
    n_samples, n_frames_input = shape[:2]

    if not hasattr(self, "buffer") or not cumulate and self.nonzero:
        self.zero(n_samples, n_frames_buffer or n_frames_input)

    try:
        self.buffer[:, slice(start, stop), self.input_index] += x.to(
            self.buffer.device
        )
    except RuntimeError as e:
        raise RuntimeError(
            f"input has shape {x.shape} but buffer has shape {self.buffer.shape}"
        ) from e
    self._nonzero = True

add_pre_stim

add_pre_stim(
    x, start=None, stop=None, n_frames_buffer=None
)

Add a constant or sequence of constants to the input/photoreceptor cells.

Parameters:

Name Type Description Default
x Tensor

Grey value(s). If Tensor, must have length n_frames or stop - start.

required
start Optional[int]

Start index in time.

None
stop Optional[int]

Stop index in time.

None
n_frames_buffer Optional[int]

Number of frames to resize the buffer to.

None

Raises:

Type Description
RuntimeError

If input shape doesn’t match buffer shape.

Source code in flyvision/network/stimulus.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def add_pre_stim(
    self,
    x: torch.Tensor,
    start: Optional[int] = None,
    stop: Optional[int] = None,
    n_frames_buffer: Optional[int] = None,
) -> None:
    """Add a constant or sequence of constants to the input/photoreceptor cells.

    Args:
        x: Grey value(s). If Tensor, must have length `n_frames` or `stop - start`.
        start: Start index in time.
        stop: Stop index in time.
        n_frames_buffer: Number of frames to resize the buffer to.

    Raises:
        RuntimeError: If input shape doesn't match buffer shape.
    """
    if not hasattr(self, "buffer") or self.nonzero:
        self.zero(None, n_frames_buffer)

    try:
        if isinstance(x, torch.Tensor) and x.ndim == 1:
            self.buffer[:, slice(start, stop), self.input_index] += x.view(
                1, len(x), 1, 1
            )
        else:
            self.buffer[:, slice(start, stop), self.input_index] += x
    except RuntimeError as e:
        raise RuntimeError(
            f"input has shape {x.shape} but buffer has shape {self.buffer.shape}"
        ) from e
    self._nonzero = True

__call__

__call__()

Return the stimulus tensor.

Returns:

Type Description
Tensor

torch.Tensor: The stimulus buffer.

Source code in flyvision/network/stimulus.py
249
250
251
252
253
254
255
def __call__(self) -> torch.Tensor:
    """Return the stimulus tensor.

    Returns:
        torch.Tensor: The stimulus buffer.
    """
    return self.buffer

flyvision.network.stimulus.register_stimulus

register_stimulus(cls=None)

Register a stimulus class.

Parameters:

Name Type Description Default
cls Optional[Type[StimulusProtocol]]

The stimulus class to register (optional when used as a decorator).

None

Returns:

Type Description
Union[Callable[[Type[StimulusProtocol]], Type[StimulusProtocol]], Type[StimulusProtocol]]

Registered class or decorator function.

Example

As a standalone function:

register_stimulus(CustomStimulus)

As a decorator:

@register_stimulus
class CustomStimulus(StimulusProtocol): ...

Source code in flyvision/network/stimulus.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def register_stimulus(
    cls: Optional[Type[StimulusProtocol]] = None,
) -> Union[
    Callable[[Type[StimulusProtocol]], Type[StimulusProtocol]], Type[StimulusProtocol]
]:
    """Register a stimulus class.

    Args:
        cls: The stimulus class to register (optional when used as a decorator).

    Returns:
        Registered class or decorator function.

    Example:
        As a standalone function:
        ```python
        register_stimulus(CustomStimulus)
        ```

        As a decorator:
        ```python
        @register_stimulus
        class CustomStimulus(StimulusProtocol): ...
        ```
    """

    def decorator(cls: Type[StimulusProtocol]) -> Type[StimulusProtocol]:
        AVAILABLE_STIMULI[cls.__name__] = cls
        return cls

    if cls is None:
        return decorator
    else:
        return decorator(cls)

flyvision.network.stimulus.init_stimulus

init_stimulus(connectome, **kwargs)
Source code in flyvision/network/stimulus.py
262
263
264
265
266
def init_stimulus(connectome: ConnectomeFromAvgFilters, **kwargs) -> StimulusProtocol:
    if "type" not in kwargs:
        return None
    stimulus_class = AVAILABLE_STIMULI[kwargs.pop("type")]
    return stimulus_class(connectome, **kwargs)

Dynamics

flyvision.network.dynamics.NetworkDynamics

Defines the initialization and behavior of a Network during simulation.

This class serves as an extension point for implementing custom network dynamics models. Subclasses must implement the following methods:

  • write_derived_params
  • write_initial_state
  • write_state_velocity

Attributes:

Name Type Description
activation Module

The activation function for the network.

Parameters:

Name Type Description Default
activation dict

A dictionary specifying the activation function type and its parameters.

{'type': 'relu'}
Source code in flyvision/network/dynamics.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class NetworkDynamics:
    """
    Defines the initialization and behavior of a Network during simulation.

    This class serves as an extension point for implementing custom network dynamics
    models. Subclasses must implement the following methods:

    - write_derived_params
    - write_initial_state
    - write_state_velocity

    Attributes:
        activation (nn.Module): The activation function for the network.

    Args:
        activation (dict): A dictionary specifying the activation function type and
            its parameters.
    """

    def __init__(self, activation: Dict[str, str] = {"type": "relu"}):
        self.activation = activation_fns[activation.pop("type")](**activation)

    def write_derived_params(
        self, params: AutoDeref[str, AutoDeref[str, RefTensor]], **kwargs
    ) -> None:
        """
        Augment `params`, called once at every forward pass.

        Args:
            params: A directory containing two subdirectories: `nodes` and
                `edges`, containing node and edges parameters, respectively.
            **kwargs: Additional keyword arguments.

        Note:
            This is called once per forward pass at the beginning. It's required
            after parameters have been updated by an optimizer but not at every
            timestep. Called by Network._param_api.
        """
        pass

    def write_initial_state(
        self,
        state: AutoDeref[str, AutoDeref[str, RefTensor]],
        params: AutoDeref[str, AutoDeref[str, RefTensor]],
        **kwargs,
    ) -> None:
        """
        Initialize a network's state variables from its network parameters.

        Args:
            state: A directory containing two subdirectories: `nodes` and
                `edges`. Write initial node and edge state variable values, as
                1D tensors, into them, respectively.
            params: A directory containing four subdirectories: `nodes`,
                `edges`, `sources`, and `targets`. `nodes` and `edges` contain
                node and edges parameters, respectively. `sources` and
                `targets` provide access to the node parameters associated with
                the source node and target node of each edge, respectively.
            **kwargs: Additional keyword arguments.

        Note:
            Called by Network._initial_state.
        """
        pass

    def write_state_velocity(
        self,
        vel: AutoDeref[str, AutoDeref[str, RefTensor]],
        state: AutoDeref[str, AutoDeref[str, RefTensor]],
        params: AutoDeref[str, AutoDeref[str, RefTensor]],
        target_sum: Callable,
        **kwargs,
    ) -> None:
        """
        Compute dx/dt for each state variable.

        Args:
            vel: A directory containing two subdirectories: `nodes` and
                `edges`. Write dx/dt for node and edge state variables
                into them, respectively.
            state: A directory containing two subdirectories: `nodes` and
                `edges`, containing node and edge state variable values,
                respectively.
            params: A directory containing four subdirectories: `nodes`,
                `edges`, `sources`, and `targets`. `nodes` and `edges` contain
                node and edges parameters, respectively. `sources` and
                `targets` provide access to the node parameters associated with
                the source node and target node of each edge, respectively.
            target_sum: Sums the entries in a `len(edges)` tensor corresponding
                to edges with the same target node, yielding a `len(nodes)`
                tensor.
            **kwargs: Additional keyword arguments.

        Note:
            Called by Network._next_state.
        """
        pass

    def currents(
        self,
        state: AutoDeref[str, AutoDeref[str, RefTensor]],
        params: AutoDeref[str, AutoDeref[str, RefTensor]],
    ) -> torch.Tensor:
        """
        Compute the current flowing through each edge.

        Args:
            state: A directory containing two subdirectories: `nodes` and
                `edges`, containing node and edge state variable values,
                respectively.
            params: A directory containing four subdirectories: `nodes`,
                `edges`, `sources`, and `targets`. `nodes` and `edges` contain
                node and edges parameters, respectively. `sources` and
                `targets` provide access to the node parameters associated with
                the source node and target node of each edge, respectively.

        Returns:
            A tensor of currents flowing through each edge.

        Note:
            Called by Network.current_response.
        """
        pass

write_derived_params

write_derived_params(params, **kwargs)

Augment params, called once at every forward pass.

Parameters:

Name Type Description Default
params AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing two subdirectories: nodes and edges, containing node and edges parameters, respectively.

required
**kwargs

Additional keyword arguments.

{}
Note

This is called once per forward pass at the beginning. It’s required after parameters have been updated by an optimizer but not at every timestep. Called by Network._param_api.

Source code in flyvision/network/dynamics.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def write_derived_params(
    self, params: AutoDeref[str, AutoDeref[str, RefTensor]], **kwargs
) -> None:
    """
    Augment `params`, called once at every forward pass.

    Args:
        params: A directory containing two subdirectories: `nodes` and
            `edges`, containing node and edges parameters, respectively.
        **kwargs: Additional keyword arguments.

    Note:
        This is called once per forward pass at the beginning. It's required
        after parameters have been updated by an optimizer but not at every
        timestep. Called by Network._param_api.
    """
    pass

write_initial_state

write_initial_state(state, params, **kwargs)

Initialize a network’s state variables from its network parameters.

Parameters:

Name Type Description Default
state AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing two subdirectories: nodes and edges. Write initial node and edge state variable values, as 1D tensors, into them, respectively.

required
params AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing four subdirectories: nodes, edges, sources, and targets. nodes and edges contain node and edges parameters, respectively. sources and targets provide access to the node parameters associated with the source node and target node of each edge, respectively.

required
**kwargs

Additional keyword arguments.

{}
Note

Called by Network._initial_state.

Source code in flyvision/network/dynamics.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def write_initial_state(
    self,
    state: AutoDeref[str, AutoDeref[str, RefTensor]],
    params: AutoDeref[str, AutoDeref[str, RefTensor]],
    **kwargs,
) -> None:
    """
    Initialize a network's state variables from its network parameters.

    Args:
        state: A directory containing two subdirectories: `nodes` and
            `edges`. Write initial node and edge state variable values, as
            1D tensors, into them, respectively.
        params: A directory containing four subdirectories: `nodes`,
            `edges`, `sources`, and `targets`. `nodes` and `edges` contain
            node and edges parameters, respectively. `sources` and
            `targets` provide access to the node parameters associated with
            the source node and target node of each edge, respectively.
        **kwargs: Additional keyword arguments.

    Note:
        Called by Network._initial_state.
    """
    pass

write_state_velocity

write_state_velocity(
    vel, state, params, target_sum, **kwargs
)

Compute dx/dt for each state variable.

Parameters:

Name Type Description Default
vel AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing two subdirectories: nodes and edges. Write dx/dt for node and edge state variables into them, respectively.

required
state AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing two subdirectories: nodes and edges, containing node and edge state variable values, respectively.

required
params AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing four subdirectories: nodes, edges, sources, and targets. nodes and edges contain node and edges parameters, respectively. sources and targets provide access to the node parameters associated with the source node and target node of each edge, respectively.

required
target_sum Callable

Sums the entries in a len(edges) tensor corresponding to edges with the same target node, yielding a len(nodes) tensor.

required
**kwargs

Additional keyword arguments.

{}
Note

Called by Network._next_state.

Source code in flyvision/network/dynamics.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def write_state_velocity(
    self,
    vel: AutoDeref[str, AutoDeref[str, RefTensor]],
    state: AutoDeref[str, AutoDeref[str, RefTensor]],
    params: AutoDeref[str, AutoDeref[str, RefTensor]],
    target_sum: Callable,
    **kwargs,
) -> None:
    """
    Compute dx/dt for each state variable.

    Args:
        vel: A directory containing two subdirectories: `nodes` and
            `edges`. Write dx/dt for node and edge state variables
            into them, respectively.
        state: A directory containing two subdirectories: `nodes` and
            `edges`, containing node and edge state variable values,
            respectively.
        params: A directory containing four subdirectories: `nodes`,
            `edges`, `sources`, and `targets`. `nodes` and `edges` contain
            node and edges parameters, respectively. `sources` and
            `targets` provide access to the node parameters associated with
            the source node and target node of each edge, respectively.
        target_sum: Sums the entries in a `len(edges)` tensor corresponding
            to edges with the same target node, yielding a `len(nodes)`
            tensor.
        **kwargs: Additional keyword arguments.

    Note:
        Called by Network._next_state.
    """
    pass

currents

currents(state, params)

Compute the current flowing through each edge.

Parameters:

Name Type Description Default
state AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing two subdirectories: nodes and edges, containing node and edge state variable values, respectively.

required
params AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing four subdirectories: nodes, edges, sources, and targets. nodes and edges contain node and edges parameters, respectively. sources and targets provide access to the node parameters associated with the source node and target node of each edge, respectively.

required

Returns:

Type Description
Tensor

A tensor of currents flowing through each edge.

Note

Called by Network.current_response.

Source code in flyvision/network/dynamics.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def currents(
    self,
    state: AutoDeref[str, AutoDeref[str, RefTensor]],
    params: AutoDeref[str, AutoDeref[str, RefTensor]],
) -> torch.Tensor:
    """
    Compute the current flowing through each edge.

    Args:
        state: A directory containing two subdirectories: `nodes` and
            `edges`, containing node and edge state variable values,
            respectively.
        params: A directory containing four subdirectories: `nodes`,
            `edges`, `sources`, and `targets`. `nodes` and `edges` contain
            node and edges parameters, respectively. `sources` and
            `targets` provide access to the node parameters associated with
            the source node and target node of each edge, respectively.

    Returns:
        A tensor of currents flowing through each edge.

    Note:
        Called by Network.current_response.
    """
    pass

flyvision.network.dynamics.PPNeuronIGRSynapses

Bases: NetworkDynamics

Passive point neurons with instantaneous graded release synapses.

Source code in flyvision/network/dynamics.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
class PPNeuronIGRSynapses(NetworkDynamics):
    """Passive point neurons with instantaneous graded release synapses."""

    def write_derived_params(
        self, params: AutoDeref[str, AutoDeref[str, RefTensor]], **kwargs
    ) -> None:
        """
        Calculate weights as the product of sign, synapse count, and strength.

        Args:
            params: A directory containing edge parameters.
            **kwargs: Additional keyword arguments.
        """
        params.edges.weight = (
            params.edges.sign * params.edges.syn_count * params.edges.syn_strength
        )

    def write_initial_state(
        self,
        state: AutoDeref[str, AutoDeref[str, RefTensor]],
        params: AutoDeref[str, AutoDeref[str, RefTensor]],
        **kwargs,
    ) -> None:
        """
        Set the initial state to the bias.

        Args:
            state: A directory to write the initial state.
            params: A directory containing node parameters.
            **kwargs: Additional keyword arguments.
        """
        state.nodes.activity = params.nodes.bias

    def write_state_velocity(
        self,
        vel: AutoDeref[str, AutoDeref[str, RefTensor]],
        state: AutoDeref[str, AutoDeref[str, RefTensor]],
        params: AutoDeref[str, AutoDeref[str, RefTensor]],
        target_sum: Callable,
        x_t: torch.Tensor,
        dt: float,
        **kwargs,
    ) -> None:
        """
        Calculate velocity as bias plus sum of weighted rectified inputs.

        Args:
            vel: A directory to write the calculated velocity.
            state: A directory containing current state values.
            params: A directory containing node and edge parameters.
            target_sum: Function to sum edge values for each target node.
            x_t: External input at time t.
            dt: Time step.
            **kwargs: Additional keyword arguments.
        """
        vel.nodes.activity = (
            1
            / torch.max(params.nodes.time_const, torch.tensor(dt).float())
            * (
                -state.nodes.activity
                + params.nodes.bias
                + target_sum(
                    params.edges.weight * self.activation(state.sources.activity)
                )  # internal chemical current
                + x_t
            )
        )

    def currents(
        self,
        state: AutoDeref[str, AutoDeref[str, RefTensor]],
        params: AutoDeref[str, AutoDeref[str, RefTensor]],
    ) -> torch.Tensor:
        """
        Calculate the internal chemical current.

        Args:
            state: A directory containing current state values.
            params: A directory containing edge parameters.

        Returns:
            torch.Tensor: The calculated internal chemical current.
        """
        return params.edges.weight * self.activation(state.sources.activity)

write_derived_params

write_derived_params(params, **kwargs)

Calculate weights as the product of sign, synapse count, and strength.

Parameters:

Name Type Description Default
params AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing edge parameters.

required
**kwargs

Additional keyword arguments.

{}
Source code in flyvision/network/dynamics.py
155
156
157
158
159
160
161
162
163
164
165
166
167
def write_derived_params(
    self, params: AutoDeref[str, AutoDeref[str, RefTensor]], **kwargs
) -> None:
    """
    Calculate weights as the product of sign, synapse count, and strength.

    Args:
        params: A directory containing edge parameters.
        **kwargs: Additional keyword arguments.
    """
    params.edges.weight = (
        params.edges.sign * params.edges.syn_count * params.edges.syn_strength
    )

write_initial_state

write_initial_state(state, params, **kwargs)

Set the initial state to the bias.

Parameters:

Name Type Description Default
state AutoDeref[str, AutoDeref[str, RefTensor]]

A directory to write the initial state.

required
params AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing node parameters.

required
**kwargs

Additional keyword arguments.

{}
Source code in flyvision/network/dynamics.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def write_initial_state(
    self,
    state: AutoDeref[str, AutoDeref[str, RefTensor]],
    params: AutoDeref[str, AutoDeref[str, RefTensor]],
    **kwargs,
) -> None:
    """
    Set the initial state to the bias.

    Args:
        state: A directory to write the initial state.
        params: A directory containing node parameters.
        **kwargs: Additional keyword arguments.
    """
    state.nodes.activity = params.nodes.bias

write_state_velocity

write_state_velocity(
    vel, state, params, target_sum, x_t, dt, **kwargs
)

Calculate velocity as bias plus sum of weighted rectified inputs.

Parameters:

Name Type Description Default
vel AutoDeref[str, AutoDeref[str, RefTensor]]

A directory to write the calculated velocity.

required
state AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing current state values.

required
params AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing node and edge parameters.

required
target_sum Callable

Function to sum edge values for each target node.

required
x_t Tensor

External input at time t.

required
dt float

Time step.

required
**kwargs

Additional keyword arguments.

{}
Source code in flyvision/network/dynamics.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def write_state_velocity(
    self,
    vel: AutoDeref[str, AutoDeref[str, RefTensor]],
    state: AutoDeref[str, AutoDeref[str, RefTensor]],
    params: AutoDeref[str, AutoDeref[str, RefTensor]],
    target_sum: Callable,
    x_t: torch.Tensor,
    dt: float,
    **kwargs,
) -> None:
    """
    Calculate velocity as bias plus sum of weighted rectified inputs.

    Args:
        vel: A directory to write the calculated velocity.
        state: A directory containing current state values.
        params: A directory containing node and edge parameters.
        target_sum: Function to sum edge values for each target node.
        x_t: External input at time t.
        dt: Time step.
        **kwargs: Additional keyword arguments.
    """
    vel.nodes.activity = (
        1
        / torch.max(params.nodes.time_const, torch.tensor(dt).float())
        * (
            -state.nodes.activity
            + params.nodes.bias
            + target_sum(
                params.edges.weight * self.activation(state.sources.activity)
            )  # internal chemical current
            + x_t
        )
    )

currents

currents(state, params)

Calculate the internal chemical current.

Parameters:

Name Type Description Default
state AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing current state values.

required
params AutoDeref[str, AutoDeref[str, RefTensor]]

A directory containing edge parameters.

required

Returns:

Type Description
Tensor

torch.Tensor: The calculated internal chemical current.

Source code in flyvision/network/dynamics.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def currents(
    self,
    state: AutoDeref[str, AutoDeref[str, RefTensor]],
    params: AutoDeref[str, AutoDeref[str, RefTensor]],
) -> torch.Tensor:
    """
    Calculate the internal chemical current.

    Args:
        state: A directory containing current state values.
        params: A directory containing edge parameters.

    Returns:
        torch.Tensor: The calculated internal chemical current.
    """
    return params.edges.weight * self.activation(state.sources.activity)

Initialization

flyvision.network.initialization

The parameters that the networks can be initialized with. Each parameter is a type on its own, because different parameters are shared differently. These types handle the initialization of indices to perform gather and scatter opera- tions. Parameter types can be initialized from a range of initial distribution types.

InitialDistribution

Initial distribution base class.

Attributes:

Name Type Description
raw_values Tensor

Initial parameters must store raw_values as attribute in their init.

readers Dict[str, Tensor]

Readers will be written by the network during initialization.

Note

To add a new initial distribution type, subclass this class and implement the init method. The init method should take the param_config as its first argument, and should store the attribute raw_values as a torch.nn.Parameter.

Example

An example of a viable param_config is:

param_config = Namespace(
    requires_grad=True,
    initial_dist="Normal",
    mean=0,
    std=1,
    mode="sample",
)

Source code in flyvision/network/initialization.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class InitialDistribution:
    """Initial distribution base class.

    Attributes:
        raw_values (Tensor): Initial parameters must store raw_values as attribute in
            their __init__.
        readers (Dict[str, Tensor]): Readers will be written by the network during
            initialization.

    Note:
        To add a new initial distribution type, subclass this class and implement the
        __init__ method. The __init__ method should take the param_config as its first
        argument, and should store the attribute raw_values as a torch.nn.Parameter.

    Example:
        An example of a viable param_config is:
        ```python
        param_config = Namespace(
            requires_grad=True,
            initial_dist="Normal",
            mean=0,
            std=1,
            mode="sample",
        )
        ```
    """

    raw_values: Tensor
    readers: Dict[str, Tensor]

    @property
    def semantic_values(self):
        """Optional reparametrization of raw values invoked for computation."""
        return self.raw_values

    def __repr__(self):
        return f"{self.__class__.__name__} (semantic values): \n{self.semantic_values}"

    def __len__(self):
        return len(self.raw_values)

    def clamp(self, values, mode):
        """To clamp the raw_values of the parameters at initialization.

        Note, mild clash with raw_values/semantic_values reparametrization.
        Parameters that use reparametrization in terms of semantic_values
        should not use clamp.
        """
        if mode == "non_negative":
            values.clamp_(min=0)
        elif isinstance(mode, Iterable) and len(mode) == 2:
            values.clamp_(*mode)
        elif mode in [False, None]:
            return values
        else:
            raise ParameterConfigError(f"{mode} not a valid argument for clamp")
        return values
semantic_values property
semantic_values

Optional reparametrization of raw values invoked for computation.

clamp
clamp(values, mode)

To clamp the raw_values of the parameters at initialization.

Note, mild clash with raw_values/semantic_values reparametrization. Parameters that use reparametrization in terms of semantic_values should not use clamp.

Source code in flyvision/network/initialization.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def clamp(self, values, mode):
    """To clamp the raw_values of the parameters at initialization.

    Note, mild clash with raw_values/semantic_values reparametrization.
    Parameters that use reparametrization in terms of semantic_values
    should not use clamp.
    """
    if mode == "non_negative":
        values.clamp_(min=0)
    elif isinstance(mode, Iterable) and len(mode) == 2:
        values.clamp_(*mode)
    elif mode in [False, None]:
        return values
    else:
        raise ParameterConfigError(f"{mode} not a valid argument for clamp")
    return values

Value

Bases: InitialDistribution

Initializes parameters with a single value.

Parameters:

Name Type Description Default
value

The value to initialize the parameter with.

required
requires_grad bool

Whether the parameter requires gradients.

required
clamp bool

Whether to clamp the values. Defaults to False.

False
**kwargs

Additional keyword arguments.

{}
Example
param_config = Namespace(
    requires_grad=True,
    initial_dist="Value",
    value=0,
)
Source code in flyvision/network/initialization.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class Value(InitialDistribution):
    """Initializes parameters with a single value.

    Args:
        value: The value to initialize the parameter with.
        requires_grad (bool): Whether the parameter requires gradients.
        clamp (bool, optional): Whether to clamp the values. Defaults to False.
        **kwargs: Additional keyword arguments.

    Example:
        ```python
        param_config = Namespace(
            requires_grad=True,
            initial_dist="Value",
            value=0,
        )
        ```
    """

    def __init__(self, value, requires_grad, clamp=False, **kwargs) -> None:
        _values = torch.tensor(value).float()
        _values = self.clamp(_values, clamp)
        self.raw_values = nn.Parameter(_values, requires_grad=requires_grad)

Normal

Bases: InitialDistribution

Initializes parameters independently from normal distributions.

Parameters:

Name Type Description Default
mean

The mean of the normal distribution.

required
std

The standard deviation of the normal distribution.

required
requires_grad bool

Whether the parameter requires gradients.

required
mode str

The initialization mode. Defaults to “sample”.

'sample'
clamp bool

Whether to clamp the values. Defaults to False.

False
seed int

Random seed for reproducibility. Defaults to None.

None
**kwargs

Additional keyword arguments.

{}
Example
param_config = Namespace(
    requires_grad=True,
    initial_dist="Normal",
    mean=0,
    std=1,
    mode="sample",
)
Source code in flyvision/network/initialization.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
class Normal(InitialDistribution):
    """Initializes parameters independently from normal distributions.

    Args:
        mean: The mean of the normal distribution.
        std: The standard deviation of the normal distribution.
        requires_grad (bool): Whether the parameter requires gradients.
        mode (str, optional): The initialization mode. Defaults to "sample".
        clamp (bool, optional): Whether to clamp the values. Defaults to False.
        seed (int, optional): Random seed for reproducibility. Defaults to None.
        **kwargs: Additional keyword arguments.

    Example:
        ```python
        param_config = Namespace(
            requires_grad=True,
            initial_dist="Normal",
            mean=0,
            std=1,
            mode="sample",
        )
        ```
    """

    def __init__(
        self, mean, std, requires_grad, mode="sample", clamp=False, seed=None, **kwargs
    ) -> None:
        if mode == "mean":
            _values = torch.tensor(mean).float()
        elif mode == "sample":
            # set seed for reproducibility and avoid seeding the global RNG
            generator = torch.Generator(device=device)
            if seed is not None:
                generator.manual_seed(seed)
            else:
                generator.seed()
            try:
                _values = torch.normal(
                    torch.tensor(mean).float(),
                    torch.tensor(std).float(),
                    generator=generator,
                )
            except RuntimeError as e:
                raise RuntimeError(
                    f"Failed to sample from normal with mean {mean} and std {std}"
                ) from e
        else:
            raise ValueError("Mode must be either mean or sample.")
        _values = self.clamp(_values, clamp)
        self.raw_values = nn.Parameter(_values, requires_grad=requires_grad)

Lognormal

Bases: Normal

Initializes parameters independently from lognormal distributions.

Note

The lognormal distribution reparametrizes a normal through semantic values.

Example
param_config = Namespace(
    requires_grad=True,
    initial_dist="Lognormal",
    mean=0,
    std=1,
    mode="sample",
)
Source code in flyvision/network/initialization.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class Lognormal(Normal):
    """Initializes parameters independently from lognormal distributions.

    Note:
        The lognormal distribution reparametrizes a normal through semantic values.

    Example:
        ```python
        param_config = Namespace(
            requires_grad=True,
            initial_dist="Lognormal",
            mean=0,
            std=1,
            mode="sample",
        )
        ```
    """

    @property
    def semantic_values(self):
        """n_syn ~ self._values.exp()."""
        return self.raw_values.exp()
semantic_values property
semantic_values

n_syn ~ self._values.exp().

Parameter

Base class for all parameters to share across nodes or edges.

Parameters:

Name Type Description Default
param_config Namespace

Namespace containing parameter configuration.

required
connectome Connectome

Connectome object.

required

Attributes:

Name Type Description
parameter InitialDistribution

InitialDistribution object.

indices Tensor

Indices for parameter sharing.

keys List[Any]

Keys to access individual parameter values associated with certain identifiers.

symmetry_masks List[Tensor]

Symmetry masks that can be configured optionally to apply further symmetry constraints to the parameter values.

Note

Subclasses must implement __init__(self, param_config, connectome_dir) with the following requirements:

  1. Configure all attributes defined in the base class.
  2. Decorate __init__ with @deepcopy_config if it updates param_config to prevent mutations in the outer scope.
  3. Update param_config with key-value pairs informed by connectome and matching the desired InitialDistribution.
  4. Store parameter from InitialDistribution(param_config), which constructs and holds the nn.Parameter.
  5. Store indices for parameter sharing using get_scatter_indices(dataframe, grouped_dataframe, groupby).
  6. Store keys to access individual parameter values associated with certain identifiers.
  7. Store symmetry_masks (optional) to apply further symmetry constraints to the parameter values.

Example implementation structure:

@deepcopy_config
def __init__(self, param_config: Namespace, connectome: Connectome):
    # Update param_config based on connectome data
    # ...

    # Initialize parameter
    self.parameter = InitialDistribution(param_config)

    # Set up indices, keys, and symmetry masks
    self.indices = get_scatter_indices(...)
    self.keys = ...
    self.symmetry_masks = symmetry_masks(...)
Source code in flyvision/network/initialization.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
class Parameter:
    """Base class for all parameters to share across nodes or edges.

    Args:
        param_config (Namespace): Namespace containing parameter configuration.
        connectome (Connectome): Connectome object.

    Attributes:
        parameter (InitialDistribution): InitialDistribution object.
        indices (torch.Tensor): Indices for parameter sharing.
        keys (List[Any]): Keys to access individual parameter values associated with
            certain identifiers.
        symmetry_masks (List[torch.Tensor]): Symmetry masks that can be configured
            optionally to apply further symmetry constraints to the parameter values.

    Note:
        Subclasses must implement `__init__(self, param_config, connectome_dir)` with the
        following requirements:

        1. Configure all attributes defined in the base class.
        2. Decorate `__init__` with `@deepcopy_config` if it updates `param_config` to
           prevent mutations in the outer scope.
        3. Update `param_config` with key-value pairs informed by `connectome` and
           matching the desired `InitialDistribution`.
        4. Store `parameter` from `InitialDistribution(param_config)`, which constructs
           and holds the `nn.Parameter`.
        5. Store `indices` for parameter sharing using
           `get_scatter_indices(dataframe, grouped_dataframe, groupby)`.
        6. Store `keys` to access individual parameter values associated with certain
           identifiers.
        7. Store `symmetry_masks` (optional) to apply further symmetry constraints to
           the parameter values.

        Example implementation structure:

        ```python
        @deepcopy_config
        def __init__(self, param_config: Namespace, connectome: Connectome):
            # Update param_config based on connectome data
            # ...

            # Initialize parameter
            self.parameter = InitialDistribution(param_config)

            # Set up indices, keys, and symmetry masks
            self.indices = get_scatter_indices(...)
            self.keys = ...
            self.symmetry_masks = symmetry_masks(...)
        ```
    """

    parameter: InitialDistribution
    indices: torch.Tensor
    symmetry_masks: List[torch.Tensor]
    keys: List[Any]

    @deepcopy_config
    def __init__(self, param_config: Namespace, connectome: ConnectomeFromAvgFilters):
        pass

    def __repr__(self):
        """Return a string representation of the Parameter object."""
        init_arg_names = list(self.__init__.__annotations__.keys())
        dir_type = self.__init__.__annotations__[init_arg_names[1]].__name__
        return f"{self.__class__.__name__}({self.config}, {dir_type})"

    def __getitem__(self, key):
        """Get parameter value for a given key."""
        if key in self.keys:
            if self.parameter.raw_values.dim() == 0:
                return self.parameter.raw_values
            return self.parameter.raw_values[self.keys.index(key)]
        else:
            raise ValueError(key)

    def __len__(self):
        """Return the length of raw_values."""
        return len(self.raw_values)

    @property
    def raw_values(self) -> torch.Tensor:
        """Get raw parameter values."""
        return self.parameter.raw_values

    @property
    def semantic_values(self) -> torch.Tensor:
        """Get semantic parameter values."""
        return self.parameter.semantic_values

    @property
    def readers(self) -> Dict[str, torch.Tensor]:
        """Get parameter readers."""
        return self.parameter.readers

    @readers.setter
    def readers(self, value) -> None:
        """Set parameter readers."""
        self.parameter.readers = value

    def _symmetry(self):
        """Return symmetry constraints from symmetry masks for debugging."""
        keys = np.array(self.keys)
        return [keys[mask.cpu()] for mask in self.symmetry_masks]
raw_values property
raw_values

Get raw parameter values.

semantic_values property
semantic_values

Get semantic parameter values.

readers property writable
readers

Get parameter readers.

__repr__
__repr__()

Return a string representation of the Parameter object.

Source code in flyvision/network/initialization.py
287
288
289
290
291
def __repr__(self):
    """Return a string representation of the Parameter object."""
    init_arg_names = list(self.__init__.__annotations__.keys())
    dir_type = self.__init__.__annotations__[init_arg_names[1]].__name__
    return f"{self.__class__.__name__}({self.config}, {dir_type})"
__getitem__
__getitem__(key)

Get parameter value for a given key.

Source code in flyvision/network/initialization.py
293
294
295
296
297
298
299
300
def __getitem__(self, key):
    """Get parameter value for a given key."""
    if key in self.keys:
        if self.parameter.raw_values.dim() == 0:
            return self.parameter.raw_values
        return self.parameter.raw_values[self.keys.index(key)]
    else:
        raise ValueError(key)
__len__
__len__()

Return the length of raw_values.

Source code in flyvision/network/initialization.py
302
303
304
def __len__(self):
    """Return the length of raw_values."""
    return len(self.raw_values)

RestingPotential

Bases: Parameter

Initialize resting potentials a.k.a. biases for cell types.

Source code in flyvision/network/initialization.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
class RestingPotential(Parameter):
    """Initialize resting potentials a.k.a. biases for cell types."""

    @deepcopy_config
    def __init__(self, param_config: Namespace, connectome: ConnectomeFromAvgFilters):
        nodes_dir = connectome.nodes

        nodes = pd.DataFrame({
            k: byte_to_str(nodes_dir[k][:]) for k in param_config.groupby
        })
        grouped_nodes = nodes.groupby(
            param_config.groupby, as_index=False, sort=False
        ).first()

        param_config["type"] = grouped_nodes["type"].values
        param_config["mean"] = np.repeat(param_config["mean"], len(grouped_nodes))
        param_config["std"] = np.repeat(param_config["std"], len(grouped_nodes))

        self.parameter = forward_subclass(
            InitialDistribution, param_config, subclass_key="initial_dist"
        )
        self.indices = get_scatter_indices(nodes, grouped_nodes, param_config.groupby)
        self.keys = param_config["type"].tolist()
        self.symmetry_masks = symmetry_masks(param_config.get("symmetric", []), self.keys)

TimeConstant

Bases: Parameter

Initialize time constants for cell types.

Source code in flyvision/network/initialization.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
class TimeConstant(Parameter):
    """Initialize time constants for cell types."""

    @deepcopy_config
    def __init__(self, param_config: Namespace, connectome: ConnectomeFromAvgFilters):
        nodes_dir = connectome.nodes

        nodes = pd.DataFrame({
            k: byte_to_str(nodes_dir[k][:]) for k in param_config.groupby
        })
        grouped_nodes = nodes.groupby(
            param_config.groupby, as_index=False, sort=False
        ).first()

        param_config["type"] = grouped_nodes["type"].values
        param_config["value"] = np.repeat(param_config["value"], len(grouped_nodes))

        self.indices = get_scatter_indices(nodes, grouped_nodes, param_config.groupby)
        self.parameter = forward_subclass(
            InitialDistribution, param_config, subclass_key="initial_dist"
        )
        self.keys = param_config["type"].tolist()
        self.symmetry_masks = symmetry_masks(param_config.get("symmetric", []), self.keys)

SynapseSign

Bases: Parameter

Initialize synapse signs for edge types.

Source code in flyvision/network/initialization.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
class SynapseSign(Parameter):
    """Initialize synapse signs for edge types."""

    @deepcopy_config
    def __init__(
        self, param_config: Namespace, connectome: ConnectomeFromAvgFilters
    ) -> None:
        edges_dir = connectome.edges

        edges = pd.DataFrame({
            k: byte_to_str(edges_dir[k][:]) for k in [*param_config.groupby, "sign"]
        })
        grouped_edges = edges.groupby(
            param_config.groupby, as_index=False, sort=False
        ).first()

        param_config.source_type = grouped_edges.source_type.values
        param_config.target_type = grouped_edges.target_type.values
        param_config.value = grouped_edges.sign.values

        self.indices = get_scatter_indices(edges, grouped_edges, param_config.groupby)
        self.parameter = forward_subclass(
            InitialDistribution, param_config, subclass_key="initial_dist"
        )
        self.keys = list(
            zip(
                param_config.source_type.tolist(),
                param_config.target_type.tolist(),
            )
        )
        self.symmetry_masks = symmetry_masks(param_config.get("symmetric", []), self.keys)

SynapseCount

Bases: Parameter

Initialize synapse counts for edge types.

Source code in flyvision/network/initialization.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
class SynapseCount(Parameter):
    """Initialize synapse counts for edge types."""

    @deepcopy_config
    def __init__(
        self, param_config: Namespace, connectome: ConnectomeFromAvgFilters
    ) -> None:
        mode = param_config.get("mode", "")
        if mode != "mean":
            raise NotImplementedError(
                f"SynapseCount does not implement {mode}. Implement "
                "a custom Parameter subclass."
            )

        edges_dir = connectome.edges

        edges = pd.DataFrame({
            k: byte_to_str(edges_dir[k][:]) for k in [*param_config.groupby, "n_syn"]
        })
        grouped_edges = edges.groupby(
            param_config.groupby, as_index=False, sort=False
        ).mean()

        param_config.source_type = grouped_edges.source_type.values
        param_config.target_type = grouped_edges.target_type.values
        param_config.du = grouped_edges.du.values
        param_config.dv = grouped_edges.dv.values

        param_config.mode = "mean"
        param_config.mean = np.log(grouped_edges.n_syn.values)

        self.indices = get_scatter_indices(edges, grouped_edges, param_config.groupby)
        self.parameter = forward_subclass(
            InitialDistribution, param_config, subclass_key="initial_dist"
        )
        self.keys = list(
            zip(
                param_config.source_type.tolist(),
                param_config.target_type.tolist(),
                param_config.du.tolist(),
                param_config.dv.tolist(),
            )
        )
        self.symmetry_masks = symmetry_masks(param_config.get("symmetric", []), self.keys)

SynapseCountScaling

Bases: Parameter

Initialize synapse count scaling for edge types.

This class initializes synapse strengths based on the average synapse count for each edge type, scaling them differently for chemical and electrical synapses.

The initialization follows this equation:

\[\alpha_{t_it_j} =\frac{\rho}{\langle N \rangle_{t_it_j}}\]

where:

  1. \(\alpha_{t_it_j}\) is the synapse strength between neurons \(i\) and \(j\).
  2. \(\langle N \rangle_{t_it_j}\) is the average synapse count for the edge type across columnar offsets \(u_i-u_j\) and \(v_i-v_j\)
  3. \(\rho\) is a scaling factor (default: 0.01)
Source code in flyvision/network/initialization.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
class SynapseCountScaling(Parameter):
    """Initialize synapse count scaling for edge types.

    This class initializes synapse strengths based on the average synapse count for each
    edge type, scaling them differently for chemical and electrical synapses.

    The initialization follows this equation:

    $$\\alpha_{t_it_j} =\\frac{\\rho}{\\langle N \\rangle_{t_it_j}}$$

    where:

    1. $\\alpha_{t_it_j}$ is the synapse strength between neurons $i$ and $j$.
    2. $\\langle N \\rangle_{t_it_j}$ is the average synapse count for the edge type
        across columnar offsets $u_i-u_j$ and $v_i-v_j$
    3. $\\rho$ is a scaling factor (default: 0.01)

    """

    @deepcopy_config
    def __init__(
        self, param_config: Namespace, connectome: ConnectomeFromAvgFilters
    ) -> None:
        edges_dir = connectome.edges

        edges = pd.DataFrame({
            k: byte_to_str(edges_dir[k][:]) for k in [*param_config.groupby, "n_syn"]
        })
        grouped_edges = edges.groupby(
            param_config.groupby, as_index=False, sort=False
        ).mean()

        # to initialize synapse strengths with scale/<N>_rf
        syn_strength = param_config.get("scale", 0.01) / grouped_edges.n_syn.values

        param_config.target_type = grouped_edges.target_type.values
        param_config.source_type = grouped_edges.source_type.values
        param_config.value = syn_strength

        self.indices = get_scatter_indices(edges, grouped_edges, param_config.groupby)
        self.parameter = forward_subclass(
            InitialDistribution, param_config, subclass_key="initial_dist"
        )
        self.keys = list(
            zip(
                param_config.source_type.tolist(),
                param_config.target_type.tolist(),
            )
        )
        self.symmetry_masks = symmetry_masks(param_config.get("symmetric", []), self.keys)

deepcopy_config

deepcopy_config(f)

Decorator to deepcopy the parameter configuration.

Note

This decorator is necessary because the __init__ method of parameter classes often modifies the param_config object. By creating a deep copy, we ensure that these modifications don’t affect the original param_config object in the outer scope. This prevents unintended side effects and maintains the integrity of the original configuration.

Source code in flyvision/network/initialization.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def deepcopy_config(f):
    """Decorator to deepcopy the parameter configuration.

    Note:
        This decorator is necessary because the `__init__` method of parameter classes
        often modifies the `param_config` object. By creating a deep copy, we ensure
        that these modifications don't affect the original `param_config` object in the
        outer scope. This prevents unintended side effects and maintains the integrity
        of the original configuration.
    """

    @functools.wraps(f)
    def wrapper(cls, param_config, connectome):
        cls.config = deepcopy(param_config)
        return f(cls, deepcopy(param_config), connectome)

    return wrapper

get_scatter_indices

get_scatter_indices(dataframe, grouped_dataframe, groupby)

Get indices for scattering operations to share parameters.

Maps each node/edge from the complete computational graph to a parameter index.

Parameters:

Name Type Description Default
dataframe DataFrame

Dataframe of nodes or edges of the graph.

required
grouped_dataframe DataFrame

Aggregated version of the same dataframe.

required
groupby List[str]

The same columns from which the grouped_dataframe was constructed.

required

Returns:

Type Description
Tensor

Tensor of indices for scattering operations.

Note

For N elements that are grouped into M groups, this function returns N indices from 0 to M-1 that can be used to scatter the parameters of the M groups to the N elements.

Example
elements = ["A", "A", "A", "B", "B", "C", "D", "D", "E"]
groups = ["A", "B", "C", "D", "E"]
parameter = [1, 2, 3, 4, 5]
# get_scatter_indices would return
scatter_indices = [0, 0, 0, 1, 1, 2, 3, 3, 4]
scattered_parameters = [parameter[idx] for idx in scatter_indices]
scattered_parameters == [1, 1, 1, 2, 2, 3, 4, 4, 5]
Source code in flyvision/network/initialization.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
def get_scatter_indices(
    dataframe: pd.DataFrame, grouped_dataframe: pd.DataFrame, groupby: List[str]
) -> Tensor:
    """Get indices for scattering operations to share parameters.

    Maps each node/edge from the complete computational graph to a parameter index.

    Args:
        dataframe: Dataframe of nodes or edges of the graph.
        grouped_dataframe: Aggregated version of the same dataframe.
        groupby: The same columns from which the grouped_dataframe was constructed.

    Returns:
        Tensor of indices for scattering operations.

    Note:
        For N elements that are grouped into M groups, this function returns N indices
        from 0 to M-1 that can be used to scatter the parameters of the M groups to the
        N elements.

    Example:
        ```python
        elements = ["A", "A", "A", "B", "B", "C", "D", "D", "E"]
        groups = ["A", "B", "C", "D", "E"]
        parameter = [1, 2, 3, 4, 5]
        # get_scatter_indices would return
        scatter_indices = [0, 0, 0, 1, 1, 2, 3, 3, 4]
        scattered_parameters = [parameter[idx] for idx in scatter_indices]
        scattered_parameters == [1, 1, 1, 2, 2, 3, 4, 4, 5]
        ```
    """
    ungrouped_elements = zip(*[dataframe[k][:] for k in groupby])
    grouped_elements = zip(*[grouped_dataframe[k][:] for k in groupby])
    to_index = {k: i for i, k in enumerate(grouped_elements)}
    return torch.tensor([to_index[k] for k in ungrouped_elements])

symmetry_masks

symmetry_masks(symmetric, keys, as_mask=False)

Create masks for subsets of parameters for joint constraints.

Parameters:

Name Type Description Default
symmetric List[Any]

Contains subsets of keys that point to the subsets of parameters to be indexed.

required
keys List[Any]

List of keys that point to individual parameter values.

required
as_mask bool

If True, returns a boolean mask, otherwise integer indices.

False

Returns:

Type Description
List[Tensor]

List of masks (List[torch.BoolTensor]).

Note

This is experimental for configuration-based fine-grained shared parameter optimization, e.g. for models including multi-compartment cells or gap junctions.

Example
# For node type parameters with individual node types as keys:
symmetric = [["T4a", "T4b", "T4c", "T4d"], ["T5a", "T5b", "T5c", "T5d"]]
# This would constrain the parameter values of all T4 subtypes to their joint
# mean and the parameter values of all T5 subtypes to their joint mean.

# For edge type parameters with individual edge types as keys:
symmetric = [[("CT1(M10)", "CT1(Lo1)"), ("CT1(Lo1)", "CT1(M10)")]]
# This would constrain the edge parameter of the directed edge from CT1(M10) to
# CT1(Lo1) and the directed edge from CT1(Lo1) to CT1(M10) to their joint mean.
Source code in flyvision/network/initialization.py
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
def symmetry_masks(
    symmetric: List[Any], keys: List[Any], as_mask: bool = False
) -> List[torch.Tensor]:
    """Create masks for subsets of parameters for joint constraints.

    Args:
        symmetric: Contains subsets of keys that point to the subsets of parameters
            to be indexed.
        keys: List of keys that point to individual parameter values.
        as_mask: If True, returns a boolean mask, otherwise integer indices.

    Returns:
        List of masks (List[torch.BoolTensor]).

    Note:
        This is experimental for configuration-based fine-grained shared parameter
        optimization, e.g. for models including multi-compartment cells or gap
        junctions.

    Example:
        ```python
        # For node type parameters with individual node types as keys:
        symmetric = [["T4a", "T4b", "T4c", "T4d"], ["T5a", "T5b", "T5c", "T5d"]]
        # This would constrain the parameter values of all T4 subtypes to their joint
        # mean and the parameter values of all T5 subtypes to their joint mean.

        # For edge type parameters with individual edge types as keys:
        symmetric = [[("CT1(M10)", "CT1(Lo1)"), ("CT1(Lo1)", "CT1(M10)")]]
        # This would constrain the edge parameter of the directed edge from CT1(M10) to
        # CT1(Lo1) and the directed edge from CT1(Lo1) to CT1(M10) to their joint mean.
        ```
    """
    if not symmetric:
        return []
    symmetry_masks = []  # type: List[torch.Tensor]
    keys = atleast_column_vector(keys)
    for identifiers in symmetric:
        identifiers = atleast_column_vector(identifiers)
        # to allow identifiers like [None, "A", None, 0]
        # for parameters that have tuples as keys
        columns = np.arange(identifiers.shape[1] + 1)[
            np.where((identifiers is not None).all(axis=0))
        ]
        try:
            symmetry_masks.append(
                torch.tensor(
                    where_equal_rows(
                        identifiers[:, columns], keys[:, columns], as_mask=as_mask
                    )
                )
            )
        except Exception as e:
            raise ValueError(
                f"{identifiers} cannot be a symmetry constraint "
                f"for parameter with keys {keys}: {e}"
            ) from e
    return symmetry_masks