Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Vishwa44 authored May 12, 2024
1 parent bb61014 commit fe6628c
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,24 @@
# PatchTST Supervised with flashattention

Objective:
Our project is to improve the PatchTST model, this model is a Transformer model for time series forecasting, with Flash Attention, a novel technique to improve the performance of the read/write operations that take place during Attention, which is projected to have a 3x improvement in the attention process, drastically improving the performance of the model.


Milestones:
Baseline Profiling:
Evaluate current model performance and resource usage.Document architecture, metrics, and resources. Flash Attention
Implementation:
Research Flash Attention mechanism. Implement in Python, test with sample data, and document.
Integration with Current Model:
Modify model to include Flash Attention. Retrain, evaluate, and document changes.
Further Optimization:
Analyze data characteristics. Experiment with techniques like preprocessing, optimization algorithms, and hardware utilization. Document optimization strategies and insights gained.
Developing a custom triton kernel:
To reduce computation overhead for the smaller model size of PatchTST we developed a custom kernel which reads attention scores directly from froward pass which utilized in backward pass.

The custom kerenl implmentation is in a py notebook format in the Kernel folder.

Results:
Our kernel has a 1.64 speed up for backward pass and 0.70 speed down for forward pass.
This kernel specifically performs better on smaller context length models.
Due to the limitations of the library forward pass is written in an unoptimal way.

0 comments on commit fe6628c

Please sign in to comment.