Source code for pymoo.visualization.scatter

import numpy as np

from pymoo.core.plot import Plot
from pymoo.docs import parse_doc_string
from pymoo.util.misc import set_if_none


def plot_1d(sc):
    sc.init_figure()
    labels = sc.get_labels()
    ax = sc.ax

    for k, (F, kwargs) in enumerate(sc.to_plot):
        func = getattr(ax, kwargs.pop("mode"))
        func(F, np.zeros_like(F), **kwargs)
        ax.set_xlabel(labels[0])


def plot_2d(sc):
    sc.init_figure()
    labels = sc.get_labels()
    ax = sc.ax

    for k, (F, kwargs) in enumerate(sc.to_plot):
        func = getattr(ax, kwargs.pop("mode"))
        func(F[:, 0], F[:, 1], **kwargs)
        ax.set_xlabel(labels[0])
        ax.set_ylabel(labels[1])

    return sc


def plot_3d(sc, angle):
    sc.init_figure(plot_3D=True)
    labels = sc.get_labels()
    ax = sc.ax

    for k, (F, kwargs) in enumerate(sc.to_plot):

        # here alo `plot_trisurf` is allowed
        func = getattr(ax, kwargs.pop("mode"))
        func(F[:, 0], F[:, 1], F[:, 2], **kwargs)

        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False

        ax.set_xlabel(labels[0])
        ax.set_ylabel(labels[1])
        ax.set_zlabel(labels[2])

        if sc.angle is not None:
            ax.view_init(*angle)


def plot_pairwise(sc):
    sc.init_figure(n_rows=sc.n_dim, n_cols=sc.n_dim)
    labels = sc.get_labels()

    for k, (F, kwargs) in enumerate(sc.to_plot):

        assert F.shape[1] >= 2, "A pairwise sc plot needs at least two dimensions."
        mode = kwargs.pop("mode")

        for i in range(sc.n_dim):
            for j in range(sc.n_dim):

                ax = sc.ax[i, j]
                func = getattr(ax, mode)

                if i != j:
                    func(F[:, i], F[:, j], **kwargs)
                    ax.set_xlabel(labels[i])
                    ax.set_ylabel(labels[j])
                else:
                    func(0, 0, s=1, color="white")
                    ax.set_xticks([])
                    ax.set_yticks([])
                    ax.text(0, 0, labels[i], ha='center', va='center', fontsize=20)


[docs] class Scatter(Plot): def __init__(self, plot_3d=True, angle=(45, 45), **kwargs): """ Scatter Plot Parameters ---------- axis_style : {axis_style} labels : {labels} Other Parameters ---------------- figsize : {figsize} title : {title} legend : {legend} tight_layout : {tight_layout} """ super().__init__(**kwargs) self.angle = angle self.plot_3d = plot_3d def _do(self): # set some default values to_plot = [] for k, (F, v) in enumerate(self.to_plot): v = dict(v) set_if_none(v, "color", self.colors[k % len(self.colors)]) set_if_none(v, "alpha", 1.0) # this is added to have compatibility to an old version # should be removed when the documentation is updated if "plot_type" in v: name = v.pop("plot_type") if name == "line": name = "plot" elif name == "surface": name = "plot_trisurf" v["mode"] = name set_if_none(v, "mode", "scatter") to_plot.append([F, v]) self.to_plot = to_plot if self.n_dim == 1: plot_1d(self) elif self.n_dim == 2: plot_2d(self) elif self.n_dim == 3 and self.plot_3d: plot_3d(self, self.angle) else: plot_pairwise(self) return self
parse_doc_string(Scatter.__init__)