Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DTW alignment plot does not seem to work properly [BUG] #1871

Closed
AhmetZamanis opened this issue Jun 30, 2023 · 4 comments · Fixed by #1880
Closed

DTW alignment plot does not seem to work properly [BUG] #1871

AhmetZamanis opened this issue Jun 30, 2023 · 4 comments · Fixed by #1880
Labels
bug Something isn't working

Comments

@AhmetZamanis
Copy link
Contributor

Describe the bug
I am using the dtw module to perform dynamic time warping on two univariate series. Everything seems to work fine except the DTWAlignment.plot_alignment() method, which does not produce the correct DTW alignment plot. It plots the two time series correctly, but the black alignment lines are just a vertical line around the start of the plot (see image below the reproducible code).

To Reproduce
Here's the code I used for everything. All other DTWAlignment methods seem to work fine except .plot_alignment(). The two series are univariate, and I've replicated the same issue with essentially the same code, using two other multivariate series.

from darts.timeseries import TimeSeries
from darts.dataprocessing import dtw

# Retrieve the two univariate series (ToolWearMins across time steps, for Machine 1 and Machine 34)
ts_list = TimeSeries.from_group_dataframe(
  df = df.loc[df["MachineID"].isin([1, 34])],
  group_cols = "MachineID",
  time_col = "SurvivalTime",
  value_cols = "ToolWearMins"
)
ts_short = ts_list[0].with_columns_renamed("ToolWearMins", "Machine 1") # 78 time steps
ts_long = ts_list[1].with_columns_renamed("ToolWearMins", "Machine 34") # 101 time steps

# Perform DTW
alignment = dtw.dtw(ts_short, ts_long)

# Visualize the alignment (broken)
alignment.plot_alignment(series2_y_offset = 100)
_ = plt.gca().set_title("Alignment")

plot_alignment_bug

Expected behavior
I was able to manually create the correct plot using the alignment indices from DTWAlignment.path() (see image below). I'll share the code for this in the additional context, in case it's helpful.
AlignmentPlotCorrect

System:

  • Python version: 3.10.8
  • darts version: 0.24.0

Additional context
Here's the code I used to create the second, correct alignment plot. I'm sure there's a cleaner way to do it.

# Get dataframes where rows = "Machine1", "Machine34", cols = SurvivalTime, ToolWearMin (+100 for Machine34)
df_alignment = []
for pair in path:
  x1 = pair[0] + 1 
  x2 = pair[1] + 1
  y1 = df.loc[(df["MachineID"] == 1) & (df["SurvivalTime"] == x1), "ToolWearMins"].values[0]
  y2 = df.loc[(df["MachineID"] == 34) & (df["SurvivalTime"] == x2), "ToolWearMins"].values[0] + 100
  
  data = pd.DataFrame({
    "SurvivalTime": [x1, x2],
    "ToolWearMins": [y1, y2]
    }
  )
  df_alignment.append(data)

# Plot the alignment
ts_short.plot()
(ts_long + 100).plot()

for pair in df_alignment:
  _ = plt.plot("SurvivalTime", "ToolWearMins", data = pair, c = "black", lw = 0.5)

_ = plt.title("Alignment plot")
_ = plt.ylabel("ToolWearMins (+100 for Machine 34)")
@AhmetZamanis AhmetZamanis added bug Something isn't working triage Issue waiting for triaging labels Jun 30, 2023
@madtoinou madtoinou removed the triage Issue waiting for triaging label Jun 30, 2023
@madtoinou
Copy link
Collaborator

Thanks you @AhmetZamanis for reporting this bug, would you have the time to open a PR to fix it?

@AhmetZamanis
Copy link
Contributor Author

@madtoinou I can give it a try next week. Can't guarantee a clean implementation though.

@AhmetZamanis
Copy link
Contributor Author

@madtoinou I took another look, .plot_alignment() actually works fine if the input time dimension has the type pd.DateTimeIndex. The issue occurs with a pd.RangeIndex time dimension, as the method converts it into datetime values within 1970-01-01.

I've made a small fix that apparently works with both types of time dimensions, and tested it locally. Before creating a PR, I have a small question about the procedure. I followed the guidelines from step 1 to 6 (and installed the pre-commit hook). When I run ./gradlew test_all in my local clone, the lint_flake8 check fails as it seems my entire virtual environment is being checked (which is installed inside the clone directory). Should I do something differently to test locally, or just go ahead and create the PR if the tests will be ran automatically? It's my first contribution to a package so thanks in advance for the help.

@madtoinou
Copy link
Collaborator

Thank you a lot for investigating the problem and identify the bug so quickly.

Ideally, flake should not be installed here but it's not a big problem: you can mimic the routine by using pytest to run the unit-test (after making sure your environment takes into account your fix, you can use pip install -e .) and then, run pre-commit run --all-files to see if any files associated with darts required linting.

If pre-commit is correctly installed, the linting should be automatically applied when you commit to your branch. You can probably already open the PR :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Archived in project
Development

Successfully merging a pull request may close this issue.

2 participants