Skip to content

nwb

sleap_io.io.nwb

Harmonization layer for NWB I/O operations.

This module provides a unified interface for reading and writing SLEAP data to/from NWB files, automatically detecting and routing to the appropriate backend based on the data format (annotations vs predictions).

Classes:

Name Description
NwbFormat

NWB format types for SLEAP data.

Functions:

Name Description
append_nwb

Append a SLEAP Labels object to an existing NWB data file.

append_nwb_data

Append data from a Labels object to an in-memory nwb file.

load_nwb

Load an NWB dataset as a SLEAP Labels object.

read_nwb

Read an NWB formatted file to a SLEAP Labels object.

save_nwb

Save a SLEAP dataset to NWB format.

write_nwb

Write labels to an nwb file and save it to the nwbfile_path given.

NwbFormat

Bases: str, Enum

NWB format types for SLEAP data.

Source code in sleap_io/io/nwb.py
class NwbFormat(str, Enum):
    """NWB format types for SLEAP data."""

    AUTO = "auto"
    ANNOTATIONS = "annotations"
    ANNOTATIONS_EXPORT = "annotations_export"
    PREDICTIONS = "predictions"

append_nwb(labels, filename, pose_estimation_metadata=None)

Append a SLEAP Labels object to an existing NWB data file.

Parameters:

Name Type Description Default
labels Labels

A general Labels object.

required
filename str

The path to the NWB file.

required
pose_estimation_metadata Optional[dict]

Metadata for pose estimation. See append_nwb_data for details.

None

See also: append_nwb_data

Source code in sleap_io/io/nwb_predictions.py
def append_nwb(
    labels: Labels, filename: str, pose_estimation_metadata: Optional[dict] = None
):
    """Append a SLEAP `Labels` object to an existing NWB data file.

    Args:
        labels: A general `Labels` object.
        filename: The path to the NWB file.
        pose_estimation_metadata: Metadata for pose estimation. See `append_nwb_data`
            for details.

    See also: append_nwb_data
    """
    with NWBHDF5IO(filename, mode="a", load_namespaces=True) as io:
        nwb_file = io.read()
        nwb_file = append_nwb_data(
            labels, nwb_file, pose_estimation_metadata=pose_estimation_metadata
        )
        io.write(nwb_file)

append_nwb_data(labels, nwbfile, pose_estimation_metadata=None, skeleton_map=None)

Append data from a Labels object to an in-memory nwb file.

Parameters:

Name Type Description Default
labels Labels

A general labels object

required
nwbfile NWBFile

And in-memory nwbfile where the data is to be appended.

required
pose_estimation_metadata Optional[dict]

This argument has a dual purpose:

1) It can be used to pass time information about the video which is necessary for synchronizing frames in pose estimation tracking to other modalities. Either the video timestamps can be passed to This can be used to pass the timestamps with the key video_timestamps or the sampling rate with keyvideo_sample_rate.

e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) or pose_estimation_metadata["video_sample_rate"] = 15 # In Hz

2) The other use of this dictionary is to overwrite sleap-io default arguments for the PoseEstimation container. see https://github.com/rly/ndx-pose for a full list or arguments.

None
skeleton_map Optional[Dict[str, Skeleton]]

Mapping of skeleton names to NWB Skeleton objects.

None

Returns:

Type Description
NWBFile

An in-memory nwbfile with the data from the labels object appended.

Source code in sleap_io/io/nwb_predictions.py
def append_nwb_data(
    labels: Labels,
    nwbfile: NWBFile,
    pose_estimation_metadata: Optional[dict] = None,
    skeleton_map: Optional[Dict[str, Skeleton]] = None,
) -> NWBFile:
    """Append data from a Labels object to an in-memory nwb file.

    Args:
        labels: A general labels object
        nwbfile: And in-memory nwbfile where the data is to be appended.
        pose_estimation_metadata: This argument has a dual purpose:

            1) It can be used to pass time information about the video which is
            necessary for synchronizing frames in pose estimation tracking to other
            modalities. Either the video timestamps can be passed to
            This can be used to pass the timestamps with the key `video_timestamps`
            or the sampling rate with key`video_sample_rate`.

            e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps)
            or   pose_estimation_metadata["video_sample_rate"] = 15  # In Hz

            2) The other use of this dictionary is to overwrite sleap-io default
            arguments for the PoseEstimation container.
            see https://github.com/rly/ndx-pose for a full list or arguments.
        skeleton_map: Mapping of skeleton names to NWB Skeleton objects.

    Returns:
        An in-memory nwbfile with the data from the labels object appended.
    """
    pose_estimation_metadata = pose_estimation_metadata or dict()
    if skeleton_map is None:
        skeleton_map = create_skeleton_container(labels=labels, nwbfile=nwbfile)

    # Extract default metadata
    provenance = labels.provenance
    default_metadata = dict(scorer=str(provenance))
    sleap_version = provenance.get("sleap_version", None)
    default_metadata["source_software_version"] = sleap_version

    labels_data_df = convert_predictions_to_dataframe(labels)

    # For every video create a processing module
    for video_index, video in enumerate(labels.videos):
        video_path = Path(video.filename)
        processing_module_name = f"SLEAP_VIDEO_{video_index:03}_{video_path.stem}"
        nwb_processing_module = get_processing_module_for_video(
            processing_module_name, nwbfile
        )

        device_name = f"camera_{video_index}"
        if device_name in nwbfile.devices:
            device = nwbfile.devices[device_name]
        else:
            device = nwbfile.create_device(
                name=device_name,
                description=f"Camera for {video_path.name}",
                manufacturer="Unknown",
            )

        # Propagate video metadata
        default_metadata["original_videos"] = [f"{video.filename}"]  # type: ignore
        default_metadata["labeled_videos"] = [f"{video.filename}"]  # type: ignore

        # Overwrite default with the user provided metadata
        default_metadata.update(pose_estimation_metadata)

        # For every track in that video create a PoseEstimation container
        name_of_tracks_in_video = (
            labels_data_df[video.filename]
            .columns.get_level_values("track_name")
            .unique()
        )

        for track_index, track_name in enumerate(name_of_tracks_in_video):
            pose_estimation_container = build_pose_estimation_container_for_track(
                labels_data_df,
                labels,
                track_name,
                video,
                default_metadata,
                skeleton_map,
                devices=[device],
            )
            nwb_processing_module.add(pose_estimation_container)

    return nwbfile

load_nwb(filename)

Load an NWB dataset as a SLEAP Labels object.

Automatically detects whether the file contains PoseTraining (annotations) or PoseEstimation (predictions) data and uses the appropriate backend.

Parameters:

Name Type Description Default
filename Union[str, Path]

Path to an NWB file (.nwb).

required

Returns:

Type Description
Labels

The dataset as a Labels object.

Raises:

Type Description
ValueError

If the NWB file doesn't contain recognized pose data.

Source code in sleap_io/io/nwb.py
def load_nwb(filename: Union[str, Path]) -> Labels:
    """Load an NWB dataset as a SLEAP Labels object.

    Automatically detects whether the file contains PoseTraining (annotations)
    or PoseEstimation (predictions) data and uses the appropriate backend.

    Args:
        filename: Path to an NWB file (.nwb).

    Returns:
        The dataset as a Labels object.

    Raises:
        ValueError: If the NWB file doesn't contain recognized pose data.
    """
    from sleap_io.io import nwb_annotations, nwb_predictions

    filename = Path(filename)

    # Check what type of data is in the file
    with h5py.File(filename, "r") as f:
        # Check for behavior processing module with PoseTraining (annotations)
        if "processing" in f and "behavior" in f["processing"]:
            behavior = f["processing"]["behavior"]

            # Check for PoseTraining (annotations)
            if "PoseTraining" in behavior:
                return nwb_annotations.load_labels(filename)

            # Check for PoseEstimation in behavior module (old format)
            for key in behavior.keys():
                if key not in ["PoseTraining", "Skeletons"]:
                    if "neurodata_type" in behavior[key].attrs:
                        if behavior[key].attrs["neurodata_type"] == "PoseEstimation":
                            return nwb_predictions.read_nwb(filename)

        # Check for PoseEstimation in separate processing modules (predictions)
        if "processing" in f:
            for module_name in f["processing"].keys():
                if module_name != "behavior":  # Skip behavior module (already checked)
                    module = f["processing"][module_name]
                    # Look for PoseEstimation containers
                    for key in module.keys():
                        if "neurodata_type" in module[key].attrs:
                            if module[key].attrs["neurodata_type"] == "PoseEstimation":
                                return nwb_predictions.read_nwb(filename)

    raise ValueError(
        f"NWB file '{filename}' does not contain recognized pose data "
        "(neither PoseTraining nor PoseEstimation found in behavior module)"
    )

read_nwb(path)

Read an NWB formatted file to a SLEAP Labels object.

Parameters:

Name Type Description Default
path str

Path to an NWB file (.nwb).

required

Returns:

Type Description
Labels

A Labels object.

Source code in sleap_io/io/nwb_predictions.py
def read_nwb(path: str) -> Labels:
    """Read an NWB formatted file to a SLEAP `Labels` object.

    Args:
        path: Path to an NWB file (`.nwb`).

    Returns:
        A `Labels` object.
    """
    with NWBHDF5IO(path, mode="r", load_namespaces=True) as io:
        read_nwbfile = io.read()
        nwb_file_processing = read_nwbfile.processing

        # Get list of videos
        video_keys: List[str] = [
            key for key in nwb_file_processing.keys() if "SLEAP_VIDEO" in key
        ]
        video_tracks = dict()

        # Get track keys from first video's processing module
        test_processing_module: ProcessingModule = nwb_file_processing[video_keys[0]]
        track_keys: List[str] = list(test_processing_module.fields["data_interfaces"])

        # Get first track's skeleton
        test_pose_estimation: PoseEstimation = test_processing_module[track_keys[0]]
        skeleton = test_pose_estimation.skeleton
        skeleton_nodes = skeleton.nodes[:]
        skeleton_edges = skeleton.edges[:]

        # Filtering out behavior module with skeletons
        pose_estimation_container_modules = [
            nwb_file_processing[key] for key in video_keys
        ]

        for processing_module in pose_estimation_container_modules:
            # Get track keys
            _track_keys: List[str] = list(processing_module.fields["data_interfaces"])
            is_tracked: bool = re.sub("[0-9]+", "", _track_keys[0]) == "track"

            # Figure out the max number of frames and the canonical timestamps
            timestamps = np.empty(())
            for track_key in _track_keys:
                pose_estimation = processing_module[track_key]
                for node_name in skeleton.nodes:
                    pose_estimation_series = pose_estimation[node_name]
                    timestamps = np.union1d(
                        timestamps, get_timestamps(pose_estimation_series)
                    )
            timestamps = np.sort(timestamps)

            # Recreate Labels numpy (same as output of Labels.numpy())
            n_tracks = len(_track_keys)
            n_frames = len(timestamps)
            n_nodes = len(skeleton.nodes)
            tracks_numpy = np.full((n_frames, n_tracks, n_nodes, 2), np.nan, np.float32)
            confidence = np.full((n_frames, n_tracks, n_nodes), np.nan, np.float32)

            for track_idx, track_key in enumerate(_track_keys):
                pose_estimation = processing_module[track_key]
                for node_idx, node_name in enumerate(skeleton.nodes):
                    pose_estimation_series = pose_estimation[node_name]
                    frame_inds = np.searchsorted(
                        timestamps, get_timestamps(pose_estimation_series)
                    )
                    tracks_numpy[frame_inds, track_idx, node_idx, :] = (
                        pose_estimation_series.data[:]
                    )
                    confidence[frame_inds, track_idx, node_idx] = (
                        pose_estimation_series.confidence[:]
                    )

            video_tracks[Path(pose_estimation.original_videos[0]).as_posix()] = (
                tracks_numpy,
                confidence,
                is_tracked,
            )

    # Create SLEAP skeleton from NWB skeleton
    sleap_skeleton = SleapSkeleton(
        nodes=skeleton_nodes,
        edges=skeleton_edges.tolist(),
    )

    # Add instances to labeled frames
    lfs = []
    for video_fn, (tracks_numpy, confidence, is_tracked) in video_tracks.items():
        video = Video(filename=video_fn)
        n_frames, n_tracks, n_nodes, _ = tracks_numpy.shape
        tracks = [Track(name=f"track{track_idx}") for track_idx in range(n_tracks)]

        for frame_idx, (frame_pts, frame_confs) in enumerate(
            zip(tracks_numpy, confidence)
        ):
            insts: List[Union[Instance, PredictedInstance]] = []
            for track, (inst_pts, inst_confs) in zip(
                tracks, zip(frame_pts, frame_confs)
            ):
                if np.isnan(inst_pts).all():
                    continue
                insts.append(
                    PredictedInstance.from_numpy(
                        points_data=np.column_stack(
                            [inst_pts, inst_confs]
                        ),  # (n_nodes, 3)
                        score=inst_confs.mean(),  # ()
                        skeleton=sleap_skeleton,
                        track=track if is_tracked else None,
                    )
                )
            if len(insts) > 0:
                lfs.append(
                    LabeledFrame(video=video, frame_idx=frame_idx, instances=insts)
                )

    labels = Labels(lfs)
    labels.provenance["filename"] = path
    return labels

save_nwb(labels, filename, nwb_format=NwbFormat.AUTO, append=False)

Save a SLEAP dataset to NWB format.

Parameters:

Name Type Description Default
labels Labels

A SLEAP Labels object to save.

required
filename Union[str, Path]

Path to NWB file to save to. Must end in '.nwb'.

required
nwb_format Union[NwbFormat, str]

Format to use for saving. Options are: - "auto" (default): Automatically detect based on data - "annotations": Save training annotations (PoseTraining) - "annotations_export": Export annotations with video frames - "predictions": Save predictions (PoseEstimation)

AUTO
append bool

If True, append to existing NWB file. Only supported for predictions format. Defaults to False.

False

Raises:

Type Description
ValueError

If an invalid format is specified.

Source code in sleap_io/io/nwb.py
def save_nwb(
    labels: Labels,
    filename: Union[str, Path],
    nwb_format: Union[NwbFormat, str] = NwbFormat.AUTO,
    append: bool = False,
) -> None:
    """Save a SLEAP dataset to NWB format.

    Args:
        labels: A SLEAP Labels object to save.
        filename: Path to NWB file to save to. Must end in '.nwb'.
        nwb_format: Format to use for saving. Options are:
            - "auto" (default): Automatically detect based on data
            - "annotations": Save training annotations (PoseTraining)
            - "annotations_export": Export annotations with video frames
            - "predictions": Save predictions (PoseEstimation)
        append: If True, append to existing NWB file. Only supported for
            predictions format. Defaults to False.

    Raises:
        ValueError: If an invalid format is specified.
    """
    from sleap_io.io import nwb_annotations, nwb_predictions

    filename = Path(filename)

    # Convert string to enum if needed
    if isinstance(nwb_format, str):
        try:
            nwb_format = NwbFormat(nwb_format)
        except ValueError:
            raise ValueError(
                f"Invalid NWB format: '{nwb_format}'. "
                f"Must be one of: {', '.join(f.value for f in NwbFormat)}"
            )

    # Auto-detect format if needed
    if nwb_format == NwbFormat.AUTO:
        # Check if there are any user instances
        has_user_instances = any(lf.has_user_instances for lf in labels.labeled_frames)

        if has_user_instances:
            nwb_format = NwbFormat.ANNOTATIONS
        else:
            nwb_format = NwbFormat.PREDICTIONS

    # Route to appropriate backend
    if nwb_format == NwbFormat.ANNOTATIONS:
        nwb_annotations.save_labels(labels, filename)
    elif nwb_format == NwbFormat.ANNOTATIONS_EXPORT:
        # Use export_labels for the export format
        output_dir = filename.parent
        nwb_filename = filename.name
        nwb_annotations.export_labels(
            labels,
            output_dir=output_dir,
            nwb_filename=nwb_filename,
            clean=True,  # Clean up intermediate files
        )
    elif nwb_format == NwbFormat.PREDICTIONS:
        if append:
            nwb_predictions.append_nwb(labels, str(filename))
        else:
            nwb_predictions.write_nwb(labels, filename)
    else:
        raise ValueError(f"Unexpected NWB format: {nwb_format}")

write_nwb(labels, nwbfile_path, nwb_file_kwargs=None, pose_estimation_metadata=None)

Write labels to an nwb file and save it to the nwbfile_path given.

Parameters:

Name Type Description Default
labels Labels

A general Labels object.

required
nwbfile_path str

The path where the nwb file is to be written.

required
nwb_file_kwargs Optional[dict]

A dict containing metadata to the nwbfile. Example: nwb_file_kwargs = { 'session_description: 'your_session_description', 'identifier': 'your session_identifier', } For a full list of possible values see: https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile

Defaults to None and default values are used to generate the nwb file.

None
pose_estimation_metadata Optional[dict]

This argument has a dual purpose:

1) It can be used to pass time information about the video which is necessary for synchronizing frames in pose estimation tracking to other modalities. Either the video timestamps can be passed to This can be used to pass the timestamps with the key video_timestamps or the sampling rate with keyvideo_sample_rate.

e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) or pose_estimation_metadata["video_sample_rate] = 15 # In Hz

2) The other use of this dictionary is to overwrite sleap-io default arguments for the PoseEstimation container. see https://github.com/rly/ndx-pose for a full list or arguments.

None
Source code in sleap_io/io/nwb_predictions.py
def write_nwb(
    labels: Labels,
    nwbfile_path: str,
    nwb_file_kwargs: Optional[dict] = None,
    pose_estimation_metadata: Optional[dict] = None,
):
    """Write labels to an nwb file and save it to the nwbfile_path given.

    Args:
        labels: A general `Labels` object.
        nwbfile_path: The path where the nwb file is to be written.
        nwb_file_kwargs: A dict containing metadata to the nwbfile. Example:
            nwb_file_kwargs = {
                'session_description: 'your_session_description',
                'identifier': 'your session_identifier',
            }
            For a full list of possible values see:
            https://pynwb.readthedocs.io/en/stable/pynwb.file.html#pynwb.file.NWBFile

            Defaults to None and default values are used to generate the nwb file.

        pose_estimation_metadata: This argument has a dual purpose:

            1) It can be used to pass time information about the video which is
            necessary for synchronizing frames in pose estimation tracking to other
            modalities. Either the video timestamps can be passed to
            This can be used to pass the timestamps with the key `video_timestamps`
            or the sampling rate with key`video_sample_rate`.

            e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps)
            or   pose_estimation_metadata["video_sample_rate] = 15  # In Hz

            2) The other use of this dictionary is to overwrite sleap-io default
            arguments for the PoseEstimation container.
            see https://github.com/rly/ndx-pose for a full list or arguments.
    """
    nwb_file_kwargs = nwb_file_kwargs or dict()

    # Add required values for nwbfile if not present
    session_description = nwb_file_kwargs.get(
        "session_description", "Processed SLEAP pose data"
    )
    session_start_time = nwb_file_kwargs.get(
        "session_start_time", datetime.datetime.now(datetime.timezone.utc)
    )
    identifier = nwb_file_kwargs.get("identifier", str(uuid.uuid1()))

    nwb_file_kwargs.update(
        session_description=session_description,
        session_start_time=session_start_time,
        identifier=identifier,
    )

    nwbfile = NWBFile(**nwb_file_kwargs)

    # Create skeleton containers first
    skeleton_map = create_skeleton_container(labels, nwbfile)

    # Then append pose data
    nwbfile = append_nwb_data(labels, nwbfile, pose_estimation_metadata, skeleton_map)

    with NWBHDF5IO(str(nwbfile_path), "w") as io:
        io.write(nwbfile)