Skip to content

Commit a406b7b

Browse files
committed
cleanup: properly format the interface
1 parent d361b50 commit a406b7b

File tree

2 files changed

+295
-1
lines changed

2 files changed

+295
-1
lines changed

.github/workflows/pylint.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on: [push]
44

55
jobs:
66
build:
7-
runs-on: node20
7+
runs-on: ubuntu-latest
88
strategy:
99
matrix:
1010
python-version: ["3.12"]

src/interface.py

+294
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
"""
2+
@brief: Specify the project interface.
3+
@author: Hao Kang <[email protected]>
4+
"""
5+
6+
from pathlib import Path
7+
from abc import ABC, abstractmethod
8+
from typing import Iterator, Literal, List, Type, Tuple
9+
from torch import Tensor
10+
11+
12+
# Define the type aliases.
13+
EmbeddingName = Literal["BgeBase", "MiniCPM"]
14+
DatasetName = Literal["MsMarco", "Beir"]
15+
PartitionName = Literal["Train", "Validate"]
16+
17+
18+
class Embedding(ABC):
19+
"""
20+
The interface for an embedding model.
21+
22+
Attributes:
23+
name (EmbeddingName): The name of the embedding.
24+
size (int): The size of the embedding.
25+
"""
26+
27+
name: EmbeddingName
28+
size: int
29+
30+
@abstractmethod
31+
def __init__(self, devices: List[int]) -> None:
32+
"""
33+
Initialize the embedding model.
34+
35+
:type devices: List[int]
36+
:param devices: The devices to use for embedding.
37+
"""
38+
raise NotImplementedError
39+
40+
@abstractmethod
41+
def forward(self, passages: List[str]) -> Tensor:
42+
"""
43+
Forward pass to embed the given passages.
44+
45+
:type passages: List[str]
46+
:param passages: The list of passages to embed.
47+
:rtype: torch.Tensor
48+
:return: The computed embeddings in a tensor of shape (N, D), where N
49+
is the number of passages and D is the embedding size.
50+
"""
51+
raise NotImplementedError
52+
53+
54+
class Dataset(ABC):
55+
"""
56+
The interface for a dataset.
57+
58+
Attributes:
59+
name (DatasetName): The name of the dataset.
60+
"""
61+
62+
name: DatasetName
63+
64+
@abstractmethod
65+
def didIter(self, batchSize: int) -> Iterator[List[str]]:
66+
"""
67+
Iterate over the document IDs in batches.
68+
69+
:type batchSize: int
70+
:param batchSize: The batch size for each iteration.
71+
:rtype: Iterator[List[str]]
72+
:return: An iterator over the document IDs. Each iteration yields a
73+
list of document IDs.
74+
"""
75+
raise NotImplementedError
76+
77+
@abstractmethod
78+
def docIter(self, batchSize: int) -> Iterator[List[str]]:
79+
"""
80+
Iterate over the document texts in batches.
81+
82+
:type batchSize: int
83+
:param batchSize: The batch size for each iteration.
84+
:rtype: Iterator[List[str]]
85+
:return: The iterator over the document texts. Each iteration yields a
86+
list of document texts.
87+
"""
88+
raise NotImplementedError
89+
90+
@abstractmethod
91+
def docEmbIter(
92+
self,
93+
embedding: Type[Embedding],
94+
batchSize: int,
95+
numWorkers: int,
96+
shuffle: bool,
97+
) -> Iterator[Tensor]:
98+
"""
99+
Iterate over the document embeddings in batches.
100+
101+
:type embedding: Type[Embedding]
102+
:param embedding: The embedding model to use.
103+
:type batchSize: int
104+
:param batchSize: The batch size for each iteration.
105+
:type numWorkers: int
106+
:param numWorkers: The number of workers for data loading.
107+
:type shuffle: bool
108+
:param shuffle: Whether to shuffle the data during loading.
109+
:rtype: Iterator[Tensor]
110+
:return: The iterator over the document embeddings. Each tensor has
111+
shape (N, D), where N is the batch size, or less for the last
112+
batch, and D is the embedding size.
113+
"""
114+
raise NotImplementedError
115+
116+
@abstractmethod
117+
def getDocLen(self) -> int:
118+
"""
119+
Get the number of documents.
120+
121+
:rtype: int
122+
:return: The number of documents.
123+
"""
124+
raise NotImplementedError
125+
126+
@abstractmethod
127+
def qidIter(
128+
self, split: PartitionName, batchSize: int
129+
) -> Iterator[List[str]]:
130+
"""
131+
Iterate over the query IDs in batches.
132+
133+
:type split: PartitionName
134+
:param split: Whether to use the training or validation split.
135+
:type batchSize: int
136+
:param batchSize: The batch size for each iteration.
137+
:rtype: Iterator[List[str]]
138+
:return: The iterator over the query IDs. Each iteration yields a list
139+
of query IDs.
140+
"""
141+
raise NotImplementedError
142+
143+
@abstractmethod
144+
def qryIter(
145+
self, split: PartitionName, batchSize: int
146+
) -> Iterator[List[str]]:
147+
"""
148+
Iterate over the query texts in batches.
149+
150+
:type split: PartitionName
151+
:param split: Whether to use the training or validation split.
152+
:type batchSize: int
153+
:param batchSize: The batch size for each iteration.
154+
:rtype: Iterator[List[str]]
155+
:return: The iterator over the query texts. Each iteration yields a
156+
list of query texts.
157+
"""
158+
raise NotImplementedError
159+
160+
@abstractmethod
161+
def qryEmbIter(
162+
self,
163+
split: PartitionName,
164+
embedding: Type[Embedding],
165+
batchSize: int,
166+
numWorkers: int,
167+
shuffle: bool,
168+
) -> Iterator[Tensor]:
169+
"""
170+
Iterate over the query embeddings in batches.
171+
172+
:type split: PartitionName
173+
:param split: Whether to use the training or validation split.
174+
:type embedding: Type[Embedding]
175+
:param embedding: The embedding class to use.
176+
:type batchSize: int
177+
:param batchSize: The batch size for each iteration.
178+
:type numWorkers: int
179+
:param numWorkers: The number of workers for data loading.
180+
:type shuffle: bool
181+
:param shuffle: Whether to shuffle the data.
182+
:rtype: Iterator[Tensor]
183+
:return: The iterator over the query embeddings. Each tensor has shape
184+
(N, D), where N is the batch size, or less for the last batch, and
185+
D is the embedding size.
186+
"""
187+
raise NotImplementedError
188+
189+
@abstractmethod
190+
def getQryLen(self, split: PartitionName) -> int:
191+
"""
192+
Get the number of queries.
193+
194+
:type split: PartitionName
195+
:param split: Whether to use the training or validation split.
196+
:rtype: int
197+
:return: The number of queries.
198+
"""
199+
raise NotImplementedError
200+
201+
@abstractmethod
202+
def getQryRel(self, split: PartitionName) -> Path:
203+
"""
204+
Get the path to the query relevance file.
205+
206+
:type split: PartitionName
207+
:param split: Whether to use the training or validation split.
208+
:rtype: Path
209+
:return: The path to the query relevance file.
210+
"""
211+
212+
@abstractmethod
213+
def mixEmbIter(
214+
self,
215+
split: PartitionName,
216+
embedding: Type[Embedding],
217+
relevant: int,
218+
batchSize: int,
219+
numWorkers: int,
220+
shuffle: bool,
221+
) -> Iterator[Tuple[Tensor, Tensor]]:
222+
"""
223+
Iterate over the embeddings of query and its retrieved documents in
224+
batches.
225+
226+
:type split: PartitionName
227+
:param split: Whether to use the training or validation split.
228+
:type embedding: Type[Embedding]
229+
:param embedding: The embedding class to use.
230+
:type relevant: int
231+
:param relevant: The number of documents to include for each query.
232+
:type batchSize: int
233+
:param batchSize: The batch size for each iteration.
234+
:type numWorkers: int
235+
:param numWorkers: The number of workers for data loading.
236+
:type shuffle: bool
237+
:param shuffle: Whether to shuffle the data.
238+
:rtype: Iterator[Tuple[Tensor, Tensor]]
239+
:return: The iterator over the query and document embeddings. The
240+
first tensor is the query embeddings and has shape (N, D), where N
241+
is the batch size, or less for the last batch, and D is the
242+
embedding size. The second tensor is the document embeddings and
243+
has shape (N, K, D), where K is the number of relevant documents.
244+
"""
245+
raise NotImplementedError
246+
247+
@abstractmethod
248+
def getMixLen(self, split: PartitionName) -> int:
249+
"""
250+
Get the number of query-document pairs.
251+
This function is equival to getQryLen.
252+
253+
:type split: PartitionName
254+
:param split: Whether to use the training or validation split.
255+
:rtype: int
256+
:return: The number of query-document pairs.
257+
"""
258+
raise NotImplementedError
259+
260+
261+
class SAE(ABC):
262+
"""
263+
The interface for a sparse autoencoder.
264+
"""
265+
266+
def __init__(self, features: int, expandBy: int) -> None:
267+
"""
268+
Initialize the sparse autoencoder.
269+
270+
:type features: int
271+
:param features: The embedding size.
272+
:type expandBy: int
273+
:param expandBy: Expand factor for the dictionary.
274+
"""
275+
raise NotImplementedError
276+
277+
def forward(self, x: Tensor, activate: int) -> Tuple[Tensor, Tensor]:
278+
"""
279+
Forward pass to reconstruct the embedding.
280+
281+
:type x: Tensor
282+
:param x: The original embedding. The tensor has shape (N, D), where N
283+
is the batch size and D is the embedding size.
284+
:type K: int
285+
:param activate: The number of features to activate. This is the
286+
sparsity constraint. Only the top-K features are activated. The
287+
rest are set to zero.
288+
:rtype: Tuple[Tensor, Tensor]
289+
:return: The latent features and the reconstructed embedding. The
290+
latent features have shape (N, D), where D is the dictionary size.
291+
The reconstructed embedding has shape (N, E), where E is the
292+
embedding size. N is the batch size in both cases.
293+
"""
294+
raise NotImplementedError

0 commit comments

Comments
 (0)