Skip to content

jabs

sleap_io.io.jabs

This module handles direct I/O operations for working with JABS files.

convert_labels(all_labels, video)

Convert a Labels object into JABS-formatted annotations.

Parameters:

Name Type Description Default
all_labels Labels

SLEAP Labels to be converted to JABS format.

required
video Video

name of video to be converted

required

Returns:

Type Description
dict

Dictionary of JABS data of the Labels data.

Source code in sleap_io/io/jabs.py
def convert_labels(all_labels: Labels, video: Video) -> dict:
    """Convert a `Labels` object into JABS-formatted annotations.

    Args:
        all_labels: SLEAP `Labels` to be converted to JABS format.
        video: name of video to be converted

    Returns:
        Dictionary of JABS data of the `Labels` data.
    """
    labels = all_labels.find(video=video)

    # Determine shape of output
    # Low estimate of last frame labeled
    num_frames = max([x.frame_idx for x in labels]) + 1
    # If there is metadata available for the video, use that
    if video.shape:
        num_frames = max(num_frames, video.shape[0])
    num_keypoints = [len(x.nodes) for x in all_labels.skeletons if x.name == "Mouse"][0]
    num_mice = get_max_ids_in_video(labels, key="Mouse")
    # Note that this 1-indexes identities
    track_2_idx = {
        key: val + 1
        for key, val in zip(all_labels.tracks, range(len(all_labels.tracks)))
    }
    last_unassigned_id = num_mice

    keypoint_mat = np.zeros([num_frames, num_mice, num_keypoints, 2], dtype=np.uint16)
    confidence_mat = np.zeros([num_frames, num_mice, num_keypoints], dtype=np.float32)
    identity_mat = np.zeros([num_frames, num_mice], dtype=np.uint32)
    instance_vector = np.zeros([num_frames], dtype=np.uint8)
    static_objects = {}

    # Populate the matrices with data
    for label in labels:
        assigned_instances = 0
        for instance_idx, instance in enumerate(label.instances):
            # Static objects just get added to the object dict
            # This will clobber data if more than one frame is annotated
            if instance.skeleton.name != "Mouse":
                static_objects[instance.skeleton.name] = instance.numpy()
                continue
            pose = instance.numpy()
            if pose.shape[0] != len(JABS_DEFAULT_KEYPOINTS):
                warnings.warn(
                    f"JABS format only supports 12 keypoints for mice. Skipping storage of instance on frame {label.frame_idx} with {len(instance.points)} keypoints."
                )
                continue
            missing_points = np.isnan(pose[:, 0])
            pose[np.isnan(pose)] = 0
            # JABS stores y,x for poses
            pose = np.flip(pose.astype(np.uint16), axis=-1)
            keypoint_mat[label.frame_idx, instance_idx, :, :] = pose
            confidence_mat[label.frame_idx, instance_idx, ~missing_points] = 1.0
            if instance.track:
                identity_mat[label.frame_idx, instance_idx] = track_2_idx[
                    instance.track
                ]
            else:
                warnings.warn(
                    f"Pose with unassigned track found on {label.video.filename} frame {label.frame_idx} instance {instance_idx}. Assigning ID {last_unassigned_id}."
                )
                identity_mat[label.frame_idx, instance_idx] = last_unassigned_id
                last_unassigned_id += 1
            assigned_instances += 1
        instance_vector[label.frame_idx] = assigned_instances

    # Return the data as a dict
    return {
        "keypoints": keypoint_mat.astype(np.uint16),
        "confidence": confidence_mat.astype(np.float32),
        "identity": identity_mat.astype(np.uint32),
        "num_identities": instance_vector.astype(np.uint16),
        "static_objects": static_objects,
    }

get_max_ids_in_video(labels, key='Mouse')

Determine the maximum number of identities that exist at the same time.

Parameters:

Name Type Description Default
labels List[Labels]

SLEAP Labels to count

required
key str

Name of the skeleton to select for identities

'Mouse'

Returns:

Type Description
int

Count of the maximum concurrent identities in a single frame

Source code in sleap_io/io/jabs.py
def get_max_ids_in_video(labels: List[Labels], key: str = "Mouse") -> int:
    """Determine the maximum number of identities that exist at the same time.

    Args:
        labels: SLEAP `Labels` to count
        key: Name of the skeleton to select for identities

    Returns:
        Count of the maximum concurrent identities in a single frame
    """
    max_labels = 0
    for label in labels:
        n_labels = sum([x.skeleton.name == key for x in label.instances])
        max_labels = max(max_labels, n_labels)

    return max_labels

make_simple_skeleton(name, num_points)

Create a Skeleton with a requested number of nodes attached in a line.

Parameters:

Name Type Description Default
name str

name of the skeleton and prefix to nodes

required
num_points int

number of points to use in the skeleton

required

Returns:

Type Description
Skeleton

Generated Skeleton.

Source code in sleap_io/io/jabs.py
def make_simple_skeleton(name: str, num_points: int) -> Skeleton:
    """Create a `Skeleton` with a requested number of nodes attached in a line.

    Args:
        name: name of the skeleton and prefix to nodes
        num_points: number of points to use in the skeleton

    Returns:
        Generated `Skeleton`.
    """
    nodes = [Node(name + "_kp" + str(i)) for i in range(num_points)]
    edges = [Edge(nodes[i], nodes[i + 1]) for i in range(num_points - 1)]
    return Skeleton(nodes, edges, name=name)

prediction_to_instance(data, confidence, skeleton, track=None)

Create an Instance from prediction data.

Parameters:

Name Type Description Default
data Union[ndarray[uint16], ndarray[float32]]

keypoint locations

required
confidence ndarray[float32]

confidence for keypoints

required
skeleton Skeleton

Skeleton to use for Instance

required
track Track

Track to assign to Instance

None

Returns:

Type Description
Instance

Parsed Instance.

Source code in sleap_io/io/jabs.py
def prediction_to_instance(
    data: Union[np.ndarray[np.uint16], np.ndarray[np.float32]],
    confidence: np.ndarray[np.float32],
    skeleton: Skeleton,
    track: Track = None,
) -> Instance:
    """Create an `Instance` from prediction data.

    Args:
        data: keypoint locations
        confidence: confidence for keypoints
        skeleton: `Skeleton` to use for `Instance`
        track: `Track` to assign to `Instance`

    Returns:
        Parsed `Instance`.
    """
    assert (
        len(skeleton.nodes) == data.shape[0]
    ), f"Skeleton ({len(skeleton.nodes)}) does not match number of keypoints ({data.shape[0]})"

    points = {}
    for i, cur_node in enumerate(skeleton.nodes):
        # confidence of 0 indicates no keypoint predicted for instance
        if confidence[i] > 0:
            points[cur_node] = Point(
                data[i, 0],
                data[i, 1],
                visible=True,
            )

    if not points:
        return None
    else:
        return Instance(points, skeleton=skeleton, track=track)

read_labels(labels_path, skeleton=JABS_DEFAULT_SKELETON)

Read JABS style pose from a file and return a Labels object.

TODO: Attributes are ignored, including px_to_cm field. TODO: Segmentation data ignored in v6, but will read in pose. TODO: Lixit static objects currently stored as n_lixit,2 (eg 1 object). Should be converted to multiple objects

Parameters:

Name Type Description Default
labels_path str

Path to the JABS pose file.

required
skeleton Optional[Skeleton]

An optional Skeleton object. Defaults to JABS pose version 2-6.

JABS_DEFAULT_SKELETON

Returns:

Type Description
Labels

Parsed labels as a Labels instance.

Source code in sleap_io/io/jabs.py
def read_labels(
    labels_path: str, skeleton: Optional[Skeleton] = JABS_DEFAULT_SKELETON
) -> Labels:
    """Read JABS style pose from a file and return a `Labels` object.

    TODO: Attributes are ignored, including px_to_cm field.
    TODO: Segmentation data ignored in v6, but will read in pose.
    TODO: Lixit static objects currently stored as n_lixit,2 (eg 1 object). Should be converted to multiple objects

    Args:
        labels_path: Path to the JABS pose file.
        skeleton: An optional `Skeleton` object. Defaults to JABS pose version 2-6.

    Returns:
        Parsed labels as a `Labels` instance.
    """
    frames: List[LabeledFrame] = []
    # Video name is the pose file minus the suffix
    video_name = re.sub(r"(_pose_est_v[2-6])?\.h5", ".avi", labels_path)
    video = Video.from_filename(video_name)
    if not skeleton:
        skeleton = JABS_DEFAULT_SKELETON
    tracks = {}

    if not os.access(labels_path, os.F_OK):
        raise FileNotFoundError(f"{labels_path} doesn't exist.")
    if not os.access(labels_path, os.R_OK):
        raise PermissionError(f"{labels_path} cannot be accessed.")

    with h5py.File(labels_path, "r") as pose_file:
        num_frames = pose_file["poseest/points"].shape[0]
        try:
            pose_version = pose_file["poseest"].attrs["version"][0]
        except (KeyError, IndexError):
            pose_version = 2
            data_shape = pose_file["poseest/points"].shape
            assert (
                len(data_shape) == 3
            ), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}"
        if pose_version == 2:
            tracks[1] = Track("1")
        # Change field name for newer pose formats
        if pose_version == 3:
            id_key = "instance_track_id"
        elif pose_version > 3:
            id_key = "instance_embed_id"
            max_ids = pose_file["poseest/points"].shape[1]

        for frame_idx in range(num_frames):
            instances = []
            pose_data = pose_file["poseest/points"][frame_idx, ...]
            # JABS stores y,x for poses
            pose_data = np.flip(pose_data, axis=-1)
            pose_conf = pose_file["poseest/confidence"][frame_idx, ...]
            # single animal case
            if pose_version == 2:
                new_instance = prediction_to_instance(
                    pose_data, pose_conf, skeleton, tracks[1]
                )
                instances.append(new_instance)
            # multi-animal case
            if pose_version > 2:
                pose_ids = pose_file["poseest/" + id_key][frame_idx, ...]
                # pose_v3 uses another field to describe the number of valid poses
                if pose_version == 3:
                    max_ids = pose_file["poseest/instance_count"][frame_idx]
                for cur_id in range(max_ids):
                    # v4+ uses reserved values for invalid/unused poses
                    # Note: ignores 'poseest/id_mask' to keep predictions that were not assigned an id
                    if pose_version > 3 and pose_ids[cur_id] <= 0:
                        continue
                    if pose_ids[cur_id] not in tracks.keys():
                        tracks[pose_ids[cur_id]] = Track(str(pose_ids[cur_id]))
                    new_instance = prediction_to_instance(
                        pose_data[cur_id],
                        pose_conf[cur_id],
                        skeleton,
                        tracks[pose_ids[cur_id]],
                    )
                    if new_instance:
                        instances.append(new_instance)
            # Static objects
            if (
                frame_idx == 0
                and pose_version >= 5
                and "static_objects" in pose_file.keys()
            ):
                present_objects = pose_file["static_objects"].keys()
                for cur_object in present_objects:
                    object_keypoints = pose_file["static_objects/" + cur_object][:]
                    object_skeleton = make_simple_skeleton(
                        cur_object, object_keypoints.shape[0]
                    )
                    new_instance = prediction_to_instance(
                        object_keypoints,
                        np.ones(object_keypoints.shape[:-1]),
                        object_skeleton,
                    )
                    if new_instance:
                        instances.append(new_instance)
            frame_label = LabeledFrame(video, frame_idx, instances)
            frames.append(frame_label)
    labels = Labels(frames)
    labels.provenance["filename"] = labels_path
    return labels

tracklets_to_v3(tracklet_matrix)

Changes identity tracklets to the v3 format specifications.

v3 specifications require

(a) tracklets are 0-indexed (b) tracklets appear in ascending order © tracklets exist for continuous blocks of time

Parameters:

Name Type Description Default
tracklet_matrix ndarray

Numpy array of shape (frame, n_animals) that contains identity values. Identities are assumed to be 1-indexed.

required

Returns:

Type Description
ndarray

A corrected numpy array of the same shape as input

Source code in sleap_io/io/jabs.py
def tracklets_to_v3(tracklet_matrix: np.ndarray) -> np.ndarray:
    """Changes identity tracklets to the v3 format specifications.

    v3 specifications require:
        (a) tracklets are 0-indexed
        (b) tracklets appear in ascending order
        (c) tracklets exist for continuous blocks of time

    Args:
        tracklet_matrix: Numpy array of shape (frame, n_animals) that contains identity values. Identities are assumed to be 1-indexed.

    Returns:
        A corrected numpy array of the same shape as input
    """
    assert tracklet_matrix.ndim == 2

    # Fragment the tracklets based on gaps
    valid_ids = np.unique(tracklet_matrix)
    valid_ids = valid_ids[valid_ids != 0]
    track_fragments = {}
    for cur_id in valid_ids:
        frame_idx, column_idx = np.where(tracklet_matrix == cur_id)
        gaps = np.nonzero(np.diff(frame_idx) - 1)[0]
        for sliced_frame, sliced_column in zip(
            np.split(frame_idx, gaps + 1), np.split(column_idx, gaps + 1)
        ):
            # The keys used here are (first frame, first column) such that sorting can be used for ascending order
            track_fragments[sliced_frame[0], sliced_column[0]] = sliced_column

    return_mat = np.zeros_like(tracklet_matrix)
    for next_id, key in enumerate(sorted(track_fragments.keys())):
        columns_to_assign = track_fragments[key]
        return_mat[
            range(key[0], key[0] + len(columns_to_assign)), columns_to_assign
        ] = next_id

    return return_mat

write_jabs_v2(data, filename)

Write JABS pose file v2 data to file.

Writes single mouse pose data.

Parameters:

Name Type Description Default
data dict

Dictionary of JABS data generated from convert_labels

required
filename str

Filename to write data to

required
Source code in sleap_io/io/jabs.py
def write_jabs_v2(data: dict, filename: str):
    """Write JABS pose file v2 data to file.

    Writes single mouse pose data.

    Args:
        data: Dictionary of JABS data generated from convert_labels
        filename: Filename to write data to
    """
    # Check that we're trying to write single mouse data
    assert data["keypoints"].shape[1] == 1
    out_keypoints = np.squeeze(data["keypoints"], axis=1)
    out_confidences = np.squeeze(data["confidence"], axis=1)

    with h5py.File(filename, "w") as h5:
        pose_grp = h5.require_group("poseest")
        pose_grp.attrs.update({"version": [2, 0]})
        pose_grp.require_dataset(
            "points", out_keypoints.shape, out_keypoints.dtype, data=out_keypoints
        )
        pose_grp.require_dataset(
            "confidence",
            out_confidences.shape,
            out_confidences.dtype,
            data=out_confidences,
        )

write_jabs_v3(data, filename)

Write JABS pose file v3 data to file.

Writes multi-mouse pose data.

Parameters:

Name Type Description Default
data dict

Dictionary of JABS data generated from convert_labels

required
filename str

Filename to write data to

required
Source code in sleap_io/io/jabs.py
def write_jabs_v3(data: dict, filename: str):
    """Write JABS pose file v3 data to file.

    Writes multi-mouse pose data.

    Args:
        data: Dictionary of JABS data generated from convert_labels
        filename: Filename to write data to
    """
    v3_tracklets = tracklets_to_v3(data["identity"])
    with h5py.File(filename, "w") as h5:
        pose_grp = h5.require_group("poseest")
        pose_grp.attrs.update({"version": [3, 0]})
        # keypoint field
        pose_grp.require_dataset(
            "points",
            data["keypoints"].shape,
            data["keypoints"].dtype,
            data=data["keypoints"],
        )
        # confidence field
        pose_grp.require_dataset(
            "confidence",
            data["confidence"].shape,
            data["confidence"].dtype,
            data=data["confidence"],
        )
        # id field
        pose_grp.require_dataset(
            "instance_track_id",
            v3_tracklets.shape,
            v3_tracklets.dtype,
            data=v3_tracklets,
        )
        # instance count field
        pose_grp.require_dataset(
            "instance_count",
            data["num_identities"].shape,
            data["num_identities"].dtype,
            data=data["num_identities"],
        )
        # extra field where we don't have data, so fill with default data
        pose_grp.require_dataset(
            "instance_embedding",
            data["confidence"].shape,
            data["confidence"].dtype,
            data=np.zeros_like(data["confidence"]),
        )

write_jabs_v4(data, filename)

Write JABS pose file v4 data to file.

Writes multi-mouse pose and longterm identity object data.

Parameters:

Name Type Description Default
data dict

Dictionary of JABS data generated from convert_labels

required
filename str

Filename to write data to

required
Source code in sleap_io/io/jabs.py
def write_jabs_v4(data: dict, filename: str):
    """Write JABS pose file v4 data to file.

    Writes multi-mouse pose and longterm identity object data.

    Args:
        data: Dictionary of JABS data generated from convert_labels
        filename: Filename to write data to
    """
    # v4 extends v3
    write_jabs_v3(data, filename)
    with h5py.File(filename, "a") as h5:
        pose_grp = h5.require_group("poseest")
        pose_grp.attrs.update({"version": [4, 0]})
        # new fields on top of v4
        identity_mask_mat = np.all(data["confidence"] == 0, axis=-1).astype(bool)
        pose_grp.require_dataset(
            "id_mask",
            identity_mask_mat.shape,
            identity_mask_mat.dtype,
            data=identity_mask_mat,
        )
        # No identity embedding data
        # Note that since the identity information doesn't exist, this will break any functionality that relies on it
        default_id_embeds = np.zeros(
            list(identity_mask_mat.shape) + [0], dtype=np.float32
        )
        pose_grp.require_dataset(
            "identity_embeds",
            default_id_embeds.shape,
            default_id_embeds.dtype,
            data=default_id_embeds,
        )
        default_id_centers = np.zeros(default_id_embeds.shape[1:], dtype=np.float32)
        pose_grp.require_dataset(
            "instance_id_center",
            default_id_centers.shape,
            default_id_centers.dtype,
            data=default_id_centers,
        )
        # v4 uses an id field that is 1-indexed
        pose_grp.require_dataset(
            "instance_embed_id",
            data["identity"].shape,
            data["identity"].dtype,
            data=data["identity"],
        )

write_jabs_v5(data, filename)

Write JABS pose file v5 data to file.

Writes multi-mouse pose, longterm identity, and static object data.

Parameters:

Name Type Description Default
data dict

Dictionary of JABS data generated from convert_labels

required
filename str

Filename to write data to

required
Source code in sleap_io/io/jabs.py
def write_jabs_v5(data: dict, filename: str):
    """Write JABS pose file v5 data to file.

    Writes multi-mouse pose, longterm identity, and static object data.

    Args:
        data: Dictionary of JABS data generated from convert_labels
        filename: Filename to write data to
    """
    # v5 extends v4
    write_jabs_v4(data, filename)
    with h5py.File(filename, "a") as h5:
        pose_grp = h5.require_group("poseest")
        pose_grp.attrs.update({"version": [5, 0]})
        if "static_objects" in data.keys():
            object_grp = h5.require_group("static_objects")
            for object_key, object_keypoints in data["static_objects"].items():
                object_grp.require_dataset(
                    object_key,
                    object_keypoints.shape,
                    np.uint16,
                    data=object_keypoints.astype(np.uint16),
                )

write_labels(labels, pose_version, root_folder)

Convert and save a SLEAP Labels object to a JABS pose file.

Only supports pose version 2 (single mouse) and 3-5 (multi mouse).

Parameters:

Name Type Description Default
labels Labels

SLEAP Labels to be converted to JABS pose format.

required
pose_version int

JABS pose version to use when writing data.

required
root_folder str

Root folder where the jabs files should be written

required
Source code in sleap_io/io/jabs.py
def write_labels(labels: Labels, pose_version: int, root_folder: str):
    """Convert and save a SLEAP `Labels` object to a JABS pose file.

    Only supports pose version 2 (single mouse) and 3-5 (multi mouse).

    Args:
        labels: SLEAP `Labels` to be converted to JABS pose format.
        pose_version: JABS pose version to use when writing data.
        root_folder: Root folder where the jabs files should be written
    """
    for video in labels.videos:
        converted_labels = convert_labels(labels, video)
        out_filename = (
            os.path.splitext(video.filename)[0] + f"_pose_est_v{pose_version}.h5"
        )
        if root_folder:
            out_filename = os.path.join(root_folder, out_filename)
        os.makedirs(os.path.dirname(out_filename), exist_ok=True)
        if os.path.exists(out_filename):
            warnings.warn(f"Skipping {out_filename} because it already exists.")
            continue
        if pose_version == 2:
            write_jabs_v2(converted_labels, out_filename)
        elif pose_version == 3:
            write_jabs_v3(converted_labels, out_filename)
        elif pose_version == 4:
            write_jabs_v4(converted_labels, out_filename)
        elif pose_version == 5:
            write_jabs_v5(converted_labels, out_filename)
        else:
            raise NotImplementedError(f"Pose format {pose_version} not supported.")