# ========================================================-
# Title:         Seurat analysis of scRNAseq analysis from cells with transgenes
# Description:   Code to generate Fig5, SuppFig15
# Author:        Monah Abou Alezz
# Date:          2025-03-06
# ========================================================-
suppressPackageStartupMessages(library(Seurat))
suppressPackageStartupMessages(library(tidyverse))
suppressPackageStartupMessages(library(reshape2))
suppressPackageStartupMessages(library(ggrepel))
suppressPackageStartupMessages(library(clusterProfiler))
suppressPackageStartupMessages(library(RColorBrewer))
## load seurat data
h48_final <- readRDS("data/organoids_h48_final.rds")
d5_final <- readRDS("data/organoids_D5_final.rds")
Idents(h48_final) <- "RNA_snn_h.orig.ident_res.0.6"
Idents(d5_final) <- "RNA_snn_h.orig.ident_res.0.6"
h48_final@meta.data[['orig.ident']] <- recode_factor(
  h48_final@meta.data[["orig.ident"]],
  "Sample1" = "AAV2",
  "Sample3" = "AAV9",
  "Sample5" = "Spk",
  "Sample7" = "UT")
h48_final@meta.data[["RNA_snn_h.orig.ident_res.0.6"]] <- recode_factor(
  h48_final@meta.data[["RNA_snn_h.orig.ident_res.0.6"]], 
  "0" = "Astrocytes", 
  "1" = "Immature.Astrocytes", 
  "2" = "Undifferentiated", 
  "3" = "Neurons", 
  "4" = "Immature.Neurons", 
  "5" = "OPC", 
  "6" = "Astrocytes", 
  "7" = "Oligodendrocytes", 
  "8" = "Undifferentiated",
  "9" = "Astrocytes",
  "10" = "Undifferentiated",
  "11" = "Undifferentiated")
h48_final$RNA_snn_h.orig.ident_res.0.6 <- factor(
  h48_final$RNA_snn_h.orig.ident_res.0.6,
  levels = c("Astrocytes",
             "Neurons",
             "Oligodendrocytes",
             "Immature.Astrocytes",
             "Immature.Neurons",
             "OPC",
             "Undifferentiated"))
d5_final@meta.data[['orig.ident']] <- recode_factor(
  d5_final@meta.data[["orig.ident"]],
  "Sample2" = "AAV2",
  "Sample4" = "AAV9",
  "Sample6" = "Spk",
  "Sample8" = "UT")
d5_final@meta.data[["RNA_snn_h.orig.ident_res.0.6"]] <- recode_factor(
  d5_final@meta.data[["RNA_snn_h.orig.ident_res.0.6"]], 
  "0" = "Astrocytes", 
  "1" = "Immature.Neurons", 
  "2" = "Undifferentiated", 
  "3" = "Neurons", 
  "4" = "Oligodendrocytes", 
  "5" = "Undifferentiated", 
  "6" = "Undifferentiated", 
  "7" = "Immature.Astrocytes", 
  "8" = "OPC",
  "9" = "Immature.Neurons",
  "10" = "Astrocytes",
  "11" = "Undifferentiated",
  "12" = "Astrocytes",
  "13" = "Neurons")
d5_final$RNA_snn_h.orig.ident_res.0.6 <- factor(
  d5_final$RNA_snn_h.orig.ident_res.0.6,
  levels = c("Astrocytes",
             "Neurons",
             "Oligodendrocytes",
             "Immature.Astrocytes",
             "Immature.Neurons",
             "OPC",
             "Undifferentiated"))
ggplotColours <- function(n = 6, h = c(0, 360) + 15){
  if ((diff(h) %% 360) < 1) h[2] <- h[2] - 360/n
  hcl(h = (seq(h[1], h[2], length = n)), c = 100, l = 65)
}
## umap
DimPlot(object = h48_final,
        reduction = "umap.harmony.orig.ident",
        group.by = "RNA_snn_h.orig.ident_res.0.6",
        label = T,
        repel = T,
        pt.size = 0.2,
        label.size = 6) +
  scale_color_manual(labels=c("Astrocytes",
                              "Neurons",
                              "Oligodendrocytes",
                              "Immature Astrocytes",
                              "Immature Neurons",
                              "OPC",
                              "Undifferentiated"),
                     values = ggplotColours(n = 7))

## donut plot
ggplot_object_clusters_labels <- melt(table(h48_final$RNA_snn_h.orig.ident_res.0.6))
colnames(ggplot_object_clusters_labels) <- c("Cell_type", "Count")
totalB <- ggplot_object_clusters_labels %>% mutate(percentage = Count/sum(Count))
totalB$percentage <- totalB$percentage*100
totalB$fraction = totalB$Count / sum(totalB$Count)
totalB$ymax <- cumsum(totalB$fraction)
totalB$ymin <- c(0, head(totalB$ymax, n=-1))
totalB$labelPosition <- (totalB$ymax + totalB$ymin) / 2
totalB$Samples <- c("Astrocytes", "Neurons", "Oligodendrocytes",
                    "Immature\nAstrocytes", "Immature\nNeurons",
                    "OPC", "Undifferentiated")
totalB$label <- paste0(totalB$Cell_type, ":","\n",round(totalB$percentage, 2), "%")
totalB$label2 <- c(paste0(totalB$Samples[1], ":","\n",round(totalB$percentage[1], 2), "%"),
                   paste0(totalB$Samples[2], ":","\n",round(totalB$percentage[2], 2), "%"),
                   paste0(totalB$Samples[3], ":","\n",round(totalB$percentage[3], 2), "%"),
                   paste0(totalB$Samples[4], ": ",round(totalB$percentage[4], 2), "%"),
                   paste0(totalB$Samples[5], ": ",round(totalB$percentage[5], 2), "%"),
                   paste0(totalB$Samples[6], ":","\n",round(totalB$percentage[6], 2), "%"),
                   paste0(totalB$Samples[7], ":","\n",round(totalB$percentage[7], 2), "%"))
ggplot(totalB, aes(ymax=ymax, ymin=ymin, xmax=4, xmin=3, fill=Cell_type)) +
  geom_rect() +
  geom_label_repel(x=3.5, aes(y=labelPosition, label=label), size=10) +
  coord_polar(theta="y") +
  xlim(c(2, 4)) +
  theme_void() +
  theme(legend.position = "none")
## barplot
PrctCellExpringGene <- function(object, genes, group.by = "all"){
  if(group.by == "all"){
    prct = unlist(lapply(genes,calc_helper, object=object))
    result = data.frame(Markers = genes, Cell_proportion = prct)
    return(result)
  }
  else{        
    list = SplitObject(object, group.by)
    factors = names(list)
    results = lapply(list, PrctCellExpringGene, genes=genes)
    for(i in 1:length(factors)){
      results[[i]]$Feature = factors[i]
    }
    combined = do.call("rbind", results)
    return(combined)
  }
}
calc_helper <- function(object,genes){
  counts = object[['RNA']]@counts
  ncells = ncol(counts)
  if(genes %in% row.names(counts)){
    sum(counts[genes,]>0)/ncells
  }else{return(NA)}
}

aav9 <- subset(h48_final, orig.ident=="AAV9")
gfp_percentage_cell_type <- PrctCellExpringGene(aav9,
                                                genes ="GFP",
                                                group.by = "RNA_snn_h.orig.ident_res.0.6")
gfp_percentage_cell_type$Percentage <- gfp_percentage_cell_type$Cell_proportion*100
gfp_percentage_cell_type <- gfp_percentage_cell_type[,c(4,3)]
colnames(gfp_percentage_cell_type) <- c("Percentage", "Cell_Types")
gfp_percentage_cell_type <- gfp_percentage_cell_type[c(7,1,6,4,5,3,2),]
gfp_percentage_cell_type$Cell_Types <- factor(gfp_percentage_cell_type$Cell_Types,
                                              levels = c("Immature.Neurons",
                                                         "Neurons",
                                                         "Immature.Astrocytes",
                                                         "Astrocytes",
                                                         "OPC",
                                                         "Oligodendrocytes",
                                                         "Undifferentiated"))
custom_colors <- c("Immature.Neurons" = "#00B6EB",
                   "Neurons" = "#C49A00",
                   "Immature.Astrocytes" = "#00C094",
                   "Astrocytes" = "#F8766D",
                   "OPC" = "#A58AFF",
                   "Oligodendrocytes" = "#53B400",
                   "Undifferentiated" = "#FB61D7")
ggplot(data=gfp_percentage_cell_type, 
       aes(x=Cell_Types, 
           y=Percentage, 
           fill = Cell_Types)) +
  geom_bar(stat="identity", 
           width = 0.7,
           position=position_dodge()) + # width of bars
  scale_fill_manual(values = custom_colors) +
  ylab("% GFP mRNA positive") +
  theme_bw() +
  theme(
    text = element_text(family = "Arial"),
    axis.title.x = element_blank(),
    axis.title.y = element_text(size=20*96/72),
    axis.text.x = element_text(size=16*96/72, angle = 45, vjust = 0.9, hjust = 0.9),
    axis.text.y = element_text(size=18*96/72),
    legend.position = "none",
    aspect.ratio = 1.1)
h48_aav9_ut <- subset(h48_final, orig.ident == "AAV9" | orig.ident == "UT") 
d4_aav9_ut <- subset(d4_final, orig.ident == "AAV9" | orig.ident == "UT")
DimPlot(object = h48_aav9_ut,
        reduction = "umap.harmony.orig.ident",
        group.by = "RNA_snn_h.orig.ident_res.1.2",
        split.by = "orig.ident",
        label = F,
        repel = T,
        pt.size = 0.2,
        label.size = 6,
        ncol = 2) + 
  scale_color_manual(labels=c("Astrocytes",
                              "Neurons",
                              "Oligodendrocytes",
                              "Immature Astrocytes",
                              "Immature Neurons",
                              "OPC",
                              "Undifferentiated"),
                     values = ggplotColours(n = 7))
DimPlot(object = d4_aav9_ut,
        reduction = "umap.harmony.orig.ident",
        group.by = "RNA_snn_h.orig.ident_res.1.2",
        split.by = "orig.ident",
        label = F,
        repel = T,
        pt.size = 0.2,
        label.size = 6,
        ncol = 2) + 
  scale_color_manual(labels=c("Astrocytes",
                              "Neurons",
                              "Oligodendrocytes",
                              "Immature Astrocytes",
                              "Immature Neurons",
                              "OPC",
                              "Undifferentiated"),
                     values = ggplotColours(n = 7))
samples <- c("AAV9", "UT")
tp <- c("h48","d4")
for (t in tp) {
  obj_t <- get(paste0(t,"_final"))
  for (i in samples) {
    obj_i <- subset(obj_t, orig.ident == i)
    ggplot_object_clusters_labels <- melt(table(obj_i$RNA_snn_h.orig.ident_res.1.2))
    colnames(ggplot_object_clusters_labels) <- c("Cell_type", "Count")
    totalB <- ggplot_object_clusters_labels %>% mutate(percentage = Count/sum(Count))
    totalB$percentage <- totalB$percentage*100
    totalB$fraction = totalB$Count / sum(totalB$Count)
    totalB$ymax <- cumsum(totalB$fraction)
    totalB$ymin <- c(0, head(totalB$ymax, n=-1))
    totalB$labelPosition <- (totalB$ymax + totalB$ymin) / 2
    totalB$Samples <- c("Astrocytes", "Neurons", "Oligodendrocytes",
                        "Immature\nAstrocytes", "Immature\nNeurons",
                        "OPC", "Undifferentiated")
    totalB$label <- paste0(totalB$Cell_type, ":","\n",round(totalB$percentage, 2), "%")
    totalB$label2 <- c(paste0(totalB$Samples[1], ":","\n",round(totalB$percentage[1], 2), "%"),
                       paste0(totalB$Samples[2], ":","\n",round(totalB$percentage[2], 2), "%"),
                       paste0(totalB$Samples[3], ":","\n",round(totalB$percentage[3], 2), "%"),
                       paste0(totalB$Samples[4], ": ",round(totalB$percentage[4], 2), "%"),
                       paste0(totalB$Samples[5], ": ",round(totalB$percentage[5], 2), "%"),
                       paste0(totalB$Samples[6], ":","\n",round(totalB$percentage[6], 2), "%"),
                       paste0(totalB$Samples[7], ":","\n",round(totalB$percentage[7], 2), "%"))
    donut_i <- ggplot(totalB, aes(ymax=ymax, ymin=ymin, xmax=4, xmin=3, fill=Cell_type)) +
      geom_rect() +
      geom_label_repel(x=3.5, aes(y=labelPosition, label=label), size=10) +
      coord_polar(theta="y") +
      xlim(c(2, 4)) +
      ggtitle(paste0(i)) +
      theme_void() +
      theme(legend.position = "none",
            plot.title = element_text(hjust = 0.5, size = 50, family = "Arial"),
            plot.title.position = "plot")
    assign(paste0("donut_",t,"_", i), donut_i)
  }
}
donut_h48_final <- (donut_h48_AAV9 | donut_h48_UT)
donut_d4_final <- (donut_d4_AAV9 | donut_d4_UT)
## markers
for (t in tp) {
  obj_t <- get(paste0(t,"_final"))
  neurons <- subset(
    obj_t, 
    RNA_snn_h.orig.ident_res.1.2=="Neurons" | RNA_snn_h.orig.ident_res.1.2=="Immature.Neurons")
  astro <- subset(
    obj_t, 
    RNA_snn_h.orig.ident_res.1.2=="Astrocytes" | RNA_snn_h.orig.ident_res.1.2=="Immature.Astrocytes")
  oligo <- subset(
    obj_t, 
    RNA_snn_h.orig.ident_res.1.2=="Oligodendrocytes" | RNA_snn_h.orig.ident_res.1.2=="OPC")
  assign(paste0("neurons_",t,"_obj_x_markers"), neurons)
  assign(paste0("astro_",t,"_obj_x_markers"), astro)
  assign(paste0("oligo_",t,"_obj_x_markers"), oligo)
}
for (k in ls()[grep("_obj_x_markers", ls())]) {
  cells <- get(k)
  degs <- FindMarkers(object = cells, 
                      ident.1 = "AAV9", 
                      ident.2 = "UT", 
                      group.by = "orig.ident",
                      only.pos = FALSE, 
                      min.pct = 0, 
                      test.use = "wilcox", 
                      logfc.threshold = 0,
                      min.cells.group = 0)
  assign(paste0(k,"_aav9_vs_ut_degs"), degs)
}
## gsea
h_gmt <- read.gmt("data/h.all.v7.2.symbols.gmt")
for (i in ls()[grep("aav9_vs_ut_degs", ls())]) {
  gene_res <- get(i)
  logfc_symbol <- as.vector(gene_res$avg_log2FC)
  names(logfc_symbol) <- row.names(gene_res)
  logfc_symbol <- sort(logfc_symbol, decreasing = TRUE)
  contr.gsea <- GSEA(logfc_symbol, 
                     TERM2GENE = h_gmt, 
                     nPerm = 10000, 
                     pvalueCutoff = 1)
  contr.gsea.symb <- as.data.frame(contr.gsea)
  assign(paste0("GSEA_",i), contr.gsea.symb)
}
## GSEA heatmaps
type <- c("astro", "neurons", "oligo")
for (p in type) {
  hallmark.full <- data.frame()
  for (ds in ls()[grep(paste0("GSEA_", p), ls())]) {
    ds.t <- get(ds)
    ds.t.filt <- ds.t[,c("ID", "setSize", "enrichmentScore", "NES", "pvalue", "p.adjust", "qvalues"), drop = FALSE]
    ds.t.filt$Dataset <- ds
    if (nrow(hallmark.full) == 0) {
      hallmark.full <- ds.t.filt
    } else {
      hallmark.full <- rbind(hallmark.full, ds.t.filt)
    }
  }
  hallmark.full.filt <- hallmark.full[, c("ID", "NES", "Dataset", "qvalues")]
  hallmark.full.filt$ID <- gsub(x = hallmark.full.filt$ID, "HALLMARK_", "")
  hallmark.full.filt$sig <- ifelse(
    hallmark.full.filt$qvalues <= 0.05 & hallmark.full.filt$qvalues > 0.01, '*', 
    ifelse(hallmark.full.filt$qvalues <= 0.01 & hallmark.full.filt$qvalues > 0.001, '**', 
           ifelse(hallmark.full.filt$qvalues <= 0.001 & hallmark.full.filt$qvalues > 0.0001, '***', 
                  ifelse(hallmark.full.filt$qvalues <= 0.0001, '****', ''))))
  hallmark.order <- hallmark.full.filt %>% 
    group_by(ID) %>% 
    summarise(Pos = sum(NES))
  hallmark.order.terms <- hallmark.order[order(hallmark.order$Pos, decreasing = FALSE), "ID", drop = FALSE]
  hallmark.full.filt.tt <- reshape2::melt(reshape2::dcast(hallmark.full.filt, 
                                                          ID ~ Dataset, 
                                                          value.var="NES"), id.vars = c("ID"))
  colnames(hallmark.full.filt.tt) <- c("ID", "Dataset", "NES")
  hallmark.full.filt.tt$ID <- factor(hallmark.full.filt.tt$ID, levels = hallmark.order.terms$ID)
  hallmark.full.filt.tt <- merge(hallmark.full.filt.tt, 
                                 hallmark.full.filt[, c("ID", "Dataset", "sig")], 
                                 by = c("ID", "Dataset"), 
                                 all.x = TRUE)
  labels <- c("Day4", "Day2")
  p.hallmark <- ggplot(hallmark.full.filt.tt, 
                       aes(x = Dataset, y = ID)) +
    geom_tile(aes(fill = NES), colour = "white") +
    geom_text(aes(label = paste(sig)), size=4*96/72, fontface = "bold") +
    scale_fill_gradientn(colours = colorRampPalette(rev(brewer.pal(11,"RdBu")))(100),
                         limits = c(-3.5, 3.5),
                         na.value = "white") +
    ylab("") + 
    xlab("") +
    coord_fixed(ratio = 0.6) +
    scale_x_discrete(labels = labels) +
    theme_bw(base_size = 12) +
    theme(axis.text.x = element_text(angle = 45, hjust = 1, face = "bold", size = 18*96/72),
          axis.text.y = element_text(face = "bold", size = 12*96/72),
          legend.text = element_text(face = "bold", size = 14*96/72),
          legend.title = element_text(face = "bold", size = 12*96/72))
  assign(paste0("p.hallmark_", p), p.hallmark)
}