From d1b7ceab853363a55821a15572d3d574689a096c Mon Sep 17 00:00:00 2001 From: Todd Gardner Date: Sun, 2 Apr 2017 20:10:30 -0400 Subject: [PATCH] Fix StaticIterator prerelease as well --- pex/resolver.py | 19 +++++++++++-------- tests/test_resolver.py | 25 +++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/pex/resolver.py b/pex/resolver.py index cf32d8094..0ccd2c266 100644 --- a/pex/resolver.py +++ b/pex/resolver.py @@ -35,12 +35,13 @@ class Unsatisfiable(Exception): class StaticIterator(IteratorInterface): """An iterator that iterates over a static list of packages.""" - def __init__(self, packages): + def __init__(self, packages, allow_prereleases): self._packages = packages + self._allow_prereleases = allow_prereleases def iter(self, req): for package in self._packages: - if package.satisfies(req): + if package.satisfies(req, allow_prereleases=self._allow_prereleases): yield package @@ -150,13 +151,15 @@ def filter_packages_by_interpreter(cls, packages, interpreter, platform): return [package for package in packages if package.compatible(interpreter.identity, platform)] - def __init__(self, interpreter=None, platform=None): + def __init__(self, allow_prereleases=False, interpreter=None, platform=None): self._interpreter = interpreter or PythonInterpreter.get() self._platform = platform or Platform.current() + self._allow_prereleases = allow_prereleases def package_iterator(self, resolvable, existing=None): if existing: - existing = resolvable.compatible(StaticIterator(existing)) + existing = resolvable.compatible( + StaticIterator(existing, allow_prereleases=self._allow_prereleases)) else: existing = resolvable.packages() return self.filter_packages_by_interpreter(existing, self._interpreter, self._platform) @@ -231,17 +234,16 @@ def filter_packages_by_ttl(cls, packages, ttl, now=None): return [package for package in packages if package.remote or package.local and (now - os.path.getmtime(package.local_path)) < ttl] - def __init__(self, cache, cache_ttl, allow_prereleases=False, *args, **kw): + def __init__(self, cache, cache_ttl, *args, **kw): self.__cache = cache self.__cache_ttl = cache_ttl - self.__allow_prereleases = allow_prereleases safe_mkdir(self.__cache) super(CachingResolver, self).__init__(*args, **kw) # Short-circuiting package iterator. def package_iterator(self, resolvable, existing=None): iterator = Iterator(fetchers=[Fetcher([self.__cache])], - allow_prereleases=self.__allow_prereleases) + allow_prereleases=self._allow_prereleases) packages = self.filter_packages_by_interpreter( resolvable.compatible(iterator), self._interpreter, @@ -356,6 +358,7 @@ def resolve( cache, cache_ttl, allow_prereleases=allow_prereleases, interpreter=interpreter, platform=platform) else: - resolver = Resolver(interpreter=interpreter, platform=platform) + resolver = Resolver( + allow_prereleases=allow_prereleases, interpreter=interpreter, platform=platform) return resolver.resolve(resolvables_from_iterable(requirements, builder)) diff --git a/tests/test_resolver.py b/tests/test_resolver.py index f1c590fb1..588948390 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -132,6 +132,31 @@ def assert_resolve(expected_version, **resolve_kwargs): assert_resolve('3.0.0rc3', fetchers=[]) +def test_resolve_prereleases_multiple_set(): + stable_dep = make_sdist(name='dep', version='2.0.0') + prerelease_dep1 = make_sdist(name='dep', version='3.0.0rc3') + prerelease_dep2 = make_sdist(name='dep', version='3.0.0rc4') + prerelease_dep3 = make_sdist(name='dep', version='3.0.0rc5') + + with temporary_dir() as td: + for sdist in (stable_dep, prerelease_dep1, prerelease_dep2, prerelease_dep3): + safe_copy(sdist, os.path.join(td, os.path.basename(sdist))) + fetchers = [Fetcher([td])] + + def assert_resolve(expected_version, **resolve_kwargs): + dists = resolve( + [ + 'dep>=3.0.0rc1', + 'dep==3.0.0rc4', + ], + fetchers=fetchers, **resolve_kwargs) + assert 1 == len(dists) + dist = dists[0] + assert expected_version == dist.version + + assert_resolve('3.0.0rc4', allow_prereleases=True) + + def test_resolvable_set(): builder = ResolverOptionsBuilder() rs = _ResolvableSet()