Skip to content

Commit

Permalink
Test that a credential helper can supply credentials for bzlmod.
Browse files Browse the repository at this point in the history
  • Loading branch information
tjgq committed May 17, 2023
1 parent 4b9505f commit 1dd94a0
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/test/py/bazel/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,16 @@ py_test(
":test_base",
],
)

py_test(
name = "bzlmod_credentials_test",
size = "large",
srcs = ["bzlmod/bzlmod_credentials_test.py"],
tags = [
"requires-network",
],
deps = [
":bzlmod_test_utils",
":test_base",
],
)
133 changes: 133 additions & 0 deletions src/test/py/bazel/bzlmod/bzlmod_credentials_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# pylint: disable=g-backslash-continuation
# Copyright 2023 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests using credentials to connect to the bzlmod registry."""

import base64
import os
import tempfile
import unittest

from src.test.py.bazel import test_base
from src.test.py.bazel.bzlmod.test_utils import BazelRegistry
from src.test.py.bazel.bzlmod.test_utils import StaticHTTPServer


class BzlmodCredentialsTest(test_base.TestBase):
"""Test class for using credentials to connect to the bzlmod registry."""

def setUp(self):
test_base.TestBase.setUp(self)
self.registries_work_dir = tempfile.mkdtemp(dir=self._test_cwd)
self.registry_root = os.path.join(self.registries_work_dir, 'main')
self.main_registry = BazelRegistry(self.registry_root)
self.main_registry.createCcModule('aaa', '1.0')

self.ScratchFile('.bazelrc', [
# In ipv6 only network, this has to be enabled.
# 'startup --host_jvm_args=-Djava.net.preferIPv6Addresses=true',
'common --experimental_enable_bzlmod',
# Disable yanked version check so we are not affected BCR changes.
'common --allow_yanked_versions=all',
])
self.ScratchFile('WORKSPACE')
# The existence of WORKSPACE.bzlmod prevents WORKSPACE prefixes or suffixes
# from being used; this allows us to test built-in modules actually work
self.ScratchFile('WORKSPACE.bzlmod')
self.ScratchFile('MODULE.bazel', [
'bazel_dep(name = "aaa", version = "1.0")',
])
self.ScratchFile('BUILD', [
'cc_binary(',
' name = "main",',
' srcs = ["main.cc"],',
' deps = ["@aaa//:lib_aaa"],',
')',
])
self.ScratchFile('main.cc', [
'#include "aaa.h"',
'int main() {',
' hello_aaa("main function");',
'}',
])
self.ScratchFile('credhelper', [
'#!/usr/bin/env python3',
'import sys',
'if "127.0.0.1" in sys.stdin.read():',
' print("""{"headers":{"Authorization":["Bearer TOKEN"]}}""")',
'else:',
' print("""{}""")',
], executable=True)
self.ScratchFile('.netrc', [
'machine 127.0.0.1',
'login foo',
'password bar',
])

def testUnauthenticated(self):
with StaticHTTPServer(self.registry_root) as static_server:
_, stdout, _ = self.RunBazel([
'run',
'--registry=' + static_server.getURL(),
'--registry=https://bcr.bazel.build',
'//:main',
])
self.assertIn('main function => [email protected]', stdout)

def testMissingCredentials(self):
with StaticHTTPServer(self.registry_root, expected_auth='Bearer TOKEN') as static_server:
_, _, stderr = self.RunBazel([
'run',
'--registry=' + static_server.getURL(),
'--registry=https://bcr.bazel.build',
'//:main',
], allow_failure=True)
self.assertIn('GET returned 401 Unauthorized', "\n".join(stderr))

def testCredentialsFromHelper(self):
with StaticHTTPServer(self.registry_root, expected_auth='Bearer TOKEN') as static_server:
_, stdout, _ = self.RunBazel([
'run',
'--experimental_credential_helper=%workspace%/credhelper',
'--registry=' + static_server.getURL(),
'--registry=https://bcr.bazel.build',
'//:main',
])
self.assertIn('main function => [email protected]', stdout)

def testCredentialsFromNetrc(self):
expected_auth='Basic ' + base64.b64encode(b'foo:bar').decode('ascii')

with StaticHTTPServer(self.registry_root, expected_auth=expected_auth) as static_server:
_, stdout, _ = self.RunBazel([
'run',
'--registry=' + static_server.getURL(),
'--registry=https://bcr.bazel.build',
'//:main',
], env_add={"NETRC": self.Path(".netrc")})
self.assertIn('main function => [email protected]', stdout)

def testCredentialsFromHelperOverrideNetrc(self):
with StaticHTTPServer(self.registry_root, expected_auth='Bearer TOKEN') as static_server:
_, stdout, _ = self.RunBazel([
'run',
'--experimental_credential_helper=%workspace%/credhelper',
'--registry=' + static_server.getURL(),
'--registry=https://bcr.bazel.build',
'//:main',
], env_add={"NETRC": self.Path(".netrc")})
self.assertIn('main function => [email protected]', stdout)

if __name__ == '__main__':
unittest.main()
52 changes: 52 additions & 0 deletions src/test/py/bazel/bzlmod/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
"""Test utils for Bzlmod."""

import base64
import functools
import hashlib
import http.server
import json
import os
import pathlib
import shutil
import threading
import urllib.request
import zipfile

Expand Down Expand Up @@ -319,3 +322,52 @@ def createLocalPathModule(self, name, version, path, deps=None):

with module_dir.joinpath('source.json').open('w') as f:
json.dump(source, f, indent=4, sort_keys=True)


class StaticHTTPServer:
"""An HTTP server serving static files, optionally with authentication."""

def __init__(self, root, expected_auth=None):
self.root = root
self.expected_auth = expected_auth

def __enter__(self):
address = ('localhost', 0) # assign random port
handler = functools.partial(
_Handler, self.expected_auth, directory=self.root)
self.httpd = http.server.HTTPServer(address, handler)
self.thread = threading.Thread(
target=self.httpd.serve_forever, daemon=True)
self.thread.start()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.httpd.shutdown()
self.thread.join()

def getURL(self):
return "http://{}:{}".format(*self.httpd.server_address)


class _Handler(http.server.SimpleHTTPRequestHandler):

def __init__(self, expected_auth, *args, **kwargs):
self.expected_auth = expected_auth
super().__init__(*args, **kwargs)

def check_auth(self):
#raise Exception("check_auth {} {}".format(self.requestline, self.headers))
auth_header = self.headers.get('Authorization', None)
if auth_header != self.expected_auth:
self.send_error(http.HTTPStatus.UNAUTHORIZED)
self.send_header('WWW-Authenticate', 'Basic')
return False
return True

def do_HEAD(self):
if self.check_auth():
return super().do_HEAD()

def do_GET(self):
if self.check_auth():
return super().do_GET()

0 comments on commit 1dd94a0

Please sign in to comment.