%matplotlib inline
%load_ext autoreload
%autoreload 2
%load_ext line_profiler
import scanpy as sc
import random
from unicoord import scu
from unicoord.visualization import draw_loss_curves
import torch
from line_profiler import LineProfiler
sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_header()
# sc.settings.set_figure_params(dpi=80, facecolor='white')
sc.settings.set_figure_params(vector_friendly=False)
adata = sc.read_h5ad('h5ad/Lung.Adult.pp.h5ad')
sc.pl.embedding(adata, 'X_umap',legend_loc='on data', legend_fontsize=10,
color=['leiden','donor_id','seq_tech','cell_type'], ncols=2)
adata = adata.raw.to_adata()
sc.pp.normalize_total(adata, target_sum=1e4 ,exclude_highly_expressed= True)
sc.pp.log1p(adata)
adata
adata.obs.donor_id.value_counts()
adata.obs.seq_tech.value_counts()
scu.model_unicoord_in_adata(adata, n_cont=30, n_diff=0, n_clus = [],
obs_fitting=['seq_tech'])
scu.train_unicoord_in_adata(adata, epochs=100, chunk_size=20000, slot = "cur")
fig = draw_loss_curves(adata.uns['unc_stuffs']['loss'])
# if save_figs:
# fig.savefig(os.path.join(savePath, 'img', 'fig1_lossCurves.png'))
fig.show()
scu.embed_unicoord_in_adata(adata, chunk_size=5000)
sc.pp.neighbors(adata, use_rep='unicoord')
sc.tl.leiden(adata, resolution=0.5)
sc.tl.umap(adata)
sc.pl.embedding(adata, 'X_umap',legend_loc='on data', legend_fontsize=10,
color=['leiden','donor_id','seq_tech','cell_type'], ncols=2)
scu.model_unicoord_in_adata(adata, n_cont=100, n_diff=0, n_clus = [],
obs_fitting=[])
scu.train_unicoord_in_adata(adata, epochs=100, chunk_size=20000, slot = "cur")
fig = draw_loss_curves(adata.uns['unc_stuffs']['loss'])
# if save_figs:
# fig.savefig(os.path.join(savePath, 'img', 'fig1_lossCurves.png'))
fig.show()
scu.embed_unicoord_in_adata(adata, chunk_size=5000)
sc.pp.neighbors(adata, use_rep='unicoord')
sc.tl.leiden(adata, resolution=0.5)
sc.tl.umap(adata)
sc.pl.embedding(adata, 'X_umap',legend_loc='on data', legend_fontsize=10,
color=['leiden','donor_id','seq_tech','cell_type'], ncols=2)