Utils¶
flyvision.utils.activity_utils¶
Classes¶
flyvision.utils.activity_utils.CellTypeActivity ¶
Bases: dict
Base class for attribute-style access to network activity based on cell types.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
keepref |
bool
|
Whether to keep a reference to the activity. This may not be desired during training to avoid memory issues. |
False
|
Attributes:
Name | Type | Description |
---|---|---|
activity |
Union[ref, NDArray, Tensor]
|
Weak reference to the activity. |
keepref |
Whether to keep a reference to the activity. |
|
unique_cell_types |
List[str]
|
List of unique cell types. |
input_indices |
NDArray
|
Indices of input cells. |
output_indices |
NDArray
|
Indices of output cells. |
Note
Activity is stored as a weakref by default for memory efficiency during training. Set keepref=True to keep a reference for analysis.
Source code in flyvision/utils/activity_utils.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
|
update ¶
update(activity)
Update the activity reference.
Source code in flyvision/utils/activity_utils.py
71 72 73 |
|
flyvision.utils.activity_utils.CentralActivity ¶
Bases: CellTypeActivity
Attribute-style access to central cell activity of a cell type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activity |
Union[NDArray, Tensor]
|
Activity of shape (…, n_cells). |
required |
connectome |
ConnectomeFromAvgFilters
|
Connectome directory with reference to required attributes. |
required |
keepref |
bool
|
Whether to keep a reference to the activity. |
False
|
Attributes:
Name | Type | Description |
---|---|---|
activity |
Activity of shape (…, n_cells). |
|
unique_cell_types |
Array of unique cell types. |
|
index |
NodeIndexer instance. |
|
input_indices |
Array of input indices. |
|
output_indices |
Array of output indices. |
Source code in flyvision/utils/activity_utils.py
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
|
flyvision.utils.activity_utils.LayerActivity ¶
Bases: CellTypeActivity
Attribute-style access to hex-lattice activity (cell-type specific).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
activity |
Union[NDArray, Tensor]
|
Activity of shape (…, n_cells). |
required |
connectome |
ConnectomeFromAvgFilters
|
Connectome directory with reference to required attributes. |
required |
keepref |
bool
|
Whether to keep a reference to the activity. |
False
|
use_central |
bool
|
Whether to use central activity. |
True
|
Attributes:
Name | Type | Description |
---|---|---|
central |
CentralActivity instance for central nodes. |
|
activity |
Activity of shape (…, n_cells). |
|
connectome |
Connectome directory. |
|
unique_cell_types |
Array of unique cell types. |
|
input_indices |
Array of input indices. |
|
output_indices |
Array of output indices. |
|
input_cell_types |
Array of input cell types. |
|
output_cell_types |
Array of output cell types. |
|
n_nodes |
Number of nodes. |
Note
The name LayerActivity
might change in future as it is misleading.
This is not a feedforward layer in the machine learning sense but the
activity of all cells of a certain cell-type.
Example:
Central activity can be accessed by:
```python
a = LayerActivity(activity, network.connectome)
central_T4a = a.central.T4a
```
Also allows 'virtual types' that are the sum of individuals:
```python
a = LayerActivity(activity, network.connectome)
summed_a = a['L2+L4']
```
Source code in flyvision/utils/activity_utils.py
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
|
flyvision.utils.activity_utils.SourceCurrentView ¶
Create views of source currents for a target type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rfs |
ReceptiveFields
|
ReceptiveFields instance. |
required |
currents |
Union[NDArray, Tensor]
|
Current values. |
required |
Attributes:
Name | Type | Description |
---|---|---|
target_type |
Target cell type. |
|
source_types |
List of source cell types. |
|
rfs |
ReceptiveFields instance. |
|
currents |
Current values. |
Source code in flyvision/utils/activity_utils.py
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 |
|
update ¶
update(currents)
Update the currents.
Source code in flyvision/utils/activity_utils.py
317 318 319 |
|
flyvision.utils.cache_utils¶
Functions¶
flyvision.utils.cache_utils.context_aware_cache ¶
context_aware_cache(func=None, context=lambda self: None)
Decorator to cache the result of a method based on its arguments and context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
func |
Callable[..., T]
|
The function to be decorated. |
None
|
context |
Callable[[Any], Any]
|
A function that returns the context for caching. |
lambda self: None
|
Returns:
Type | Description |
---|---|
Callable[..., T]
|
A wrapped function that implements caching based on arguments and context. |
Example
class MyClass:
def __init__(self):
self.cache = {}
@context_aware_cache(context=lambda self: self.some_attribute)
def my_method(self, arg1, arg2):
# Method implementation
pass
Source code in flyvision/utils/cache_utils.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
flyvision.utils.cache_utils.make_hashable ¶
make_hashable(obj)
Recursively converts an object into a hashable type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any
|
The object to be converted. |
required |
Returns:
Type | Description |
---|---|
Any
|
A hashable representation of the input object. |
Note
This function handles various types including immutable types, lists, sets, dictionaries, tuples, frozensets, and slices. For complex objects, it falls back to string conversion, which may not be ideal for all use cases.
Source code in flyvision/utils/cache_utils.py
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
|
flyvision.utils.chkpt_utils¶
Classes¶
flyvision.utils.chkpt_utils.Checkpoints
dataclass
¶
Dataclass to store checkpoint information.
Attributes:
Name | Type | Description |
---|---|---|
indices |
List[int]
|
List of checkpoint indices. |
paths |
List[Path]
|
List of checkpoint paths. |
Source code in flyvision/utils/chkpt_utils.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
|
Functions¶
flyvision.utils.chkpt_utils.recover_network ¶
recover_network(
network, state_dict, ensemble_and_network_id=None
)
Load network parameters from state dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network |
Module
|
FlyVision network. |
required |
state_dict |
Union[Dict, Path, str]
|
State or path to checkpoint containing the “network” parameters. |
required |
ensemble_and_network_id |
str
|
Optional identifier for the network. |
None
|
Returns:
Type | Description |
---|---|
Module
|
The updated network. |
Source code in flyvision/utils/chkpt_utils.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
|
flyvision.utils.chkpt_utils.recover_decoder ¶
recover_decoder(decoder, state_dict, strict=True)
Recover multiple decoders from state dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
decoder |
Dict[str, Module]
|
Dictionary of decoders. |
required |
state_dict |
Union[Dict, Path]
|
State or path to checkpoint. |
required |
strict |
bool
|
Whether to strictly enforce that the keys in state_dict match. |
True
|
Returns:
Type | Description |
---|---|
Dict[str, Module]
|
The updated dictionary of decoders. |
Source code in flyvision/utils/chkpt_utils.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
|
flyvision.utils.chkpt_utils.recover_optimizer ¶
recover_optimizer(optimizer, state_dict)
Recover optimizer state from state dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer |
Optimizer
|
PyTorch optimizer. |
required |
state_dict |
Union[Dict, Path]
|
State or path to checkpoint. |
required |
Returns:
Type | Description |
---|---|
Optimizer
|
The updated optimizer. |
Source code in flyvision/utils/chkpt_utils.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
flyvision.utils.chkpt_utils.recover_penalty_optimizers ¶
recover_penalty_optimizers(optimizers, state_dict)
Recover penalty optimizers from state dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizers |
Dict[str, Optimizer]
|
Dictionary of penalty optimizers. |
required |
state_dict |
Union[Dict, Path]
|
State or path to checkpoint. |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optimizer]
|
The updated dictionary of penalty optimizers. |
Source code in flyvision/utils/chkpt_utils.py
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
|
flyvision.utils.chkpt_utils.get_from_state_dict ¶
get_from_state_dict(state_dict, key)
Get a specific key from the state dict.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
state_dict |
Union[Dict, Path, str]
|
State dict or path to checkpoint. |
required |
key |
str
|
Key to retrieve from the state dict. |
required |
Returns:
Type | Description |
---|---|
Dict
|
The value associated with the key in the state dict. |
Raises:
Type | Description |
---|---|
TypeError
|
If state_dict is not of type Path, str, or dict. |
Source code in flyvision/utils/chkpt_utils.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
flyvision.utils.chkpt_utils.resolve_checkpoints ¶
resolve_checkpoints(networkdir)
Resolve checkpoints from network directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
networkdir |
NetworkDir
|
FlyVision network directory. |
required |
Returns:
Type | Description |
---|---|
Checkpoints
|
A Checkpoints object containing indices and paths of checkpoints. |
Source code in flyvision/utils/chkpt_utils.py
175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
|
flyvision.utils.chkpt_utils.checkpoint_index_to_path_map ¶
checkpoint_index_to_path_map(path, glob='chkpt_*')
Returns all numerical identifiers and paths to checkpoints stored in path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Path
|
Checkpoint directory. |
required |
glob |
str
|
Glob pattern for checkpoint files. |
'chkpt_*'
|
Returns:
Type | Description |
---|---|
Tuple[List[int], List[Path]]
|
A tuple containing a list of indices and a list of paths to checkpoints. |
Source code in flyvision/utils/chkpt_utils.py
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
|
flyvision.utils.chkpt_utils.best_checkpoint_default_fn ¶
best_checkpoint_default_fn(
path,
validation_subdir="validation",
loss_file_name="loss",
)
Find the best checkpoint based on the minimum loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Path
|
Path to the network directory. |
required |
validation_subdir |
str
|
Subdirectory containing validation data. |
'validation'
|
loss_file_name |
str
|
Name of the loss file. |
'loss'
|
Returns:
Type | Description |
---|---|
Path
|
Path to the best checkpoint. |
Source code in flyvision/utils/chkpt_utils.py
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
|
flyvision.utils.chkpt_utils.check_loss_name ¶
check_loss_name(loss_folder, loss_file_name)
Check if the loss file name exists in the loss folder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss_folder |
The folder containing loss files. |
required | |
loss_file_name |
str
|
The name of the loss file to check. |
required |
Returns:
Type | Description |
---|---|
str
|
The validated loss file name. |
Source code in flyvision/utils/chkpt_utils.py
244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
|
flyvision.utils.class_utils¶
Functions¶
flyvision.utils.class_utils.find_subclass ¶
find_subclass(cls, target_subclass_name)
Recursively search for the target subclass.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cls |
Type
|
The base class to start the search from. |
required |
target_subclass_name |
str
|
The name of the subclass to find. |
required |
Returns:
Type | Description |
---|---|
Optional[Type]
|
The found subclass, or None if not found. |
Source code in flyvision/utils/class_utils.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
|
flyvision.utils.class_utils.forward_subclass ¶
forward_subclass(
cls, config={}, subclass_key="type", unpack_kwargs=True
)
Forward to a subclass based on the <subclass_key>
key in config
.
Forwards to the parent class if <subclass_key>
is not in config
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cls |
Type
|
The base class to forward from. |
required |
config |
Dict[str, Any]
|
Configuration dictionary containing subclass information. |
{}
|
subclass_key |
str
|
Key in the config dictionary specifying the subclass. |
'type'
|
unpack_kwargs |
bool
|
Whether to unpack kwargs when initializing the instance. |
True
|
Returns:
Type | Description |
---|---|
Any
|
An instance of the specified subclass or the base class. |
Note
If the specified subclass is not found, a warning is issued and the base class is used instead.
Source code in flyvision/utils/class_utils.py
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
|
flyvision.utils.color_utils¶
Classes¶
flyvision.utils.color_utils.cmap_iter ¶
An iterator for colormap colors.
Attributes:
Name | Type | Description |
---|---|---|
i |
int
|
The current index. |
cmap |
The colormap to iterate over. |
|
stop |
int
|
The number of colors in the colormap. |
Source code in flyvision/utils/color_utils.py
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
|
__init__ ¶
__init__(cmap)
Initialize the cmap_iter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cmap |
Union[LinearSegmentedColormap, ListedColormap]
|
The colormap to iterate over. |
required |
Source code in flyvision/utils/color_utils.py
274 275 276 277 278 279 280 281 282 283 |
|
__next__ ¶
__next__()
Get the next color from the colormap.
Returns:
Type | Description |
---|---|
Tuple[float, float, float, float]
|
The next color as an RGBA tuple. |
Source code in flyvision/utils/color_utils.py
285 286 287 288 289 290 291 292 293 294 |
|
Functions¶
flyvision.utils.color_utils.is_hex ¶
is_hex(color)
Check if the given color is in hexadecimal format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
color |
Union[str, Tuple[float, float, float]]
|
The color to check. |
required |
Returns:
Type | Description |
---|---|
bool
|
True if the color is in hexadecimal format, False otherwise. |
Source code in flyvision/utils/color_utils.py
32 33 34 35 36 37 38 39 40 41 42 |
|
flyvision.utils.color_utils.is_integer_rgb ¶
is_integer_rgb(color)
Check if the given color is in integer RGB format (0-255).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
color |
Union[Tuple[float, float, float], List[float]]
|
The color to check. |
required |
Returns:
Type | Description |
---|---|
bool
|
True if the color is in integer RGB format, False otherwise. |
Source code in flyvision/utils/color_utils.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
|
flyvision.utils.color_utils.single_color_cmap ¶
single_color_cmap(color)
Create a single color colormap.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
color |
Union[str, Tuple[float, float, float]]
|
The color to use for the colormap. |
required |
Returns:
Type | Description |
---|---|
ListedColormap
|
A ListedColormap object with the specified color. |
Source code in flyvision/utils/color_utils.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
|
flyvision.utils.color_utils.color_to_cmap ¶
color_to_cmap(
end_color,
start_color="#FFFFFF",
name="custom_cmap",
N=256,
)
Create a colormap from start and end colors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
end_color |
str
|
The end color of the colormap. |
required |
start_color |
str
|
The start color of the colormap. |
'#FFFFFF'
|
name |
str
|
The name of the colormap. |
'custom_cmap'
|
N |
int
|
The number of color segments. |
256
|
Returns:
Type | Description |
---|---|
LinearSegmentedColormap
|
A LinearSegmentedColormap object. |
Source code in flyvision/utils/color_utils.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|
flyvision.utils.color_utils.get_alpha_colormap ¶
get_alpha_colormap(saturated_color, number_of_shades)
Create a colormap from a color and a number of shades.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
saturated_color |
Union[str, Tuple[float, float, float]]
|
The base color for the colormap. |
required |
number_of_shades |
int
|
The number of shades to create. |
required |
Returns:
Type | Description |
---|---|
ListedColormap
|
A ListedColormap object with varying alpha values. |
Source code in flyvision/utils/color_utils.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
|
flyvision.utils.color_utils.adapt_color_alpha ¶
adapt_color_alpha(color, alpha)
Transform a color specification to RGBA and adapt the alpha value.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
color |
Union[str, Tuple[float, float, float], Tuple[float, float, float, float]]
|
Color specification in various formats: hex string, RGB tuple, or RGBA tuple. |
required |
alpha |
float
|
New alpha value to be applied. |
required |
Returns:
Type | Description |
---|---|
Tuple[float, float, float, float]
|
The adapted color in RGBA format. |
Source code in flyvision/utils/color_utils.py
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
|
flyvision.utils.color_utils.flash_response_color_labels ¶
flash_response_color_labels(ax)
Apply color labels for ON and OFF flash responses.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ax |
The matplotlib axis to apply the labels to. |
required |
Returns:
Type | Description |
---|---|
The modified matplotlib axis. |
Source code in flyvision/utils/color_utils.py
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
|
flyvision.utils.color_utils.truncate_colormap ¶
truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100)
Truncate a colormap to a specific range.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cmap |
Union[LinearSegmentedColormap, ListedColormap]
|
The colormap to truncate. |
required |
minval |
float
|
The minimum value of the new range. |
0.0
|
maxval |
float
|
The maximum value of the new range. |
1.0
|
n |
int
|
The number of color segments in the new colormap. |
100
|
Returns:
Type | Description |
---|---|
LinearSegmentedColormap
|
A new LinearSegmentedColormap with the truncated range. |
Source code in flyvision/utils/color_utils.py
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
|
flyvision.utils.compute_cloud_utils¶
Classes¶
flyvision.utils.compute_cloud_utils.ClusterManager ¶
Bases: ABC
Abstract base class for cluster management operations.
Source code in flyvision/utils/compute_cloud_utils.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
|
run_job
abstractmethod
¶
run_job(command)
Run a job on the cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
command |
str
|
The command to run. |
required |
Returns:
Type | Description |
---|---|
str
|
The job ID as a string. |
Source code in flyvision/utils/compute_cloud_utils.py
25 26 27 28 29 30 31 32 33 34 35 36 |
|
is_running
abstractmethod
¶
is_running(job_id)
Check if a job is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
job_id |
str
|
The ID of the job to check. |
required |
Returns:
Type | Description |
---|---|
bool
|
True if the job is running, False otherwise. |
Source code in flyvision/utils/compute_cloud_utils.py
38 39 40 41 42 43 44 45 46 47 48 49 |
|
kill_job
abstractmethod
¶
kill_job(job_id)
Kill a running job.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
job_id |
str
|
The ID of the job to kill. |
required |
Returns:
Type | Description |
---|---|
str
|
The output of the kill command. |
Source code in flyvision/utils/compute_cloud_utils.py
51 52 53 54 55 56 57 58 59 60 61 62 |
|
get_submit_command
abstractmethod
¶
get_submit_command(
job_name, n_cpus, output_file, gpu, queue
)
Get the command to submit a job to the cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
job_name |
str
|
The name of the job. |
required |
n_cpus |
int
|
The number of CPUs to request. |
required |
output_file |
str
|
The file to write job output to. |
required |
gpu |
str
|
The GPU configuration. |
required |
queue |
str
|
The queue to submit the job to. |
required |
Returns:
Type | Description |
---|---|
str
|
The submit command as a string. |
Source code in flyvision/utils/compute_cloud_utils.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
|
flyvision.utils.compute_cloud_utils.LSFManager ¶
Bases: ClusterManager
Cluster manager for LSF (Load Sharing Facility) systems.
Source code in flyvision/utils/compute_cloud_utils.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
|
flyvision.utils.compute_cloud_utils.SLURMManager ¶
Bases: ClusterManager
Cluster manager for SLURM systems.
Warning
This is untested.
Source code in flyvision/utils/compute_cloud_utils.py
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
|
Functions¶
flyvision.utils.compute_cloud_utils.get_cluster_manager ¶
get_cluster_manager()
Autodetect the cluster type and return the appropriate ClusterManager.
Returns:
Type | Description |
---|---|
ClusterManager
|
An instance of the appropriate ClusterManager subclass. |
Raises:
Type | Description |
---|---|
RuntimeError
|
If neither LSF nor SLURM commands are found. |
Source code in flyvision/utils/compute_cloud_utils.py
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
|
flyvision.utils.compute_cloud_utils.run_job ¶
run_job(command, dry)
Run a job on the cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
command |
str
|
The command to run. |
required |
dry |
bool
|
If True, perform a dry run without actually submitting the job. |
required |
Returns:
Type | Description |
---|---|
str
|
The job ID as a string, or “dry run” for dry runs. |
Source code in flyvision/utils/compute_cloud_utils.py
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
|
flyvision.utils.compute_cloud_utils.is_running ¶
is_running(job_id, dry)
Check if a job is running.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
job_id |
str
|
The ID of the job to check. |
required |
dry |
bool
|
If True, always return False. |
required |
Returns:
Type | Description |
---|---|
bool
|
True if the job is running, False otherwise. |
Source code in flyvision/utils/compute_cloud_utils.py
244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
|
flyvision.utils.compute_cloud_utils.kill_job ¶
kill_job(job_id, dry)
Kill a running job.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
job_id |
str
|
The ID of the job to kill. |
required |
dry |
bool
|
If True, return a message without actually killing the job. |
required |
Returns:
Type | Description |
---|---|
str
|
The output of the kill command or a dry run message. |
Source code in flyvision/utils/compute_cloud_utils.py
260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
|
flyvision.utils.compute_cloud_utils.wait_for_single ¶
wait_for_single(job_id, job_name, dry=False)
Wait for a single job to finish on the cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
job_id |
str
|
The ID of the job to wait for. |
required |
job_name |
str
|
The name of the job. |
required |
dry |
bool
|
If True, skip actual waiting. |
False
|
Raises:
Type | Description |
---|---|
KeyboardInterrupt
|
If the waiting is interrupted by the user. |
Source code in flyvision/utils/compute_cloud_utils.py
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
|
flyvision.utils.compute_cloud_utils.wait_for_many ¶
wait_for_many(job_id_names, dry=False)
Wait for multiple jobs to finish on the cluster.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
job_id_names |
Dict[str, str]
|
A dictionary mapping job IDs to job names. |
required |
dry |
bool
|
If True, skip actual waiting. |
False
|
Raises:
Type | Description |
---|---|
KeyboardInterrupt
|
If the waiting is interrupted by the user. |
Source code in flyvision/utils/compute_cloud_utils.py
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 |
|
flyvision.utils.compute_cloud_utils.check_valid_host ¶
check_valid_host(blacklist)
Prevent running on certain blacklisted hosts, e.g., login nodes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
blacklist |
List[str]
|
A list of blacklisted hostnames or substrings. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If the current host is in the blacklist. |
Source code in flyvision/utils/compute_cloud_utils.py
324 325 326 327 328 329 330 331 332 333 334 335 336 |
|
flyvision.utils.compute_cloud_utils.launch_range ¶
launch_range(
start,
end,
ensemble_id,
task_name,
nP,
gpu,
q,
script,
dry,
kwargs,
)
Launch a range of models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
start |
int
|
The starting index for the range. |
required |
end |
int
|
The ending index for the range. |
required |
ensemble_id |
str
|
The ID of the ensemble. |
required |
task_name |
str
|
The name of the task. |
required |
nP |
int
|
The number of processors to use. |
required |
gpu |
str
|
The GPU configuration. |
required |
q |
str
|
The queue to submit the job to. |
required |
script |
str
|
The script to run. |
required |
dry |
bool
|
If True, perform a dry run without actually submitting jobs. |
required |
kwargs |
List[str]
|
A list of additional keyword arguments for the script. |
required |
Note
kwargs is an ordered list of strings, either in the format [“-kw”, “val”, …] or following hydra syntax, i.e. [“kw=val”, …].
Source code in flyvision/utils/compute_cloud_utils.py
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 |
|
flyvision.utils.compute_cloud_utils.launch_single ¶
launch_single(
ensemble_id, task_name, nP, gpu, q, script, dry, kwargs
)
Launch a single job for an ensemble.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ensemble_id |
str
|
The ID of the ensemble. |
required |
task_name |
str
|
The name of the task. |
required |
nP |
int
|
The number of processors to use. |
required |
gpu |
str
|
The GPU configuration. |
required |
q |
str
|
The queue to submit the job to. |
required |
script |
str
|
The script to run. |
required |
dry |
bool
|
If True, perform a dry run without actually submitting the job. |
required |
kwargs |
List[str]
|
A list of additional keyword arguments for the script. |
required |
Note
kwargs is an ordered list of strings, either in the format [“-kw”, “val”, …] or following hydra syntax, i.e. [“kw=val”, …].
Source code in flyvision/utils/compute_cloud_utils.py
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 |
|
flyvision.utils.config_utils¶
Classes¶
flyvision.utils.config_utils.HybridArgumentParser ¶
Bases: ArgumentParser
Hybrid argument parser that can parse unknown arguments in basic key=value style.
Attributes:
Name | Type | Description |
---|---|---|
hybrid_args |
Dictionary of hybrid arguments with their requirements and help texts. |
|
allow_unrecognized |
Whether to allow unrecognized arguments. |
|
drop_disjoint_from |
Path to a configuration file that can be used to filter out arguments that are present in the command line arguments but not in the configuration file. This is to pass through arguments through multiple scripts as hydra does not support this. |
Parameters:
Name | Type | Description | Default |
---|---|---|---|
hybrid_args |
Optional[Dict[str, Dict[str, Any]]]
|
Dictionary of hybrid arguments with their requirements and help texts. |
None
|
allow_unrecognized |
bool
|
Whether to allow unrecognized arguments. |
True
|
Source code in flyvision/utils/config_utils.py
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 |
|
parse_with_hybrid_args ¶
parse_with_hybrid_args(args=None, namespace=None)
Parse arguments and set hybrid arguments as attributes in the namespace.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
args |
Optional[List[str]]
|
List of arguments to parse. |
None
|
namespace |
Optional[Namespace]
|
Namespace to populate with parsed arguments. |
None
|
Returns:
Type | Description |
---|---|
Namespace
|
Namespace with parsed arguments. |
Raises:
Type | Description |
---|---|
ArgumentError
|
If required arguments are missing or invalid values are provided. |
Source code in flyvision/utils/config_utils.py
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
|
Functions¶
flyvision.utils.config_utils.get_default_config ¶
get_default_config(
overrides,
path="../../config/solver.yaml",
as_namespace=True,
)
Get the default configuration using Hydra.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
overrides |
List[str]
|
List of configuration overrides. |
required |
path |
str
|
Path to the configuration file. |
'../../config/solver.yaml'
|
as_namespace |
bool
|
Whether to return a namespaced configuration or the OmegaConf object. |
True
|
Returns:
Type | Description |
---|---|
Union[Dict[DictKeyType, Any], List[Any], None, str, Any, Namespace]
|
The configuration object. |
Note
Expected overrides are: - task_name - network_id
Source code in flyvision/utils/config_utils.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
|
flyvision.utils.config_utils.parse_kwargs_to_dict ¶
parse_kwargs_to_dict(values)
Parse a list of key-value pairs into a dictionary.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values |
List[str]
|
List of key-value pairs in the format “key=value”. |
required |
Returns:
Type | Description |
---|---|
Namespace
|
Namespace object with parsed key-value pairs as attributes. |
Source code in flyvision/utils/config_utils.py
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
|
flyvision.utils.config_utils.safe_cast ¶
safe_cast(value, type_name)
Safely cast a string value to a specified type.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
value |
str
|
The string value to cast. |
required |
type_name |
str
|
The name of the type to cast to. |
required |
Returns:
Type | Description |
---|---|
Union[int, float, bool, str]
|
The casted value. |
Note
Supports casting to int, float, bool, and str.
Source code in flyvision/utils/config_utils.py
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 |
|
flyvision.utils.dataset_utils¶
Classes¶
flyvision.utils.dataset_utils.CrossValIndices ¶
Returns folds of indices for cross-validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_samples |
int
|
Total number of samples. |
required |
folds |
int
|
Total number of folds. |
required |
shuffle |
bool
|
Shuffles the indices. |
True
|
seed |
int
|
Seed for shuffling. |
0
|
Attributes:
Name | Type | Description |
---|---|---|
n_samples |
Total number of samples. |
|
folds |
Total number of folds. |
|
indices |
Array of indices. |
|
random |
RandomState object for shuffling. |
Source code in flyvision/utils/dataset_utils.py
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
|
__call__ ¶
__call__(fold)
Returns train and test indices for a fold.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fold |
int
|
The fold number. |
required |
Returns:
Type | Description |
---|---|
Tuple[ndarray, ndarray]
|
A tuple containing train and test indices. |
Source code in flyvision/utils/dataset_utils.py
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
|
iter ¶
iter()
Iterate over all folds.
Yields:
Type | Description |
---|---|
Tuple[ndarray, ndarray]
|
A tuple containing train and test indices for each fold. |
Source code in flyvision/utils/dataset_utils.py
176 177 178 179 180 181 182 183 |
|
flyvision.utils.dataset_utils.IndexSampler ¶
Bases: Sampler
Samples the provided indices in sequence.
Note
To be used with torch.utils.data.DataLoader.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
indices |
List[int]
|
List of indices to sample. |
required |
Attributes:
Name | Type | Description |
---|---|---|
indices |
List of indices to sample. |
Source code in flyvision/utils/dataset_utils.py
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
|
Functions¶
flyvision.utils.dataset_utils.random_walk_of_blocks ¶
random_walk_of_blocks(
n_blocks=20,
block_size=4,
top_lum=0,
bottom_lum=0,
dataset_size=[3, 20, 64, 64],
noise_mean=0.5,
noise_std=0.1,
step_size=4,
p_random=0.6,
p_center_attraction=0.3,
p_edge_attraction=0.1,
seed=42,
)
Generate a sequence dataset with blocks doing random walks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_blocks |
int
|
Number of blocks. |
20
|
block_size |
int
|
Size of blocks. |
4
|
top_lum |
float
|
Luminance of the top of the block. |
0
|
bottom_lum |
float
|
Luminance of the bottom of the block. |
0
|
dataset_size |
List[int]
|
Size of the dataset. (n_sequences, n_frames, h, w) |
[3, 20, 64, 64]
|
noise_mean |
float
|
Mean of the background noise. |
0.5
|
noise_std |
float
|
Standard deviation of the background noise. |
0.1
|
step_size |
int
|
Number of pixels to move in each step. |
4
|
p_random |
float
|
Probability of moving randomly. |
0.6
|
p_center_attraction |
float
|
Probability of moving towards the center. |
0.3
|
p_edge_attraction |
float
|
Probability of moving towards the edge. |
0.1
|
seed |
int
|
Seed for the random number generator. |
42
|
Returns:
Type | Description |
---|---|
ndarray
|
Dataset of shape (n_sequences, n_frames, h, w) |
Source code in flyvision/utils/dataset_utils.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|
flyvision.utils.dataset_utils.load_moving_mnist ¶
load_moving_mnist(delete_if_exists=False)
Return Moving MNIST dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
delete_if_exists |
bool
|
If True, delete the dataset if it exists. |
False
|
Returns:
Type | Description |
---|---|
ndarray
|
Dataset of shape (n_sequences, n_frames, h, w)==(10000, 20, 64, 64). |
Note
This dataset (0.78GB) will be downloaded if not present. The download is stored in flyvision.root_dir / “mnist_test_seq.npy”.
Source code in flyvision/utils/dataset_utils.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
|
flyvision.utils.dataset_utils.get_random_data_split ¶
get_random_data_split(
fold, n_samples, n_folds, shuffle=True, seed=0
)
Return indices to split the data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
fold |
int
|
The fold number. |
required |
n_samples |
int
|
Total number of samples. |
required |
n_folds |
int
|
Total number of folds. |
required |
shuffle |
bool
|
Whether to shuffle the indices. |
True
|
seed |
int
|
Seed for shuffling. |
0
|
Returns:
Type | Description |
---|---|
Tuple[ndarray, ndarray]
|
A tuple containing train and validation indices. |
Source code in flyvision/utils/dataset_utils.py
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
|
flyvision.utils.df_utils¶
Functions¶
flyvision.utils.df_utils.filter_by_column_values ¶
filter_by_column_values(dataframe, column, values)
Return subset of dataframe based on list of values to appear in a column.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataframe |
DataFrame
|
DataFrame with key as column. |
required |
column |
str
|
Column of the dataframe, e.g. |
required |
values |
Iterable
|
Types of neurons e.g. R1, T4a, etc. |
required |
Returns:
Name | Type | Description |
---|---|---|
DataFrame |
DataFrame
|
Subset of the input dataframe. |
Example
filtered_df = filter_by_column_values(df, 'neuron_type', ['R1', 'T4a'])
Source code in flyvision/utils/df_utils.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
|
flyvision.utils.df_utils.where_dataframe ¶
where_dataframe(arg_df, **kwargs)
Return indices of rows in a DataFrame where conditions are met.
Conditions are passed as keyword arguments.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
arg_df |
DataFrame
|
Input DataFrame. |
required |
**kwargs |
Keyword arguments representing conditions. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
DataFrame |
DataFrame
|
Indices of rows where conditions are met. |
Example
indices = where_dataframe(df, type='T4a', u=2, v=0)
Note
The dataframe is expected to have columns matching the keyword arguments.
Source code in flyvision/utils/df_utils.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
|
flyvision.utils.hex_utils¶
Classes¶
flyvision.utils.hex_utils.Hexal ¶
Hexal representation containing u, v, z coordinates and value.
Attributes:
Name | Type | Description |
---|---|---|
u |
Coordinate in u principal direction (0 degree axis). |
|
v |
Coordinate in v principal direction (60 degree axis). |
|
z |
Coordinate in z principal direction (-60 degree axis). |
|
value |
‘Hexal’ value. |
|
u_stride |
Stride in u-direction. |
|
v_stride |
Stride in v-direction. |
Source code in flyvision/utils/hex_utils.py
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 |
|
__eq__ ¶
__eq__(other)
Compares coordinates (not values).
Source code in flyvision/utils/hex_utils.py
364 365 366 367 368 369 |
|
__add__ ¶
__add__(other)
Adds u and v coordinates, while keeping the value of the left hexal.
Source code in flyvision/utils/hex_utils.py
371 372 373 374 375 376 |
|
__mul__ ¶
__mul__(other)
Multiplies values, while preserving coordinates.
Source code in flyvision/utils/hex_utils.py
378 379 380 381 382 383 384 385 |
|
eq_val ¶
eq_val(other)
Compares the values, not the coordinates.
Source code in flyvision/utils/hex_utils.py
387 388 389 390 391 392 |
|
neighbours ¶
neighbours()
Returns 6 neighbours sorted CCW, starting from east.
Source code in flyvision/utils/hex_utils.py
420 421 422 423 424 425 426 427 428 429 |
|
is_neighbour ¶
is_neighbour(other)
Evaluates if other is a neighbour.
Source code in flyvision/utils/hex_utils.py
431 432 433 434 435 436 437 |
|
unit_directions
staticmethod
¶
unit_directions()
Returns the six unit directions.
Source code in flyvision/utils/hex_utils.py
439 440 441 442 |
|
interp ¶
interp(other, t)
Interpolates towards other.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
t |
float
|
interpolation step, 0<t<1. |
required |
Returns:
Type | Description |
---|---|
Hexal |
Source code in flyvision/utils/hex_utils.py
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 |
|
angle ¶
angle(other=None, non_negative=False)
Returns the angle to other or the origin.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
non_negative |
bool
|
add 2pi if angle is negative. Default: False. |
False
|
Returns:
Name | Type | Description |
---|---|---|
float |
angle in radians. |
Source code in flyvision/utils/hex_utils.py
492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 |
|
distance ¶
distance(other=None)
Returns the columnar distance between to hexals.
Source code in flyvision/utils/hex_utils.py
525 526 527 528 529 530 531 532 533 534 535 536 |
|
flyvision.utils.hex_utils.HexArray ¶
Bases: ndarray
Flat array holding Hexal’s as elements.
Can be constructed with
HexArray(hexals: Iterable, values: Optional[np.nan]) HexArray(u: Iterable, v: Iterable, values: Optional[np.nan])
Source code in flyvision/utils/hex_utils.py
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 |
|
get_extent
staticmethod
¶
get_extent(
hexals=None, u=None, v=None, center=Hexal(0, 0, 0)
)
Returns the columnar extent.
Source code in flyvision/utils/hex_utils.py
605 606 607 608 609 610 611 612 613 614 615 616 |
|
with_stride ¶
with_stride(u_stride=None, v_stride=None)
Returns a sliced instance obeying strides in u- and v-direction.
Source code in flyvision/utils/hex_utils.py
639 640 641 642 643 644 |
|
where ¶
where(value)
Returns a mask of where values are equal to the given one.
Note: value can be np.nan.
Source code in flyvision/utils/hex_utils.py
646 647 648 649 650 651 |
|
fill ¶
fill(value)
Fills the values with the given one.
Source code in flyvision/utils/hex_utils.py
653 654 655 656 |
|
to_pixel ¶
to_pixel(scale=1, mode='default')
Converts to pixel coordinates.
Source code in flyvision/utils/hex_utils.py
658 659 660 |
|
plot ¶
plot(figsize=[3, 3], fill=True)
Plots values in regular hexagonal lattice.
Meant for debugging.
Source code in flyvision/utils/hex_utils.py
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 |
|
flyvision.utils.hex_utils.HexLattice ¶
Bases: HexArray
Flat array of Hexals.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
extent |
Extent of the regular hexagon grid. |
required | |
hexals |
Existing hexals to initialize with. |
required | |
center |
Center hexal of the lattice. |
required | |
u_stride |
Stride in u-direction. |
required | |
v_stride |
Stride in v-direction. |
required |
Source code in flyvision/utils/hex_utils.py
681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 |
|
circle ¶
circle(
radius=None, center=Hexal(0, 0, 0), as_lattice=False
)
Draws a circle in hex coordinates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
radius |
Radius in columns of the circle. |
None
|
|
center |
Center of the circle. |
Hexal(0, 0, 0)
|
|
as_lattice |
Returns the circle on a constrained regular lattice. |
False
|
Source code in flyvision/utils/hex_utils.py
736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 |
|
filled_circle
staticmethod
¶
filled_circle(
radius=None, center=Hexal(0, 0, 0), as_lattice=False
)
Draws a circle in hex coordinates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
radius |
Radius in columns of the circle. |
None
|
|
center |
Center of the circle. |
Hexal(0, 0, 0)
|
|
as_lattice |
Returns the circle on a constrained regular lattice. |
False
|
Source code in flyvision/utils/hex_utils.py
756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 |
|
hull ¶
hull()
Returns the hull of the regular lattice.
Source code in flyvision/utils/hex_utils.py
777 778 779 |
|
line ¶
line(angle, center=Hexal(0, 0, 1), as_lattice=False)
Returns a line on a HexLattice or HexArray.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
angle |
In [0, np.pi] |
required | |
center |
Midpoint of the line |
Hexal(0, 0, 1)
|
|
as_lattice |
Returns the ring on a constrained regular lattice. |
False
|
Returns:
Type | Description |
---|---|
HexArray or constrained HexLattice |
Source code in flyvision/utils/hex_utils.py
802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 |
|
flyvision.utils.hex_utils.LatticeMask ¶
Boolean masks for lattice dimension.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
extent |
int
|
Extent of the hexagonal lattice. |
15
|
u_stride |
int
|
Stride in u-direction. |
1
|
v_stride |
int
|
Stride in v-direction. |
1
|
Source code in flyvision/utils/hex_utils.py
841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 |
|
Functions¶
flyvision.utils.hex_utils.get_hex_coords ¶
get_hex_coords(extent, astensor=False)
Construct hexagonal coordinates for a regular hex-lattice with extent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
extent |
int
|
Integer radius of hexagonal lattice. 0 returns the single center coordinate. |
required |
astensor |
bool
|
If True, returns torch.Tensor, else np.array. |
False
|
Returns:
Type | Description |
---|---|
Tuple[NDArray, NDArray]
|
A tuple containing: u: Hex-coordinates in u-direction. v: Hex-coordinates in v-direction. |
Note
Will return get_num_hexals(extent)
coordinates.
Source code in flyvision/utils/hex_utils.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
|
flyvision.utils.hex_utils.hex_to_pixel ¶
hex_to_pixel(u, v, size=1, mode='default')
Returns pixel coordinates from hex coordinates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
u |
NDArray
|
Hex-coordinates in u-direction. |
required |
v |
NDArray
|
Hex-coordinates in v-direction. |
required |
size |
float
|
Size of hexagon. |
1
|
mode |
Literal['default', 'flat', 'pointy']
|
Coordinate system convention. |
'default'
|
Returns:
Type | Description |
---|---|
Tuple[NDArray, NDArray]
|
A tuple containing: x: Pixel-coordinates in x-direction. y: Pixel-coordinates in y-direction. |
Source code in flyvision/utils/hex_utils.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
flyvision.utils.hex_utils.hex_rows ¶
hex_rows(n_rows, n_columns, eps=0.1, mode='pointy')
Return a hex grid in pixel coordinates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_rows |
int
|
Number of rows. |
required |
n_columns |
int
|
Number of columns. |
required |
eps |
float
|
Small offset to avoid overlapping hexagons. |
0.1
|
mode |
Literal['pointy', 'flat']
|
Orientation of hexagons. |
'pointy'
|
Returns:
Type | Description |
---|---|
Tuple[NDArray, NDArray]
|
A tuple containing: x: X-coordinates of hexagon centers. y: Y-coordinates of hexagon centers. |
Source code in flyvision/utils/hex_utils.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
flyvision.utils.hex_utils.pixel_to_hex ¶
pixel_to_hex(x, y, size=1, mode='default')
Returns hex coordinates from pixel coordinates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
NDArray
|
Pixel-coordinates in x-direction. |
required |
y |
NDArray
|
Pixel-coordinates in y-direction. |
required |
size |
float
|
Size of hexagon. |
1
|
mode |
Literal['default', 'flat', 'pointy']
|
Coordinate system convention. |
'default'
|
Returns:
Type | Description |
---|---|
Tuple[NDArray, NDArray]
|
A tuple containing: u: Hex-coordinates in u-direction. v: Hex-coordinates in v-direction. |
Source code in flyvision/utils/hex_utils.py
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
|
flyvision.utils.hex_utils.pad_to_regular_hex ¶
pad_to_regular_hex(u, v, values, extent, value=np.nan)
Pad hexals with coordinates to a regular hex lattice.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
u |
NDArray
|
U-coordinate of hexal. |
required |
v |
NDArray
|
V-coordinate of hexal. |
required |
values |
NDArray
|
Value of hexal with arbitrary shape but last axis must match the hexal dimension. |
required |
extent |
int
|
Extent of regular hex grid to pad to. |
required |
value |
float
|
The pad value. |
nan
|
Returns:
Type | Description |
---|---|
Tuple[NDArray, NDArray, NDArray]
|
A tuple containing: u_padded: Padded u-coordinate. v_padded: Padded v-coordinate. values_padded: Padded value. |
Note
The canonical use case here is to pad a filter, receptive field, or postsynaptic current field for visualization.
Example
u = np.array([1, 0, -1, 0, 1, 2])
v = np.array([-2, -1, 0, 0, 0, 0])
values = np.array([0.05, 0.1, 0.3, 0.5, 0.7, 0.9])
hexals = pad_to_regular_hex(u, v, values, 6)
hex_scatter(*hexals, edgecolor='k', cmap=plt.cm.Blues, vmin=0, vmax=1)
Source code in flyvision/utils/hex_utils.py
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
|
flyvision.utils.hex_utils.max_extent_index ¶
max_extent_index(u, v, max_extent)
Returns a mask to constrain u and v axial-hex-coordinates by max_extent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
u |
NDArray
|
Hex-coordinates in u-direction. |
required |
v |
NDArray
|
Hex-coordinates in v-direction. |
required |
max_extent |
int
|
Maximal extent. |
required |
Returns:
Type | Description |
---|---|
NDArray
|
Boolean mask. |
Source code in flyvision/utils/hex_utils.py
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
|
flyvision.utils.hex_utils.get_num_hexals ¶
get_num_hexals(extent)
Returns the absolute number of hexals in a hexagonal grid with extent.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
extent |
int
|
Extent of hex-lattice. |
required |
Returns:
Type | Description |
---|---|
int
|
Number of hexals. |
Note
Inverse of get_hextent.
Source code in flyvision/utils/hex_utils.py
218 219 220 221 222 223 224 225 226 227 228 229 230 |
|
flyvision.utils.hex_utils.get_hextent ¶
get_hextent(num_hexals)
Computes the hex-lattice extent from the number of hexals.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_hexals |
int
|
Number of hexals. |
required |
Returns:
Type | Description |
---|---|
int
|
Extent of hex-lattice. |
Note
Inverse of get_num_hexals.
Source code in flyvision/utils/hex_utils.py
233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
|
flyvision.utils.hex_utils.sort_u_then_v ¶
sort_u_then_v(u, v, values)
Sorts u, v, and values by u and then v.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
u |
NDArray
|
U-coordinate of hexal. |
required |
v |
NDArray
|
V-coordinate of hexal. |
required |
values |
NDArray
|
Value of hexal. |
required |
Returns:
Type | Description |
---|---|
Tuple[NDArray, NDArray, NDArray]
|
A tuple containing: u: Sorted u-coordinate of hexal. v: Sorted v-coordinate of hexal. values: Sorted value of hexal. |
Source code in flyvision/utils/hex_utils.py
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 |
|
flyvision.utils.hex_utils.sort_u_then_v_index ¶
sort_u_then_v_index(u, v)
Index to sort u, v by u and then v.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
u |
NDArray
|
U-coordinate of hexal. |
required |
v |
NDArray
|
V-coordinate of hexal. |
required |
Returns:
Type | Description |
---|---|
NDArray
|
Index to sort u and v. |
Source code in flyvision/utils/hex_utils.py
269 270 271 272 273 274 275 276 277 278 279 |
|
flyvision.utils.hex_utils.get_extent ¶
get_extent(u, v, astype=int)
Returns extent (integer distance to origin) of arbitrary u, v coordinates.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
u |
NDArray
|
U-coordinate of hexal. |
required |
v |
NDArray
|
V-coordinate of hexal. |
required |
astype |
type
|
Type to cast to. |
int
|
Returns:
Type | Description |
---|---|
int
|
Extent of hex-lattice. |
Note
If u and v are arrays, returns the maximum extent.
Source code in flyvision/utils/hex_utils.py
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
|
flyvision.utils.log_utils¶
Classes¶
flyvision.utils.log_utils.Status
dataclass
¶
Status object from log files of model runs.
Attributes:
Name | Type | Description |
---|---|---|
ensemble_name |
str
|
Name of the ensemble. |
log_files |
List[Path]
|
List of all log files. |
train_logs |
List[Path]
|
List of train logs. |
model_id_to_train_log_file |
Dict[str, Path]
|
Mapping of model ID to log file. |
status |
Dict[str, str]
|
Mapping of model ID to status. |
user_input |
Dict[str, str]
|
Mapping of model ID to user input (behind LSF command). |
hosts |
Dict[str, List[str]]
|
Mapping of model ID to host. |
rerun_failed_runs |
Dict[str, List[str]]
|
Formatted submission commands to restart failed models. |
lsf_part |
str
|
LSF command part. |
Source code in flyvision/utils/log_utils.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
|
print_for_rerun ¶
print_for_rerun(exclude_failed_hosts=True, model_ids=None)
Print formatted submission commands to restart failed models.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
exclude_failed_hosts |
bool
|
Whether to exclude failed hosts. |
True
|
model_ids |
List[str]
|
List of model IDs to rerun. If None, all failed models are included. |
None
|
Source code in flyvision/utils/log_utils.py
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
|
get_hosts ¶
get_hosts()
Get hosts on which the job was executed.
Source code in flyvision/utils/log_utils.py
56 57 58 |
|
bad_hosts ¶
bad_hosts()
Get hosts on which the job failed.
Source code in flyvision/utils/log_utils.py
60 61 62 63 64 65 66 67 |
|
successful_runs ¶
successful_runs()
Get number of successful runs.
Source code in flyvision/utils/log_utils.py
69 70 71 |
|
running_runs ¶
running_runs()
Get number of running runs.
Source code in flyvision/utils/log_utils.py
73 74 75 |
|
failed_runs ¶
failed_runs()
Get number of failed runs.
Source code in flyvision/utils/log_utils.py
77 78 79 |
|
successful_model_ids ¶
successful_model_ids()
Get model IDs of successful runs.
Source code in flyvision/utils/log_utils.py
81 82 83 |
|
running_model_ids ¶
running_model_ids()
Get model IDs of running runs.
Source code in flyvision/utils/log_utils.py
85 86 87 |
|
failed_model_ids ¶
failed_model_ids()
Get model IDs of failed runs.
Source code in flyvision/utils/log_utils.py
89 90 91 |
|
lookup_log ¶
lookup_log(
model_id, log_type="train_single", last_n_lines=20
)
Lookup log for a model ID.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_id |
str
|
ID of the model. |
required |
log_type |
str
|
Type of log to lookup. |
'train_single'
|
last_n_lines |
int
|
Number of lines to return from the end of the log. |
20
|
Returns:
Type | Description |
---|---|
List[str]
|
List of log lines. |
Source code in flyvision/utils/log_utils.py
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
|
extract_error_trace ¶
extract_error_trace(
model_id,
check_last_n_lines=100,
log_type="train_single",
)
Extract the Python error message and traceback from a given log string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_id |
str
|
ID of the model. |
required |
check_last_n_lines |
int
|
Number of lines to check from the end of the log. |
100
|
log_type |
str
|
Type of log to extract error from. |
'train_single'
|
Returns:
Type | Description |
---|---|
str
|
Extracted error message and traceback, or a message if no error is found. |
Source code in flyvision/utils/log_utils.py
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
|
extract_error_type ¶
extract_error_type(
model_id,
log_type="train_single",
check_last_n_lines=100,
)
Extract the type of error from a given log string.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_id |
str
|
ID of the model. |
required |
log_type |
str
|
Type of log to extract error from. |
'train_single'
|
check_last_n_lines |
int
|
Number of lines to check from the end of the log. |
100
|
Returns:
Type | Description |
---|---|
str
|
Extracted error type, or a message if no specific error type is found. |
Source code in flyvision/utils/log_utils.py
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
|
all_errors ¶
all_errors(check_last_n_lines=100, log_type='train_single')
Get all unique errors from failed runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
check_last_n_lines |
int
|
Number of lines to check from the end of the log. |
100
|
log_type |
str
|
Type of log to extract errors from. |
'train_single'
|
Returns:
Type | Description |
---|---|
set
|
Set of unique error messages. |
Source code in flyvision/utils/log_utils.py
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
|
all_error_types ¶
all_error_types(log_type='train_single')
Get all unique error types from failed runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_type |
str
|
Type of log to extract error types from. |
'train_single'
|
Returns:
Type | Description |
---|---|
set
|
Set of unique error types. |
Source code in flyvision/utils/log_utils.py
174 175 176 177 178 179 180 181 182 183 184 185 186 |
|
print_all_errors ¶
print_all_errors(
check_last_n_lines=100, log_type="train_single"
)
Print all errors and tracebacks from failed runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
check_last_n_lines |
int
|
Number of lines to check from the end of the log. |
100
|
log_type |
str
|
Type of log to extract errors from. |
'train_single'
|
Source code in flyvision/utils/log_utils.py
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
|
__getitem__ ¶
__getitem__(key)
Get status for a specific model ID.
Source code in flyvision/utils/log_utils.py
211 212 213 214 215 |
|
__repr__ ¶
__repr__()
Return a string representation of the Status object.
Source code in flyvision/utils/log_utils.py
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
|
Functions¶
flyvision.utils.log_utils.find_host ¶
find_host(log_string)
Find the host(s) on which the job was executed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_string |
str
|
The log string to search for host information. |
required |
Returns:
Type | Description |
---|---|
List[str]
|
List of host names. |
Source code in flyvision/utils/log_utils.py
238 239 240 241 242 243 244 245 246 247 248 |
|
flyvision.utils.log_utils.get_exclude_host_part ¶
get_exclude_host_part(log_string, exclude_hosts)
Get the part of the LSF command that excludes hosts.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
log_string |
str
|
The log string to search for host information. |
required |
exclude_hosts |
Union[str, List[str]]
|
Host(s) to exclude. Can be ‘auto’, a single host name, or a list of host names. |
required |
Returns:
Type | Description |
---|---|
str
|
The LSF command part for excluding hosts. |
Source code in flyvision/utils/log_utils.py
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 |
|
flyvision.utils.log_utils.get_status ¶
get_status(
ensemble_name,
nP=4,
gpu="num=1",
queue="gpu_l4",
exclude_hosts="auto",
)
Get Status object for the ensemble of models with formatting for rerun.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ensemble_name |
str
|
Ensemble name (e.g. “flow/ |
required |
nP |
int
|
Number of processors. |
4
|
gpu |
str
|
Number of GPUs. |
'num=1'
|
queue |
str
|
Queue name. |
'gpu_l4'
|
exclude_hosts |
Union[str, List[str]]
|
Host(s) to exclude. Can be ‘auto’, a single host name, or a list of host names. |
'auto'
|
Returns:
Type | Description |
---|---|
Status
|
Status object containing information about the ensemble runs. |
Source code in flyvision/utils/log_utils.py
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 |
|
flyvision.utils.log_utils.flatten_list ¶
flatten_list(nested_list)
Flatten a nested list of lists into a single list with all elements.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nested_list |
List
|
A nested list of lists to be flattened. |
required |
Returns:
Type | Description |
---|---|
List
|
A single flattened list with all elements. |
Source code in flyvision/utils/log_utils.py
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 |
|
flyvision.utils.logging_utils¶
Functions¶
flyvision.utils.logging_utils.warn_once
cached
¶
warn_once(logger, msg)
Log a warning message only once for a given logger and message combination.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logger |
Logger
|
The logger object to use for logging. |
required |
msg |
str
|
The warning message to log. |
required |
Note
This function uses an LRU cache to ensure each unique combination of logger and message is only logged once.
Source code in flyvision/utils/logging_utils.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
|
flyvision.utils.logging_utils.save_conda_environment ¶
save_conda_environment(path)
Save the current Conda environment to a JSON file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Path
|
The path where the JSON file will be saved. |
required |
Note
The function appends ‘.json’ to the provided path.
Source code in flyvision/utils/logging_utils.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
|
flyvision.utils.logging_utils.all_logging_disabled ¶
all_logging_disabled(highest_level=logging.CRITICAL)
A context manager that prevents any logging messages from being processed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
highest_level |
int
|
The maximum logging level to disable. Only needs to be changed if a custom level greater than CRITICAL is defined. |
CRITICAL
|
Example
with all_logging_disabled():
# Code here will not produce any log output
logging.warning("This warning will not be logged")
Source code in flyvision/utils/logging_utils.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
|
flyvision.utils.nn_utils¶
Classes¶
flyvision.utils.nn_utils.NumberOfParams
dataclass
¶
Dataclass to store the number of free and fixed parameters.
Attributes:
Name | Type | Description |
---|---|---|
free |
int
|
The number of trainable parameters. |
fixed |
int
|
The number of non-trainable parameters. |
Source code in flyvision/utils/nn_utils.py
48 49 50 51 52 53 54 55 56 57 58 59 |
|
Functions¶
flyvision.utils.nn_utils.simulation ¶
simulation(network)
Context manager to turn off training mode and require_grad for a network.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
network |
Module
|
The neural network module to simulate. |
required |
Yields:
Type | Description |
---|---|
None
|
None |
Example
model = MyNeuralNetwork()
with simulation(model):
# Perform inference or evaluation
output = model(input_data)
Note
This context manager temporarily disables gradient computation and sets the network to evaluation mode. It restores the original state after exiting the context.
Source code in flyvision/utils/nn_utils.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
|
flyvision.utils.nn_utils.n_params ¶
n_params(nnmodule)
Returns the numbers of free and fixed parameters in a PyTorch module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
nnmodule |
Module
|
The PyTorch module to analyze. |
required |
Returns:
Type | Description |
---|---|
NumberOfParams
|
A NumberOfParams object containing the count of free and fixed parameters. |
Example
model = MyNeuralNetwork()
param_count = n_params(model)
print(f"Free parameters: {param_count.free}")
print(f"Fixed parameters: {param_count.fixed}")
Source code in flyvision/utils/nn_utils.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
|
flyvision.utils.nodes_edges_utils¶
Classes¶
flyvision.utils.nodes_edges_utils.NodeIndexer ¶
Bases: dict
Attribute-style accessible map from cell types to indices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
connectome |
Optional[ConnectomeFromAvgFilters]
|
Connectome object. The cell types are taken from the connectome and references are created in order. |
None
|
unique_cell_types |
Optional[NDArray[str]]
|
Array of unique cell types. Optional. To specify the mapping from cell types to indices in provided order. |
None
|
Attributes:
Name | Type | Description |
---|---|---|
unique_cell_types |
NDArray[str]
|
Array of unique cell types. |
central_cells_index |
Optional[NDArray[int]]
|
Array of indices of central cells. |
Raises:
Type | Description |
---|---|
ValueError
|
If neither connectome nor unique_cell_types is provided. |
Source code in flyvision/utils/nodes_edges_utils.py
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
|
flyvision.utils.nodes_edges_utils.CellTypeArray ¶
Attribute-style accessible map from cell types to coordinates in array.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
array |
Union[NDArray, Tensor]
|
Has the dim-th axis corresponding to unique cell types in the connectome or provided cell types. |
required |
connectome |
Optional[ConnectomeFromAvgFilters]
|
Connectome object. |
None
|
cell_types |
Optional[NDArray[str]]
|
Array of cell types. |
None
|
dim |
int
|
Axis corresponding to unique cell types. |
-1
|
Attributes:
Name | Type | Description |
---|---|---|
node_indexer |
NodeIndexer
|
Indexer for cell types. |
array |
NDArray
|
The array of cell type data. |
dim |
int
|
Dimension corresponding to cell types. |
cell_types |
NDArray[str]
|
Array of unique cell types. |
Source code in flyvision/utils/nodes_edges_utils.py
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 |
|
Functions¶
flyvision.utils.nodes_edges_utils.order_node_type_list ¶
order_node_type_list(
node_types,
groups=[
"R\\d",
"L\\d",
"Lawf\\d",
"A",
"C\\d",
"CT\\d.*",
"Mi\\d{1,2}",
"T\\d{1,2}.*",
"Tm.*\\d{1,2}.*",
],
)
Orders a list of node types by the regular expressions defined in groups.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
node_types |
List[str]
|
Messy list of nodes. |
required |
groups |
List[str]
|
Ordered list of regular expressions to sort node_types. |
['R\\d', 'L\\d', 'Lawf\\d', 'A', 'C\\d', 'CT\\d.*', 'Mi\\d{1,2}', 'T\\d{1,2}.*', 'Tm.*\\d{1,2}.*']
|
Returns:
Type | Description |
---|---|
List[str]
|
A tuple containing: |
List[int]
|
|
Tuple[List[str], List[int]]
|
|
Raises:
Type | Description |
---|---|
AssertionError
|
If sorting doesn’t include all cell types. |
ValueError
|
If sorting fails due to length mismatch. |
Source code in flyvision/utils/nodes_edges_utils.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
flyvision.utils.nodes_edges_utils.get_index_mapping_lists ¶
get_index_mapping_lists(from_list, to_list)
Get indices to sort and filter from_list by occurrence of items in to_list.
The indices are useful to sort or filter another list or tensor that is an ordered mapping to items in from_list to the order of items in to_list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
from_list |
List[str]
|
Original list of items. |
required |
to_list |
List[str]
|
Target list of items. |
required |
Returns:
Type | Description |
---|---|
List[int]
|
List of indices for sorting. |
Example
from_list = ["a", "b", "c"]
mapping_to_from_list = [1, 2, 3]
to_list = ["c", "a", "b"]
sort_index = get_index_mapping_lists(from_list, to_list)
sorted_list = [mapping_to_from_list[i] for i in sort_index]
# sorted_list will be [3, 1, 2]
Source code in flyvision/utils/nodes_edges_utils.py
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
flyvision.utils.nodes_edges_utils.sort_by_mapping_lists ¶
sort_by_mapping_lists(from_list, to_list, tensor, axis=0)
Sort and filter a tensor along an axis indexed by from_list to match to_list.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
from_list |
List[str]
|
Original list of items. |
required |
to_list |
List[str]
|
Target list of items. |
required |
tensor |
Union[ndarray, Tensor]
|
Tensor to be sorted. |
required |
axis |
int
|
Axis along which to sort the tensor. |
0
|
Returns:
Type | Description |
---|---|
ndarray
|
Sorted numpy array. |
Source code in flyvision/utils/nodes_edges_utils.py
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
|
flyvision.utils.nodes_edges_utils.nodes_list_sorting_on_off_unknown ¶
nodes_list_sorting_on_off_unknown(cell_types=None)
Sort node list based on on/off/unknown polarity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cell_types |
Optional[List[str]]
|
List of cell types to sort. If None, uses all types from groundtruth_utils.polarity. |
None
|
Returns:
Type | Description |
---|---|
List[str]
|
Sorted list of cell types. |
Source code in flyvision/utils/nodes_edges_utils.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
|
flyvision.utils.tensor_utils¶
Classes¶
flyvision.utils.tensor_utils.RefTensor ¶
A tensor with reference indices along the last dimension.
Attributes:
Name | Type | Description |
---|---|---|
values |
Tensor
|
The tensor values. |
indices |
Tensor
|
The reference indices. |
Source code in flyvision/utils/tensor_utils.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
|
deref ¶
deref()
Index the values with the given indices in the last dimension.
Source code in flyvision/utils/tensor_utils.py
23 24 25 |
|
clone ¶
clone()
Return a copy of the RefTensor cloning values.
Source code in flyvision/utils/tensor_utils.py
33 34 35 |
|
detach ¶
detach()
Return a copy of the RefTensor detaching values.
Source code in flyvision/utils/tensor_utils.py
37 38 39 |
|
flyvision.utils.tensor_utils.AutoDeref ¶
Bases: dict
An auto-dereferencing namespace.
Dereferencing means that if attributes are RefTensors, getitem will call RefTensor.deref() to obtain the values at the given indices.
Note
Constructed at each forward call in Network. A cache speeds up processing, e.g., for when a parameter is referenced multiple times in the dynamics.
Attributes:
Name | Type | Description |
---|---|---|
_cache |
Dict[str, object]
|
Cache for dereferenced values. |
Source code in flyvision/utils/tensor_utils.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
|
get_as_reftensor ¶
get_as_reftensor(key)
Get the original RefTensor without dereferencing.
Source code in flyvision/utils/tensor_utils.py
130 131 132 |
|
clear_cache ¶
clear_cache()
Clear the cache and return a cloned instance.
Source code in flyvision/utils/tensor_utils.py
134 135 136 137 |
|
detach ¶
detach()
Return a detached copy of the AutoDeref instance.
Source code in flyvision/utils/tensor_utils.py
139 140 141 |
|
Functions¶
flyvision.utils.tensor_utils.detach ¶
detach(obj)
Recursively detach AutoDeref mappings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
AutoDeref
|
The object to detach. |
required |
Returns:
Type | Description |
---|---|
AutoDeref
|
A detached copy of the input object. |
Raises:
Type | Description |
---|---|
TypeError
|
If the object type is not supported. |
Source code in flyvision/utils/tensor_utils.py
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
|
flyvision.utils.tensor_utils.clone ¶
clone(obj)
Recursively clone AutoDeref mappings.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
AutoDeref
|
The object to clone. |
required |
Returns:
Type | Description |
---|---|
AutoDeref
|
A cloned copy of the input object. |
Raises:
Type | Description |
---|---|
TypeError
|
If the object type is not supported. |
Source code in flyvision/utils/tensor_utils.py
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
|
flyvision.utils.tensor_utils.to_numpy ¶
to_numpy(array)
Convert array-like to numpy array.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
array |
Union[ndarray, Tensor, List]
|
The input array-like object. |
required |
Returns:
Type | Description |
---|---|
ndarray
|
A numpy array. |
Raises:
Type | Description |
---|---|
ValueError
|
If the input type cannot be cast to a numpy array. |
Source code in flyvision/utils/tensor_utils.py
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
|
flyvision.utils.tensor_utils.atleast_column_vector ¶
atleast_column_vector(array)
Convert 1d-array-like to column vector n x 1 or return the original.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
array |
ndarray
|
The input array. |
required |
Returns:
Type | Description |
---|---|
ndarray
|
A column vector or the original array if it’s already 2D or higher. |
Source code in flyvision/utils/tensor_utils.py
221 222 223 224 225 226 227 228 229 230 231 232 233 |
|
flyvision.utils.tensor_utils.matrix_mask_by_sub ¶
matrix_mask_by_sub(sub_matrix, matrix)
Create a mask of rows in matrix that are contained in sub_matrix.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sub_matrix |
ndarray
|
Shape (n_rows1, n_columns) |
required |
matrix |
ndarray
|
Shape (n_rows2, n_columns) |
required |
Returns:
Type | Description |
---|---|
NDArray[bool]
|
1D boolean array of length n_rows2 |
Note
n_rows1 !<= n_rows2
Example
sub_matrix = np.array([[1, 2, 3],
[4, 3, 1]])
matrix = np.array([[3, 4, 1],
[4, 3, 1],
[1, 2, 3]])
matrix_mask_by_sub(sub_matrix, matrix)
# array([False, True, True])
Typically, indexing a tensor with indices instead of booleans is faster. Therefore, see also where_equal_rows.
Source code in flyvision/utils/tensor_utils.py
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
|
flyvision.utils.tensor_utils.where_equal_rows ¶
where_equal_rows(
matrix1, matrix2, as_mask=False, astype="|S64"
)
Find indices where matrix1 rows are in matrix2.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
matrix1 |
ndarray
|
First input matrix. |
required |
matrix2 |
ndarray
|
Second input matrix. |
required |
as_mask |
bool
|
If True, return a boolean mask instead of indices. |
False
|
astype |
str
|
Data type to use for comparison. |
'|S64'
|
Returns:
Type | Description |
---|---|
NDArray[int]
|
Array of indices or boolean mask. |
Example
matrix1 = np.array([[1, 2, 3],
[4, 3, 1]])
matrix2 = np.array([[3, 4, 1],
[4, 3, 1],
[1, 2, 3],
[0, 0, 0]])
where_equal_rows(matrix1, matrix2)
# array([2, 1])
matrix2[where_equal_rows(matrix1, matrix2)]
# array([[1, 2, 3],
# [4, 3, 1]])
See also
matrix_mask_by_sub
Source code in flyvision/utils/tensor_utils.py
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
|
flyvision.utils.tensor_utils.broadcast ¶
broadcast(src, other, dim)
Broadcast src
to the shape of other
along dimension dim
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Tensor
|
Source tensor to broadcast. |
required |
other |
Tensor
|
Target tensor to broadcast to. |
required |
dim |
int
|
Dimension along which to broadcast. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
Broadcasted tensor. |
Note
Source code in flyvision/utils/tensor_utils.py
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 |
|
flyvision.utils.tensor_utils.scatter_reduce ¶
scatter_reduce(src, index, dim=-1, mode='mean')
Reduce along dimension dim
using values in the index
tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Tensor
|
Source tensor. |
required |
index |
Tensor
|
Index tensor. |
required |
dim |
int
|
Dimension along which to reduce. |
-1
|
mode |
Literal['mean', 'sum']
|
Reduction mode, either “mean” or “sum”. |
'mean'
|
Returns:
Type | Description |
---|---|
Tensor
|
Reduced tensor. |
Note
Convenience function for torch.scatter_reduce
that broadcasts index
to
the shape of src
along dimension dim
to cohere to pytorch_scatter API.
Source code in flyvision/utils/tensor_utils.py
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 |
|
flyvision.utils.tensor_utils.scatter_mean ¶
scatter_mean(src, index, dim=-1)
Average along dimension dim
using values in the index
tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Tensor
|
Source tensor. |
required |
index |
Tensor
|
Index tensor. |
required |
dim |
int
|
Dimension along which to average. |
-1
|
Returns:
Type | Description |
---|---|
Tensor
|
Averaged tensor. |
Source code in flyvision/utils/tensor_utils.py
388 389 390 391 392 393 394 395 396 397 398 399 |
|
flyvision.utils.tensor_utils.scatter_add ¶
scatter_add(src, index, dim=-1)
Sum along dimension dim
using values in the index
tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src |
Tensor
|
Source tensor. |
required |
index |
Tensor
|
Index tensor. |
required |
dim |
int
|
Dimension along which to sum. |
-1
|
Returns:
Type | Description |
---|---|
Tensor
|
Summed tensor. |
Source code in flyvision/utils/tensor_utils.py
402 403 404 405 406 407 408 409 410 411 412 413 |
|
flyvision.utils.tensor_utils.select_along_axes ¶
select_along_axes(array, indices, dims)
Select indices from array along dims.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
array |
ndarray
|
Array to take indices from. |
required |
indices |
Union[int, Iterable[int]]
|
Indices to take. |
required |
dims |
Union[int, Iterable[int]]
|
Dimensions to take indices from. |
required |
Returns:
Type | Description |
---|---|
ndarray
|
Array with selected indices along specified dimensions. |
Source code in flyvision/utils/tensor_utils.py
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 |
|
flyvision.utils.tensor_utils.asymmetric_weighting ¶
asymmetric_weighting(tensor, gamma=1.0, delta=0.1)
Apply asymmetric weighting to the positive and negative elements of a tensor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor |
Union[NDArray, Tensor]
|
Input tensor. |
required |
gamma |
float
|
Positive weighting factor. |
1.0
|
delta |
float
|
Negative weighting factor. |
0.1
|
Returns:
Type | Description |
---|---|
Union[NDArray, Tensor]
|
Weighted tensor. |
Note
The function is defined as: f(x) = gamma * x if x > 0 else delta * x
Source code in flyvision/utils/tensor_utils.py
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 |
|
flyvision.utils.type_utils¶
Functions¶
flyvision.utils.type_utils.byte_to_str ¶
byte_to_str(obj)
Cast byte elements to string types recursively.
This function recursively converts byte elements to string types in nested data structures.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
obj |
Any
|
The object to be processed. Can be of various types including Mapping, numpy.ndarray, list, tuple, bytes, str, or Number. |
required |
Returns:
Type | Description |
---|---|
Any
|
The input object with all byte elements converted to strings. |
Raises:
Type | Description |
---|---|
TypeError
|
If the input object cannot be cast to a string type. |
Note
This function will cast all byte elements in nested lists or tuples.
Examples:
>>> byte_to_str(b"hello")
'hello'
>>> byte_to_str([b"world", 42, {b"key": b"value"}])
['world', 42, {'key': 'value'}]
Source code in flyvision/utils/type_utils.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
|
flyvision.utils.xarray_joblib_backend¶
Classes¶
flyvision.utils.xarray_joblib_backend.H5XArrayDatasetStoreBackend ¶
Bases: FileSystemStoreBackend
FileSystemStoreBackend subclass for handling xarray.Dataset objects.
This class uses xarray’s to_netcdf and open_dataset methods for Dataset objects and .h5 files.
Attributes:
Name | Type | Description |
---|---|---|
location |
str
|
The base directory for storing items. |
Source code in flyvision/utils/xarray_joblib_backend.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
|
dump_item ¶
dump_item(path, item, verbose=1)
Dump an item to the store.
If the item is an xarray.Dataset or the path ends with ‘.h5’, use xarray.Dataset.to_netcdf. Otherwise, use the superclass method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
List[str]
|
The identifier for the item in the store. |
required |
item |
Any
|
The item to be stored. |
required |
verbose |
int
|
Verbosity level. |
1
|
Source code in flyvision/utils/xarray_joblib_backend.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
|
load_item ¶
load_item(path, verbose=1, msg=None)
Load an item from the store.
If the path ends with ‘.h5’ or the store contains a h5 file, use xarray.open_dataset. Otherwise, use the superclass method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
List[str]
|
The identifier for the item in the store. |
required |
verbose |
int
|
Verbosity level. |
1
|
msg |
Optional[str]
|
Additional message for logging (not used here). |
None
|
Returns:
Type | Description |
---|---|
Any
|
The loaded item, either an xarray.Dataset or the original object. |
Source code in flyvision/utils/xarray_joblib_backend.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
|
contains_item ¶
contains_item(path)
Check if there is an item at the given path.
This method checks for both h5 and pickle files.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
List[str]
|
The identifier for the item in the store. |
required |
Returns:
Type | Description |
---|---|
bool
|
True if the item exists in either h5 or pickle format, False otherwise. |
Source code in flyvision/utils/xarray_joblib_backend.py
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
|
flyvision.utils.xarray_utils¶
Classes¶
flyvision.utils.xarray_utils.CustomAccessor ¶
Custom accessor for xarray objects providing additional functionality.
Attributes:
Name | Type | Description |
---|---|---|
_obj |
The xarray object being accessed. |
Source code in flyvision/utils/xarray_utils.py
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 |
|
plot_traces ¶
plot_traces(
x,
key="",
legend_labels=[],
extra_legend_coords=[],
plot_kwargs={},
**kwargs
)
Plot traces from the xarray object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
str
|
The dimension to use as the x-axis. |
required |
key |
str
|
The key of the data to plot if the dataset is a Dataset. |
''
|
legend_labels |
List[str]
|
List of coordinates to include in the legend. |
[]
|
extra_legend_coords |
List[str]
|
Additional coordinates to include in the legend. |
[]
|
plot_kwargs |
dict
|
Additional keyword arguments to pass to the plot function. |
{}
|
**kwargs |
Query-like conditions on coordinates. |
{}
|
Returns:
Type | Description |
---|---|
Axes
|
The matplotlib axes object containing the plot. |
Example
Overlay stimulus and response traces:
fig, ax = plt.subplots()
r.custom.plot_traces(
key='stimulus',
x='time',
speed=[19, 25],
intensity=1,
angle=90,
u_in=0,
v_in=0,
plot_kwargs=dict(ax=ax),
time='>0,<1.0'
)
r.custom.plot_traces(
key='responses',
x='time',
speed=[19, 25],
cell_type='T4c',
intensity=1,
angle=90,
network_id=0,
plot_kwargs=dict(ax=ax),
time='>0,<1.0'
)
Polar plot:
prs = peak_responses(stims_and_resps_moving_edges).custom.where(
cell_type="T4c",
intensity=1,
speed=19,
)
prs['angle'] = np.radians(prs.angle)
ax = plt.subplots(subplot_kw={"projection": "polar"})[1]
prs.custom.plot_traces(
x="angle",
legend_labels=["network_id"],
plot_kwargs={"add_legend": False, "ax": ax, "color": "b"},
)
Source code in flyvision/utils/xarray_utils.py
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 |
|
Functions¶
flyvision.utils.xarray_utils.where_xarray ¶
where_xarray(dataset, rtol=1e-05, atol=1e-08, **kwargs)
Return a subset of the xarray Dataset or DataArray where coordinates meet specified query-like conditions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
Dataset | DataArray
|
The dataset or data array to filter. |
required |
rtol |
float
|
Relative tolerance for floating point comparisons. |
1e-05
|
atol |
float
|
Absolute tolerance for floating point comparisons. |
1e-08
|
**kwargs |
Query-like conditions on coordinates. Conditions can be specified as: - Strings with comma-separated conditions (interpreted as AND). - Iterables (lists, tuples) representing multiple conditions (interpreted as OR). - Single values for equality conditions. |
{}
|
Returns:
Type | Description |
---|---|
Dataset | DataArray
|
The filtered dataset or data array. |
Example
filtered_ds = where_xarray(
ds,
cell_type=["T4a", "T4b"],
time="<1.0,>0",
intensity=1.0,
radius=6,
width=2.4
)
Source code in flyvision/utils/xarray_utils.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
|
flyvision.utils.xarray_utils.plot_traces ¶
plot_traces(
dataset,
key,
x,
legend_labels=[],
extra_legend_coords=[],
plot_kwargs={},
**kwargs
)
Plot the flash response traces from the dataset, optionally filtered by various parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dataset |
DataArray | Dataset
|
The dataset containing the responses to plot. |
required |
key |
str
|
The key of the data to plot if the dataset is a Dataset. |
required |
x |
str
|
The dimension to use as the x-axis. |
required |
legend_labels |
List[str]
|
List of coordinates to include in the legend. |
[]
|
extra_legend_coords |
List[str]
|
Additional coordinates to include in the legend. |
[]
|
plot_kwargs |
dict
|
Additional keyword arguments to pass to the plot function. |
{}
|
**kwargs |
Query-like conditions on coordinates. |
{}
|
Returns:
Type | Description |
---|---|
Axes
|
The matplotlib axes object containing the plot. |
Note
Query-like conditions can be specified as:
- Strings with comma-separated conditions (e.g., time=’<0.5,>0.1’)
- Lists for equality conditions (e.g., cell_type=[“T4a”, “T4b”])
- Single values for equality conditions (e.g., intensity=1.0)
Source code in flyvision/utils/xarray_utils.py
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
|