diff --git a/panpipes/funcs/plotting.py b/panpipes/funcs/plotting.py index d4c180eb..77496302 100644 --- a/panpipes/funcs/plotting.py +++ b/panpipes/funcs/plotting.py @@ -35,28 +35,53 @@ def scatter_one(group_choice, col_choice, plot_df, axs=None, colour="#1f77b4", t def batch_scatter_two_var(plot_df, method, batch, palette_choice=None): + """Plots facetted umaps, with each group within the method and batch highlighted as foreground. + Args: + plot_df (pd.DataFrame): pandas dataframe contianin umap coordinates plus method and batch columns. + method (str): method column in plot_df (this will be plotted one category per column) + batch (str): batch column in plot_df (this will be plotted one category per row) + palette_choice (list): List of colors to plot each rows foreground. Defaults to None. + Returns: + fig, ax : matplotlib subplots figure + """ plot_df = plot_df[['umap_1', 'umap_2', method, batch]] plot_df = plot_df.dropna() # n_vars = len(df[var_choice].unique()) - nrows = len(plot_df[batch].unique()) - ncols = len(plot_df[method].unique()) - fig, axs = plt.subplots(nrows, ncols, figsize=(4*ncols, 3*nrows), facecolor='w', edgecolor='k') - fig.subplots_adjust(hspace=.2, wspace=.2) - if np.max([nrows, ncols]) > 1: - axs = axs.ravel(order="F") + method_choices = plot_df[method].cat.categories.tolist() + logging.debug(method_choices) + group_choices = plot_df[batch].unique() + logging.debug(group_choices) + nrows = len(group_choices) + ncols = len(method_choices) + if nrows > 40: + logging.info("skipping facet plot as too many variables: %s" % batch) + fig = None + axs = None else: - axs=[axs] - if palette_choice is None: - palette_choice = ['#1f77b4']*ncols - # for j in range(ncols): - idx=0 - for j, method_choice in enumerate(plot_df[method].cat.categories.tolist()): - plot_df2 = plot_df[plot_df[method] == method_choice].copy() - for i in range(nrows): - group_choice = plot_df2[batch].unique()[i] - logging.debug(str(method_choice) + "|" + str(group_choice)) - scatter_one(group_choice, batch, plot_df2, axs[idx], colour=palette_choice[j], title=str(method_choice) + "|" + str(group_choice)) - idx+=1 + fig, axs = plt.subplots(nrows, ncols, figsize=(4*ncols, 3*nrows), facecolor='w', edgecolor='k') + fig.subplots_adjust(hspace=.2, wspace=.2) + if np.max([nrows, ncols]) > 1: + logging.debug('ravelling the matploltib as there are multipl columns') + axs = axs.ravel(order="F") + else: + axs=[axs] + if palette_choice is None: + palette_choice = ['#1f77b4']*nrows + logging.debug(len(palette_choice)) + # for j in range(ncols): + idx=0 + for i, method_choice in enumerate(method_choices): + plot_df2 = plot_df[plot_df[method] == method_choice].copy() + for j, group_choice in enumerate(group_choices): + logging.debug('plotting %i %s' % (j, group_choice)) + logging.debug(str(method_choice) + "|" + str(group_choice)) + scatter_one(group_choice, batch, plot_df2, axs[idx], + colour=palette_choice[i], + title=str(method_choice) + "|" + str(group_choice)) + idx+=1 + logging.info(idx) + return fig, axs + def facet_scatter(x, y, c, **kwargs): """Draw scatterplot with point colors from a faceted DataFrame columns.""" diff --git a/panpipes/python_scripts/plot_umaps_batch_correct.py b/panpipes/python_scripts/plot_umaps_batch_correct.py index 5c2aca59..4f2906fa 100644 --- a/panpipes/python_scripts/plot_umaps_batch_correct.py +++ b/panpipes/python_scripts/plot_umaps_batch_correct.py @@ -16,7 +16,7 @@ import sys import logging L = logging.getLogger() -L.setLevel(logging.WARNING) +L.setLevel(logging.INFO) log_handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter('%(asctime)s: %(levelname)s - %(message)s') log_handler.setFormatter(formatter) @@ -116,8 +116,9 @@ g = (g.map(sns.scatterplot, "umap_1", "umap_2", col, s=pointsize, linewidth=0)) g.add_legend() g.savefig(os.path.join(args.fig_dir, mod, "umap_method_" + str(col) + ".png")) - batch_scatter_two_var(plt_df, "method", col, palette_choice=palette_choice) - plt.savefig(os.path.join(args.fig_dir,mod, "umap_method_facet_" + str(col) + ".png"), dpi=300) + fig, ax = batch_scatter_two_var(plt_df, "method", col, palette_choice=palette_choice) + if fig is not None: + fig.savefig(os.path.join(args.fig_dir, mod, "umap_method_facet_" + str(col) + ".png"), dpi=300) plt.clf() ncats = len(plt_df['method'].unique())