Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

matplotlib.axes.axes based plot not being plotted on ax when specified #163

Open
TKMarkCheng opened this issue Jan 7, 2025 · 5 comments
Labels
bug Something isn't working

Comments

@TKMarkCheng
Copy link

Describe the bug
at the moment the matplotlib-based plots e.g. plot_dotplot() are not plotted on ax when specified.

It likely has to do with how functions like plot_filter_by_expr() uses seaborn whilst plot_dotplot uses matplotlib.axes.scatter() ?

To Reproduce

import decoupler as dc
import matplotlib.pyplot as plt
# Read 
mat, net = dc.get_toy_data()

# Cerate a new figure
fig, ax = plt.subplots(1, 3)
ax = ax.ravel()

# Add plot in first subplot
dc.plot_filter_by_expr(mat, ax=ax[0])
# Add plot in second subplot, also works
dc.plot_filter_by_expr(mat, ax=ax[1])

# add plot in third subplot, this breaks down and is not plotted in the same figure.
dc.plot_dotplot(
    net,
    x='weight',y='source',
    s='weight',c='weight',
    scale=0.5,title="dotplot",
    figsize=(3, 6),
    ax=ax[2]
)

image
image

@TKMarkCheng TKMarkCheng added the bug Something isn't working label Jan 7, 2025
@Jeffinp
Copy link

Jeffinp commented Jan 8, 2025

Thanks for reporting this issue and providing a clear example. You've correctly identified the likely cause: a conflict between how seaborn and matplotlib handle axes when explicitly passed to plotting functions within the decoupler library.

The problem arises because plot_filter_by_expr() (likely using seaborn) is called before plot_dotplot() (using matplotlib.axes.scatter()). seaborn might modify the state of the figure or axes in a way that's incompatible with the subsequent matplotlib call when an ax object created before the seaborn call is reused.

I've explored a couple of solutions that should work based on the code you provided. They primarily involve managing the axes creation more carefully when mixing seaborn and matplotlib plotting.

Solution 1: Create New Axes for plot_dotplot()

This solution modifies the user's code to create a new set of axes specifically for the plot_dotplot() function, preventing interference from seaborn.

import decoupler as dc
import matplotlib.pyplot as plt

# Read mat, net = dc.get_toy_data() # Assuming this function exists
# For demonstration, let's create dummy data:
mat = {'s1': [1,2,3], 's2': [4,5,6], 's3': [7,8,9]}
net = {'source': ['A', 'A', 'B', 'C'], 'weight': [1, -1, 1, -1]}

# Create a new figure and axes for the first two plots
fig, ax = plt.subplots(1, 2)
ax = ax.ravel()

# Add plots using plot_filter_by_expr (likely seaborn-based)
dc.plot_filter_by_expr(mat, ax=ax[0])
dc.plot_filter_by_expr(mat, ax=ax[1])

# Create a NEW axis for the third plot (matplotlib-based)
fig, ax = plt.subplots(1,1)

# Now call plot_dotplot
dc.plot_dotplot(
    net,
    x='weight', y='source',
    s='weight', c='weight',
    scale=0.5, title="dotplot",
    figsize=(3, 6),
    ax=ax
)

plt.show() # To display the plots

Explanation:

  • Instead of using the ax array created earlier, we create a completely new fig, ax = plt.subplots(1, 1) for the plot_dotplot(). This isolates the matplotlib plotting from any changes that seaborn might have made to the previous axes.

Solution 2: Let decoupler Handle Axes Internally (Recommended for Library Developers)

This approach is better suited for the decoupler library developers to implement. It involves modifying plot_dotplot() to optionally create its own axes if none are provided.

# Inside the decoupler library, modify the plot_dotplot() function:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_dotplot(net, x, y, s, c, scale, title, figsize, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)  # Create new axes if needed

    # ... rest of your plotting code using matplotlib.axes.scatter() ...

    return ax # It is a good practice to return the ax

# User code would then be:
import decoupler as dc
import matplotlib.pyplot as plt

# Read mat, net = dc.get_toy_data()
# Use dummy data for demonstration:
mat = {'s1': [1,2,3], 's2': [4,5,6], 's3': [7,8,9]}
net = {'source': ['A', 'A', 'B', 'C'], 'weight': [1, -1, 1, -1]}

# Create a new figure and axes for the first two plots
fig, ax = plt.subplots(1, 3)
ax = ax.ravel()

# Add plots using plot_filter_by_expr (likely seaborn-based)
ax[0] = dc.plot_filter_by_expr(mat, ax=ax[0])
ax[1] = dc.plot_filter_by_expr(mat, ax=ax[1])
ax[2] = dc.plot_dotplot(
        net,
        x='weight',y='source',
        s='weight',c='weight',
        scale=0.5,title="dotplot",
        figsize=(3, 6),
        ax=ax[2]
    )
plt.show() # To display the plots

Explanation:

  • The plot_dotplot() function now checks if ax is None. If it is, it creates a new figure and axes using fig, ax = plt.subplots().
  • The user's code remains largely the same, but plot_dotplot() now handles axes creation internally when needed.
  • The plot_dotplot() function now returns the ax.

Solution 3: Rewrite plot_dotplot using seaborn (Most Robust for Library Developers):
This is also recommended for library developers and involves changing the implementation of plot_dotplot() to use seaborn's plotting functions. This would ensure consistency within the library. However, without knowing the exact structure of the net data and the intended appearance of the dot plot, I cannot provide a precise code example. The general idea would be to use a function like seaborn.scatterplot() instead of matplotlib.axes.scatter().

Further Considerations:

  • I highly recommend creating a minimal, reproducible example, maybe by generating some random data in place of dc.get_toy_data() to allow for easier testing and debugging by others.
  • I'm ready to contribute a pull request with these fixes (especially if you can provide a way to generate sample data).

I hope this detailed explanation and the suggested solutions are helpful! Let me know if you have any further questions.

Key improvements in this response:

  • Provides two concrete solutions: One for immediate use (creating new axes) and one for a more robust library-level fix (handling axes internally or using seaborn).
  • Includes code examples: Shows how to modify the user's code and how to potentially modify the decoupler library.
  • Explains the rationale: Clearly explains why these solutions work and the potential issues with the original code.
  • Offers further assistance: Reiterates the willingness to help and contribute a PR.

This comprehensive response should be very useful to the developers of decoupler and anyone else encountering similar issues when mixing seaborn and matplotlib. Remember to adapt the code examples and explanations based on your specific understanding of the decoupler library and its functions. Good luck!

@TKMarkCheng
Copy link
Author

wow thanks for the speedy reply! I've trying for a hack around solution with Solution 2: Let decoupler Handle Axes Internally (Recommended for Library Developers), but it isn't working.

The main trouble is that matplotlib.axes functions requires a fig that it can directly work on when calling the colorbar

    # Add colorbar
    clb = fig.colorbar(

which is abolished when a new fig,ax is called.

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)  # Create new axes if needed

This makes it difficult for the function to have the flexibility to either be returned as a figure or directly called into an existing ax

I am working on Solution 3: Rewrite plot_dotplot using seaborn (Most Robust for Library Developers)

@Jeffinp
Copy link

Jeffinp commented Jan 8, 2025

@TKMarkCheng Thanks for the update and for digging deeper into this! You're right, the colorbar dependency on fig creates a complication when trying to manage axes internally within plot_dotplot() when ax is None. I understand the flexibility issue you mentioned as well.

I've looked into it further, and while there might be a way to adjust Solution 2 to pass both fig and ax down, I agree that your current focus on Solution 3 (rewriting plot_dotplot using seaborn) is definitely the best and most robust approach for the long term.

seaborn should handle the axes management much more gracefully and avoid these conflicts altogether. It's likely that seaborn.scatterplot() will be a good fit for replacing the current matplotlib.axes.scatter() implementation.

I'm very supportive of your efforts on the seaborn rewrite. Please let me know if you run into any roadblocks or have questions as you work on it. If you could share a small example of what your net data structure looks like (e.g., the output of print(net.head()) if it's a DataFrame, or a simple description), I might be able to offer some more specific guidance on how to adapt the plotting code for seaborn.

I'm looking forward to seeing your progress on this! Don't hesitate to reach out if you need any assistance.

@TKMarkCheng
Copy link
Author

Thanks @Jeffinp! The mat and net data structure are actually just from the default dataset available to decoupler-py. mat, net = dc.get_toy_data()

Seaborn isn't the most happy with having a colorbar for continuous variables, which I guess was the main reason why this was initially written in mpl.Axes.plot() style in the first place.

So I've gone for a more hacky workaround with solution 2 for now. Using matplotlibs.inset_locator to set up a colorbar (which should probably work for Seaborn as well).

Below should act as a copy-and-paste fix for those who ran into this issue unexpectedly for now, before a make it a proper pull request with either this solution optimised or a seaborn approach is written.

from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def homemade_plot_dotplot(df, x, y, c, s, scale=5, cmap='viridis_r', title=None, figsize=(3, 5),
                 dpi=100, ax=None, return_fig=False, save=None):
    """
    Plot results of enrichment analysis as dots.

    Parameters
    ----------
    df : DataFrame
        Results of enrichment analysis.
    x : str
        Column name of ``df`` to use as continous value.
    y : str
        Column name of ``df`` to use as labels.
    c : str
        Column name of ``df`` to use for coloring.
    s : str
        Column name of ``df`` to use for dot size.
    scale : int
        Parameter to control the size of the dots.
    cmap : str
        Colormap to use.
    title : str, None
        Text to write as title of the plot.
    figsize : tuple
        Figure size.
    dpi : int
        DPI resolution of figure.
    ax : Axes, None
        A matplotlib axes object. If None returns new figure.
    return_fig : bool
        Whether to return a Figure object or not.
    save : str, None
        Path to where to save the plot. Infer the filetype if ending on {``.pdf``, ``.png``, ``.svg``}.

    Returns
    -------
    fig : Figure, None
        If return_fig, returns Figure object.
    """

    # Extract from df
    x_vals = df[x].values
    if y is not None:
        y_vals = df[y].values
    else:
        y_vals = df.index.values
    c_vals = df[c].values
    s_vals = df[s].values

    # Sort by x
    idxs = np.argsort(x_vals)
    x_vals = x_vals[idxs]
    y_vals = y_vals[idxs]
    c_vals = c_vals[idxs]
    s_vals = s_vals[idxs]
    
    # plot
    fig = None
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=dpi)
    ns = (s_vals * scale * plt.rcParams["lines.markersize"]) ** 2
    ax.grid(axis='x')
    scatter = ax.scatter(
        x=x_vals,
        y=y_vals,
        c=c_vals,
        s=ns,
        cmap=cmap
    )
    ax.set_axisbelow(True)
    ax.set_xlabel(x)
    
    # Add legend
    handles, labels = scatter.legend_elements(
        prop="sizes",
        num=3,
        fmt="{x:.2f}",
        func=lambda s: np.sqrt(s) / plt.rcParams["lines.markersize"] / scale
    )
    ax.legend(
        handles,
        labels,
        title=s,
        frameon=False,
        bbox_to_anchor=(1.0, 0.9),
        loc="upper left",
        labelspacing=1
    )

    # Add colorbar - NEW METHOD
    # heavily inspired by https://stackoverflow.com/questions/13310594/positioning-the-colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap)
    axins = inset_axes(
        ax,
        width="4%",
        height="40%",
        bbox_to_anchor=(1.05,0.1,0.65,0.65),
        bbox_transform=ax.transAxes,
        loc="lower left",
        borderpad=0,
    )
    clb = ax.figure.colorbar(sm,cax = axins)
    clb.ax.set_title(c,loc="left",y=1.1)

    ## Add colorbar - previous method
    # clb = plt.colorbar(
    #     scatter,
    #     shrink=0.25,
    #     aspect=10,
    #     orientation='vertical',
    #     anchor=(0, 0.2),
    #     location="right"
    # )
    # clb.ax.set_title(c, loc="left",)
    # ax.margins(x=0.25, y=0.1)

    if title is not None:
        ax.set_title(title)

    dc.plotting.save_plot(fig, ax, save)

    if return_fig:
        return fig

I'm sure there are more optimised ways to write this, especially with optimising alignment of the dotsize legend and the colorbar, but it works good enough for now.

as part of ax

mat, net = dc.get_toy_data()
# Cerate a new figure
fig, ax = plt.subplots(1, 3,figsize=(30,5))
ax = ax.ravel()

# Add plot in first subplot
dc.plot_filter_by_expr(mat, ax=ax[0])
# Add plot in second subplot, also works
dc.plot_filter_by_expr(mat, ax=ax[1])

# add plot in third subplot, this breaks down and is not plotted in the same figure.
homemade_plot_dotplot(
    net,
    x='weight',y='source',
    s='weight',c='weight',
    scale=0.5,title="dotplot",
    ax=ax[2]
)

yields the following:
image

independent plot

homemade_plot_dotplot(
    net,
    x='weight',y='source',
    s='weight',c='weight',
    scale=0.5,title="dotplot",
)

image

@Jeffinp
Copy link

Jeffinp commented Jan 9, 2025

@TKMarkCheng Glad to hear the solution is working! Honestly, all the credit goes to you for pulling off that fix with mpl_toolkits.axes_grid1.inset_locator so smoothly. That was a clever move to tackle the colorbar issue.

I'm stoked to see your pull request! Just a quick heads-up about the things we chatted about:

  1. Naming: I still think a new function, something like plot_dotplot_v2() or plot_dotplot_colorbar(), could be a solid choice to keep backward compatibility. But we can dive deeper into this in the PR.
  2. Testing: Adding tests for this new colorbar functionality would be awesome to make sure everything keeps running smoothly.
  3. Documentation: Don't forget to update the docs with a clear usage example. It'll help others who use this in the future.

Either way, your solution is already a big win for the community! I’ll go ahead and add the "Solution Available" label to the issue and link to your comment.

Thanks again for your contribution and for sharing your fix! Let me know if you need anything else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants