-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit-postgres-schema
executable file
·134 lines (121 loc) · 4.27 KB
/
split-postgres-schema
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/usr/bin/env python3
#
# split a postgres schema by tables, functions, etc
import os
import os.path
import re
import sys
def stderr(msg):
sys.stderr.write("%s: %s\n" % (os.path.basename(sys.argv[0]), msg))
class SectionReader:
"""Sections are separated by at least one blank line"""
def __init__(self, schema_name, f, spec):
self._schema_name = schema_name
self._f = f
self._linenum = 0
self._spec = spec
self._regex = [ re.compile(s['match']) for s in spec ]
self._section = "db" # first section is db header
self._included = {} # what SQL files have been included at top level
self._buf = []
self._blanks = []
self._comment = []
self._post_comment_blanks = []
self._init_outdir()
def _init_outdir(self):
os.makedirs(schema_name, exist_ok=True)
# delete all sql files from outdir, as we append
for filename in os.listdir(self._schema_name):
if filename.endswith('.sql'):
os.unlink('%s/%s' % (self._schema_name, filename))
def _flush_blanks(self):
if self._blanks:
self._buf.extend(self._blanks)
self._blanks = []
def _flush_comment(self):
self._buf.extend(self._blanks)
self._buf.extend(self._comment)
self._buf.extend(self._post_comment_blanks)
self._blanks = []
self._comment = []
self._post_comment_blanks = []
def _write(self):
if self._section is not None:
with open("%s/%s.sql" % (self._schema_name, self._section), 'a') as sql_f:
for line in self._buf:
sql_f.write(line)
self._buf = []
if self._section != 'db' and self._section not in self._included:
# include this file in the overall schema
with open("%s/%s.sql" % (self._schema_name, 'db'), 'a') as sql_f:
sql_f.write("\\i %s.sql\n" % self._section)
self._included[self._section] = True
else:
stderr("WARNING: flushing buf with no section at line %d" % self._linenum)
def read(self):
for line in self._f:
self._linenum += 1
if line.startswith('--'):
if self._comment and self._post_comment_blanks:
self._flush_comment()
self._comment.append(line)
elif line == '\n':
if self._comment:
self._post_comment_blanks.append(line)
else:
self._blanks.append(line)
else:
for i, r in enumerate(self._regex):
m = re.match(r, line)
if m:
# write the previously buffered section
self._flush_blanks()
self._write()
# start new section
section_type = self._spec[i]['type']
if m.group(1):
self._section = "%s.%s" % (section_type, m.group(1))
else:
self._section = "%s" % section_type
break
self._flush_comment()
self._buf.append(line)
self._flush_comment()
self._write()
if len(sys.argv) != 2:
stderr("usage: %s <schema.sql>" % os.path.basename(sys.argv[0]))
sys.exit(1)
with open(sys.argv[1], 'r') as schema_f:
d, basename = os.path.split(sys.argv[1])
schema_name, ext = os.path.splitext(basename)
sr = SectionReader(schema_name, schema_f, [
{
"match": "CREATE FUNCTION (\w+)",
"type": "fn",
},
{
"match": "CREATE AGGREGATE (\w+)",
"type": "agg",
},
{
"match": "CREATE TABLE \"?(\w+)\"?",
"type": "tbl",
},
{
"match": "ALTER TABLE ONLY (\w+)",
"type": "alt",
},
{
"match": "REVOKE ALL ON SCHEMA()",
"type": "acl-schema",
},
{
"match": "REVOKE ALL ON TABLE (\w+)",
"type": "acl",
},
{
"match": "SET ()",
"type": "db",
},
])
sr.read()