Skip to content

Utils

flyvis.utils.activity_utils

Classes

flyvis.utils.activity_utils.CellTypeActivity

Bases: dict

Base class for attribute-style access to network activity based on cell types.

Parameters:

Name Type Description Default
keepref bool

Whether to keep a reference to the activity. This may not be desired during training to avoid memory issues.

False

Attributes:

Name Type Description
activity Union[ref, NDArray, Tensor]

Weak reference to the activity.

keepref

Whether to keep a reference to the activity.

unique_cell_types List[str]

List of unique cell types.

input_indices NDArray

Indices of input cells.

output_indices NDArray

Indices of output cells.

Note

Activity is stored as a weakref by default for memory efficiency during training. Set keepref=True to keep a reference for analysis.

Source code in flyvis/utils/activity_utils.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class CellTypeActivity(dict):
    """Base class for attribute-style access to network activity based on cell types.

    Args:
        keepref: Whether to keep a reference to the activity. This may not be desired
            during training to avoid memory issues.

    Attributes:
        activity: Weak reference to the activity.
        keepref: Whether to keep a reference to the activity.
        unique_cell_types: List of unique cell types.
        input_indices: Indices of input cells.
        output_indices: Indices of output cells.

    Note:
        Activity is stored as a weakref by default for memory efficiency
        during training. Set keepref=True to keep a reference for analysis.
    """

    def __init__(self, keepref: bool = False):
        self.keepref = keepref
        self.activity: Union[weakref.ref, NDArray, torch.Tensor] = None
        self.unique_cell_types: List[str] = []
        self.input_indices: NDArray = np.array([])
        self.output_indices: NDArray = np.array([])

    def __dir__(self) -> List[str]:
        return list(set([*dict.__dir__(self), *dict.__iter__(self)]))

    def __len__(self) -> int:
        return len(self.unique_cell_types)

    def __iter__(self):
        yield from self.unique_cell_types

    def __repr__(self) -> str:
        return "Activity of: \n{}".format("\n".join(wrap(", ".join(list(self)))))

    def update(self, activity: Union[NDArray, torch.Tensor]) -> None:
        """Update the activity reference."""
        self.activity = activity

    def _slices(self, n: int) -> tuple:
        return tuple(slice(None) for _ in range(n))

    def __getattr__(self, key):
        activity = self.activity() if not self.keepref else self.activity
        if activity is None:
            return
        if isinstance(key, list):
            index = np.stack(list(map(lambda key: dict.__getitem__(self, key), key)))
            slices = self._slices(len(activity.shape) - 1)
            slices += (index,)
            return activity[slices]
        elif key == slice(None):
            return activity
        elif key in self.unique_cell_types:
            slices = self._slices(len(activity.shape) - 1)
            slices += (dict.__getitem__(self, key),)
            return activity[slices]
        elif key == "output":
            slices = self._slices(len(activity.shape) - 1)
            slices += (self.output_indices,)
            return activity[slices]
        elif key == "input":
            slices = self._slices(len(activity.shape) - 1)
            slices += (self.input_indices,)
            return activity[slices]
        elif key in self.__dict__:
            return self.__dict__[key]
        else:
            raise ValueError(f"{key}")

    def __getitem__(self, key):
        return self.__getattr__(key)

    def __setattr__(self, key, value):
        if key == "activity" and value is not None:
            if self.keepref is False:
                value = weakref.ref(value)
            object.__setattr__(self, key, value)
        else:
            object.__setattr__(self, key, value)
update
update(activity)

Update the activity reference.

Source code in flyvis/utils/activity_utils.py
71
72
73
def update(self, activity: Union[NDArray, torch.Tensor]) -> None:
    """Update the activity reference."""
    self.activity = activity

flyvis.utils.activity_utils.CentralActivity

Bases: CellTypeActivity

Attribute-style access to central cell activity of a cell type.

Parameters:

Name Type Description Default
activity Union[NDArray, Tensor]

Activity of shape (…, n_cells).

required
connectome ConnectomeFromAvgFilters

Connectome directory with reference to required attributes.

required
keepref bool

Whether to keep a reference to the activity.

False

Attributes:

Name Type Description
activity

Activity of shape (…, n_cells).

unique_cell_types

Array of unique cell types.

index

NodeIndexer instance.

input_indices

Array of input indices.

output_indices

Array of output indices.

Source code in flyvis/utils/activity_utils.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
class CentralActivity(CellTypeActivity):
    """Attribute-style access to central cell activity of a cell type.

    Args:
        activity: Activity of shape (..., n_cells).
        connectome: Connectome directory with reference to required attributes.
        keepref: Whether to keep a reference to the activity.

    Attributes:
        activity: Activity of shape (..., n_cells).
        unique_cell_types: Array of unique cell types.
        index: NodeIndexer instance.
        input_indices: Array of input indices.
        output_indices: Array of output indices.
    """

    def __init__(
        self,
        activity: Union[NDArray, torch.Tensor],
        connectome: ConnectomeFromAvgFilters,
        keepref: bool = False,
    ):
        super().__init__(keepref)
        self.index = nodes_edges_utils.NodeIndexer(connectome)

        unique_cell_types = connectome.unique_cell_types[:]
        input_cell_types = connectome.input_cell_types[:]
        output_cell_types = connectome.output_cell_types[:]
        self.input_indices = np.array([
            np.nonzero(unique_cell_types == t)[0] for t in input_cell_types
        ])
        self.output_indices = np.array([
            np.nonzero(unique_cell_types == t)[0] for t in output_cell_types
        ])
        self.activity = activity
        self.unique_cell_types = unique_cell_types.astype(str)

    def __getattr__(self, key):
        activity = self.activity() if not self.keepref else self.activity
        if activity is None:
            return
        if isinstance(key, list):
            index = np.stack(list(map(lambda key: self.index[key], key)))
            slices = self._slices(len(activity.shape) - 1)
            slices += (index,)
            return activity[slices]
        elif key == slice(None):
            return activity
        elif key in self.index.unique_cell_types:
            slices = self._slices(len(activity.shape) - 1)
            slices += (self.index[key],)
            return activity[slices]
        elif key == "output":
            slices = self._slices(len(activity.shape) - 1)
            slices += (self.output_indices,)
            return activity[slices]
        elif key == "input":
            slices = self._slices(len(activity.shape) - 1)
            slices += (self.input_indices,)
            return activity[slices]
        elif key in self.__dict__:
            return self.__dict__[key]
        else:
            raise ValueError(f"{key}")

    def __setattr__(self, key, value):
        if key == "activity" and value is not None:
            if len(self.index.unique_cell_types) != value.shape[-1]:
                slices = self._slices(len(value.shape) - 1)
                slices += (self.index.central_cells_index,)
                value = value[slices]
                self.keepref = True
            if self.keepref is False:
                value = weakref.ref(value)
            object.__setattr__(self, key, value)
        else:
            object.__setattr__(self, key, value)

    def __len__(self):
        return len(self.unique_cell_types)

    def __iter__(self):
        for cell_type in self.unique_cell_types:
            yield cell_type

flyvis.utils.activity_utils.LayerActivity

Bases: CellTypeActivity

Attribute-style access to hex-lattice activity (cell-type specific).

Parameters:

Name Type Description Default
activity Union[NDArray, Tensor]

Activity of shape (…, n_cells).

required
connectome ConnectomeFromAvgFilters

Connectome directory with reference to required attributes.

required
keepref bool

Whether to keep a reference to the activity.

False
use_central bool

Whether to use central activity.

True

Attributes:

Name Type Description
central

CentralActivity instance for central nodes.

activity

Activity of shape (…, n_cells).

connectome

Connectome directory.

unique_cell_types

Array of unique cell types.

input_indices

Array of input indices.

output_indices

Array of output indices.

input_cell_types

Array of input cell types.

output_cell_types

Array of output cell types.

n_nodes

Number of nodes.

Note

The name LayerActivity might change in future as it is misleading. This is not a feedforward layer in the machine learning sense but the activity of all cells of a certain cell-type.

Example:

Central activity can be accessed by:
```python
a = LayerActivity(activity, network.connectome)
central_T4a = a.central.T4a
```

Also allows 'virtual types' that are the sum of individuals:
```python
a = LayerActivity(activity, network.connectome)
summed_a = a['L2+L4']
```
Source code in flyvis/utils/activity_utils.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class LayerActivity(CellTypeActivity):
    """Attribute-style access to hex-lattice activity (cell-type specific).

    Args:
        activity: Activity of shape (..., n_cells).
        connectome: Connectome directory with reference to required attributes.
        keepref: Whether to keep a reference to the activity.
        use_central: Whether to use central activity.

    Attributes:
        central: CentralActivity instance for central nodes.
        activity: Activity of shape (..., n_cells).
        connectome: Connectome directory.
        unique_cell_types: Array of unique cell types.
        input_indices: Array of input indices.
        output_indices: Array of output indices.
        input_cell_types: Array of input cell types.
        output_cell_types: Array of output cell types.
        n_nodes: Number of nodes.

    Note:
        The name `LayerActivity` might change in future as it is misleading.
        This is not a feedforward layer in the machine learning sense but the
        activity of all cells of a certain cell-type.

    Example:

        Central activity can be accessed by:
        ```python
        a = LayerActivity(activity, network.connectome)
        central_T4a = a.central.T4a
        ```

        Also allows 'virtual types' that are the sum of individuals:
        ```python
        a = LayerActivity(activity, network.connectome)
        summed_a = a['L2+L4']
        ```
    """

    def __init__(
        self,
        activity: Union[NDArray, torch.Tensor],
        connectome: ConnectomeFromAvgFilters,
        keepref: bool = False,
        use_central: bool = True,
    ):
        super().__init__(keepref)
        self.keepref = keepref

        self.use_central = use_central
        if use_central:
            self.central = CentralActivity(activity, connectome, keepref)

        self.activity = activity
        self.connectome = connectome
        self.unique_cell_types = connectome.unique_cell_types[:].astype("str")
        for cell_type in self.unique_cell_types:
            index = connectome.nodes.layer_index[cell_type][:]
            self[cell_type] = index

        _cell_types = self.connectome.nodes.type[:]
        self.input_indices = np.array([
            np.nonzero(_cell_types == t)[0] for t in self.connectome.input_cell_types
        ])
        self.output_indices = np.array([
            np.nonzero(_cell_types == t)[0] for t in self.connectome.output_cell_types
        ])
        self.input_cell_types = self.connectome.input_cell_types[:].astype(str)
        self.output_cell_types = self.connectome.output_cell_types[:].astype(str)
        self.n_nodes = len(self.connectome.nodes.type)

    def __setattr__(self, key, value):
        if key == "activity" and value is not None:
            if self.keepref is False:
                value = weakref.ref(value)

            if self.use_central:
                self.central.__setattr__(key, value)

            object.__setattr__(self, key, value)
        else:
            object.__setattr__(self, key, value)

flyvis.utils.activity_utils.SourceCurrentView

Create views of source currents for a target type.

Parameters:

Name Type Description Default
rfs ReceptiveFields

ReceptiveFields instance.

required
currents Union[NDArray, Tensor]

Current values.

required

Attributes:

Name Type Description
target_type

Target cell type.

source_types

List of source cell types.

rfs

ReceptiveFields instance.

currents

Current values.

Source code in flyvis/utils/activity_utils.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
class SourceCurrentView:
    """Create views of source currents for a target type.

    Args:
        rfs: ReceptiveFields instance.
        currents: Current values.

    Attributes:
        target_type: Target cell type.
        source_types: List of source cell types.
        rfs: ReceptiveFields instance.
        currents: Current values.
    """

    def __init__(self, rfs: ReceptiveFields, currents: Union[NDArray, torch.Tensor]):
        self.target_type = rfs.target_type
        self.source_types = list(rfs)
        self.rfs = rfs
        self.currents = currents

    def __getattr__(self, key: str) -> Union[NDArray, torch.Tensor]:
        if key in self.source_types:
            return np.take(self.currents, self.rfs[key].index, axis=-1)
        return object.__getattr__(self, key)

    def __getitem__(self, key: str) -> Union[NDArray, torch.Tensor]:
        return self.__getattr__(key)

    def update(self, currents: Union[NDArray, torch.Tensor]) -> None:
        """Update the currents."""
        self.currents = currents
update
update(currents)

Update the currents.

Source code in flyvis/utils/activity_utils.py
317
318
319
def update(self, currents: Union[NDArray, torch.Tensor]) -> None:
    """Update the currents."""
    self.currents = currents

flyvis.utils.cache_utils

Functions

flyvis.utils.cache_utils.context_aware_cache

context_aware_cache(func=None, context=lambda self: None)

Decorator to cache the result of a method based on its arguments and context.

Parameters:

Name Type Description Default
func Callable[..., T]

The function to be decorated.

None
context Callable[[Any], Any]

A function that returns the context for caching.

lambda self: None

Returns:

Type Description
Callable[..., T]

A wrapped function that implements caching based on arguments and context.

Example
class MyClass:
    def __init__(self):
        self.cache = {}

    @context_aware_cache(context=lambda self: self.some_attribute)
    def my_method(self, arg1, arg2):
        # Method implementation
        pass
Source code in flyvis/utils/cache_utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def context_aware_cache(
    func: Callable[..., T] = None, context: Callable[[Any], Any] = lambda self: None
) -> Callable[..., T]:
    """
    Decorator to cache the result of a method based on its arguments and context.

    Args:
        func: The function to be decorated.
        context: A function that returns the context for caching.

    Returns:
        A wrapped function that implements caching based on arguments and context.

    Example:
        ```python
        class MyClass:
            def __init__(self):
                self.cache = {}

            @context_aware_cache(context=lambda self: self.some_attribute)
            def my_method(self, arg1, arg2):
                # Method implementation
                pass
        ```
    """
    if func is None:

        def decorator(f: Callable[..., T]) -> Callable[..., T]:
            @wraps(f)
            def wrapper(self: Any, *args: Any, **kwargs: Any) -> T:
                context_key = make_hashable(context(self))
                cache_key = hash(make_hashable((f.__name__, args, kwargs, context_key)))
                if cache_key in self.cache:
                    return self.cache[cache_key]
                result = f(self, *args, **kwargs)
                self.cache[cache_key] = result
                return result

            return wrapper

        return decorator
    else:

        @wraps(func)
        def wrapper(self: Any, *args: Any, **kwargs: Any) -> T:
            context_key = make_hashable(context(self))
            cache_key = hash(make_hashable((func.__name__, args, kwargs, context_key)))
            if cache_key in self.cache:
                return self.cache[cache_key]
            result = func(self, *args, **kwargs)
            self.cache[cache_key] = result
            return result

        return wrapper

flyvis.utils.cache_utils.make_hashable

make_hashable(obj)

Recursively converts an object into a hashable type.

Source code in flyvis/utils/cache_utils.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def make_hashable(obj: Any) -> Any:
    """Recursively converts an object into a hashable type."""
    if isinstance(obj, (int, float, str, bool, type(None))):
        return obj
    elif isinstance(obj, (list, set)):
        try:
            # Try direct sorting first
            return tuple(make_hashable(e) for e in sorted(obj))
        except TypeError:
            # Fall back to sorting by hash
            return tuple(
                make_hashable(e)
                for e in sorted(obj, key=lambda x: hash(make_hashable(x)))
            )
    elif isinstance(obj, dict):
        try:
            # Try direct sorting of keys first
            return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
        except TypeError:
            # Fall back to sorting by hash of keys
            return tuple(
                sorted(
                    ((k, make_hashable(v)) for k, v in obj.items()),
                    key=lambda x: hash(make_hashable(x[0])),
                )
            )
    elif isinstance(obj, (tuple, frozenset)):
        return tuple(make_hashable(e) for e in obj)
    elif isinstance(obj, slice):
        return (obj.start, obj.stop, obj.step)
    else:
        # For other types, try to get a consistent string representation
        return f"{obj.__class__.__module__}.{obj.__class__.__name__}:{str(obj)}"

flyvis.utils.chkpt_utils

Classes

flyvis.utils.chkpt_utils.Checkpoints dataclass

Dataclass to store checkpoint information.

Attributes:

Name Type Description
indices List[int]

List of checkpoint indices.

paths List[Path]

List of checkpoint paths.

Source code in flyvis/utils/chkpt_utils.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
@dataclass
class Checkpoints:
    """
    Dataclass to store checkpoint information.

    Attributes:
        indices: List of checkpoint indices.
        paths: List of checkpoint paths.
    """

    indices: List[int]
    paths: List[Path]

    def __repr__(self):
        return (
            f"Checkpoints(\n"
            f"  indices={repr(self.indices)},\n"
            f"  paths={repr(self.paths)},\n"
            f")"
        )

Functions

flyvis.utils.chkpt_utils.recover_network

recover_network(network, state_dict, ensemble_and_network_id=None)

Load network parameters from state dict.

Parameters:

Name Type Description Default
network Module

FlyVision network.

required
state_dict Union[Dict, Path, str]

State or path to checkpoint containing the “network” parameters.

required
ensemble_and_network_id str

Optional identifier for the network.

None

Returns:

Type Description
Module

The updated network.

Source code in flyvis/utils/chkpt_utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def recover_network(
    network: nn.Module,
    state_dict: Union[Dict, Path, str],
    ensemble_and_network_id: str = None,
) -> nn.Module:
    """
    Load network parameters from state dict.

    Args:
        network: FlyVision network.
        state_dict: State or path to checkpoint containing the "network" parameters.
        ensemble_and_network_id: Optional identifier for the network.

    Returns:
        The updated network.
    """
    state = get_from_state_dict(state_dict, "network")
    if state is not None:
        network.load_state_dict(state)
        logging.info(
            "Recovered network state%s",
            f" {ensemble_and_network_id}." if ensemble_and_network_id else ".",
        )
    else:
        logging.warning("Could not recover network state.")
    return network

flyvis.utils.chkpt_utils.recover_decoder

recover_decoder(decoder, state_dict, strict=True)

Recover multiple decoders from state dict.

Parameters:

Name Type Description Default
decoder Dict[str, Module]

Dictionary of decoders.

required
state_dict Union[Dict, Path]

State or path to checkpoint.

required
strict bool

Whether to strictly enforce that the keys in state_dict match.

True

Returns:

Type Description
Dict[str, Module]

The updated dictionary of decoders.

Source code in flyvis/utils/chkpt_utils.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def recover_decoder(
    decoder: Dict[str, nn.Module], state_dict: Union[Dict, Path], strict: bool = True
) -> Dict[str, nn.Module]:
    """
    Recover multiple decoders from state dict.

    Args:
        decoder: Dictionary of decoders.
        state_dict: State or path to checkpoint.
        strict: Whether to strictly enforce that the keys in state_dict match.

    Returns:
        The updated dictionary of decoders.
    """
    states = get_from_state_dict(state_dict, "decoder")
    if states is not None:
        for key, dec in decoder.items():
            state = states.pop(key, None)
            if state is not None:
                dec.load_state_dict(state, strict=strict)
                logging.info("Recovered %s decoder state.", key)
            else:
                logging.warning("Could not recover state of %s decoder.", key)
    else:
        logging.warning("Could not recover decoder states.")
    return decoder

flyvis.utils.chkpt_utils.recover_optimizer

recover_optimizer(optimizer, state_dict)

Recover optimizer state from state dict.

Parameters:

Name Type Description Default
optimizer Optimizer

PyTorch optimizer.

required
state_dict Union[Dict, Path]

State or path to checkpoint.

required

Returns:

Type Description
Optimizer

The updated optimizer.

Source code in flyvis/utils/chkpt_utils.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def recover_optimizer(
    optimizer: torch.optim.Optimizer, state_dict: Union[Dict, Path]
) -> torch.optim.Optimizer:
    """
    Recover optimizer state from state dict.

    Args:
        optimizer: PyTorch optimizer.
        state_dict: State or path to checkpoint.

    Returns:
        The updated optimizer.
    """
    state = get_from_state_dict(state_dict, "optim")
    if state is not None:
        optimizer.load_state_dict(state)
        logging.info("Recovered optimizer state.")
    else:
        logging.warning("Could not recover optimizer state.")
    return optimizer

flyvis.utils.chkpt_utils.recover_penalty_optimizers

recover_penalty_optimizers(optimizers, state_dict)

Recover penalty optimizers from state dict.

Parameters:

Name Type Description Default
optimizers Dict[str, Optimizer]

Dictionary of penalty optimizers.

required
state_dict Union[Dict, Path]

State or path to checkpoint.

required

Returns:

Type Description
Dict[str, Optimizer]

The updated dictionary of penalty optimizers.

Source code in flyvis/utils/chkpt_utils.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def recover_penalty_optimizers(
    optimizers: Dict[str, torch.optim.Optimizer], state_dict: Union[Dict, Path]
) -> Dict[str, torch.optim.Optimizer]:
    """
    Recover penalty optimizers from state dict.

    Args:
        optimizers: Dictionary of penalty optimizers.
        state_dict: State or path to checkpoint.

    Returns:
        The updated dictionary of penalty optimizers.
    """
    states = get_from_state_dict(state_dict, "penalty_optims")
    if states is not None:
        for key, optim in optimizers.items():
            state = states.pop(key, None)
            if state is not None:
                optim.load_state_dict(state)
                logging.info("Recovered %s optimizer state.", key)
            else:
                logging.warning("Could not recover state of %s optimizer.", key)
    else:
        logging.warning("Could not recover penalty optimizer states.")
    return optimizers

flyvis.utils.chkpt_utils.get_from_state_dict

get_from_state_dict(state_dict, key)

Get a specific key from the state dict.

Parameters:

Name Type Description Default
state_dict Union[Dict, Path, str]

State dict or path to checkpoint.

required
key str

Key to retrieve from the state dict.

required

Returns:

Type Description
Dict

The value associated with the key in the state dict.

Raises:

Type Description
TypeError

If state_dict is not of type Path, str, or dict.

Source code in flyvis/utils/chkpt_utils.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def get_from_state_dict(state_dict: Union[Dict, Path, str], key: str) -> Dict:
    """
    Get a specific key from the state dict.

    Args:
        state_dict: State dict or path to checkpoint.
        key: Key to retrieve from the state dict.

    Returns:
        The value associated with the key in the state dict.

    Raises:
        TypeError: If state_dict is not of type Path, str, or dict.
    """
    if state_dict is None:
        return None
    if isinstance(state_dict, (Path, str)):
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=FutureWarning)
            state = torch.load(
                state_dict, map_location=flyvis.device, weights_only=False
            ).pop(key, None)
    elif isinstance(state_dict, dict):
        state = state_dict.get(key, None)
    else:
        raise TypeError(
            f"state_dict must be of type Path, str or dict, but is {type(state_dict)}."
        )
    return state

flyvis.utils.chkpt_utils.resolve_checkpoints

resolve_checkpoints(networkdir)

Resolve checkpoints from network directory.

Parameters:

Name Type Description Default
networkdir NetworkDir

FlyVision network directory.

required

Returns:

Type Description
Checkpoints

A Checkpoints object containing indices and paths of checkpoints.

Source code in flyvis/utils/chkpt_utils.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def resolve_checkpoints(
    networkdir: "flyvis.network.NetworkDir",
) -> Checkpoints:
    """
    Resolve checkpoints from network directory.

    Args:
        networkdir: FlyVision network directory.

    Returns:
        A Checkpoints object containing indices and paths of checkpoints.
    """
    indices, paths = checkpoint_index_to_path_map(networkdir.chkpts.path)
    return Checkpoints(indices, paths)

flyvis.utils.chkpt_utils.checkpoint_index_to_path_map

checkpoint_index_to_path_map(path, glob='chkpt_*')

Returns all numerical identifiers and paths to checkpoints stored in path.

Parameters:

Name Type Description Default
path Path

Checkpoint directory.

required
glob str

Glob pattern for checkpoint files.

'chkpt_*'

Returns:

Type Description
Tuple[List[int], List[Path]]

A tuple containing a list of indices and a list of paths to checkpoints.

Source code in flyvis/utils/chkpt_utils.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def checkpoint_index_to_path_map(
    path: Path, glob: str = "chkpt_*"
) -> Tuple[List[int], List[Path]]:
    """
    Returns all numerical identifiers and paths to checkpoints stored in path.

    Args:
        path: Checkpoint directory.
        glob: Glob pattern for checkpoint files.

    Returns:
        A tuple containing a list of indices and a list of paths to checkpoints.
    """
    import re

    path.mkdir(exist_ok=True)
    paths = np.array(sorted(list((path).glob(glob))))
    try:
        _index = [int(re.findall(r"\d{1,10}", p.parts[-1])[0]) for p in paths]
        _sorting_index = np.argsort(_index)
        paths = paths[_sorting_index].tolist()
        index = np.array(_index)[_sorting_index].tolist()
        return index, paths
    except IndexError:
        return [], paths

flyvis.utils.chkpt_utils.best_checkpoint_default_fn

best_checkpoint_default_fn(path, validation_subdir='validation', loss_file_name='loss')

Find the best checkpoint based on the minimum loss.

Parameters:

Name Type Description Default
path Path

Path to the network directory.

required
validation_subdir str

Subdirectory containing validation data.

'validation'
loss_file_name str

Name of the loss file.

'loss'

Returns:

Type Description
Path

Path to the best checkpoint.

Source code in flyvis/utils/chkpt_utils.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def best_checkpoint_default_fn(
    path: Path,
    validation_subdir: str = "validation",
    loss_file_name: str = "loss",
) -> Path:
    """
    Find the best checkpoint based on the minimum loss.

    Args:
        path: Path to the network directory.
        validation_subdir: Subdirectory containing validation data.
        loss_file_name: Name of the loss file.

    Returns:
        Path to the best checkpoint.
    """
    networkdir = flyvis.NetworkDir(path)
    checkpoint_dir = networkdir.chkpts.path
    indices, paths = checkpoint_index_to_path_map(checkpoint_dir, glob="chkpt_*")
    loss_file_name = check_loss_name(networkdir[validation_subdir], loss_file_name)
    index = np.argmin(networkdir[validation_subdir][loss_file_name][()])
    index = indices[index]
    path = paths[index]
    return path

flyvis.utils.chkpt_utils.check_loss_name

check_loss_name(loss_folder, loss_file_name)

Check if the loss file name exists in the loss folder.

Parameters:

Name Type Description Default
loss_folder

The folder containing loss files.

required
loss_file_name str

The name of the loss file to check.

required

Returns:

Type Description
str

The validated loss file name.

Source code in flyvis/utils/chkpt_utils.py
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
def check_loss_name(loss_folder, loss_file_name: str) -> str:
    """
    Check if the loss file name exists in the loss folder.

    Args:
        loss_folder: The folder containing loss files.
        loss_file_name: The name of the loss file to check.

    Returns:
        The validated loss file name.
    """
    if loss_file_name not in loss_folder and "loss" in loss_folder:
        warn_once(
            logging,
            f"{loss_file_name} not in {loss_folder.path}, but 'loss' is. "
            "Falling back to 'loss'. You can rerun the ensemble validation to make "
            "appropriate recordings of the losses.",
        )
        loss_file_name = "loss"
    return loss_file_name

flyvis.utils.class_utils

Functions

flyvis.utils.class_utils.find_subclass

find_subclass(cls, target_subclass_name)

Recursively search for the target subclass.

Parameters:

Name Type Description Default
cls Type

The base class to start the search from.

required
target_subclass_name str

The name of the subclass to find.

required

Returns:

Type Description
Optional[Type]

The found subclass, or None if not found.

Source code in flyvis/utils/class_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def find_subclass(cls: Type, target_subclass_name: str) -> Optional[Type]:
    """
    Recursively search for the target subclass.

    Args:
        cls: The base class to start the search from.
        target_subclass_name: The name of the subclass to find.

    Returns:
        The found subclass, or None if not found.
    """
    for subclass in cls.__subclasses__():
        if subclass.__qualname__ == target_subclass_name:
            return subclass
        # Recursively check the subclasses of the current subclass
        found_subclass = find_subclass(subclass, target_subclass_name)
        if found_subclass is not None:
            return found_subclass
    return None

flyvis.utils.class_utils.forward_subclass

forward_subclass(cls, config={}, subclass_key='type', unpack_kwargs=True)

Forward to a subclass based on the <subclass_key> key in config.

Forwards to the parent class if <subclass_key> is not in config.

Parameters:

Name Type Description Default
cls Type

The base class to forward from.

required
config Dict[str, Any]

Configuration dictionary containing subclass information.

{}
subclass_key str

Key in the config dictionary specifying the subclass.

'type'
unpack_kwargs bool

Whether to unpack kwargs when initializing the instance.

True

Returns:

Type Description
Any

An instance of the specified subclass or the base class.

Note

If the specified subclass is not found, a warning is issued and the base class is used instead.

Source code in flyvis/utils/class_utils.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def forward_subclass(
    cls: Type,
    config: Dict[str, Any] = {},
    subclass_key: str = "type",
    unpack_kwargs: bool = True,
) -> Any:
    """
    Forward to a subclass based on the `<subclass_key>` key in `config`.

    Forwards to the parent class if `<subclass_key>` is not in `config`.

    Args:
        cls: The base class to forward from.
        config: Configuration dictionary containing subclass information.
        subclass_key: Key in the config dictionary specifying the subclass.
        unpack_kwargs: Whether to unpack kwargs when initializing the instance.

    Returns:
        An instance of the specified subclass or the base class.

    Note:
        If the specified subclass is not found, a warning is issued and the base
        class is used instead.
    """
    config = deepcopy(config)
    target_subclass = config.pop(subclass_key, None)

    # Prepare kwargs by removing the subclass_key if it exists
    kwargs = {k: v for k, v in config.items() if k != subclass_key}

    def init_with_kwargs(instance: Any) -> None:
        if unpack_kwargs:
            instance.__init__(**kwargs)
        else:
            instance.__init__(kwargs)

    if target_subclass is not None:
        # Find the target subclass recursively
        subclass = find_subclass(cls, target_subclass)
        if subclass is not None:
            instance = object.__new__(subclass)
            init_with_kwargs(instance)
            return instance
        else:
            warn(
                f"Unrecognized {subclass_key} {target_subclass}. "
                f"Using {cls.__qualname__}.",
                stacklevel=2,
            )
    else:
        warn(f"Missing {subclass_key} in config. Using {cls.__qualname__}.", stacklevel=2)

    # Default case: create an instance of the base class
    instance = object.__new__(cls)
    init_with_kwargs(instance)
    return instance

flyvis.utils.color_utils

Classes

flyvis.utils.color_utils.cmap_iter

An iterator for colormap colors.

Attributes:

Name Type Description
i int

The current index.

cmap

The colormap to iterate over.

stop int

The number of colors in the colormap.

Source code in flyvis/utils/color_utils.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
class cmap_iter:
    """
    An iterator for colormap colors.

    Attributes:
        i: The current index.
        cmap: The colormap to iterate over.
        stop: The number of colors in the colormap.
    """

    def __init__(self, cmap: Union[LinearSegmentedColormap, ListedColormap]):
        """
        Initialize the cmap_iter.

        Args:
            cmap: The colormap to iterate over.
        """
        self.i: int = 0
        self.cmap = cmap
        self.stop: int = cmap.N

    def __next__(self) -> Tuple[float, float, float, float]:
        """
        Get the next color from the colormap.

        Returns:
            The next color as an RGBA tuple.
        """
        if self.i < self.stop:
            self.i += 1
            return self.cmap(self.i - 1)

    def _repr_html_(self) -> str:
        """
        Return the HTML representation of the colormap.

        Returns:
            The HTML representation of the colormap.
        """
        return self.cmap._repr_html_()
__init__
__init__(cmap)

Initialize the cmap_iter.

Parameters:

Name Type Description Default
cmap Union[LinearSegmentedColormap, ListedColormap]

The colormap to iterate over.

required
Source code in flyvis/utils/color_utils.py
274
275
276
277
278
279
280
281
282
283
def __init__(self, cmap: Union[LinearSegmentedColormap, ListedColormap]):
    """
    Initialize the cmap_iter.

    Args:
        cmap: The colormap to iterate over.
    """
    self.i: int = 0
    self.cmap = cmap
    self.stop: int = cmap.N
__next__
__next__()

Get the next color from the colormap.

Returns:

Type Description
Tuple[float, float, float, float]

The next color as an RGBA tuple.

Source code in flyvis/utils/color_utils.py
285
286
287
288
289
290
291
292
293
294
def __next__(self) -> Tuple[float, float, float, float]:
    """
    Get the next color from the colormap.

    Returns:
        The next color as an RGBA tuple.
    """
    if self.i < self.stop:
        self.i += 1
        return self.cmap(self.i - 1)

Functions

flyvis.utils.color_utils.is_hex

is_hex(color)

Check if the given color is in hexadecimal format.

Parameters:

Name Type Description Default
color Union[str, Tuple[float, float, float]]

The color to check.

required

Returns:

Type Description
bool

True if the color is in hexadecimal format, False otherwise.

Source code in flyvis/utils/color_utils.py
32
33
34
35
36
37
38
39
40
41
42
def is_hex(color: Union[str, Tuple[float, float, float]]) -> bool:
    """
    Check if the given color is in hexadecimal format.

    Args:
        color: The color to check.

    Returns:
        True if the color is in hexadecimal format, False otherwise.
    """
    return "#" in color

flyvis.utils.color_utils.is_integer_rgb

is_integer_rgb(color)

Check if the given color is in integer RGB format (0-255).

Parameters:

Name Type Description Default
color Union[Tuple[float, float, float], List[float]]

The color to check.

required

Returns:

Type Description
bool

True if the color is in integer RGB format, False otherwise.

Source code in flyvis/utils/color_utils.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def is_integer_rgb(color: Union[Tuple[float, float, float], List[float]]) -> bool:
    """
    Check if the given color is in integer RGB format (0-255).

    Args:
        color: The color to check.

    Returns:
        True if the color is in integer RGB format, False otherwise.
    """
    try:
        return any([c > 1 for c in color])
    except TypeError:
        return False

flyvis.utils.color_utils.single_color_cmap

single_color_cmap(color)

Create a single color colormap.

Parameters:

Name Type Description Default
color Union[str, Tuple[float, float, float]]

The color to use for the colormap.

required

Returns:

Type Description
ListedColormap

A ListedColormap object with the specified color.

Source code in flyvis/utils/color_utils.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def single_color_cmap(color: Union[str, Tuple[float, float, float]]) -> ListedColormap:
    """
    Create a single color colormap.

    Args:
        color: The color to use for the colormap.

    Returns:
        A ListedColormap object with the specified color.
    """
    if is_hex(color):
        color = to_rgba(color)
    elif is_integer_rgb(color):
        color = np.array(color) / 255
    return ListedColormap(color)

flyvis.utils.color_utils.color_to_cmap

color_to_cmap(end_color, start_color='#FFFFFF', name='custom_cmap', N=256)

Create a colormap from start and end colors.

Parameters:

Name Type Description Default
end_color str

The end color of the colormap.

required
start_color str

The start color of the colormap.

'#FFFFFF'
name str

The name of the colormap.

'custom_cmap'
N int

The number of color segments.

256

Returns:

Type Description
LinearSegmentedColormap

A LinearSegmentedColormap object.

Source code in flyvis/utils/color_utils.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def color_to_cmap(
    end_color: str,
    start_color: str = "#FFFFFF",
    name: str = "custom_cmap",
    N: int = 256,
) -> LinearSegmentedColormap:
    """
    Create a colormap from start and end colors.

    Args:
        end_color: The end color of the colormap.
        start_color: The start color of the colormap.
        name: The name of the colormap.
        N: The number of color segments.

    Returns:
        A LinearSegmentedColormap object.
    """
    return LinearSegmentedColormap.from_list(
        name,
        [hex2color(start_color), hex2color(end_color)],
        N=N,
    )

flyvis.utils.color_utils.get_alpha_colormap

get_alpha_colormap(saturated_color, number_of_shades)

Create a colormap from a color and a number of shades.

Parameters:

Name Type Description Default
saturated_color Union[str, Tuple[float, float, float]]

The base color for the colormap.

required
number_of_shades int

The number of shades to create.

required

Returns:

Type Description
ListedColormap

A ListedColormap object with varying alpha values.

Source code in flyvis/utils/color_utils.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def get_alpha_colormap(
    saturated_color: Union[str, Tuple[float, float, float]], number_of_shades: int
) -> ListedColormap:
    """
    Create a colormap from a color and a number of shades.

    Args:
        saturated_color: The base color for the colormap.
        number_of_shades: The number of shades to create.

    Returns:
        A ListedColormap object with varying alpha values.
    """
    if is_hex(saturated_color):
        rgba = [*hex2color(saturated_color)[:3], 0]
    elif is_integer_rgb(saturated_color):
        rgba = [*list(np.array(saturated_color) / 255.0), 0]

    colors = []
    alphas = np.linspace(1 / number_of_shades, 1, number_of_shades)[::-1]
    for alpha in alphas:
        rgba[-1] = alpha
        colors.append(rgba.copy())

    return ListedColormap(colors)

flyvis.utils.color_utils.adapt_color_alpha

adapt_color_alpha(color, alpha)

Transform a color specification to RGBA and adapt the alpha value.

Parameters:

Name Type Description Default
color Union[str, Tuple[float, float, float], Tuple[float, float, float, float]]

Color specification in various formats: hex string, RGB tuple, or RGBA tuple.

required
alpha float

New alpha value to be applied.

required

Returns:

Type Description
Tuple[float, float, float, float]

The adapted color in RGBA format.

Source code in flyvis/utils/color_utils.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def adapt_color_alpha(
    color: Union[str, Tuple[float, float, float], Tuple[float, float, float, float]],
    alpha: float,
) -> Tuple[float, float, float, float]:
    """
    Transform a color specification to RGBA and adapt the alpha value.

    Args:
        color: Color specification in various formats: hex string, RGB tuple, or
            RGBA tuple.
        alpha: New alpha value to be applied.

    Returns:
        The adapted color in RGBA format.
    """
    color_rgb = to_rgba(color)
    r, g, b, _ = color_rgb
    return r, g, b, alpha

flyvis.utils.color_utils.flash_response_color_labels

flash_response_color_labels(ax)

Apply color labels for ON and OFF flash responses.

Parameters:

Name Type Description Default
ax

The matplotlib axis to apply the labels to.

required

Returns:

Type Description

The modified matplotlib axis.

Source code in flyvis/utils/color_utils.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def flash_response_color_labels(ax):
    """
    Apply color labels for ON and OFF flash responses.

    Args:
        ax: The matplotlib axis to apply the labels to.

    Returns:
        The modified matplotlib axis.
    """
    on = [key for key, value in polarity.items() if value == 1]
    off = [key for key, value in polarity.items() if value == -1]
    color_labels(on, ON_FR, ax)
    color_labels(off, OFF_FR, ax)
    return ax

flyvis.utils.color_utils.truncate_colormap

truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100)

Truncate a colormap to a specific range.

Parameters:

Name Type Description Default
cmap Union[LinearSegmentedColormap, ListedColormap]

The colormap to truncate.

required
minval float

The minimum value of the new range.

0.0
maxval float

The maximum value of the new range.

1.0
n int

The number of color segments in the new colormap.

100

Returns:

Type Description
LinearSegmentedColormap

A new LinearSegmentedColormap with the truncated range.

Source code in flyvis/utils/color_utils.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def truncate_colormap(
    cmap: Union[LinearSegmentedColormap, ListedColormap],
    minval: float = 0.0,
    maxval: float = 1.0,
    n: int = 100,
) -> LinearSegmentedColormap:
    """
    Truncate a colormap to a specific range.

    Args:
        cmap: The colormap to truncate.
        minval: The minimum value of the new range.
        maxval: The maximum value of the new range.
        n: The number of color segments in the new colormap.

    Returns:
        A new LinearSegmentedColormap with the truncated range.
    """
    new_cmap = LinearSegmentedColormap.from_list(
        "trunc({n},{a:.2f},{b:.2f})".format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, max(n, 2))),
    )
    return new_cmap.resampled(max(n, 2))

flyvis.utils.compute_cloud_utils

Classes

flyvis.utils.compute_cloud_utils.ClusterManager

Bases: ABC

Abstract base class for cluster management operations.

Source code in flyvis/utils/compute_cloud_utils.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class ClusterManager(ABC):
    """Abstract base class for cluster management operations."""

    @abstractmethod
    def run_job(self, command: str) -> str:
        """
        Run a job on the cluster.

        Args:
            command: The command to run.

        Returns:
            The job ID as a string.
        """
        pass

    @abstractmethod
    def is_running(self, job_id: str) -> bool:
        """
        Check if a job is running.

        Args:
            job_id: The ID of the job to check.

        Returns:
            True if the job is running, False otherwise.
        """
        pass

    @abstractmethod
    def kill_job(self, job_id: str) -> str:
        """
        Kill a running job.

        Args:
            job_id: The ID of the job to kill.

        Returns:
            The output of the kill command.
        """
        pass

    @abstractmethod
    def get_submit_command(
        self, job_name: str, n_cpus: int, output_file: str, gpu: str, queue: str
    ) -> str:
        """
        Get the command to submit a job to the cluster.

        Args:
            job_name: The name of the job.
            n_cpus: The number of CPUs to request.
            output_file: The file to write job output to.
            gpu: The GPU configuration.
            queue: The queue to submit the job to.

        Returns:
            The submit command as a string.
        """
        pass

    @abstractmethod
    def get_script_part(self, command: str) -> str:
        """
        Get the script part of the command.

        Args:
            command: The command to wrap.

        Returns:
            The wrapped command as a string.
        """
        pass
run_job abstractmethod
run_job(command)

Run a job on the cluster.

Parameters:

Name Type Description Default
command str

The command to run.

required

Returns:

Type Description
str

The job ID as a string.

Source code in flyvis/utils/compute_cloud_utils.py
25
26
27
28
29
30
31
32
33
34
35
36
@abstractmethod
def run_job(self, command: str) -> str:
    """
    Run a job on the cluster.

    Args:
        command: The command to run.

    Returns:
        The job ID as a string.
    """
    pass
is_running abstractmethod
is_running(job_id)

Check if a job is running.

Parameters:

Name Type Description Default
job_id str

The ID of the job to check.

required

Returns:

Type Description
bool

True if the job is running, False otherwise.

Source code in flyvis/utils/compute_cloud_utils.py
38
39
40
41
42
43
44
45
46
47
48
49
@abstractmethod
def is_running(self, job_id: str) -> bool:
    """
    Check if a job is running.

    Args:
        job_id: The ID of the job to check.

    Returns:
        True if the job is running, False otherwise.
    """
    pass
kill_job abstractmethod
kill_job(job_id)

Kill a running job.

Parameters:

Name Type Description Default
job_id str

The ID of the job to kill.

required

Returns:

Type Description
str

The output of the kill command.

Source code in flyvis/utils/compute_cloud_utils.py
51
52
53
54
55
56
57
58
59
60
61
62
@abstractmethod
def kill_job(self, job_id: str) -> str:
    """
    Kill a running job.

    Args:
        job_id: The ID of the job to kill.

    Returns:
        The output of the kill command.
    """
    pass
get_submit_command abstractmethod
get_submit_command(job_name, n_cpus, output_file, gpu, queue)

Get the command to submit a job to the cluster.

Parameters:

Name Type Description Default
job_name str

The name of the job.

required
n_cpus int

The number of CPUs to request.

required
output_file str

The file to write job output to.

required
gpu str

The GPU configuration.

required
queue str

The queue to submit the job to.

required

Returns:

Type Description
str

The submit command as a string.

Source code in flyvis/utils/compute_cloud_utils.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
@abstractmethod
def get_submit_command(
    self, job_name: str, n_cpus: int, output_file: str, gpu: str, queue: str
) -> str:
    """
    Get the command to submit a job to the cluster.

    Args:
        job_name: The name of the job.
        n_cpus: The number of CPUs to request.
        output_file: The file to write job output to.
        gpu: The GPU configuration.
        queue: The queue to submit the job to.

    Returns:
        The submit command as a string.
    """
    pass
get_script_part abstractmethod
get_script_part(command)

Get the script part of the command.

Parameters:

Name Type Description Default
command str

The command to wrap.

required

Returns:

Type Description
str

The wrapped command as a string.

Source code in flyvis/utils/compute_cloud_utils.py
83
84
85
86
87
88
89
90
91
92
93
94
@abstractmethod
def get_script_part(self, command: str) -> str:
    """
    Get the script part of the command.

    Args:
        command: The command to wrap.

    Returns:
        The wrapped command as a string.
    """
    pass

flyvis.utils.compute_cloud_utils.LSFManager

Bases: ClusterManager

Cluster manager for LSF (Load Sharing Facility) systems.

Source code in flyvis/utils/compute_cloud_utils.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class LSFManager(ClusterManager):
    """Cluster manager for LSF (Load Sharing Facility) systems."""

    def run_job(self, command: str) -> str:
        answer = subprocess.getoutput(command)
        job_id = re.findall(r"(?<=<)\d+(?=>)", answer)
        assert len(job_id) == 1
        return job_id[0]

    def is_running(self, job_id: str) -> bool:
        job_info = subprocess.getoutput("bjobs -w")
        return job_id in job_info

    def kill_job(self, job_id: str) -> str:
        return subprocess.getoutput(f"bkill {job_id}")

    def get_submit_command(
        self, job_name: str, n_cpus: int, output_file: str, gpu: str, queue: str
    ) -> str:
        return f"bsub -J {job_name} -n {n_cpus} -o {output_file} -gpu '{gpu}' -q {queue} "

    def get_script_part(self, command: str) -> str:
        return command

flyvis.utils.compute_cloud_utils.SLURMManager

Bases: ClusterManager

Cluster manager for SLURM systems.

Warning

This is untested.

Source code in flyvis/utils/compute_cloud_utils.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class SLURMManager(ClusterManager):
    """Cluster manager for SLURM systems.

    Warning:
        This is untested.
    """

    def run_job(self, command: str) -> str:
        answer = subprocess.getoutput(command)
        job_id = re.findall(r"\d+", answer)
        assert len(job_id) == 1
        return job_id[0]

    def is_running(self, job_id: str) -> bool:
        cmd = f"sacct -j {job_id} --format=State --noheader -X"
        state = subprocess.getoutput(cmd).strip()
        return state in ["PENDING", "RUNNING", "REQUEUED"]

    def kill_job(self, job_id: str) -> str:
        return subprocess.getoutput(f"scancel {job_id}")

    def get_submit_command(
        self, job_name: str, n_cpus: int, output_file: str, gpu: str, queue: str
    ) -> str:
        return (
            f"sbatch --job-name={job_name} "
            f"--cpus-per-task={n_cpus} "
            f"--output={output_file} "
            f"--gres=gpu:{gpu} "
            f"--partition={queue} "
        )

    def get_script_part(self, command: str) -> str:
        return f"--wrap '{command}'"

Functions

flyvis.utils.compute_cloud_utils.get_cluster_manager

get_cluster_manager(dry=False)

Autodetect the cluster type and return the appropriate ClusterManager.

Parameters:

Name Type Description Default
dry bool

If True, return LSFManager even if no cluster is detected.

False

Returns:

Type Description
ClusterManager

An instance of the appropriate ClusterManager subclass.

Source code in flyvis/utils/compute_cloud_utils.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def get_cluster_manager(dry: bool = False) -> ClusterManager:
    """
    Autodetect the cluster type and return the appropriate ClusterManager.

    Args:
        dry: If True, return LSFManager even if no cluster is detected.

    Returns:
        An instance of the appropriate ClusterManager subclass.
    """
    virtual = os.environ.get("VIRTUAL_CLUSTER", "").lower() in ("true", "1", "yes", "on")
    dry = dry or os.environ.get("DRYRUN_ONLY", "").lower() in ("true", "1", "yes", "on")

    if subprocess.getoutput("command -v bsub"):
        return LSFManager()
    elif subprocess.getoutput("command -v sbatch"):
        return SLURMManager()
    else:
        if dry:
            return LSFManager()
        elif virtual:
            warnings.warn(
                "No cluster management system detected. Using VirtualClusterManager for "
                "local execution. This is not recommended for production use.",
                UserWarning,
                stacklevel=2,
            )
            return VirtualClusterManager()
        else:
            raise RuntimeError("No cluster management system detected.")

flyvis.utils.compute_cloud_utils.run_job

run_job(command, dry)

Run a job on the cluster.

Parameters:

Name Type Description Default
command str

The command to run.

required
dry bool

If True, perform a dry run without actually submitting the job.

required

Returns:

Type Description
str

The job ID as a string, or “dry run” for dry runs.

Source code in flyvis/utils/compute_cloud_utils.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def run_job(command: str, dry: bool) -> str:
    """
    Run a job on the cluster.

    Args:
        command: The command to run.
        dry: If True, perform a dry run without actually submitting the job.

    Returns:
        The job ID as a string, or "dry run" for dry runs.
    """
    # TODO: dry handling currently not elegant but works for now
    env_dry = os.environ.get("DRYRUN_ONLY", "").lower() in ("true", "1", "yes", "on")
    is_dry = dry or env_dry

    if is_dry:
        job_id = "dry run"
        logger.info("Dry run command: %s", command)
        return job_id

    return CLUSTER_MANAGER.run_job(command)

flyvis.utils.compute_cloud_utils.is_running

is_running(job_id, dry)

Check if a job is running.

Parameters:

Name Type Description Default
job_id str

The ID of the job to check.

required
dry bool

If True, always return False.

required

Returns:

Type Description
bool

True if the job is running, False otherwise.

Source code in flyvis/utils/compute_cloud_utils.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def is_running(job_id: str, dry: bool) -> bool:
    """
    Check if a job is running.

    Args:
        job_id: The ID of the job to check.
        dry: If True, always return False.

    Returns:
        True if the job is running, False otherwise.
    """
    if dry:
        return False
    return CLUSTER_MANAGER.is_running(job_id)

flyvis.utils.compute_cloud_utils.kill_job

kill_job(job_id, dry)

Kill a running job.

Parameters:

Name Type Description Default
job_id str

The ID of the job to kill.

required
dry bool

If True, return a message without actually killing the job.

required

Returns:

Type Description
str

The output of the kill command or a dry run message.

Source code in flyvis/utils/compute_cloud_utils.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def kill_job(job_id: str, dry: bool) -> str:
    """
    Kill a running job.

    Args:
        job_id: The ID of the job to kill.
        dry: If True, return a message without actually killing the job.

    Returns:
        The output of the kill command or a dry run message.
    """
    if dry:
        return f"Would kill job {job_id}"
    return CLUSTER_MANAGER.kill_job(job_id)

flyvis.utils.compute_cloud_utils.wait_for_single

wait_for_single(job_id, dry=False)

Wait for a single job to finish on the cluster.

Parameters:

Name Type Description Default
job_id str

The ID of the job to wait for.

required
job_name

The name of the job.

required
dry bool

If True, skip actual waiting.

False

Raises:

Type Description
KeyboardInterrupt

If the waiting is interrupted by the user.

Source code in flyvis/utils/compute_cloud_utils.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def wait_for_single(job_id: str, dry: bool = False) -> None:
    """
    Wait for a single job to finish on the cluster.

    Args:
        job_id: The ID of the job to wait for.
        job_name: The name of the job.
        dry: If True, skip actual waiting.

    Raises:
        KeyboardInterrupt: If the waiting is interrupted by the user.
    """
    try:
        if not dry:
            sleep(60)
        while is_running(job_id, dry):
            if not dry:
                sleep(60)
    except KeyboardInterrupt as e:
        logger.info("Killing job %s", kill_job(job_id, dry))
        raise KeyboardInterrupt from e

flyvis.utils.compute_cloud_utils.wait_for_many

wait_for_many(job_id_names, dry=False)

Wait for multiple jobs to finish on the cluster.

Parameters:

Name Type Description Default
job_id_names Dict[str, str]

A dictionary mapping job IDs to job names.

required
dry bool

If True, skip actual waiting.

False

Raises:

Type Description
KeyboardInterrupt

If the waiting is interrupted by the user.

Source code in flyvis/utils/compute_cloud_utils.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def wait_for_many(job_id_names: Dict[str, str], dry: bool = False) -> None:
    """
    Wait for multiple jobs to finish on the cluster.

    Args:
        job_id_names: A dictionary mapping job IDs to job names.
        dry: If True, skip actual waiting.

    Raises:
        KeyboardInterrupt: If the waiting is interrupted by the user.
    """
    try:
        if not dry:
            print("Jobs launched.. waiting 60s..")
            sleep(60)
        while any(is_running(job_id, dry) for job_id in job_id_names):
            if not dry:
                print("Jobs still running.. waiting 60s..")
                sleep(60)
    except KeyboardInterrupt as e:
        for job_id in job_id_names:
            logger.info("Killing job %s", kill_job(job_id, dry))
        raise KeyboardInterrupt from e

flyvis.utils.compute_cloud_utils.check_valid_host

check_valid_host(blacklist)

Prevent running on certain blacklisted hosts, e.g., login nodes.

Parameters:

Name Type Description Default
blacklist List[str]

A list of blacklisted hostnames or substrings.

required

Raises:

Type Description
ValueError

If the current host is in the blacklist.

Source code in flyvis/utils/compute_cloud_utils.py
376
377
378
379
380
381
382
383
384
385
386
387
388
def check_valid_host(blacklist: List[str]) -> None:
    """
    Prevent running on certain blacklisted hosts, e.g., login nodes.

    Args:
        blacklist: A list of blacklisted hostnames or substrings.

    Raises:
        ValueError: If the current host is in the blacklist.
    """
    host = socket.gethostname()
    if any(h in host for h in blacklist):
        raise ValueError(f"This script should not be run from {host}!")

flyvis.utils.compute_cloud_utils.launch_range

launch_range(start, end, ensemble_id, task_name, nP, gpu, q, script, dry, kwargs)

Launch a range of models.

Parameters:

Name Type Description Default
start int

The starting index for the range.

required
end int

The ending index for the range.

required
ensemble_id str

The ID of the ensemble.

required
task_name str

The name of the task.

required
nP int

The number of processors to use.

required
gpu str

The GPU configuration.

required
q str

The queue to submit the job to.

required
script str

The script to run.

required
dry bool

If True, perform a dry run without actually submitting jobs.

required
kwargs List[str]

A list of additional keyword arguments for the script.

required
Note

kwargs is an ordered list of strings, either in the format [“-kw”, “val”, …] or following hydra syntax, i.e. [“kw=val”, …].

Source code in flyvis/utils/compute_cloud_utils.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def launch_range(
    start: int,
    end: int,
    ensemble_id: str,
    task_name: str,
    nP: int,
    gpu: str,
    q: str,
    script: str,
    dry: bool,
    kwargs: List[str],
) -> None:
    """
    Launch a range of models.

    Args:
        start: The starting index for the range.
        end: The ending index for the range.
        ensemble_id: The ID of the ensemble.
        task_name: The name of the task.
        nP: The number of processors to use.
        gpu: The GPU configuration.
        q: The queue to submit the job to.
        script: The script to run.
        dry: If True, perform a dry run without actually submitting jobs.
        kwargs: A list of additional keyword arguments for the script.

    Note:
        kwargs is an ordered list of strings, either in the format ["-kw", "val", ...]
        or following hydra syntax, i.e. ["kw=val", ...].
    """
    SCRIPT_PART = "{} {} {}"

    CLUSTER_MANAGER.set_dry(dry)

    job_id_names = {}
    for i in range(start, end):
        kw = kwargs.copy()
        ensemble_and_network_id = f"{ensemble_id:04}/{i:03}"
        assert "_" not in ensemble_and_network_id
        network_dir = results_dir / task_name / ensemble_and_network_id
        if not network_dir.parent.exists():
            network_dir.parent.mkdir(parents=True)
        log_file = (
            network_dir.parent / f"{i:04}_{script.split('/')[-1].split('.')[0]}.log"
        )
        if log_file.exists():
            log_file.unlink()

        kw.extend([f"ensemble_and_network_id={ensemble_and_network_id}"])
        kw.extend([f"task_name={task_name}"])

        LSF_CMD = CLUSTER_MANAGER.get_submit_command(
            f"{task_name}_{ensemble_and_network_id}", nP, log_file, gpu, q
        )
        SCRIPT_CMD = SCRIPT_PART.format(sys.executable, script, " ".join(kw))
        command = LSF_CMD + CLUSTER_MANAGER.get_script_part(SCRIPT_CMD)
        logger.info("Launching command: %s", command)
        job_id = run_job(command, dry)
        job_id_names[job_id] = f"{task_name}_{ensemble_and_network_id}"

    wait_for_many(job_id_names, dry)

flyvis.utils.compute_cloud_utils.launch_single

launch_single(ensemble_id, task_name, nP, gpu, q, script, dry, kwargs)

Launch a single job for an ensemble.

Parameters:

Name Type Description Default
ensemble_id str

The ID of the ensemble.

required
task_name str

The name of the task.

required
nP int

The number of processors to use.

required
gpu str

The GPU configuration.

required
q str

The queue to submit the job to.

required
script str

The script to run.

required
dry bool

If True, perform a dry run without actually submitting the job.

required
kwargs List[str]

A list of additional keyword arguments for the script.

required
Note

kwargs is an ordered list of strings, either in the format [“-kw”, “val”, …] or following hydra syntax, i.e. [“kw=val”, …].

Source code in flyvis/utils/compute_cloud_utils.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
def launch_single(
    ensemble_id: str,
    task_name: str,
    nP: int,
    gpu: str,
    q: str,
    script: str,
    dry: bool,
    kwargs: List[str],
) -> None:
    """
    Launch a single job for an ensemble.

    Args:
        ensemble_id: The ID of the ensemble.
        task_name: The name of the task.
        nP: The number of processors to use.
        gpu: The GPU configuration.
        q: The queue to submit the job to.
        script: The script to run.
        dry: If True, perform a dry run without actually submitting the job.
        kwargs: A list of additional keyword arguments for the script.

    Note:
        kwargs is an ordered list of strings, either in the format ["-kw", "val", ...]
        or following hydra syntax, i.e. ["kw=val", ...].
    """
    SCRIPT_PART = "{} {} {}"

    CLUSTER_MANAGER.set_dry(dry)

    job_id_names = {}
    kw = kwargs.copy()
    ensemble_id = f"{ensemble_id:04}"
    assert "_" not in ensemble_id
    ensemble_dir = results_dir / task_name / ensemble_id

    assert ensemble_dir.exists()
    log_file = ensemble_dir / f"{script.split('/')[-1].split('.')[0]}.log"
    if log_file.exists():
        log_file.unlink()

    kw.extend([f"ensemble_id={ensemble_id}"])
    kw.extend([f"task_name={task_name}"])

    LSF_CMD = CLUSTER_MANAGER.get_submit_command(
        f"{task_name}_{ensemble_id}", nP, log_file, gpu, q
    )
    SCRIPT_CMD = SCRIPT_PART.format(sys.executable, script, " ".join(kw))
    command = LSF_CMD + CLUSTER_MANAGER.get_script_part(SCRIPT_CMD)
    logger.info("Launching command: %s", command)
    job_id = run_job(command, dry)
    job_id_names[job_id] = f"{task_name}_{ensemble_id}"

    wait_for_many(job_id_names, dry)

flyvis.utils.config_utils

Classes

flyvis.utils.config_utils.HybridArgumentParser

Bases: ArgumentParser

Hybrid argument parser that can parse unknown arguments in basic key=value style.

Attributes:

Name Type Description
hybrid_args

Dictionary of hybrid arguments with their requirements and help texts.

allow_unrecognized

Whether to allow unrecognized arguments.

drop_disjoint_from

Path to a configuration file that can be used to filter out arguments that are present in the command line arguments but not in the configuration file. This is to pass through arguments through multiple scripts as hydra does not support this.

Parameters:

Name Type Description Default
hybrid_args Optional[Dict[str, Dict[str, Any]]]

Dictionary of hybrid arguments with their requirements and help texts.

None
allow_unrecognized bool

Whether to allow unrecognized arguments.

True
Source code in flyvis/utils/config_utils.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class HybridArgumentParser(argparse.ArgumentParser):
    """
    Hybrid argument parser that can parse unknown arguments in basic key=value style.

    Attributes:
        hybrid_args: Dictionary of hybrid arguments with their requirements and
            help texts.
        allow_unrecognized: Whether to allow unrecognized arguments.
        drop_disjoint_from: Path to a configuration file that can be used to filter
            out arguments that are present in the command line arguments but not in
            the configuration file. This is to pass through arguments through multiple
            scripts as hydra does not support this.

    Args:
        hybrid_args: Dictionary of hybrid arguments with their requirements and
            help texts.
        allow_unrecognized: Whether to allow unrecognized arguments.
    """

    def __init__(
        self,
        *args: Any,
        hybrid_args: Optional[Dict[str, Dict[str, Any]]] = None,
        allow_unrecognized: bool = True,
        drop_disjoint_from: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.hybrid_args = hybrid_args or {}
        self.allow_unrecognized = allow_unrecognized
        self.drop_disjoint_from = drop_disjoint_from
        self._add_hybrid_args_to_help()

    def _add_hybrid_args_to_help(self) -> None:
        """Add hybrid arguments to the help message."""
        if self.hybrid_args:
            hybrid_group = self.add_argument_group('Hybrid Arguments')
            for arg, config in self.hybrid_args.items():
                help_text = config.get('help', '')
                required = config.get('required', False)
                arg_type = config.get('type', None)
                arg_help = f"{arg}=value: {help_text}"
                if arg_type:
                    arg_help += f" (type: {arg_type.__name__})"
                if required:
                    arg_help += " (Required)"
                hybrid_group.add_argument(f"--{arg}", help=arg_help, required=False)

    def parse_with_hybrid_args(
        self,
        args: Optional[List[str]] = None,
        namespace: Optional[argparse.Namespace] = None,
    ) -> argparse.Namespace:
        """
        Parse arguments and set hybrid arguments as attributes in the namespace.

        Args:
            args: List of arguments to parse.
            namespace: Namespace to populate with parsed arguments.

        Returns:
            Namespace with parsed arguments.

        Raises:
            argparse.ArgumentError: If required arguments are missing or invalid
                values are provided.
        """
        if args is None:
            args = sys.argv[1:]

        args_for_parser = []
        key_value_args = []

        # Separate key=value pairs from other arguments
        for arg in args:
            if '=' in arg and not arg.startswith('-'):
                key_value_args.append(arg)
            else:
                args_for_parser.append(arg)

        # Parse the known arguments
        args, unknown_args = self.parse_known_args(args_for_parser, namespace)

        # Combine key_value_args with unknown_args for processing
        all_unknown_args = key_value_args + unknown_args

        argv = []
        for arg in all_unknown_args:
            if ":" in arg and "=" in arg:
                keytype, value = arg.split("=")
                key, astype = keytype.split(":")
                try:
                    if value.lower() in ["true", "1", 'yes'] and astype == "bool":
                        setattr(args, key, True)
                    elif value.lower() in ["false", "0", 'no'] and astype == "bool":
                        setattr(args, key, False)
                    else:
                        setattr(args, key, safe_cast(value, astype))
                except (ValueError, TypeError):
                    self.error(
                        f"Invalid type '{astype}' or value '{value}' for argument {key}"
                    )
            elif "=" in arg:
                key, value = arg.split("=", 1)
                if key in self.hybrid_args and 'type' in self.hybrid_args[key]:
                    arg_type = self.hybrid_args[key]['type']
                    try:
                        typed_value = arg_type(value)
                        setattr(args, key, typed_value)
                    except ValueError:
                        self.error(
                            f"Invalid {arg_type.__name__} value '{value}' "
                            f"for argument {key}"
                        )
                else:
                    setattr(args, key, value)
            else:
                argv.append(arg)

        # Apply type conversion for arguments parsed by argparse
        for arg, config in self.hybrid_args.items():
            if (
                hasattr(args, arg)
                and config.get('type')
                and getattr(args, arg) is not None
            ):
                setattr(args, arg, config['type'](getattr(args, arg)))

        # Check for required arguments
        missing_required = []
        for arg, config in self.hybrid_args.items():
            if config.get('required', False) and getattr(args, arg) is None:
                missing_required.append(arg)

        if missing_required:
            self.error(
                f"The following required arguments are missing: "
                f"{', '.join(missing_required)}"
            )

        if argv and not self.allow_unrecognized:
            msg = "unrecognized arguments: %s"
            self.error(msg % " ".join(argv))

        if self.drop_disjoint_from:
            args = self._filter_args_based_on_config(args)

        return args

    def hydra_argv(self) -> List[str]:
        hybrid_args = self.parse_with_hybrid_args()
        return [
            f"{key}={value}" for key, value in vars(hybrid_args).items() if ":" not in key
        ]

    def get_registered_args(self) -> List[str]:
        """
        Get a list of all argument names that were registered using add_argument.

        Returns:
            List of argument names (without the -- prefix)
        """
        return [
            action.dest
            for action in self._actions
            if action.dest != "help"  # Exclude the default help action
        ]

    def _filter_args_based_on_config(
        self, args: argparse.Namespace
    ) -> argparse.Namespace:
        """
        Filter arguments based on the Hydra config file specified in drop_disjoint_from.

        Args:
            args: Namespace containing all parsed arguments.

        Returns:
            Filtered Namespace with only arguments present in the config or with
                Hydra syntax.
        """
        if not self.drop_disjoint_from:
            return args

        config = OmegaConf.create(
            get_config_from_file(self.drop_disjoint_from, resolve=False)
        )

        filtered_args = argparse.Namespace()
        registered_args = self.get_registered_args()

        for arg, value in vars(args).items():
            if (
                self._is_in_config(arg, config)
                or arg.startswith('+')
                or arg.startswith('++')
                or arg.startswith('~')
            ):
                setattr(filtered_args, arg, value)
            elif arg not in registered_args:
                warnings.warn(
                    f"{Fore.YELLOW}Argument {Style.BRIGHT}{arg}={value}"
                    f"{Style.RESET_ALL}{Fore.YELLOW} "
                    f"does not affect the hydra config because it is not present in "
                    f"the config file {Style.BRIGHT}{self.drop_disjoint_from}"
                    f"{Style.RESET_ALL}{Fore.YELLOW}. "
                    f"This may be unintended, like a typo, or intended, like a "
                    f"hydra-style argument passed through to another script. "
                    f"Check script docs and config file "
                    f"for clarification.{Style.RESET_ALL}",
                    stacklevel=2,
                )

        return filtered_args

    def _is_in_config(self, arg: str, config: Any) -> bool:
        """
        Check if an argument exists in the config, including nested structures.

        Args:
            arg: The argument to check.
            config: The configuration object or sub-object.

        Returns:
            True if the argument is found in the config, False otherwise.
        """
        try:
            return OmegaConf.select(config, arg, throw_on_missing=True) is not None
        except errors.MissingMandatoryValue:
            return True
parse_with_hybrid_args
parse_with_hybrid_args(args=None, namespace=None)

Parse arguments and set hybrid arguments as attributes in the namespace.

Parameters:

Name Type Description Default
args Optional[List[str]]

List of arguments to parse.

None
namespace Optional[Namespace]

Namespace to populate with parsed arguments.

None

Returns:

Type Description
Namespace

Namespace with parsed arguments.

Raises:

Type Description
ArgumentError

If required arguments are missing or invalid values are provided.

Source code in flyvis/utils/config_utils.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def parse_with_hybrid_args(
    self,
    args: Optional[List[str]] = None,
    namespace: Optional[argparse.Namespace] = None,
) -> argparse.Namespace:
    """
    Parse arguments and set hybrid arguments as attributes in the namespace.

    Args:
        args: List of arguments to parse.
        namespace: Namespace to populate with parsed arguments.

    Returns:
        Namespace with parsed arguments.

    Raises:
        argparse.ArgumentError: If required arguments are missing or invalid
            values are provided.
    """
    if args is None:
        args = sys.argv[1:]

    args_for_parser = []
    key_value_args = []

    # Separate key=value pairs from other arguments
    for arg in args:
        if '=' in arg and not arg.startswith('-'):
            key_value_args.append(arg)
        else:
            args_for_parser.append(arg)

    # Parse the known arguments
    args, unknown_args = self.parse_known_args(args_for_parser, namespace)

    # Combine key_value_args with unknown_args for processing
    all_unknown_args = key_value_args + unknown_args

    argv = []
    for arg in all_unknown_args:
        if ":" in arg and "=" in arg:
            keytype, value = arg.split("=")
            key, astype = keytype.split(":")
            try:
                if value.lower() in ["true", "1", 'yes'] and astype == "bool":
                    setattr(args, key, True)
                elif value.lower() in ["false", "0", 'no'] and astype == "bool":
                    setattr(args, key, False)
                else:
                    setattr(args, key, safe_cast(value, astype))
            except (ValueError, TypeError):
                self.error(
                    f"Invalid type '{astype}' or value '{value}' for argument {key}"
                )
        elif "=" in arg:
            key, value = arg.split("=", 1)
            if key in self.hybrid_args and 'type' in self.hybrid_args[key]:
                arg_type = self.hybrid_args[key]['type']
                try:
                    typed_value = arg_type(value)
                    setattr(args, key, typed_value)
                except ValueError:
                    self.error(
                        f"Invalid {arg_type.__name__} value '{value}' "
                        f"for argument {key}"
                    )
            else:
                setattr(args, key, value)
        else:
            argv.append(arg)

    # Apply type conversion for arguments parsed by argparse
    for arg, config in self.hybrid_args.items():
        if (
            hasattr(args, arg)
            and config.get('type')
            and getattr(args, arg) is not None
        ):
            setattr(args, arg, config['type'](getattr(args, arg)))

    # Check for required arguments
    missing_required = []
    for arg, config in self.hybrid_args.items():
        if config.get('required', False) and getattr(args, arg) is None:
            missing_required.append(arg)

    if missing_required:
        self.error(
            f"The following required arguments are missing: "
            f"{', '.join(missing_required)}"
        )

    if argv and not self.allow_unrecognized:
        msg = "unrecognized arguments: %s"
        self.error(msg % " ".join(argv))

    if self.drop_disjoint_from:
        args = self._filter_args_based_on_config(args)

    return args
get_registered_args
get_registered_args()

Get a list of all argument names that were registered using add_argument.

Returns:

Type Description
List[str]

List of argument names (without the – prefix)

Source code in flyvis/utils/config_utils.py
244
245
246
247
248
249
250
251
252
253
254
255
def get_registered_args(self) -> List[str]:
    """
    Get a list of all argument names that were registered using add_argument.

    Returns:
        List of argument names (without the -- prefix)
    """
    return [
        action.dest
        for action in self._actions
        if action.dest != "help"  # Exclude the default help action
    ]

Functions

flyvis.utils.config_utils.get_default_config

get_default_config(overrides, path='../../config/solver.yaml', as_namespace=True)

Get the default configuration using Hydra.

Parameters:

Name Type Description Default
overrides List[str]

List of configuration overrides.

required
path str

Path to the configuration file.

'../../config/solver.yaml'
as_namespace bool

Whether to return a namespaced configuration or the OmegaConf object.

True

Returns:

Type Description
Union[Dict[DictKeyType, Any], List[Any], None, str, Any, Namespace]

The configuration object.

Note

Expected overrides are: - task_name - network_id

Source code in flyvis/utils/config_utils.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def get_default_config(
    overrides: List[str],
    path: str = "../../config/solver.yaml",
    as_namespace: bool = True,
) -> Union[Dict[DictKeyType, Any], List[Any], None, str, Any, Namespace]:
    """
    Get the default configuration using Hydra.

    Args:
        overrides: List of configuration overrides.
        path: Path to the configuration file.
        as_namespace: Whether to return a namespaced configuration or the
            OmegaConf object.

    Returns:
        The configuration object.

    Note:
        Expected overrides are:
        - task_name
        - network_id
    """

    config = get_config_from_file(path, overrides, resolve=True)
    if as_namespace:
        return namespacify(config)
    return config

flyvis.utils.config_utils.parse_kwargs_to_dict

parse_kwargs_to_dict(values)

Parse a list of key-value pairs into a dictionary.

Parameters:

Name Type Description Default
values List[str]

List of key-value pairs in the format “key=value”.

required

Returns:

Type Description
Namespace

Namespace object with parsed key-value pairs as attributes.

Source code in flyvis/utils/config_utils.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def parse_kwargs_to_dict(values: List[str]) -> argparse.Namespace:
    """
    Parse a list of key-value pairs into a dictionary.

    Args:
        values: List of key-value pairs in the format "key=value".

    Returns:
        Namespace object with parsed key-value pairs as attributes.
    """
    kwargs = argparse.Namespace()
    for value in values:
        key, value = value.split("=")
        setattr(kwargs, key, value)
    return kwargs

flyvis.utils.config_utils.safe_cast

safe_cast(value, type_name)

Safely cast a string value to a specified type.

Parameters:

Name Type Description Default
value str

The string value to cast.

required
type_name str

The name of the type to cast to.

required

Returns:

Type Description
Union[int, float, bool, str]

The casted value.

Note

Supports casting to int, float, bool, and str.

Source code in flyvis/utils/config_utils.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def safe_cast(value: str, type_name: str) -> Union[int, float, bool, str]:
    """
    Safely cast a string value to a specified type.

    Args:
        value: The string value to cast.
        type_name: The name of the type to cast to.

    Returns:
        The casted value.

    Note:
        Supports casting to int, float, bool, and str.
    """
    if type_name == 'int':
        return int(value)
    elif type_name == 'float':
        return float(value)
    elif type_name == 'bool':
        return value.lower() in ('true', 'yes', '1', 'on')
    else:
        return value  # Default to string

flyvis.utils.dataset_utils

Classes

flyvis.utils.dataset_utils.CrossValIndices

Returns folds of indices for cross-validation.

Parameters:

Name Type Description Default
n_samples int

Total number of samples.

required
folds int

Total number of folds.

required
shuffle bool

Shuffles the indices.

True
seed int

Seed for shuffling.

0

Attributes:

Name Type Description
n_samples

Total number of samples.

folds

Total number of folds.

indices

Array of indices.

random

RandomState object for shuffling.

Source code in flyvis/utils/dataset_utils.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
class CrossValIndices:
    """Returns folds of indices for cross-validation.

    Args:
        n_samples: Total number of samples.
        folds: Total number of folds.
        shuffle: Shuffles the indices.
        seed: Seed for shuffling.

    Attributes:
        n_samples: Total number of samples.
        folds: Total number of folds.
        indices: Array of indices.
        random: RandomState object for shuffling.

    """

    def __init__(self, n_samples: int, folds: int, shuffle: bool = True, seed: int = 0):
        self.n_samples = n_samples
        self.folds = folds
        self.indices = np.arange(n_samples)

        if shuffle:
            self.random = RandomState(seed)
            self.random.shuffle(self.indices)

    def __call__(self, fold: int) -> Tuple[np.ndarray, np.ndarray]:
        """Returns train and test indices for a fold.

        Args:
            fold: The fold number.

        Returns:
            A tuple containing train and test indices.
        """
        fold_sizes = np.full(self.folds, self.n_samples // self.folds, dtype=int)
        fold_sizes[: self.n_samples % self.folds] += 1
        current = sum(fold_sizes[:fold])
        start, stop = current, current + fold_sizes[fold]
        test_index = self.indices[start:stop]
        test_mask = np.zeros_like(self.indices, dtype=bool)
        test_mask[test_index] = True
        return self.indices[np.logical_not(test_mask)], self.indices[test_mask]

    def iter(self) -> Tuple[np.ndarray, np.ndarray]:
        """Iterate over all folds.

        Yields:
            A tuple containing train and test indices for each fold.
        """
        for fold in range(self.folds):
            yield self(fold)
__call__
__call__(fold)

Returns train and test indices for a fold.

Parameters:

Name Type Description Default
fold int

The fold number.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

A tuple containing train and test indices.

Source code in flyvis/utils/dataset_utils.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def __call__(self, fold: int) -> Tuple[np.ndarray, np.ndarray]:
    """Returns train and test indices for a fold.

    Args:
        fold: The fold number.

    Returns:
        A tuple containing train and test indices.
    """
    fold_sizes = np.full(self.folds, self.n_samples // self.folds, dtype=int)
    fold_sizes[: self.n_samples % self.folds] += 1
    current = sum(fold_sizes[:fold])
    start, stop = current, current + fold_sizes[fold]
    test_index = self.indices[start:stop]
    test_mask = np.zeros_like(self.indices, dtype=bool)
    test_mask[test_index] = True
    return self.indices[np.logical_not(test_mask)], self.indices[test_mask]
iter
iter()

Iterate over all folds.

Yields:

Type Description
Tuple[ndarray, ndarray]

A tuple containing train and test indices for each fold.

Source code in flyvis/utils/dataset_utils.py
176
177
178
179
180
181
182
183
def iter(self) -> Tuple[np.ndarray, np.ndarray]:
    """Iterate over all folds.

    Yields:
        A tuple containing train and test indices for each fold.
    """
    for fold in range(self.folds):
        yield self(fold)

flyvis.utils.dataset_utils.IndexSampler

Bases: Sampler

Samples the provided indices in sequence.

Note

To be used with torch.utils.data.DataLoader.

Parameters:

Name Type Description Default
indices List[int]

List of indices to sample.

required

Attributes:

Name Type Description
indices

List of indices to sample.

Source code in flyvis/utils/dataset_utils.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
class IndexSampler(Sampler):
    """Samples the provided indices in sequence.

    Note:
        To be used with torch.utils.data.DataLoader.

    Args:
        indices: List of indices to sample.

    Attributes:
        indices: List of indices to sample.
    """

    def __init__(self, indices: List[int]):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in range(len(self.indices)))

    def __len__(self) -> int:
        return len(self.indices)

Functions

flyvis.utils.dataset_utils.random_walk_of_blocks

random_walk_of_blocks(n_blocks=20, block_size=4, top_lum=0, bottom_lum=0, dataset_size=[3, 20, 64, 64], noise_mean=0.5, noise_std=0.1, step_size=4, p_random=0.6, p_center_attraction=0.3, p_edge_attraction=0.1, seed=42)

Generate a sequence dataset with blocks doing random walks.

Parameters:

Name Type Description Default
n_blocks int

Number of blocks.

20
block_size int

Size of blocks.

4
top_lum float

Luminance of the top of the block.

0
bottom_lum float

Luminance of the bottom of the block.

0
dataset_size List[int]

Size of the dataset. (n_sequences, n_frames, h, w)

[3, 20, 64, 64]
noise_mean float

Mean of the background noise.

0.5
noise_std float

Standard deviation of the background noise.

0.1
step_size int

Number of pixels to move in each step.

4
p_random float

Probability of moving randomly.

0.6
p_center_attraction float

Probability of moving towards the center.

0.3
p_edge_attraction float

Probability of moving towards the edge.

0.1
seed int

Seed for the random number generator.

42

Returns:

Type Description
ndarray

Dataset of shape (n_sequences, n_frames, h, w)

Source code in flyvis/utils/dataset_utils.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def random_walk_of_blocks(
    n_blocks: int = 20,
    block_size: int = 4,
    top_lum: float = 0,
    bottom_lum: float = 0,
    dataset_size: List[int] = [3, 20, 64, 64],
    noise_mean: float = 0.5,
    noise_std: float = 0.1,
    step_size: int = 4,
    p_random: float = 0.6,
    p_center_attraction: float = 0.3,
    p_edge_attraction: float = 0.1,
    seed: int = 42,
) -> np.ndarray:
    """Generate a sequence dataset with blocks doing random walks.

    Args:
        n_blocks: Number of blocks.
        block_size: Size of blocks.
        top_lum: Luminance of the top of the block.
        bottom_lum: Luminance of the bottom of the block.
        dataset_size: Size of the dataset. (n_sequences, n_frames, h, w)
        noise_mean: Mean of the background noise.
        noise_std: Standard deviation of the background noise.
        step_size: Number of pixels to move in each step.
        p_random: Probability of moving randomly.
        p_center_attraction: Probability of moving towards the center.
        p_edge_attraction: Probability of moving towards the edge.
        seed: Seed for the random number generator.

    Returns:
        Dataset of shape (n_sequences, n_frames, h, w)
    """
    np.random.seed(seed)
    sequences = np.random.normal(loc=noise_mean, scale=noise_std, size=dataset_size)
    h, w = sequences.shape[2:]
    assert h == w

    y_coordinates = np.arange(h)
    x_coordinates = np.arange(w)

    def step(coordinate: int) -> int:
        ps = np.array([p_random, p_center_attraction, p_edge_attraction])
        ps /= ps.max()

        q = np.random.rand()
        if q < p_center_attraction:
            return (coordinate + np.sign(h // 2 - coordinate) * step_size) % h
        elif q > 1 - p_edge_attraction:
            return (coordinate + np.sign(coordinate - h // 2) * step_size) % h
        else:
            return (coordinate + np.random.choice([-1, 1]) * step_size) % h

    def block_at_coords(
        y: int, x: int
    ) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]:
        mask_top = np.meshgrid(
            np.arange(y - block_size // 2, y) % h,
            np.arange(x - block_size // 2, x + block_size // 2) % w,
        )
        mask_bottom = np.meshgrid(
            np.arange(y, y + block_size // 2) % h,
            np.arange(x - block_size // 2, x + block_size // 2) % w,
        )
        return mask_bottom, mask_top

    def initial_block() -> (
        Tuple[
            int, int, Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
        ]
    ):
        initial_x = np.random.choice(x_coordinates)
        initial_y = np.random.choice(y_coordinates)
        return initial_x, initial_y, block_at_coords(initial_x, initial_y)

    for _b in range(n_blocks):
        for i in range(sequences.shape[0]):
            for t in range(sequences.shape[1]):
                if t == 0:
                    x, y, (mask_bottom, mask_top) = initial_block()
                else:
                    x = step(x)
                    y = step(y)
                    mask_bottom, mask_top = block_at_coords(x, y)
                sequences[i, t, mask_bottom[0], mask_bottom[1]] = bottom_lum
                sequences[i, t, mask_top[0], mask_top[1]] = top_lum

    return sequences / sequences.max()

flyvis.utils.dataset_utils.load_moving_mnist

load_moving_mnist(delete_if_exists=False)

Return Moving MNIST dataset.

Parameters:

Name Type Description Default
delete_if_exists bool

If True, delete the dataset if it exists.

False

Returns:

Type Description
ndarray

Dataset of shape (n_sequences, n_frames, h, w)==(10000, 20, 64, 64).

Note

This dataset (0.78GB) will be downloaded if not present. The download is stored in flyvis.root_dir / “mnist_test_seq.npy”.

Source code in flyvis/utils/dataset_utils.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def load_moving_mnist(delete_if_exists: bool = False) -> np.ndarray:
    """Return Moving MNIST dataset.

    Args:
        delete_if_exists: If True, delete the dataset if it exists.

    Returns:
        Dataset of shape (n_sequences, n_frames, h, w)==(10000, 20, 64, 64).

    Note:
        This dataset (0.78GB) will be downloaded if not present. The download
        is stored in flyvis.root_dir / "mnist_test_seq.npy".
    """
    moving_mnist_path = flyvis.root_dir / "mnist_test_seq.npy"
    moving_mnist_url = (
        "https://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
    )

    if not moving_mnist_path.exists() or delete_if_exists:
        download_url_to_file(moving_mnist_url, moving_mnist_path)
    try:
        sequences = np.load(moving_mnist_path)
        return np.transpose(sequences, (1, 0, 2, 3)) / 255.0
    except ValueError as e:
        # delete broken download and load again
        print(f"broken file: {e}, restarting download...")
        return load_moving_mnist(delete_if_exists=True)

flyvis.utils.dataset_utils.get_random_data_split

get_random_data_split(fold, n_samples, n_folds, shuffle=True, seed=0)

Return indices to split the data.

Parameters:

Name Type Description Default
fold int

The fold number.

required
n_samples int

Total number of samples.

required
n_folds int

Total number of folds.

required
shuffle bool

Whether to shuffle the indices.

True
seed int

Seed for shuffling.

0

Returns:

Type Description
Tuple[ndarray, ndarray]

A tuple containing train and validation indices.

Source code in flyvis/utils/dataset_utils.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def get_random_data_split(
    fold: int, n_samples: int, n_folds: int, shuffle: bool = True, seed: int = 0
) -> Tuple[np.ndarray, np.ndarray]:
    """Return indices to split the data.

    Args:
        fold: The fold number.
        n_samples: Total number of samples.
        n_folds: Total number of folds.
        shuffle: Whether to shuffle the indices.
        seed: Seed for shuffling.

    Returns:
        A tuple containing train and validation indices.
    """
    cv_split = CrossValIndices(
        n_samples=n_samples,
        folds=n_folds,
        shuffle=shuffle,
        seed=seed,
    )
    train_seq_index, val_seq_index = cv_split(fold)
    return train_seq_index, val_seq_index

flyvis.utils.df_utils

Functions

flyvis.utils.df_utils.filter_by_column_values

filter_by_column_values(dataframe, column, values)

Return subset of dataframe based on list of values to appear in a column.

Parameters:

Name Type Description Default
dataframe DataFrame

DataFrame with key as column.

required
column str

Column of the dataframe, e.g. type.

required
values Iterable

Types of neurons e.g. R1, T4a, etc.

required

Returns:

Name Type Description
DataFrame DataFrame

Subset of the input dataframe.

Example
filtered_df = filter_by_column_values(df, 'neuron_type', ['R1', 'T4a'])
Source code in flyvis/utils/df_utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def filter_by_column_values(
    dataframe: DataFrame, column: str, values: Iterable
) -> DataFrame:
    """
    Return subset of dataframe based on list of values to appear in a column.

    Args:
        dataframe: DataFrame with key as column.
        column: Column of the dataframe, e.g. `type`.
        values: Types of neurons e.g. R1, T4a, etc.

    Returns:
        DataFrame: Subset of the input dataframe.

    Example:
        ```python
        filtered_df = filter_by_column_values(df, 'neuron_type', ['R1', 'T4a'])
        ```
    """
    cond = ""
    for t in values:
        cond += f"(dataframe.{column}=='{t}')"
        if t != values[-1]:
            cond += "|"
    return dataframe[eval(cond)]

flyvis.utils.df_utils.where_dataframe

where_dataframe(arg_df, **kwargs)

Return indices of rows in a DataFrame where conditions are met.

Conditions are passed as keyword arguments.

Parameters:

Name Type Description Default
arg_df DataFrame

Input DataFrame.

required
**kwargs

Keyword arguments representing conditions.

{}

Returns:

Name Type Description
DataFrame DataFrame

Indices of rows where conditions are met.

Example
indices = where_dataframe(df, type='T4a', u=2, v=0)
Note

The dataframe is expected to have columns matching the keyword arguments.

Source code in flyvis/utils/df_utils.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def where_dataframe(arg_df: DataFrame, **kwargs) -> DataFrame:
    """
    Return indices of rows in a DataFrame where conditions are met.

    Conditions are passed as keyword arguments.

    Args:
        arg_df: Input DataFrame.
        **kwargs: Keyword arguments representing conditions.

    Returns:
        DataFrame: Indices of rows where conditions are met.

    Example:
        ```python
        indices = where_dataframe(df, type='T4a', u=2, v=0)
        ```

    Note:
        The dataframe is expected to have columns matching the keyword arguments.
    """

    def _query_from_kwargs(kwargs):
        _query_start = "{}=={}"
        _query_append = "& {}=={}"

        _query_elements = []
        for i, (key, value) in enumerate(kwargs.items()):
            if isinstance(value, str) and (
                not value.startswith("'") or value.startswith('"')
            ):
                value = f"'{value}'"
            if i == 0:
                _query_elements.append(_query_start.format(key, value))
            else:
                _query_elements.append(_query_append.format(key, value))
        return "".join(_query_elements)

    query = _query_from_kwargs(kwargs)

    return arg_df.query(query).index

flyvis.utils.hex_utils

Classes

flyvis.utils.hex_utils.Hexal

Hexal representation containing u, v, z coordinates and value.

Attributes:

Name Type Description
u

Coordinate in u principal direction (0 degree axis).

v

Coordinate in v principal direction (60 degree axis).

z

Coordinate in z principal direction (-60 degree axis).

value

‘Hexal’ value.

u_stride

Stride in u-direction.

v_stride

Stride in v-direction.

Source code in flyvis/utils/hex_utils.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
class Hexal:
    """Hexal representation containing u, v, z coordinates and value.

    Attributes:
        u: Coordinate in u principal direction (0 degree axis).
        v: Coordinate in v principal direction (60 degree axis).
        z: Coordinate in z principal direction (-60 degree axis).
        value: 'Hexal' value.
        u_stride: Stride in u-direction.
        v_stride: Stride in v-direction.
    """

    def __init__(
        self, u: int, v: int, value: float = np.nan, u_stride: int = 1, v_stride: int = 1
    ):
        self.u = u
        self.v = v
        self.z = -(u + v)
        self.value = value
        self.u_stride = u_stride
        self.v_stride = v_stride

    def __repr__(self):
        return "Hexal(u={}, v={}, value={}, u_stride={}, v_stride={})".format(
            self.u, self.v, self.value, self.u_stride, self.v_stride
        )

    def __eq__(self, other):
        """Compares coordinates (not values)."""
        if isinstance(other, Hexal):
            return all((self.u == other.u, self.v == other.v))
        elif isinstance(other, Iterable):
            return np.array([self == h for h in other])

    def __add__(self, other):
        """Adds u and v coordinates, while keeping the value of the left hexal."""
        if isinstance(other, Hexal):
            return Hexal(self.u + other.u, self.v + other.v, self.value)
        elif isinstance(other, Iterable):
            return np.array([self + h for h in other])

    def __mul__(self, other):
        """Multiplies values, while preserving coordinates."""
        if isinstance(other, Hexal):
            return Hexal(self.u, self.v, self.value * other.value)
        elif isinstance(other, Iterable):
            return np.array([self * h for h in other])
        else:
            return Hexal(self.u, self.v, self.value * other)

    def eq_val(self, other):
        """Compares the values, not the coordinates."""
        if isinstance(other, Hexal):
            return self.value == other.value
        elif isinstance(other, Iterable):
            return np.array([self.eq_val(h) for h in other])

    # ----- Neighbour identification

    @property
    def east(self):
        return Hexal(self.u + self.u_stride, self.v, 0)

    @property
    def north_east(self):
        return Hexal(self.u, self.v + self.v_stride, 0)

    @property
    def north_west(self):
        return Hexal(self.u - self.u_stride, self.v + self.v_stride, 0)

    @property
    def west(self):
        return Hexal(self.u - self.u_stride, self.v, 0)

    @property
    def south_west(self):
        return Hexal(self.u, self.v - self.v_stride, 0)

    @property
    def south_east(self):
        return Hexal(self.u + self.u_stride, self.v - self.v_stride, 0)

    def neighbours(self):
        """Returns 6 neighbours sorted CCW, starting from east."""
        return (
            self.east,
            self.north_east,
            self.north_west,
            self.west,
            self.south_west,
            self.south_east,
        )

    def is_neighbour(self, other):
        """Evaluates if other is a neighbour."""
        neighbours = self.neighbours()
        if isinstance(other, Hexal):
            return other in neighbours
        elif isinstance(other, Iterable):
            return np.array([self.neighbour(h) for h in other])

    @staticmethod
    def unit_directions():
        """Returns the six unit directions."""
        return HexArray(Hexal(0, 0, 0).neighbours())

    def neighbour(self, angle):
        neighbours = np.array(self.neighbours())
        angles = np.array([h.angle(signed=True) for h in neighbours])
        distance = (angles - angle) % np.pi
        index = np.argsort(distance)
        return HexArray(neighbours[index[:2]])

    def direction(self, angle):
        neighbours = HexArray(self.neighbour(angle))
        angles = np.array([h.angle(signed=True) for h in neighbours])
        distance = (angles - angle) % np.pi
        index = np.argsort(distance)
        return HexArray(self.unit_directions()[index[:2]])

    # ----- Geometric methods

    def interp(self, other, t):
        """Interpolates towards other.

        Args:
            other (Hexal)
            t (float): interpolation step, 0<t<1.

        Returns:
            Hexal
        """

        def hex_round(u, v):
            z = -(u + v)
            ru = round(u)
            rv = round(v)
            rz = round(z)
            u_diff = abs(ru - u)
            v_diff = abs(rv - v)
            z_diff = abs(rz - z)
            if u_diff > v_diff and u_diff > z_diff:
                ru = -rv - rz
            elif v_diff > z_diff:
                rv = -ru - rz
            return ru, rv

        uprime, vprime = (
            self.u + (other.u - self.u) * t,
            self.v + (other.v - self.v) * t,
        )
        uprime, vprime = hex_round(uprime, vprime)
        return Hexal(uprime, vprime, 0)

    def angle(self, other=None, non_negative=False):
        """
        Returns the angle to other or the origin.

        Args:
            other (Hexal)
            non_negative (bool): add 2pi if angle is negative.
                Default: False.

        Returns:
            float: angle in radians.
        """

        def _angle(p1, p2):
            """Counter clockwise angle from p1 to p2.

            Returns:
                float: angle in [0, np.pi]
            """
            dot = p1[0] * p2[0] + p1[1] * p2[1]
            det = p1[0] * p2[1] - p1[1] * p2[0]
            angle = np.arctan2(det, dot)
            return angle

        x, y = self._to_pixel(self.u, self.v)
        theta = np.arctan2(y, x)
        if other is not None:
            xother, yother = self._to_pixel(other.u, other.v)
            theta = _angle([x, y], [xother, yother])
        if non_negative:
            theta += 2 * np.pi if theta < 0 else 0
        return theta

    def distance(self, other=None):
        """Returns the columnar distance between to hexals."""
        if other is not None:
            return int(
                (
                    abs(self.u - other.u)
                    + abs(self.u + self.v - other.u - other.v)
                    + abs(self.v - other.v)
                )
                / 2
            )
        return int((abs(self.u) + abs(self.u + self.v) + abs(self.v)) / 2)

    @staticmethod
    def _to_pixel(u, v, scale=1):
        """Converts to pixel coordinates."""
        return hex_to_pixel(u, v, scale)
__eq__
__eq__(other)

Compares coordinates (not values).

Source code in flyvis/utils/hex_utils.py
364
365
366
367
368
369
def __eq__(self, other):
    """Compares coordinates (not values)."""
    if isinstance(other, Hexal):
        return all((self.u == other.u, self.v == other.v))
    elif isinstance(other, Iterable):
        return np.array([self == h for h in other])
__add__
__add__(other)

Adds u and v coordinates, while keeping the value of the left hexal.

Source code in flyvis/utils/hex_utils.py
371
372
373
374
375
376
def __add__(self, other):
    """Adds u and v coordinates, while keeping the value of the left hexal."""
    if isinstance(other, Hexal):
        return Hexal(self.u + other.u, self.v + other.v, self.value)
    elif isinstance(other, Iterable):
        return np.array([self + h for h in other])
__mul__
__mul__(other)

Multiplies values, while preserving coordinates.

Source code in flyvis/utils/hex_utils.py
378
379
380
381
382
383
384
385
def __mul__(self, other):
    """Multiplies values, while preserving coordinates."""
    if isinstance(other, Hexal):
        return Hexal(self.u, self.v, self.value * other.value)
    elif isinstance(other, Iterable):
        return np.array([self * h for h in other])
    else:
        return Hexal(self.u, self.v, self.value * other)
eq_val
eq_val(other)

Compares the values, not the coordinates.

Source code in flyvis/utils/hex_utils.py
387
388
389
390
391
392
def eq_val(self, other):
    """Compares the values, not the coordinates."""
    if isinstance(other, Hexal):
        return self.value == other.value
    elif isinstance(other, Iterable):
        return np.array([self.eq_val(h) for h in other])
neighbours
neighbours()

Returns 6 neighbours sorted CCW, starting from east.

Source code in flyvis/utils/hex_utils.py
420
421
422
423
424
425
426
427
428
429
def neighbours(self):
    """Returns 6 neighbours sorted CCW, starting from east."""
    return (
        self.east,
        self.north_east,
        self.north_west,
        self.west,
        self.south_west,
        self.south_east,
    )
is_neighbour
is_neighbour(other)

Evaluates if other is a neighbour.

Source code in flyvis/utils/hex_utils.py
431
432
433
434
435
436
437
def is_neighbour(self, other):
    """Evaluates if other is a neighbour."""
    neighbours = self.neighbours()
    if isinstance(other, Hexal):
        return other in neighbours
    elif isinstance(other, Iterable):
        return np.array([self.neighbour(h) for h in other])
unit_directions staticmethod
unit_directions()

Returns the six unit directions.

Source code in flyvis/utils/hex_utils.py
439
440
441
442
@staticmethod
def unit_directions():
    """Returns the six unit directions."""
    return HexArray(Hexal(0, 0, 0).neighbours())
interp
interp(other, t)

Interpolates towards other.

Parameters:

Name Type Description Default
t float

interpolation step, 0<t<1.

required

Returns:

Type Description

Hexal

Source code in flyvis/utils/hex_utils.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def interp(self, other, t):
    """Interpolates towards other.

    Args:
        other (Hexal)
        t (float): interpolation step, 0<t<1.

    Returns:
        Hexal
    """

    def hex_round(u, v):
        z = -(u + v)
        ru = round(u)
        rv = round(v)
        rz = round(z)
        u_diff = abs(ru - u)
        v_diff = abs(rv - v)
        z_diff = abs(rz - z)
        if u_diff > v_diff and u_diff > z_diff:
            ru = -rv - rz
        elif v_diff > z_diff:
            rv = -ru - rz
        return ru, rv

    uprime, vprime = (
        self.u + (other.u - self.u) * t,
        self.v + (other.v - self.v) * t,
    )
    uprime, vprime = hex_round(uprime, vprime)
    return Hexal(uprime, vprime, 0)
angle
angle(other=None, non_negative=False)

Returns the angle to other or the origin.

Parameters:

Name Type Description Default
non_negative bool

add 2pi if angle is negative. Default: False.

False

Returns:

Name Type Description
float

angle in radians.

Source code in flyvis/utils/hex_utils.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def angle(self, other=None, non_negative=False):
    """
    Returns the angle to other or the origin.

    Args:
        other (Hexal)
        non_negative (bool): add 2pi if angle is negative.
            Default: False.

    Returns:
        float: angle in radians.
    """

    def _angle(p1, p2):
        """Counter clockwise angle from p1 to p2.

        Returns:
            float: angle in [0, np.pi]
        """
        dot = p1[0] * p2[0] + p1[1] * p2[1]
        det = p1[0] * p2[1] - p1[1] * p2[0]
        angle = np.arctan2(det, dot)
        return angle

    x, y = self._to_pixel(self.u, self.v)
    theta = np.arctan2(y, x)
    if other is not None:
        xother, yother = self._to_pixel(other.u, other.v)
        theta = _angle([x, y], [xother, yother])
    if non_negative:
        theta += 2 * np.pi if theta < 0 else 0
    return theta
distance
distance(other=None)

Returns the columnar distance between to hexals.

Source code in flyvis/utils/hex_utils.py
525
526
527
528
529
530
531
532
533
534
535
536
def distance(self, other=None):
    """Returns the columnar distance between to hexals."""
    if other is not None:
        return int(
            (
                abs(self.u - other.u)
                + abs(self.u + self.v - other.u - other.v)
                + abs(self.v - other.v)
            )
            / 2
        )
    return int((abs(self.u) + abs(self.u + self.v) + abs(self.v)) / 2)

flyvis.utils.hex_utils.HexArray

Bases: ndarray

Flat array holding Hexal’s as elements.

Can be constructed with

HexArray(hexals: Iterable, values: Optional[np.nan]) HexArray(u: Iterable, v: Iterable, values: Optional[np.nan])

Source code in flyvis/utils/hex_utils.py
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
class HexArray(np.ndarray):
    """Flat array holding Hexal's as elements.

    Can be constructed with:
        HexArray(hexals: Iterable, values: Optional[np.nan])
        HexArray(u: Iterable, v: Iterable, values: Optional[np.nan])
    """

    def __new__(cls, hexals=None, u=None, v=None, values=0):
        if isinstance(hexals, Iterable):
            u = np.array([h.u for h in hexals])
            v = np.array([h.v for h in hexals])
            values = np.array([h.value for h in hexals])
        if not isinstance(values, Iterable):
            values = np.ones_like(u) * values
        u, v = HexArray.sort(u, v)
        hexals = np.array(
            [Hexal(_u, _v, _val) for _u, _v, _val in zip(u, v, values)],
            dtype=Hexal,
        ).view(cls)
        return hexals

    def __array_finalize__(self, obj):
        if obj is None:
            return

    def __eq__(self, other):
        if isinstance(other, Hexal):
            return other == self
        else:
            return super().__eq__(other)

    def __getitem__(self, key):
        if isinstance(key, HexArray):
            mask = self.where_hexarray(key)
            return self[mask]
        else:
            return super().__getitem__(key)

    def __setitem__(self, key, value):
        if isinstance(key, slice) and key == slice(None):
            self.values = value
        elif isinstance(key, HexArray):
            mask = self.where_hexarray(key)
            super().__setitem__(mask, value)
        else:
            super().__setitem__(key, value)

    def where_hexarray(self, hexarray):
        return matrix_mask_by_sub(
            np.stack((hexarray.u, hexarray.v), axis=0).T,
            np.stack((self.u, self.v), axis=0).T,
        )

    @staticmethod
    def sort(u, v):
        sort_index = np.lexsort((v, u))
        u = u[sort_index]
        v = v[sort_index]
        return u, v

    @staticmethod
    def get_extent(hexals=None, u=None, v=None, center=Hexal(0, 0, 0)):
        """Returns the columnar extent."""
        from numbers import Number

        if isinstance(u, Number) and isinstance(v, Number):
            h = Hexal(u, v, 0)
            return h.distance(center)
        else:
            ha = HexArray(hexals, u, v)
            distance = max([h.distance(center) for h in ha])
            return distance

    @property
    def u(self):
        return np.array([h.u for h in self])

    @property
    def v(self):
        return np.array([h.v for h in self])

    @property
    def values(self):
        return np.array([h.value for h in self])

    @values.setter
    def values(self, values):
        for h, val in zip(self, values):
            h.value = val

    @property
    def extent(self):
        return super().get_extent(self)

    def with_stride(self, u_stride=None, v_stride=None):
        """Returns a sliced instance obeying strides in u- and v-direction."""
        new = []
        for u, v, _ in zip(self.u, self.v, self.values):
            new.append(u % u_stride == 0 and v % v_stride == 0)
        return self[np.array(new)]

    def where(self, value):
        """Returns a mask of where values are equal to the given one.

        Note: value can be np.nan.
        """
        return np.isclose(self.values, value, rtol=0, atol=0, equal_nan=True)

    def fill(self, value):
        """Fills the values with the given one."""
        for h in self:
            h.value = value

    def to_pixel(self, scale=1, mode="default"):
        """Converts to pixel coordinates."""
        return hex_to_pixel(self.u, self.v, scale, mode=mode)

    def plot(self, figsize=[3, 3], fill=True):
        """Plots values in regular hexagonal lattice.

        Meant for debugging.
        """
        u = np.array([h.u for h in self])
        v = np.array([h.v for h in self])
        color = np.array([h.value for h in self])
        return flyvis.plots.hex_scatter(
            u,
            v,
            color,
            fill=fill,
            cmap=cm.get_cmap("binary"),
            edgecolor="black",
            figsize=figsize,
        )
get_extent staticmethod
get_extent(hexals=None, u=None, v=None, center=Hexal(0, 0, 0))

Returns the columnar extent.

Source code in flyvis/utils/hex_utils.py
605
606
607
608
609
610
611
612
613
614
615
616
@staticmethod
def get_extent(hexals=None, u=None, v=None, center=Hexal(0, 0, 0)):
    """Returns the columnar extent."""
    from numbers import Number

    if isinstance(u, Number) and isinstance(v, Number):
        h = Hexal(u, v, 0)
        return h.distance(center)
    else:
        ha = HexArray(hexals, u, v)
        distance = max([h.distance(center) for h in ha])
        return distance
with_stride
with_stride(u_stride=None, v_stride=None)

Returns a sliced instance obeying strides in u- and v-direction.

Source code in flyvis/utils/hex_utils.py
639
640
641
642
643
644
def with_stride(self, u_stride=None, v_stride=None):
    """Returns a sliced instance obeying strides in u- and v-direction."""
    new = []
    for u, v, _ in zip(self.u, self.v, self.values):
        new.append(u % u_stride == 0 and v % v_stride == 0)
    return self[np.array(new)]
where
where(value)

Returns a mask of where values are equal to the given one.

Note: value can be np.nan.

Source code in flyvis/utils/hex_utils.py
646
647
648
649
650
651
def where(self, value):
    """Returns a mask of where values are equal to the given one.

    Note: value can be np.nan.
    """
    return np.isclose(self.values, value, rtol=0, atol=0, equal_nan=True)
fill
fill(value)

Fills the values with the given one.

Source code in flyvis/utils/hex_utils.py
653
654
655
656
def fill(self, value):
    """Fills the values with the given one."""
    for h in self:
        h.value = value
to_pixel
to_pixel(scale=1, mode='default')

Converts to pixel coordinates.

Source code in flyvis/utils/hex_utils.py
658
659
660
def to_pixel(self, scale=1, mode="default"):
    """Converts to pixel coordinates."""
    return hex_to_pixel(self.u, self.v, scale, mode=mode)
plot
plot(figsize=[3, 3], fill=True)

Plots values in regular hexagonal lattice.

Meant for debugging.

Source code in flyvis/utils/hex_utils.py
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
def plot(self, figsize=[3, 3], fill=True):
    """Plots values in regular hexagonal lattice.

    Meant for debugging.
    """
    u = np.array([h.u for h in self])
    v = np.array([h.v for h in self])
    color = np.array([h.value for h in self])
    return flyvis.plots.hex_scatter(
        u,
        v,
        color,
        fill=fill,
        cmap=cm.get_cmap("binary"),
        edgecolor="black",
        figsize=figsize,
    )

flyvis.utils.hex_utils.HexLattice

Bases: HexArray

Flat array of Hexals.

Parameters:

Name Type Description Default
extent

Extent of the regular hexagon grid.

required
hexals

Existing hexals to initialize with.

required
center

Center hexal of the lattice.

required
u_stride

Stride in u-direction.

required
v_stride

Stride in v-direction.

required
Source code in flyvis/utils/hex_utils.py
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
class HexLattice(HexArray):
    """Flat array of Hexals.

    Args:
        extent: Extent of the regular hexagon grid.
        hexals: Existing hexals to initialize with.
        center: Center hexal of the lattice.
        u_stride: Stride in u-direction.
        v_stride: Stride in v-direction.
    """

    def __new__(
        cls,
        extent=15,
        hexals=None,
        center=Hexal(0, 0, 0),
        u_stride=1,
        v_stride=1,
    ):
        if isinstance(hexals, Iterable):
            hexals = HexArray(hexals=hexals)
            u = np.array([h.u for h in hexals])
            v = np.array([h.v for h in hexals])
            extent = extent or super().get_extent(hexals, center=center)
            lattice = HexLattice(
                extent=extent,
                center=center,
                u_stride=u_stride,
                v_stride=v_stride,
            )
            for h in lattice:
                if h in hexals:
                    h.value = hexals[h == hexals][0].value
        else:
            u, v = flyvis.utils.hex_utils.get_hex_coords(extent)
            u += center.u
            v += center.v
            values = [np.nan for _ in range(len(u))]  # np.ones_like(u) * np.nan
            lattice = []
            for _u, _v, _val in zip(u, v, values):
                if _u % u_stride == 0 and _v % v_stride == 0:
                    lattice.append(Hexal(_u, _v, _val, u_stride, v_stride))
            lattice = np.array(lattice, dtype=Hexal).view(cls)
        return lattice

    @property
    def center(self):
        return self[len(self) // 2]

    @property
    def extent(self):
        return super().get_extent(self, center=self.center)

    # ----- Geometry

    def circle(self, radius=None, center=Hexal(0, 0, 0), as_lattice=False):
        """Draws a circle in hex coordinates.

        Args:
            radius: Radius in columns of the circle.
            center: Center of the circle.
            as_lattice: Returns the circle on a constrained regular lattice.
        """
        lattice = HexLattice(extent=max(radius or 0, self.extent), center=center)
        radius = radius or self.extent
        circle = []
        for _, h in enumerate(lattice):
            distance = center.distance(h)
            if distance == radius:
                h.value = 1
                circle.append(h)
        if as_lattice:
            return HexLattice(hexals=circle)
        return HexArray(hexals=circle)

    @staticmethod
    def filled_circle(radius=None, center=Hexal(0, 0, 0), as_lattice=False):
        """Draws a circle in hex coordinates.

        Args:
            radius: Radius in columns of the circle.
            center: Center of the circle.
            as_lattice: Returns the circle on a constrained regular lattice.
        """
        lattice = HexLattice(extent=radius or 0, center=center)
        radius = radius
        circle = []
        for _, h in enumerate(lattice):
            distance = center.distance(h)
            if distance <= radius:
                h.value = 1
                circle.append(h)
        if as_lattice:
            return HexLattice(hexals=circle)
        return HexArray(hexals=circle)

    def hull(self):
        """Returns the hull of the regular lattice."""
        return self.circle(radius=self.extent, center=self.center)

    def _line_span(self, angle):
        """Returns two points spanning a line with given angle wrt. origin.

        Args:
            angle: In [0, np.pi]

        Returns:
            HexArray
        """
        # To offset the line by simple addition of the offset,
        # radius=2 * self.extent spans the line in ways that each valid offset
        # can be added.
        distant_hull = self.ring(radius=2 * self.extent)
        angles = np.array([h.angle(signed=True) for h in distant_hull])
        distance = (angles - angle) % np.pi
        index = np.argsort(distance)
        span = distant_hull[index[0:2]]
        for h in span:
            h.value = 1
        return HexArray(hexals=span)

    def line(self, angle, center=Hexal(0, 0, 1), as_lattice=False):
        """Returns a line on a HexLattice or HexArray.

        Args:
            angle: In [0, np.pi]
            center: Midpoint of the line
            as_lattice: Returns the ring on a constrained regular lattice.

        Returns:
            HexArray or constrained HexLattice
        """
        line_span = self._line_span(angle)
        distance = line_span[0].distance(line_span[1])
        line = []
        for i in range(distance + 1):
            _next = line_span[0].interp(line_span[1], 1 / distance * i)
            line.append(_next)
        for h in line:
            h.value = 1
        if as_lattice:
            return HexLattice(extent=self.extent, hexals=center + line)
        return HexArray(hexals=center + line)

    def _get_neighbour_indices(self, index):
        _neighbours = self[index].neighbours()
        neighbours = ()
        for n in _neighbours:
            valid = self == n
            if valid.any():
                neighbours += (np.where(valid)[0][0],)
        return neighbours

    def valid_neighbours(self):
        neighbours = ()
        for i in range(len(self)):
            neighbours += (self._get_neighbour_indices(i),)
        return neighbours
circle
circle(radius=None, center=Hexal(0, 0, 0), as_lattice=False)

Draws a circle in hex coordinates.

Parameters:

Name Type Description Default
radius

Radius in columns of the circle.

None
center

Center of the circle.

Hexal(0, 0, 0)
as_lattice

Returns the circle on a constrained regular lattice.

False
Source code in flyvis/utils/hex_utils.py
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
def circle(self, radius=None, center=Hexal(0, 0, 0), as_lattice=False):
    """Draws a circle in hex coordinates.

    Args:
        radius: Radius in columns of the circle.
        center: Center of the circle.
        as_lattice: Returns the circle on a constrained regular lattice.
    """
    lattice = HexLattice(extent=max(radius or 0, self.extent), center=center)
    radius = radius or self.extent
    circle = []
    for _, h in enumerate(lattice):
        distance = center.distance(h)
        if distance == radius:
            h.value = 1
            circle.append(h)
    if as_lattice:
        return HexLattice(hexals=circle)
    return HexArray(hexals=circle)
filled_circle staticmethod
filled_circle(radius=None, center=Hexal(0, 0, 0), as_lattice=False)

Draws a circle in hex coordinates.

Parameters:

Name Type Description Default
radius

Radius in columns of the circle.

None
center

Center of the circle.

Hexal(0, 0, 0)
as_lattice

Returns the circle on a constrained regular lattice.

False
Source code in flyvis/utils/hex_utils.py
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
@staticmethod
def filled_circle(radius=None, center=Hexal(0, 0, 0), as_lattice=False):
    """Draws a circle in hex coordinates.

    Args:
        radius: Radius in columns of the circle.
        center: Center of the circle.
        as_lattice: Returns the circle on a constrained regular lattice.
    """
    lattice = HexLattice(extent=radius or 0, center=center)
    radius = radius
    circle = []
    for _, h in enumerate(lattice):
        distance = center.distance(h)
        if distance <= radius:
            h.value = 1
            circle.append(h)
    if as_lattice:
        return HexLattice(hexals=circle)
    return HexArray(hexals=circle)
hull
hull()

Returns the hull of the regular lattice.

Source code in flyvis/utils/hex_utils.py
777
778
779
def hull(self):
    """Returns the hull of the regular lattice."""
    return self.circle(radius=self.extent, center=self.center)
line
line(angle, center=Hexal(0, 0, 1), as_lattice=False)

Returns a line on a HexLattice or HexArray.

Parameters:

Name Type Description Default
angle

In [0, np.pi]

required
center

Midpoint of the line

Hexal(0, 0, 1)
as_lattice

Returns the ring on a constrained regular lattice.

False

Returns:

Type Description

HexArray or constrained HexLattice

Source code in flyvis/utils/hex_utils.py
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
def line(self, angle, center=Hexal(0, 0, 1), as_lattice=False):
    """Returns a line on a HexLattice or HexArray.

    Args:
        angle: In [0, np.pi]
        center: Midpoint of the line
        as_lattice: Returns the ring on a constrained regular lattice.

    Returns:
        HexArray or constrained HexLattice
    """
    line_span = self._line_span(angle)
    distance = line_span[0].distance(line_span[1])
    line = []
    for i in range(distance + 1):
        _next = line_span[0].interp(line_span[1], 1 / distance * i)
        line.append(_next)
    for h in line:
        h.value = 1
    if as_lattice:
        return HexLattice(extent=self.extent, hexals=center + line)
    return HexArray(hexals=center + line)

flyvis.utils.hex_utils.LatticeMask

Boolean masks for lattice dimension.

Parameters:

Name Type Description Default
extent int

Extent of the hexagonal lattice.

15
u_stride int

Stride in u-direction.

1
v_stride int

Stride in v-direction.

1
Source code in flyvis/utils/hex_utils.py
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
class LatticeMask:
    """Boolean masks for lattice dimension.

    Args:
        extent: Extent of the hexagonal lattice.
        u_stride: Stride in u-direction.
        v_stride: Stride in v-direction.
    """

    def __init__(self, extent: int = 15, u_stride: int = 1, v_stride: int = 1):
        self._lattice = HexLattice(extent=extent, u_stride=u_stride, v_stride=v_stride)

    @property
    def center(self):
        return self._lattice.center == self._lattice

    @property
    def center_east(self):
        return self._lattice.center.east == self._lattice

    @property
    def center_north_east(self):
        return self._lattice.center.north_east == self._lattice

    @property
    def center_north_west(self):
        return self._lattice.center.north_west == self._lattice

    @property
    def center_west(self):
        return self._lattice.center.west == self._lattice

    @property
    def center_south_west(self):
        return self._lattice.center.south_west == self._lattice

    @property
    def center_south_east(self):
        return self._lattice.center.south_east == self._lattice

Functions

flyvis.utils.hex_utils.get_hex_coords

get_hex_coords(extent, astensor=False)

Construct hexagonal coordinates for a regular hex-lattice with extent.

Parameters:

Name Type Description Default
extent int

Integer radius of hexagonal lattice. 0 returns the single center coordinate.

required
astensor bool

If True, returns torch.Tensor, else np.array.

False

Returns:

Type Description
Tuple[NDArray, NDArray]

A tuple containing: u: Hex-coordinates in u-direction. v: Hex-coordinates in v-direction.

Note

Will return get_num_hexals(extent) coordinates.

See Also

https://www.redblobgames.com/grids/hexagons/#range-coordinate

Source code in flyvis/utils/hex_utils.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def get_hex_coords(extent: int, astensor: bool = False) -> Tuple[NDArray, NDArray]:
    """Construct hexagonal coordinates for a regular hex-lattice with extent.

    Args:
        extent: Integer radius of hexagonal lattice. 0 returns the single
            center coordinate.
        astensor: If True, returns torch.Tensor, else np.array.

    Returns:
        A tuple containing:
            u: Hex-coordinates in u-direction.
            v: Hex-coordinates in v-direction.

    Note:
        Will return `get_num_hexals(extent)` coordinates.

    See Also:
        https://www.redblobgames.com/grids/hexagons/#range-coordinate
    """
    u = []
    v = []
    for q in range(-extent, extent + 1):
        for r in range(max(-extent, -extent - q), min(extent, extent - q) + 1):
            u.append(q)
            v.append(r)
    if astensor:
        return torch.tensor(u, dtype=torch.long), torch.tensor(v, dtype=torch.long)
    return np.array(u), np.array(v)

flyvis.utils.hex_utils.hex_to_pixel

hex_to_pixel(u, v, size=1, mode='default')

Returns pixel coordinates from hex coordinates.

Parameters:

Name Type Description Default
u NDArray

Hex-coordinates in u-direction.

required
v NDArray

Hex-coordinates in v-direction.

required
size float

Size of hexagon.

1
mode Literal['default', 'flat', 'pointy']

Coordinate system convention.

'default'

Returns:

Type Description
Tuple[NDArray, NDArray]

A tuple containing: x: Pixel-coordinates in x-direction. y: Pixel-coordinates in y-direction.

See Also

https://www.redblobgames.com/grids/hexagons/#hex-to-pixel

Source code in flyvis/utils/hex_utils.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def hex_to_pixel(
    u: NDArray,
    v: NDArray,
    size: float = 1,
    mode: Literal["default", "flat", "pointy"] = "default",
) -> Tuple[NDArray, NDArray]:
    """Returns pixel coordinates from hex coordinates.

    Args:
        u: Hex-coordinates in u-direction.
        v: Hex-coordinates in v-direction.
        size: Size of hexagon.
        mode: Coordinate system convention.

    Returns:
        A tuple containing:
            x: Pixel-coordinates in x-direction.
            y: Pixel-coordinates in y-direction.

    See Also:
        https://www.redblobgames.com/grids/hexagons/#hex-to-pixel
    """
    if isinstance(u, list) and isinstance(v, list):
        u = np.array(u)
        v = np.array(v)
    if mode == "default":
        return 3 / 2 * v, -np.sqrt(3) * (u + v / 2)
    elif mode == "flat":
        return (3 / 2 * u) * size, (np.sqrt(3) / 2 * u + np.sqrt(3) * v) * size
    elif mode == "pointy":
        return (np.sqrt(3) * u + np.sqrt(3) / 2 * v) * size, (3 / 2 * v) * size
    else:
        raise ValueError(f"{mode} not recognized.")

flyvis.utils.hex_utils.hex_rows

hex_rows(n_rows, n_columns, eps=0.1, mode='pointy')

Return a hex grid in pixel coordinates.

Parameters:

Name Type Description Default
n_rows int

Number of rows.

required
n_columns int

Number of columns.

required
eps float

Small offset to avoid overlapping hexagons.

0.1
mode Literal['pointy', 'flat']

Orientation of hexagons.

'pointy'

Returns:

Type Description
Tuple[NDArray, NDArray]

A tuple containing: x: X-coordinates of hexagon centers. y: Y-coordinates of hexagon centers.

Source code in flyvis/utils/hex_utils.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def hex_rows(
    n_rows: int,
    n_columns: int,
    eps: float = 0.1,
    mode: Literal["pointy", "flat"] = "pointy",
) -> Tuple[NDArray, NDArray]:
    """Return a hex grid in pixel coordinates.

    Args:
        n_rows: Number of rows.
        n_columns: Number of columns.
        eps: Small offset to avoid overlapping hexagons.
        mode: Orientation of hexagons.

    Returns:
        A tuple containing:
            x: X-coordinates of hexagon centers.
            y: Y-coordinates of hexagon centers.
    """
    u = []
    v = []
    for r in range(n_rows):
        for c in range(n_columns):
            u.append(c)
            v.append(r)
    u = np.array(u)
    v = np.array(v)
    x, y = hex_to_pixel(u, v, mode=mode)
    x += eps
    y += eps
    return x, y

flyvis.utils.hex_utils.pixel_to_hex

pixel_to_hex(x, y, size=1, mode='default')

Returns hex coordinates from pixel coordinates.

Parameters:

Name Type Description Default
x NDArray

Pixel-coordinates in x-direction.

required
y NDArray

Pixel-coordinates in y-direction.

required
size float

Size of hexagon.

1
mode Literal['default', 'flat', 'pointy']

Coordinate system convention.

'default'

Returns:

Type Description
Tuple[NDArray, NDArray]

A tuple containing: u: Hex-coordinates in u-direction. v: Hex-coordinates in v-direction.

See Also

https://www.redblobgames.com/grids/hexagons/#hex-to-pixel

Source code in flyvis/utils/hex_utils.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def pixel_to_hex(
    x: NDArray,
    y: NDArray,
    size: float = 1,
    mode: Literal["default", "flat", "pointy"] = "default",
) -> Tuple[NDArray, NDArray]:
    """Returns hex coordinates from pixel coordinates.

    Args:
        x: Pixel-coordinates in x-direction.
        y: Pixel-coordinates in y-direction.
        size: Size of hexagon.
        mode: Coordinate system convention.

    Returns:
        A tuple containing:
            u: Hex-coordinates in u-direction.
            v: Hex-coordinates in v-direction.

    See Also:
        https://www.redblobgames.com/grids/hexagons/#hex-to-pixel
    """
    if mode == "default":
        return -x / 3 - y / np.sqrt(3), 2 / 3 * x
    elif mode == "flat":
        return (2 / 3 * x) / size, (-1 / 3 * x + np.sqrt(3) / 3 * y) / size
    elif mode == "pointy":
        return (np.sqrt(3) / 3 * x - 1 / 3 * y) / size, (2 / 3 * y) / size
    else:
        raise ValueError(f"{mode} not recognized.")

flyvis.utils.hex_utils.pad_to_regular_hex

pad_to_regular_hex(u, v, values, extent, value=np.nan)

Pad hexals with coordinates to a regular hex lattice.

Parameters:

Name Type Description Default
u NDArray

U-coordinate of hexal.

required
v NDArray

V-coordinate of hexal.

required
values NDArray

Value of hexal with arbitrary shape but last axis must match the hexal dimension.

required
extent int

Extent of regular hex grid to pad to.

required
value float

The pad value.

nan

Returns:

Type Description
Tuple[NDArray, NDArray, NDArray]

A tuple containing: u_padded: Padded u-coordinate. v_padded: Padded v-coordinate. values_padded: Padded value.

Note

The canonical use case here is to pad a filter, receptive field, or postsynaptic current field for visualization.

Example
u = np.array([1, 0, -1, 0, 1, 2])
v = np.array([-2, -1, 0, 0, 0, 0])
values = np.array([0.05, 0.1, 0.3, 0.5, 0.7, 0.9])
hexals = pad_to_regular_hex(u, v, values, 6)
hex_scatter(*hexals, edgecolor='k', cmap=plt.cm.Blues, vmin=0, vmax=1)
Source code in flyvis/utils/hex_utils.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def pad_to_regular_hex(
    u: NDArray,
    v: NDArray,
    values: NDArray,
    extent: int,
    value: float = np.nan,
) -> Tuple[NDArray, NDArray, NDArray]:
    """Pad hexals with coordinates to a regular hex lattice.

    Args:
        u: U-coordinate of hexal.
        v: V-coordinate of hexal.
        values: Value of hexal with arbitrary shape but last axis
            must match the hexal dimension.
        extent: Extent of regular hex grid to pad to.
        value: The pad value.

    Returns:
        A tuple containing:
            u_padded: Padded u-coordinate.
            v_padded: Padded v-coordinate.
            values_padded: Padded value.

    Note:
        The canonical use case here is to pad a filter, receptive field, or
        postsynaptic current field for visualization.

    Example:
        ```python
        u = np.array([1, 0, -1, 0, 1, 2])
        v = np.array([-2, -1, 0, 0, 0, 0])
        values = np.array([0.05, 0.1, 0.3, 0.5, 0.7, 0.9])
        hexals = pad_to_regular_hex(u, v, values, 6)
        hex_scatter(*hexals, edgecolor='k', cmap=plt.cm.Blues, vmin=0, vmax=1)
        ```
    """
    u_padded, v_padded = flyvis.utils.hex_utils.get_hex_coords(extent)
    slices = tuple()
    if len(values.shape) > 1:
        values_padded = np.ones([*values.shape[:-1], len(u_padded)]) * value
        for _ in range(len(values.shape[:-1])):
            slices += (slice(None),)
    else:
        values_padded = np.ones([len(u_padded)]) * value
    index = flyvis.utils.tensor_utils.where_equal_rows(
        np.stack((u, v), axis=1), np.stack((u_padded, v_padded), axis=1)
    )
    slices += (index,)
    values_padded[slices] = values
    return u_padded, v_padded, values_padded

flyvis.utils.hex_utils.max_extent_index

max_extent_index(u, v, max_extent)

Returns a mask to constrain u and v axial-hex-coordinates by max_extent.

Parameters:

Name Type Description Default
u NDArray

Hex-coordinates in u-direction.

required
v NDArray

Hex-coordinates in v-direction.

required
max_extent int

Maximal extent.

required

Returns:

Type Description
NDArray

Boolean mask.

Source code in flyvis/utils/hex_utils.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def max_extent_index(u: NDArray, v: NDArray, max_extent: int) -> NDArray:
    """Returns a mask to constrain u and v axial-hex-coordinates by max_extent.

    Args:
        u: Hex-coordinates in u-direction.
        v: Hex-coordinates in v-direction.
        max_extent: Maximal extent.

    Returns:
        Boolean mask.
    """
    return (
        (-max_extent <= u)
        & (u <= max_extent)
        & (-max_extent <= v)
        & (v <= max_extent)
        & (-max_extent <= u + v)
        & (u + v <= max_extent)
    )

flyvis.utils.hex_utils.get_num_hexals

get_num_hexals(extent)

Returns the absolute number of hexals in a hexagonal grid with extent.

Parameters:

Name Type Description Default
extent int

Extent of hex-lattice.

required

Returns:

Type Description
int

Number of hexals.

Note

Inverse of get_hextent.

Source code in flyvis/utils/hex_utils.py
218
219
220
221
222
223
224
225
226
227
228
229
230
def get_num_hexals(extent: int) -> int:
    """Returns the absolute number of hexals in a hexagonal grid with extent.

    Args:
        extent: Extent of hex-lattice.

    Returns:
        Number of hexals.

    Note:
        Inverse of get_hextent.
    """
    return 1 + 3 * extent * (extent + 1)

flyvis.utils.hex_utils.get_hextent

get_hextent(num_hexals)

Computes the hex-lattice extent from the number of hexals.

Parameters:

Name Type Description Default
num_hexals int

Number of hexals.

required

Returns:

Type Description
int

Extent of hex-lattice.

Note

Inverse of get_num_hexals.

Source code in flyvis/utils/hex_utils.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def get_hextent(num_hexals: int) -> int:
    """Computes the hex-lattice extent from the number of hexals.

    Args:
        num_hexals: Number of hexals.

    Returns:
        Extent of hex-lattice.

    Note:
        Inverse of get_num_hexals.
    """

    return np.floor(np.sqrt(num_hexals / 3)).astype("int")

flyvis.utils.hex_utils.sort_u_then_v

sort_u_then_v(u, v, values)

Sorts u, v, and values by u and then v.

Parameters:

Name Type Description Default
u NDArray

U-coordinate of hexal.

required
v NDArray

V-coordinate of hexal.

required
values NDArray

Value of hexal.

required

Returns:

Type Description
Tuple[NDArray, NDArray, NDArray]

A tuple containing: u: Sorted u-coordinate of hexal. v: Sorted v-coordinate of hexal. values: Sorted value of hexal.

Source code in flyvis/utils/hex_utils.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def sort_u_then_v(
    u: NDArray, v: NDArray, values: NDArray
) -> Tuple[NDArray, NDArray, NDArray]:
    """Sorts u, v, and values by u and then v.

    Args:
        u: U-coordinate of hexal.
        v: V-coordinate of hexal.
        values: Value of hexal.

    Returns:
        A tuple containing:
            u: Sorted u-coordinate of hexal.
            v: Sorted v-coordinate of hexal.
            values: Sorted value of hexal.
    """
    index = np.lexsort((v, u))
    return u[index], v[index], values[index]

flyvis.utils.hex_utils.sort_u_then_v_index

sort_u_then_v_index(u, v)

Index to sort u, v by u and then v.

Parameters:

Name Type Description Default
u NDArray

U-coordinate of hexal.

required
v NDArray

V-coordinate of hexal.

required

Returns:

Type Description
NDArray

Index to sort u and v.

Source code in flyvis/utils/hex_utils.py
269
270
271
272
273
274
275
276
277
278
279
def sort_u_then_v_index(u: NDArray, v: NDArray) -> NDArray:
    """Index to sort u, v by u and then v.

    Args:
        u: U-coordinate of hexal.
        v: V-coordinate of hexal.

    Returns:
        Index to sort u and v.
    """
    return np.lexsort((v, u))

flyvis.utils.hex_utils.get_extent

get_extent(u, v, astype=int)

Returns extent (integer distance to origin) of arbitrary u, v coordinates.

Parameters:

Name Type Description Default
u NDArray

U-coordinate of hexal.

required
v NDArray

V-coordinate of hexal.

required
astype type

Type to cast to.

int

Returns:

Type Description
int

Extent of hex-lattice.

Note

If u and v are arrays, returns the maximum extent.

See Also

https://www.redblobgames.com/grids/hexagons/#distances

Source code in flyvis/utils/hex_utils.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def get_extent(u: NDArray, v: NDArray, astype: type = int) -> int:
    """Returns extent (integer distance to origin) of arbitrary u, v coordinates.

    Args:
        u: U-coordinate of hexal.
        v: V-coordinate of hexal.
        astype: Type to cast to.

    Returns:
        Extent of hex-lattice.

    Note:
        If u and v are arrays, returns the maximum extent.

    See Also:
        https://www.redblobgames.com/grids/hexagons/#distances
    """
    if isinstance(u, Number) and isinstance(v, Number):
        u, v = np.array((u,)), np.array((v,))
    uv = np.stack((u, v), 1)
    extent = (
        abs(0 - uv[:, 0]) + abs(0 + 0 - uv[:, 0] - uv[:, 1]) + abs(0 - uv[:, 1])
    ) / 2
    return np.max(extent).astype(astype)

flyvis.utils.log_utils

Classes

flyvis.utils.log_utils.Status dataclass

Status object from log files of model runs.

Attributes:

Name Type Description
ensemble_name str

Name of the ensemble.

log_files List[Path]

List of all log files.

train_logs List[Path]

List of train logs.

model_id_to_train_log_file Dict[str, Path]

Mapping of model ID to log file.

status Dict[str, str]

Mapping of model ID to status.

user_input Dict[str, str]

Mapping of model ID to user input (behind LSF command).

hosts Dict[str, List[str]]

Mapping of model ID to host.

rerun_failed_runs Dict[str, List[str]]

Formatted submission commands to restart failed models.

lsf_part str

LSF command part.

Source code in flyvis/utils/log_utils.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
@dataclass
class Status:
    """Status object from log files of model runs.

    Attributes:
        ensemble_name: Name of the ensemble.
        log_files: List of all log files.
        train_logs: List of train logs.
        model_id_to_train_log_file: Mapping of model ID to log file.
        status: Mapping of model ID to status.
        user_input: Mapping of model ID to user input (behind LSF command).
        hosts: Mapping of model ID to host.
        rerun_failed_runs: Formatted submission commands to restart failed models.
        lsf_part: LSF command part.
    """

    ensemble_name: str
    log_files: List[Path]
    train_logs: List[Path]
    model_id_to_train_log_file: Dict[str, Path]
    status: Dict[str, str]
    user_input: Dict[str, str]
    hosts: Dict[str, List[str]]
    rerun_failed_runs: Dict[str, List[str]]
    lsf_part: str

    def print_for_rerun(
        self, exclude_failed_hosts: bool = True, model_ids: List[str] = None
    ) -> None:
        """Print formatted submission commands to restart failed models.

        Args:
            exclude_failed_hosts: Whether to exclude failed hosts.
            model_ids: List of model IDs to rerun. If None, all failed models are
                included.
        """
        model_ids = model_ids or list(self.rerun_failed_runs.keys())
        for model_id in model_ids:
            command = ""
            subcmds = self.rerun_failed_runs[model_id]
            if exclude_failed_hosts:
                command += f"{subcmds[0]}{subcmds[1]}{subcmds[2]}"
            else:
                command += f"{subcmds[0]}{subcmds[2]}"
            print(command)

    def get_hosts(self) -> List[str]:
        """Get hosts on which the job was executed."""
        return list(set(flatten_list(self.hosts.values())))

    def bad_hosts(self) -> List[str]:
        """Get hosts on which the job failed."""
        host_lists = [
            host
            for model_id, host in self.hosts.items()
            if self.status[model_id] not in ["Successfully completed.", "running"]
        ]
        return list(set(flatten_list(host_lists)))

    def successful_runs(self) -> int:
        """Get number of successful runs."""
        return sum(1 for v in self.status.values() if v == "Successfully completed.")

    def running_runs(self) -> int:
        """Get number of running runs."""
        return sum(1 for v in self.status.values() if v == "running")

    def failed_runs(self) -> int:
        """Get number of failed runs."""
        return sum(1 for v in self.status.values() if "Exited with exit code" in v)

    def successful_model_ids(self) -> List[str]:
        """Get model IDs of successful runs."""
        return [k for k, v in self.status.items() if v == "Successfully completed."]

    def running_model_ids(self) -> List[str]:
        """Get model IDs of running runs."""
        return [k for k, v in self.status.items() if v == "running"]

    def failed_model_ids(self) -> List[str]:
        """Get model IDs of failed runs."""
        return [k for k, v in self.status.items() if "Exited with exit code" in v]

    def lookup_log(
        self, model_id: str, log_type: str = "train_single", last_n_lines: int = 20
    ) -> List[str]:
        """Lookup log for a model ID.

        Args:
            model_id: ID of the model.
            log_type: Type of log to lookup.
            last_n_lines: Number of lines to return from the end of the log.

        Returns:
            List of log lines.
        """
        log_file = [
            p
            for p in self.log_files
            if log_type in str(p) and p.name.split("_")[0] == model_id
        ][0]
        return log_file.read_text().split("\n")[-last_n_lines:]

    def extract_error_trace(
        self, model_id: str, check_last_n_lines: int = 100, log_type: str = "train_single"
    ) -> str:
        """Extract the Python error message and traceback from a given log string.

        Args:
            model_id: ID of the model.
            check_last_n_lines: Number of lines to check from the end of the log.
            log_type: Type of log to extract error from.

        Returns:
            Extracted error message and traceback, or a message if no error is found.
        """
        log_string = "\n".join(
            self.lookup_log(model_id, last_n_lines=check_last_n_lines, log_type=log_type)
        )
        pattern = r"Traceback \(most recent call last\):(.+?)(?=\n\n|\Z)"
        match = re.search(pattern, log_string, re.DOTALL)
        return match.group(0).strip() if match else "No Python error found in the log."

    def extract_error_type(
        self, model_id: str, log_type: str = "train_single", check_last_n_lines: int = 100
    ) -> str:
        """Extract the type of error from a given log string.

        Args:
            model_id: ID of the model.
            log_type: Type of log to extract error from.
            check_last_n_lines: Number of lines to check from the end of the log.

        Returns:
            Extracted error type, or a message if no specific error type is found.
        """
        log_string = "\n".join(
            self.lookup_log(model_id, last_n_lines=check_last_n_lines, log_type=log_type)
        )
        pattern = r"\b[A-Z]\w*Error\b"
        match = re.search(pattern, log_string)
        return match.group(0) if match else "No specific error type found."

    def all_errors(
        self, check_last_n_lines: int = 100, log_type: str = "train_single"
    ) -> set:
        """Get all unique errors from failed runs.

        Args:
            check_last_n_lines: Number of lines to check from the end of the log.
            log_type: Type of log to extract errors from.

        Returns:
            Set of unique error messages.
        """
        return set(
            self.extract_error_trace(
                model_id,
                check_last_n_lines=check_last_n_lines,
                log_type=log_type,
            )
            for model_id in self.failed_model_ids()
        )

    def all_error_types(self, log_type: str = "train_single") -> set:
        """Get all unique error types from failed runs.

        Args:
            log_type: Type of log to extract error types from.

        Returns:
            Set of unique error types.
        """
        return set(
            self.extract_error_type(model_id, log_type=log_type)
            for model_id in self.failed_model_ids()
        )

    def print_all_errors(
        self, check_last_n_lines: int = 100, log_type: str = "train_single"
    ) -> None:
        """Print all errors and tracebacks from failed runs.

        Args:
            check_last_n_lines: Number of lines to check from the end of the log.
            log_type: Type of log to extract errors from.
        """
        for model_id in self.failed_model_ids():
            print(
                f"Model {model_id} failed with the following error message "
                "and traceback:\n"
            )
            print(
                self.extract_error_trace(
                    model_id,
                    check_last_n_lines=check_last_n_lines,
                    log_type=log_type,
                )
            )
            print("\n")

    def __getitem__(self, key: str) -> str:
        """Get status for a specific model ID."""
        if key in self.status:
            return self.status[key]
        return object.__getitem__(self, key)

    def __repr__(self) -> str:
        """Return a string representation of the Status object."""
        _repr = f"Status of ensemble {self.ensemble_name}."
        _repr += f"\n{len(self.status)} models."
        _repr += f"\nHosts: {','.join(self.get_hosts())}."
        _repr += f"\n  {self.successful_runs()} successful runs."
        _repr += f"\n  {self.running_runs()} running runs."
        _repr += f"\n  {self.failed_runs()} failed runs."
        if self.failed_runs() > 0:
            _repr += f"\n  Bad hosts: {','.join(self.bad_hosts())}."
            _repr += "\n  Use .print_for_rerun() to print formatted submission commands"
            _repr += " to restart failed models."
            _repr += "\nError types:"
            for error in self.all_error_types():
                _repr += f"\n  {error}"
            _repr += (
                "\n  Run .print_all_errors() to print the error messages and tracebacks."
            )
        return _repr
print_for_rerun
print_for_rerun(exclude_failed_hosts=True, model_ids=None)

Print formatted submission commands to restart failed models.

Parameters:

Name Type Description Default
exclude_failed_hosts bool

Whether to exclude failed hosts.

True
model_ids List[str]

List of model IDs to rerun. If None, all failed models are included.

None
Source code in flyvis/utils/log_utils.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def print_for_rerun(
    self, exclude_failed_hosts: bool = True, model_ids: List[str] = None
) -> None:
    """Print formatted submission commands to restart failed models.

    Args:
        exclude_failed_hosts: Whether to exclude failed hosts.
        model_ids: List of model IDs to rerun. If None, all failed models are
            included.
    """
    model_ids = model_ids or list(self.rerun_failed_runs.keys())
    for model_id in model_ids:
        command = ""
        subcmds = self.rerun_failed_runs[model_id]
        if exclude_failed_hosts:
            command += f"{subcmds[0]}{subcmds[1]}{subcmds[2]}"
        else:
            command += f"{subcmds[0]}{subcmds[2]}"
        print(command)
get_hosts
get_hosts()

Get hosts on which the job was executed.

Source code in flyvis/utils/log_utils.py
56
57
58
def get_hosts(self) -> List[str]:
    """Get hosts on which the job was executed."""
    return list(set(flatten_list(self.hosts.values())))
bad_hosts
bad_hosts()

Get hosts on which the job failed.

Source code in flyvis/utils/log_utils.py
60
61
62
63
64
65
66
67
def bad_hosts(self) -> List[str]:
    """Get hosts on which the job failed."""
    host_lists = [
        host
        for model_id, host in self.hosts.items()
        if self.status[model_id] not in ["Successfully completed.", "running"]
    ]
    return list(set(flatten_list(host_lists)))
successful_runs
successful_runs()

Get number of successful runs.

Source code in flyvis/utils/log_utils.py
69
70
71
def successful_runs(self) -> int:
    """Get number of successful runs."""
    return sum(1 for v in self.status.values() if v == "Successfully completed.")
running_runs
running_runs()

Get number of running runs.

Source code in flyvis/utils/log_utils.py
73
74
75
def running_runs(self) -> int:
    """Get number of running runs."""
    return sum(1 for v in self.status.values() if v == "running")
failed_runs
failed_runs()

Get number of failed runs.

Source code in flyvis/utils/log_utils.py
77
78
79
def failed_runs(self) -> int:
    """Get number of failed runs."""
    return sum(1 for v in self.status.values() if "Exited with exit code" in v)
successful_model_ids
successful_model_ids()

Get model IDs of successful runs.

Source code in flyvis/utils/log_utils.py
81
82
83
def successful_model_ids(self) -> List[str]:
    """Get model IDs of successful runs."""
    return [k for k, v in self.status.items() if v == "Successfully completed."]
running_model_ids
running_model_ids()

Get model IDs of running runs.

Source code in flyvis/utils/log_utils.py
85
86
87
def running_model_ids(self) -> List[str]:
    """Get model IDs of running runs."""
    return [k for k, v in self.status.items() if v == "running"]
failed_model_ids
failed_model_ids()

Get model IDs of failed runs.

Source code in flyvis/utils/log_utils.py
89
90
91
def failed_model_ids(self) -> List[str]:
    """Get model IDs of failed runs."""
    return [k for k, v in self.status.items() if "Exited with exit code" in v]
lookup_log
lookup_log(model_id, log_type='train_single', last_n_lines=20)

Lookup log for a model ID.

Parameters:

Name Type Description Default
model_id str

ID of the model.

required
log_type str

Type of log to lookup.

'train_single'
last_n_lines int

Number of lines to return from the end of the log.

20

Returns:

Type Description
List[str]

List of log lines.

Source code in flyvis/utils/log_utils.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def lookup_log(
    self, model_id: str, log_type: str = "train_single", last_n_lines: int = 20
) -> List[str]:
    """Lookup log for a model ID.

    Args:
        model_id: ID of the model.
        log_type: Type of log to lookup.
        last_n_lines: Number of lines to return from the end of the log.

    Returns:
        List of log lines.
    """
    log_file = [
        p
        for p in self.log_files
        if log_type in str(p) and p.name.split("_")[0] == model_id
    ][0]
    return log_file.read_text().split("\n")[-last_n_lines:]
extract_error_trace
extract_error_trace(model_id, check_last_n_lines=100, log_type='train_single')

Extract the Python error message and traceback from a given log string.

Parameters:

Name Type Description Default
model_id str

ID of the model.

required
check_last_n_lines int

Number of lines to check from the end of the log.

100
log_type str

Type of log to extract error from.

'train_single'

Returns:

Type Description
str

Extracted error message and traceback, or a message if no error is found.

Source code in flyvis/utils/log_utils.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def extract_error_trace(
    self, model_id: str, check_last_n_lines: int = 100, log_type: str = "train_single"
) -> str:
    """Extract the Python error message and traceback from a given log string.

    Args:
        model_id: ID of the model.
        check_last_n_lines: Number of lines to check from the end of the log.
        log_type: Type of log to extract error from.

    Returns:
        Extracted error message and traceback, or a message if no error is found.
    """
    log_string = "\n".join(
        self.lookup_log(model_id, last_n_lines=check_last_n_lines, log_type=log_type)
    )
    pattern = r"Traceback \(most recent call last\):(.+?)(?=\n\n|\Z)"
    match = re.search(pattern, log_string, re.DOTALL)
    return match.group(0).strip() if match else "No Python error found in the log."
extract_error_type
extract_error_type(model_id, log_type='train_single', check_last_n_lines=100)

Extract the type of error from a given log string.

Parameters:

Name Type Description Default
model_id str

ID of the model.

required
log_type str

Type of log to extract error from.

'train_single'
check_last_n_lines int

Number of lines to check from the end of the log.

100

Returns:

Type Description
str

Extracted error type, or a message if no specific error type is found.

Source code in flyvis/utils/log_utils.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def extract_error_type(
    self, model_id: str, log_type: str = "train_single", check_last_n_lines: int = 100
) -> str:
    """Extract the type of error from a given log string.

    Args:
        model_id: ID of the model.
        log_type: Type of log to extract error from.
        check_last_n_lines: Number of lines to check from the end of the log.

    Returns:
        Extracted error type, or a message if no specific error type is found.
    """
    log_string = "\n".join(
        self.lookup_log(model_id, last_n_lines=check_last_n_lines, log_type=log_type)
    )
    pattern = r"\b[A-Z]\w*Error\b"
    match = re.search(pattern, log_string)
    return match.group(0) if match else "No specific error type found."
all_errors
all_errors(check_last_n_lines=100, log_type='train_single')

Get all unique errors from failed runs.

Parameters:

Name Type Description Default
check_last_n_lines int

Number of lines to check from the end of the log.

100
log_type str

Type of log to extract errors from.

'train_single'

Returns:

Type Description
set

Set of unique error messages.

Source code in flyvis/utils/log_utils.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def all_errors(
    self, check_last_n_lines: int = 100, log_type: str = "train_single"
) -> set:
    """Get all unique errors from failed runs.

    Args:
        check_last_n_lines: Number of lines to check from the end of the log.
        log_type: Type of log to extract errors from.

    Returns:
        Set of unique error messages.
    """
    return set(
        self.extract_error_trace(
            model_id,
            check_last_n_lines=check_last_n_lines,
            log_type=log_type,
        )
        for model_id in self.failed_model_ids()
    )
all_error_types
all_error_types(log_type='train_single')

Get all unique error types from failed runs.

Parameters:

Name Type Description Default
log_type str

Type of log to extract error types from.

'train_single'

Returns:

Type Description
set

Set of unique error types.

Source code in flyvis/utils/log_utils.py
174
175
176
177
178
179
180
181
182
183
184
185
186
def all_error_types(self, log_type: str = "train_single") -> set:
    """Get all unique error types from failed runs.

    Args:
        log_type: Type of log to extract error types from.

    Returns:
        Set of unique error types.
    """
    return set(
        self.extract_error_type(model_id, log_type=log_type)
        for model_id in self.failed_model_ids()
    )
print_all_errors
print_all_errors(check_last_n_lines=100, log_type='train_single')

Print all errors and tracebacks from failed runs.

Parameters:

Name Type Description Default
check_last_n_lines int

Number of lines to check from the end of the log.

100
log_type str

Type of log to extract errors from.

'train_single'
Source code in flyvis/utils/log_utils.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def print_all_errors(
    self, check_last_n_lines: int = 100, log_type: str = "train_single"
) -> None:
    """Print all errors and tracebacks from failed runs.

    Args:
        check_last_n_lines: Number of lines to check from the end of the log.
        log_type: Type of log to extract errors from.
    """
    for model_id in self.failed_model_ids():
        print(
            f"Model {model_id} failed with the following error message "
            "and traceback:\n"
        )
        print(
            self.extract_error_trace(
                model_id,
                check_last_n_lines=check_last_n_lines,
                log_type=log_type,
            )
        )
        print("\n")
__getitem__
__getitem__(key)

Get status for a specific model ID.

Source code in flyvis/utils/log_utils.py
211
212
213
214
215
def __getitem__(self, key: str) -> str:
    """Get status for a specific model ID."""
    if key in self.status:
        return self.status[key]
    return object.__getitem__(self, key)
__repr__
__repr__()

Return a string representation of the Status object.

Source code in flyvis/utils/log_utils.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def __repr__(self) -> str:
    """Return a string representation of the Status object."""
    _repr = f"Status of ensemble {self.ensemble_name}."
    _repr += f"\n{len(self.status)} models."
    _repr += f"\nHosts: {','.join(self.get_hosts())}."
    _repr += f"\n  {self.successful_runs()} successful runs."
    _repr += f"\n  {self.running_runs()} running runs."
    _repr += f"\n  {self.failed_runs()} failed runs."
    if self.failed_runs() > 0:
        _repr += f"\n  Bad hosts: {','.join(self.bad_hosts())}."
        _repr += "\n  Use .print_for_rerun() to print formatted submission commands"
        _repr += " to restart failed models."
        _repr += "\nError types:"
        for error in self.all_error_types():
            _repr += f"\n  {error}"
        _repr += (
            "\n  Run .print_all_errors() to print the error messages and tracebacks."
        )
    return _repr

Functions

flyvis.utils.log_utils.find_host

find_host(log_string)

Find the host(s) on which the job was executed.

Parameters:

Name Type Description Default
log_string str

The log string to search for host information.

required

Returns:

Type Description
List[str]

List of host names.

Source code in flyvis/utils/log_utils.py
238
239
240
241
242
243
244
245
246
247
248
def find_host(log_string: str) -> List[str]:
    """Find the host(s) on which the job was executed.

    Args:
        log_string: The log string to search for host information.

    Returns:
        List of host names.
    """
    pattern = r"executed on host\(s\) <(?:\d*\*)?(.+?)>,"
    return re.findall(pattern, log_string)

flyvis.utils.log_utils.get_exclude_host_part

get_exclude_host_part(log_string, exclude_hosts)

Get the part of the LSF command that excludes hosts.

Parameters:

Name Type Description Default
log_string str

The log string to search for host information.

required
exclude_hosts Union[str, List[str]]

Host(s) to exclude. Can be ‘auto’, a single host name, or a list of host names.

required

Returns:

Type Description
str

The LSF command part for excluding hosts.

Source code in flyvis/utils/log_utils.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def get_exclude_host_part(log_string: str, exclude_hosts: Union[str, List[str]]) -> str:
    """Get the part of the LSF command that excludes hosts.

    Args:
        log_string: The log string to search for host information.
        exclude_hosts: Host(s) to exclude. Can be 'auto', a single host name, or a list
            of host names.

    Returns:
        The LSF command part for excluding hosts.
    """
    if exclude_hosts is None:
        return ""

    exclude_host_part = '-R "select[{}]" '

    if isinstance(exclude_hosts, str) and exclude_hosts == "auto":
        exclude_hosts = find_host(log_string)
    elif isinstance(exclude_hosts, str):
        exclude_hosts = [exclude_hosts]

    exclusion_strings = [f"hname!='{host}'" for host in exclude_hosts]
    return exclude_host_part.format(" && ".join(exclusion_strings))

flyvis.utils.log_utils.get_status

get_status(ensemble_name, nP=4, gpu='num=1', queue='gpu_l4', exclude_hosts='auto')

Get Status object for the ensemble of models with formatting for rerun.

Parameters:

Name Type Description Default
ensemble_name str

Ensemble name (e.g. “flow/“).

required
nP int

Number of processors.

4
gpu str

Number of GPUs.

'num=1'
queue str

Queue name.

'gpu_l4'
exclude_hosts Union[str, List[str]]

Host(s) to exclude. Can be ‘auto’, a single host name, or a list of host names.

'auto'

Returns:

Type Description
Status

Status object containing information about the ensemble runs.

Source code in flyvis/utils/log_utils.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def get_status(
    ensemble_name: str,
    nP: int = 4,
    gpu: str = "num=1",
    queue: str = "gpu_l4",
    exclude_hosts: Union[str, List[str]] = "auto",
) -> Status:
    """Get Status object for the ensemble of models with formatting for rerun.

    Args:
        ensemble_name: Ensemble name (e.g. "flow/<id>").
        nP: Number of processors.
        gpu: Number of GPUs.
        queue: Queue name.
        exclude_hosts: Host(s) to exclude. Can be 'auto', a single host name, or a
            list of host names.

    Returns:
        Status object containing information about the ensemble runs.
    """
    _lsf_part = "bsub -J {} -n {} -o {} -gpu '{}' -q {} "

    tnn_paths, path = model_paths_from_parent(flyvis.results_dir / ensemble_name)
    log_files = [p for p in path.iterdir() if p.suffix == ".log"]
    train_logs = [p for p in log_files if "train_single" in str(p)]
    model_id_to_train_log_file = {p.name.split("_")[0]: p for p in train_logs}

    status = {}
    user_input = {}
    hosts = {}
    log_strings = {}
    for p in train_logs:
        model_id = p.name.split("_")[0]
        log_str = p.read_text()
        log_strings[model_id] = log_str
        if log_str.split("\n")[-3] == "The output (if any) is above this job summary.":
            status[model_id] = log_str.split("\n")[-18]
            user_input[model_id] = log_str.split("\n")[-21]
        else:
            status[model_id] = "running"
            user_input[model_id] = ""

        hosts[model_id] = find_host(log_str)

    _lfs_cmd = _lsf_part
    rerun_failed_runs = {}
    for model_id, stat in status.items():
        if stat not in ["Successfully completed.", "running"]:
            _lsf_cmd = _lsf_part.format(
                f"{ensemble_name}/{model_id}",
                nP,
                model_id_to_train_log_file[model_id],
                gpu,
                queue,
            )
            exclude_host_part = get_exclude_host_part(
                log_strings[model_id], exclude_hosts
            )
            rerun_failed_runs[model_id] = [
                _lsf_cmd,
                exclude_host_part,
                user_input[model_id],
            ]
    return Status(
        ensemble_name,
        log_files,
        train_logs,
        model_id_to_train_log_file,
        status,
        user_input,
        hosts,
        rerun_failed_runs,
        _lfs_cmd,
    )

flyvis.utils.log_utils.flatten_list

flatten_list(nested_list)

Flatten a nested list of lists into a single list with all elements.

Parameters:

Name Type Description Default
nested_list List

A nested list of lists to be flattened.

required

Returns:

Type Description
List

A single flattened list with all elements.

Source code in flyvis/utils/log_utils.py
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def flatten_list(nested_list: List) -> List:
    """Flatten a nested list of lists into a single list with all elements.

    Args:
        nested_list: A nested list of lists to be flattened.

    Returns:
        A single flattened list with all elements.
    """
    flattened = []
    for item in nested_list:
        if isinstance(item, list):
            flattened.extend(flatten_list(item))
        else:
            flattened.append(item)
    return flattened

flyvis.utils.logging_utils

Functions

flyvis.utils.logging_utils.warn_once cached

warn_once(logger, msg)

Log a warning message only once for a given logger and message combination.

Parameters:

Name Type Description Default
logger Logger

The logger object to use for logging.

required
msg str

The warning message to log.

required
Note

This function uses an LRU cache to ensure each unique combination of logger and message is only logged once.

Source code in flyvis/utils/logging_utils.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
@lru_cache(100)
def warn_once(logger: logging.Logger, msg: str) -> None:
    """
    Log a warning message only once for a given logger and message combination.

    Args:
        logger: The logger object to use for logging.
        msg: The warning message to log.

    Note:
        This function uses an LRU cache to ensure each unique combination of
        logger and message is only logged once.
    """
    logger.warning(msg)

flyvis.utils.logging_utils.save_conda_environment

save_conda_environment(path)

Save the current Conda environment to a JSON file.

Parameters:

Name Type Description Default
path Path

The path where the JSON file will be saved.

required
Note

The function appends ‘.json’ to the provided path.

Source code in flyvis/utils/logging_utils.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def save_conda_environment(path: Path) -> None:
    """
    Save the current Conda environment to a JSON file.

    Args:
        path: The path where the JSON file will be saved.

    Note:
        The function appends '.json' to the provided path.
    """
    result = subprocess.run(
        ["conda", "list", "--json"], stdout=subprocess.PIPE, text=True, check=False
    )

    installed_packages = json.loads(result.stdout)

    with open(path.with_suffix(".json"), "w") as json_file:
        json.dump(installed_packages, json_file, indent=4)

flyvis.utils.logging_utils.all_logging_disabled

all_logging_disabled(highest_level=logging.CRITICAL)

A context manager that prevents any logging messages from being processed.

Parameters:

Name Type Description Default
highest_level int

The maximum logging level to disable. Only needs to be changed if a custom level greater than CRITICAL is defined.

CRITICAL
Example
with all_logging_disabled():
    # Code here will not produce any log output
    logging.warning("This warning will not be logged")
Reference

https://gist.github.com/simon-weber/7853144

Source code in flyvis/utils/logging_utils.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@contextmanager
def all_logging_disabled(highest_level: int = logging.CRITICAL) -> Any:
    """
    A context manager that prevents any logging messages from being processed.

    Args:
        highest_level: The maximum logging level to disable. Only needs to be
            changed if a custom level greater than CRITICAL is defined.

    Example:
        ```python
        with all_logging_disabled():
            # Code here will not produce any log output
            logging.warning("This warning will not be logged")
        ```

    Reference:
        https://gist.github.com/simon-weber/7853144
    """
    previous_level = logging.root.manager.disable

    logging.disable(highest_level)

    try:
        yield
    finally:
        logging.disable(previous_level)

flyvis.utils.nn_utils

Classes

flyvis.utils.nn_utils.NumberOfParams dataclass

Dataclass to store the number of free and fixed parameters.

Attributes:

Name Type Description
free int

The number of trainable parameters.

fixed int

The number of non-trainable parameters.

Source code in flyvis/utils/nn_utils.py
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass
class NumberOfParams:
    """
    Dataclass to store the number of free and fixed parameters.

    Attributes:
        free: The number of trainable parameters.
        fixed: The number of non-trainable parameters.
    """

    free: int
    fixed: int

Functions

flyvis.utils.nn_utils.simulation

simulation(network)

Context manager to turn off training mode and require_grad for a network.

Parameters:

Name Type Description Default
network Module

The neural network module to simulate.

required

Yields:

Type Description
None

None

Example
model = MyNeuralNetwork()
with simulation(model):
    # Perform inference or evaluation
    output = model(input_data)
Note

This context manager temporarily disables gradient computation and sets the network to evaluation mode. It restores the original state after exiting the context.

Source code in flyvis/utils/nn_utils.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@contextmanager
def simulation(network: nn.Module) -> Generator[None, None, None]:
    """
    Context manager to turn off training mode and require_grad for a network.

    Args:
        network: The neural network module to simulate.

    Yields:
        None

    Example:
        ```python
        model = MyNeuralNetwork()
        with simulation(model):
            # Perform inference or evaluation
            output = model(input_data)
        ```

    Note:
        This context manager temporarily disables gradient computation and sets
        the network to evaluation mode. It restores the original state after
        exiting the context.
    """
    _training = network.training
    network.training = False
    params_require_grad = {}
    for name, p in network.named_parameters():
        params_require_grad[name] = p.requires_grad
        p.requires_grad = False
    try:
        yield
    finally:
        network.training = _training
        for name, p in network.named_parameters():
            p.requires_grad = params_require_grad[name]

flyvis.utils.nn_utils.n_params

n_params(nnmodule)

Returns the numbers of free and fixed parameters in a PyTorch module.

Parameters:

Name Type Description Default
nnmodule Module

The PyTorch module to analyze.

required

Returns:

Type Description
NumberOfParams

A NumberOfParams object containing the count of free and fixed parameters.

Example
model = MyNeuralNetwork()
param_count = n_params(model)
print(f"Free parameters: {param_count.free}")
print(f"Fixed parameters: {param_count.fixed}")
Source code in flyvis/utils/nn_utils.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def n_params(nnmodule: nn.Module) -> NumberOfParams:
    """
    Returns the numbers of free and fixed parameters in a PyTorch module.

    Args:
        nnmodule: The PyTorch module to analyze.

    Returns:
        A NumberOfParams object containing the count of free and fixed parameters.

    Example:
        ```python
        model = MyNeuralNetwork()
        param_count = n_params(model)
        print(f"Free parameters: {param_count.free}")
        print(f"Fixed parameters: {param_count.fixed}")
        ```
    """
    n_free = 0
    n_fixed = 0
    for param in nnmodule.parameters():
        if param.requires_grad:
            n_free += param.nelement()
        else:
            n_fixed += param.nelement()
    return NumberOfParams(n_free, n_fixed)

flyvis.utils.nodes_edges_utils

Classes

flyvis.utils.nodes_edges_utils.NodeIndexer

Bases: dict

Attribute-style accessible map from cell types to indices.

Parameters:

Name Type Description Default
connectome Optional[ConnectomeFromAvgFilters]

Connectome object. The cell types are taken from the connectome and references are created in order.

None
unique_cell_types Optional[NDArray[str]]

Array of unique cell types. Optional. To specify the mapping from cell types to indices in provided order.

None

Attributes:

Name Type Description
unique_cell_types NDArray[str]

Array of unique cell types.

central_cells_index Optional[NDArray[int]]

Array of indices of central cells.

Raises:

Type Description
ValueError

If neither connectome nor unique_cell_types is provided.

Source code in flyvis/utils/nodes_edges_utils.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
class NodeIndexer(dict):
    """Attribute-style accessible map from cell types to indices.

    Args:
        connectome: Connectome object. The cell types are taken from the
            connectome and references are created in order.
        unique_cell_types: Array of unique cell types. Optional.
            To specify the mapping from cell types to indices in provided order.

    Attributes:
        unique_cell_types (NDArray[str]): Array of unique cell types.
        central_cells_index (Optional[NDArray[int]]): Array of indices of central cells.

    Raises:
        ValueError: If neither connectome nor unique_cell_types is provided.
    """

    def __init__(
        self,
        connectome: Optional["connectome.ConnectomeFromAvgFilters"] = None,
        unique_cell_types: Optional[NDArray[str]] = None,
    ):
        # if connectome is specified, the indices are taken from the connectome
        # and reference to positions in the entire list of nodes/cells
        if connectome is not None and unique_cell_types is None:
            self.unique_cell_types = connectome.unique_cell_types[:].astype("str")
            self.central_cells_index = connectome.central_cells_index[:]
        # alternatively the mapping can be specified from a list of cell types
        # and reference to positions in order of the list
        elif connectome is None and unique_cell_types is not None:
            self.unique_cell_types = unique_cell_types
            self.central_cells_index = None
        else:
            raise ValueError("either cell types or connectome must be specified")
        for index, cell_type in enumerate(self.unique_cell_types):
            super().__setitem__(cell_type, index)

    def __dir__(self):
        return list(set([*dict.__dir__(self), *dict.__iter__(self)]))

    def __len__(self):
        return len(self.unique_cell_types)

    def __iter__(self):
        for cell_type in self.unique_cell_types:
            yield cell_type

    def __getattr__(self, key):
        if isinstance(key, str):
            pass
        elif isinstance(key, Iterable):
            return [dict.__getitem__(self, _key) for _key in key]
        return dict.__getitem__(self, key)

    def __getitem__(self, key):
        return self.__getattr__(key)

flyvis.utils.nodes_edges_utils.CellTypeArray

Attribute-style accessible map from cell types to coordinates in array.

Parameters:

Name Type Description Default
array Union[NDArray, Tensor]

Has the dim-th axis corresponding to unique cell types in the connectome or provided cell types.

required
connectome Optional[ConnectomeFromAvgFilters]

Connectome object.

None
cell_types Optional[NDArray[str]]

Array of cell types.

None
dim int

Axis corresponding to unique cell types.

-1

Attributes:

Name Type Description
node_indexer NodeIndexer

Indexer for cell types.

array NDArray

The array of cell type data.

dim int

Dimension corresponding to cell types.

cell_types NDArray[str]

Array of unique cell types.

Source code in flyvis/utils/nodes_edges_utils.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
class CellTypeArray:
    """Attribute-style accessible map from cell types to coordinates in array.

    Args:
        array: Has the dim-th axis corresponding to unique cell types
            in the connectome or provided cell types.
        connectome: Connectome object.
        cell_types: Array of cell types.
        dim: Axis corresponding to unique cell types.

    Attributes:
        node_indexer (NodeIndexer): Indexer for cell types.
        array (NDArray): The array of cell type data.
        dim (int): Dimension corresponding to cell types.
        cell_types (NDArray[str]): Array of unique cell types.
    """

    node_indexer: NodeIndexer = None
    array: NDArray = None
    dim: float = None

    def __init__(
        self,
        array: Union[NDArray, torch.Tensor],
        connectome: Optional["connectome.ConnectomeFromAvgFilters"] = None,
        cell_types: Optional[NDArray[str]] = None,
        dim: int = -1,
    ):
        self.array = array
        self.dim = dim
        self.node_indexer = NodeIndexer(connectome, cell_types)
        self.cell_types = self.node_indexer.unique_cell_types

    def __bool__(self):
        return self.array is not None

    def __iter__(self):
        for cell_type in self.node_indexer.unique_cell_types:
            yield cell_type

    def __dir__(self):
        return list(
            set([
                *object.__dir__(self),
                *dict.__dir__(self.node_indexer),
                *dict.__iter__(self.node_indexer),
            ])
        )

    @property
    def shape(self):
        if self.array is not None:
            return self.array.shape
        return []

    def __repr__(self):
        shape = list(self.shape)
        desc = f"Array({tuple(shape)})"
        return {k: desc for k in self}.__repr__()

    def values(self):
        return [self[k] for k in self]

    def keys(self):
        return [k for k in self]

    def items(self):
        return [(k, self[k]) for k in self]

    def __len__(self):
        return len(self.node_indexer.unique_cell_types)

    def __getattr__(self, key):
        if self.node_indexer is not None:
            if isinstance(key, slice) and key == slice(None):
                return self.array
            elif isinstance(key, str) and key in self.node_indexer.unique_cell_types:
                indices = np.int_([dict.__getitem__(self.node_indexer, key)])
            elif isinstance(key, Iterable) and all([
                _key in self.node_indexer.unique_cell_types for _key in key
            ]):
                indices = np.int_([
                    dict.__getitem__(self.node_indexer, _key) for _key in key
                ])
            elif key in self.node_indexer.__dir__():
                return object.__getattribute__(self.node_indexer, key)
            else:
                return object.__getattribute__(self, key)
            return np.take(self.array, indices, axis=self.dim)
        return object.__getattribute__(self, key)

    def __getitem__(self, key):
        return self.__getattr__(key)

    def __setitem__(self, key, value):
        if self.node_indexer is not None and key in self.node_indexer.unique_cell_types:
            if value.shape[-1] != 1:
                value = np.expand_dims(value, self.dim)
            if self.array is None:
                n_cell_types = len(self.node_indexer.unique_cell_types)
                shape = list(value.shape)
                shape[self.dim] = n_cell_types
                self.array = np.zeros(shape)
            # breakpoint()
            index = dict.__getitem__(self.node_indexer, key)
            np.put_along_axis(
                self.array,
                np.expand_dims(np.array([index]), list(range(len(self.array.shape[1:])))),
                value,
                self.dim,
            )
        else:
            object.__setattr__(self, key, value)

    def __setattr__(self, key, value):
        return self.__setitem__(key, value)

    def from_cell_types(self, cell_types):
        activity = self[cell_types]
        return CellTypeArray(
            activity,
            cell_types=cell_types,
            dim=self.dim,
        )

Functions

flyvis.utils.nodes_edges_utils.order_node_type_list

order_node_type_list(node_types, groups=['R\\d', 'L\\d', 'Lawf\\d', 'A', 'C\\d', 'CT\\d.*', 'Mi\\d{1,2}', 'T\\d{1,2}.*', 'Tm.*\\d{1,2}.*'])

Orders a list of node types by the regular expressions defined in groups.

Parameters:

Name Type Description Default
node_types List[str]

Messy list of nodes.

required
groups List[str]

Ordered list of regular expressions to sort node_types.

['R\\d', 'L\\d', 'Lawf\\d', 'A', 'C\\d', 'CT\\d.*', 'Mi\\d{1,2}', 'T\\d{1,2}.*', 'Tm.*\\d{1,2}.*']

Returns:

Type Description
List[str]

A tuple containing:

List[int]
  • Ordered node type list
Tuple[List[str], List[int]]
  • Corresponding sorting indices

Raises:

Type Description
AssertionError

If sorting doesn’t include all cell types.

ValueError

If sorting fails due to length mismatch.

Source code in flyvis/utils/nodes_edges_utils.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def order_node_type_list(
    node_types: List[str],
    groups: List[str] = [
        r"R\d",
        r"L\d",
        r"Lawf\d",
        r"A",
        r"C\d",
        r"CT\d.*",
        r"Mi\d{1,2}",
        r"T\d{1,2}.*",
        r"Tm.*\d{1,2}.*",
    ],
) -> Tuple[List[str], List[int]]:
    """Orders a list of node types by the regular expressions defined in groups.

    Args:
        node_types: Messy list of nodes.
        groups: Ordered list of regular expressions to sort node_types.

    Returns:
        A tuple containing:
        - Ordered node type list
        - Corresponding sorting indices

    Raises:
        AssertionError: If sorting doesn't include all cell types.
        ValueError: If sorting fails due to length mismatch.
    """
    if node_types is None:
        return None, None

    _len = len(node_types)

    def sort_numeric(string):
        """Used in sorted(list, key=sort_fn) for sorting
        lists including 0 to 4 digits after the character.
        """
        regular_expression = r"\d+"
        match = re.search(regular_expression, string)
        if not match:
            # For example Am types are not numbered.
            return string
        return re.sub(regular_expression, f"{int(match.group()):04}", string)

    #     breakpoint()
    type_groups = {index: [] for index in range(len(groups))}
    type_groups.update({len(groups) + 1: []})  # for unmatched types.
    matched = {cell_type: False for cell_type in node_types}
    for node_index, cell_type in enumerate(node_types):
        for group_index, regular_expression in enumerate(groups):
            if re.match(regular_expression, cell_type):
                type_groups[group_index].append((node_index, cell_type))
                matched[cell_type] = True
        if matched[cell_type]:
            pass
        else:
            type_groups[len(groups) + 1].append((node_index, cell_type))

    # ordered = [y for x in type_groups.values() for y in sorted(x, key=lambda z:
    # sort_fn(z[1]))]
    ordered = []
    for x in type_groups.values():
        for y in sorted(x, key=lambda z: sort_numeric(z[1])):
            ordered.append(y)
    index = [y[0] for y in ordered]
    nodes = [y[1] for y in ordered]

    if set(node_types) - set(nodes):
        print(set(node_types) - set(nodes))
        raise AssertionError(
            "Defined sorting through regular expressions does not include all cell"
            " types."
        )

    if _len != len(nodes) or _len != len(index):
        raise ValueError(
            "sorting failed because the resulting array if of " " different length"
        )

    return nodes, index

flyvis.utils.nodes_edges_utils.get_index_mapping_lists

get_index_mapping_lists(from_list, to_list)

Get indices to sort and filter from_list by occurrence of items in to_list.

The indices are useful to sort or filter another list or tensor that is an ordered mapping to items in from_list to the order of items in to_list.

Parameters:

Name Type Description Default
from_list List[str]

Original list of items.

required
to_list List[str]

Target list of items.

required

Returns:

Type Description
List[int]

List of indices for sorting.

Example
from_list = ["a", "b", "c"]
mapping_to_from_list = [1, 2, 3]
to_list = ["c", "a", "b"]
sort_index = get_index_mapping_lists(from_list, to_list)
sorted_list = [mapping_to_from_list[i] for i in sort_index]
# sorted_list will be [3, 1, 2]
Source code in flyvis/utils/nodes_edges_utils.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def get_index_mapping_lists(from_list: List[str], to_list: List[str]) -> List[int]:
    """Get indices to sort and filter from_list by occurrence of items in to_list.

    The indices are useful to sort or filter another list or tensor that
    is an ordered mapping to items in from_list to the order of items in to_list.

    Args:
        from_list: Original list of items.
        to_list: Target list of items.

    Returns:
        List of indices for sorting.

    Example:
        ```python
        from_list = ["a", "b", "c"]
        mapping_to_from_list = [1, 2, 3]
        to_list = ["c", "a", "b"]
        sort_index = get_index_mapping_lists(from_list, to_list)
        sorted_list = [mapping_to_from_list[i] for i in sort_index]
        # sorted_list will be [3, 1, 2]
        ```
    """
    if isinstance(from_list, np.ndarray):
        from_list = from_list.tolist()
    if isinstance(to_list, np.ndarray):
        to_list = to_list.tolist()
    return [from_list.index(item) for item in to_list]

flyvis.utils.nodes_edges_utils.sort_by_mapping_lists

sort_by_mapping_lists(from_list, to_list, tensor, axis=0)

Sort and filter a tensor along an axis indexed by from_list to match to_list.

Parameters:

Name Type Description Default
from_list List[str]

Original list of items.

required
to_list List[str]

Target list of items.

required
tensor Union[ndarray, Tensor]

Tensor to be sorted.

required
axis int

Axis along which to sort the tensor.

0

Returns:

Type Description
ndarray

Sorted numpy array.

Source code in flyvis/utils/nodes_edges_utils.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def sort_by_mapping_lists(
    from_list: List[str],
    to_list: List[str],
    tensor: Union[np.ndarray, torch.Tensor],
    axis: int = 0,
) -> np.ndarray:
    """Sort and filter a tensor along an axis indexed by from_list to match to_list.

    Args:
        from_list: Original list of items.
        to_list: Target list of items.
        tensor: Tensor to be sorted.
        axis: Axis along which to sort the tensor.

    Returns:
        Sorted numpy array.
    """
    tensor = np.array(tensor)
    if axis != 0:
        tensor = np.transpose(tensor, axes=(axis, 0))
    sort_index = get_index_mapping_lists(from_list, to_list)
    tensor = np.array([tensor[i] for i in sort_index])
    if axis != 0:
        tensor = np.transpose(tensor, axes=(axis, 0))
    return tensor

flyvis.utils.nodes_edges_utils.nodes_list_sorting_on_off_unknown

nodes_list_sorting_on_off_unknown(cell_types=None)

Sort node list based on on/off/unknown polarity.

Parameters:

Name Type Description Default
cell_types Optional[List[str]]

List of cell types to sort. If None, uses all types from groundtruth_utils.polarity.

None

Returns:

Type Description
List[str]

Sorted list of cell types.

Source code in flyvis/utils/nodes_edges_utils.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def nodes_list_sorting_on_off_unknown(
    cell_types: Optional[List[str]] = None,
) -> List[str]:
    """Sort node list based on on/off/unknown polarity.

    Args:
        cell_types: List of cell types to sort. If None, uses all types from
                    groundtruth_utils.polarity.

    Returns:
        Sorted list of cell types.
    """
    value = {1: 1, -1: 2, 0: 3}
    preferred_contrasts = groundtruth_utils.polarity
    cell_types = list(preferred_contrasts) if cell_types is None else cell_types
    preferred_contrasts = {
        k: value[v] for k, v in preferred_contrasts.items() if k in cell_types
    }
    preferred_contrasts = dict(sorted(preferred_contrasts.items(), key=lambda k: k[1]))
    nodes_list = list(preferred_contrasts.keys())
    return nodes_list

flyvis.utils.tensor_utils

Classes

flyvis.utils.tensor_utils.RefTensor

A tensor with reference indices along the last dimension.

Attributes:

Name Type Description
values Tensor

The tensor values.

indices Tensor

The reference indices.

Source code in flyvis/utils/tensor_utils.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class RefTensor:
    """A tensor with reference indices along the last dimension.

    Attributes:
        values (torch.Tensor): The tensor values.
        indices (torch.Tensor): The reference indices.
    """

    def __init__(self, values: torch.Tensor, indices: torch.Tensor) -> None:
        self.values = values
        self.indices = indices

    def deref(self) -> torch.Tensor:
        """Index the values with the given indices in the last dimension."""
        return self.values.index_select(-1, self.indices)

    def __len__(self) -> int:
        return len(self.values)

    def __repr__(self) -> str:
        return f"RefTensor(values={self.values.data}, indices={self.indices})"

    def clone(self) -> "RefTensor":
        """Return a copy of the RefTensor cloning values."""
        return RefTensor(self.values.clone(), self.indices)

    def detach(self) -> "RefTensor":
        """Return a copy of the RefTensor detaching values."""
        return RefTensor(self.values.detach(), self.indices)
deref
deref()

Index the values with the given indices in the last dimension.

Source code in flyvis/utils/tensor_utils.py
23
24
25
def deref(self) -> torch.Tensor:
    """Index the values with the given indices in the last dimension."""
    return self.values.index_select(-1, self.indices)
clone
clone()

Return a copy of the RefTensor cloning values.

Source code in flyvis/utils/tensor_utils.py
33
34
35
def clone(self) -> "RefTensor":
    """Return a copy of the RefTensor cloning values."""
    return RefTensor(self.values.clone(), self.indices)
detach
detach()

Return a copy of the RefTensor detaching values.

Source code in flyvis/utils/tensor_utils.py
37
38
39
def detach(self) -> "RefTensor":
    """Return a copy of the RefTensor detaching values."""
    return RefTensor(self.values.detach(), self.indices)

flyvis.utils.tensor_utils.AutoDeref

Bases: dict

An auto-dereferencing namespace.

Dereferencing means that if attributes are RefTensors, getitem will call RefTensor.deref() to obtain the values at the given indices.

Note

Constructed at each forward call in Network. A cache speeds up processing, e.g., for when a parameter is referenced multiple times in the dynamics.

Attributes:

Name Type Description
_cache Dict[str, object]

Cache for dereferenced values.

Source code in flyvis/utils/tensor_utils.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class AutoDeref(dict):
    """An auto-dereferencing namespace.

    Dereferencing means that if attributes are RefTensors,
    __getitem__ will call RefTensor.deref() to obtain the values at the
    given indices.

    Note:
        Constructed at each forward call in Network. A cache speeds up
        processing, e.g., for when a parameter is referenced multiple times in the
        dynamics.

    Attributes:
        _cache (Dict[str, object]): Cache for dereferenced values.
    """

    _cache: Dict[str, object]

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        object.__setattr__(self, "_cache", {})

    def __setitem__(self, key: str, value: object) -> None:
        self._cache.pop(key, None)
        super().__setitem__(key, value)

    def __getitem__(self, key: str) -> Any:
        try:
            val = super().__getitem__(key)
        except AttributeError as e:
            raise e
        if isinstance(val, RefTensor):
            if key not in self._cache:
                self._cache[key] = val.deref()
            val = self._cache[key]
        return val

    def __setattr__(self, key: str, value: object) -> None:
        self.__setitem__(key, value)

    def __getattr__(self, key: str) -> Any:
        return self.__getitem__(key)

    def __repr__(self) -> str:
        def single_line_repr(elem: object) -> str:
            if isinstance(elem, list):
                return "[" + ", ".join(map(single_line_repr, elem)) + "]"
            elif isinstance(elem, AutoDeref):
                return (
                    f"{elem.__class__.__name__}("
                    + ", ".join(f"{k}={single_line_repr(v)}" for k, v in elem.items())
                    + ")"
                )
            else:
                return repr(elem).replace("\n", " ")

        def repr_in_context(elem: object, curr_col: int, indent: int) -> str:
            sl_repr = single_line_repr(elem)
            if len(sl_repr) <= 80 - curr_col:
                return sl_repr
            elif isinstance(elem, list):
                return (
                    "[\n"
                    + " " * (indent + 2)
                    + (",\n" + " " * (indent + 2)).join(
                        repr_in_context(e, indent + 2, indent + 2) for e in elem
                    )
                    + "\n"
                    + " " * indent
                    + "]"
                )
            elif isinstance(elem, AutoDeref):
                return (
                    f"{elem.__class__.__name__}(\n"
                    + " " * (indent + 2)
                    + (",\n" + " " * (indent + 2)).join(
                        f"{k} = " + repr_in_context(v, indent + 5 + len(k), indent + 2)
                        for k, v in elem.items()
                    )
                    + "\n"
                    + " " * indent
                    + ")"
                )
            else:
                return repr(elem)

        return repr_in_context(self, 0, 0)

    def get_as_reftensor(self, key: str) -> RefTensor:
        """Get the original RefTensor without dereferencing."""
        return dict.__getitem__(self, key)

    def clear_cache(self) -> "AutoDeref":
        """Clear the cache and return a cloned instance."""
        object.__setattr__(self, "_cache", {})
        return clone(self)

    def detach(self) -> "AutoDeref":
        """Return a detached copy of the AutoDeref instance."""
        return detach(self)
get_as_reftensor
get_as_reftensor(key)

Get the original RefTensor without dereferencing.

Source code in flyvis/utils/tensor_utils.py
130
131
132
def get_as_reftensor(self, key: str) -> RefTensor:
    """Get the original RefTensor without dereferencing."""
    return dict.__getitem__(self, key)
clear_cache
clear_cache()

Clear the cache and return a cloned instance.

Source code in flyvis/utils/tensor_utils.py
134
135
136
137
def clear_cache(self) -> "AutoDeref":
    """Clear the cache and return a cloned instance."""
    object.__setattr__(self, "_cache", {})
    return clone(self)
detach
detach()

Return a detached copy of the AutoDeref instance.

Source code in flyvis/utils/tensor_utils.py
139
140
141
def detach(self) -> "AutoDeref":
    """Return a detached copy of the AutoDeref instance."""
    return detach(self)

Functions

flyvis.utils.tensor_utils.detach

detach(obj)

Recursively detach AutoDeref mappings.

Parameters:

Name Type Description Default
obj AutoDeref

The object to detach.

required

Returns:

Type Description
AutoDeref

A detached copy of the input object.

Raises:

Type Description
TypeError

If the object type is not supported.

Source code in flyvis/utils/tensor_utils.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def detach(obj: AutoDeref) -> AutoDeref:
    """Recursively detach AutoDeref mappings.

    Args:
        obj: The object to detach.

    Returns:
        A detached copy of the input object.

    Raises:
        TypeError: If the object type is not supported.
    """
    if isinstance(obj, (type(None), bool, int, float, str, type)):
        return obj
    elif isinstance(obj, (RefTensor, torch.Tensor)):
        return obj.detach()
    elif isinstance(obj, (list, tuple)):
        return [detach(v) for v in obj]
    elif isinstance(obj, Mapping):
        return AutoDeref({k: detach(dict.__getitem__(obj, k)) for k in obj})
    else:
        try:
            return detach(vars(obj))
        except TypeError as e:
            raise TypeError(f"{obj} of type {type(obj)} as {e}.") from None

flyvis.utils.tensor_utils.clone

clone(obj)

Recursively clone AutoDeref mappings.

Parameters:

Name Type Description Default
obj AutoDeref

The object to clone.

required

Returns:

Type Description
AutoDeref

A cloned copy of the input object.

Raises:

Type Description
TypeError

If the object type is not supported.

Source code in flyvis/utils/tensor_utils.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def clone(obj: AutoDeref) -> AutoDeref:
    """Recursively clone AutoDeref mappings.

    Args:
        obj: The object to clone.

    Returns:
        A cloned copy of the input object.

    Raises:
        TypeError: If the object type is not supported.
    """
    if isinstance(obj, (type(None), bool, int, float, str, type)):
        return obj
    elif isinstance(obj, (RefTensor, torch.Tensor)):
        return obj.clone()
    elif isinstance(obj, (list, tuple)):
        return [clone(v) for v in obj]
    elif isinstance(obj, Mapping):
        return AutoDeref({k: clone(dict.__getitem__(obj, k)) for k in obj})
    else:
        try:
            print("reached")
            return clone(vars(obj))
        except TypeError as e:
            raise TypeError(f"{obj} of type {type(obj)} as {e}.") from None

flyvis.utils.tensor_utils.to_numpy

to_numpy(array)

Convert array-like to numpy array.

Parameters:

Name Type Description Default
array Union[ndarray, Tensor, List]

The input array-like object.

required

Returns:

Type Description
ndarray

A numpy array.

Raises:

Type Description
ValueError

If the input type cannot be cast to a numpy array.

Source code in flyvis/utils/tensor_utils.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def to_numpy(array: Union[np.ndarray, torch.Tensor, List]) -> np.ndarray:
    """Convert array-like to numpy array.

    Args:
        array: The input array-like object.

    Returns:
        A numpy array.

    Raises:
        ValueError: If the input type cannot be cast to a numpy array.
    """
    if isinstance(array, np.ndarray):
        return array
    elif isinstance(array, torch.Tensor):
        return array.detach().cpu().numpy()
    elif isinstance(array, list):
        return np.array(array)
    else:
        raise ValueError(f"type {type(array)} cannot be cast to numpy array")

flyvis.utils.tensor_utils.atleast_column_vector

atleast_column_vector(array)

Convert 1d-array-like to column vector n x 1 or return the original.

Parameters:

Name Type Description Default
array ndarray

The input array.

required

Returns:

Type Description
ndarray

A column vector or the original array if it’s already 2D or higher.

Source code in flyvis/utils/tensor_utils.py
221
222
223
224
225
226
227
228
229
230
231
232
233
def atleast_column_vector(array: np.ndarray) -> np.ndarray:
    """Convert 1d-array-like to column vector n x 1 or return the original.

    Args:
        array: The input array.

    Returns:
        A column vector or the original array if it's already 2D or higher.
    """
    array = np.array(array)
    if array.ndim == 1:
        return array.reshape(-1, 1)
    return array

flyvis.utils.tensor_utils.matrix_mask_by_sub

matrix_mask_by_sub(sub_matrix, matrix)

Create a mask of rows in matrix that are contained in sub_matrix.

Parameters:

Name Type Description Default
sub_matrix ndarray

Shape (n_rows1, n_columns)

required
matrix ndarray

Shape (n_rows2, n_columns)

required

Returns:

Type Description
NDArray[bool]

1D boolean array of length n_rows2

Note

n_rows1 !<= n_rows2

Example
sub_matrix = np.array([[1, 2, 3],
                       [4, 3, 1]])
matrix = np.array([[3, 4, 1],
                   [4, 3, 1],
                   [1, 2, 3]])
matrix_mask_by_sub(sub_matrix, matrix)
# array([False, True, True])

Typically, indexing a tensor with indices instead of booleans is faster. Therefore, see also where_equal_rows.

Source code in flyvis/utils/tensor_utils.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def matrix_mask_by_sub(sub_matrix: np.ndarray, matrix: np.ndarray) -> NDArray[bool]:
    """Create a mask of rows in matrix that are contained in sub_matrix.

    Args:
        sub_matrix: Shape (n_rows1, n_columns)
        matrix: Shape (n_rows2, n_columns)

    Returns:
        1D boolean array of length n_rows2

    Note:
        n_rows1 !<= n_rows2

    Example:
        ```python
        sub_matrix = np.array([[1, 2, 3],
                               [4, 3, 1]])
        matrix = np.array([[3, 4, 1],
                           [4, 3, 1],
                           [1, 2, 3]])
        matrix_mask_by_sub(sub_matrix, matrix)
        # array([False, True, True])
        ```

    Typically, indexing a tensor with indices instead of booleans is
    faster. Therefore, see also where_equal_rows.
    """
    from functools import reduce

    n_rows, n_columns = sub_matrix.shape
    n_rows2 = matrix.shape[0]
    if not n_rows <= n_rows2:
        raise ValueError("sub_matrix must have fewer or equal rows as matrix")
    row_mask = []
    for i in range(n_rows):
        column_mask = []
        for j in range(n_columns):
            column_mask.append(sub_matrix[i, j] == matrix[:, j])
        row_mask.append(reduce(np.logical_and, column_mask))
    return reduce(np.logical_or, row_mask)

flyvis.utils.tensor_utils.where_equal_rows

where_equal_rows(matrix1, matrix2, as_mask=False, astype='|S64')

Find indices where matrix1 rows are in matrix2.

Parameters:

Name Type Description Default
matrix1 ndarray

First input matrix.

required
matrix2 ndarray

Second input matrix.

required
as_mask bool

If True, return a boolean mask instead of indices.

False
astype str

Data type to use for comparison.

'|S64'

Returns:

Type Description
NDArray[int]

Array of indices or boolean mask.

Example
matrix1 = np.array([[1, 2, 3],
                    [4, 3, 1]])
matrix2 = np.array([[3, 4, 1],
                    [4, 3, 1],
                    [1, 2, 3],
                    [0, 0, 0]])
where_equal_rows(matrix1, matrix2)
# array([2, 1])
matrix2[where_equal_rows(matrix1, matrix2)]
# array([[1, 2, 3],
#        [4, 3, 1]])
See also

matrix_mask_by_sub

Source code in flyvis/utils/tensor_utils.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def where_equal_rows(
    matrix1: np.ndarray,
    matrix2: np.ndarray,
    as_mask: bool = False,
    astype: str = "|S64",
) -> NDArray[int]:
    """Find indices where matrix1 rows are in matrix2.

    Args:
        matrix1: First input matrix.
        matrix2: Second input matrix.
        as_mask: If True, return a boolean mask instead of indices.
        astype: Data type to use for comparison.

    Returns:
        Array of indices or boolean mask.

    Example:
        ```python
        matrix1 = np.array([[1, 2, 3],
                            [4, 3, 1]])
        matrix2 = np.array([[3, 4, 1],
                            [4, 3, 1],
                            [1, 2, 3],
                            [0, 0, 0]])
        where_equal_rows(matrix1, matrix2)
        # array([2, 1])
        matrix2[where_equal_rows(matrix1, matrix2)]
        # array([[1, 2, 3],
        #        [4, 3, 1]])
        ```

    See also:
        matrix_mask_by_sub
    """
    matrix1 = atleast_column_vector(matrix1)
    matrix2 = atleast_column_vector(matrix2)
    matrix1 = matrix1.astype(astype)
    matrix2 = matrix2.astype(astype)

    if as_mask:
        return matrix_mask_by_sub(matrix1, matrix2)

    n_rows1, n_cols1 = matrix1.shape
    n_rows2, n_cols2 = matrix2.shape

    if not n_rows1 <= n_rows2:
        raise ValueError("matrix1 must have less or equal as many rows as matrix2")
    if not n_cols1 == n_cols2:
        raise ValueError("cannot compare matrices with different number of columns")

    where = []
    rows = np.arange(matrix2.shape[0])
    for row in matrix1:
        equal_rows = (row == matrix2).all(axis=1)
        for index in rows[equal_rows]:
            where.append(index)
    return np.array(where)

flyvis.utils.tensor_utils.broadcast

broadcast(src, other, dim)

Broadcast src to the shape of other along dimension dim.

Parameters:

Name Type Description Default
src Tensor

Source tensor to broadcast.

required
other Tensor

Target tensor to broadcast to.

required
dim int

Dimension along which to broadcast.

required

Returns:

Type Description
Tensor

Broadcasted tensor.

Note

From https://github.com/rusty1s/pytorch_scatter/.

Source code in flyvis/utils/tensor_utils.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int) -> torch.Tensor:
    """Broadcast `src` to the shape of `other` along dimension `dim`.

    Args:
        src: Source tensor to broadcast.
        other: Target tensor to broadcast to.
        dim: Dimension along which to broadcast.

    Returns:
        Broadcasted tensor.

    Note:
        From https://github.com/rusty1s/pytorch_scatter/.
    """
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(0, dim):
            src = src.unsqueeze(0)
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand(other.size())
    return src

flyvis.utils.tensor_utils.scatter_reduce

scatter_reduce(src, index, dim=-1, mode='mean')

Reduce along dimension dim using values in the index tensor.

Parameters:

Name Type Description Default
src Tensor

Source tensor.

required
index Tensor

Index tensor.

required
dim int

Dimension along which to reduce.

-1
mode Literal['mean', 'sum']

Reduction mode, either “mean” or “sum”.

'mean'

Returns:

Type Description
Tensor

Reduced tensor.

Note

Convenience function for torch.scatter_reduce that broadcasts index to the shape of src along dimension dim to cohere to pytorch_scatter API.

Source code in flyvis/utils/tensor_utils.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def scatter_reduce(
    src: torch.Tensor,
    index: torch.Tensor,
    dim: int = -1,
    mode: Literal["mean", "sum"] = "mean",
) -> torch.Tensor:
    """Reduce along dimension `dim` using values in the `index` tensor.

    Args:
        src: Source tensor.
        index: Index tensor.
        dim: Dimension along which to reduce.
        mode: Reduction mode, either "mean" or "sum".

    Returns:
        Reduced tensor.

    Note:
        Convenience function for `torch.scatter_reduce` that broadcasts `index` to
        the shape of `src` along dimension `dim` to cohere to pytorch_scatter API.
    """
    index = broadcast(index.long(), src, dim)
    return torch.scatter_reduce(src, dim, index, reduce=mode)

flyvis.utils.tensor_utils.scatter_mean

scatter_mean(src, index, dim=-1)

Average along dimension dim using values in the index tensor.

Parameters:

Name Type Description Default
src Tensor

Source tensor.

required
index Tensor

Index tensor.

required
dim int

Dimension along which to average.

-1

Returns:

Type Description
Tensor

Averaged tensor.

Source code in flyvis/utils/tensor_utils.py
388
389
390
391
392
393
394
395
396
397
398
399
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """Average along dimension `dim` using values in the `index` tensor.

    Args:
        src: Source tensor.
        index: Index tensor.
        dim: Dimension along which to average.

    Returns:
        Averaged tensor.
    """
    return scatter_reduce(src, index, dim, "mean")

flyvis.utils.tensor_utils.scatter_add

scatter_add(src, index, dim=-1)

Sum along dimension dim using values in the index tensor.

Parameters:

Name Type Description Default
src Tensor

Source tensor.

required
index Tensor

Index tensor.

required
dim int

Dimension along which to sum.

-1

Returns:

Type Description
Tensor

Summed tensor.

Source code in flyvis/utils/tensor_utils.py
402
403
404
405
406
407
408
409
410
411
412
413
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1) -> torch.Tensor:
    """Sum along dimension `dim` using values in the `index` tensor.

    Args:
        src: Source tensor.
        index: Index tensor.
        dim: Dimension along which to sum.

    Returns:
        Summed tensor.
    """
    return scatter_reduce(src, index, dim, "sum")

flyvis.utils.tensor_utils.select_along_axes

select_along_axes(array, indices, dims)

Select indices from array along dims.

Parameters:

Name Type Description Default
array ndarray

Array to take indices from.

required
indices Union[int, Iterable[int]]

Indices to take.

required
dims Union[int, Iterable[int]]

Dimensions to take indices from.

required

Returns:

Type Description
ndarray

Array with selected indices along specified dimensions.

Source code in flyvis/utils/tensor_utils.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def select_along_axes(
    array: np.ndarray,
    indices: Union[int, Iterable[int]],
    dims: Union[int, Iterable[int]],
) -> np.ndarray:
    """Select indices from array along dims.

    Args:
        array: Array to take indices from.
        indices: Indices to take.
        dims: Dimensions to take indices from.

    Returns:
        Array with selected indices along specified dimensions.
    """
    if not isinstance(indices, Iterable):
        indices = [indices]
    if not isinstance(dims, Iterable):
        dims = [dims]

    for index, dim in zip(indices, dims):
        if not isinstance(index, Iterable):
            index = [index]
        array = array.take(index, axis=dim)
    return array

flyvis.utils.tensor_utils.asymmetric_weighting

asymmetric_weighting(tensor, gamma=1.0, delta=0.1)

Apply asymmetric weighting to the positive and negative elements of a tensor.

Parameters:

Name Type Description Default
tensor Union[NDArray, Tensor]

Input tensor.

required
gamma float

Positive weighting factor.

1.0
delta float

Negative weighting factor.

0.1

Returns:

Type Description
Union[NDArray, Tensor]

Weighted tensor.

Note

The function is defined as: f(x) = gamma * x if x > 0 else delta * x

Source code in flyvis/utils/tensor_utils.py
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def asymmetric_weighting(
    tensor: Union[NDArray, torch.Tensor], gamma: float = 1.0, delta: float = 0.1
) -> Union[NDArray, torch.Tensor]:
    """
    Apply asymmetric weighting to the positive and negative elements of a tensor.

    Args:
        tensor: Input tensor.
        gamma: Positive weighting factor.
        delta: Negative weighting factor.

    Returns:
        Weighted tensor.

    Note:
        The function is defined as:
        f(x) = gamma * x if x > 0 else delta * x
    """
    return gamma * nn.functional.relu(tensor) - delta * nn.functional.relu(-tensor)

flyvis.utils.type_utils

Functions

flyvis.utils.type_utils.byte_to_str

byte_to_str(obj)

Cast byte elements to string types recursively.

This function recursively converts byte elements to string types in nested data structures.

Parameters:

Name Type Description Default
obj Any

The object to be processed. Can be of various types including Mapping, numpy.ndarray, list, tuple, bytes, str, or Number.

required

Returns:

Type Description
Any

The input object with all byte elements converted to strings.

Raises:

Type Description
TypeError

If the input object cannot be cast to a string type.

Note

This function will cast all byte elements in nested lists or tuples.

Examples:

>>> byte_to_str(b"hello")
'hello'
>>> byte_to_str([b"world", 42, {b"key": b"value"}])
['world', 42, {'key': 'value'}]
Source code in flyvis/utils/type_utils.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def byte_to_str(obj: Any) -> Any:
    """Cast byte elements to string types recursively.

    This function recursively converts byte elements to string types in nested
    data structures.

    Args:
        obj: The object to be processed. Can be of various types including
            Mapping, numpy.ndarray, list, tuple, bytes, str, or Number.

    Returns:
        The input object with all byte elements converted to strings.

    Raises:
        TypeError: If the input object cannot be cast to a string type.

    Note:
        This function will cast all byte elements in nested lists or tuples.

    Examples:
        ```python
        >>> byte_to_str(b"hello")
        'hello'
        >>> byte_to_str([b"world", 42, {b"key": b"value"}])
        ['world', 42, {'key': 'value'}]
        ```
    """
    if isinstance(obj, Mapping):
        return type(obj)({k: byte_to_str(v) for k, v in obj.items()})
    elif isinstance(obj, np.ndarray):
        if np.issubdtype(obj.dtype, np.dtype("S")):
            return obj.astype("U")
        return obj
    elif isinstance(obj, list):
        return [byte_to_str(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(byte_to_str(item) for item in obj)
    elif isinstance(obj, bytes):
        return obj.decode()
    elif isinstance(obj, (str, Number)):
        return obj
    else:
        raise TypeError(f"can't cast {obj} of type {type(obj)} to str")

flyvis.utils.xarray_joblib_backend

Classes

flyvis.utils.xarray_joblib_backend.H5XArrayDatasetStoreBackend

Bases: FileSystemStoreBackend

FileSystemStoreBackend subclass for handling xarray.Dataset objects.

This class uses xarray’s to_netcdf and open_dataset methods for Dataset objects and .h5 files.

Attributes:

Name Type Description
location str

The base directory for storing items.

Source code in flyvis/utils/xarray_joblib_backend.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class H5XArrayDatasetStoreBackend(FileSystemStoreBackend):
    """FileSystemStoreBackend subclass for handling xarray.Dataset objects.

    This class uses xarray's to_netcdf and open_dataset methods for Dataset objects and
    .h5 files.

    Attributes:
        location (str): The base directory for storing items.
    """

    def dump_item(self, path: List[str], item: Any, *args, **kwargs) -> None:
        """Dump an item to the store.

        If the item is an xarray.Dataset or the path ends with '.h5', use
        xarray.Dataset.to_netcdf. Otherwise, use the superclass method.

        Args:
            path: The identifier for the item in the store.
            item: The item to be stored.
            *args: Variable positional arguments passed to parent class or to_netcdf
            **kwargs: Variable keyword arguments passed to parent class or to_netcdf
        """
        is_dataset = isinstance(item, xr.Dataset)
        is_h5_file = path[-1].endswith('.h5') if path else False

        if is_dataset or is_h5_file:
            item_path = os.path.join(self.location, *path)
            nc_path = item_path if is_h5_file else os.path.join(item_path, 'output.h5')

            verbose = kwargs.get('verbose', 1)
            if verbose > 10:
                logger.info('Persisting Dataset to h5 at %s', nc_path)

            try:
                self.create_location(os.path.dirname(nc_path))
                logger.info("Store item %s", nc_path)
                # Ensure mode='w' by default but allow override through kwargs
                kwargs.setdefault('mode', 'w')
                item.to_netcdf(nc_path)
            except Exception as e:
                warnings.warn(
                    f"Unable to cache Dataset to h5. Exception: {e}.",
                    CacheWarning,
                    stacklevel=2,
                )
        else:
            super().dump_item(path, item, *args, **kwargs)

    def load_item(self, path: List[str], *args, **kwargs) -> Any:
        """Load an item from the store.

        If the path ends with '.h5' or the store contains a h5 file, use
        xarray.open_dataset. Otherwise, use the superclass method.

        Args:
            path: The identifier for the item in the store.
            *args: Variable positional arguments passed to parent class or xr.open_dataset
            **kwargs: Variable keyword arguments passed to parent class or xr.open_dataset

        Returns:
            The loaded item, either an xarray.Dataset or the original object.
        """
        item_path = os.path.join(self.location, *path)
        nc_path = (
            item_path
            if path[-1].endswith('.h5')
            else os.path.join(item_path, 'output.h5')
        )
        print(nc_path)
        if self._item_exists(nc_path):
            verbose = kwargs.get('verbose', 1)
            if verbose > 1:
                logger.info('Loading Dataset from h5 at %s', nc_path)
            try:
                return xr.open_dataset(nc_path)
            except Exception as e:
                warnings.warn(
                    f"Unable to load Dataset from h5. Exception: {e}.",
                    CacheWarning,
                    stacklevel=2,
                )
        return super().load_item(path, *args, **kwargs)

    def contains_item(self, path: List[str]) -> bool:
        """Check if there is an item at the given path.

        This method checks for both h5 and pickle files.

        Args:
            path: The identifier for the item in the store.

        Returns:
            True if the item exists in either h5 or pickle format, False otherwise.
        """
        item_path = os.path.join(self.location, *path)
        nc_filename = (
            item_path
            if path[-1].endswith('.h5')
            else os.path.join(item_path, 'output.h5')
        )
        super_filename = os.path.join(item_path, 'output.pkl')

        return self._item_exists(nc_filename) or super()._item_exists(super_filename)
dump_item
dump_item(path, item, *args, **kwargs)

Dump an item to the store.

If the item is an xarray.Dataset or the path ends with ‘.h5’, use xarray.Dataset.to_netcdf. Otherwise, use the superclass method.

Parameters:

Name Type Description Default
path List[str]

The identifier for the item in the store.

required
item Any

The item to be stored.

required
*args

Variable positional arguments passed to parent class or to_netcdf

()
**kwargs

Variable keyword arguments passed to parent class or to_netcdf

{}
Source code in flyvis/utils/xarray_joblib_backend.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def dump_item(self, path: List[str], item: Any, *args, **kwargs) -> None:
    """Dump an item to the store.

    If the item is an xarray.Dataset or the path ends with '.h5', use
    xarray.Dataset.to_netcdf. Otherwise, use the superclass method.

    Args:
        path: The identifier for the item in the store.
        item: The item to be stored.
        *args: Variable positional arguments passed to parent class or to_netcdf
        **kwargs: Variable keyword arguments passed to parent class or to_netcdf
    """
    is_dataset = isinstance(item, xr.Dataset)
    is_h5_file = path[-1].endswith('.h5') if path else False

    if is_dataset or is_h5_file:
        item_path = os.path.join(self.location, *path)
        nc_path = item_path if is_h5_file else os.path.join(item_path, 'output.h5')

        verbose = kwargs.get('verbose', 1)
        if verbose > 10:
            logger.info('Persisting Dataset to h5 at %s', nc_path)

        try:
            self.create_location(os.path.dirname(nc_path))
            logger.info("Store item %s", nc_path)
            # Ensure mode='w' by default but allow override through kwargs
            kwargs.setdefault('mode', 'w')
            item.to_netcdf(nc_path)
        except Exception as e:
            warnings.warn(
                f"Unable to cache Dataset to h5. Exception: {e}.",
                CacheWarning,
                stacklevel=2,
            )
    else:
        super().dump_item(path, item, *args, **kwargs)
load_item
load_item(path, *args, **kwargs)

Load an item from the store.

If the path ends with ‘.h5’ or the store contains a h5 file, use xarray.open_dataset. Otherwise, use the superclass method.

Parameters:

Name Type Description Default
path List[str]

The identifier for the item in the store.

required
*args

Variable positional arguments passed to parent class or xr.open_dataset

()
**kwargs

Variable keyword arguments passed to parent class or xr.open_dataset

{}

Returns:

Type Description
Any

The loaded item, either an xarray.Dataset or the original object.

Source code in flyvis/utils/xarray_joblib_backend.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def load_item(self, path: List[str], *args, **kwargs) -> Any:
    """Load an item from the store.

    If the path ends with '.h5' or the store contains a h5 file, use
    xarray.open_dataset. Otherwise, use the superclass method.

    Args:
        path: The identifier for the item in the store.
        *args: Variable positional arguments passed to parent class or xr.open_dataset
        **kwargs: Variable keyword arguments passed to parent class or xr.open_dataset

    Returns:
        The loaded item, either an xarray.Dataset or the original object.
    """
    item_path = os.path.join(self.location, *path)
    nc_path = (
        item_path
        if path[-1].endswith('.h5')
        else os.path.join(item_path, 'output.h5')
    )
    print(nc_path)
    if self._item_exists(nc_path):
        verbose = kwargs.get('verbose', 1)
        if verbose > 1:
            logger.info('Loading Dataset from h5 at %s', nc_path)
        try:
            return xr.open_dataset(nc_path)
        except Exception as e:
            warnings.warn(
                f"Unable to load Dataset from h5. Exception: {e}.",
                CacheWarning,
                stacklevel=2,
            )
    return super().load_item(path, *args, **kwargs)
contains_item
contains_item(path)

Check if there is an item at the given path.

This method checks for both h5 and pickle files.

Parameters:

Name Type Description Default
path List[str]

The identifier for the item in the store.

required

Returns:

Type Description
bool

True if the item exists in either h5 or pickle format, False otherwise.

Source code in flyvis/utils/xarray_joblib_backend.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def contains_item(self, path: List[str]) -> bool:
    """Check if there is an item at the given path.

    This method checks for both h5 and pickle files.

    Args:
        path: The identifier for the item in the store.

    Returns:
        True if the item exists in either h5 or pickle format, False otherwise.
    """
    item_path = os.path.join(self.location, *path)
    nc_filename = (
        item_path
        if path[-1].endswith('.h5')
        else os.path.join(item_path, 'output.h5')
    )
    super_filename = os.path.join(item_path, 'output.pkl')

    return self._item_exists(nc_filename) or super()._item_exists(super_filename)

flyvis.utils.xarray_utils

Classes

flyvis.utils.xarray_utils.CustomAccessor

Custom accessor for xarray objects providing additional functionality.

Attributes:

Name Type Description
_obj

The xarray object being accessed.

Source code in flyvis/utils/xarray_utils.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
class CustomAccessor:
    """Custom accessor for xarray objects providing additional functionality.

    Attributes:
        _obj: The xarray object being accessed.
    """

    def __init__(self, xarray_obj: xr.Dataset | xr.DataArray):
        self._obj = xarray_obj

    @wraps(where_xarray)
    def where(self, **kwargs) -> xr.Dataset | xr.DataArray:
        return where_xarray(self._obj, **kwargs)

    @wraps(plot_traces)
    def plot_traces(
        self,
        x: str,
        key: str = "",
        legend_labels: List[str] = [],
        extra_legend_coords: List[str] = [],
        plot_kwargs: dict = {},
        **kwargs,
    ) -> plt.Axes:
        """Plot traces from the xarray object.

        Args:
            x: The dimension to use as the x-axis.
            key: The key of the data to plot if the dataset is a Dataset.
            legend_labels: List of coordinates to include in the legend.
            extra_legend_coords: Additional coordinates to include in the legend.
            plot_kwargs: Additional keyword arguments to pass to the plot function.
            **kwargs: Query-like conditions on coordinates.

        Returns:
            The matplotlib axes object containing the plot.

        Example:
            Overlay stimulus and response traces:
            ```python
            fig, ax = plt.subplots()
            r.custom.plot_traces(
                key='stimulus',
                x='time',
                speed=[19, 25],
                intensity=1,
                angle=90,
                u_in=0,
                v_in=0,
                plot_kwargs=dict(ax=ax),
                time='>0,<1.0'
            )
            r.custom.plot_traces(
                key='responses',
                x='time',
                speed=[19, 25],
                cell_type='T4c',
                intensity=1,
                angle=90,
                network_id=0,
                plot_kwargs=dict(ax=ax),
                time='>0,<1.0'
            )
            ```

            Polar plot:
            ```python
            prs = peak_responses(stims_and_resps_moving_edges).custom.where(
                cell_type="T4c",
                intensity=1,
                speed=19,
            )
            prs['angle'] = np.radians(prs.angle)
            ax = plt.subplots(subplot_kw={"projection": "polar"})[1]
            prs.custom.plot_traces(
                x="angle",
                legend_labels=["network_id"],
                plot_kwargs={"add_legend": False, "ax": ax, "color": "b"},
            )
            ```
        """
        return plot_traces(
            self._obj,
            key,
            x,
            legend_labels,
            extra_legend_coords,
            plot_kwargs,
            **kwargs,
        )
plot_traces
plot_traces(x, key='', legend_labels=[], extra_legend_coords=[], plot_kwargs={}, **kwargs)

Plot traces from the xarray object.

Parameters:

Name Type Description Default
x str

The dimension to use as the x-axis.

required
key str

The key of the data to plot if the dataset is a Dataset.

''
legend_labels List[str]

List of coordinates to include in the legend.

[]
extra_legend_coords List[str]

Additional coordinates to include in the legend.

[]
plot_kwargs dict

Additional keyword arguments to pass to the plot function.

{}
**kwargs

Query-like conditions on coordinates.

{}

Returns:

Type Description
Axes

The matplotlib axes object containing the plot.

Example

Overlay stimulus and response traces:

fig, ax = plt.subplots()
r.custom.plot_traces(
    key='stimulus',
    x='time',
    speed=[19, 25],
    intensity=1,
    angle=90,
    u_in=0,
    v_in=0,
    plot_kwargs=dict(ax=ax),
    time='>0,<1.0'
)
r.custom.plot_traces(
    key='responses',
    x='time',
    speed=[19, 25],
    cell_type='T4c',
    intensity=1,
    angle=90,
    network_id=0,
    plot_kwargs=dict(ax=ax),
    time='>0,<1.0'
)

Polar plot:

prs = peak_responses(stims_and_resps_moving_edges).custom.where(
    cell_type="T4c",
    intensity=1,
    speed=19,
)
prs['angle'] = np.radians(prs.angle)
ax = plt.subplots(subplot_kw={"projection": "polar"})[1]
prs.custom.plot_traces(
    x="angle",
    legend_labels=["network_id"],
    plot_kwargs={"add_legend": False, "ax": ax, "color": "b"},
)

Source code in flyvis/utils/xarray_utils.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
@wraps(plot_traces)
def plot_traces(
    self,
    x: str,
    key: str = "",
    legend_labels: List[str] = [],
    extra_legend_coords: List[str] = [],
    plot_kwargs: dict = {},
    **kwargs,
) -> plt.Axes:
    """Plot traces from the xarray object.

    Args:
        x: The dimension to use as the x-axis.
        key: The key of the data to plot if the dataset is a Dataset.
        legend_labels: List of coordinates to include in the legend.
        extra_legend_coords: Additional coordinates to include in the legend.
        plot_kwargs: Additional keyword arguments to pass to the plot function.
        **kwargs: Query-like conditions on coordinates.

    Returns:
        The matplotlib axes object containing the plot.

    Example:
        Overlay stimulus and response traces:
        ```python
        fig, ax = plt.subplots()
        r.custom.plot_traces(
            key='stimulus',
            x='time',
            speed=[19, 25],
            intensity=1,
            angle=90,
            u_in=0,
            v_in=0,
            plot_kwargs=dict(ax=ax),
            time='>0,<1.0'
        )
        r.custom.plot_traces(
            key='responses',
            x='time',
            speed=[19, 25],
            cell_type='T4c',
            intensity=1,
            angle=90,
            network_id=0,
            plot_kwargs=dict(ax=ax),
            time='>0,<1.0'
        )
        ```

        Polar plot:
        ```python
        prs = peak_responses(stims_and_resps_moving_edges).custom.where(
            cell_type="T4c",
            intensity=1,
            speed=19,
        )
        prs['angle'] = np.radians(prs.angle)
        ax = plt.subplots(subplot_kw={"projection": "polar"})[1]
        prs.custom.plot_traces(
            x="angle",
            legend_labels=["network_id"],
            plot_kwargs={"add_legend": False, "ax": ax, "color": "b"},
        )
        ```
    """
    return plot_traces(
        self._obj,
        key,
        x,
        legend_labels,
        extra_legend_coords,
        plot_kwargs,
        **kwargs,
    )

Functions

flyvis.utils.xarray_utils.where_xarray

where_xarray(dataset, rtol=1e-05, atol=1e-08, **kwargs)

Return a subset of the xarray Dataset or DataArray where coordinates meet specified query-like conditions.

Parameters:

Name Type Description Default
dataset Dataset | DataArray

The dataset or data array to filter.

required
rtol float

Relative tolerance for floating point comparisons.

1e-05
atol float

Absolute tolerance for floating point comparisons.

1e-08
**kwargs

Query-like conditions on coordinates. Conditions can be specified as: - Strings with comma-separated conditions (interpreted as AND). - Iterables (lists, tuples) representing multiple conditions (interpreted as OR). - Single values for equality conditions.

{}

Returns:

Type Description
Dataset | DataArray

The filtered dataset or data array.

Example
filtered_ds = where_xarray(
    ds,
    cell_type=["T4a", "T4b"],
    time="<1.0,>0",
    intensity=1.0,
    radius=6,
    width=2.4
)
Source code in flyvis/utils/xarray_utils.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def where_xarray(
    dataset: xr.Dataset | xr.DataArray,
    rtol: float = 1.0e-5,
    atol: float = 1.0e-8,
    **kwargs,
) -> xr.Dataset | xr.DataArray:
    """Return a subset of the xarray Dataset or DataArray where coordinates meet
    specified query-like conditions.

    Args:
        dataset: The dataset or data array to filter.
        rtol: Relative tolerance for floating point comparisons.
        atol: Absolute tolerance for floating point comparisons.
        **kwargs: Query-like conditions on coordinates. Conditions can be specified as:
            - Strings with comma-separated conditions (interpreted as AND).
            - Iterables (lists, tuples) representing multiple conditions
                (interpreted as OR).
            - Single values for equality conditions.

    Returns:
        The filtered dataset or data array.

    Example:
        ```python
        filtered_ds = where_xarray(
            ds,
            cell_type=["T4a", "T4b"],
            time="<1.0,>0",
            intensity=1.0,
            radius=6,
            width=2.4
        )
        ```
    """
    # Force evaluation of coordinates
    # Heisenbug, strangely required for the where() method to work
    # to circumvent AttributeError: 'ScipyArrayWrapper' object has no attribute 'oindex'
    for _, coord in dataset.coords.items():
        _ = coord.values.dtype

    # Define a mapping of operators from string to functions
    operators = {
        '>=': operator.ge,
        '<=': operator.le,
        '==': operator.eq,
        '!=': operator.ne,
        '>': operator.gt,
        '<': operator.lt,
    }

    # Sort operators by length in descending order to match multi-character operators
    # first
    sorted_operators = sorted(operators.keys(), key=len, reverse=True)

    def parse_condition(cond_str):
        """Parse a single condition string into (operator_function, target_value)."""
        cond_str = cond_str.strip()
        for op_str in sorted_operators:
            if cond_str.startswith(op_str):
                target = cond_str[len(op_str) :].strip()
                with contextlib.suppress(ValueError):
                    target = float(target)
                return (operators[op_str], target)
        # If no operator is found, assume equality
        try:
            target = float(cond_str)
        except ValueError:
            target = cond_str
        return (operator.eq, target)

    filtered_dataset = dataset

    for coord_name, condition in kwargs.items():
        # Check if coord_name is a coordinate in the dataset
        if coord_name not in dataset.coords:
            raise ValueError(f"Coordinate '{coord_name}' not found in the dataset.")

        coord_values = dataset.coords[coord_name]
        coord_mask = xr.ones_like(coord_values, dtype=bool)  # Initialize mask as all True

        if isinstance(condition, str):
            # String conditions: multiple conditions separated by commas (AND logic)
            condition_strings = [c.strip() for c in condition.split(',') if c.strip()]
            for cond_str in condition_strings:
                op_func, target_value = parse_condition(cond_str)

                if np.issubdtype(coord_values.dtype, np.floating):
                    if op_func == operator.eq:
                        mask = np.isclose(
                            coord_values, target_value, atol=atol, rtol=rtol
                        )
                    else:
                        mask = op_func(coord_values, target_value)
                else:
                    mask = op_func(coord_values, target_value)

                # Combine masks using logical AND
                coord_mask &= xr.DataArray(
                    mask, dims=coord_values.dims, coords=coord_values.coords
                )

        elif isinstance(condition, Iterable) and not isinstance(condition, (str, bytes)):
            # Iterable conditions: each element is a separate condition (OR logic)
            temp_mask = xr.zeros_like(
                coord_values, dtype=bool
            )  # Initialize mask as all False
            for item in condition:
                if isinstance(item, str):
                    # Parse condition string
                    op_func, target_value = parse_condition(item)
                else:
                    # Assume equality if not a string condition
                    op_func, target_value = operator.eq, item

                if np.issubdtype(coord_values.dtype, np.floating):
                    if op_func == operator.eq:
                        mask = np.isclose(
                            coord_values, target_value, atol=atol, rtol=rtol
                        )
                    else:
                        mask = op_func(coord_values, target_value)
                else:
                    mask = op_func(coord_values, target_value)

                # Combine masks using logical OR
                temp_mask |= xr.DataArray(
                    mask, dims=coord_values.dims, coords=coord_values.coords
                )
            coord_mask &= temp_mask  # Apply OR mask with existing mask
        else:
            # Single non-string, non-iterable value: assume equality
            op_func, target_value = operator.eq, condition
            if np.issubdtype(coord_values.dtype, np.floating):
                if op_func == operator.eq:
                    mask = np.isclose(coord_values, target_value, atol=atol, rtol=rtol)
                else:
                    mask = op_func(coord_values, target_value)
            else:
                mask = op_func(coord_values, target_value)
            coord_mask &= xr.DataArray(
                mask, dims=coord_values.dims, coords=coord_values.coords
            )

        # Apply the combined mask
        filtered_dataset = filtered_dataset.where(coord_mask, drop=True)

    return filtered_dataset

flyvis.utils.xarray_utils.plot_traces

plot_traces(dataset, key, x, legend_labels=[], extra_legend_coords=[], plot_kwargs={}, **kwargs)

Plot the flash response traces from the dataset, optionally filtered by various parameters.

Parameters:

Name Type Description Default
dataset DataArray | Dataset

The dataset containing the responses to plot.

required
key str

The key of the data to plot if the dataset is a Dataset.

required
x str

The dimension to use as the x-axis.

required
legend_labels List[str]

List of coordinates to include in the legend.

[]
extra_legend_coords List[str]

Additional coordinates to include in the legend.

[]
plot_kwargs dict

Additional keyword arguments to pass to the plot function.

{}
**kwargs

Query-like conditions on coordinates.

{}

Returns:

Type Description
Axes

The matplotlib axes object containing the plot.

Note

Query-like conditions can be specified as:

  • Strings with comma-separated conditions (e.g., time=’<0.5,>0.1’)
  • Lists for equality conditions (e.g., cell_type=[“T4a”, “T4b”])
  • Single values for equality conditions (e.g., intensity=1.0)
Source code in flyvis/utils/xarray_utils.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def plot_traces(
    dataset: xr.DataArray | xr.Dataset,
    key: str,
    x: str,
    legend_labels: List[str] = [],
    extra_legend_coords: List[str] = [],
    plot_kwargs: dict = {},
    **kwargs,
) -> plt.Axes:
    """Plot the flash response traces from the dataset, optionally filtered by various
    parameters.

    Args:
        dataset: The dataset containing the responses to plot.
        key: The key of the data to plot if the dataset is a Dataset.
        x: The dimension to use as the x-axis.
        legend_labels: List of coordinates to include in the legend.
        extra_legend_coords: Additional coordinates to include in the legend.
        plot_kwargs: Additional keyword arguments to pass to the plot function.
        **kwargs: Query-like conditions on coordinates.

    Returns:
        The matplotlib axes object containing the plot.

    Note:
        Query-like conditions can be specified as:

        - Strings with comma-separated conditions (e.g., time='<0.5,>0.1')
        - Lists for equality conditions (e.g., cell_type=["T4a", "T4b"])
        - Single values for equality conditions (e.g., intensity=1.0)
    """
    traces = dataset.custom.where(**kwargs)

    if key:
        traces = traces[key]

    arg_df = traces.sample.to_dataframe()

    # Stack all dims besides x
    stack_dims = [dim for dim in traces.dims if dim not in list(traces[x].coords.keys())]
    # logging.info("Stacking dimensions: %s", stack_dims)
    traces = traces.stack(traces=stack_dims)

    num_stacks = traces.sizes.get('traces', 0)
    if num_stacks > 250:
        warnings.warn(
            f"The traces stack has {num_stacks} elements.",
            UserWarning,
            stacklevel=2,
        )

    original_legend_labels = [col for col in arg_df.columns if col != 'sample']
    if x in original_legend_labels:
        # cannot set legend for x-axis values
        original_legend_labels = []

    stacked_legend_labels = list(stack_dims)

    legend_labels = (
        legend_labels
        or stacked_legend_labels + extra_legend_coords + original_legend_labels
    )
    legend_table = [np.atleast_1d(traces[col].data) for col in legend_labels]

    # Confirm all elements are 1D arrays of equal length
    try:
        legend_table = np.column_stack(legend_table)
    except ValueError as e:
        raise ValueError(
            "All elements in legend_coords must be 1D arrays of equal length. "
            "Specify legend_labels to use only a subset of the coordinates."
        ) from e

    legend_info = np.array([
        ", ".join([f"{col}: {value}" for col, value in zip(legend_labels, row)])
        for row in legend_table
    ])
    traces = traces.assign_coords(legend_info=("traces", legend_info))

    traces.plot.line(x=x, hue="legend_info", **plot_kwargs)

    ax = plt.gca()

    legend = ax.get_legend()
    if legend is not None:
        legend.set_title(None)

    return ax