-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
221 lines (171 loc) · 7.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import os
import subprocess
import sys
import time
from typing import List
from pathlib import Path
from datetime import datetime, timedelta
from collections import namedtuple
import boto3
TIME_FORMAT = '%Y:%m:%d:%H:%M:%S'
Interval = namedtuple('Interval', ['name', 'max_backups'])
INTERVALS = [
Interval('day', 7),
Interval('week', 5),
Interval('month', 12),
Interval('year', 4),
]
CURRENT_FOLDER = Path(__file__).parent.absolute()
DB_HOST = os.environ.get('DB_HOST')
DB_NAME = os.environ.get('DB_NAME')
DB_PORT = os.environ.get('DB_PORT')
DB_USER = os.environ.get('DB_USER')
DB_PASS = os.environ.get('DB_PASS')
DEBUG = os.environ.get('DEBUG_VALUE')
BOTO_SESSION = boto3.session.Session()
ACCESS_KEY = os.environ.get('DBBACKUP_ACCESS_KEY')
SECRET_KEY = os.environ.get('DBBACKUP_SECRET_KEY')
BUCKET_NAME = os.environ.get('DBBACKUP_BUCKET_NAME')
ENDPOINT_URL = os.environ.get('DBBACKUP_ENDPOINT_URL')
class Storage:
def __init__(self, access_key, secret_key, bucket_name, endpoint_url):
if not (access_key and secret_key and bucket_name and endpoint_url):
# raise error if any s3-related env is None
raise Exception('S3 Bucket Is Not Configured')
self.client = BOTO_SESSION.client(
's3',
region_name='nyc3',
endpoint_url=endpoint_url,
aws_access_key_id=access_key,
aws_secret_access_key=secret_key
)
self.bucket_name = bucket_name
def write_file(self, absolute_file_path: str, file_path_in_s3_bucket):
self.client.upload_file(absolute_file_path, self.bucket_name, file_path_in_s3_bucket)
def delete_file(self, file_path_in_s3_bucket: str):
self.client.delete_object(
Bucket=self.bucket_name,
Key=file_path_in_s3_bucket,
)
def list_directory(self, directory_path: str) -> list:
response = self.client.list_objects_v2(
Bucket=self.bucket_name,
Prefix=directory_path
)
if response['KeyCount'] == 0:
return []
# return file names
return [f['Key'] for f in response['Contents']]
class BaseCommand:
@staticmethod
def print_error(message: str):
print(message, file=sys.stderr)
@staticmethod
def print_info(message: str):
print(message, file=sys.stdout)
class DB_CONNECTOR:
def __init__(self, host, name, port, user, password):
self.settings = {
'HOST': host,
'NAME': name,
'PORT': port,
'USER': user,
'PASSWORD': password,
}
def dump(self, output_file_path):
cmd = f'PGPASSWORD="{DB_PASS}" runuser -u {DB_USER} -- pg_dump -U {DB_USER} -h {DB_HOST} {DB_NAME} > {output_file_path}'
result = subprocess.run(
cmd, capture_output=True, shell=True, timeout=60,
)
# forward error when cmd fails
if result.stderr:
error_message = f'Error Dumping File: {result.stderr}'
raise Exception(error_message)
class Command(BaseCommand):
help = 'Runs backup code'
storage = Storage(ACCESS_KEY, SECRET_KEY, BUCKET_NAME, ENDPOINT_URL)
db = DB_CONNECTOR(DB_HOST, DB_NAME, DB_PORT, DB_USER, DB_PASS)
env = 'prod' if DEBUG == 'False' else 'dev'
@staticmethod
def truncate_datetime(dt: datetime, interval_name: str) -> datetime:
"""
Rounds datetime precision down to the nearest year, month, week, or day.
"""
if interval_name == 'year':
return datetime(year=dt.year, month=1, day=1)
elif interval_name == 'month':
return datetime(year=dt.year, month=dt.month, day=1)
elif interval_name == 'week':
# note start of week may be in previous month
days_after_start_of_week = dt.weekday()
return datetime(year=dt.year, month=dt.month, day=dt.day) - timedelta(days_after_start_of_week)
elif interval_name == 'day':
return datetime(year=dt.year, month=dt.month, day=dt.day)
else:
raise Exception('Invalid Interval Name')
@staticmethod
def remove_dumped_file(dumped_file_path):
if os.path.exists(dumped_file_path):
os.remove(dumped_file_path)
def create_backup(self, s3_file_path: str) -> bool:
# returns True when backup was created and False otherwise
dumped_db_file_path = os.path.join(CURRENT_FOLDER, 'dumped_files', 'dumped.psql')
try:
# create db dump only when file doesn't exist
if not os.path.isfile(dumped_db_file_path):
self.db.dump(dumped_db_file_path)
self.storage.write_file(dumped_db_file_path, s3_file_path)
return True
except Exception as e:
# remove the dumped file since an error happened,
# the file is likely empty
self.remove_dumped_file(dumped_db_file_path)
self.print_error(f'Error Creating File {s3_file_path}: {e}')
return False
def should_save_new_file(self, interval: Interval, files: List[str]) -> bool:
if len(files) == 0:
return True
most_recent_file_name = Path(files[-1]).stem
most_recent_file_datetime = datetime.strptime(most_recent_file_name, TIME_FORMAT)
truncated_recent_file_time = self.truncate_datetime(most_recent_file_datetime, interval.name)
truncated_current_time = self.truncate_datetime(datetime.now(), interval.name)
# compare the truncated times to see if they are the same
return truncated_recent_file_time != truncated_current_time
def remove_oldest_file(self, interval: Interval, files_in_folder):
# when folder is empty, skip removing the oldest file
if len(files_in_folder) == 0:
return
# when the file count hasn't reached the limit, skip removing the oldest file
if len(files_in_folder) <= interval.max_backups:
return
oldest_file_name = files_in_folder[0]
oldest_file_path = os.path.join(self.env, interval.name, oldest_file_name)
try:
self.storage.delete_file(oldest_file_path)
self.print_info(f'Deleted Old Backup For {interval.name} Named {oldest_file_name}')
except (Exception,):
self.print_error(f'Error Deleting File {oldest_file_name}')
def job(self):
for interval in INTERVALS:
try:
path = os.path.join(self.env, interval.name)
files: list = self.storage.list_directory(directory_path=path)
files.sort()
except (Exception,):
self.print_error('Error Getting Files')
files = []
if self.should_save_new_file(interval, files):
file_name = f'{time.strftime(TIME_FORMAT)}.psql'
file_path = os.path.join(self.env, interval.name, file_name)
self.print_info(f'Creating Backup For {interval.name} Named {file_name}')
creation_successful = self.create_backup(file_path)
if creation_successful:
files.append(file_name)
self.remove_oldest_file(interval, files)
self.print_info('Backup Script Complete')
def run(self):
self.print_info('Running backup Script')
self.job()
if __name__ == '__main__':
c = Command()
c.run()