diff --git a/src/genomicranges/GenomicRanges.py b/src/genomicranges/GenomicRanges.py index 15a2659..55bc44a 100644 --- a/src/genomicranges/GenomicRanges.py +++ b/src/genomicranges/GenomicRanges.py @@ -2532,7 +2532,7 @@ def tile_by_range( ) seqnames.extend([val.seqnames[0]] * len(all_intervals)) - strand.extend([val.strand[0]] * len(all_intervals)) + strand.extend([int(val.strand[0])] * len(all_intervals)) starts.extend([x[0] for x in all_intervals]) widths.extend(x[1] for x in all_intervals) @@ -2594,7 +2594,7 @@ def tile( ) seqnames.extend([val.seqnames[0]] * len(all_intervals)) - strand.extend([val.strand[0]] * len(all_intervals)) + strand.extend([int(val.strand[0])] * len(all_intervals)) starts.extend([x[0] for x in all_intervals]) widths.extend(x[1] for x in all_intervals) @@ -2636,7 +2636,7 @@ def sliding_windows(self, width: int, step: int = 1) -> "GenomicRanges": ) seqnames.extend([val.seqnames[0]] * len(all_intervals)) - strand.extend([val.strand[0]] * len(all_intervals)) + strand.extend([int(val.strand[0])] * len(all_intervals)) starts.extend([x[0] for x in all_intervals]) widths.extend(x[1] for x in all_intervals) diff --git a/src/genomicranges/utils.py b/src/genomicranges/utils.py index 8918f74..b2c87bb 100644 --- a/src/genomicranges/utils.py +++ b/src/genomicranges/utils.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Union, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import biocutils as ut import numpy as np @@ -41,18 +41,23 @@ def sanitize_strand_vector( raise ValueError( "'strand' must only contain values 1 (forward strand), -1 (reverse strand) or 0 (reverse strand)." ) - return strand + return strand.astype(np.int8) if ut.is_list_of_type(strand, str): if not set(strand).issubset(["+", "-", "*"]): raise ValueError("Values in 'strand' must be either +, - or *.") - return np.array([STRAND_MAP[x] for x in strand]) - elif ut.is_list_of_type(strand, int): - return np.array(strand) - else: - TypeError( - "'strand' must be either a numpy vector, a list of integers or strings representing strand." - ) + return np.array([STRAND_MAP[x] for x in strand], dtype=np.int8) + + if ut.is_list_of_type(strand, (int, float, np.int_)): + if not set(strand).issubset([1, 0, -1]): + raise ValueError( + "'strand' must only contain values 1 (forward strand), -1 (reverse strand) or 0 (reverse strand)." + ) + return np.array(strand, dtype=np.int8) + + raise TypeError( + "'strand' must be either a numpy vector, a list of integers or strings representing strand." + ) def _sanitize_strand_search_ops(query_strand, subject_strand):