Skip to content

Connectomes

Connectome protocol

Connectomes must implement the Connectome protocol to be compatible with flyvision.network.network.Network.

flyvision.connectome.connectome.Connectome

Bases: Protocol

Protocol for connectome classes compatible with flyvision.network.Network.

Note

Nodes and edges have additional attributes that require compatibility with Parameter class implementations. For instance, when a parameter for edges is derived from synapse counts, the edges have an n_syn attribute (ArrayFile or np.ndarray).

Source code in flyvision/connectome/connectome.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class Connectome(Protocol):
    """Protocol for connectome classes compatible with flyvision.network.Network.

    Note:
        Nodes and edges have additional attributes that require compatibility
        with `Parameter` class implementations. For instance, when a parameter
        for edges is derived from synapse counts, the edges have an `n_syn`
        attribute (ArrayFile or np.ndarray).
    """

    class nodes:
        index: Union[np.ndarray, ArrayFile]
        ...

    class edges:
        source_index: Union[np.ndarray, ArrayFile]
        target_index: Union[np.ndarray, ArrayFile]
        ...

Compilation from average filters

flyvision.connectome.connectome.ConnectomeFromAvgFilters

Bases: Directory

Compiles a connectome graph from average convolutional filters.

The graph consists of cells (nodes) and synapse sets (edges).

Parameters:

Name Type Description Default
file

The name of a JSON connectome file.

connectome_file
extent

The array radius, in columns.

15
n_syn_fill

The number of synapses to assume in data gaps.

1

Attributes:

Name Type Description
unique_cell_types ArrayFile

Identified cell types.

input_cell_types ArrayFile

Input cell types.

intermediate_cell_types ArrayFile

Hidden cell types.

output_cell_types ArrayFile

Decoded cell types.

central_cells_index ArrayFile

Index of central cell in nodes table for each cell type in unique_cell_types.

layout ArrayFile

Input, hidden, output definitions for visualization.

nodes NodeDir

Table with a row for each individual node/cell.

edges EdgeDir

Table with a row for each edge.

Note

A connectome can be constructed from a JSON model file following this schema:

{
    "nodes": [{
        "name": string,
        "pattern": (
            ["stride", [<u_stride:int>, <v_stride:int>]]
            | ["tile", <stride:int>]
            | ["single", null]
        )
    }*],
    "edges": [{
        "src": string,
        "tar": string,
        "alpha": int,
        "offsets": [[
            [<du:int>, <dv:int>],
            <n_synapses:number>
            ]*],
        }*]
    }
}

See “data/connectome/fib25-fib19_v2.2.json” for an example.

Example
config = Namespace(file='fib25-fib19_v2.2.json', extent=15, n_syn_fill=1)
connectome = Connectome(config)
Source code in flyvision/connectome/connectome.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
@register_connectome
@root(flyvision.root_dir / "connectome")
class ConnectomeFromAvgFilters(Directory):
    """Compiles a connectome graph from average convolutional filters.

    The graph consists of cells (nodes) and synapse sets (edges).

    Args:
        file: The name of a JSON connectome file.
        extent: The array radius, in columns.
        n_syn_fill: The number of synapses to assume in data gaps.

    Attributes:
        unique_cell_types (ArrayFile): Identified cell types.
        input_cell_types (ArrayFile): Input cell types.
        intermediate_cell_types (ArrayFile): Hidden cell types.
        output_cell_types (ArrayFile): Decoded cell types.
        central_cells_index (ArrayFile): Index of central cell in nodes table
            for each cell type in unique_cell_types.
        layout (ArrayFile): Input, hidden, output definitions for visualization.
        nodes (NodeDir): Table with a row for each individual node/cell.
        edges (EdgeDir): Table with a row for each edge.

    Note:
        A connectome can be constructed from a JSON model file following this schema:

        ```json
        {
            "nodes": [{
                "name": string,
                "pattern": (
                    ["stride", [<u_stride:int>, <v_stride:int>]]
                    | ["tile", <stride:int>]
                    | ["single", null]
                )
            }*],
            "edges": [{
                "src": string,
                "tar": string,
                "alpha": int,
                "offsets": [[
                    [<du:int>, <dv:int>],
                    <n_synapses:number>
                    ]*],
                }*]
            }
        }
        ```

        See "data/connectome/fib25-fib19_v2.2.json" for an example.

    Example:
        ```python
        config = Namespace(file='fib25-fib19_v2.2.json', extent=15, n_syn_fill=1)
        connectome = Connectome(config)
        ```
    """

    def __init__(self, file=flyvision.connectome_file, extent=15, n_syn_fill=1) -> None:
        if not Path(file).exists():
            file = flyvision.root_dir / "connectome" / file

        # Load the connectome spec.
        spec = json.loads(Path(file).read_text())

        # Store unique cell types and layout variables.
        self.unique_cell_types = np.bytes_([n["name"] for n in spec["nodes"]])
        self.input_cell_types = np.bytes_(spec["input_units"])
        self.output_cell_types = np.bytes_(spec["output_units"])
        intermediate_cell_types, _ = nodes_edges_utils.order_node_type_list(
            np.array(
                list(
                    set(self.unique_cell_types)
                    - set(self.input_cell_types)
                    - set(self.output_cell_types)
                )
            ).astype(str)
        )
        self.intermediate_cell_types = np.array(intermediate_cell_types).astype("S")

        layout = []
        layout.extend(
            list(
                zip(
                    self.input_cell_types,
                    [b"retina" for _ in range(len(self.input_cell_types))],
                )
            )
        )
        layout.extend(
            list(
                zip(
                    self.intermediate_cell_types,
                    [b"intermediate" for _ in range(len(self.intermediate_cell_types))],
                )
            )
        )
        layout.extend(
            list(
                zip(
                    self.output_cell_types,
                    [b"output" for _ in range(len(self.output_cell_types))],
                )
            )
        )
        self.layout = np.bytes_(layout)

        # Construct nodes and edges.
        nodes: List[Node] = []
        edges: List[Edge] = []
        add_nodes(nodes, spec["nodes"], extent)
        add_edges(edges, nodes, spec["edges"], n_syn_fill)

        # Define node roles (input, intermediate, output).
        _role = {node: "intermediate" for node in set([n.type for n in nodes])}
        _role.update({node: "input" for node in _role if node in spec["input_units"]})
        _role.update({node: "output" for node in _role if node in spec["output_units"]})

        # Store the graph.
        self.nodes = dict(  # type: ignore
            index=np.int64([n.id for n in nodes]),
            type=np.bytes_([n.type for n in nodes]),
            u=np.int32([n.u for n in nodes]),
            v=np.int32([n.v for n in nodes]),
            role=np.bytes_([_role[n.type] for n in nodes]),
        )

        self.edges = dict(  # type: ignore
            # [Essential fields]
            source_index=np.int64([e.source.id for e in edges]),
            target_index=np.int64([e.target.id for e in edges]),
            sign=np.float32([e.sign for e in edges]),
            n_syn=np.float32([e.n_syn for e in edges]),
            # [Convenience fields]
            source_type=np.bytes_([e.source.type for e in edges]),
            target_type=np.bytes_([e.target.type for e in edges]),
            source_u=np.int32([e.source.u for e in edges]),
            target_u=np.int32([e.target.u for e in edges]),
            source_v=np.int32([e.source.v for e in edges]),
            target_v=np.int32([e.target.v for e in edges]),
            du=np.int32([e.target.u - e.source.u for e in edges]),
            dv=np.int32([e.target.v - e.source.v for e in edges]),
            n_syn_certainty=np.float32([e.n_syn_certainty for e in edges]),
        )

        # Store central indices.
        self.central_cells_index = np.int64(
            np.nonzero((self.nodes.u[:] == 0) & (self.nodes.v[:] == 0))[0]
        )

        # Store layer indices.
        layer_index = {}
        for cell_type in self.unique_cell_types[:]:
            node_indices = np.nonzero(self.nodes["type"][:] == cell_type)[0]
            layer_index[cell_type.decode()] = np.int64(node_indices)
        self.nodes.layer_index = layer_index

Analysis and visualization

flyvision.connectome.connectome.ConnectomeView

Visualization of the connectome data.

Parameters:

Name Type Description Default
connectome ConnectomeFromAvgFilters

Directory of the connectome.

required
groups List[str]

Regular expressions to sort the nodes by.

['R\\d', 'L\\d', 'Lawf\\d', 'A', 'C\\d', 'CT\\d.*', 'Mi\\d{1,2}', 'T\\d{1,2}.*', 'Tm.*\\d{1,2}.*']

Attributes:

Name Type Description
dir ConnectomeFromAvgFilters

Connectome directory.

edges Directory

Edge table.

nodes Directory

Node table.

cell_types_unsorted List[str]

Unsorted list of cell types.

cell_types_sorted List[str]

Sorted list of cell types.

cell_types_sort_index List[int]

Indices for sorting cell types.

layout Dict[str, str]

Layout information for cell types.

node_indexer NodeIndexer

Indexer for nodes.

Source code in flyvision/connectome/connectome.py
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
class ConnectomeView:
    """Visualization of the connectome data.

    Args:
        connectome: Directory of the connectome.
        groups: Regular expressions to sort the nodes by.

    Attributes:
        dir (ConnectomeFromAvgFilters): Connectome directory.
        edges (Directory): Edge table.
        nodes (Directory): Node table.
        cell_types_unsorted (List[str]): Unsorted list of cell types.
        cell_types_sorted (List[str]): Sorted list of cell types.
        cell_types_sort_index (List[int]): Indices for sorting cell types.
        layout (Dict[str, str]): Layout information for cell types.
        node_indexer (NodeIndexer): Indexer for nodes.
    """

    def __init__(
        self,
        connectome: ConnectomeFromAvgFilters,
        groups: List[str] = [
            r"R\d",
            r"L\d",
            r"Lawf\d",
            r"A",
            r"C\d",
            r"CT\d.*",
            r"Mi\d{1,2}",
            r"T\d{1,2}.*",
            r"Tm.*\d{1,2}.*",
        ],
    ) -> None:
        self.dir = connectome

        assert "nodes" in self.dir and "edges" in self.dir

        self.edges = self.dir.edges
        self.nodes = self.dir.nodes

        self.cell_types_unsorted = self.dir.unique_cell_types[:].astype(str)

        (
            self.cell_types_sorted,
            self.cell_types_sort_index,
        ) = nodes_edges_utils.order_node_type_list(
            self.dir.unique_cell_types[:].astype(str), groups
        )

        self.layout = dict(self.dir.layout[:].astype(str))
        self.node_indexer = nodes_edges_utils.NodeIndexer(self.dir)

    def connectivity_matrix(
        self,
        mode: str = "n_syn",
        only_sign: Optional[str] = None,
        cell_types: Optional[List[str]] = None,
        no_symlog: Optional[bool] = False,
        min_number: Optional[float] = None,
        cmap: Optional[Colormap] = None,
        size_scale: Optional[float] = None,
        title: Optional[str] = None,
        cbar_label: Optional[str] = None,
        **kwargs,
    ) -> Figure:
        """Plot the connectivity matrix as counts or weights.

        Args:
            mode: 'n_syn' for number of input synapses, 'count' for number of neurons.
            only_sign: '+' for excitatory projections, '-' for inhibitory projections.
            cell_types: Subset of nodes to display.
            no_symlog: Disable symmetric log scale.
            min_number: Minimum value to display.
            cmap: Custom colormap.
            size_scale: Size of the scattered squares.
            title: Custom title for the plot.
            cbar_label: Custom colorbar label.
            **kwargs: Additional arguments passed to the heatmap plot function.

        Returns:
            Figure: Matplotlib figure object.
        """
        _kwargs = dict(
            n_syn=dict(
                symlog=1e-5,
                grid=True,
                cmap=cmap or cm.get_cmap("seismic"),
                title=title or "Connectivity between identified cell types",
                cbar_label=cbar_label or r"$\pm\sum_{pre} N_\mathrm{syn.}^{pre, post}$",
                size_scale=size_scale or 0.05,
            ),
            count=dict(
                grid=True,
                cmap=cmap or cm.get_cmap("seismic"),
                midpoint=0,
                title=title or "Number of Input Neurons",
                cbar_label=cbar_label or r"$\sum_{pre} 1$",
                size_scale=size_scale or 0.05,
            ),
        )

        kwargs.update(_kwargs[mode])
        if no_symlog:
            kwargs.update(symlog=None)
            kwargs.update(midpoint=0)

        edges = self.edges.to_df()

        # to take projections onto central nodes (home columns) into account
        edges = edges[(edges.target_u == 0) & (edges.target_v == 0)]

        # filter edges to allow providing a subset of cell types
        cell_types = cell_types or self.cell_types_sorted
        edges = df_utils.filter_by_column_values(
            df_utils.filter_by_column_values(
                edges, column="source_type", values=cell_types
            ),
            column="target_type",
            values=cell_types,
        )
        weights = self._weights()[edges.index]

        # lookup table for key -> (i, j)
        type_index = {node_typ: i for i, node_typ in enumerate(cell_types)}
        matrix = np.zeros([len(type_index), len(type_index)])

        for srctyp, tgttyp, weight in zip(
            edges.source_type.values, edges.target_type.values, weights
        ):
            if mode == "count":
                # to simply count the number of projections
                matrix[type_index[srctyp], type_index[tgttyp]] += 1
            elif mode in ["weight", "n_syn"]:
                # to sum the synapse counts
                matrix[type_index[srctyp], type_index[tgttyp]] += weight
            else:
                raise ValueError

        # to filter out all connections weaker than min_number
        if min_number is not None:
            matrix[np.abs(matrix) <= min_number] = np.nan

        # to display either only excitatory or inhibitory connections
        if only_sign == "+":
            matrix[matrix < 0] = 0
            kwargs.update(symlog=None, midpoint=0)
        elif only_sign == "-":
            matrix[matrix > 0] = 0
            kwargs.update(symlog=None, midpoint=0)
        elif only_sign is None:
            pass
        else:
            raise ValueError

        return plots.heatmap(matrix, cell_types, **kwargs)

    def _weights(self) -> NDArray:
        """Calculate weights for edges.

        Returns:
            NDArray: Array of edge weights.
        """
        return self.edges.sign[:] * self.edges.n_syn[:]

    def network_layout(
        self,
        max_extent: int = 5,
        **kwargs,
    ) -> Figure:
        """Plot retinotopic hexagonal lattice columnar organization of the network.

        Args:
            max_extent: Integer column radius to visualize.
            **kwargs: Additional arguments passed to hex_layout_all.

        Returns:
            Figure: Matplotlib figure object.
        """
        backbone = WholeNetworkFigure(self.dir)
        backbone.init_figure(figsize=[7, 3])
        return self.hex_layout_all(
            max_extent=max_extent, fig=backbone.fig, axes=backbone.axes, **kwargs
        )

    def hex_layout(
        self,
        cell_type: str,
        max_extent: int = 5,
        edgecolor: str = "none",
        edgewidth: float = 0.5,
        alpha: float = 1,
        fill: bool = False,
        cmap: Optional[Colormap] = None,
        fig: Optional[Figure] = None,
        ax: Optional[Axes] = None,
        **kwargs,
    ) -> Figure:
        """Plot retinotopic hexagonal lattice organization of a cell type.

        Args:
            cell_type: Type of cell to plot.
            max_extent: Maximum extent of the layout.
            edgecolor: Color of the hexagon edges.
            edgewidth: Width of the hexagon edges.
            alpha: Transparency of the hexagons.
            fill: Whether to fill the hexagons.
            cmap: Custom colormap.
            fig: Existing figure to plot on.
            ax: Existing axis to plot on.
            **kwargs: Additional arguments passed to hex_scatter.

        Returns:
            Figure: Matplotlib figure object.
        """
        nodes = self.nodes.to_df()
        node_condition = nodes.type == cell_type
        u, v = nodes.u[node_condition], nodes.v[node_condition]
        max_extent = hex_utils.get_extent(u, v) if max_extent is None else max_extent
        extent_condition = (
            (-max_extent <= u)
            & (u <= max_extent)
            & (-max_extent <= v)
            & (v <= max_extent)
            & (-max_extent <= u + v)
            & (u + v <= max_extent)
        )
        u, v = u[extent_condition].values, v[extent_condition].values

        label = cell_type
        if ax is not None:
            # prevent labeling twice
            label = cell_type if cell_type not in [t.get_text() for t in ax.texts] else ""

        fig, ax, _ = plots.hex_scatter(
            u,
            v,
            values=1,
            label=label,
            fig=fig,
            ax=ax,
            edgecolor=edgecolor,
            edgewidth=edgewidth,
            alpha=alpha,
            fill=fill,
            cmap=cmap or plt_utils.get_alpha_colormap("#2f3541", 1),
            cbar=False,
            **kwargs,
        )
        return fig

    def hex_layout_all(
        self,
        cell_types: Optional[List[str]] = None,
        max_extent: int = 5,
        edgecolor: str = "none",
        alpha: float = 1,
        fill: bool = False,
        cmap: Optional[Colormap] = None,
        fig: Optional[Figure] = None,
        axes: Optional[List[Axes]] = None,
        **kwargs,
    ) -> Figure:
        """Plot retinotopic hexagonal lattice organization of all cell types.

        Args:
            cell_types: List of cell types to plot.
            max_extent: Maximum extent of the layout.
            edgecolor: Color of the hexagon edges.
            alpha: Transparency of the hexagons.
            fill: Whether to fill the hexagons.
            cmap: Custom colormap.
            fig: Existing figure to plot on.
            axes: List of existing axes to plot on.
            **kwargs: Additional arguments passed to hex_layout.

        Returns:
            Figure: Matplotlib figure object.
        """
        cell_types = self.cell_types_sorted if cell_types is None else cell_types
        if fig is None or axes is None:
            fig, axes, (gw, gh) = plt_utils.get_axis_grid(self.cell_types_sorted)

        for i, cell_type in enumerate(cell_types):
            self.hex_layout(
                cell_type,
                edgecolor=edgecolor,
                edgewidth=0.1,
                alpha=alpha,
                fill=fill,
                max_extent=max_extent,
                cmap=cmap or plt_utils.get_alpha_colormap("#2f3541", 1),
                fig=fig,
                ax=axes[i],
                **kwargs,
            )
        return fig

    def get_uv(self, cell_type: str) -> Tuple[NDArray, NDArray]:
        """Get hex-coordinates of a particular cell type.

        Args:
            cell_type: Type of cell to get coordinates for.

        Returns:
            Tuple[NDArray, NDArray]: Arrays of u and v coordinates.
        """
        nodes = self.nodes.to_df()
        nodes = nodes[nodes.type == cell_type]
        u, v = nodes[["u", "v"]].values.T
        return u, v

    def sources_list(self, cell_type: str) -> NDArray:
        """Get presynaptic cell types.

        Args:
            cell_type: Type of cell to get sources for.

        Returns:
            NDArray: Array of presynaptic cell types.
        """
        edges = self.edges.to_df()
        return np.unique(edges[edges.target_type == cell_type].source_type.values)

    def targets_list(self, cell_type: str) -> NDArray:
        """Get postsynaptic cell types.

        Args:
            cell_type: Type of cell to get targets for.

        Returns:
            NDArray: Array of postsynaptic cell types.
        """
        edges = self.edges.to_df()
        return np.unique(edges[edges.source_type == cell_type].target_type.values)

    def receptive_field(
        self,
        source: str = "Mi9",
        target: str = "T4a",
        rfs: Optional["ReceptiveFields"] = None,
        max_extent: Optional[int] = None,
        vmin: Optional[float] = None,
        vmax: Optional[float] = None,
        title: str = "{source} :→ {target}",
        **kwargs,
    ) -> Figure:
        """Plot the receptive field of a target cell type from a source cell type.

        Args:
            source: Source cell type.
            target: Target cell type.
            rfs: ReceptiveFields object. If None, it will be created.
            max_extent: Maximum extent of the receptive field.
            vmin: Minimum value for colormap.
            vmax: Maximum value for colormap.
            title: Title format string for the plot.
            **kwargs: Additional arguments passed to plots.kernel.

        Returns:
            Matplotlib Figure object.
        """
        if rfs is None:
            rfs = ReceptiveFields(target, self.edges.to_df())
            max_extent = max_extent or rfs.max_extent

        weights = self._weights()

        # to derive color range values taking all inputs into account
        vmin = min(
            0,
            min(weights[rfs[source].index].min() for source in rfs.source_types),
        )
        vmax = max(
            0,
            max(weights[rfs[source].index].max() for source in rfs.source_types),
        )

        weights = weights[rfs[source].index]
        label = ""

        # requires to look from the target cell, ie mirror the coordinates
        du_inv, dv_inv = -rfs[source].du.values, -rfs[source].dv.values
        fig, ax, (label_text, scalarmapper) = plots.kernel(
            du_inv,
            dv_inv,
            weights,
            label=label,
            max_extent=max_extent,
            fill=True,
            vmin=vmin,
            vmax=vmax,
            title=title.format(**locals()),
            **kwargs,
        )
        return fig

    def receptive_fields_grid(
        self,
        target: str,
        sources: Optional[Iterable[str]] = None,
        sort_alphabetically: bool = True,
        ax_titles: str = "{source} :→ {target}",
        figsize: List[int] = [20, 20],
        max_extent: Optional[int] = None,
        fig: Optional[Figure] = None,
        axes: Optional[List[Axes]] = None,
        ignore_sign_error: bool = False,
        max_figure_height_cm: float = 22,
        panel_height_cm: float = 3,
        max_figure_width_cm: float = 18,
        panel_width_cm: float = 3.6,
        **kwargs,
    ) -> Figure:
        """Plot receptive fields of a target cell type in a grid layout.

        Args:
            target: Target cell type.
            sources: Iterable of source cell types. If None, all sources are used.
            sort_alphabetically: Whether to sort source types alphabetically.
            ax_titles: Title format string for each subplot.
            figsize: Figure size in inches.
            max_extent: Maximum extent of the receptive fields.
            fig: Existing figure to plot on.
            axes: List of existing axes to plot on.
            ignore_sign_error: Whether to ignore sign errors in plotting.
            max_figure_height_cm: Maximum figure height in cm.
            panel_height_cm: Height of each panel in cm.
            max_figure_width_cm: Maximum figure width in cm.
            panel_width_cm: Width of each panel in cm.
            **kwargs: Additional arguments passed to receptive_field.

        Returns:
            Matplotlib Figure object.
        """

        rfs = ReceptiveFields(target, self.edges.to_df())
        max_extent = max_extent or rfs.max_extent
        weights = self._weights()

        # to sort in descending order by sum of inputs
        sorted_sum_of_inputs = dict(
            sorted(
                valmap(lambda v: weights[v.index].sum(), rfs).items(),
                key=lambda item: item[1],
                reverse=True,
            )
        )
        # to sort alphabetically in case sources is specified
        if sort_alphabetically:
            sources, _ = nodes_edges_utils.order_node_type_list(sources)
        sources = sources or list(sorted_sum_of_inputs.keys())

        # to derive color range values taking all inputs into account
        vmin = min(0, min(weights[rfs[source].index].min() for source in sources))
        vmax = max(0, max(weights[rfs[source].index].max() for source in sources))

        if fig is None or axes is None:
            figsize = figsize_from_n_items(
                len(rfs.source_types),
                max_figure_height_cm=max_figure_height_cm,
                panel_height_cm=panel_height_cm,
                max_figure_width_cm=max_figure_width_cm,
                panel_width_cm=panel_width_cm,
            )
            fig, axes = figsize.axis_grid(
                unmask_n=len(rfs.source_types), hspace=0.0, wspace=0
            )

        cbar = kwargs.get("cbar", False)
        for i, src in enumerate(sources):
            if i == 0 and cbar:
                cbar = True
                kwargs.update(cbar=cbar)
            else:
                cbar = False
                kwargs.update(cbar=cbar)
            try:
                self.receptive_field(
                    target=target,
                    source=src,
                    fig=fig,
                    ax=axes[i],
                    title=ax_titles,
                    vmin=vmin,
                    vmax=vmax,
                    rfs=rfs,
                    max_extent=max_extent,
                    annotate=False,
                    annotate_coords=False,
                    title_y=0.9,
                    **kwargs,
                )
            except plots.SignError as e:
                if ignore_sign_error:
                    pass
                else:
                    raise e
        return fig

    def projective_field(
        self,
        source: str = "Mi9",
        target: str = "T4a",
        title: str = "{source} →: {target}",
        prfs: Optional["ProjectiveFields"] = None,
        max_extent: Optional[int] = None,
        vmin: Optional[float] = None,
        vmax: Optional[float] = None,
        **kwargs,
    ) -> Optional[Figure]:
        """Plot the projective field from a source cell type to a target cell type.

        Args:
            source: Source cell type.
            target: Target cell type.
            title: Title format string for the plot.
            prfs: ProjectiveFields object. If None, it will be created.
            max_extent: Maximum extent of the projective field.
            vmin: Minimum value for colormap.
            vmax: Maximum value for colormap.
            **kwargs: Additional arguments passed to plots.kernel.

        Returns:
            Matplotlib Figure object or None if max_extent is None.
        """
        if prfs is None:
            prfs = ProjectiveFields(source, self.edges.to_df())
            max_extent = max_extent or prfs.max_extent
        if max_extent is None:
            return None
        weights = self._weights()

        # to derive color range values taking all inputs into account
        vmin = min(
            0,
            min(weights[prfs[target].index].min() for target in prfs.target_types),
        )

        vmax = max(
            0,
            max(weights[prfs[target].index].max() for target in prfs.target_types),
        )

        weights = weights[prfs[target].index]
        label = ""
        du, dv = prfs[target].du.values, prfs[target].dv.values
        fig, ax, (label_text, scalarmapper) = plots.kernel(
            du,
            dv,
            weights,
            label=label,
            fill=True,
            max_extent=max_extent,
            vmin=vmin,
            vmax=vmax,
            title=title.format(**locals()),
            **kwargs,
        )
        return fig

    def projective_fields_grid(
        self,
        source: str,
        targets: Optional[Iterable[str]] = None,
        fig: Optional[Figure] = None,
        axes: Optional[List[Axes]] = None,
        figsize: List[int] = [20, 20],
        ax_titles: str = "{source} →: {target}",
        max_figure_height_cm: float = 22,
        panel_height_cm: float = 3,
        max_figure_width_cm: float = 18,
        panel_width_cm: float = 3.6,
        max_extent: Optional[int] = None,
        sort_alphabetically: bool = False,
        ignore_sign_error: bool = False,
        **kwargs,
    ) -> Figure:
        """Plot projective fields of a source cell type in a grid layout.

        Args:
            source: Source cell type.
            targets: Iterable of target cell types. If None, all targets are used.
            fig: Existing figure to plot on.
            axes: List of existing axes to plot on.
            figsize: Figure size in inches.
            ax_titles: Title format string for each subplot.
            max_figure_height_cm: Maximum figure height in cm.
            panel_height_cm: Height of each panel in cm.
            max_figure_width_cm: Maximum figure width in cm.
            panel_width_cm: Width of each panel in cm.
            max_extent: Maximum extent of the projective fields.
            sort_alphabetically: Whether to sort target types alphabetically.
            ignore_sign_error: Whether to ignore sign errors in plotting.
            **kwargs: Additional arguments passed to projective_field.

        Returns:
            Matplotlib Figure object.
        """
        prfs = ProjectiveFields(source, self.edges.to_df())
        max_extent = max_extent or prfs.max_extent
        weights = self._weights()
        sorted_sum_of_outputs = dict(
            sorted(
                valmap(lambda v: weights[v.index].sum(), prfs).items(),
                key=lambda item: item[1],
                reverse=True,
            )
        )

        # to sort alphabetically in case sources is specified
        if sort_alphabetically:
            targets, _ = nodes_edges_utils.order_node_type_list(targets)

        targets = targets or list(sorted_sum_of_outputs.keys())

        vmin = min(0, min(weights[prfs[target].index].min() for target in targets))
        vmax = max(0, max(weights[prfs[target].index].max() for target in targets))

        if fig is None or axes is None:
            figsize = figsize_from_n_items(
                len(prfs.target_types),
                max_figure_height_cm=max_figure_height_cm,
                panel_height_cm=panel_height_cm,
                max_figure_width_cm=max_figure_width_cm,
                panel_width_cm=panel_width_cm,
            )
            fig, axes = figsize.axis_grid(
                unmask_n=len(prfs.target_types), hspace=0.0, wspace=0
            )

        cbar = kwargs.get("cbar", False)
        for i, target in enumerate(targets):
            if i == 0 and cbar:
                cbar = True
                kwargs.update(cbar=cbar)
            else:
                cbar = False
                kwargs.update(cbar=cbar)
            try:
                self.projective_field(
                    source=source,
                    target=target,
                    fig=fig,
                    ax=axes[i],
                    title=ax_titles,
                    prfs=prfs,
                    max_extent=max_extent,
                    vmin=vmin,
                    vmax=vmax,
                    annotate_coords=False,
                    annotate=False,
                    title_y=0.9,
                    **kwargs,
                )
            except plots.SignError as e:
                if ignore_sign_error:
                    pass
                else:
                    raise e
        return fig

    def receptive_fields_df(self, target_type: str) -> "ReceptiveFields":
        """Get receptive fields for a target cell type.

        Args:
            target_type: Target cell type.

        Returns:
            ReceptiveFields object.
        """
        return ReceptiveFields(target_type, self.edges.to_df())

    def projective_fields_df(self, source_type: str) -> "ProjectiveFields":
        """Get projective fields for a source cell type.

        Args:
            source_type: Source cell type.

        Returns:
            ProjectiveFields object.
        """
        return ProjectiveFields(source_type, self.edges.to_df())

    def receptive_fields_sum(self, target_type: str) -> Dict[str, int]:
        """Get sum of synapses for each source type in the receptive field.

        Args:
            target_type: Target cell type.

        Returns:
            Dictionary mapping source types to synapse counts.
        """
        return ReceptiveFields(target_type, self.edges.to_df()).sum()

    def projective_fields_sum(self, source_type: str) -> Dict[str, int]:
        """Get sum of synapses for each target type in the projective field.

        Args:
            source_type: Source cell type.

        Returns:
            Dictionary mapping target types to synapse counts.
        """

connectivity_matrix

connectivity_matrix(
    mode="n_syn",
    only_sign=None,
    cell_types=None,
    no_symlog=False,
    min_number=None,
    cmap=None,
    size_scale=None,
    title=None,
    cbar_label=None,
    **kwargs
)

Plot the connectivity matrix as counts or weights.

Parameters:

Name Type Description Default
mode str

‘n_syn’ for number of input synapses, ‘count’ for number of neurons.

'n_syn'
only_sign Optional[str]

’+’ for excitatory projections, ‘-’ for inhibitory projections.

None
cell_types Optional[List[str]]

Subset of nodes to display.

None
no_symlog Optional[bool]

Disable symmetric log scale.

False
min_number Optional[float]

Minimum value to display.

None
cmap Optional[Colormap]

Custom colormap.

None
size_scale Optional[float]

Size of the scattered squares.

None
title Optional[str]

Custom title for the plot.

None
cbar_label Optional[str]

Custom colorbar label.

None
**kwargs

Additional arguments passed to the heatmap plot function.

{}

Returns:

Name Type Description
Figure Figure

Matplotlib figure object.

Source code in flyvision/connectome/connectome.py
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
def connectivity_matrix(
    self,
    mode: str = "n_syn",
    only_sign: Optional[str] = None,
    cell_types: Optional[List[str]] = None,
    no_symlog: Optional[bool] = False,
    min_number: Optional[float] = None,
    cmap: Optional[Colormap] = None,
    size_scale: Optional[float] = None,
    title: Optional[str] = None,
    cbar_label: Optional[str] = None,
    **kwargs,
) -> Figure:
    """Plot the connectivity matrix as counts or weights.

    Args:
        mode: 'n_syn' for number of input synapses, 'count' for number of neurons.
        only_sign: '+' for excitatory projections, '-' for inhibitory projections.
        cell_types: Subset of nodes to display.
        no_symlog: Disable symmetric log scale.
        min_number: Minimum value to display.
        cmap: Custom colormap.
        size_scale: Size of the scattered squares.
        title: Custom title for the plot.
        cbar_label: Custom colorbar label.
        **kwargs: Additional arguments passed to the heatmap plot function.

    Returns:
        Figure: Matplotlib figure object.
    """
    _kwargs = dict(
        n_syn=dict(
            symlog=1e-5,
            grid=True,
            cmap=cmap or cm.get_cmap("seismic"),
            title=title or "Connectivity between identified cell types",
            cbar_label=cbar_label or r"$\pm\sum_{pre} N_\mathrm{syn.}^{pre, post}$",
            size_scale=size_scale or 0.05,
        ),
        count=dict(
            grid=True,
            cmap=cmap or cm.get_cmap("seismic"),
            midpoint=0,
            title=title or "Number of Input Neurons",
            cbar_label=cbar_label or r"$\sum_{pre} 1$",
            size_scale=size_scale or 0.05,
        ),
    )

    kwargs.update(_kwargs[mode])
    if no_symlog:
        kwargs.update(symlog=None)
        kwargs.update(midpoint=0)

    edges = self.edges.to_df()

    # to take projections onto central nodes (home columns) into account
    edges = edges[(edges.target_u == 0) & (edges.target_v == 0)]

    # filter edges to allow providing a subset of cell types
    cell_types = cell_types or self.cell_types_sorted
    edges = df_utils.filter_by_column_values(
        df_utils.filter_by_column_values(
            edges, column="source_type", values=cell_types
        ),
        column="target_type",
        values=cell_types,
    )
    weights = self._weights()[edges.index]

    # lookup table for key -> (i, j)
    type_index = {node_typ: i for i, node_typ in enumerate(cell_types)}
    matrix = np.zeros([len(type_index), len(type_index)])

    for srctyp, tgttyp, weight in zip(
        edges.source_type.values, edges.target_type.values, weights
    ):
        if mode == "count":
            # to simply count the number of projections
            matrix[type_index[srctyp], type_index[tgttyp]] += 1
        elif mode in ["weight", "n_syn"]:
            # to sum the synapse counts
            matrix[type_index[srctyp], type_index[tgttyp]] += weight
        else:
            raise ValueError

    # to filter out all connections weaker than min_number
    if min_number is not None:
        matrix[np.abs(matrix) <= min_number] = np.nan

    # to display either only excitatory or inhibitory connections
    if only_sign == "+":
        matrix[matrix < 0] = 0
        kwargs.update(symlog=None, midpoint=0)
    elif only_sign == "-":
        matrix[matrix > 0] = 0
        kwargs.update(symlog=None, midpoint=0)
    elif only_sign is None:
        pass
    else:
        raise ValueError

    return plots.heatmap(matrix, cell_types, **kwargs)

network_layout

network_layout(max_extent=5, **kwargs)

Plot retinotopic hexagonal lattice columnar organization of the network.

Parameters:

Name Type Description Default
max_extent int

Integer column radius to visualize.

5
**kwargs

Additional arguments passed to hex_layout_all.

{}

Returns:

Name Type Description
Figure Figure

Matplotlib figure object.

Source code in flyvision/connectome/connectome.py
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
def network_layout(
    self,
    max_extent: int = 5,
    **kwargs,
) -> Figure:
    """Plot retinotopic hexagonal lattice columnar organization of the network.

    Args:
        max_extent: Integer column radius to visualize.
        **kwargs: Additional arguments passed to hex_layout_all.

    Returns:
        Figure: Matplotlib figure object.
    """
    backbone = WholeNetworkFigure(self.dir)
    backbone.init_figure(figsize=[7, 3])
    return self.hex_layout_all(
        max_extent=max_extent, fig=backbone.fig, axes=backbone.axes, **kwargs
    )

hex_layout

hex_layout(
    cell_type,
    max_extent=5,
    edgecolor="none",
    edgewidth=0.5,
    alpha=1,
    fill=False,
    cmap=None,
    fig=None,
    ax=None,
    **kwargs
)

Plot retinotopic hexagonal lattice organization of a cell type.

Parameters:

Name Type Description Default
cell_type str

Type of cell to plot.

required
max_extent int

Maximum extent of the layout.

5
edgecolor str

Color of the hexagon edges.

'none'
edgewidth float

Width of the hexagon edges.

0.5
alpha float

Transparency of the hexagons.

1
fill bool

Whether to fill the hexagons.

False
cmap Optional[Colormap]

Custom colormap.

None
fig Optional[Figure]

Existing figure to plot on.

None
ax Optional[Axes]

Existing axis to plot on.

None
**kwargs

Additional arguments passed to hex_scatter.

{}

Returns:

Name Type Description
Figure Figure

Matplotlib figure object.

Source code in flyvision/connectome/connectome.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
def hex_layout(
    self,
    cell_type: str,
    max_extent: int = 5,
    edgecolor: str = "none",
    edgewidth: float = 0.5,
    alpha: float = 1,
    fill: bool = False,
    cmap: Optional[Colormap] = None,
    fig: Optional[Figure] = None,
    ax: Optional[Axes] = None,
    **kwargs,
) -> Figure:
    """Plot retinotopic hexagonal lattice organization of a cell type.

    Args:
        cell_type: Type of cell to plot.
        max_extent: Maximum extent of the layout.
        edgecolor: Color of the hexagon edges.
        edgewidth: Width of the hexagon edges.
        alpha: Transparency of the hexagons.
        fill: Whether to fill the hexagons.
        cmap: Custom colormap.
        fig: Existing figure to plot on.
        ax: Existing axis to plot on.
        **kwargs: Additional arguments passed to hex_scatter.

    Returns:
        Figure: Matplotlib figure object.
    """
    nodes = self.nodes.to_df()
    node_condition = nodes.type == cell_type
    u, v = nodes.u[node_condition], nodes.v[node_condition]
    max_extent = hex_utils.get_extent(u, v) if max_extent is None else max_extent
    extent_condition = (
        (-max_extent <= u)
        & (u <= max_extent)
        & (-max_extent <= v)
        & (v <= max_extent)
        & (-max_extent <= u + v)
        & (u + v <= max_extent)
    )
    u, v = u[extent_condition].values, v[extent_condition].values

    label = cell_type
    if ax is not None:
        # prevent labeling twice
        label = cell_type if cell_type not in [t.get_text() for t in ax.texts] else ""

    fig, ax, _ = plots.hex_scatter(
        u,
        v,
        values=1,
        label=label,
        fig=fig,
        ax=ax,
        edgecolor=edgecolor,
        edgewidth=edgewidth,
        alpha=alpha,
        fill=fill,
        cmap=cmap or plt_utils.get_alpha_colormap("#2f3541", 1),
        cbar=False,
        **kwargs,
    )
    return fig

hex_layout_all

hex_layout_all(
    cell_types=None,
    max_extent=5,
    edgecolor="none",
    alpha=1,
    fill=False,
    cmap=None,
    fig=None,
    axes=None,
    **kwargs
)

Plot retinotopic hexagonal lattice organization of all cell types.

Parameters:

Name Type Description Default
cell_types Optional[List[str]]

List of cell types to plot.

None
max_extent int

Maximum extent of the layout.

5
edgecolor str

Color of the hexagon edges.

'none'
alpha float

Transparency of the hexagons.

1
fill bool

Whether to fill the hexagons.

False
cmap Optional[Colormap]

Custom colormap.

None
fig Optional[Figure]

Existing figure to plot on.

None
axes Optional[List[Axes]]

List of existing axes to plot on.

None
**kwargs

Additional arguments passed to hex_layout.

{}

Returns:

Name Type Description
Figure Figure

Matplotlib figure object.

Source code in flyvision/connectome/connectome.py
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
def hex_layout_all(
    self,
    cell_types: Optional[List[str]] = None,
    max_extent: int = 5,
    edgecolor: str = "none",
    alpha: float = 1,
    fill: bool = False,
    cmap: Optional[Colormap] = None,
    fig: Optional[Figure] = None,
    axes: Optional[List[Axes]] = None,
    **kwargs,
) -> Figure:
    """Plot retinotopic hexagonal lattice organization of all cell types.

    Args:
        cell_types: List of cell types to plot.
        max_extent: Maximum extent of the layout.
        edgecolor: Color of the hexagon edges.
        alpha: Transparency of the hexagons.
        fill: Whether to fill the hexagons.
        cmap: Custom colormap.
        fig: Existing figure to plot on.
        axes: List of existing axes to plot on.
        **kwargs: Additional arguments passed to hex_layout.

    Returns:
        Figure: Matplotlib figure object.
    """
    cell_types = self.cell_types_sorted if cell_types is None else cell_types
    if fig is None or axes is None:
        fig, axes, (gw, gh) = plt_utils.get_axis_grid(self.cell_types_sorted)

    for i, cell_type in enumerate(cell_types):
        self.hex_layout(
            cell_type,
            edgecolor=edgecolor,
            edgewidth=0.1,
            alpha=alpha,
            fill=fill,
            max_extent=max_extent,
            cmap=cmap or plt_utils.get_alpha_colormap("#2f3541", 1),
            fig=fig,
            ax=axes[i],
            **kwargs,
        )
    return fig

get_uv

get_uv(cell_type)

Get hex-coordinates of a particular cell type.

Parameters:

Name Type Description Default
cell_type str

Type of cell to get coordinates for.

required

Returns:

Type Description
Tuple[NDArray, NDArray]

Tuple[NDArray, NDArray]: Arrays of u and v coordinates.

Source code in flyvision/connectome/connectome.py
792
793
794
795
796
797
798
799
800
801
802
803
804
def get_uv(self, cell_type: str) -> Tuple[NDArray, NDArray]:
    """Get hex-coordinates of a particular cell type.

    Args:
        cell_type: Type of cell to get coordinates for.

    Returns:
        Tuple[NDArray, NDArray]: Arrays of u and v coordinates.
    """
    nodes = self.nodes.to_df()
    nodes = nodes[nodes.type == cell_type]
    u, v = nodes[["u", "v"]].values.T
    return u, v

sources_list

sources_list(cell_type)

Get presynaptic cell types.

Parameters:

Name Type Description Default
cell_type str

Type of cell to get sources for.

required

Returns:

Name Type Description
NDArray NDArray

Array of presynaptic cell types.

Source code in flyvision/connectome/connectome.py
806
807
808
809
810
811
812
813
814
815
816
def sources_list(self, cell_type: str) -> NDArray:
    """Get presynaptic cell types.

    Args:
        cell_type: Type of cell to get sources for.

    Returns:
        NDArray: Array of presynaptic cell types.
    """
    edges = self.edges.to_df()
    return np.unique(edges[edges.target_type == cell_type].source_type.values)

targets_list

targets_list(cell_type)

Get postsynaptic cell types.

Parameters:

Name Type Description Default
cell_type str

Type of cell to get targets for.

required

Returns:

Name Type Description
NDArray NDArray

Array of postsynaptic cell types.

Source code in flyvision/connectome/connectome.py
818
819
820
821
822
823
824
825
826
827
828
def targets_list(self, cell_type: str) -> NDArray:
    """Get postsynaptic cell types.

    Args:
        cell_type: Type of cell to get targets for.

    Returns:
        NDArray: Array of postsynaptic cell types.
    """
    edges = self.edges.to_df()
    return np.unique(edges[edges.source_type == cell_type].target_type.values)

receptive_field

receptive_field(
    source="Mi9",
    target="T4a",
    rfs=None,
    max_extent=None,
    vmin=None,
    vmax=None,
    title="{source} :→ {target}",
    **kwargs
)

Plot the receptive field of a target cell type from a source cell type.

Parameters:

Name Type Description Default
source str

Source cell type.

'Mi9'
target str

Target cell type.

'T4a'
rfs Optional[ReceptiveFields]

ReceptiveFields object. If None, it will be created.

None
max_extent Optional[int]

Maximum extent of the receptive field.

None
vmin Optional[float]

Minimum value for colormap.

None
vmax Optional[float]

Maximum value for colormap.

None
title str

Title format string for the plot.

'{source} :→ {target}'
**kwargs

Additional arguments passed to plots.kernel.

{}

Returns:

Type Description
Figure

Matplotlib Figure object.

Source code in flyvision/connectome/connectome.py
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
def receptive_field(
    self,
    source: str = "Mi9",
    target: str = "T4a",
    rfs: Optional["ReceptiveFields"] = None,
    max_extent: Optional[int] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    title: str = "{source} :→ {target}",
    **kwargs,
) -> Figure:
    """Plot the receptive field of a target cell type from a source cell type.

    Args:
        source: Source cell type.
        target: Target cell type.
        rfs: ReceptiveFields object. If None, it will be created.
        max_extent: Maximum extent of the receptive field.
        vmin: Minimum value for colormap.
        vmax: Maximum value for colormap.
        title: Title format string for the plot.
        **kwargs: Additional arguments passed to plots.kernel.

    Returns:
        Matplotlib Figure object.
    """
    if rfs is None:
        rfs = ReceptiveFields(target, self.edges.to_df())
        max_extent = max_extent or rfs.max_extent

    weights = self._weights()

    # to derive color range values taking all inputs into account
    vmin = min(
        0,
        min(weights[rfs[source].index].min() for source in rfs.source_types),
    )
    vmax = max(
        0,
        max(weights[rfs[source].index].max() for source in rfs.source_types),
    )

    weights = weights[rfs[source].index]
    label = ""

    # requires to look from the target cell, ie mirror the coordinates
    du_inv, dv_inv = -rfs[source].du.values, -rfs[source].dv.values
    fig, ax, (label_text, scalarmapper) = plots.kernel(
        du_inv,
        dv_inv,
        weights,
        label=label,
        max_extent=max_extent,
        fill=True,
        vmin=vmin,
        vmax=vmax,
        title=title.format(**locals()),
        **kwargs,
    )
    return fig

receptive_fields_grid

receptive_fields_grid(
    target,
    sources=None,
    sort_alphabetically=True,
    ax_titles="{source} :→ {target}",
    figsize=[20, 20],
    max_extent=None,
    fig=None,
    axes=None,
    ignore_sign_error=False,
    max_figure_height_cm=22,
    panel_height_cm=3,
    max_figure_width_cm=18,
    panel_width_cm=3.6,
    **kwargs
)

Plot receptive fields of a target cell type in a grid layout.

Parameters:

Name Type Description Default
target str

Target cell type.

required
sources Optional[Iterable[str]]

Iterable of source cell types. If None, all sources are used.

None
sort_alphabetically bool

Whether to sort source types alphabetically.

True
ax_titles str

Title format string for each subplot.

'{source} :→ {target}'
figsize List[int]

Figure size in inches.

[20, 20]
max_extent Optional[int]

Maximum extent of the receptive fields.

None
fig Optional[Figure]

Existing figure to plot on.

None
axes Optional[List[Axes]]

List of existing axes to plot on.

None
ignore_sign_error bool

Whether to ignore sign errors in plotting.

False
max_figure_height_cm float

Maximum figure height in cm.

22
panel_height_cm float

Height of each panel in cm.

3
max_figure_width_cm float

Maximum figure width in cm.

18
panel_width_cm float

Width of each panel in cm.

3.6
**kwargs

Additional arguments passed to receptive_field.

{}

Returns:

Type Description
Figure

Matplotlib Figure object.

Source code in flyvision/connectome/connectome.py
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
def receptive_fields_grid(
    self,
    target: str,
    sources: Optional[Iterable[str]] = None,
    sort_alphabetically: bool = True,
    ax_titles: str = "{source} :→ {target}",
    figsize: List[int] = [20, 20],
    max_extent: Optional[int] = None,
    fig: Optional[Figure] = None,
    axes: Optional[List[Axes]] = None,
    ignore_sign_error: bool = False,
    max_figure_height_cm: float = 22,
    panel_height_cm: float = 3,
    max_figure_width_cm: float = 18,
    panel_width_cm: float = 3.6,
    **kwargs,
) -> Figure:
    """Plot receptive fields of a target cell type in a grid layout.

    Args:
        target: Target cell type.
        sources: Iterable of source cell types. If None, all sources are used.
        sort_alphabetically: Whether to sort source types alphabetically.
        ax_titles: Title format string for each subplot.
        figsize: Figure size in inches.
        max_extent: Maximum extent of the receptive fields.
        fig: Existing figure to plot on.
        axes: List of existing axes to plot on.
        ignore_sign_error: Whether to ignore sign errors in plotting.
        max_figure_height_cm: Maximum figure height in cm.
        panel_height_cm: Height of each panel in cm.
        max_figure_width_cm: Maximum figure width in cm.
        panel_width_cm: Width of each panel in cm.
        **kwargs: Additional arguments passed to receptive_field.

    Returns:
        Matplotlib Figure object.
    """

    rfs = ReceptiveFields(target, self.edges.to_df())
    max_extent = max_extent or rfs.max_extent
    weights = self._weights()

    # to sort in descending order by sum of inputs
    sorted_sum_of_inputs = dict(
        sorted(
            valmap(lambda v: weights[v.index].sum(), rfs).items(),
            key=lambda item: item[1],
            reverse=True,
        )
    )
    # to sort alphabetically in case sources is specified
    if sort_alphabetically:
        sources, _ = nodes_edges_utils.order_node_type_list(sources)
    sources = sources or list(sorted_sum_of_inputs.keys())

    # to derive color range values taking all inputs into account
    vmin = min(0, min(weights[rfs[source].index].min() for source in sources))
    vmax = max(0, max(weights[rfs[source].index].max() for source in sources))

    if fig is None or axes is None:
        figsize = figsize_from_n_items(
            len(rfs.source_types),
            max_figure_height_cm=max_figure_height_cm,
            panel_height_cm=panel_height_cm,
            max_figure_width_cm=max_figure_width_cm,
            panel_width_cm=panel_width_cm,
        )
        fig, axes = figsize.axis_grid(
            unmask_n=len(rfs.source_types), hspace=0.0, wspace=0
        )

    cbar = kwargs.get("cbar", False)
    for i, src in enumerate(sources):
        if i == 0 and cbar:
            cbar = True
            kwargs.update(cbar=cbar)
        else:
            cbar = False
            kwargs.update(cbar=cbar)
        try:
            self.receptive_field(
                target=target,
                source=src,
                fig=fig,
                ax=axes[i],
                title=ax_titles,
                vmin=vmin,
                vmax=vmax,
                rfs=rfs,
                max_extent=max_extent,
                annotate=False,
                annotate_coords=False,
                title_y=0.9,
                **kwargs,
            )
        except plots.SignError as e:
            if ignore_sign_error:
                pass
            else:
                raise e
    return fig

projective_field

projective_field(
    source="Mi9",
    target="T4a",
    title="{source} →: {target}",
    prfs=None,
    max_extent=None,
    vmin=None,
    vmax=None,
    **kwargs
)

Plot the projective field from a source cell type to a target cell type.

Parameters:

Name Type Description Default
source str

Source cell type.

'Mi9'
target str

Target cell type.

'T4a'
title str

Title format string for the plot.

'{source} →: {target}'
prfs Optional[ProjectiveFields]

ProjectiveFields object. If None, it will be created.

None
max_extent Optional[int]

Maximum extent of the projective field.

None
vmin Optional[float]

Minimum value for colormap.

None
vmax Optional[float]

Maximum value for colormap.

None
**kwargs

Additional arguments passed to plots.kernel.

{}

Returns:

Type Description
Optional[Figure]

Matplotlib Figure object or None if max_extent is None.

Source code in flyvision/connectome/connectome.py
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
def projective_field(
    self,
    source: str = "Mi9",
    target: str = "T4a",
    title: str = "{source} →: {target}",
    prfs: Optional["ProjectiveFields"] = None,
    max_extent: Optional[int] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    **kwargs,
) -> Optional[Figure]:
    """Plot the projective field from a source cell type to a target cell type.

    Args:
        source: Source cell type.
        target: Target cell type.
        title: Title format string for the plot.
        prfs: ProjectiveFields object. If None, it will be created.
        max_extent: Maximum extent of the projective field.
        vmin: Minimum value for colormap.
        vmax: Maximum value for colormap.
        **kwargs: Additional arguments passed to plots.kernel.

    Returns:
        Matplotlib Figure object or None if max_extent is None.
    """
    if prfs is None:
        prfs = ProjectiveFields(source, self.edges.to_df())
        max_extent = max_extent or prfs.max_extent
    if max_extent is None:
        return None
    weights = self._weights()

    # to derive color range values taking all inputs into account
    vmin = min(
        0,
        min(weights[prfs[target].index].min() for target in prfs.target_types),
    )

    vmax = max(
        0,
        max(weights[prfs[target].index].max() for target in prfs.target_types),
    )

    weights = weights[prfs[target].index]
    label = ""
    du, dv = prfs[target].du.values, prfs[target].dv.values
    fig, ax, (label_text, scalarmapper) = plots.kernel(
        du,
        dv,
        weights,
        label=label,
        fill=True,
        max_extent=max_extent,
        vmin=vmin,
        vmax=vmax,
        title=title.format(**locals()),
        **kwargs,
    )
    return fig

projective_fields_grid

projective_fields_grid(
    source,
    targets=None,
    fig=None,
    axes=None,
    figsize=[20, 20],
    ax_titles="{source} →: {target}",
    max_figure_height_cm=22,
    panel_height_cm=3,
    max_figure_width_cm=18,
    panel_width_cm=3.6,
    max_extent=None,
    sort_alphabetically=False,
    ignore_sign_error=False,
    **kwargs
)

Plot projective fields of a source cell type in a grid layout.

Parameters:

Name Type Description Default
source str

Source cell type.

required
targets Optional[Iterable[str]]

Iterable of target cell types. If None, all targets are used.

None
fig Optional[Figure]

Existing figure to plot on.

None
axes Optional[List[Axes]]

List of existing axes to plot on.

None
figsize List[int]

Figure size in inches.

[20, 20]
ax_titles str

Title format string for each subplot.

'{source} →: {target}'
max_figure_height_cm float

Maximum figure height in cm.

22
panel_height_cm float

Height of each panel in cm.

3
max_figure_width_cm float

Maximum figure width in cm.

18
panel_width_cm float

Width of each panel in cm.

3.6
max_extent Optional[int]

Maximum extent of the projective fields.

None
sort_alphabetically bool

Whether to sort target types alphabetically.

False
ignore_sign_error bool

Whether to ignore sign errors in plotting.

False
**kwargs

Additional arguments passed to projective_field.

{}

Returns:

Type Description
Figure

Matplotlib Figure object.

Source code in flyvision/connectome/connectome.py
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
def projective_fields_grid(
    self,
    source: str,
    targets: Optional[Iterable[str]] = None,
    fig: Optional[Figure] = None,
    axes: Optional[List[Axes]] = None,
    figsize: List[int] = [20, 20],
    ax_titles: str = "{source} →: {target}",
    max_figure_height_cm: float = 22,
    panel_height_cm: float = 3,
    max_figure_width_cm: float = 18,
    panel_width_cm: float = 3.6,
    max_extent: Optional[int] = None,
    sort_alphabetically: bool = False,
    ignore_sign_error: bool = False,
    **kwargs,
) -> Figure:
    """Plot projective fields of a source cell type in a grid layout.

    Args:
        source: Source cell type.
        targets: Iterable of target cell types. If None, all targets are used.
        fig: Existing figure to plot on.
        axes: List of existing axes to plot on.
        figsize: Figure size in inches.
        ax_titles: Title format string for each subplot.
        max_figure_height_cm: Maximum figure height in cm.
        panel_height_cm: Height of each panel in cm.
        max_figure_width_cm: Maximum figure width in cm.
        panel_width_cm: Width of each panel in cm.
        max_extent: Maximum extent of the projective fields.
        sort_alphabetically: Whether to sort target types alphabetically.
        ignore_sign_error: Whether to ignore sign errors in plotting.
        **kwargs: Additional arguments passed to projective_field.

    Returns:
        Matplotlib Figure object.
    """
    prfs = ProjectiveFields(source, self.edges.to_df())
    max_extent = max_extent or prfs.max_extent
    weights = self._weights()
    sorted_sum_of_outputs = dict(
        sorted(
            valmap(lambda v: weights[v.index].sum(), prfs).items(),
            key=lambda item: item[1],
            reverse=True,
        )
    )

    # to sort alphabetically in case sources is specified
    if sort_alphabetically:
        targets, _ = nodes_edges_utils.order_node_type_list(targets)

    targets = targets or list(sorted_sum_of_outputs.keys())

    vmin = min(0, min(weights[prfs[target].index].min() for target in targets))
    vmax = max(0, max(weights[prfs[target].index].max() for target in targets))

    if fig is None or axes is None:
        figsize = figsize_from_n_items(
            len(prfs.target_types),
            max_figure_height_cm=max_figure_height_cm,
            panel_height_cm=panel_height_cm,
            max_figure_width_cm=max_figure_width_cm,
            panel_width_cm=panel_width_cm,
        )
        fig, axes = figsize.axis_grid(
            unmask_n=len(prfs.target_types), hspace=0.0, wspace=0
        )

    cbar = kwargs.get("cbar", False)
    for i, target in enumerate(targets):
        if i == 0 and cbar:
            cbar = True
            kwargs.update(cbar=cbar)
        else:
            cbar = False
            kwargs.update(cbar=cbar)
        try:
            self.projective_field(
                source=source,
                target=target,
                fig=fig,
                ax=axes[i],
                title=ax_titles,
                prfs=prfs,
                max_extent=max_extent,
                vmin=vmin,
                vmax=vmax,
                annotate_coords=False,
                annotate=False,
                title_y=0.9,
                **kwargs,
            )
        except plots.SignError as e:
            if ignore_sign_error:
                pass
            else:
                raise e
    return fig

receptive_fields_df

receptive_fields_df(target_type)

Get receptive fields for a target cell type.

Parameters:

Name Type Description Default
target_type str

Target cell type.

required

Returns:

Type Description
ReceptiveFields

ReceptiveFields object.

Source code in flyvision/connectome/connectome.py
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
def receptive_fields_df(self, target_type: str) -> "ReceptiveFields":
    """Get receptive fields for a target cell type.

    Args:
        target_type: Target cell type.

    Returns:
        ReceptiveFields object.
    """
    return ReceptiveFields(target_type, self.edges.to_df())

projective_fields_df

projective_fields_df(source_type)

Get projective fields for a source cell type.

Parameters:

Name Type Description Default
source_type str

Source cell type.

required

Returns:

Type Description
ProjectiveFields

ProjectiveFields object.

Source code in flyvision/connectome/connectome.py
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
def projective_fields_df(self, source_type: str) -> "ProjectiveFields":
    """Get projective fields for a source cell type.

    Args:
        source_type: Source cell type.

    Returns:
        ProjectiveFields object.
    """
    return ProjectiveFields(source_type, self.edges.to_df())

receptive_fields_sum

receptive_fields_sum(target_type)

Get sum of synapses for each source type in the receptive field.

Parameters:

Name Type Description Default
target_type str

Target cell type.

required

Returns:

Type Description
Dict[str, int]

Dictionary mapping source types to synapse counts.

Source code in flyvision/connectome/connectome.py
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
def receptive_fields_sum(self, target_type: str) -> Dict[str, int]:
    """Get sum of synapses for each source type in the receptive field.

    Args:
        target_type: Target cell type.

    Returns:
        Dictionary mapping source types to synapse counts.
    """
    return ReceptiveFields(target_type, self.edges.to_df()).sum()

projective_fields_sum

projective_fields_sum(source_type)

Get sum of synapses for each target type in the projective field.

Parameters:

Name Type Description Default
source_type str

Source cell type.

required

Returns:

Type Description
Dict[str, int]

Dictionary mapping target types to synapse counts.

Source code in flyvision/connectome/connectome.py
1189
1190
1191
1192
1193
1194
1195
1196
1197
def projective_fields_sum(self, source_type: str) -> Dict[str, int]:
    """Get sum of synapses for each target type in the projective field.

    Args:
        source_type: Source cell type.

    Returns:
        Dictionary mapping target types to synapse counts.
    """

flyvision.connectome.connectome.ReceptiveFields

Bases: Namespace

Dictionary of receptive field dataframes for a specific cell type.

Parameters:

Name Type Description Default
target_type str

Target cell type.

required
edges DataFrame

All edges of a Connectome.

required

Attributes:

Name Type Description
target_type

The target cell type.

source_types

List of source cell types.

_extents

List of extents for each source type.

Example
rf = ReceptiveFields("T4a", edges_dataframe)
Source code in flyvision/connectome/connectome.py
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
class ReceptiveFields(Namespace):
    """Dictionary of receptive field dataframes for a specific cell type.

    Args:
        target_type: Target cell type.
        edges: All edges of a Connectome.

    Attributes:
        target_type: The target cell type.
        source_types: List of source cell types.
        _extents: List of extents for each source type.

    Example:
        ```python
        rf = ReceptiveFields("T4a", edges_dataframe)
        ```
    """

    def __init__(self, target_type: str, edges: DataFrame, *args, **kwargs):
        super().__init__(*args, **kwargs)
        object.__setattr__(self, "_extents", [])
        _receptive_fields_edge_dfs(self, target_type, edges)

    @property
    def extents(self) -> Dict[str, int]:
        """Dictionary of extents for each source type."""
        return dict(zip(self.source_types, self._extents))

    @property
    def max_extent(self) -> Optional[int]:
        """Maximum extent across all source types."""
        return max(self._extents) if self._extents else None

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.target_type})"

    def sum(self) -> Dict[str, float]:
        """Sum of synapses for each source type."""
        return {key: self[key].n_syn.sum() for key in self}

extents property

extents

Dictionary of extents for each source type.

max_extent property

max_extent

Maximum extent across all source types.

sum

sum()

Sum of synapses for each source type.

Source code in flyvision/connectome/connectome.py
1236
1237
1238
def sum(self) -> Dict[str, float]:
    """Sum of synapses for each source type."""
    return {key: self[key].n_syn.sum() for key in self}

flyvision.connectome.connectome.ProjectiveFields

Bases: Namespace

Dictionary of projective field dataframes for a specific cell type.

Parameters:

Name Type Description Default
source_type str

Source cell type.

required
edges DataFrame

All edges of a Connectome.

required

Attributes:

Name Type Description
source_type

The source cell type.

target_types

List of target cell types.

_extents

List of extents for each target type.

Example
pf = ProjectiveFields("Mi9", edges_dataframe)
Source code in flyvision/connectome/connectome.py
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
class ProjectiveFields(Namespace):
    """Dictionary of projective field dataframes for a specific cell type.

    Args:
        source_type: Source cell type.
        edges: All edges of a Connectome.

    Attributes:
        source_type: The source cell type.
        target_types: List of target cell types.
        _extents: List of extents for each target type.

    Example:
        ```python
        pf = ProjectiveFields("Mi9", edges_dataframe)
        ```
    """

    def __init__(self, source_type: str, edges: DataFrame, *args, **kwargs):
        super().__init__(*args, **kwargs)
        object.__setattr__(self, "_extents", [])
        _projective_fields_edge_dfs(self, source_type, edges)

    @property
    def extents(self) -> Dict[str, int]:
        """Dictionary of extents for each target type."""
        return dict(zip(self.target_types, self._extents))

    @property
    def max_extent(self) -> Optional[int]:
        """Maximum extent across all target types."""
        return max(self._extents) if self._extents else None

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.source_type})"

    def sum(self) -> Dict[str, float]:
        """Sum of synapses for each target type."""
        return {key: self[key].n_syn.sum() for key in self}

extents property

extents

Dictionary of extents for each target type.

max_extent property

max_extent

Maximum extent across all target types.

sum

sum()

Sum of synapses for each target type.

Source code in flyvision/connectome/connectome.py
1277
1278
1279
def sum(self) -> Dict[str, float]:
    """Sum of synapses for each target type."""
    return {key: self[key].n_syn.sum() for key in self}

Miscellaneous

flyvision.connectome.connectome.init_connectome

init_connectome(**kwargs)

Initialize a Connectome instance from a config dictionary.

Parameters:

Name Type Description Default
config

A dictionary containing the connectome configuration.

required

Returns:

Type Description
Connectome

An instance of a class implementing the Connectome(Protocol).

Raises:

Type Description
KeyError

If the specified connectome type is not available.

Example
config = {
    "type": "ConnectomeFromAvgFilters",
    **config
}
connectome = init_connectome(**config)
Source code in flyvision/connectome/connectome.py
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
def init_connectome(**kwargs) -> Connectome:
    """Initialize a Connectome instance from a config dictionary.

    Args:
        config: A dictionary containing the connectome configuration.

    Returns:
        An instance of a class implementing the Connectome(Protocol).

    Raises:
        KeyError: If the specified connectome type is not available.

    Example:
        ```python
        config = {
            "type": "ConnectomeFromAvgFilters",
            **config
        }
        connectome = init_connectome(**config)
        ```
    """
    connectome_class = AVAILABLE_CONNECTOMES[kwargs.pop("type")]

    connectome = connectome_class(**kwargs)
    is_valid, error_msg = is_connectome_protocol(connectome)
    assert is_valid, (
        f"Connectome class {connectome} does "
        f"not implement the Connectome(Protocol): {error_msg}"
    )
    return connectome

flyvision.connectome.connectome.get_avgfilt_connectome

get_avgfilt_connectome(config)

Create a ConnectomeView instance from a config for ConnectomeFromAvgFilters.

Parameters:

Name Type Description Default
config dict

Containing ConnectomeFromAvgFilters configuration.

required

Returns:

Type Description
ConnectomeView

ConnectomeView instance.

Source code in flyvision/connectome/connectome.py
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
def get_avgfilt_connectome(config: dict) -> ConnectomeView:
    """Create a ConnectomeView instance from a config for ConnectomeFromAvgFilters.

    Args:
        config: Containing ConnectomeFromAvgFilters configuration.

    Returns:
        ConnectomeView instance.
    """
    return ConnectomeView(ConnectomeFromAvgFilters(**config))