Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Graph Neural Networks with PyG

Graph Machine Learning

Machine learning (ML) on a graph fundamentally differs from ML on two-dimensional data (such as images or structured tables) primarily due to the irregular structure and relational nature of graph data, which necessitates specialized models and data representations.

The shape of the data matters

FreeSurfer 8 will be transitioning many of it’s core algorithms to convolutional neural nets (CNNs), but not necessarily graph networks. New tools like SynthMorph and SynthSeg are not based on graph machine learning methods. Instead, they rely on widely adopted CNN architectures common in medical image analysis.

This is not meant to diminish the power or utility of using Graph ML models, but only to point out that they need to be used deliberately. Just because data can be characterized as a graph does not mean the added complexity and assumptions of the model will benefit the analysis.

To further use the new FreeSurfer 8 functions as an example,

SynthMorph (Joint Registration)

SynthMorph is a deep learning (DL) tool developed for fast, symmetric, diffeomorphic end-to-end affine and deformable brain registration .

  1. Overall Approach: SynthMorph uses DL methods that learn a function mapping an image pair to an output transform . The general framework employs convolutional neural networks (CNNs).

  2. Affine Component: The affine model hθ implemented in SynthMorph uses a modified Detector architecture. This architecture relies on a series of convolutions to predict spatial feature maps. These feature maps are used to calculate corresponding moving and fixed point clouds, and a weighted least-squares (WLS) solution then provides the affine transform. Alternative affine architectures considered in the related work include the Encoder, which uses a convolutional encoder combined with a fully connected (FC) layer, and Decomposer, a fully convolutional network. Some affine DL strategies utilize vision transformers instead of convolutional layers, but SynthMorph ultimately favors the Detector architecture.

  3. Deformable Component: The deformable module uses a U-Net architecture (a convolutional neural network) for predicting a stationary velocity field (SVF). The model also employs a hypernetwork Γξ, which is described as a simple feed-forward network with four ReLU-activated hidden FC layers, used to parameterize the weights of the deformable task network gη based on a regularization weight λ.

SynthSeg (Segmentation)

SynthSeg is a convolutional neural network (CNN) designed for the segmentation of brain MRI scans of any contrast and resolution without requiring retraining or fine-tuning .

  1. Architecture: The core segmentation network used in SynthSeg (and related variant SynthSeg+) is based on a 3D UNet architecture .

  2. Components: The UNet comprises five levels, with operations including two convolutions per level, batch normalization, max-pooling (contracting path), upsampling (expanding path), and skip connections. The modules used in the robust variant, SynthSeg+ (such as segmenters S1, S2, S3, and the denoiser D), are also implemented as CNNs.

The descriptions of both SynthMorph and SynthSeg emphasize architectures based on standard convolutional networks (U-Nets, encoders, detectors) for image processing, not methods explicitly utilizing graph structures or Graph Neural Networks (GNNs).

Here is a breakdown of the differences based on data characteristics and the resulting ML pipeline requirements:

1. Data Structure and Relationship

The core distinction lies in the organization of the input data:

2. Traditional ML Feature Engineering

In traditional ML pipelines for graphs, significant effort must be placed on creating features that capture the topological context, a step often less critical or handled differently for standard tabular or image data:

3. Deep Learning Approaches (GNNs vs. CNNs)

Graph Neural Networks (GNNs) are designed to handle the irregular structure of graphs and overcome the need for non-trivial human feature engineering.

Wait, what’s a graph?

The basic features of a graph, how data is mapped to them, and the resulting advantages stem from the graph’s ability to natively represent entities and their complex, irregular relationships.

1. Basic Features of a Graph Structure

A graph is a general language for describing entities, their relations, and interactions. When storing graph data for machine learning, the structure requires specific components:

Core Data Components (PyTorch Geometric Context):

  1. Node Features (XX or data.xdata.x): A matrix containing the features or attributes of every node. These features must generally be of a numerical type (float, double, or integer).

  2. Graph Connectivity (edge_index): This special format defines the connections (edges) between nodes. It is typically a matrix of two rows, where the first row writes the source of every edge and the second row writes the destination.

  3. Edge Features (edge_attribute): Optional features that describe the edges themselves, such as weights or multiple attributes.

  4. Target (YY or data.ydata.y): The target values used for training, which can be flexible depending on whether the task involves classification for every node, every edge, or the entire graph.

Structural Features (Derived Properties): In traditional machine learning on graphs, good predictive performance relies on engineering structural features that describe the network topology. These features capture how a node is positioned and what its local structure is:

2. How Tabular Data Can Be Mapped to Graph Properties

Tabular or relational data can be mapped to graph features by identifying entities as nodes and relationships as edges.

Example: Social Network Data Mapping: If tabular data describes people in a social network, the mapping could be structured as follows:

Tabular Data ComponentGraph PropertyExample Content
Individual Entity (Person)NodePerson 1, Person 2, Person 3, Person 4
Individual AttributesNode Features (XX)Age (e.g., 42 years old) and Income Level (e.g., 1200 currency).
Relationship/InteractionEdge (edge_index\text{edge\_index})An edge exists between two people if they know each other.
Relationship StrengthEdge AttributesThe weight of the edge could be the time (e.g., years) that both people have met each other.
Prediction TargetTarget Label (YY)The hours that person is working now.

Example: Relational Databases (RDL): More complex structured data, like an entire SQL database, can be treated as a graph for Relational Deep Learning (RDL). In this scenario:

3. What is Gained by Performing this Mapping

The primary gain of mapping data onto a graph structure is the ability to incorporate and leverage the relational structure inherent in the data, moving beyond traditional models that assume data regularity or independence.

Key Advantages:

  1. Capturing Relational Structure: The central challenge and benefit of ML on graphs is explicitly incorporating the information about the graph structure (the relational structure) into the model. This relational structure is crucial for obtaining good predictive performance.

  2. Addressing Irregularity: Graphs serve as a powerful representation because they can model interactions and relations that do not fit into the regular structures assumed by traditional deep learning (like the grid structures used for images by CNNs).

  3. Advanced Representation Learning: Graph Neural Networks (GNNs) utilize the graph structure to learn effective representations for nodes, links, or entire graphs. This process, often involving message passing (aggregating neighbor features), overcomes the need for complex, hand-designed features that were required in older, traditional ML pipelines on graphs.

  4. Enabling Relational Deep Learning (RDL): Mapping relational data to a graph allows for modern deep learning models (GNNs) to learn directly on raw relational databases, enabling breakthroughs in practical domains.

  5. Augmenting Language Models (Graph RAG): Graphs can be used in Retrieval Augmented Generation (RAG) pipelines for Large Language Models (LLMs) by providing crucial relational and topological information. Using GNNs to encode subgraphs and feed this structural context to an LLM can significantly enhance accuracy (in one scenario, leading to a 2x increase in accuracy compared to the LLM alone).

  6. Predicting Links and Roles: Graph features are essential for specific relational tasks, such as predicting new links between nodes (link prediction) or predicting a particular node’s role in the network (e.g., predicting protein function) using structure-based features like Graphlet Degree Vectors.

Mapping Brain Data With a Graph Network

Strucutral MRI

Modeling structural brain data, such as a 3D volumetric MRI, with a graph learning approach involves transforming the regular, spatial data structure into an irregular, relational structure that captures the entities (brain regions) and their interactions (connectivity).

Graphs provide a general language for describing entities, relations, and interactions. Since the fundamental challenge in machine learning (ML) on graphs is finding the right way to incorporate the relational structure, mapping brain data to a graph explicitly enables the ML model to utilize these connectivity patterns.

Here is how structural brain data could be mapped onto the basic features of a graph:

1. Defining Nodes (Entities)

The entities in a structural brain graph would be the different functional or anatomical regions of the brain.

2. Defining Node Features (XX)

Node features are numerical attributes attached to each entity. For structural MRI data, these features would be the metrics describing the structural properties of the defined brain region.

3. Defining Edges and Connectivity (edge_index)

The edges describe the relationships or connections between the brain regions. In neuroscience, these often represent structural connections.

4. Defining the Prediction Target (YY)

The target (YY or data.ydata.y) depends on the specific machine learning task.

Gains of the Graph Mapping Approach

While 3D volumetric MRI data inherently has a regular structure similar to 2D image data (which is typically handled by Convolutional Neural Networks, or CNNs), modeling it as a graph provides critical advantages for biomedical analysis:

  1. Capturing Relational Structure: The most significant gain is explicitly incorporating the information about the graph structure—the complex relational structure of brain connectivity—into the model. Traditional models designed for regular grids struggle to natively encode these non-local, sparse, and intricate topological relationships.

  2. Specialized Deep Learning: The graph representation enables the use of specialized models like Graph Neural Networks (GNNs), which use message passing (aggregating features from connected neighbors) to learn representations that naturally capture both the attributes of a region and its relational context within the network.

  3. Handling Complexity: The approach aligns with applications in other complex scientific domains, such as mapping out complex molecular structures and protein-protein interaction networks in biomedicine. The structure of the brain network itself is the key to obtaining good predictive performance in these relational learning tasks.

Working with functional (MRI, EEG / MEG, etc) data as a graph

fMRI and EEG are more similar to time series data or temporal data, and can be modeled using specialized graph networks and the supporting frameworks.

Graph networks serve as a general language for analyzing entities and the relationships and interactions between them. When dealing with data that changes over time, like time series, the structure can be captured using temporal graphs.

Here is how time-varying data can be conceptualized and managed within a graph network framework:

1. Modeling Dynamic Relationships

Modeling time series data requires capturing the evolving structure of the relationships, which can be done through temporal graphs (graphs that change over time).

2. Handling Temporal Information and Sampling

To process a time series dataset modeled as a temporal graph, techniques focusing on preserving the time order are crucial:

3. Graph Structure and Feature Assignment

When designing the graph network for time series data, the entities (like sensors or measurement points) and their interactions would be defined:

Updated context of Graph Signal Processing

Time series data from neuroimaging modalities, such as fMRI (functional Magnetic Resonance Imaging) BOLD or EEG (Electroencephalography), can be effectively modeled as a graph structure to leverage sophisticated machine learning algorithms like Graph Neural Networks (GNNs) and Graph Transformers.

This modeling approach is particularly valuable because the complex data structure inherent in these imaging modalities aligns well with the networked organizational structure of the human brain.

The general strategy involves three key steps: defining the nodes, determining the node features (signals), and inferring the edges (connectivity).

1. Defining Nodes and Node Features

In the context of brain imaging, the nodes in the graph typically represent regions or entities of interest.

2. Inferring Edges and Connectivity

Edges represent the relationships, interactions, or connectivity patterns between the brain regions (nodes). The method for inferring these connections transforms the raw time series data into a relational structure:

3. Application in Machine Learning

Once the time series data is structured as a graph, it can be utilized by specialized machine learning models:

Graph learning is also incorporated into multimodal modeling frameworks, allowing the integration of networks derived from different brain imaging modalities (like functional and structural networks).

Applied description

Graph machine learning methods are highly effective tools for observing and quantifying changes in functional networks, including transitions between clinical or physiological states. This capacity stems from the fundamental alignment between graph structures and the networked organizational structure of the human brain. Graph learning approaches are specifically designed to extract important information from graphs and model the interactions of multiple brain regions.

1. Identifying Alterations Associated with Clinical Disorders

Graph machine learning (GML) frameworks, particularly those utilizing Graph Neural Networks (GNNs), have been applied directly to brain imaging data to classify disorders and investigate corresponding network alterations:

2. Observing Pharmacologically Induced State Transitions

Network analysis, which provides the foundation for GML approaches by defining graph structure and features, can reveal distinct transitions in brain activity associated with altered states of consciousness, such as those induced by anesthetics: