Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Brad Balderson committed Jun 28, 2023
2 parents 7e34de4 + b0cb409 commit 19fba01
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 91 deletions.
2 changes: 1 addition & 1 deletion cytocipher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Top-level package for stLearn."""
"""Top-level package for Cytocipher."""

__author__ = """Brad Balderson"""
__email__ = "[email protected]"
Expand Down
12 changes: 11 additions & 1 deletion cytocipher/cluster/go.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
""" Wrapper for performing the LR GO analysis.
To get this to work, needed to make the following environment (did NOT work with python 3.8.12):
conda create -n rpy2_env python=3.9
conda activate rpy2_env
pip install rpy2
"""

import os
#from ..utils.r_helpers import rpy2_setup, ro, localconverter, pandas2ri
import neurotools.utils.r_helpers as rhs
#import neurotools.utils.r_helpers as rhs
#import cytocipher.utils.r_helpers as rhs
#from ..utils import r_helpers as rhs
import r_helpers as rhs


def run_GO(genes, bg_genes, species, r_path, p_cutoff=0.01, q_cutoff=0.5, onts="BP"):
"""Running GO term analysis."""
Expand Down
127 changes: 89 additions & 38 deletions cytocipher/score_and_merge/_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,63 +24,79 @@ def general_neighbours(data: sc.AnnData,
#knn_adj_matrix = data.uns[neigh_key]['connectivities'].toarray() > 0
knn_adj_matrix = data.obsp['connectivities'].toarray() > 0

neighbours, dists, clust_dists = get_neighs_FAST(labels, label_set,
knn_adj_matrix,
mnn_frac_cutoff)
return list(neighbours), list(dists), \
#### Calculating the proportion of MNNs between clusters
label_set_dict = {label: i for i, label in enumerate(label_set)}
labels_as_indices = np.array([label_set_dict[label_] for label_ in labels],
dtype=np.int64)
label_lens = np.array(
[len(np.where(labels==label)[0]) for label in label_set],
dtype=np.int64)
clust_dists = get_clust_dists(labels_as_indices, len(label_set),
label_lens,
knn_adj_matrix, mnn_frac_cutoff)

#### Calculating the neighbours that are closer than mnn_frac_cutoff
neighbours, dists = get_neighs_FAST(label_set, mnn_frac_cutoff,
clust_dists)

return neighbours, dists, \
pd.DataFrame(clust_dists, index=label_set, columns=label_set)

#@jit(parallel=True, forceobj=True, nopython=False)
#@njit #(parallel=True)
#@jit(parallel=False, forceobj=True, nopython=False)
def get_neighs_FAST(labels: np.array, label_set: np.array,
knn_adj_matrix: np.ndarray,
mnn_frac_cutoff: float):
""" Get's the neighbourhoods using method described in doc-string
of general_neighbours, VERY quickly.
@njit(parallel=False)
def get_clust_dists(labels_as_indices: np.array, n_labels: int,
label_lens: np.array,
knn_adj_matrix: np.ndarray, mnn_frac_cutoff: float):
""" Gets the distance between the clusters as the proportion of MNNs.
"""
clust_dists = np.zeros((n_labels, n_labels), dtype=np.float64)

### Counting the MNNs for each cluster ###
clust_dists = np.zeros((len(label_set), len(label_set)), dtype=np.float64)
for i in prange( len(label_set) ):
labeli = label_set[i]

#labeli_indices = get_indices(labels, labeli)
labeli_indices = np.where(labels==labeli)[0]
#### Getting the cells which are MNNs
knn_indices = np.where( knn_adj_matrix )
for pairi in range( len(knn_indices[0]) ):
i, j = knn_indices[0][pairi], knn_indices[1][pairi]

labeli_knns = knn_adj_matrix[labeli_indices, :]
labeli, labelj = labels_as_indices[i], labels_as_indices[j]
clust_dists[labeli, labelj] += 1
clust_dists[labelj, labeli] += 1

for j in range((i + 1), len(label_set)):
labelj = label_set[j]
#### Getting totals to get proportions...
mnn_indices = np.where( clust_dists > 0 )
for mnni in prange( len(mnn_indices[0]) ):
i, j = mnn_indices[0][mnni], mnn_indices[1][mnni]
if i==j:
clust_dists[i, j] = 0
continue

#labelj_indices = get_indices(labels, labelj)
labelj_indices = np.where(labels == labelj)[0]
total = label_lens[i] + label_lens[j]

labelj_knns = knn_adj_matrix[labelj_indices, :]
clust_dists[i,j] = clust_dists[i,j] / total

nn_ij = labeli_knns[:, labelj_indices]
nn_ji = labelj_knns[:, labeli_indices].transpose()
mnn_bool = np.logical_and(nn_ij, nn_ji)
return clust_dists

#n_total = np.logical_or(labeli_bool, labelj_bool).sum()
n_total = len(labeli_indices) + len(labelj_indices)
mnn_dist = mnn_bool.sum() / n_total
#@jit(parallel=True, forceobj=True, nopython=False)
#@njit #(parallel=True)
@jit(parallel=False, forceobj=True, nopython=False)
def get_neighs_FAST(label_set: np.array, mnn_frac_cutoff: float,
clust_dists: np.ndarray
):
""" Get's the neighbourhoods using method described in doc-string
of general_neighbours, VERY quickly.
"""

clust_dists[i, j] = mnn_dist
clust_dists[j, i] = mnn_dist
neigh_bools = clust_dists > mnn_frac_cutoff

##### Now converting this into neighbourhood information....
neighbours = [] #List()
dists = [] #List()
neighbours = []
dists = []
for i, label in enumerate(label_set):
neigh_bool = clust_dists[i, :] > mnn_frac_cutoff
#neigh_bool = clust_dists[i, :] > mnn_frac_cutoff
#neigh_indices = get_true_indices( neigh_bool )
neigh_indices = np.where(neigh_bool)[0]
neigh_indices = neigh_bools[i, :] #np.where(neigh_bool)[0]

neighbours.append( label_set[neigh_indices] )
dists.append( clust_dists[i,:][neigh_indices] )

return neighbours, dists, clust_dists
return neighbours, dists

################################################################################
# The below are old cluster neighbourhood determining functions
Expand Down Expand Up @@ -140,3 +156,38 @@ def all_neighbours(label_set: np.array):
dists.append([np.nan] * (len(label_set) - 1))

return neighbours, dists

@njit(parallel=False)
def get_clust_dists_OLD(labels: np.array, label_set: np.array,
knn_adj_matrix: np.ndarray, mnn_frac_cutoff: float):
""" Gets the distance between the clusters as the proportion of MNNs.
"""
clust_dists = np.zeros((len(label_set), len(label_set)), dtype=np.float64)
for i in prange(len(label_set)):
labeli = label_set[i]

#labeli_indices = get_indices(labels, labeli)
labeli_indices = np.where(labels == labeli)[0]

labeli_knns = knn_adj_matrix[labeli_indices, :]

for j in range((i + 1), len(label_set)):
labelj = label_set[j]

#labelj_indices = get_indices(labels, labelj)
labelj_indices = np.where(labels == labelj)[0]

labelj_knns = knn_adj_matrix[labelj_indices, :]

nn_ij = labeli_knns[:, labelj_indices]
nn_ji = labelj_knns[:, labeli_indices].transpose()
mnn_bool = np.logical_and(nn_ij, nn_ji)

# n_total = np.logical_or(labeli_bool, labelj_bool).sum()
n_total = len(labeli_indices) + len(labelj_indices)
mnn_dist = mnn_bool.sum() / n_total

clust_dists[i, j] = mnn_dist
clust_dists[j, i] = mnn_dist

return clust_dists
57 changes: 56 additions & 1 deletion cytocipher/score_and_merge/cluster_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from collections import defaultdict
from numba.typed import List
from numba import jit

from .cluster_score import giotto_page_enrich, code_enrich, coexpr_enrich, \
get_markers
Expand All @@ -31,7 +32,7 @@ def average(expr: pd.DataFrame, labels: np.array, label_set: np.array):

return avg_data

def get_merge_groups(label_pairs: list):
def get_merge_groups_SLOW(label_pairs: list):
"""Examines the pairs to be merged, and groups them into large groups of
of clusters to be merged. This implementation will merge cluster pairs
if there exists a mutual cluster they are both non-significantly
Expand Down Expand Up @@ -82,6 +83,59 @@ def get_merge_groups(label_pairs: list):

return merge_groups

@jit(parallel=False, forceobj=True, nopython=False)
def get_merge_groups(label_pairs: list):
"""Examines the pairs to be merged, and groups them into large groups of
of clusters to be merged. This implementation will merge cluster pairs
if there exists a mutual cluster they are both non-significantly
different from. Can be mediated by filtering the pairs based on the
overlap of clusters they are both non-significantly different from
(which is performed in a separate function).
"""
#### Using a syncing strategy with a dictionary.
clust_groups = defaultdict(set)
#all_match_bool = [False] * len(label_pairs)
for pairi, pair in enumerate(label_pairs):
# NOTE we only need to do it for one clust of pair,
# since below syncs for other clust
clust_groups[pair[0]] = clust_groups[pair[0]].union(pair)

# Pull in the clusts from each other clust.
for clust in clust_groups[pair[0]]: # Syncing across clusters.
clust_groups[pair[0]] = clust_groups[pair[0]].union(
clust_groups[clust] )

# Update each other clust with this clusters clusts to merge
for clust in clust_groups[pair[0]]: # Syncing across clusters.
clust_groups[clust] = clust_groups[clust].union(
clust_groups[pair[0]] )

# Checking to make sure they now all represent the same thing....
# clusts = clust_groups[pair[0]]
# match_bool = [False] * len(clusts)
# for i, clust in enumerate(clusts):
# match_bool[i] = np.all(
# np.array(clust_groups[clust]) == np.array(clusts))
#
# all_match_bool[pairi] = np.all(match_bool)

# Just for testing purposes...
#print(np.all(all_match_bool))

# Getting the merge groups now.
merge_groups = [] #np.unique([tuple(group) for group in clust_groups.values()])
merge_groups_str = []
all_groups = list(clust_groups.values())
for group in all_groups:
group = list(group)
group.sort()
group_str = '_'.join(group)
if group_str not in merge_groups_str:
merge_groups.append( group )
merge_groups_str.append( group_str )

return merge_groups

##### Merging the clusters....
def merge_neighbours_v2(cluster_labels: np.array,
label_pairs: list):
Expand Down Expand Up @@ -293,6 +347,7 @@ def merge_clusters_single(data: sc.AnnData, groupby: str, key_added: str,
ps_dict = {}
for i, labeli in enumerate(label_set):
for j, labelj in enumerate(label_set):
### Only compare mutual neighbours ###
if labelj in neighbours[i] and labeli in neighbours[j]:

labeli_labelj_scores = label_scores[j][labels == labeli]
Expand Down
Loading

0 comments on commit 19fba01

Please sign in to comment.