In some spatial group atlas articles, we often see a picture, which is to calculate the complexity of cells in a region. We take an article published in Cell this year as an example to see how to calculate the complexity of cells in a region. The name of the article is Single-cell spatial transcriptome reveals cell-type organization in the macaque cortex, and the link to the article is: https://linkinghub.elsevier.com/retrieve/pii/S0092867423006797, there is such a picture in Figure II: 20231229110148

The meaning of I figure is the complexity distribution of each cell in different layers. The calculation of each cell's complexity is to take each cell as the center, 200 pixels as the radius, draw a circle, and calculate the number of cell types different from this cell in this circle. According to this understanding, we can write the program.

# first is the introduction of the package
import anndata as ad
import numpy as np
import pandas as pd
import scanpy as sc
from matplotlib.ticker import FuncFormatter
import seaborn as sns
import numpy as np
from collections import Counter
import matplotlib.pyplot as plt

here I wrote three functions, the calculation speed of the three functions is faster than one, but the degree of understanding is more difficult than one

# the first function is the most common one we understand to traverse all cells
def calculate_neighborhood_complexity(cell_coordinates, cluster_labels, radius):
    neighborhood_complexities = []

    for central_cell, central_cluster_label in zip(cell_coordinates, cluster_labels):
        complexity = 0

        for cell, cluster_label in zip(cell_coordinates, cluster_labels):
            distance = np.linalg.norm(cell - central_cell)
            if 0 < distance <= radius and cluster_label not in cluster_labels[:complexity]:
                complexity += 1

        neighborhood_complexities.append(complexity)

    complexity_counts = Counter(neighborhood_complexities)
    total_cells = len(neighborhood_complexities)
    complexity_probabilities = {
        complexity: count / total_cells for complexity, count in complexity_counts.items()
    }
    return np.array(neighborhood_complexities), complexity_probabilities

# and the second function uses some numpy functions, which is much faster than the first function
def calculate_neighborhood_complexity(cell_coordinates, cluster_labels, radius):
    unique_labels, label_indices = np.unique(cluster_labels, return_inverse=True)
    complexity_counts = np.zeros(len(cell_coordinates), dtype=int)
    
    for i, central_cell in enumerate(cell_coordinates):
        distances = np.linalg.norm(cell_coordinates - central_cell, axis=1)
        mask = np.logical_and(distances > 0, distances <= radius)
        previous_complexities = set()
        
        for cell_idx, cluster_idx in enumerate(label_indices[mask]):
            if cluster_idx not in previous_complexities:
                complexity_counts[i] += 1
                previous_complexities.add(cluster_idx)
    
    total_cells = len(cell_coordinates)
    complexity_probabilities = np.bincount(complexity_counts) / total_cells
    
    return complexity_counts, complexity_probabilities

# and the third function uses KD tree, which maintains the edge when building the tree, and the speed is much faster than the second function
from scipy.spatial import cKDTree

def calculate_neighborhood_complexity(cell_coordinates, cluster_labels, radius):
    unique_labels, label_indices = np.unique(cluster_labels, return_inverse=True)
    complexity_counts = np.zeros(len(cell_coordinates), dtype=int)

    # Build a KD-tree from the cell coordinates
    kdtree = cKDTree(cell_coordinates)

    for i, central_cell in enumerate(cell_coordinates):
        # Query the KD-tree to find the cell indices within the radius
        neighbor_indices = kdtree.query_ball_point(central_cell, radius)
        previous_complexities = set()

        for cell_idx, cluster_idx in zip(neighbor_indices, label_indices[neighbor_indices]):
            if cluster_idx not in previous_complexities:
                complexity_counts[i] += 1
                previous_complexities.add(cluster_idx)

    total_cells = len(cell_coordinates)
    complexity_probabilities = np.bincount(complexity_counts) / total_cells

    return complexity_counts, complexity_probabilities

and how to use it and how to draw pictures?

# adata is the data of the spatial group
# obs region is the information of the brain region
# obs celltype is the information of cell type
# obsm spatial is the coordinate information of cells
# layer is the information of each layer
# The measured speed is about 15 min for the first function for 20k cells, about 5s for the second function, and about 1.5s for the third function

cell_coordinates_list = [
    adata[adata.obs['region']=='layer1'].obsm['spatial'],
    adata[adata.obs['region']=='layer2'].obsm['spatial'],
    adata[adata.obs['region']=='layer3'].obsm['spatial'],
    adata[adata.obs['region']=='layer4'].obsm['spatial'],
]
cluster_labels_list = [
    adata[adata.obs['region']=='layer1'].obs['celltype'].tolist(),
    adata[adata.obs['region']=='layer2'].obs['celltype'].tolist(),
    adata[adata.obs['region']=='layer3'].obs['celltype'].tolist(),
    adata[adata.obs['region']=='layer4'].obs['celltype'].tolist(),
]

name = ['layer1', 'layer2', 'layer3', 'layer4']

radius_list = [200, 200, 200, 200]

fig, ax = plt.subplots(figsize=(6, 6))

for i, (cell_coordinates, cluster_labels, radius) in enumerate(zip(cell_coordinates_list, cluster_labels_list, radius_list)):
    complexities, complexity_probabilities = calculate_neighborhood_complexity(cell_coordinates, cluster_labels, radius)

    sns.kdeplot(complexities, cumulative=True, fill=False, ax=ax, label=name[i])

ax.set_xlabel('Neighborhood Complexity')
ax.set_ylabel('Probability (%)')
ax.set_title('Brain Neighborhood Complexity Distribution')
ax.legend(loc='lower right')

# Format y-axis labels as percentages
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{x*100:.0f}'))
plt.tight_layout()
plt.savefig('Brain_Neighborhood_Complexity_Distribution.pdf')
plt.show()

so the result is the picture of Cell