Source code for dashi.supervised_characterization.plot_performance
# Copyright 2024 Biomedical Data Science Lab, Universitat Politècnica de València (Spain)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main function for multi-batch metrics exploration.
"""
from typing import Dict
import plotly.graph_objects as go
import plotly.io as pio
from plotly.colors import get_colorscale
from .arrange_metrics import arrange_performance_metrics
_FONTSIZE = 14
[docs]
def plot_multibatch_performance(*, metrics: Dict[str, float], metric_name: str) -> go.Figure:
"""
Plots a heatmap visualizing the specified metric for multiple batches of training and test models.
The function takes a dictionary of metrics and filters them based on the metric identifier.
It then generates a heatmap where the x-axis represents the test batches,
the y-axis represents the training batches, and the color scale indicates the
values of the specified metric.
The plot is interactive and can be explored (zoomed, hovered, etc.) using Plotly.
Parameters
----------
metrics : dict
A dictionary where keys are tuples of (training_batch, test_batch, dataset_type),
and values are the metric values for the corresponding combination.
The `dataset_type` should be `'test'` to include the metric in the heatmap.
metric_name : str
The name of the metric to visualize.
The function will filter metrics based on this identifier and only plot those for the 'test' set.
Regression metric names, when applicable:
- 'MEAN_ABSOLUTE_ERROR'
- 'MEAN_SQUARED_ERROR'
- 'ROOT_MEAN_SQUARED_ERROR'
- 'R_SQUARED'
Classification metric names, when applicable:
- 'AUC_{class_identifier}'
- 'AUC_MACRO'
- 'LOGLOSS'
- 'RECALL_{class_identifier}'
- 'PRECISION_{class_identifier}'
- 'F1-SCORE_{class_identifier}'
- 'ACCURACY'
- 'RECALL_MACRO'
- 'RECALL_MICRO'
- 'RECALL_WEIGHTED'
- 'PRECISION_MACRO'
- 'PRECISION_MICRO'
- 'PRECISION_WEIGHTED'
- 'F1-SCORE_MACRO'
- 'F1-SCORE_MICRO'
- 'F1-SCORE_WEIGHTED'
Returns
-------
fig
A Plotly figure object containing the heatmap visualization of the specified metric.
Raises
------
TypeError
If the `metrics` parameter is not a dictionary or if `metric_identifier` is not a string.
"""
# Metrics arrangement
metrics_test_frame = arrange_performance_metrics(metrics=metrics, metric_name=metric_name)
# Color scale definition
colorscale = get_colorscale('RdYlGn')
if metric_name in ('MEAN_ABSOLUTE_ERROR', 'MEAN_SQUARED_ERROR', 'ROOT_MEAN_SQUARED_ERROR', 'LOGLOSS'):
colorscale = colorscale[::-1]
# Plotting using Plotly
heatmap_data = go.Heatmap(
z=metrics_test_frame.values, # Values for the heatmap (reversed rows)
x=metrics_test_frame.columns, # Columns as x-axis
y=metrics_test_frame.index, # Rows as y-axis
colorscale=colorscale, # Color scale
colorbar=dict(title=metric_name), # Colorbar label
hovertemplate="%{y}<br>%{x}: %{z:.3f}", # Tooltip on hover
showscale=True # Display colorbar scale
)
# Layout of the plot
layout = go.Layout(
title=f'{metric_name.lower().capitalize()} heatmap',
xaxis=dict(title='Test Batch', tickangle=45, tickfont=dict(size=_FONTSIZE - 2)),
yaxis=dict(title='Training Batch', tickfont=dict(size=_FONTSIZE - 2)),
font=dict(size=_FONTSIZE, family="serif"),
template="plotly_white" # Optional: use a clean white background template
)
# Set the Plotly renderer for Jupyter or standalone use
# pio.renderers.default = 'notebook' # For Jupyter Notebooks (use 'notebook' or 'jupyterlab')
# For standalone (non-Jupyter) use, you can also use:
#pio.renderers.default = 'browser'
# Create the figure and plot
fig = go.Figure(data=[heatmap_data], layout=layout)
return fig