def display_signals(
data: TabularDatasetData,
*,
figsize: tuple[int, int] = (6, 4),
dpi: int = 150,
colormap: str = "viridis",
x_label: str = "Spectral bands",
y_label: str = "",
label_fontsize: int = 14,
tick_params_label_size: int = 12,
legend_fontsize: int = 10,
legend_frameon: bool = True,
):
if not isinstance(data.target, ClassificationTarget):
raise InvalidInputError(
input_value=data.target,
message="The target must be an instance of ClassificationTarget.",
)
signals = data.signals.copy()
target = data.target.model_copy()
y_data_encoded = target.value
classes = list(target.encoding.to_dict().values())
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
cmap = plt.get_cmap(colormap)
unique_labels = np.unique(y_data_encoded)
no_colors = len(unique_labels)
if no_colors > 2:
colors = list(cmap(np.linspace(0, 1, no_colors)))
else:
colors = ["darkgoldenrod", "forestgreen"]
x_values = list(range(len(signals.columns)))
grouped_data = signals.groupby(y_data_encoded.to_numpy())
mean_values = grouped_data.mean()
std_values = grouped_data.std()
for idx in unique_labels:
mean = mean_values.loc[idx].tolist()
std = std_values.loc[idx].tolist()
ax.plot(x_values, mean, color=colors[idx], label=classes[idx], alpha=0.6)
ax.fill_between(
x_values,
[m - s for m, s in zip(mean, std)],
[m + s for m, s in zip(mean, std)],
color=colors[idx],
alpha=0.2,
)
custom_lines = []
for idx in unique_labels:
custom_lines.append(Line2D([0], [0], color=colors[idx], lw=2))
ax.set_ylabel(y_label, fontsize=label_fontsize)
ax.set_xlabel(x_label, fontsize=label_fontsize)
ax.tick_params(axis="both", which="major", labelsize=tick_params_label_size)
ax.tick_params(axis="both", which="minor", labelsize=tick_params_label_size)
# ax.set_ylim([0, 1])
# ax.spines["bottom"].set_linewidth(2)
# ax.spines["left"].set_linewidth(2)
ax.spines["right"].set_linewidth(0)
ax.spines["top"].set_linewidth(0)
ax.set_xticks(x_values)
ax.set_xticklabels(signals.columns, rotation=0)
ax.legend(
loc="upper left",
fontsize=legend_fontsize,
framealpha=1,
frameon=legend_frameon,
)
plt.show()