Module mogptk.gpr.plot

Expand source code Browse git
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_gram(K):
    fig, ax = plt.subplots(1, 1, figsize=(6,6))
    fig.suptitle('Matrix is not positive semi-definitive', fontsize=16)

    K = K.detach().cpu().numpy()
    K_real = K[~np.isnan(K) & ~np.isinf(K)]
    if len(K_real) != 0:
        vmin, vmax = np.abs(K_real).min(), np.abs(K_real).max()
        norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
        im = ax.matshow(np.where(np.isinf(K)|np.isnan(K),np.nan,K), cmap='viridis', norm=norm)

    # show Inf and NaN as blue and red respectively
    cmap = matplotlib.colors.ListedColormap(["red"])
    ax.matshow(np.where(np.isinf(K),1.0,np.nan), cmap=cmap)
    cmap = matplotlib.colors.ListedColormap(["blue"])
    ax.matshow(np.where(np.isnan(K),1.0,np.nan), cmap=cmap)

    if len(K_real) != 0:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="3%", pad=0.1)
        fig.colorbar(im, cax=cax)

    ax.set_title('Red=Inf, Blue=NaN', pad=10, fontsize=12)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    plt.show()

Functions

def plot_gram(K)
Expand source code Browse git
def plot_gram(K):
    fig, ax = plt.subplots(1, 1, figsize=(6,6))
    fig.suptitle('Matrix is not positive semi-definitive', fontsize=16)

    K = K.detach().cpu().numpy()
    K_real = K[~np.isnan(K) & ~np.isinf(K)]
    if len(K_real) != 0:
        vmin, vmax = np.abs(K_real).min(), np.abs(K_real).max()
        norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
        im = ax.matshow(np.where(np.isinf(K)|np.isnan(K),np.nan,K), cmap='viridis', norm=norm)

    # show Inf and NaN as blue and red respectively
    cmap = matplotlib.colors.ListedColormap(["red"])
    ax.matshow(np.where(np.isinf(K),1.0,np.nan), cmap=cmap)
    cmap = matplotlib.colors.ListedColormap(["blue"])
    ax.matshow(np.where(np.isnan(K),1.0,np.nan), cmap=cmap)

    if len(K_real) != 0:
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="3%", pad=0.1)
        fig.colorbar(im, cax=cax)

    ax.set_title('Red=Inf, Blue=NaN', pad=10, fontsize=12)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    plt.show()