forked from mypetyak/StrokeWidthTransform
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathswt.py
381 lines (332 loc) · 14.2 KB
/
swt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
# -*- encoding: utf-8 -*-
from __future__ import division
from collections import defaultdict
import hashlib
import math
import os
import time
from urllib2 import urlopen
import numpy as np
import cv2
import scipy.sparse, scipy.spatial
t0 = time.clock()
diagnostics = True
class SWTScrubber(object):
@classmethod
def scrub(cls, filepath):
"""
Apply Stroke-Width Transform to image.
:param filepath: relative or absolute filepath to source image
:return: numpy array representing result of transform
"""
canny, sobelx, sobely, theta = cls._create_derivative(filepath)
swt = cls._swt(theta, canny, sobelx, sobely)
shapes = cls._connect_components(swt)
swts, heights, widths, topleft_pts, images = cls._find_letters(swt, shapes)
word_images = cls._find_words(swts, heights, widths, topleft_pts, images)
final_mask = np.zeros(swt.shape)
for word in word_images:
final_mask += word
return final_mask
@classmethod
def _create_derivative(cls, filepath):
img = cv2.imread(filepath,0)
edges = cv2.Canny(img, 175, 320, apertureSize=3)
# Create gradient map using Sobel
sobelx64f = cv2.Sobel(img,cv2.CV_64F,1,0,ksize=-1)
sobely64f = cv2.Sobel(img,cv2.CV_64F,0,1,ksize=-1)
theta = np.arctan2(sobely64f, sobelx64f)
if diagnostics:
cv2.imwrite('edges.jpg',edges)
cv2.imwrite('sobelx64f.jpg', np.absolute(sobelx64f))
cv2.imwrite('sobely64f.jpg', np.absolute(sobely64f))
# amplify theta for visual inspection
theta_visible = (theta + np.pi)*255/(2*np.pi)
cv2.imwrite('theta.jpg', theta_visible)
return (edges, sobelx64f, sobely64f, theta)
@classmethod
def _swt(self, theta, edges, sobelx64f, sobely64f):
# create empty image, initialized to infinity
swt = np.empty(theta.shape)
swt[:] = np.Infinity
rays = []
print time.clock() - t0
# now iterate over pixels in image, checking Canny to see if we're on an edge.
# if we are, follow a normal a ray to either the next edge or image border
# edgesSparse = scipy.sparse.coo_matrix(edges)
step_x_g = -1 * sobelx64f
step_y_g = -1 * sobely64f
mag_g = np.sqrt( step_x_g * step_x_g + step_y_g * step_y_g )
grad_x_g = step_x_g / mag_g
grad_y_g = step_y_g / mag_g
for x in xrange(edges.shape[1]):
for y in xrange(edges.shape[0]):
if edges[y, x] > 0:
step_x = step_x_g[y, x]
step_y = step_y_g[y, x]
mag = mag_g[y, x]
grad_x = grad_x_g[y, x]
grad_y = grad_y_g[y, x]
ray = []
ray.append((x, y))
prev_x, prev_y, i = x, y, 0
while True:
i += 1
cur_x = math.floor(x + grad_x * i)
cur_y = math.floor(y + grad_y * i)
if cur_x != prev_x or cur_y != prev_y:
# we have moved to the next pixel!
try:
if edges[cur_y, cur_x] > 0:
# found edge,
ray.append((cur_x, cur_y))
theta_point = theta[y, x]
alpha = theta[cur_y, cur_x]
if math.acos(grad_x * -grad_x_g[cur_y, cur_x] + grad_y * -grad_y_g[cur_y, cur_x]) < np.pi/2.0:
thickness = math.sqrt( (cur_x - x) * (cur_x - x) + (cur_y - y) * (cur_y - y) )
for (rp_x, rp_y) in ray:
swt[rp_y, rp_x] = min(thickness, swt[rp_y, rp_x])
rays.append(ray)
break
# this is positioned at end to ensure we don't add a point beyond image boundary
ray.append((cur_x, cur_y))
except IndexError:
# reached image boundary
break
prev_x = cur_x
prev_y = cur_y
# Compute median SWT
for ray in rays:
median = np.median([swt[y, x] for (x, y) in ray])
for (x, y) in ray:
swt[y, x] = min(median, swt[y, x])
if diagnostics:
cv2.imwrite('swt.jpg', swt * 100)
return swt
@classmethod
def _connect_components(cls, swt):
# STEP: Compute distinct connected components
# Implementation of disjoint-set
class Label(object):
def __init__(self, value):
self.value = value
self.parent = self
self.rank = 0
def __eq__(self, other):
if type(other) is type(self):
return self.value == other.value
else:
return False
def __ne__(self, other):
return not self.__eq__(other)
ld = {}
def MakeSet(x):
try:
return ld[x]
except KeyError:
item = Label(x)
ld[x] = item
return item
def Find(item):
# item = ld[x]
if item.parent != item:
item.parent = Find(item.parent)
return item.parent
def Union(x, y):
"""
:param x:
:param y:
:return: root node of new union tree
"""
x_root = Find(x)
y_root = Find(y)
if x_root == y_root:
return x_root
if x_root.rank < y_root.rank:
x_root.parent = y_root
return y_root
elif x_root.rank > y_root.rank:
y_root.parent = x_root
return x_root
else:
y_root.parent = x_root
x_root.rank += 1
return x_root
# apply Connected Component algorithm, comparing SWT values.
# components with a SWT ratio less extreme than 1:3 are assumed to be
# connected. Apply twice, once for each ray direction/orientation, to
# allow for dark-on-light and light-on-dark texts
trees = {}
# Assumption: we'll never have more than 65535-1 unique components
label_map = np.zeros(shape=swt.shape, dtype=np.uint16)
next_label = 1
# First Pass, raster scan-style
swt_ratio_threshold = 3.0
for y in xrange(swt.shape[0]):
for x in xrange(swt.shape[1]):
sw_point = swt[y, x]
if sw_point < np.Infinity and sw_point > 0:
neighbors = [(y, x-1), # west
(y-1, x-1), # northwest
(y-1, x), # north
(y-1, x+1)] # northeast
connected_neighbors = None
neighborvals = []
for neighbor in neighbors:
# west
try:
sw_n = swt[neighbor]
label_n = label_map[neighbor]
except IndexError:
continue
if label_n > 0 and sw_n / sw_point < swt_ratio_threshold and sw_point / sw_n < swt_ratio_threshold:
neighborvals.append(label_n)
if connected_neighbors:
connected_neighbors = Union(connected_neighbors, MakeSet(label_n))
else:
connected_neighbors = MakeSet(label_n)
if not connected_neighbors:
# We don't see any connections to North/West
trees[next_label] = (MakeSet(next_label))
label_map[y, x] = next_label
next_label += 1
else:
# We have at least one connection to North/West
label_map[y, x] = min(neighborvals)
# For each neighbor, make note that their respective connected_neighbors are connected
# for label in connected_neighbors. @todo: do I need to loop at all neighbor trees?
trees[connected_neighbors.value] = Union(trees[connected_neighbors.value], connected_neighbors)
# Second pass. re-base all labeling with representative label for each connected tree
layers = {}
contours = defaultdict(list)
for x in xrange(swt.shape[1]):
for y in xrange(swt.shape[0]):
if label_map[y, x] > 0:
item = ld[label_map[y, x]]
common_label = Find(item).value
label_map[y, x] = common_label
contours[common_label].append([x, y])
try:
layer = layers[common_label]
except KeyError:
layers[common_label] = np.zeros(shape=swt.shape, dtype=np.uint16)
layer = layers[common_label]
layer[y, x] = 1
return layers
@classmethod
def _find_letters(cls, swt, shapes):
# STEP: Discard shapes that are probably not letters
swts = []
heights = []
widths = []
topleft_pts = []
images = []
for label,layer in shapes.iteritems():
(nz_y, nz_x) = np.nonzero(layer)
east, west, south, north = max(nz_x), min(nz_x), max(nz_y), min(nz_y)
width, height = east - west, south - north
if width < 8 or height < 8:
continue
if width / height > 10 or height / width > 10:
continue
diameter = math.sqrt(width * width + height * height)
median_swt = np.median(swt[(nz_y, nz_x)])
if diameter / median_swt > 10:
continue
if width / layer.shape[1] > 0.4 or height / layer.shape[0] > 0.4:
continue
if diagnostics:
print " written to image."
cv2.imwrite('layer'+ str(label) +'.jpg', layer * 255)
# we use log_base_2 so we can do linear distance comparison later using k-d tree
# ie, if log2(x) - log2(y) > 1, we know that x > 2*y
# Assumption: we've eliminated anything with median_swt == 1
swts.append([math.log(median_swt, 2)])
heights.append([math.log(height, 2)])
topleft_pts.append(np.asarray([north, west]))
widths.append(width)
images.append(layer)
return swts, heights, widths, topleft_pts, images
@classmethod
def _find_words(cls, swts, heights, widths, topleft_pts, images):
# Find all shape pairs that have similar median stroke widths
print 'SWTS'
print swts
print 'DONESWTS'
swt_tree = scipy.spatial.KDTree(np.asarray(swts))
stp = swt_tree.query_pairs(1)
# Find all shape pairs that have similar heights
height_tree = scipy.spatial.KDTree(np.asarray(heights))
htp = height_tree.query_pairs(1)
# Intersection of valid pairings
isect = htp.intersection(stp)
chains = []
pairs = []
pair_angles = []
for pair in isect:
left = pair[0]
right = pair[1]
widest = max(widths[left], widths[right])
distance = np.linalg.norm(topleft_pts[left] - topleft_pts[right])
if distance < widest * 3:
delta_yx = topleft_pts[left] - topleft_pts[right]
angle = np.arctan2(delta_yx[0], delta_yx[1])
if angle < 0:
angle += np.pi
pairs.append(pair)
pair_angles.append(np.asarray([angle]))
angle_tree = scipy.spatial.KDTree(np.asarray(pair_angles))
atp = angle_tree.query_pairs(np.pi/12)
for pair_idx in atp:
pair_a = pairs[pair_idx[0]]
pair_b = pairs[pair_idx[1]]
left_a = pair_a[0]
right_a = pair_a[1]
left_b = pair_b[0]
right_b = pair_b[1]
# @todo - this is O(n^2) or similar, extremely naive. Use a search tree.
added = False
for chain in chains:
if left_a in chain:
chain.add(right_a)
added = True
elif right_a in chain:
chain.add(left_a)
added = True
if not added:
chains.append(set([left_a, right_a]))
added = False
for chain in chains:
if left_b in chain:
chain.add(right_b)
added = True
elif right_b in chain:
chain.add(left_b)
added = True
if not added:
chains.append(set([left_b, right_b]))
word_images = []
for chain in [c for c in chains if len(c) > 3]:
for idx in chain:
word_images.append(images[idx])
# cv2.imwrite('keeper'+ str(idx) +'.jpg', images[idx] * 255)
# final += images[idx]
return word_images
file_url = 'http://upload.wikimedia.org/wikipedia/commons/0/0b/ReceiptSwiss.jpg'
local_filename = hashlib.sha224(file_url).hexdigest()
try:
s3_response = urlopen(file_url)
with open(local_filename, 'wb+') as destination:
while True:
# read file in 4kB chunks
chunk = s3_response.read(4096)
if not chunk: break
destination.write(chunk)
#final_mask = SWTScrubber.scrub('wallstreetsmd.jpeg')
final_mask = SWTScrubber.scrub(local_filename)
# final_mask = cv2.GaussianBlur(final_mask, (1, 3), 0)
# cv2.GaussianBlur(sobelx64f, (3, 3), 0)
cv2.imwrite('final.jpg', final_mask * 255)
print time.clock() - t0
finally:
s3_response.close()