-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from graphcore-research/linear
Linear
- Loading branch information
Showing
21 changed files
with
845 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
FROM graphcore/pytorch:3.1.0-ubuntu-20.04 | ||
|
||
RUN apt-get update \ | ||
&& apt-get install -y sudo \ | ||
&& apt-get clean | ||
|
||
# Snippet from https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user | ||
ARG USERNAME=user-name-goes-here | ||
ARG USER_UID=1000 | ||
ARG USER_GID=$USER_UID | ||
RUN groupadd --gid $USER_GID $USERNAME \ | ||
&& useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ | ||
&& echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ | ||
&& chmod 0440 /etc/sudoers.d/$USERNAME | ||
USER $USERNAME | ||
|
||
ADD . /tmp/unit_scaling | ||
RUN pip install -r /tmp/unit_scaling/requirements.txt \ | ||
jupyterlab |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
FROM graphcore/pytorch: 3.1.0-ubuntu-20.04 | ||
|
||
RUN apt-get update \ | ||
&& apt-get install -y sudo \ | ||
&& apt-get clean | ||
|
||
# Snippet from https: //code.visualstudio.com/remote/advancedcontainers/add-nonroot-user | ||
ARG USERNAME=user-name-goes-here | ||
ARG USER_UID=1000 | ||
ARG USER_GID=$USER_UID | ||
RUN groupadd --gid $USER_GID $USERNAME \ | ||
&& useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ | ||
&& echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ | ||
&& chmod 0440 /etc/sudoers.d/$USERNAME | ||
USER $USERNAME | ||
|
||
ADD . /tmp/unit_scaling | ||
RUN pip install -r /tmp/unit_scaling/requirements.txt \ | ||
jupyterlab |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
* | ||
!requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
name: CI | ||
|
||
on: | ||
push: { branches: [ "main" ] } | ||
pull_request: | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
# Run everything on main, most-recent on PR builds | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
ci: | ||
runs-on: [self-hosted, Linux, X64, 20.04, Ubuntu] | ||
container: graphcore/pytorch:3.2.0-ubuntu-20.04 | ||
timeout-minutes: 10 | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Update package info | ||
run: apt-get update | ||
- name: Install git | ||
run: apt-get -y install git | ||
- name: Install dependencies | ||
run: pip install -r requirements.txt | ||
- name: Run CI | ||
run: ./dev ci |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
.coverage | ||
.env | ||
.mypy_cache | ||
__pycache__ | ||
.pytest_cache | ||
.venv | ||
.venvs | ||
.vscode | ||
|
||
/build | ||
/local |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,31 @@ | ||
# unit-scaling | ||
A library for unit scaling in PyTorch | ||
# Unit Scaling | ||
|
||
A library for unit scaling in PyTorch. | ||
|
||
Based on the paper [Unit Scaling: Out-of-the-Box Low-Precision Training](https://arxiv.org/abs/2303.11257). | ||
|
||
## Development | ||
|
||
**First-time setup**: | ||
|
||
```bash | ||
python3 -m venv .venv | ||
# Add to `.venv/bin/activate`: `source /PATH_TO_POPLAR_SDK/enable` (If running on IPU) | ||
source .venv/bin/activate | ||
|
||
# pip install wheel # (If running on IPU) | ||
# pip install $POPLAR_SDK_ENABLED/../poptorch-*.whl # (If running on IPU) | ||
pip install -r requirements.txt | ||
``` | ||
|
||
**Run pre-flight checks** (or run `./dev --help` to see supported commands): | ||
|
||
```bash | ||
./dev | ||
``` | ||
|
||
IDE recommendations: | ||
- Python intepreter is set to `.venv/bin/python` | ||
- Format-on-save enabled | ||
- Consider a `.env` file for setting PYTHONPATH, e.g. `echo "PYTHONPATH=$(pwd)" > .env` | ||
(note that this will be a different path if using devcontainers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
#!/usr/bin/env python | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. | ||
|
||
"""Dev task launcher.""" | ||
|
||
import argparse | ||
import datetime | ||
import os | ||
import subprocess | ||
import sys | ||
from pathlib import Path | ||
from typing import Any, Callable, Iterable, List, Optional, TypeVar | ||
|
||
# Utilities | ||
|
||
|
||
def run(command: Iterable[Any]) -> None: | ||
"""Run a command, terminating on failure.""" | ||
cmd = [str(arg) for arg in command if arg is not None] | ||
print("$ " + " ".join(cmd), file=sys.stderr) | ||
environ = os.environ.copy() | ||
environ["PYTHONPATH"] = f"{os.getcwd()}:{environ.get('PYTHONPATH', '')}" | ||
exit_code = subprocess.call(cmd, env=environ) | ||
if exit_code: | ||
sys.exit(exit_code) | ||
|
||
|
||
T = TypeVar("T") | ||
|
||
|
||
def cli(*args: Any, **kwargs: Any) -> Callable[[T], T]: | ||
"""Declare a CLI command / arguments for that command.""" | ||
|
||
def wrap(func: T) -> T: | ||
if not hasattr(func, "cli_args"): | ||
setattr(func, "cli_args", []) | ||
if args or kwargs: | ||
getattr(func, "cli_args").append((args, kwargs)) | ||
return func | ||
|
||
return wrap | ||
|
||
|
||
# Commands | ||
|
||
PYTHON_ROOTS = ["unit_scaling", "dev"] | ||
|
||
|
||
@cli("-k", "--filter") | ||
def tests(filter: Optional[str]) -> None: | ||
"""run Python tests""" | ||
run( | ||
[ | ||
"python", | ||
"-m", | ||
"pytest", | ||
"unit_scaling", | ||
None if filter else "--cov=unit_scaling", | ||
*(["-k", filter] if filter else []), | ||
] | ||
) | ||
|
||
|
||
@cli("command", nargs="*") | ||
def python(command: List[Any]) -> None: | ||
"""run Python with the current directory on PYTHONPATH, for development""" | ||
run(["python"] + command) | ||
|
||
|
||
@cli() | ||
def lint() -> None: | ||
"""run static analysis""" | ||
run(["python", "-m", "flake8", *PYTHON_ROOTS]) | ||
run(["python", "-m", "mypy", *PYTHON_ROOTS]) | ||
|
||
|
||
@cli("--check", action="store_true") | ||
def format(check: bool) -> None: | ||
"""autoformat all sources""" | ||
run(["python", "-m", "black", "--check" if check else None, *PYTHON_ROOTS]) | ||
run(["python", "-m", "isort", "--check" if check else None, *PYTHON_ROOTS]) | ||
|
||
|
||
@cli() | ||
def copyright() -> None: | ||
"""check for Graphcore copyright headers on relevant files""" | ||
command = ( | ||
f"find {' '.join(PYTHON_ROOTS)} -type f -not -name *.pyc" | ||
" | xargs grep -L 'Copyright (c) 202. Graphcore Ltd[.] All rights reserved[.]'" | ||
) | ||
print(f"$ {command}", file=sys.stderr) | ||
# Note: grep exit codes are not consistent between versions, so we don't use | ||
# check=True | ||
output = ( | ||
subprocess.run( | ||
command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT | ||
) | ||
.stdout.decode() | ||
.strip() | ||
) | ||
if output: | ||
print( | ||
"Error - failed copyright header check in:\n " | ||
+ output.replace("\n", "\n "), | ||
file=sys.stderr, | ||
) | ||
print("Template(s):") | ||
comment_prefixes = { | ||
{".cpp": "//"}.get(Path(f).suffix, "#") for f in output.split("\n") | ||
} | ||
for prefix in comment_prefixes: | ||
print( | ||
( | ||
f"{prefix} Copyright (c) {datetime.datetime.now().year}" | ||
" Graphcore Ltd. All rights reserved." | ||
), | ||
file=sys.stderr, | ||
) | ||
sys.exit(1) | ||
|
||
|
||
@cli( | ||
"-s", | ||
"--skip", | ||
nargs="*", | ||
default=[], | ||
choices=["tests", "lint", "format", "copyright"], | ||
help="commands to skip", | ||
) | ||
def ci(skip: List[str] = []) -> None: | ||
"""run all continuous integration tests & checks""" | ||
if "tests" not in skip: | ||
tests(filter=None) | ||
if "lint" not in skip: | ||
lint() | ||
if "format" not in skip: | ||
format(check=True) | ||
if "copyright" not in skip: | ||
copyright() | ||
|
||
|
||
# Script | ||
|
||
|
||
def _main() -> None: | ||
# Build an argparse command line by finding globals in the current module | ||
# that are marked via the @cli() decorator. Each one becomes a subcommand | ||
# running that function, usage "$ ./dev fn_name ...args". | ||
parser = argparse.ArgumentParser(description=__doc__) | ||
parser.set_defaults(command=ci) | ||
|
||
subs = parser.add_subparsers() | ||
for key, value in globals().items(): | ||
if hasattr(value, "cli_args"): | ||
sub = subs.add_parser(key.replace("_", "-"), help=value.__doc__) | ||
for args, kwargs in value.cli_args: | ||
sub.add_argument(*args, **kwargs) | ||
sub.set_defaults(command=value) | ||
|
||
cli_args = vars(parser.parse_args()) | ||
command = cli_args.pop("command") | ||
command(**cli_args) | ||
|
||
|
||
if __name__ == "__main__": | ||
_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
black==23.1.0 | ||
docstring-parser==0.15 | ||
flake8==6.0.0 | ||
isort==5.12.0 | ||
mypy==1.0.1 | ||
numpy==1.24.2 | ||
pytest==7.2.1 | ||
pytest-cov==4.0.0 | ||
scipy==1.10.1 | ||
wandb==0.13.10 | ||
git+https://github.com/graphcore-research/poptorch-experimental-addons@14886d2285c3e45b0eadf4d719dae87d5f28b109 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
[options] | ||
packages = | ||
unit_scaling | ||
|
||
[mypy] | ||
pretty = true | ||
show_error_codes = true | ||
strict = true | ||
check_untyped_defs = true | ||
|
||
[mypy-poptorch.*] | ||
ignore_missing_imports = True | ||
|
||
[mypy-poptorch_experimental_addons.*] | ||
ignore_missing_imports = True | ||
|
||
[mypy-scipy.*] | ||
ignore_missing_imports = True | ||
|
||
[flake8] | ||
# See https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html | ||
max-line-length = 88 | ||
extend-ignore = E203,E731 | ||
|
||
[isort] | ||
profile = black | ||
|
||
[tool:pytest] | ||
addopts = --no-cov-on-fail | ||
|
||
[coverage:report] | ||
# fail_under = 100 | ||
skip_covered = true | ||
show_missing = true | ||
exclude_lines = | ||
pragma: no cover | ||
raise NotImplementedError | ||
assert False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# Copyright (c) 2023 Graphcore Ltd. All rights reserved. |
Oops, something went wrong.