-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathq1b_gibbs.py
111 lines (90 loc) · 2.87 KB
/
q1b_gibbs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import matplotlib.pyplot as plt
import random
def initializeDist():
pTable = dict()
pTable[("a0", "b1")] = 5
pTable[("a1", "b1")] = 10
pTable[("b1", "c0")] = 1
pTable[("b1", "c1")] = 100
pTable[("c0", "d0")] = 1
pTable[("c0", "d1")] = 100
pTable[("c1", "d0")] = 100
pTable[("c1", "d1")] = 1
pTable[("a0", "d0")] = 100
pTable[("a0", "d1")] = 1
pTable[("a1", "d0")] = 1
pTable[("a1", "d1")] = 100
return pTable
def normalize(probT, probF):
total = float(probT + probF)
return float(probT) / total, float(probF) / total
def getProb(numTrues, numSamples):
return float(numTrues) / float(numSamples)
def sample(n):
B = "b1"
a, c, d = "a1", "c1", "d1"
pTable = initializeDist()
results = list()
numTrueA = 0
switch = 1
for i in range(n):
# sample A
if switch == 1:
probAT = pTable["a1", B] * pTable["a1", d]
probAF = pTable["a0", B] * pTable["a0", d]
normProbAT, normProbAF = normalize(probAT, probAF)
sample = random.random()
if sample <= normProbAF:
a = "a0"
else:
a = "a1"
numTrueA += 1
totalProbA = getProb(numTrueA, i+1)
results.append(totalProbA)
switch += 1
# sample C
elif switch == 2:
probCT = pTable[B, "c1"] * pTable["c1", d]
probCF = pTable[B, "c0"] * pTable["c0", d]
normProbCT, normProbCF = normalize(probCT, probCF)
sample = random.random()
if sample <= normProbCF:
c = "c0"
else:
c = "c1"
if a == "a1":
numTrueA += 1
totalProbA = getProb(numTrueA, i+1)
results.append(totalProbA)
switch += 1
# sample D
elif switch == 3:
probDT = pTable[a, "d1"] * pTable[c, "d1"]
probDF = pTable[a, "d0"] * pTable[c, "d0"]
normProbDT, normProbDF = normalize(probDT, probDF)
sample = random.random()
if sample <= normProbDF:
d = "d0"
else:
d = "d1"
if a == "a1":
numTrueA += 1
totalProbA = getProb(numTrueA, i+1)
results.append(totalProbA)
switch = 1
return results
def plotResults(data):
varElimA1 = 0.0566
plt.plot(data, label="Gibbs Sampling")
plt.axhline(y=varElimA1, color="red", label="Variable Elimination")
gibbsA1 = data.pop()
plt.text(100005, varElimA1+0.006, str(varElimA1), color="red")
plt.text(100005, gibbsA1-0.007, str(gibbsA1), color="blue")
plt.ylim(-0.05, 0.5)
plt.xlabel("# of samples")
plt.ylabel("P(a|b)")
plt.title("Gibbs Sampling Results")
plt.legend()
plt.show()
results = sample(100000)
plotResults(results)