Source code for omnihuman.models.text.embeddings

"""Module for Token embeddings."""

__all__ = ["pool_embeddings"]

from typing import Iterable

import torch


[docs] def pool_embeddings( embeddings: Iterable[torch.Tensor], out_dim: int = 512, pre_norm: bool = True, post_norm: bool = True, ) -> torch.Tensor: """Combine multiple embedding tables into a single one by concatenating them and reducing the dimensionality in the result. The vocabulary size of all embeddings **must** be the same. Pooling is done by first concatenating the embeddings across embed_dim (`shape=(vocabulary_size, sum(e.shape[-1] for e in embeddings))`), then reshaping them into a 3D tensor of shape (vocabulary_size, out_dim, n), and finally taking the mean across the last dimension. i.e. The resulting tensor is [[mean(0:n), mean(n:2n), ..., mean(-n:total_embed_size)], ...] so the features from different embedding tables are not mixed. Args: embeddings (Iterable[torch.Tensor]): A list of 2D tensors (vocabulary_size, embedding_dim) to be combined. out_dim (int, optional): Number of features per token in the output tensor. Defaults to 512. pre_norm (bool, optional): L2 normalize each input embedding table so that each token's vector has magnitude 1. Defaults to True. post_norm (bool, optional): L2 normalize the pooled tensor so that each token's vector has magnitude 1. Defaults to True. Raises: ValueError: If the sum of all embedding sizes is smaller than out_dim. ValueError: If out_dim is not a divisor of the sum of all embedding sizes. Returns: torch.Tensor: The pooled embedding table of shape (vocabulary_size, out_dim). """ stacked_size = sum(e.shape[-1] for e in embeddings) if stacked_size < out_dim: raise ValueError(f"{out_dim=} is smaller than the concatenated embedding dimension {stacked_size}") if stacked_size % out_dim != 0: raise ValueError(f"{out_dim=} is not a divisor of the concatenated embedding dimension {stacked_size}") if len(vocab_sizes := {e.shape[-2] for e in embeddings}) != 1: raise ValueError(f"All embedding tables must have the same vocabulary size. Got {vocab_sizes=}") embedding = torch.cat( [(v32 := v.type(torch.float32)) / (v32.norm(dim=-1, keepdim=True) if pre_norm else 1) for v in embeddings], dim=-1, ) if embedding.shape[1] == out_dim: return embedding embedding = embedding.view(embedding.shape[0], out_dim, -1).mean(dim=-1) if post_norm: embedding = embedding / embedding.norm(dim=-1, keepdim=True) return embedding