Skip to content

Commit

Permalink
Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Aug 2, 2022
1 parent e1edd75 commit 9b3d622
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 58 deletions.
38 changes: 15 additions & 23 deletions gama/configuration/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def pset_from_config(
maps Primitive name to a check for the validity of the hp configuration
"""

pset = defaultdict(list)
pset: Dict[str, List[Union[Primitive, Terminal]]] = defaultdict(list)
parameter_checks = {}

# Make sure the str-keys are evaluated first, they describe shared hyperparameters.
Expand All @@ -39,9 +39,9 @@ def pset_from_config(
# Specification of shared hyperparameters
for value in values:
pset[key].append(Terminal(value=value, output=key, identifier=key))
elif isinstance(key, object):
elif isinstance(key, type):
# Specification of operator (learner, preprocessor)
hyperparameter_types = []
hyperparameter_types: List[str] = []
for name, param_values in sorted(values.items()):
# We construct a new type for each hyperparameter, so we can specify
# it as terminal type, making sure it matches with expected
Expand All @@ -68,36 +68,28 @@ def pset_from_config(

# After registering the hyperparameter types,
# we can register the operator itself.
transformer_tags = [
"DATA_PREPROCESSING",
"FEATURE_SELECTION",
"DATA_TRANSFORMATION",
]
if issubclass(key, sklearn.base.TransformerMixin) or (
hasattr(key, "metadata")
and key.metadata.query()["primitive_family"] in transformer_tags
):
if issubclass(key, sklearn.base.TransformerMixin):
pset[DATA_TERMINAL].append(
Primitive(
input=hyperparameter_types, output=DATA_TERMINAL, identifier=key
input=tuple(hyperparameter_types),
output=DATA_TERMINAL,
identifier=key,
)
)
elif issubclass(key, sklearn.base.ClassifierMixin) or (
hasattr(key, "metadata")
and key.metadata.query()["primitive_family"] == "CLASSIFICATION"
):
elif issubclass(key, sklearn.base.ClassifierMixin):
pset["prediction"].append(
Primitive(
input=hyperparameter_types, output="prediction", identifier=key
input=tuple(hyperparameter_types),
output="prediction",
identifier=key,
)
)
elif issubclass(key, sklearn.base.RegressorMixin) or (
hasattr(key, "metadata")
and key.metadata.query()["primitive_family"] == "REGRESSION"
):
elif issubclass(key, sklearn.base.RegressorMixin):
pset["prediction"].append(
Primitive(
input=hyperparameter_types, output="prediction", identifier=key
input=tuple(hyperparameter_types),
output="prediction",
identifier=key,
)
)
else:
Expand Down
21 changes: 14 additions & 7 deletions gama/gama.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
] = "filled_in_by_child_class",
regularize_length: bool = True,
max_pipeline_length: Optional[int] = None,
config: Dict = None,
config: Dict[Union[str, object], Any] = {},
random_state: Optional[int] = None,
max_total_time: int = 3600,
max_eval_time: Optional[int] = None,
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(
)
max_start_length = 3 if max_pipeline_length is None else max_pipeline_length
self._operator_set = OperatorSet(
mutate=partial(
mutate=partial( # type: ignore #https://github.com/python/mypy/issues/1484
random_valid_mutation_in_place,
primitive_set=self._pset,
max_length=max_pipeline_length,
Expand Down Expand Up @@ -335,7 +335,8 @@ def _prepare_for_prediction(
) -> pd.DataFrame:
if isinstance(x, np.ndarray):
x = self._np_to_matching_dataframe(x)
x = self._basic_encoding_pipeline.transform(x)
if self._basic_encoding_pipeline:
x = self._basic_encoding_pipeline.transform(x)
return x

def _predict(self, x: pd.DataFrame) -> np.ndarray:
Expand Down Expand Up @@ -625,7 +626,7 @@ def _search_phase(

def export_script(
self, file: Optional[str] = "gama_pipeline.py", raise_if_exists: bool = False
) -> Optional[str]:
) -> str:
"""Export a Python script which sets up the best found pipeline.
Can only be called after `fit`.
Expand Down Expand Up @@ -653,6 +654,10 @@ def export_script(
raise_if_exists: bool (default=False)
If True, raise an error if the file already exists.
If False, overwrite `file` if it already exists.
Returns
-------
script: str
"""
if self.model is None:
raise RuntimeError(STR_NO_OPTIMAL_PIPELINE)
Expand All @@ -670,8 +675,7 @@ def export_script(
with open(file, "w") as fh:
fh.write(script_text)
subprocess.call(["black", file])
else:
return script_text
return script_text

def _safe_outside_call(self, fn: Callable) -> None:
"""Calls fn logging and ignoring all exceptions except TimeoutException."""
Expand All @@ -686,7 +690,10 @@ def _safe_outside_call(self, fn: Callable) -> None:
# Note KeyboardInterrupts are not exceptions and get elevated to the caller.
log.warning("Exception during callback.", exc_info=True)

if self._time_manager.current_activity.exceeded_limit(margin=3.0):
if (
self._time_manager.current_activity
and self._time_manager.current_activity.exceeded_limit(margin=3.0)
):
# If time exceeds during a safe callback, the timeout exception *might*
# have been swallowed. This can result in GAMA running indefinitely.
# However in rare conditions it can be that the TimeoutException is still
Expand Down
4 changes: 2 additions & 2 deletions gama/genetic_programming/components/individual.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def __init__(
self._id = uuid.uuid4()
self._to_pipeline = to_pipeline

def __eq__(self, other: "Individual") -> bool:
def __eq__(self, other) -> bool:
return isinstance(other, Individual) and other._id == self._id

def __hash__(self) -> str:
def __hash__(self) -> int:
return hash(self._id)

def __str__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion gama/genetic_programming/components/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class Primitive(NamedTuple):
E.g. a preprocessing or classification algorithm.
"""

input: Tuple[str]
input: Tuple[str, ...]
output: str
identifier: Callable

Expand Down
10 changes: 5 additions & 5 deletions gama/genetic_programming/components/primitive_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Union, cast
from .terminal import DATA_TERMINAL, Terminal
from .primitive import Primitive

Expand Down Expand Up @@ -51,9 +51,9 @@ def str_nonrecursive(self) -> str:

def copy(self) -> "PrimitiveNode":
"""Copies the object. Shallow for terminals, deep for data_node."""
if self._data_node == DATA_TERMINAL:
data_node_copy = DATA_TERMINAL
else:
if isinstance(self._data_node, str) and self._data_node == DATA_TERMINAL:
data_node_copy = DATA_TERMINAL # type: Union[str, PrimitiveNode]
elif isinstance(self._data_node, PrimitiveNode):
data_node_copy = self._data_node.copy()
return PrimitiveNode(
primitive=self._primitive,
Expand Down Expand Up @@ -101,7 +101,7 @@ def from_string(cls, string: str, primitive_set: dict) -> "PrimitiveNode":
raise ValueError(f"terminals {missing} for primitive {primitive}")
last_node = cls(primitive, last_node, terminals)

return last_node
return cast(PrimitiveNode, last_node)


def find_primitive(primitive_set: dict, primitive_string: str) -> Primitive:
Expand Down
13 changes: 8 additions & 5 deletions gama/genetic_programming/nsga2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ class NSGAMeta:
def __init__(self, obj: object, metrics: List[Callable]):
self.obj = obj
self.values = tuple((m(obj) for m in metrics))
self.rank = None
self.distance = 0
self.dominating = []
self.rank = 0
self.distance = 0.0
self.dominating: List["NSGAMeta"] = []
self.domination_counter = 0

def dominates(self, other: "NSGAMeta") -> bool:
Expand Down Expand Up @@ -106,11 +106,14 @@ def nsga2(
selection += fronts[i]
else:
# Only the least crowded remainder is selected
s = sorted(fronts[i], key=cmp_to_key(lambda x, y: x.crowd_compare(y)))
s = sorted(
fronts[i],
key=cmp_to_key(lambda x, y: x.crowd_compare(y)), # type: ignore
)
selection += s[: (n - len(selection))] # Fill up to n
i += 1

return selection if return_meta else [s.obj for s in selection]
return selection if return_meta else [s.obj for s in selection] # type: ignore


def fast_non_dominated_sort(P: List[NSGAMeta]) -> List[List[NSGAMeta]]:
Expand Down
13 changes: 7 additions & 6 deletions gama/genetic_programming/operator_set.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Callable, List, Optional, Tuple, Any
from typing import Callable, Dict, List, Optional, Tuple, Any
from gama.genetic_programming.components.primitive_node import PrimitiveNode

from sklearn.pipeline import Pipeline

Expand All @@ -18,24 +19,24 @@ def __init__(
mutate: Callable[[Individual], None],
mate: Callable[[Individual, Individual], Tuple[Individual, Individual]],
create_from_population: Callable[[Any], List[Individual]],
create_new: Callable[[], Individual],
create_new: Callable[[], PrimitiveNode],
compile_: Callable[[Individual], Pipeline],
eliminate: Callable[[List[Individual], int], List[Individual]],
evaluate_callback: Callable[[Evaluation], None],
max_retry: int = 50,
completed_evaluations: Optional[List[Individual]] = None,
completed_evaluations: Optional[Dict[str, Evaluation]] = None,
):
self._mutate = mutate
self._mate = mate
self._create_from_population = create_from_population
self._create_new = create_new
self._compile = compile_
self._safe_compile = None
self._safe_compile: Optional[Callable[[Individual], Pipeline]] = None
self._eliminate = eliminate
self._max_retry = max_retry
self._evaluate = None
self._evaluate_callback = evaluate_callback
self.evaluate = None
self.evaluate: Optional[Callable[..., Evaluation]] = None

self._completed_evaluations = completed_evaluations

Expand Down Expand Up @@ -91,7 +92,7 @@ def individual(self, *args, **kwargs) -> Individual:
ind.meta["origin"] = "new"
return ind

def create(self, *args, **kwargs) -> Individual:
def create(self, *args, **kwargs) -> List[Individual]:
return self._create_from_population(self, *args, **kwargs)

def eliminate(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion gama/logging/GamaReport.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ def init_to_hps(init_line: str) -> Dict[str, str]:
# only supports one nested level - will do proper parsing later
for token in ["()", "(", ")", ",,"]:
all_arguments = all_arguments.replace(token, ",")
return dict(hp.split("=") for hp in all_arguments.split(",")) # type: ignore
return dict(hp.split("=") for hp in all_arguments.split(","))
2 changes: 1 addition & 1 deletion gama/logging/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ def register_stream_log(verbosity: int) -> None:
]

stdout_streamhandler = logging.StreamHandler(sys.stdout)
stdout_streamhandler.tag = "machine_set"
setattr(stdout_streamhandler, "tag", "machine_set")
stdout_streamhandler.setLevel(verbosity)
gama_log.addHandler(stdout_streamhandler)
6 changes: 4 additions & 2 deletions gama/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import logging
import os
import pickle
from typing import List, Union

from pandas.api.types import is_categorical_dtype

from gama import GamaClassifier, GamaRegressor
from gama.data_loading import X_y_from_file
from gama.gama import Gama


def make_parser():
Expand Down Expand Up @@ -127,7 +129,7 @@ def make_parser():
return parser


def main(command: str = ""):
def main(command: Union[str, List[str]] = ""):
parser = make_parser()

if isinstance(command, str):
Expand Down Expand Up @@ -172,7 +174,7 @@ def main(command: str = ""):
configuration["scoring"] = args.metric

if args.mode == "regression":
automl = GamaRegressor(**configuration)
automl: Gama = GamaRegressor(**configuration)
elif args.mode == "classification":
automl = GamaClassifier(**configuration)
else:
Expand Down
10 changes: 5 additions & 5 deletions gama/utilities/evaluation_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def __init__(
):
self.individual: Individual = individual
self.score = score
self._estimators: Optional[List] = [] if estimators is None else estimators
self._estimators: List[BaseEstimator] = [] if estimators is None else estimators
self.start_time = start_time
self.duration = duration
self.error = error
self.pid = pid
self._cache_file = None
self._cache_file = ""

if isinstance(predictions, (pd.Series, pd.DataFrame)):
predictions = predictions.values
Expand All @@ -52,7 +52,7 @@ def to_disk(self, directory: str) -> None:
def remove_from_disk(self) -> None:
"""Remove the related file from disk."""
os.remove(os.path.join(self._cache_file))
self._cache_file = None
self._cache_file = ""

@property
def estimators(self) -> List[BaseEstimator]:
Expand Down Expand Up @@ -234,7 +234,7 @@ def save_evaluation(self, evaluation: Evaluation) -> None:
self._process_predictions(evaluation)

if evaluation.error is not None:
evaluation._estimators, evaluation._predictions = None, None
evaluation._estimators, evaluation._predictions = [], None
self.other_evaluations.append(evaluation)
elif self._m is None or self._m > len(self.top_evaluations):
evaluation.to_disk(self._cache)
Expand All @@ -243,7 +243,7 @@ def save_evaluation(self, evaluation: Evaluation) -> None:
removed = heapq.heappushpop(self.top_evaluations, evaluation)
if removed == evaluation:
# new evaluation is not in heap, big memory items may be discarded
removed._predictions, removed._estimators = None, None
removed._estimators, removed._predictions = [], None
else:
# new evaluation is now on the heap, remove old from disk
evaluation.to_disk(self._cache)
Expand Down

0 comments on commit 9b3d622

Please sign in to comment.