diff --git a/latest_requirements.txt b/latest_requirements.txt index 39e1078e..99adb22b 100644 --- a/latest_requirements.txt +++ b/latest_requirements.txt @@ -1,7 +1,7 @@ copulas==0.12.3 -numpy==2.3.4 +numpy==2.3.5 pandas==2.3.3 -plotly==6.3.1 -scikit-learn==1.7.2 +plotly==6.5.0 +scikit-learn==1.8.0 scipy==1.16.3 tqdm==4.67.1 diff --git a/sdmetrics/visualization.py b/sdmetrics/visualization.py index 7e1a4a87..459e08b6 100644 --- a/sdmetrics/visualization.py +++ b/sdmetrics/visualization.py @@ -843,12 +843,14 @@ def get_column_line_plot(real_data, synthetic_data, column_name, metadata): x_axis = metadata['sequence_index'] if 'sequence_key' in metadata: r_data = ( - r_data.groupby(x_axis, as_index=False) + r_data + .groupby(x_axis, as_index=False) .agg({x_axis: 'first', column_name: ['mean', 'min', 'max']}) .rename(columns={'mean': column_name, 'first': x_axis}) ) s_data = ( - s_data.groupby(x_axis, as_index=False) + s_data + .groupby(x_axis, as_index=False) .agg({x_axis: 'first', column_name: ['mean', 'min', 'max']}) .rename(columns={'mean': column_name, 'first': x_axis}) )