The DARTS cfg file is here, also check enas cfg breakup for some explainations of the cfg sections.
All rollout_type
in DARTS configuration file is differentiable.
- class:
CNNSearchSpace
(inaw_nas.common
)- 8 cells
- [0,0,1,0,0,1,0,0] (1 denotes where the reduction cell is)
- class:
DifferentiableController
(inaw_nas.controller.differentiable
) - Call:
aw_nas.trainer.simple.SimpleTrainer._controller_update
rollout = controller.sample()
- sample architecture from search spaceevaluator.evaluate_rollout(rollout)
- callevaluator.evaluate_rollout
to evaluate the architecturecontroller.step()
- update the controller with evaluated rollouts, in DARTS, the controller parameters is the architecture parameters a, and they are updated with validation set gradients
- class:
DiffSuperNet
(inaw_nas.weights_manager.diff_super_net
) - Call: the
cand_net.eval_data
,cand_net.eval_queue
,cand_net.gradients
calls inaw_nas.tranier.simple.SimpleTrainer.evaluate_rollout
aw_nas.tranier.simple.SimpleTrainer.update_evaluator
- Interfaces:
DiffSubCandidateNet
instance will have a reference to the architecture parameterarch
, and itsforward
call will callself.super_net.forward(inputs, self.arch)
(Seeaw_nas.weights_manager.diff_super_net.DiffSubCandidateNet.forward
)diff_supernet.forward(inputs, arch)
- forward a sub-net/candidate net of the supernet using the specifiedarch
parameterdiff_supernet.assemble_candidate(rollout)
- assemble an architecture rollout into a candidate networkDiffSubCandidateNet
- class:
MepaEvaluator
(inaw_nas.evaluator.mepa
) Compatible with shared-weights evaluator, please refer to the config breakup for more details - Call:
aw_nas.trainer.simple.SimpleTrainer._evaluator_update()
- Interfaces:
evaluator.update_rollout(rollout)
- do nothingevaluator.update_evaluator(controller)
- optimize the shared weights on the training data queue (mepa_queue
in the code)evaluator.evaluate_rollout(rollout)
- evaluate the rollout on the validation data queue (controller_queue
in the code)