Skip to content

Commit b7cf127

Browse files
Added plotly parallel coordinate plot (#175)
2 parents 0f55e77 + acecc34 commit b7cf127

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

Python_Engine/Python/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ case-converter
99
numpy==1.26.4
1010
pandas==2.2.1
1111
matplotlib==3.8.3
12+
plotly==5.18.1
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import plotly.graph_objects as go
2+
import pandas as pd
3+
from matplotlib.colors import Colormap
4+
import decimal as d
5+
from typing import List, Dict
6+
import math
7+
8+
from ..bhom.analytics import bhom_analytics
9+
10+
def set_dimensions(df: pd.DataFrame, tick_mark_count: int, dp:int) -> List[Dict[str, any]]:
11+
12+
"""Set the dimensions for a parallel coordinate plot, based on column datatypes and unique values.
13+
14+
Args:
15+
df (pd.DataFrame):
16+
The pandas DataFrame to plot.
17+
tick_mark_count (int):
18+
The number of tick marks to show on the parallel coordinate plot.
19+
dp (int):
20+
The number of decimal places to show on the tick marks.
21+
22+
Returns:
23+
list[dict[str, any]]:
24+
A list of dimensions to plot.
25+
"""
26+
27+
df_copy = df.copy()
28+
dimensions = []
29+
30+
for column in df_copy.columns:
31+
32+
dim = {}
33+
dim['label'] = str(column)
34+
35+
if df_copy[column].dtype == "object":
36+
#for catagorical data types, convert to numerical, with text as tick marks
37+
df_copy[column] = df_copy[column].astype("category").cat.codes
38+
39+
dim['values'] = df_copy[column]
40+
dim['tickvals'] = df_copy[column].unique()
41+
dim['ticktext'] = df[column].unique()
42+
43+
dimensions.append(dim)
44+
continue
45+
46+
dim['values'] = df_copy[column]
47+
48+
if df_copy[column].nunique() < tick_mark_count:
49+
50+
dim['range'] = [df_copy[column].min(), df_copy[column].max()]
51+
dim['tickvals'] = dim['ticktext'] = df_copy[column].unique()
52+
53+
dimensions.append(dim)
54+
55+
else:
56+
# reduce the number of tick marks if the column has a large number of unique values
57+
dim['range'] = [df_copy[column].min(), df_copy[column].max()]
58+
59+
if (dim['range'][1] - dim['range'][0] +1) < tick_mark_count:
60+
tick_mark_count = math.ceil(dim['range'][1] - dim['range'][0]) + 1
61+
62+
dim['tickvals'] = [df_copy[column].min() + i * (df_copy[column].max() - df_copy[column].min()) / (tick_mark_count - 1) for i in range(tick_mark_count)]
63+
dim['ticktext'] = [round(i ,dp) for i in dim['tickvals']]
64+
65+
dimensions.append(dim)
66+
67+
return dimensions
68+
69+
@bhom_analytics()
70+
def parallel_coordinate_plot(
71+
72+
df: pd.DataFrame = pd.DataFrame(),
73+
variables_to_show: list = None,
74+
decimal_places: int = 0,
75+
tick_mark_count: int = 11,
76+
colour_key: str = None,
77+
cmap: Colormap = "viridis",
78+
dimensions: List[dict] = None,
79+
plot_title: str = "",
80+
plot_bgcolour: str = 'black',
81+
paper_bgcolour: str = 'black',
82+
font_colour: str = 'white',
83+
**kwargs,
84+
) -> go.Figure:
85+
"""Create a parallel coordinate plot of a pandas DataFrame.
86+
87+
Args:
88+
df (pd.DataFrame):
89+
The pandas DataFrame to plot.
90+
variables_to_show (list, optional):
91+
The variables to show on the parallel coordinate plot. Must be a subset of df.columns.
92+
decimal_places (int, optional):
93+
The number of decimal places to show on the tick marks. Defaults to 0.
94+
tick_mark_count (int, optional):
95+
The number of tick marks to show on the parallel coordinate plot. Defaults to 11.
96+
colour_key (str, optional):
97+
The column to use as the colour key. Defaults to None.
98+
cmap (Colormap or str, optional):
99+
The colormap to use for the colour key. Can be a matplotlib Colormap or a string representing a Plotly colorscale. Defaults to "viridis".
100+
dimensions (list[dict], optional):
101+
A list of dimensions to plot. If None, dimensions will be automatically generated based on the DataFrame. Defaults to None.
102+
plot_title (str, optional):
103+
The title of the plot. Defaults to an empty string.
104+
plot_bgcolour (str, optional):
105+
The background color of the plot. Defaults to 'black'.
106+
paper_bgcolour (str, optional):
107+
The background color of the paper. Defaults to 'black'.
108+
font_colour (str, optional):
109+
The color of the font used in the plot. Defaults to 'white'.
110+
**kwargs:
111+
Additional keyword arguments to pass to go.Parcoords().
112+
113+
Returns:
114+
go.Figure:
115+
The populated go.Figure object.
116+
"""
117+
118+
if variables_to_show is not None:
119+
df = df[variables_to_show]
120+
121+
if dimensions is None:
122+
dimensions = set_dimensions(df, tick_mark_count, decimal_places)
123+
124+
if colour_key is None and not df.empty:
125+
colour_key = df.columns[-1]
126+
127+
line = dict(color=df[colour_key], colorscale=cmap)
128+
129+
fig = go.Figure(
130+
data=go.Parcoords(
131+
line = line,
132+
dimensions = dimensions,
133+
**kwargs
134+
)
135+
)
136+
137+
fig.update_layout(
138+
title = plot_title,
139+
plot_bgcolor = plot_bgcolour,
140+
paper_bgcolor = paper_bgcolour,
141+
font_color = font_colour
142+
)
143+
144+
return fig

0 commit comments

Comments
 (0)