-
-
Notifications
You must be signed in to change notification settings - Fork 553
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Trac #33324: pycodestyle cleaning in discrete Fourier transforms
also fixing bugs in discrete Sine and Cosine transforms URL: https://trac.sagemath.org/33324 Reported by: chapoton Ticket author(s): Frédéric Chapoton Reviewer(s): Travis Scrimshaw
- Loading branch information
Showing
1 changed file
with
95 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,22 +64,19 @@ | |
- William Stein (2006-11) -- fix many bugs | ||
""" | ||
|
||
########################################################################## | ||
# Copyright (C) 2006 David Joyner <[email protected]> | ||
# | ||
# Distributed under the terms of the GNU General Public License (GPL): | ||
# | ||
# http://www.gnu.org/licenses/ | ||
# https://www.gnu.org/licenses/ | ||
########################################################################## | ||
|
||
from sage.rings.number_field.number_field import CyclotomicField | ||
from sage.plot.all import polygon, line, text | ||
from sage.groups.abelian_gps.abelian_group import AbelianGroup | ||
from sage.groups.perm_gps.permgroup_element import is_PermutationGroupElement | ||
from sage.rings.integer_ring import ZZ | ||
from sage.rings.integer import Integer | ||
from sage.arith.all import factor | ||
from sage.rings.rational_field import QQ | ||
from sage.rings.real_mpfr import RR | ||
from sage.functions.all import sin, cos | ||
|
@@ -89,6 +86,7 @@ | |
from sage.structure.sage_object import SageObject | ||
from sage.structure.sequence import Sequence | ||
|
||
|
||
class IndexedSequence(SageObject): | ||
""" | ||
An indexed sequence. | ||
|
@@ -225,9 +223,9 @@ def _repr_(self): | |
Indexed sequence: [0, 1, 1] | ||
indexed by Finite Field of size 3 | ||
""" | ||
return "Indexed sequence: "+str(self.list())+"\n indexed by "+str(self.index_object()) | ||
return "Indexed sequence: " + str(self.list()) + "\n indexed by " + str(self.index_object()) | ||
|
||
def plot_histogram(self, clr=(0,0,1), eps = 0.4): | ||
def plot_histogram(self, clr=(0, 0, 1), eps=0.4): | ||
r""" | ||
Plot the histogram plot of the sequence. | ||
|
@@ -249,8 +247,13 @@ def plot_histogram(self, clr=(0,0,1), eps = 0.4): | |
I = self.index_object() | ||
N = len(I) | ||
S = self.list() | ||
P = [polygon([[RR(I[i])-eps,0],[RR(I[i])-eps,RR(S[i])],[RR(I[i])+eps,RR(S[i])],[RR(I[i])+eps,0],[RR(I[i]),0]], rgbcolor=clr) for i in range(N)] | ||
T = [text(str(I[i]),(RR(I[i]),-0.8),fontsize=15,rgbcolor=(1,0,0)) for i in range(N)] | ||
P = [polygon([[RR(I[i]) - eps, 0], | ||
[RR(I[i]) - eps, RR(S[i])], | ||
[RR(I[i]) + eps, RR(S[i])], | ||
[RR(I[i]) + eps, 0], | ||
[RR(I[i]), 0]], rgbcolor=clr) for i in range(N)] | ||
T = [text(str(I[i]), (RR(I[i]), -0.8), fontsize=15, rgbcolor=(1, 0, 0)) | ||
for i in range(N)] | ||
return sum(P) + sum(T) | ||
|
||
def plot(self): | ||
|
@@ -271,9 +274,9 @@ def plot(self): | |
# elements must be coercible into RR | ||
I = self.index_object() | ||
S = self.list() | ||
return line([[RR(I[i]),RR(S[i])] for i in range(len(I)-1)]) | ||
return line([[RR(I[i]), RR(S[i])] for i in range(len(I) - 1)]) | ||
|
||
def dft(self, chi = lambda x: x): | ||
def dft(self, chi=lambda x: x): | ||
r""" | ||
A discrete Fourier transform "over `\QQ`" using exact | ||
`N`-th roots of unity. | ||
|
@@ -322,34 +325,34 @@ def dft(self, chi = lambda x: x): | |
implemented Group (permutation, matrix), call .characters() | ||
and test if the index list is the set of conjugacy classes. | ||
""" | ||
J = self.index_object() ## index set of length N | ||
J = self.index_object() # index set of length N | ||
N = len(J) | ||
S = self.list() | ||
F = self.base_ring() ## elements must be coercible into QQ(zeta_N) | ||
F = self.base_ring() # elements must be coercible into QQ(zeta_N) | ||
if not(J[0] in ZZ): | ||
G = J[0].parent() ## if J is not a range it is a group G | ||
if J[0] in ZZ and F.base_ring().fraction_field()==QQ: | ||
## assumes J is range(N) | ||
G = J[0].parent() # if J is not a range it is a group G | ||
if J[0] in ZZ and F.base_ring().fraction_field() == QQ: | ||
# assumes J is range(N) | ||
zeta = CyclotomicField(N).gen() | ||
FT = [sum([S[i]*chi(zeta**(i*j)) for i in J]) for j in J] | ||
elif not(J[0] in ZZ) and G.is_abelian() and F == ZZ or (F.is_field() and F.base_ring()==QQ): | ||
FT = [sum([S[i] * chi(zeta**(i * j)) for i in J]) for j in J] | ||
elif (J[0] not in ZZ) and G.is_abelian() and F == ZZ or (F.is_field() and F.base_ring() == QQ): | ||
if is_PermutationGroupElement(J[0]): | ||
## J is a CyclicPermGp | ||
# J is a CyclicPermGp | ||
n = G.order() | ||
a = list(factor(n)) | ||
a = list(n.factor()) | ||
invs = [x[0]**x[1] for x in a] | ||
G = AbelianGroup(len(a),invs) | ||
## assumes J is AbelianGroup(...) | ||
G = AbelianGroup(len(a), invs) | ||
# assumes J is AbelianGroup(...) | ||
Gd = G.dual_group() | ||
FT = [sum([S[i]*chid(G.list()[i]) for i in range(N)]) | ||
FT = [sum([S[i] * chid(G.list()[i]) for i in range(N)]) | ||
for chid in Gd] | ||
elif not(J[0] in ZZ) and G.is_finite() and F == ZZ or (F.is_field() and F.base_ring()==QQ): | ||
## assumes J is the list of conj class representatives of a | ||
## PermutationGroup(...) or Matrixgroup(...) | ||
elif (J[0] not in ZZ) and G.is_finite() and F == ZZ or (F.is_field() and F.base_ring() == QQ): | ||
# assumes J is the list of conj class representatives of a | ||
# PermutationGroup(...) or Matrixgroup(...) | ||
chi = G.character_table() | ||
FT = [sum([S[i]*chi[i,j] for i in range(N)]) for j in range(N)] | ||
FT = [sum([S[i] * chi[i, j] for i in range(N)]) for j in range(N)] | ||
else: | ||
raise ValueError("list elements must be in QQ(zeta_"+str(N)+")") | ||
raise ValueError(f"list elements must be in QQ(zeta_{N})") | ||
return IndexedSequence(FT, J) | ||
|
||
def idft(self): | ||
|
@@ -370,15 +373,15 @@ def idft(self): | |
sage: it == s | ||
True | ||
""" | ||
F = self.base_ring() ## elements must be coercible into QQ(zeta_N) | ||
J = self.index_object() ## must be = range(N) | ||
F = self.base_ring() # elements must be coercible into QQ(zeta_N) | ||
J = self.index_object() # must be = range(N) | ||
N = len(J) | ||
S = self.list() | ||
zeta = CyclotomicField(N).gen() | ||
iFT = [sum([S[i]*zeta**(-i*j) for i in J]) for j in J] | ||
if not(J[0] in ZZ) or F.base_ring().fraction_field() != QQ: | ||
iFT = [sum([S[i] * zeta**(-i * j) for i in J]) for j in J] | ||
if (J[0] not in ZZ) or F.base_ring().fraction_field() != QQ: | ||
raise NotImplementedError("Sorry this type of idft is not implemented yet.") | ||
return IndexedSequence(iFT,J)*(Integer(1)/N) | ||
return IndexedSequence(iFT, J) * (Integer(1) / N) | ||
|
||
def dct(self): | ||
""" | ||
|
@@ -390,17 +393,17 @@ def dct(self): | |
sage: A = [exp(-2*pi*i*I/5) for i in J] | ||
sage: s = IndexedSequence(A,J) | ||
sage: s.dct() | ||
Indexed sequence: [1/16*(sqrt(5) + I*sqrt(-2*sqrt(5) + 10) + ... | ||
Indexed sequence: [0, 1/16*(sqrt(5) + I*sqrt(-2*sqrt(5) + 10) + ... | ||
indexed by [0, 1, 2, 3, 4] | ||
""" | ||
from sage.symbolic.constants import pi | ||
F = self.base_ring() ## elements must be coercible into RR | ||
J = self.index_object() ## must be = range(N) | ||
F = self.base_ring() # elements must be coercible into RR | ||
J = self.index_object() # must be = range(N) | ||
N = len(J) | ||
S = self.list() | ||
PI = F(pi) | ||
FT = [sum([S[i]*cos(2*PI*i/N) for i in J]) for j in J] | ||
return IndexedSequence(FT,J) | ||
PI = 2 * F(pi) / N | ||
FT = [sum([S[i] * cos(PI * i * j) for i in J]) for j in J] | ||
return IndexedSequence(FT, J) | ||
|
||
def dst(self): | ||
""" | ||
|
@@ -414,17 +417,17 @@ def dst(self): | |
sage: s = IndexedSequence(A,J) | ||
sage: s.dst() # discrete sine | ||
Indexed sequence: [1.11022302462516e-16 - 2.50000000000000*I, 1.11022302462516e-16 - 2.50000000000000*I, 1.11022302462516e-16 - 2.50000000000000*I, 1.11022302462516e-16 - 2.50000000000000*I, 1.11022302462516e-16 - 2.50000000000000*I] | ||
indexed by [0, 1, 2, 3, 4] | ||
Indexed sequence: [0.000000000000000, 1.11022302462516e-16 - 2.50000000000000*I, ...] | ||
indexed by [0, 1, 2, 3, 4] | ||
""" | ||
from sage.symbolic.constants import pi | ||
F = self.base_ring() ## elements must be coercible into RR | ||
J = self.index_object() ## must be = range(N) | ||
F = self.base_ring() # elements must be coercible into RR | ||
J = self.index_object() # must be = range(N) | ||
N = len(J) | ||
S = self.list() | ||
PI = F(pi) | ||
FT = [sum([S[i]*sin(2*PI*i/N) for i in J]) for j in J] | ||
return IndexedSequence(FT,J) | ||
PI = 2 * F(pi) / N | ||
FT = [sum([S[i] * sin(PI * i * j) for i in J]) for j in J] | ||
return IndexedSequence(FT, J) | ||
|
||
def convolution(self, other): | ||
r""" | ||
|
@@ -471,19 +474,18 @@ def convolution(self, other): | |
raise TypeError("IndexedSequences must have same index set") | ||
M = len(S) | ||
N = len(T) | ||
if M < N: ## first, extend by 0 if necessary | ||
a = [S[i] for i in range(M)]+[F(0) for i in range(2*N)] | ||
b = T+[E(0) for i in range(2*M)] | ||
if M > N: ## python trick - a[-j] is really j from the *right* | ||
b = [T[i] for i in range(N)]+[E(0) for i in range(2*M)] | ||
a = S+[F(0) for i in range(2*M)] | ||
if M==N: ## so need only extend by 0 to the *right* | ||
a = S+[F(0) for i in range(2*M)] | ||
b = T+[E(0) for i in range(2*M)] | ||
N = max(M,N) | ||
c = [sum([a[i]*b[j-i] for i in range(N)]) for j in range(2*N-1)] | ||
#print([[b[j-i] for i in range(N)] for j in range(N)]) | ||
return c | ||
if M < N: # first, extend by 0 if necessary | ||
a = [S[i] for i in range(M)] + [F(0) for i in range(2 * N)] | ||
b = T + [E(0) for i in range(2 * M)] | ||
if M > N: # python trick - a[-j] is really j from the *right* | ||
b = [T[i] for i in range(N)] + [E(0) for i in range(2 * M)] | ||
a = S + [F(0) for i in range(2 * M)] | ||
if M == N: # so need only extend by 0 to the *right* | ||
a = S + [F(0) for i in range(2 * M)] | ||
b = T + [E(0) for i in range(2 * M)] | ||
N = max(M, N) | ||
return [sum([a[i] * b[j - i] for i in range(N)]) | ||
for j in range(2 * N - 1)] | ||
|
||
def convolution_periodic(self, other): | ||
r""" | ||
|
@@ -531,17 +533,17 @@ def convolution_periodic(self, other): | |
M = len(S) | ||
N = len(T) | ||
if M < N: # first, extend by 0 if necessary | ||
a = [S[i] for i in range(M)]+[F(0) for i in range(N-M)] | ||
a = [S[i] for i in range(M)] + [F(0) for i in range(N - M)] | ||
b = other | ||
if M > N: | ||
b = [T[i] for i in range(N)]+[E(0) for i in range(M-N)] | ||
b = [T[i] for i in range(N)] + [E(0) for i in range(M - N)] | ||
a = self | ||
if M == N: | ||
a = S | ||
b = T | ||
N = max(M, N) | ||
c = [sum([a[i]*b[(j-i)%N] for i in range(N)]) for j in range(2*N-1)] | ||
return c | ||
return [sum([a[i] * b[(j - i) % N] for i in range(N)]) | ||
for j in range(2 * N - 1)] | ||
|
||
def __mul__(self, other): | ||
""" | ||
|
@@ -563,7 +565,7 @@ def __mul__(self, other): | |
S1 = [S[i] * other for i in range(len(self.index_object()))] | ||
return IndexedSequence(S1, self.index_object()) | ||
|
||
def __eq__(self,other): | ||
def __eq__(self, other): | ||
""" | ||
Implements boolean equals. | ||
|
@@ -587,16 +589,17 @@ def __eq__(self,other): | |
T = other.list() | ||
I = self.index_object() | ||
J = other.index_object() | ||
if I!=J: | ||
if I != J: | ||
return False | ||
for i in I: | ||
try: | ||
if abs(S[i]-T[i]) > 10**(-8): ## tests if they differ as reals -- WHY 10^(-8)??? | ||
if abs(S[i] - T[i]) > 10**(-8): | ||
# tests if they differ as reals -- WHY 10^(-8)??? | ||
return False | ||
except TypeError: | ||
pass | ||
#if F!=E: ## omitted this test since it | ||
# return 0 ## doesn't take into account coercions -- WHY??? | ||
# if F != E: # omitted this test since it | ||
# return 0 # doesn't take into account coercions -- WHY??? | ||
return True | ||
|
||
def fft(self): | ||
|
@@ -623,14 +626,14 @@ def fft(self): | |
I = CC.gen() | ||
|
||
# elements must be coercible into RR | ||
J = self.index_object() ## must be = range(N) | ||
J = self.index_object() # must be = range(N) | ||
N = len(J) | ||
S = self.list() | ||
a = FastFourierTransform(N) | ||
for i in range(N): | ||
a[i] = S[i] | ||
a.forward_transform() | ||
return IndexedSequence([a[j][0]+I*a[j][1] for j in J],J) | ||
return IndexedSequence([a[j][0] + I * a[j][1] for j in J], J) | ||
|
||
def ifft(self): | ||
""" | ||
|
@@ -660,16 +663,16 @@ def ifft(self): | |
I = CC.gen() | ||
|
||
# elements must be coercible into RR | ||
J = self.index_object() ## must be = range(N) | ||
J = self.index_object() # must be = range(N) | ||
N = len(J) | ||
S = self.list() | ||
a = FastFourierTransform(N) | ||
for i in range(N): | ||
a[i] = S[i] | ||
a.inverse_transform() | ||
return IndexedSequence([a[j][0]+I*a[j][1] for j in J],J) | ||
return IndexedSequence([a[j][0] + I * a[j][1] for j in J], J) | ||
|
||
def dwt(self,other="haar",wavelet_k=2): | ||
def dwt(self, other="haar", wavelet_k=2): | ||
r""" | ||
Wraps the gsl ``WaveletTransform.forward`` in :mod:`~sage.calculus.transforms.dwt` | ||
(written by Joshua Kantor). Assumes the length of the sample is a | ||
|
@@ -709,28 +712,28 @@ def dwt(self,other="haar",wavelet_k=2): | |
indexed by [0, 1, 2, 3, 4, 5, 6, 7] | ||
""" | ||
# elements must be coercible into RR | ||
J = self.index_object() ## must be = range(N) | ||
N = len(J) ## must be 1 minus a power of 2 | ||
J = self.index_object() # must be = range(N) | ||
N = len(J) # must be 1 minus a power of 2 | ||
S = self.list() | ||
if other == "haar" or other == "haar_centered": | ||
if wavelet_k in [2]: | ||
a = WaveletTransform(N,other,wavelet_k) | ||
a = WaveletTransform(N, other, wavelet_k) | ||
else: | ||
raise ValueError("wavelet_k must be = 2") | ||
if other == "debauchies" or other == "debauchies_centered": | ||
if wavelet_k in [4,6,8,10,12,14,16,18,20]: | ||
a = WaveletTransform(N,other,wavelet_k) | ||
if other == "daubechies" or other == "daubechies_centered": | ||
if wavelet_k in [4, 6, 8, 10, 12, 14, 16, 18, 20]: | ||
a = WaveletTransform(N, other, wavelet_k) | ||
else: | ||
raise ValueError("wavelet_k must be in {4,6,8,10,12,14,16,18,20}") | ||
if other == "bspline" or other == "bspline_centered": | ||
if wavelet_k in [103,105,202,204,206,208,301,305,307,309]: | ||
a = WaveletTransform(N,other,103) | ||
if wavelet_k in [103, 105, 202, 204, 206, 208, 301, 305, 307, 309]: | ||
a = WaveletTransform(N, other, 103) | ||
else: | ||
raise ValueError("wavelet_k must be in {103,105,202,204,206,208,301,305,307,309}") | ||
for i in range(N): | ||
a[i] = S[i] | ||
a.forward_transform() | ||
return IndexedSequence([RR(a[j]) for j in J],J) | ||
return IndexedSequence([RR(a[j]) for j in J], J) | ||
|
||
def idwt(self, other="haar", wavelet_k=2): | ||
r""" | ||
|
@@ -786,26 +789,26 @@ def idwt(self, other="haar", wavelet_k=2): | |
True | ||
""" | ||
# elements must be coercible into RR | ||
J = self.index_object() ## must be = range(N) | ||
N = len(J) ## must be 1 minus a power of 2 | ||
J = self.index_object() # must be = range(N) | ||
N = len(J) # must be 1 minus a power of 2 | ||
S = self.list() | ||
k = wavelet_k | ||
if other=="haar" or other=="haar_centered": | ||
if other == "haar" or other == "haar_centered": | ||
if k in [2]: | ||
a = WaveletTransform(N,other,wavelet_k) | ||
a = WaveletTransform(N, other, wavelet_k) | ||
else: | ||
raise ValueError("wavelet_k must be = 2") | ||
if other=="debauchies" or other=="debauchies_centered": | ||
if k in [4,6,8,10,12,14,16,18,20]: | ||
a = WaveletTransform(N,other,wavelet_k) | ||
if other == "daubechies" or other == "daubechies_centered": | ||
if k in [4, 6, 8, 10, 12, 14, 16, 18, 20]: | ||
a = WaveletTransform(N, other, wavelet_k) | ||
else: | ||
raise ValueError("wavelet_k must be in {4,6,8,10,12,14,16,18,20}") | ||
if other=="bspline" or other=="bspline_centered": | ||
if k in [103,105,202,204,206,208,301,305,307,309]: | ||
a = WaveletTransform(N,other,103) | ||
if other == "bspline" or other == "bspline_centered": | ||
if k in [103, 105, 202, 204, 206, 208, 301, 305, 307, 309]: | ||
a = WaveletTransform(N, other, 103) | ||
else: | ||
raise ValueError("wavelet_k must be in {103,105,202,204,206,208,301,305,307,309}") | ||
for i in range(N): | ||
a[i] = S[i] | ||
a.backward_transform() | ||
return IndexedSequence([RR(a[j]) for j in J],J) | ||
return IndexedSequence([RR(a[j]) for j in J], J) |