Skip to content

Commit

Permalink
Update _points2regions.py
Browse files Browse the repository at this point in the history
Bug fix in anndata creation
  • Loading branch information
axanderssonuu committed Jun 4, 2024
1 parent 7f92fd9 commit 7571398
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions points2regions/_points2regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, xy:np.ndarray, labels:np.ndarray, pixel_width:float, pixel_sm
self._num_points = None
self._cluster_centers = None
self._is_clustered = False
self.inertia = None
self._extract_features(xy, labels, pixel_width, pixel_smoothing, min_num_pts_per_pixel, datasetids)


Expand Down Expand Up @@ -338,9 +339,10 @@ def _cluster(self, num_clusters:int, seed:int=42):

# Run K-Means
n_kmeans_clusters = int(1.75 * num_clusters)
kmeans = KMeans(n_clusters=n_kmeans_clusters, n_init='auto', random_state=seed, max_iter=100, batch_size=256, max_no_improvement=10, init_size=100, reassignment_ratio=0.05)
kmeans = KMeans(n_clusters=n_kmeans_clusters, n_init='auto', random_state=seed, max_iter=100, batch_size=256, max_no_improvement=10, init_size=100, reassignment_ratio=0.005)
kmeans = kmeans.fit(self.X_train)

self.inertia = kmeans.inertia_

# Merge clusters using agglomerative clustering
clusters = _merge_clusters(kmeans, num_clusters)

Expand Down Expand Up @@ -518,7 +520,7 @@ def _get_anndata(self, cluster_key_added:str='Clusters') -> Any:
# Create an adata object
import anndata
import pandas as pd

print('Creating anndata')
# Get position of bins for each group (library id)
xy_pixel = np.vstack([
r['xy_pixel'][r['passed_threshold']] for r in self._results.values()
Expand All @@ -527,12 +529,12 @@ def _get_anndata(self, cluster_key_added:str='Clusters') -> Any:
# Get labels of bins for each group (library id)
labels_pixel = np.hstack([
r['cluster_per_pixel'][r['passed_threshold']] for r in self._results.values()
])
]).astype(str)

obs = {}
obs[cluster_key_added] = labels_pixel
if len(self._results) > 1:
obs['datasetid'] = np.hstack([[id]*len(r['cluster_per_pixel']) for id, r in self._results.items()])
obs['datasetid'] = np.hstack([[id]*len(r['cluster_per_pixel'][r['passed_threshold']]) for id, r in self._results.items()])

# Multiply back features with the norm
norms = 1.0 / np.hstack([r['norms'][r['passed_threshold']] for r in self._results.values()])
Expand All @@ -550,10 +552,12 @@ def _get_anndata(self, cluster_key_added:str='Clusters') -> Any:
adata = anndata.AnnData(
X=X,
obs=obs,
obsm=dict(spatial=xy_pixel),
var=pd.DataFrame(index=self._unique_labels)
obsm=dict(spatial=xy_pixel)
)

adata.var_names = self._unique_labels
adata.obs['datasetid'] = adata.obs['datasetid'].astype('int')

adata.obs[cluster_key_added] = adata\
.obs[cluster_key_added]\
.astype('category')
Expand Down Expand Up @@ -592,6 +596,7 @@ def _get_anndata(self, cluster_key_added:str='Clusters') -> Any:

if self._datasetids is not None:
reads['datasetid'] = self._datasetids
reads['datasetid'] = reads['datasetid'].astype('int')

# Create the dataframe
reads = pd.DataFrame(reads)
Expand Down

0 comments on commit 7571398

Please sign in to comment.