Skip to content

Commit

Permalink
Q6. Finish the integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
lkopocinski committed Jul 3, 2024
1 parent 5fcb521 commit d83a041
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 48 deletions.
3 changes: 3 additions & 0 deletions 06-best-practices/homework/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
INPUT_FILE_PATTERN=
OUTPUT_FILE_PATTERN=
S3_ENDPOINT_URL=
37 changes: 26 additions & 11 deletions 06-best-practices/homework/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
import pandas as pd


def get_s3_endpoint_url():
endpoint_url = os.getenv('S3_ENDPOINT_URL')
return endpoint_url


def get_input_path(year, month):
def get_input_path(year: int, month: int) -> str:
default_input_pattern = 'https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_{year:04d}-{month:02d}.parquet'
input_pattern = os.getenv('INPUT_FILE_PATTERN', default_input_pattern)
return input_pattern.format(year=year, month=month)
Expand All @@ -20,6 +15,7 @@ def get_output_path(year, month):
output_pattern = os.getenv('OUTPUT_FILE_PATTERN', default_output_pattern)
return output_pattern.format(year=year, month=month)


def read_data(filename) -> pd.DataFrame:
options = None

Expand All @@ -33,6 +29,25 @@ def read_data(filename) -> pd.DataFrame:
return pd.read_parquet(filename, storage_options=options)


def save_data(df, filename) -> None:
options = None

if s3_endpoint_url := os.getenv('S3_ENDPOINT_URL'):
options = {
'client_kwargs': {
'endpoint_url': s3_endpoint_url
}
}

df.to_parquet(
filename,
engine='pyarrow',
compression=None,
index=False,
storage_options=options
)


def prepare_data(df, categorical) -> pd.DataFrame:
df['duration'] = df.tpep_dropoff_datetime - df.tpep_pickup_datetime
df['duration'] = df.duration.dt.total_seconds() / 60
Expand All @@ -43,22 +58,21 @@ def prepare_data(df, categorical) -> pd.DataFrame:

return df


def load_model(filename):
with open(filename, 'rb') as f_in:
return pickle.load(f_in)


def main(year, month):
# input_file = f'https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_{year:04d}-{month:02d}.parquet'
# output_file = f'output/yellow_tripdata_{year:04d}-{month:02d}.parquet'

input_file = get_input_path(year, month)
output_file = get_output_path(year, month)
model_file = 'model.bin'

categorical = ['PULocationID', 'DOLocationID']

df = read_data(filename=input_file, categorical=categorical)
df = read_data(filename=input_file)
df = prepare_data(df, categorical)
df['ride_id'] = f'{year:04d}/{month:02d}_' + df.index.astype('str')

dv, lr = load_model(filename=model_file)
Expand All @@ -68,12 +82,13 @@ def main(year, month):
y_pred = lr.predict(X_val)

print('predicted mean duration:', y_pred.mean())
print('predicted sum duration:', y_pred.sum())

df_result = pd.DataFrame()
df_result['ride_id'] = df['ride_id']
df_result['predicted_duration'] = y_pred

df_result.to_parquet(output_file, engine='pyarrow', index=False)
save_data(df_result, output_file)


if __name__ == '__main__':
Expand Down
37 changes: 0 additions & 37 deletions 06-best-practices/homework/integration_test.py

This file was deleted.

40 changes: 40 additions & 0 deletions 06-best-practices/homework/tests/integration_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

import os
import pandas as pd
from datetime import datetime
from pandas.testing import assert_frame_equal
from batch import read_data, save_data, get_output_path, get_input_path

def dt(hour, minute, second=0):
return datetime(2023, 1, 1, hour, minute, second)


def test_s3_prediction():
# Arrange
columns = ['PULocationID', 'DOLocationID', 'tpep_pickup_datetime', 'tpep_dropoff_datetime']
data = [
(None, None, dt(1, 1), dt(1, 10)),
(1, 1, dt(1, 2), dt(1, 10)),
(1, None, dt(1, 2, 0), dt(1, 2, 59)),
(3, 4, dt(1, 2, 0), dt(2, 2, 1)),
]
df_input = pd.DataFrame(data, columns=columns)

filename = get_input_path(2023, 1)
save_data(df_input, filename)

# Act
os.system("python batch.py 2023 1")

filename = get_output_path(2023, 1)
df_actual = read_data(filename)

# Assert
columns = ['ride_id', 'predicted_duration']
data = [
('2023/01_0', 23.197149),
('2023/01_1', 13.080101),
]
df_expected = pd.DataFrame(data, columns=columns)

assert_frame_equal(df_actual, df_expected)

0 comments on commit d83a041

Please sign in to comment.