# Niche reconstruction and spatial domain detection

Authors: Francesca Drummer, Marco Varrone

In this notebook we will cover:

1. Graph construction and analysis of spatial transcriptomics data using Squidpy
2. CellCharter
3. BANKSY

In [None]:
# Data analysis and ML imports
import pandas as pd
import matplotlib.pyplot as plt

# single-cell imports
import squidpy as sq
import scanpy as sc

from pathlib import Path
import os

import warnings
warnings.filterwarnings("ignore")

## Dataset

We will use the Xenium AD dataset from the previous notebooks here.

As a reminder the dataset consists of 6 coronal mouse brain slices from 2 different conditions (wildtype - ctrl vs TgCRND8 - AD) across 3 timepoints. In this practical, we additionally have information about cell types available in  `adata.obs['cell_types']`. Please note that these annotation are not perfect. For example, there are quite some cells that could not be assigned to a cell type (NaN or "unkown"). These annotations have been made with on leiden clustering and marker genes reported in [this](https://pages.10xgenomics.com/rs/446-PBO-704/images/10x_LIT000210_App-Note_Xenium-In-Situ_Letter_Digital.pdf) document. 

In this practical we aim to understand the differences of the mouse brain between the two conditions and across the timepoints using niches and spatial domains.

In [None]:
PATH = "/data/spatial_workshop/day3/practical_4"

In [None]:
# load adata
adata = sc.read_h5ad(Path(PATH, 'xenium_mouse_ad_annotated_rotated.h5ad'))
adata

In [None]:
# Creating a DataFrame from 'split', 'fov', and 'condition'
df = adata.obs[['condition', 'time', 'batch_key']]
value_counts = pd.DataFrame(df.values, columns=df.columns).value_counts()
print(value_counts)

In [None]:
# Add "Unknown" as a category
adata.obs["cell_types"] = adata.obs["cell_types"].cat.add_categories("Unknown")

# Fill NaN values with "Unknown"
adata.obs["cell_types"] = adata.obs["cell_types"].fillna("Unknown")

## 1. Cell neighborhood detection via graph construction

Spatial transcriptomics data can be represented as graphs with cells as nodes and edges as relations. Depending on the technology (imaging-based or sequencing-based) different assumptions can be made for the graph structure. 
Here we will explore graph construction and cell neighborhood analysis using the squidpy `sq.gr.spatial_neighbors` module on imaging-based data. 

*Information for graph reconstruciton for sequencing-based data can be found [here](https://squidpy.readthedocs.io/en/stable/notebooks/examples/graph/compute_spatial_neighbors.html).*

For image-based technologies we construct a graph with `coord_type=generic`, meaning that the nodes / cells will preserve their spatial location and will not be re-arranged in a spatial grid. For generic graph approaches we can choose to set a fixed radius `radius` or number of neighbors to connect to (`n_neighs`).

Below we try both graph construction methods and give them a different `key` to add to the `adata` object.

In [None]:
if 'cell_types_colors' in adata.uns:
    del adata.uns['cell_types_colors']
    
sq.gr.spatial_neighbors(adata, n_neighs=5, coord_type="generic", key_added = 'neighs_based_spatial')
sq.pl.spatial_scatter(
    adata,
    shape=None,
    library_key = 'sample',
    color=["cell_types"],
    connectivity_key="neighs_based_spatial_connectivities",
    title=adata.obs['sample'].cat.categories,
    ncols=3,
    size = 10
)

In [None]:
sq.gr.spatial_neighbors(adata, radius=0.2, coord_type="generic", key_added = 'radius_based_spatial')
sq.pl.spatial_scatter(
    adata,
    shape=None,
    library_key = 'sample', 
    color="cell_types",
    connectivity_key="radius_based_spatial_connectivities",
    title=adata.obs['sample'].cat.categories,
    ncols=3,
    size=10,
)

First differences across the graphs can be observed by visual inspection, e.i. the cells that are further away from the clear brain structures are connected in the clostest neighbor approach but unconnected in the radius-based approach. Some other differences are harder to observe like dependencies in the more dense connected regions. For this Squidpy provides a number of statistics to better the differences between the connectivities of cells.

Using the cell type information we can formulate some hypothesis from the data.

To perform downstream analysis you require a graph structure or radius that connects the cells of interest. For example, if you are interested in the communication between Microglia and Oligodendrocytes you need to set a radius such that the cell typer connections exists.

<span style="color: red;">**Task 1:** Compute the [interaction matrix](https://squidpy.readthedocs.io/en/stable/notebooks/examples/graph/compute_interaction_matrix.html) to understand differences between neighborhood and radius based graphs. Find an appropriate number of neighbors or radius that include sufficient connections between your cell types of interest.</span>

Example questions to answer could be: What are the average number of connections per cell type? Which cell types tend to cluster together? 

In [None]:
sq.gr.interaction_matrix(adata, cluster_key="cell_types", connectivity_key="neighs_based_spatial")
sq.pl.interaction_matrix(adata, cluster_key="cell_types", connectivity_key="neighs_based_spatial")

In [None]:
#TODO: Interaction matrix for radius based graph

<span style="color: red;">**Task 2:** Try out some analysis from Squidpy to identify distinctions between the graphs e.i. using the [centrality score](https://squidpy.readthedocs.io/en/stable/notebooks/examples/graph/compute_centrality_scores.html) and [neighborhood enrichment](https://squidpy.readthedocs.io/en/stable/notebooks/examples/graph/compute_nhood_enrichment.html). </span>

In [None]:
sq.gr.centrality_scores(adata, cluster_key = "cell_types", connectivity_key = "neighs_based_spatial")
sq.pl.centrality_scores(adata, cluster_key = "cell_types")

In [None]:
sq.gr.nhood_enrichment(adata, cluster_key="cell_types", library_key = 'condition', connectivity_key = "radius_based_spatial")
sq.pl.nhood_enrichment(
    adata, cluster_key="cell_types", method="average", figsize=(5, 5)
) 

## 2. Spatial Domain detection with CellCharter

In [None]:
import scvi
import scanpy as sc
from pathlib import Path
import matplotlib.pyplot as plt
import squidpy as sq
import numpy as np
import cellcharter as cc
import os
import logging
logger = logging.getLogger('pytorch_lightning.utilities.rank_zero')
logger.setLevel(logging.ERROR)

In [None]:
scvi.settings.seed = 12345
scvi.settings.num_threads = 2

There may be cells with very low counts. We will filter them out.

In [None]:
sc.pp.filter_cells(adata, min_counts=3)

### 2.1 Dimensionality reduction

First, we need to run the dimensionality reduction. We will use [scVI](https://docs.scvi-tools.org/en/latest/api/reference/scvi.model.SCVI.html) for this.

Make sure that `adata.X` contains count data, as it is required by scVI. <br>
In some tutorials you will see that the count data is stored in an AnnData layer called `counts`. Here we will use the `X` layer.

In [None]:
scvi.model.SCVI.setup_anndata(adata)

In [None]:
LOAD_MODEL = True

We set the parameters of the neural network. Here we use 1 layer and an embedding size of 10.<br>
These are very common parameters, but you can play with them to see how they affect the results.

In [None]:
if LOAD_MODEL:
    model = scvi.model.SCVI.load(os.path.join(PATH, 'scvi_model'), adata=adata)
else:
    model = scvi.model.SCVI(
        adata,
        n_layers=1,
        n_latent=10,
        use_layer_norm="both",
        use_batch_norm="none",
    )
    model.train(early_stopping=True, enable_progress_bar=True, max_epochs=10)

To make sure that the training has converged, we can plot the training history.<br>
Here we is important to focus on the (validation) reconstruction loss. This shows how well the model is able to reconstruct the original data.

In [None]:
plt.figure(figsize=(5, 5))
plt.plot(
    model.history[f"reconstruction_loss_train"],
    label="train",
    color="darkgreen",
    linewidth=1.25
)
plt.plot(
    model.history[f"reconstruction_loss_validation"],
    label="validation",
    color="firebrick",
    linewidth=1.25
    )
plt.legend()
plt.title("reconstruction_loss")  
plt.tight_layout()

We extract the latent representation of the data.

In [None]:
adata.obsm['X_scVI'] = model.get_latent_representation(adata).astype(np.float32)

### 2.2 Neighborhood aggregation

Until now, all the analyses have been done ignoring the spatial information. <br>
Now we will use the spatial information to perform the clustering.

First, we need to create a network where cells are connected if they are close to each other using the Delaunay triangulation.

In [None]:
sq.gr.spatial_neighbors(adata, library_key='sample', coord_type='generic')

In [None]:
sq.pl.spatial_scatter(
    adata, 
    shape=None, 
    library_key='sample',
    library_id=adata.obs['sample'].cat.categories[0],
    color="sample", 
    size=1, 
    figsize=(10,10),
    connectivity_key="spatial_connectivities",
    ncols=1
)

As you can see, the Delaunay triangulation generates very long edges for a few cells. <br>
The most appropriate approach would be to estimate the most biologically relevant distance between cells and use it to remove the edges longer than that distance.

A quick alternative solution is that, since those long edges are sort of outlines, the measure the 99th percentile of the edge lengths and remove the edges longer than that distance. <br>
This process will lead to some isolated cells, but that's not a problem.


In [None]:
sq.gr.spatial_neighbors(adata, library_key='sample', coord_type='generic', percentile=99)

In [None]:
sq.pl.spatial_scatter(
    adata, 
    shape=None, 
    library_key='sample',
    library_id=adata.obs['sample'].cat.categories[0],
    color="sample", 
    size=1, 
    figsize=(10,10),
    connectivity_key="spatial_connectivities",
    ncols=1
)

We construct the neighborhood aggregated representation by combining every cell's features with the ones of first 3 layers of neighbors.<br>
We will use the scVI latent representation as features and the new features will be stored in `adata.obsm['X_cellcharter']`.

In [None]:
cc.gr.aggregate_neighbors(adata, n_layers=3, use_rep='X_scVI', out_key='X_cellcharter_temp', sample_key='sample')

<span style="color: red;">**Task 3:** Given that `X_scVI` contains 10 features, how many features does `X_cellcharter` contain?<br>
Guess the answer and then write the code to check it.
</span>

### 2.3 Clustering

Finally, let's cluster the cells and find their spatial domains.<br>
We will use 18 clusters, we will see later why this is a good choice.

In [None]:
gmm = cc.tl.Cluster(n_clusters=18, random_state=12345)
gmm.fit(adata, use_rep='X_cellcharter_temp')
adata.obs['spatial_domain_temp'] = gmm.predict(adata, use_rep='X_cellcharter_temp')

We can now plot domains and cell types back to back

In [None]:
if 'spatial_domain_temp_colors' in adata.uns:
    del adata.uns['spatial_domain_temp_colors']

sq.pl.spatial_scatter(
    adata, 
    shape=None, 
    library_key='sample', 
    color=["spatial_domain_temp", "cell_types"], 
    size=1,
    figsize=(10,10),
    title=np.repeat(adata.obs['sample'].cat.categories, 2),
    ncols=2
)

### 2.4 Downstream analyses

The last section was to show how CellCharter can be used to obtain spatial domains.

However, to have an easier discussion and interpretation of the results, we want everyone to have the same results.

Therefore, we will use the spatial domains computed in advance using the `ClusterAutoK` model that also runs the stability analysis to find the optimal number of clusters.
The model has been computed using 3 layers of neighbors (it took 25 minutes on a GPU).

We have to:
1. load the features from the pretrained scvi model to generate the same features used by fitted `ClusterAutoK`
2. aggregate the neighborhood features using the same number of layers as used by `ClusterAutoK` (3 layers of neighbors)
3. load the `ClusterAutoK` model
4. plot the stability curve
5. look for the peak(s) to identify the optimal number of clusters.

In [None]:
model = scvi.model.SCVI.load(os.path.join(PATH, 'scvi_model'), adata=adata)
adata.obsm['X_scVI'] = model.get_latent_representation(adata).astype(np.float32)
cc.gr.aggregate_neighbors(adata, n_layers=3, use_rep='X_scVI', out_key='X_cellcharter', sample_key='sample')

autok = cc.tl.ClusterAutoK.load(Path(PATH, 'autok_l3'))
cc.pl.autok_stability(autok)

# If it takes too long also in this case, you can load the domain labels from the data folder
# adata.obs['spatial_domain_18'] = pd.read_csv(Path(PATH, 'spatial_domains_cellcharter.csv'), index_col=0)['spatial_domain_18']
# adata.obs['spatial_domain_18'] = adata.obs['spatial_domain_18'].astype('category')

We don't see a clear single peak, but we can see that the highest stability is for 18 clusters.

In [None]:
adata.obs['spatial_domain_18'] = autok.predict(adata, use_rep='X_cellcharter', k=18)

In [None]:
if 'spatial_domain_18_colors' in adata.uns:
    del adata.uns['spatial_domain_18_colors']

sq.pl.spatial_scatter(
    adata, 
    shape=None, 
    library_key='sample', 
    color=["spatial_domain_18", "cell_types"], 
    size=1,
    figsize=(10,10),
    title=np.repeat(adata.obs['sample'].cat.categories, 2),
    ncols=2
)

Now that we have the spatial domains, we can look at the cell type enrichment in each domain.

This function measures the likelihood of a cell type being found in a domain compared to random chance.

In [None]:
cc.gr.enrichment(
    adata,
    group_key='spatial_domain_18',
    label_key='cell_types',
)
cc.pl.enrichment(
    adata,
    group_key='spatial_domain_18',
    label_key='cell_types',
    dot_scale=8
)

<span style="color: red;">**Task 4:**
The [Allen Brain Atlas](https://atlas.brain-map.org/atlas?atlas=1&plate=100960076#atlas=1&plate=100960520&resolution=6.98&x=5512.001546223959&y=3967.997233072917&zoom=-2) provides images and annotation of mouse brain samples at multiple depths.<br><br>
Choose one of the six samples, go to the atlas and try to find the image that better matches the regions shown by CellCharter.</span>

<span style="color: red;">1. Which of the 132 images is the closest one?</span><br>
<span style="color: red;">2. After identifying the closest image, try to match the domains obtained with the annotated regions in the atlas.</span><br>
<span style="color: red;">3. Do you find anatomical differences between the two conditions? are there differences in spatial domain compositions? Are there unique domains in one of the two conditions (Alzheimer vs wildtype)?</span><br>
<span style="color: red;">&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- Focus, in particular on comparing samples from mice of the same age.</span><br>
<span style="color: red;">&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- If you find difference between the two samples, what do you think is the cause?</span><br>


<span style="color: red;">**Task 5:** Now, save the spatial domains corresponding to 9 and 23 clusters into the `adata.obs` column `spatial_domain_9` and `spatial_domain_23`.</span>


In [None]:
sq.pl.spatial_scatter(
    adata, 
    shape=None, 
    library_key='sample', 
    color=["spatial_domain_9", "spatial_domain_18", "spatial_domain_23"], 
    size=1,
    figsize=(10,10),
    title=np.repeat(adata.obs['sample'].cat.categories, 3),
    ncols=3
)

The three levels of clustering form a sort of hierarchical structure.<br>
9-cluster domains tend to separate into subdomains in the 18 clusters, and so on.

<span style="color: red;">**Task 6:** Go back to the Allen Brain Atlas and look if some of the hierarchical structure is reflected in the atlas.</span>


#### Shape characterization

Now, we are going to look at the shapes of the domains we obtained.

We first find the local components of the domains.
This is done by first finding the connected components of the spatial graph.

In [None]:
cc.gr.connected_components(adata, cluster_key='spatial_domain_18', min_cells=100)

In [None]:
# Hackfix: squidpy's spatial_scatter has some issues with categorical data with NaNs.
adata.obs['component_tmp'] = adata.obs['component'].astype('str')

In [None]:
if 'component_tmp_colors' in adata.uns:
    del adata.uns['component_tmp_colors']

sq.pl.spatial_scatter(
    adata[(adata.obs['sample'].isin(['TgCRND8_17_9', 'TgCRND8_2_5']))], 
    shape=None, 
    library_key='sample', 
    color=["component_tmp", "spatial_domain_18"], 
    size=1,
    figsize=(10,10),
    title=np.repeat(['TgCRND8_17_9', 'TgCRND8_2_5'], 2),
    ncols=2
)

Then draw a boundary around every component

In [None]:
cc.tl.boundaries(adata, alpha_start=10)

In [None]:
from cellcharter_utils import plot_boundaries, plot_shape_metrics

In [None]:
plot_boundaries(adata, sample='wildtype_5_7', show_cells=True, cells_radius=10)

And finally, by computing some shape metrics that we will use to compare the domains.

In [None]:
cc.tl.curl(adata)
cc.tl.linearity(adata)

Let's look at the shape metrics of one of the cortex layers, domain 8.

In [None]:
plot_shape_metrics(adata, cluster_key='spatial_domain_18', figsize=(6,3), cluster_id=8, metrics=['curl', 'linearity'])

<span style="color: red;">**Task 7:** Look at the spatial domain 0.</span><br>
<span style="color: red;">&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- Is it going to have lower, same or higher linearity than domain 8?</span><br>
<span style="color: red;">&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- What about curl?</span>

<span style="color: red;">After you have answered, plot the shape metrics for domain 0 and check if it matches your expectations.</span>

Tip: you can pass multiple cluster ids to `plot_shape_metrics` in the form of a list to plot multiple domains at once.

<span style="color: red;">**Task 8:** Look at spatial domain 12.</span><br>
<span style="color: red;">&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- What do you expect to see in terms of curl and linearity?</span><br>
<span style="color: red;">&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- What about elongation?</span>

<span style="color: red;">After answering, compute the elongation and compare the shape metrics for domains 8 and 12.</span>

You can check [CellCharter's documentation](https://cellcharter.readthedocs.io/en/latest/tools.html) for the elongation metric.


## 3. BANKSY

In this notebook we are using the Python implementation of [BANSKY](https://github.com/prabhakarlab/Banksy_py). 

Let's reload the data to restore the original version.

In [None]:
adata = sc.read_h5ad(Path(PATH, 'xenium_mouse_ad_annotated_rotated.h5ad'))

We run BANKSY on a example section, select a section that you find interesting.

In [None]:
adata_section = adata[(adata.obs['time'] == '5_7') & (adata.obs['condition'] == 'wildtype')]
adata_section

In [None]:
adata_section.obsm['spatial'][:,0]

In [None]:
## add x and y coordinate to .obs (needed for plot later)
adata_section.obs['x'] = adata_section.obsm['spatial'][:, 0]
adata_section.obs['y'] = adata_section.obsm['spatial'][:, 1]

In [None]:
from banksy_utils.load_data import load_adata, display_adata

from banksy_utils.filter_utils import normalize_total, filter_hvg, print_max_min

# Normalizes the AnnData object
adata_section = normalize_total(adata_section)

BANSKY unifies spatially informed cell type and domain segmentation. In this notebook we will focus on the domain segmentation part. 

![../figures/BANKSY_fig1A.png](../figures/BANKSY_fig1A.png)

The idea is to create a representation of a cell using 1) its own transcriptomic profile (purple) and 2) local microenvironment (red + light pink). The local microenvironment is represented through a pair of spatial kernels that represent the mean gene expression of the local microenvironment (red) and its gradient calculated with the azimuthal Gabor filter (AGF) (light pink). 

The relative contribution of the microenvironment is captured by $\lambda$. Smaller settings of $\lambda$ decrease the influence of the cells within the microenvironment ($\lambda = 0$ reduces to nonspatial informed clustering). $G(r)$ is a radially symmetric Gaussian kernel that decays from magnitude 1 at distance = 0.

The main BANSKY algorithm requires: 

1. Creating a kNN graph by setting a spatial number of neighbords `num_neighbors` (`k_geom`) parameter. 
2. Assigning weights to the edges of the conected spatial graph. By default, we use the `gaussian decay` option, where weights decay as a function of distance to the index cell with $\sigma$ `= sigma`.
3. Defining whether to use the Azumithal Gabor Filter kernel (`max_m = 1`) or just the mean expression (`max_m = 0`).

First, we set the required parameters.

In [None]:
coord_keys = ('x', 'y', 'spatial')

# set parameters 
plot_graph_weights = True
k_geom = 15 # number of neighbors
max_m = 1 # azumithal transform up to kth order
nbr_weight_decay = "scaled_gaussian" # can also be "reciprocal", "uniform" or "ranked"

### Construct the $k_{geom}$-NN graph

In [None]:
from banksy.main import median_dist_to_nearest_neighbour

# Find median distance to closest neighbours, the median distance will be `sigma`
nbrs = median_dist_to_nearest_neighbour(adata_section, key = coord_keys[2])

### Generate spatial weights from distance

Here, we generate the spatial weights using the gaussian decay function from the median distance to the k-th nearest neighbours as specified earlier.

In [None]:
from banksy.initialize_banksy import initialize_banksy

plt.style.use('default')

banksy_dict = initialize_banksy(
    adata_section,
    coord_keys,
    k_geom,
    nbr_weight_decay=nbr_weight_decay,
    max_m=max_m,
    plt_edge_hist=True,
    plt_nbr_weights=True,
    plt_agf_angles=False, # takes long time to plot
    plt_theta=True,
)

### Generate BANKSY matrix

The BANKSY matrix considers a cells transcriptomic profile and local microenvironment (Figure 1). 

As mentioned before, $\lambda$ is a mixing parameter that controls the importance of cells’ own expression and neighborhood expression effects, it takes values from 0, being spatial information not used in the clustering, to 1, giving the maximum importance to the neighborhood expression.

o generate the BANKSY matrix, we proceed with the following:

1. Matrix multiply sparse CSR weights matrix with cell-gene matrix to get **neighbourhood matrix** and the **AGF matrix** if `max_m > 1`

2. Z-score both matrices along **genes**

3. Multiply each matrix by a weighting factor $\lambda$ (We refer to this parameter as lambda in our manuscript and code)

4. Concatenate the matrices along the genes dimension in the form -> `horizontal_concat(cell_mat, nbr_mat, agf_mat)`

Here, we save all the results in the dictionary (banksy_dict), which contains the results from the subsequent operations for BANKSY.

In [None]:
from banksy.embed_banksy import generate_banksy_matrix

# The following are the main hyperparameters for BANKSY
lambda_list = [0.6] # list of lambda parameters

banksy_dict, banksy_matrix = generate_banksy_matrix(adata_section, banksy_dict, lambda_list, max_m)

In [None]:
from banksy.main import concatenate_all

banksy_dict["nonspatial"] = {
    # Here we simply append the nonspatial matrix (adata.X) to obtain the nonspatial clustering results
    0.0: {"adata": concatenate_all([adata_section.X], 0, adata=adata_section), }
}

print(banksy_dict['nonspatial'][0.0]['adata'])

### Reduce dimensions of each data matrix

We utilize two common methods for dimensionality reduction:

1. PCA (using `scikit-learn`), we reduce the size of thematrix from $3 * N_{genes}$ to `pca_dims`. As a default settings, we reduce to 20 dimensions.

3. UMAP (`UMAP` package), which we use to visualize expressions of clusters in the umap space (2-D space).

In [None]:
## Define hyperparameters

resolutions = [0.1] # clustering resolution for UMAP
pca_dims = [20] # Dimensionality in which PCA reduces to

In [None]:
from banksy_utils.umap_pca import pca_umap

pca_umap(banksy_dict,
         pca_dims = pca_dims,
         add_umap = True,
         plt_remaining_var = False,
         )

### Cluster cells using a partition algorithm

We then cluster cells using the **leiden** algorithm partition methods. Other clustering algorithms include *louvain* (another resolution based clustering algorithm), or *mclust* (a clustering based on gaussian mixture model).

In [None]:
from banksy.cluster_methods import run_Leiden_partition
seed = 0
results_df, max_num_labels = run_Leiden_partition(
    banksy_dict,
    resolutions,
    num_nn = 50,
    num_iterations = -1,
    partition_seed = seed,
    match_labels = True,
)

### Plot results

In [None]:
from banksy.plot_banksy import plot_results
import time

c_map =  'tab20' # specify color map
weights_graph =  banksy_dict['scaled_gaussian']['weights'][0]

In [None]:
banksy_path = f'./outputs/banksy_output/' 

plot_results(
    results_df,
    weights_graph,
    c_map,
    match_labels = True,
    coord_keys = coord_keys,
    max_num_labels  =  max_num_labels, 
    save_path = os.path.join(banksy_path, 'tmp_png'),
    save_fig = True, # save the spatial map of all clusters
    save_seperate_fig = True, # save the figure of all clusters plotted seperately
)

<!-- <div class="alert alert-block alert-danger">
    <b>Task:</b> Explore the robustness of BANSKY. 
    <ol>
        <li>What do you expect would happen when we increase <code>k_geom</code> (to 80)?</li>
        <li>What would you expect the <b>spatial clusters</b> and <b>cell type composition</b> to look like if we increase <code>lambda=0.99</code>?</li>
        <li>How do you expect the <b>spatial clusters</b> and <b>cell type composition</b> to change if <code>lambda</code> is decreasing? Set <code>lambda=0</code> </li>
    </ol>
</div>
 -->

### Investigate Cell Type composition in each Banksy-defined Spatial Domain

In [None]:
def plot_sd_vs_cell_type_composition(res_df,idx):
    """
    Plots the cell type composition as a percentage across different SD (standard deviation) values.
    The data is visualized as a stacked bar plot.

    Parameters:
    - results_df_lambda05: DataFrame containing the data with columns 'labels_scaled_gaussian_pc20_nc0.50_r0.10', 'class', and others.
    - idx: string of column of interest in the DataFrame
    
    Returns:
    - None
    """
    # Step 1: Add a 'Count' column to facilitate pivoting (each row contributes a count of 1)
    res_df.obs['Count'] = 1

    # Step 2: Create a pivot table with SD as the index, cell types as columns, and the sum of counts as values
    pivot_df = res_df.obs.pivot_table(
        index=idx,  # Group by SD
        columns='cell_types',  # Columns represent cell types
        values='Count',  # Aggregate the 'Count' column
        aggfunc='sum',  # Sum up counts for each combination
        fill_value=0  # Fill missing combinations with 0
    )

    # Ensure SD values are numeric
    pivot_df.index = pivot_df.index.astype(float)

    # Step 3: Convert counts to percentages for each SD
    # Divide each row by the row sum to get percentages, then multiply by 100
    pivot_df = pivot_df.div(pivot_df.sum(axis=1), axis=0) * 100

    # Step 4: Set up the plot
    fig, ax = plt.subplots(figsize=(10, 6))  # Define figure size

    # Plot stacked bars
    bottom = None  # Keeps track of the cumulative height of the bars
    for cell_type in pivot_df.columns:  # Loop through each cell type
        ax.bar(
            pivot_df.index,  # X-axis: SD values
            pivot_df[cell_type],  # Y-axis: Percentages for this cell type
            label=cell_type,  # Legend label
            bottom=bottom  # Stack on top of previous bars
        )
        # Update 'bottom' to include the current cell type's values
        bottom = pivot_df[cell_type] if bottom is None else bottom + pivot_df[cell_type]

    # Step 5: Add labels and title
    ax.set_xlabel('SD')  # Label for the x-axis
    ax.set_ylabel('Cell Type Composition (%)')  # Label for the y-axis
    ax.set_title('SD vs. Cell Type Composition')  # Title of the plot
    ax.set_ylim(0, 100)  # Set y-axis limits to [0, 100] to represent percentages

    # Add legend
    plt.legend(
        title="Cell Type",  # Title of the legend
        bbox_to_anchor=(1.05, 1),  # Position the legend outside the plot
        loc='upper left'  # Align the legend at the upper left corner
    )

    # Adjust layout to prevent overlap
    plt.tight_layout()
    # Step 6: Show the plot
    plt.show()

In [None]:
results_df

In [None]:
results_df.loc[idx]['adata']

In [None]:
idx='scaled_gaussian_pc20_nc0.60_r0.10'
results_df_lambda05 = results_df.loc[idx]['adata']
label_idx = f'labels_{idx}'
plot_sd_vs_cell_type_composition(results_df_lambda05, label_idx)

<span style="color: red;">**Task 9:** What do you expect would happen when we increase k_geom?</span><br>

<span style="color: red;">After answering, Re run Banksy changing k_geom to 80.</span>

<span style="color: red;">**Task 10:** What do you expect would happen when we set lambda = 0.99?</span><br>

<span style="color: red;">How do you expect the spatial clusters plot to look like?.</span>

<span style="color: red;">**Task 11:**  How do the Cell Type compositions of each Spatial Domains change as lambda is increasing? Set lambda=0</span><br>

## References

[1] Varrone, M., Tavernari, D., Santamaria-Martínez, A., Walsh, L. A. & Ciriello, G. CellCharter reveals spatial cell niches associated with tissue remodeling and cell plasticity. Nat Genet 56, 74–84 (2024).

[2] Singhal, V. et al. BANKSY unifies cell typing and tissue domain segmentation for scalable spatial omics data analysis. Nat Genet 56, 431–441 (2024).

[3] https://github.com/BrainOmicsCourse/BrainOmics2024/tree/main/3_Day3. Last access: 18.12.2024

[4] https://github.com/NBISweden/workshop-spatial/blob/main/labs/07b_spatial_domains.ipynb: Last access: 14.01.2025