nwb
sleap_io.io.nwb
¶
Harmonization layer for NWB I/O operations.
This module provides a unified interface for reading and writing SLEAP data to/from NWB files, automatically detecting and routing to the appropriate backend based on the data format (annotations vs predictions).
Classes:
Name | Description |
---|---|
NwbFormat |
NWB format types for SLEAP data. |
Functions:
Name | Description |
---|---|
append_nwb |
Append a SLEAP |
append_nwb_data |
Append data from a Labels object to an in-memory nwb file. |
load_nwb |
Load an NWB dataset as a SLEAP Labels object. |
read_nwb |
Read an NWB formatted file to a SLEAP |
save_nwb |
Save a SLEAP dataset to NWB format. |
write_nwb |
Write labels to an nwb file and save it to the nwbfile_path given. |
NwbFormat
¶
append_nwb(labels, filename, pose_estimation_metadata=None)
¶
Append a SLEAP Labels
object to an existing NWB data file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels
|
Labels
|
A general |
required |
filename
|
str
|
The path to the NWB file. |
required |
pose_estimation_metadata
|
Optional[dict]
|
Metadata for pose estimation. See |
None
|
See also: append_nwb_data
Source code in sleap_io/io/nwb_predictions.py
def append_nwb(
labels: Labels, filename: str, pose_estimation_metadata: Optional[dict] = None
):
"""Append a SLEAP `Labels` object to an existing NWB data file.
Args:
labels: A general `Labels` object.
filename: The path to the NWB file.
pose_estimation_metadata: Metadata for pose estimation. See `append_nwb_data`
for details.
See also: append_nwb_data
"""
with NWBHDF5IO(filename, mode="a", load_namespaces=True) as io:
nwb_file = io.read()
nwb_file = append_nwb_data(
labels, nwb_file, pose_estimation_metadata=pose_estimation_metadata
)
io.write(nwb_file)
append_nwb_data(labels, nwbfile, pose_estimation_metadata=None, skeleton_map=None)
¶
Append data from a Labels object to an in-memory nwb file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels
|
Labels
|
A general labels object |
required |
nwbfile
|
NWBFile
|
And in-memory nwbfile where the data is to be appended. |
required |
pose_estimation_metadata
|
Optional[dict]
|
This argument has a dual purpose: 1) It can be used to pass time information about the video which is
necessary for synchronizing frames in pose estimation tracking to other
modalities. Either the video timestamps can be passed to
This can be used to pass the timestamps with the key e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) or pose_estimation_metadata["video_sample_rate"] = 15 # In Hz 2) The other use of this dictionary is to overwrite sleap-io default arguments for the PoseEstimation container. see https://github.com/rly/ndx-pose for a full list or arguments. |
None
|
skeleton_map
|
Optional[Dict[str, Skeleton]]
|
Mapping of skeleton names to NWB Skeleton objects. |
None
|
Returns:
Type | Description |
---|---|
NWBFile
|
An in-memory nwbfile with the data from the labels object appended. |
Source code in sleap_io/io/nwb_predictions.py
def append_nwb_data(
labels: Labels,
nwbfile: NWBFile,
pose_estimation_metadata: Optional[dict] = None,
skeleton_map: Optional[Dict[str, Skeleton]] = None,
) -> NWBFile:
"""Append data from a Labels object to an in-memory nwb file.
Args:
labels: A general labels object
nwbfile: And in-memory nwbfile where the data is to be appended.
pose_estimation_metadata: This argument has a dual purpose:
1) It can be used to pass time information about the video which is
necessary for synchronizing frames in pose estimation tracking to other
modalities. Either the video timestamps can be passed to
This can be used to pass the timestamps with the key `video_timestamps`
or the sampling rate with key`video_sample_rate`.
e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps)
or pose_estimation_metadata["video_sample_rate"] = 15 # In Hz
2) The other use of this dictionary is to overwrite sleap-io default
arguments for the PoseEstimation container.
see https://github.com/rly/ndx-pose for a full list or arguments.
skeleton_map: Mapping of skeleton names to NWB Skeleton objects.
Returns:
An in-memory nwbfile with the data from the labels object appended.
"""
pose_estimation_metadata = pose_estimation_metadata or dict()
if skeleton_map is None:
skeleton_map = create_skeleton_container(labels=labels, nwbfile=nwbfile)
# Extract default metadata
provenance = labels.provenance
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
labels_data_df = convert_predictions_to_dataframe(labels)
# For every video create a processing module
for video_index, video in enumerate(labels.videos):
video_path = Path(video.filename)
processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}"
nwb_processing_module = get_processing_module_for_video(
processing_module_name, nwbfile
)
device_name = f"camera_{video_index}"
if device_name in nwbfile.devices:
device = nwbfile.devices[device_name]
else:
device = nwbfile.create_device(
name=device_name,
description=f"Camera for {video_path.name}",
manufacturer="Unknown",
)
# Propagate video metadata
default_metadata["original_videos"] = [f"{video.filename}"] # type: ignore
default_metadata["labeled_videos"] = [f"{video.filename}"] # type: ignore
# Overwrite default with the user provided metadata
default_metadata.update(pose_estimation_metadata)
# For every track in that video create a PoseEstimation container
name_of_tracks_in_video = (
labels_data_df[video.filename]
.columns.get_level_values("track_name")
.unique()
)
for track_index, track_name in enumerate(name_of_tracks_in_video):
pose_estimation_container = build_pose_estimation_container_for_track(
labels_data_df,
labels,
track_name,
video,
default_metadata,
skeleton_map,
devices=[device],
)
nwb_processing_module.add(pose_estimation_container)
return nwbfile
load_nwb(filename)
¶
Load an NWB dataset as a SLEAP Labels object.
Automatically detects whether the file contains PoseTraining (annotations) or PoseEstimation (predictions) data and uses the appropriate backend.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
filename
|
Union[str, Path]
|
Path to an NWB file (.nwb). |
required |
Returns:
Type | Description |
---|---|
Labels
|
The dataset as a Labels object. |
Raises:
Type | Description |
---|---|
ValueError
|
If the NWB file doesn't contain recognized pose data. |
Source code in sleap_io/io/nwb.py
def load_nwb(filename: Union[str, Path]) -> Labels:
"""Load an NWB dataset as a SLEAP Labels object.
Automatically detects whether the file contains PoseTraining (annotations)
or PoseEstimation (predictions) data and uses the appropriate backend.
Args:
filename: Path to an NWB file (.nwb).
Returns:
The dataset as a Labels object.
Raises:
ValueError: If the NWB file doesn't contain recognized pose data.
"""
from sleap_io.io import nwb_annotations, nwb_predictions
filename = Path(filename)
# Check what type of data is in the file
with h5py.File(filename, "r") as f:
# Check for behavior processing module with PoseTraining (annotations)
if "processing" in f and "behavior" in f["processing"]:
behavior = f["processing"]["behavior"]
# Check for PoseTraining (annotations)
if "PoseTraining" in behavior:
return nwb_annotations.load_labels(filename)
# Check for PoseEstimation in behavior module (old format)
for key in behavior.keys():
if key not in ["PoseTraining", "Skeletons"]:
if "neurodata_type" in behavior[key].attrs:
if behavior[key].attrs["neurodata_type"] == "PoseEstimation":
return nwb_predictions.read_nwb(filename)
# Check for PoseEstimation in separate processing modules (predictions)
if "processing" in f:
for module_name in f["processing"].keys():
if module_name != "behavior": # Skip behavior module (already checked)
module = f["processing"][module_name]
# Look for PoseEstimation containers
for key in module.keys():
if "neurodata_type" in module[key].attrs:
if module[key].attrs["neurodata_type"] == "PoseEstimation":
return nwb_predictions.read_nwb(filename)
raise ValueError(
f"NWB file '{filename}' does not contain recognized pose data "
"(neither PoseTraining nor PoseEstimation found in behavior module)"
)
read_nwb(path)
¶
Read an NWB formatted file to a SLEAP Labels
object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
str
|
Path to an NWB file ( |
required |
Returns:
Type | Description |
---|---|
Labels
|
A |
Source code in sleap_io/io/nwb_predictions.py
def read_nwb(path: str) -> Labels:
"""Read an NWB formatted file to a SLEAP `Labels` object.
Args:
path: Path to an NWB file (`.nwb`).
Returns:
A `Labels` object.
"""
with NWBHDF5IO(path, mode="r", load_namespaces=True) as io:
read_nwbfile = io.read()
nwb_file_processing = read_nwbfile.processing
# Get list of videos
video_keys: List[str] = [
key for key in nwb_file_processing.keys() if "SLEAP_VIDEO" in key
]
video_tracks = dict()
# Get track keys from first video's processing module
test_processing_module: ProcessingModule = nwb_file_processing[video_keys[0]]
track_keys: List[str] = list(test_processing_module.fields["data_interfaces"])
# Get first track's skeleton
test_pose_estimation: PoseEstimation = test_processing_module[track_keys[0]]
skeleton = test_pose_estimation.skeleton
skeleton_nodes = skeleton.nodes[:]
skeleton_edges = skeleton.edges[:]
# Filtering out behavior module with skeletons
pose_estimation_container_modules = [
nwb_file_processing[key] for key in video_keys
]
for processing_module in pose_estimation_container_modules:
# Get track keys
_track_keys: List[str] = list(processing_module.fields["data_interfaces"])
is_tracked: bool = re.sub("[0-9]+", "", _track_keys[0]) == "track"
# Figure out the max number of frames and the canonical timestamps
timestamps = np.empty(())
for track_key in _track_keys:
pose_estimation = processing_module[track_key]
for node_name in skeleton.nodes:
pose_estimation_series = pose_estimation[node_name]
timestamps = np.union1d(
timestamps, get_timestamps(pose_estimation_series)
)
timestamps = np.sort(timestamps)
# Recreate Labels numpy (same as output of Labels.numpy())
n_tracks = len(_track_keys)
n_frames = len(timestamps)
n_nodes = len(skeleton.nodes)
tracks_numpy = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, np.float32)
confidence = np.full((n_frames, n_tracks, n_nodes), np.nan, np.float32)
for track_idx, track_key in enumerate(_track_keys):
pose_estimation = processing_module[track_key]
for node_idx, node_name in enumerate(skeleton.nodes):
pose_estimation_series = pose_estimation[node_name]
frame_inds = np.searchsorted(
timestamps, get_timestamps(pose_estimation_series)
)
tracks_numpy[frame_inds, track_idx, node_idx, :] = (
pose_estimation_series.data[:]
)
confidence[frame_inds, track_idx, node_idx] = (
pose_estimation_series.confidence[:]
)
video_tracks[Path(pose_estimation.original_videos[0]).as_posix()] = (
tracks_numpy,
confidence,
is_tracked,
)
# Create SLEAP skeleton from NWB skeleton
sleap_skeleton = SleapSkeleton(
nodes=skeleton_nodes,
edges=skeleton_edges.tolist(),
)
# Add instances to labeled frames
lfs = []
for video_fn, (tracks_numpy, confidence, is_tracked) in video_tracks.items():
video = Video(filename=video_fn)
n_frames, n_tracks, n_nodes, _ = tracks_numpy.shape
tracks = [Track(name=f"track{track_idx}") for track_idx in range(n_tracks)]
for frame_idx, (frame_pts, frame_confs) in enumerate(
zip(tracks_numpy, confidence)
):
insts: List[Union[Instance, PredictedInstance]] = []
for track, (inst_pts, inst_confs) in zip(
tracks, zip(frame_pts, frame_confs)
):
if np.isnan(inst_pts).all():
continue
insts.append(
PredictedInstance.from_numpy(
points_data=np.column_stack(
[inst_pts, inst_confs]
), # (n_nodes, 3)
score=inst_confs.mean(), # ()
skeleton=sleap_skeleton,
track=track if is_tracked else None,
)
)
if len(insts) > 0:
lfs.append(
LabeledFrame(video=video, frame_idx=frame_idx, instances=insts)
)
labels = Labels(lfs)
labels.provenance["filename"] = path
return labels
save_nwb(labels, filename, nwb_format=NwbFormat.AUTO, append=False)
¶
Save a SLEAP dataset to NWB format.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels
|
Labels
|
A SLEAP Labels object to save. |
required |
filename
|
Union[str, Path]
|
Path to NWB file to save to. Must end in '.nwb'. |
required |
nwb_format
|
Union[NwbFormat, str]
|
Format to use for saving. Options are: - "auto" (default): Automatically detect based on data - "annotations": Save training annotations (PoseTraining) - "annotations_export": Export annotations with video frames - "predictions": Save predictions (PoseEstimation) |
AUTO
|
append
|
bool
|
If True, append to existing NWB file. Only supported for predictions format. Defaults to False. |
False
|
Raises:
Type | Description |
---|---|
ValueError
|
If an invalid format is specified. |
Source code in sleap_io/io/nwb.py
def save_nwb(
labels: Labels,
filename: Union[str, Path],
nwb_format: Union[NwbFormat, str] = NwbFormat.AUTO,
append: bool = False,
) -> None:
"""Save a SLEAP dataset to NWB format.
Args:
labels: A SLEAP Labels object to save.
filename: Path to NWB file to save to. Must end in '.nwb'.
nwb_format: Format to use for saving. Options are:
- "auto" (default): Automatically detect based on data
- "annotations": Save training annotations (PoseTraining)
- "annotations_export": Export annotations with video frames
- "predictions": Save predictions (PoseEstimation)
append: If True, append to existing NWB file. Only supported for
predictions format. Defaults to False.
Raises:
ValueError: If an invalid format is specified.
"""
from sleap_io.io import nwb_annotations, nwb_predictions
filename = Path(filename)
# Convert string to enum if needed
if isinstance(nwb_format, str):
try:
nwb_format = NwbFormat(nwb_format)
except ValueError:
raise ValueError(
f"Invalid NWB format: '{nwb_format}'. "
f"Must be one of: {', '.join(f.value for f in NwbFormat)}"
)
# Auto-detect format if needed
if nwb_format == NwbFormat.AUTO:
# Check if there are any user instances
has_user_instances = any(lf.has_user_instances for lf in labels.labeled_frames)
if has_user_instances:
nwb_format = NwbFormat.ANNOTATIONS
else:
nwb_format = NwbFormat.PREDICTIONS
# Route to appropriate backend
if nwb_format == NwbFormat.ANNOTATIONS:
nwb_annotations.save_labels(labels, filename)
elif nwb_format == NwbFormat.ANNOTATIONS_EXPORT:
# Use export_labels for the export format
output_dir = filename.parent
nwb_filename = filename.name
nwb_annotations.export_labels(
labels,
output_dir=output_dir,
nwb_filename=nwb_filename,
clean=True, # Clean up intermediate files
)
elif nwb_format == NwbFormat.PREDICTIONS:
if append:
nwb_predictions.append_nwb(labels, str(filename))
else:
nwb_predictions.write_nwb(labels, filename)
else:
raise ValueError(f"Unexpected NWB format: {nwb_format}")
write_nwb(labels, nwbfile_path, nwb_file_kwargs=None, pose_estimation_metadata=None)
¶
Write labels to an nwb file and save it to the nwbfile_path given.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
labels
|
Labels
|
A general |
required |
nwbfile_path
|
str
|
The path where the nwb file is to be written. |
required |
nwb_file_kwargs
|
Optional[dict]
|
A dict containing metadata to the nwbfile. Example: nwb_file_kwargs = { 'session_description: 'your_session_description', 'identifier': 'your session_identifier', } For a full list of possible values see: https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile Defaults to None and default values are used to generate the nwb file. |
None
|
pose_estimation_metadata
|
Optional[dict]
|
This argument has a dual purpose: 1) It can be used to pass time information about the video which is
necessary for synchronizing frames in pose estimation tracking to other
modalities. Either the video timestamps can be passed to
This can be used to pass the timestamps with the key e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) or pose_estimation_metadata["video_sample_rate] = 15 # In Hz 2) The other use of this dictionary is to overwrite sleap-io default arguments for the PoseEstimation container. see https://github.com/rly/ndx-pose for a full list or arguments. |
None
|
Source code in sleap_io/io/nwb_predictions.py
def write_nwb(
labels: Labels,
nwbfile_path: str,
nwb_file_kwargs: Optional[dict] = None,
pose_estimation_metadata: Optional[dict] = None,
):
"""Write labels to an nwb file and save it to the nwbfile_path given.
Args:
labels: A general `Labels` object.
nwbfile_path: The path where the nwb file is to be written.
nwb_file_kwargs: A dict containing metadata to the nwbfile. Example:
nwb_file_kwargs = {
'session_description: 'your_session_description',
'identifier': 'your session_identifier',
}
For a full list of possible values see:
https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile
Defaults to None and default values are used to generate the nwb file.
pose_estimation_metadata: This argument has a dual purpose:
1) It can be used to pass time information about the video which is
necessary for synchronizing frames in pose estimation tracking to other
modalities. Either the video timestamps can be passed to
This can be used to pass the timestamps with the key `video_timestamps`
or the sampling rate with key`video_sample_rate`.
e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps)
or pose_estimation_metadata["video_sample_rate] = 15 # In Hz
2) The other use of this dictionary is to overwrite sleap-io default
arguments for the PoseEstimation container.
see https://github.com/rly/ndx-pose for a full list or arguments.
"""
nwb_file_kwargs = nwb_file_kwargs or dict()
# Add required values for nwbfile if not present
session_description = nwb_file_kwargs.get(
"session_description", "Processed SLEAP pose data"
)
session_start_time = nwb_file_kwargs.get(
"session_start_time", datetime.datetime.now(datetime.timezone.utc)
)
identifier = nwb_file_kwargs.get("identifier", str(uuid.uuid1()))
nwb_file_kwargs.update(
session_description=session_description,
session_start_time=session_start_time,
identifier=identifier,
)
nwbfile = NWBFile(**nwb_file_kwargs)
# Create skeleton containers first
skeleton_map = create_skeleton_container(labels, nwbfile)
# Then append pose data
nwbfile = append_nwb_data(labels, nwbfile, pose_estimation_metadata, skeleton_map)
with NWBHDF5IO(str(nwbfile_path), "w") as io:
io.write(nwbfile)