Skip to content

Latest commit

 

History

History
237 lines (194 loc) · 6.59 KB

NodeBinaryClassification.md

File metadata and controls

237 lines (194 loc) · 6.59 KB

runner.NodeBinaryClassification

View source on GitHub

Node binary (or multi-label) classification via structured readout.

Inherits From: Task

runner.NodeBinaryClassification(
    key: str = 'seed',
    units: int = 1,
    *,
    feature_name: str = tfgnn.HIDDEN_STATE,
    readout_node_set: tfgnn.NodeSetName = '_readout',
    validate: bool = True,
    name: str = 'classification_logits',
    label_fn: Optional[LabelFn] = None,
    label_feature_name: Optional[str] = None
)

Args

key A string key to select between possibly multiple named readouts.
units The units for the classification head. (Typically 1 for binary classification and the number of labels for multi-label classification.)
feature_name The name of the feature to read. If unset, tfgnn.HIDDEN_STATE will be read.
readout_node_set A string, defaults to "_readout". This is used as the name for the readout node set and as a name prefix for its edge sets.
validate Setting this to false disables the validity checks for the auxiliary edge sets. This is stronlgy discouraged, unless great care is taken to run tfgnn.validate_graph_tensor_for_readout() earlier on structurally unchanged GraphTensors.
name The classification head's layer name. To control the naming of saved model outputs see the runner model exporters (e.g., KerasModelExporter).
label_fn A label extraction function. This function mutates the input GraphTensor. Mutually exclusive with label_feature_name.
label_feature_name A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input GraphTensor. Mutually exclusive with label_fn.

Methods

gather_activations

View source

gather_activations(
    inputs: GraphTensor
) -> Field

Gather activations from auxiliary node (and edge) sets.

losses

View source

losses() -> interfaces.Losses

Returns arbitrary task specific losses.

metrics

View source

metrics() -> interfaces.Metrics

Returns arbitrary task specific metrics.

predict

View source

predict(
    inputs: tfgnn.GraphTensor
) -> interfaces.Predictions

Apply a linear head for classification.

Args
inputs A tfgnn.GraphTensor for classification.
Returns
The classification logits.

preprocess

View source

preprocess(
    inputs: GraphTensor
) -> tuple[GraphTensor, Field]

Preprocesses a scalar (after merge_batch_to_components) GraphTensor.

This function uses the Keras functional API to define non-trainable transformations of the symbolic input GraphTensor, which get executed during dataset preprocessing in a tf.data.Dataset.map(...) operation. It has two responsibilities:

  1. Splitting the training label out of the input for training. It must be returned as a separate tensor or mapping of tensors.
  2. Optionally, transforming input features. Some advanced modeling techniques require running the same base GNN on multiple different transformations, so this function may return a single GraphTensor or a non-empty sequence of GraphTensors. The corresponding base GNN output for each GraphTensor is provided to the predict(...) method.
Args
inputs A symbolic Keras GraphTensor for processing.
Returns
A tuple of processed GraphTensor(s) and a (one or mapping of) Field to be used as labels.