diff --git a/notebooks/Parametric_UMAP/08.0-torch-parametric-umap.ipynb b/notebooks/Parametric_UMAP/08.0-torch-parametric-umap.ipynb new file mode 100644 index 00000000..58fecf2f --- /dev/null +++ b/notebooks/Parametric_UMAP/08.0-torch-parametric-umap.ipynb @@ -0,0 +1,334 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Some examples of using the torch version of Parametric UMAP\n", + "#\n", + "# Borrows ideas/code from:\n", + "# * https://github.com/lmcinnes/umap/issues/580\n", + "# * https://colab.research.google.com/drive/1CYxt0GD-Y2zPMOnJIXJWsAhr0LdqI0R6\n", + "#\n", + "# Uses the MNIST dataset to demonstrate how to use the code in two examples:\n", + "# * Vanilla parametric UMAP using the default arguments\n", + "# * Parametric Umap with a custom encoder, and a decoder for inversion. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Load the MNIST dataset (requires torchvision)\n", + "# Store dataset in variable X, where X[i] is\n", + "# the flattened pixel values for digit i (784=28^2 components).\n", + "\n", + "import torch\n", + "import torchvision\n", + "from torchvision.transforms import transforms\n", + "\n", + "download_dir = '/tmp'\n", + "\n", + "train_dataset = torchvision.datasets.MNIST(root=f'{download_dir}/.data',\n", + " train=True,\n", + " transform=transforms.ToTensor(),\n", + " download=True)\n", + "train_label = torch.tensor([example[1] for example in train_dataset])\n", + "train_tensor = torch.stack([example[0] for example in train_dataset])[:, 0][:, None, ...]\n", + "labels = [str(example[1]) for example in train_dataset]\n", + "X = train_tensor\n", + "X = X.reshape(-1,28*28)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.061779022216796875, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "", + "rate": null, + "total": 7103, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "8dd3148a55c64684a0d4f490f9e38a1b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/7103 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualise the results (requires matplotlib)\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "with plt.style.context('seaborn-paper'):\n", + " fig,ax = plt.subplots(figsize=(8,8))\n", + " int_labels = [int(y) for y in labels]\n", + " scatter = ax.scatter(*U.T,\n", + " c=int_labels,\n", + " s=0.5,\n", + " cmap=plt.cm.tab10)\n", + " legend = ax.legend(*scatter.legend_elements(),\n", + " loc=\"lower left\", title=\"Classes\")\n", + " ax.add_artist(legend)\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# We can define our own Encoder, in this example we use\n", + "# a CNN to encode the MNIST digits.\n", + "\n", + "# We can also define a decoder to perform inversion and map from the 2D\n", + "# UMAP projection back to a digit. To do this, we use a simple CNN as a decoder.\n", + "\n", + "import torch.nn as nn\n", + "\n", + "class ConvEncoder(nn.Module):\n", + " '''Simple CNN to encode 28*28 images into 2d vectors'''\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " self.encoder = nn.Sequential(\n", + " nn.Conv2d(\n", + " in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=0,\n", + " ),\n", + " nn.PReLU(),\n", + " nn.Conv2d(\n", + " in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=0,\n", + " ),\n", + " nn.Flatten(),\n", + " nn.Linear(128*6*6, 128),\n", + " nn.PReLU(),\n", + " nn.Linear(128, 64),\n", + " nn.PReLU(),\n", + " nn.Linear(64, 2)\n", + " )\n", + " def forward(self, X):\n", + " return self.encoder(X)\n", + " \n", + "class ConvDecoder(nn.Module):\n", + " ''' Simple CNN to decode 2d vectors into 28*28 images'''\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " self.decoder = nn.Sequential(\n", + " nn.Linear(2, 32),\n", + " nn.PReLU(),\n", + " nn.Linear(32, 128),\n", + " nn.PReLU(),\n", + " nn.Linear(128, 64*7*7),\n", + " nn.Unflatten(dim=-1,unflattened_size=(64,7,7)),\n", + " nn.ConvTranspose2d(64, 32, \n", + " kernel_size=3, \n", + " stride=2, \n", + " padding=1, \n", + " output_padding=1),\n", + " nn.PReLU(),\n", + " nn.ConvTranspose2d(32, 1, \n", + " kernel_size=3, \n", + " stride=2, \n", + " padding=1, \n", + " output_padding=1),\n", + " nn.Sigmoid()\n", + " )\n", + " def forward(self, X):\n", + " return self.decoder(X)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "application/json": { + "ascii": false, + "bar_format": null, + "colour": null, + "elapsed": 0.031854867935180664, + "initial": 0, + "n": 0, + "ncols": null, + "nrows": null, + "postfix": null, + "prefix": "", + "rate": null, + "total": 7103, + "unit": "it", + "unit_divisor": 1000, + "unit_scale": false + }, + "application/vnd.jupyter.widget-view+json": { + "model_id": "647e954df93e4de9ac768be8915204a5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/7103 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualise the results (requires matplotlib)\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "with plt.style.context('seaborn-paper'):\n", + " fig,ax = plt.subplots(figsize=(8,8))\n", + " int_labels = [int(y) for y in labels]\n", + " scatter = ax.scatter(*U.T,\n", + " c=int_labels,\n", + " s=0.5,\n", + " cmap=plt.cm.tab10)\n", + " legend = ax.legend(*scatter.legend_elements(),\n", + " loc=\"lower left\", title=\"Classes\")\n", + " ax.add_artist(legend)\n", + " plt.show()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# See how our decoder performs by inverting the encoding of training samples\n", + "from matplotlib import pyplot as plt \n", + "import numpy as np \n", + "\n", + "V = pumap.inverse_transform(U)\n", + "fig,axs = plt.subplots(5,2,figsize=(5,15))\n", + "axs=axs.flatten()\n", + "for i in range(5):\n", + " j = np.random.randint(len(X))\n", + " axs[2*i].imshow(X[j].reshape((28,28)))\n", + " axs[2*i+1].imshow(V[j].reshape((28,28)))\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pytorch", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/setup.py b/setup.py index 9e2a9354..7bba2acb 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ def readme(): "scikit-image", ], "parametric_umap": ["tensorflow >= 2.1"], + "torch_umap": ["torch"], "tbb": ["tbb >= 2019.0"], }, "ext_modules": [], diff --git a/umap/torch.py b/umap/torch.py new file mode 100644 index 00000000..1211a7b7 --- /dev/null +++ b/umap/torch.py @@ -0,0 +1,487 @@ +""" +Pytorch implimentation of ParametricUMAP +Borrows ideas/code from: + * https://github.com/lmcinnes/umap/issues/580 + * https://colab.research.google.com/drive/1CYxt0GD-Y2zPMOnJIXJWsAhr0LdqI0R6 +""" + +import numpy as np +from pynndescent import NNDescent +from sklearn.utils import check_random_state +from tqdm.auto import tqdm as tq +from warnings import warn + +from .umap_ import fuzzy_simplicial_set, find_ab_params + +try: + import torch + from torch.utils.data import Dataset, DataLoader + +except ImportError: + warn( + """The umap.torch package requires PyTorch to be installed. + You can install PyTorch at https://pytorch.org/ + + + """ + ) + raise ImportError("umap.torch requires torch") from None + + +def convert_distance_to_probability(distances, a=1.0, b=1.0): + return -torch.log1p(a * distances ** (2 * b)) + + +def compute_cross_entropy( + probabilities_graph, probabilities_distance, repulsion_strength=1.0 +): + # cross entropy + attraction_term = -probabilities_graph * torch.nn.functional.logsigmoid( + probabilities_distance + ) + repellant_term = ( + -(1.0 - probabilities_graph) + * ( + torch.nn.functional.logsigmoid(probabilities_distance) + - probabilities_distance + ) + * repulsion_strength + ) + + # balance the expected losses between attraction and repulsion + CE = attraction_term + repellant_term + return attraction_term, repellant_term, CE + + +def umap_loss(embedding_to, embedding_from, _a, _b, batch_size, negative_sample_rate=5): + # get negative samples by randomly shuffling the batch + embedding_neg_to = embedding_to.repeat(negative_sample_rate, 1) + repeat_neg = embedding_from.repeat(negative_sample_rate, 1) + embedding_neg_from = repeat_neg[torch.randperm(repeat_neg.shape[0])] + distance_embedding = torch.cat( + ( + (embedding_to - embedding_from).norm(dim=1), + (embedding_neg_to - embedding_neg_from).norm(dim=1), + ), + dim=0, + ) + + # convert probabilities to distances + probabilities_distance = convert_distance_to_probability(distance_embedding, _a, _b) + # set true probabilities based on negative sampling + probabilities_graph = torch.cat( + (torch.ones(batch_size), torch.zeros(batch_size * negative_sample_rate)), + dim=0, + ) + + # compute cross entropy + (attraction_loss, repellant_loss, ce_loss) = compute_cross_entropy( + probabilities_graph.cuda(), + probabilities_distance.cuda(), + ) + loss = torch.mean(ce_loss) + return loss + + +def get_umap_graph( + X, n_neighbors=10, metric="cosine", random_state=None, verbose=False +): + random_state = check_random_state(None) if random_state is None else random_state + # number of trees in random projection forest + n_trees = 5 + int(round((X.shape[0]) ** 0.5 / 20.0)) + # max number of nearest neighbor iters to perform + n_iters = max(5, int(round(np.log2(X.shape[0])))) + # distance metric + + # get nearest neighbors + nnd = NNDescent( + X.reshape((len(X), np.product(np.shape(X)[1:]))), + n_neighbors=n_neighbors, + metric=metric, + n_trees=n_trees, + n_iters=n_iters, + max_candidates=60, + verbose=verbose, + ) + # get indices and distances + knn_indices, knn_dists = nnd.neighbor_graph + + # get indices and distances + knn_indices, knn_dists = nnd.neighbor_graph + # build fuzzy_simplicial_set + umap_graph, sigmas, rhos = fuzzy_simplicial_set( + X=X, + n_neighbors=n_neighbors, + metric=metric, + random_state=random_state, + knn_indices=knn_indices, + knn_dists=knn_dists, + ) + + return umap_graph + + +def get_graph_elements(graph_, n_epochs): + + graph = graph_.tocoo() + # eliminate duplicate entries by summing them together + graph.sum_duplicates() + # number of vertices in dataset + n_vertices = graph.shape[1] + # get the number of epochs based on the size of the dataset + if n_epochs is None: + # For smaller datasets we can use more epochs + if graph.shape[0] <= 10_000: + n_epochs = 50 + else: + n_epochs = 20 + + # remove elements with very low probability + graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0 + graph.eliminate_zeros() + # get epochs per sample based upon edge probability + epochs_per_sample = n_epochs * graph.data + + head = graph.row + tail = graph.col + weight = graph.data + + return graph, epochs_per_sample, head, tail, weight, n_vertices + + +class UMAPDataset(Dataset): + """A dataset containing positive edges from the umap graph. + If data is provided, returns the data vectors, otherwise returns the indices + """ + + def __init__(self, graph_, data=None, n_epochs=None): + graph, epochs_per_sample, head, tail, weight, n_vertices = get_graph_elements( + graph_, n_epochs + ) + + self.edges_to_ix, self.edges_from_ix = ( + np.repeat(head, epochs_per_sample.astype("int")), + np.repeat(tail, epochs_per_sample.astype("int")), + ) + + if data is not None: + self.data = torch.Tensor(data) + else: + self.data = None + + def __len__(self): + return int(self.edges_to_ix.shape[0]) + + def __getitem__(self, index): + edges_to_ix = self.edges_to_ix[index] + edges_from_ix = self.edges_from_ix[index] + + if self.data is not None: + edges_to_exp = self.data[edges_to_ix] + edges_from_exp = self.data[edges_from_ix] + return edges_to_exp, edges_from_exp + else: + return edges_to_ix, edges_from_ix + + +class Encoder(torch.nn.Module): + """ + Default encoder for ParametricUmap class + """ + + def __init__( + self, + input_channels, + output_channels, + hidden_channels=128, + activation=torch.nn.LeakyReLU, + ): + super().__init__() + self.encoder = torch.nn.Sequential( + torch.nn.Linear(input_channels, hidden_channels), + activation(), + torch.nn.Linear(hidden_channels, hidden_channels), + activation(), + torch.nn.Linear(hidden_channels, output_channels), + ) + + def forward(self, X): + return self.encoder(X) + + +class ParametricUMAP: + def __init__( + self, + n_components=2, + n_neighbors=15, + metric="euclidean", + n_training_epochs=1, + n_epochs=None, + negative_sample_rate=5, + lr=1e-3, + min_dist=0.1, + encoder=None, + decoder=None, + beta=1, + batch_size=1024, + num_workers=4, + random_state=None, + device=None, + verbose=True, + ): + """ + Parametric UMAP implimentation in PyTorch + + Parameters + ---------- + n_components: int (optional, default 2) + The dimension of the space to embed into. This defaults to 2 to + provide easy visualization, but can reasonably be set to any + integer value in the range 2 to 100. + + n_neighbors: float (optional, default 15) + The size of local neighborhood (in terms of number of neighboring + sample points) used for manifold approximation. Larger values + result in more global views of the manifold, while smaller + values result in more local data being preserved. In general + values should be in the range 2 to 100. + + metric: string or function (optional, default 'euclidean') + The metric to use to compute distances in high dimensional space. + If a string is passed it must match a valid predefined metric. If + a general metric is required a function that takes two 1d arrays and + returns a float can be provided. For performance purposes it is + required that this be a numba jit'd function. Valid string metrics + include: + + * euclidean + * manhattan + * chebyshev + * minkowski + * canberra + * braycurtis + * mahalanobis + * wminkowski + * seuclidean + * cosine + * correlation + * haversine + * hamming + * jaccard + * dice + * russelrao + * kulsinski + * ll_dirichlet + * hellinger + * rogerstanimoto + * sokalmichener + * sokalsneath + * yule + + TODO: The torch implimentation currently does not support additional + arguments that should be passed to the metric (e.g. minkowski, + mahalanobis etc.) + + n_training_epochs: int (optional, default 1) + The number of training epochs to be used in optimizing the + low dimensional embedding. Corresponds to the number of times + we optimisze over the training dataloader. + + n_epochs: int (optional, default None) + The number of epochs used in constructing the UMAP dataset. + The highest probability edge in the umap graph will appear in + the train dataset n_epochs times. edges with lower probability + are represented proportionally. A larger value will result in + a larger, but more accurate dataset. + Defaults to 50 for small datasets, 20 for large. + + negative_sample_rate: int (optional, default 5) + The number of negative samples to select per positive sample + in the optimization process. Increasing this value will result + in greater repulsive force being applied, greater optimization + cost, but slightly more accuracy. + + lr: float (optional, default 1e-3) + The learning rate for the embedding optimization. + Passed to the torch optimizer. + + min_dist: float (optional, default 0.1) + The effective minimum distance between embedded points. Smaller values + will result in a more clustered/clumped embedding where nearby points + on the manifold are drawn closer together, while larger values will + result on a more even dispersal of points. The value should be set + relative to the ``spread`` value, which determines the scale at which + embedded points will be spread out. + + encoder: torch.nn.Module (optional, default None) + An encoder which takes items from your data and maps them to + vectors of size n_components. Defaults to a standard multi-layer + encoder model (3 linear layers with LeakyReLU activation). + + decoder: torch.nn.Module (optional, default None) + A decoder for inverting vectors of shape n_components, returning + vectors shaped like the input data. Default is none, meaning that + we do not train a decoder. + + beta: float (optional, default 1) + The contribution of the decoder loss to the total loss. Total loss + is given by umap_loss + beta * decoder_loss. Increasing/decreasing + this will prioritise decoder loss over umap loss and vice versa. + + batch_size: int (optional, default 1024) + Training batch size. 1024 is a sensible default for medium-large datasets. + + num_workers: int (optional, default 4) + Number of workers used to manage the training dataloader. + Defaults to 4, but performance may be boosted by increasing this for + large datasets on machines with many cores. + + random_state: int or instance of RandomState (optional, default None) + controls the random_state which is used in creating the umap graph. + Setting this seed does not guarantee reproducability since it is + not passed through to the torch modules. + + device: str, 'cpu' or 'cuda' (optional, default None) + Controls the device on which we train the umap model. Set to 'cpu' + for cpu training, or 'cuda' for gpu training. Default behaviour is + to search for the active device via torch.cuda.is_available(). + + verbose: bool (optional, default True) + Controls whether we have progress bars during training. + + """ + self.n_components = n_components + self.encoder = encoder + self.decoder = decoder + self.n_neighbors = n_neighbors + self.min_dist = min_dist + self.beta = beta + self.metric = metric + self.lr = lr + self.n_training_epochs = n_training_epochs + self.n_epochs = n_epochs + self.negative_sample_rate = negative_sample_rate + self.batch_size = batch_size + self.num_workers = num_workers + self.random_state = random_state + self._a, self._b = find_ab_params(1.0, self.min_dist) + + if device is None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device + + self.verbose = verbose + + def fit(self, X): + + if type(X) is np.ndarray: + X = torch.from_numpy(X).float() + + assert isinstance(X, torch.Tensor) + + if self.encoder is None: + self.encoder = Encoder( + input_channels=X.shape[-1], output_channels=self.n_components + ) + + # Move encoder/decoder to correct device + self.encoder.to(self.device) + if self.decoder is not None: + self.decoder.to(self.device) + + graph = get_umap_graph( + X, + n_neighbors=self.n_neighbors, + metric=self.metric, + random_state=self.random_state, + ) + + dataset = UMAPDataset(graph, data=X, n_epochs=self.n_epochs) + + dataloader = DataLoader( + dataset=dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + # Don't forget to add decoder to optimizer if it is present + if self.decoder is None: + optimizer = torch.optim.AdamW(self.encoder.parameters(), lr=self.lr) + else: + optimizer = torch.optim.AdamW( + (*self.encoder.parameters(), *self.decoder.parameters()), lr=self.lr + ) + + # Use tqdm for nice loading bars if verbose flag set + # otherwise run silently + if self.verbose: + + def wrapper(loader): + return tq(enumerate(loader), total=len(loader), leave=False) + + else: + + def wrapper(loader): + return enumerate(loader) + + for epoch in range(self.n_training_epochs): + + for ib, batch in (batch_pbar := wrapper(dataloader)): + + total_loss = 0 + + edges_to_exp, edges_from_exp = batch + edges_to_exp = edges_to_exp.to(self.device) + edges_from_exp = edges_from_exp.to(self.device) + + embedding_to = self.encoder(edges_to_exp) + embedding_from = self.encoder(edges_from_exp) + + encoder_loss = umap_loss( + embedding_to, + embedding_from, + self._a, + self._b, + edges_to_exp.shape[0], + negative_sample_rate=self.negative_sample_rate, + ) + + total_loss += encoder_loss + + if self.decoder is not None: + recon = self.decoder(embedding_to) + recon_loss = torch.nn.functional.mse_loss(recon, edges_to_exp) + total_loss += self.beta * recon_loss + + total_loss.backward() + optimizer.step() + optimizer.zero_grad() + + if self.verbose: + desc = f"Batch: {ib} Training loss: {total_loss.item():5.3f}" + if self.decoder is not None: + desc += f" | Umap loss: {encoder_loss.item():5.3f}" + desc += f" | Reconstruction loss: {recon_loss.item():5.3f}" + batch_pbar.set_description(desc) + + def fit_transform(self, X): + self.fit(X) + return self.transform(X) + + @torch.no_grad() + def transform(self, X): + if type(X) is np.ndarray: + X = torch.from_numpy(X).float() + self.embedding_ = self.encoder(X.to(self.device)).detach().cpu().numpy() + return self.embedding_ + + @torch.no_grad() + def inverse_transform(self, Z): + assert ( + self.decoder is not None + ), "No inverse_transform available, decoder is None." + if type(Z) is np.ndarray: + Z = torch.from_numpy(Z).float() + return self.decoder(Z.to(self.device)).detach().cpu().numpy()