-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtasks.py
146 lines (110 loc) · 4.16 KB
/
tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import re
from dotenv import load_dotenv
from invoke import Context, task
WINDOWS = os.name == "nt"
PROJECT_NAME = "mlops_project"
PYTHON_VERSION = "3.11"
load_dotenv()
# Setup commands
@task
def create_environment(ctx: Context) -> None:
"""Create a new conda environment for project."""
ctx.run(
f"conda create --name {PROJECT_NAME} python={PYTHON_VERSION} pip --no-default-packages --yes",
echo=True,
pty=not WINDOWS,
)
@task
def requirements(ctx: Context) -> None:
"""Install project requirements."""
ctx.run("pip install -U pip setuptools wheel", echo=True, pty=not WINDOWS)
ctx.run("pip install -r requirements.txt", echo=True, pty=not WINDOWS)
ctx.run("pip install -e .", echo=True, pty=not WINDOWS)
@task(requirements)
def dev_requirements(ctx: Context) -> None:
"""Install development requirements."""
ctx.run('pip install -e .["dev"]', echo=True, pty=not WINDOWS)
# Project commands
@task
def preprocess_data(ctx: Context, percentage: float = 1.0) -> None:
"""Preprocess data."""
ctx.run(
f"python src/{PROJECT_NAME}/data.py data/raw data/processed --percentage {percentage}",
echo=True,
pty=not WINDOWS,
)
@task
def train(ctx: Context) -> None:
"""Train model."""
ctx.run(f"python src/{PROJECT_NAME}/train.py", echo=True, pty=not WINDOWS)
@task
def test(ctx: Context) -> None:
"""Run tests."""
ctx.run("coverage run -m pytest tests/", echo=True, pty=not WINDOWS)
ctx.run("coverage report -m", echo=True, pty=not WINDOWS)
@task
def docker_build(ctx: Context, progress: str = "plain") -> None:
"""Build docker image."""
with open("default.json", "r") as file:
default_json = file.read()
ctx.run(
f"docker build . --build-arg 'DEFAULT_JSON={default_json}' -f dockerfiles/train.dockerfile -t train:latest --progress={progress}",
echo=True,
pty=not WINDOWS,
)
@task
def docker_train(ctx: Context) -> None:
"""Run docker train image."""
wandb_api_key = os.getenv("WANDB_API_KEY")
if not wandb_api_key:
raise ValueError("WANDB_API_KEY not found in the environment. Make sure it's set in the .env file.")
ctx.run(
f"docker run --name train1 --rm "
f"-v $(pwd)/models:/models/ "
f"-v $(pwd)/reports/figures:/reports/figures/ "
f"-e WANDB_API_KEY={wandb_api_key} "
f"train:latest",
echo=True,
pty=not WINDOWS,
)
@task
def wandb_sweep(ctx, config_path="configs/sweep.yaml") -> None:
"""Run a W&B sweep and start the agent."""
try:
# Create the sweep
print(f"Creating W&B sweep with config: {config_path}")
sweep_result = ctx.run(f"wandb sweep {config_path}", echo=True, pty=True, warn=True)
# Extract sweep command
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
sweep_output = ansi_escape.sub("", sweep_result.stdout.strip())
pattern = r"(wandb agent [\w/]+)"
match = re.search(pattern, sweep_output)
if not match:
print("Error: Could not extract sweep command from W&B output.")
return
sweep_command = match.group(0)
# Start the W&B agent
ctx.run(sweep_command, echo=True, pty=True)
print(f"Sweep created using command: {sweep_command}")
except Exception as e:
print(f"An error occurred: {e}")
@task
def ruff_format(ctx: Context) -> None:
"""Train model."""
ctx.run(f"ruff check . --fix", echo=True, pty=not WINDOWS)
ctx.run(f"ruff format .", echo=True, pty=not WINDOWS)
@task
def test_coverage(ctx: Context) -> None:
"""Show test coverage."""
ctx.run("coverage run -m pytest tests/", echo=True, pty=not WINDOWS)
ctx.run("coverage report -m", echo=True, pty=not WINDOWS)
# Documentation commands
@task(dev_requirements)
def build_docs(ctx: Context) -> None:
"""Build documentation."""
ctx.run("mkdocs build --config-file docs/mkdocs.yaml --site-dir build", echo=True, pty=not WINDOWS)
@task(dev_requirements)
def serve_docs(ctx: Context) -> None:
"""Serve documentation."""
ctx.run("mkdocs serve --config-file docs/mkdocs.yaml", echo=True, pty=not WINDOWS)