Source code for omnihuman.utils.io

"""
IO Module
=========

This module provides utility functions for reading and writing files.

Functions:
----------
- read_frames: Read frames from image or video file as 4D tensor (n_frames, n_channels, height, width).
- fetch_pretrained_weights: Downloads & reads specific tensors from a Hugging Face Hub repository.
"""

__all__ = ["read_frames", "fetch_pretrained_weights"]

from mimetypes import guess_type
from typing import Dict

import torch
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torchvision.io import read_image, read_video


[docs] def read_frames(path: str) -> torch.Tensor: """Read frames from image or video file as 4D tensor (n_frames, n_channels, height, width). Args: path (str): Where the image or video file is located. Raises: ValueError: If the file type is neither image nor video. Returns: torch.Tensor: Frames as 4d torch tensor of shape (n_frames, n_channels, height, width). """ file_type = guess_type(path)[0] or "" if file_type.startswith("image"): frames = read_image(path) frames = frames.unsqueeze(0) elif file_type.startswith("video"): frames, *_ = read_video(path) else: raise ValueError(f"Unsupported file type: '{file_type}' of '{path}'") return frames
[docs] def fetch_pretrained_weights( repo_id: str, weight_name_to_file_name: Dict[str, str], ) -> Dict[str, torch.Tensor]: """Downloads specific files from a Hugging Face Hub repository and loads specific tensors from them. Args: repo_id (str): Model repository name on Hugging Face Hub (e.g. "organization/their-awesome-model"). weight_name_to_file_name (Dict[str, str]): mapping of layer name to the shard file that contains its weights (Explore the `model.safetensors.index.json` in the HF model repo for names) Returns: Dict[str, torch.Tensor]: mapping from layer name to its weights tensor """ weights = {} for weight_name, filename in weight_name_to_file_name.items(): file_path = hf_hub_download(repo_id=repo_id, filename=filename) with safe_open(file_path, framework="pt", device="cpu") as f: weights[weight_name] = f.get_tensor(weight_name) return weights