From adb23bffea26eeaa33fd51a85aee569e62c8cd25 Mon Sep 17 00:00:00 2001 From: Andrew Tritt Date: Thu, 10 Aug 2023 23:05:48 -0700 Subject: [PATCH] Skip sequences that are less than minimum sequence length (#44) --- src/gtnet/sequence.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/gtnet/sequence.py b/src/gtnet/sequence.py index f3d15f1..5489486 100644 --- a/src/gtnet/sequence.py +++ b/src/gtnet/sequence.py @@ -34,6 +34,8 @@ def __init__(self, window, step, vocab=None, padval=None, min_seq_len=100, devic self.device = device def encode(self, seq): + if len(seq) < self.min_seq_len: + raise ValueError(f"Minimum sequence length is {self.min_seq_len} - got {len(seq)}") if seq.dtype == np.dtype('S1'): seq = seq.view(np.uint8) elif seq.dtype == np.dtype('U1'): @@ -143,9 +145,14 @@ def readfiles(cls, encoder, fastas): for fa in fastas: logging.debug(f'loading {fa}') for seqid, values in cls.readfile(fa): - batches = encoder.encode(values) - val = (fa, seqid, len(values), batches) - yield val + if len(values) < encoder.min_seq_len: + logging.warning((f"Skipping {seqid} from {fa} - length less than " + "minimum sequence length {encoder.min_seq}")) + yield (fa, seqid, len(values), torch.zeros((0, 0, 0), dtype=torch.uint8)) + else: + batches = encoder.encode(values) + val = (fa, seqid, len(values), batches) + yield val class SerialLoader(Loader):