Node binary (or multi-label) classification via structured readout.
Inherits From: Task
runner.NodeBinaryClassification(
key: str = 'seed',
units: int = 1,
*,
feature_name: str = tfgnn.HIDDEN_STATE,
readout_node_set: tfgnn.NodeSetName = '_readout',
validate: bool = True,
name: str = 'classification_logits',
label_fn: Optional[LabelFn] = None,
label_feature_name: Optional[str] = None
)
gather_activations(
inputs: GraphTensor
) -> Field
Gather activations from auxiliary node (and edge) sets.
losses() -> interfaces.Losses
Returns arbitrary task specific losses.
metrics() -> interfaces.Metrics
Returns arbitrary task specific metrics.
predict(
inputs: tfgnn.GraphTensor
) -> interfaces.Predictions
Apply a linear head for classification.
Args | |
---|---|
inputs
|
A tfgnn.GraphTensor for classification.
|
Returns | |
---|---|
The classification logits. |
preprocess(
inputs: GraphTensor
) -> tuple[GraphTensor, Field]
Preprocesses a scalar (after merge_batch_to_components
) GraphTensor
.
This function uses the Keras functional API to define non-trainable
transformations of the symbolic input GraphTensor
, which get executed during
dataset preprocessing in a tf.data.Dataset.map(...)
operation. It has two
responsibilities:
- Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
- Optionally, transforming input features. Some advanced modeling techniques
require running the same base GNN on multiple different transformations, so
this function may return a single
GraphTensor
or a non-empty sequence ofGraphTensors
. The corresponding base GNN output for eachGraphTensor
is provided to thepredict(...)
method.
Args | |
---|---|
inputs
|
A symbolic Keras GraphTensor for processing.
|
Returns | |
---|---|
A tuple of processed GraphTensor (s) and a (one or mapping of) Field to
be used as labels.
|