Skip to content

Commit

Permalink
Revert "Add WrapperAware trait"
Browse files Browse the repository at this point in the history
This reverts commit 241abc4.
  • Loading branch information
ElGigi committed Jan 17, 2024
1 parent e8aa78e commit c0d8bf6
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 79 deletions.
44 changes: 42 additions & 2 deletions src/GridSearch.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
use Rubix\ML\Specifications\EstimatorIsCompatibleWithMetric;
use Rubix\ML\Specifications\SamplesAreCompatibleWithEstimator;
use Rubix\ML\Exceptions\InvalidArgumentException;
use Rubix\ML\Traits\WrapperAware;

/**
* Grid Search
Expand All @@ -42,7 +41,7 @@
*/
class GridSearch implements Wrapper, Learner, Parallel, Verbose, Persistable
{
use AutotrackRevisions, Multiprocessing, LoggerAware, WrapperAware;
use AutotrackRevisions, Multiprocessing, LoggerAware;

/**
* The class name of the base estimator.
Expand Down Expand Up @@ -72,6 +71,13 @@ class GridSearch implements Wrapper, Learner, Parallel, Verbose, Persistable
*/
protected \Rubix\ML\CrossValidation\Validator $validator;

/**
* The base estimator instance.
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;

/**
* The validation scores obtained from the last search.
*
Expand Down Expand Up @@ -173,6 +179,18 @@ public function __construct(
$this->backend = new Serial();
}

/**
* Return the estimator type.
*
* @internal
*
* @return EstimatorType
*/
public function type() : EstimatorType
{
return $this->base->type();
}

/**
* Return the data types that the estimator is compatible with.
*
Expand Down Expand Up @@ -214,6 +232,16 @@ public function trained() : bool
return $this->base->trained();
}

/**
* Return the base learner instance.
*
* @return Estimator
*/
public function base() : Estimator
{
return $this->base;
}

/**
* Train one estimator per combination of parameters given by the grid and
* assign the best one as the base estimator of this instance.
Expand Down Expand Up @@ -276,6 +304,18 @@ public function train(Dataset $dataset) : void
}
}

/**
* Make a prediction on a given sample dataset.
*
* @param Dataset $dataset
* @throws Exceptions\RuntimeException
* @return mixed[]
*/
public function predict(Dataset $dataset) : array
{
return $this->base->predict($dataset);
}

/**
* The callback that executes after the cross validation task.
*
Expand Down
53 changes: 51 additions & 2 deletions src/PersistentModel.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
use Rubix\ML\AnomalyDetectors\Scoring;
use Rubix\ML\Exceptions\InvalidArgumentException;
use Rubix\ML\Exceptions\RuntimeException;
use Rubix\ML\Traits\WrapperAware;

/**
* Persistent Model
Expand All @@ -24,7 +23,12 @@
*/
class PersistentModel implements Wrapper, Learner, Probabilistic, Scoring
{
use WrapperAware;
/**
* The persistable base learner.
*
* @var Learner
*/
protected \Rubix\ML\Learner $base;

/**
* The persister used to interface with the storage layer.
Expand Down Expand Up @@ -80,6 +84,30 @@ public function __construct(Learner $base, Persister $persister, ?Serializer $se
$this->serializer = $serializer ?? new RBX();
}

/**
* Return the estimator type.
*
* @internal
*
* @return EstimatorType
*/
public function type() : EstimatorType
{
return $this->base->type();
}

/**
* Return the data types that the estimator is compatible with.
*
* @internal
*
* @return list<\Rubix\ML\DataType>
*/
public function compatibility() : array
{
return $this->base->compatibility();
}

/**
* Return the settings of the hyper-parameters in an associative array.
*
Expand All @@ -106,6 +134,16 @@ public function trained() : bool
return $this->base->trained();
}

/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base() : Estimator
{
return $this->base;
}

/**
* Save the model to storage.
*/
Expand All @@ -130,6 +168,17 @@ public function train(Dataset $dataset) : void
$this->base->train($dataset);
}

/**
* Make a prediction on a given sample dataset.
*
* @param Dataset $dataset
* @return mixed[]
*/
public function predict(Dataset $dataset) : array
{
return $this->base->predict($dataset);
}

/**
* Estimate the joint probabilities for each possible outcome.
*
Expand Down
44 changes: 42 additions & 2 deletions src/Pipeline.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

use Rubix\ML\Helpers\Params;
use Rubix\ML\Datasets\Dataset;
use Rubix\ML\Traits\WrapperAware;
use Rubix\ML\Transformers\Elastic;
use Rubix\ML\Transformers\Stateful;
use Rubix\ML\Transformers\Transformer;
Expand All @@ -28,7 +27,7 @@
*/
class Pipeline implements Online, Probabilistic, Scoring, Persistable, Wrapper
{
use AutotrackRevisions, WrapperAware;
use AutotrackRevisions;

/**
* A list of transformers to be applied in series.
Expand All @@ -39,6 +38,13 @@ class Pipeline implements Online, Probabilistic, Scoring, Persistable, Wrapper
//
];

/**
* An instance of a base estimator to receive the transformed data.
*
* @var Estimator
*/
protected \Rubix\ML\Estimator $base;

/**
* Should we update the elastic transformers during partial train?
*
Expand Down Expand Up @@ -66,6 +72,30 @@ public function __construct(array $transformers, Estimator $base, bool $elastic
$this->elastic = $elastic;
}

/**
* Return the estimator type.
*
* @internal
*
* @return EstimatorType
*/
public function type() : EstimatorType
{
return $this->base->type();
}

/**
* Return the data types that the estimator is compatible with.
*
* @internal
*
* @return list<\Rubix\ML\DataType>
*/
public function compatibility() : array
{
return $this->base->compatibility();
}

/**
* Return the settings of the hyper-parameters in an associative array.
*
Expand Down Expand Up @@ -94,6 +124,16 @@ public function trained() : bool
: true;
}

/**
* Return the base estimator instance.
*
* @return Estimator
*/
public function base() : Estimator
{
return $this->base;
}

/**
* Run the training dataset through all transformers in order and use the
* transformed dataset to train the estimator.
Expand Down
73 changes: 0 additions & 73 deletions src/Traits/WrapperAware.php

This file was deleted.

0 comments on commit c0d8bf6

Please sign in to comment.