-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
87 lines (68 loc) · 2.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
from data.degrade_images import degrade_image, DegradationType
from algorithm.dictionary import Dictionary, DictionaryType
from algorithm.sparse_solver import SparseSolver
from algorithm.statistics import calculate_psnr
INPUT_FILE = {
"file_path": Path("./data/sample_image/image.png") # Gray scale image
}
CONFIG = {
"image_degradation": {
"noise_sigma": 25,
"random_seed": 41,
"degradation_type": DegradationType.NOISE
},
"sparse_model": {
"patch_size": (10, 10), # The patch must be a square
"initial_dict": DictionaryType.DCT,
"enable_dictionary_learning": False, # False means a predefined dictionary will be used.
# "num_learning_iterations": 30, # number of learning iterations
# "epsilon": 210,
"verbose": True,
},
}
def check_config(config: dict):
if "epsilon" not in config:
config["sparse_model"]["epsilon"] = \
np.sqrt(1.1) * config["sparse_model"]["patch_size"][0] * config["image_degradation"]["noise_sigma"]
if not config["sparse_model"]["enable_dictionary_learning"]:
config["sparse_model"]["num_learning_iterations"] = 0
def run_scripts(input_file: dict, config: dict):
check_config(config)
# Load an image
img = np.array(Image.open(input_file["file_path"]))
degraded_img = degrade_image(img, config["image_degradation"])
dictionary = Dictionary(
dictionary_type=config["sparse_model"]["initial_dict"], patch_size=config["sparse_model"]["patch_size"]
)
sparse_solver = SparseSolver(
enable_dictionary_learning=config["sparse_model"]["enable_dictionary_learning"],
num_learning_iterations=config["sparse_model"]["num_learning_iterations"],
img=degraded_img,
dictionary=dictionary,
epsilon=config["sparse_model"]["epsilon"],
verbose=config["sparse_model"]["verbose"]
)
reconstructed_img = sparse_solver()
if config["sparse_model"]["verbose"]:
dictionary.show_dictionary()
# Data visualization
plt.figure(figsize=(14, 4))
plt.subplot(131)
plt.imshow(img, "gray")
plt.axis("off")
plt.title("Original image")
plt.subplot(132)
plt.imshow(degraded_img, "gray")
plt.axis("off")
plt.title("Degraded image PSNR={:.3f}".format(calculate_psnr(img, degraded_img)))
plt.subplot(133)
plt.imshow(reconstructed_img, "gray")
plt.axis("off")
plt.title("Reconstrucred image. PSNR={:.3f}".format(calculate_psnr(img, reconstructed_img)))
plt.show()
if __name__ == "__main__":
run_scripts(input_file=INPUT_FILE, config=CONFIG)