Physics Simulation With Graph Neural Networks Targeting Mobile

Predict dynamic behaviors in a physics system in a way that’s computationally efficient and adaptable to a range of scenarios.

popularity

By Máté Stodulka and Tomas Zilhao Borges

The demand for immersive, realistic graphics in mobile gaming and AR or VR is pushing the limits of mobile hardware. Achieving lifelike simulations of fluids, cloth, and other materials historically requires intensive mathematical computations. While these traditional methods yield highly accurate results, they have been too resource-heavy to run real-time on mobile. But as mobile hardware advances, Machine Learning (ML) techniques, particularly Graph Neural Networks (GNNs), are emerging as a powerful, efficient alternative to emulate physics on mobile.

GNNs are particularly suited for scenarios where real-world situations can be represented as interactions between related objects. So, by representing each particle as a node and the forces between them as edges, GNNs can be used to predict dynamic behaviors in a physics system. This enables GNNs to approximate traditional methods in a way that’s computationally efficient and adaptable to a range of scenarios, making them promising for resource-constrained mobile devices. The recent launch of TensorFlow GNN offers a streamlined way to design, build and deploy GNNs, providing “ready to wear” architectures and essential tools to define nodes, edges, and interactions.

To assess the feasibility of these simulations on mobile, we evaluate the performance of GNN-based models on today’s state-of-the-art (SOTA) hardware.

With these goals in mind, we focus on two main objectives:

  • Perform a study of the GNN architecture and the new TF-GNN API
  • Determine whether GNNs are a viable approach for implementing physics simulations.

Introduction to GNNs

Fig. 1: A graphic representation of a GNN.

GNNs excel at representing data as networks of objects and their interactions. This ability makes GNNs particularly well-suited for applications where data is naturally structured as interconnected entities, for example social networks, recommendation systems, physics simulations, and so on.

GNNs extend the foundational ideas of Convolutional Neural Networks (CNNs) to graph data. While CNNs capture spatial locality in grid-like data (for example, images) through convolutional kernels, GNNs capture structural locality in graph data, allowing for flexible connections represented by sparse adjacency matrices. And like the sliding of convolutional kernels in CNNs, GNNs share weights across edges most commonly through message passing, efficiently capturing patterns across a graph.

At their core, GNNs consist of:

  • Nodes representing individual entities, for example particles in a simulation or people in a social network
  • Edges representing relationships or interactions between nodes, for example forces in a physics system or friendships in a network
  • Global Context representing graph-level information, for example gravity in a physics simulation or overarching characteristics in a dataset

Graphs can be built from both static and dynamic properties.

The development of GNNs has introduced several key milestones in their expressive power:

  • Node-level Predictions: Early GNNs used MLPs to predict node attributes independently
  • Neighbor Pooling: Graph connectivity allowed pooling of neighboring nodes’ information
  • Graph Convolutions: Convolution-like operations allowed better information aggregation
  • Message Passing: Iterative exchanges of information between nodes improved learning
  • Edge Representations: Added context by learning edge features alongside node features
  • Global Context: Introduced to capture broad graph information
  • Attention Mechanisms: Allowed GNNs to focus on the most relevant parts of a graph

Message passing is the core of GNNs, allowing nodes to share information. It generally involves:

  • Message Function (M): Computes messages between nodes.
  • Aggregation Function (A): Gathers messages for each node (for example, sum, mean).
  • Update Function (U): Updates nodes or edges based on messages, often with non-linear transformations.

Each message-passing step covers immediate neighbors, and additional layers capture a wider network context, enabling a comprehensive understanding of graph structures.

TF-GNN overview

Fig. 2: Layers of TF-GNN.

TF-GNN was first published in 2022, with a major update in early 2024. TensorFlow (TF) itself is a mature framework, with support from a great number of platforms and tools. TF-GNN integrates tightly into base TF, with most added data and structures being made up of native TF operators, enabling compatibility with the popular Keras API as well as easy conversion into LiteRT, Google’s mobile friendly version of TensorFlow models.

The TF-GNN library is made up of multiple API levels, as shown in the graphic, enabling multiple levels of fine-tuning with increasing complexity.

  • Data level, creating your own graphs: The foundational layer of the TF-GNN API is the Data level, where users create and define the graph structures themselves
  • Data exchange level – creating custom MP layers: The message-passing layer of TF-GNN controls data exchange across node and edge sets, allowing users to define custom data flow between entities in the graph
  • Keras layers: These layers provide flexibility and control, while also maintaining compatibility with other TensorFlow and Keras tools. For more details on available layers, check out the TF-GNN layer documentation
  • Orchestrator, high level API: At the highest level, the TF-GNN Orchestrator (or TF-GNN Runner) enables users to experiment with complete models and datasets efficiently

Physics simulation

Physics simulations traditionally relied on solving the Navier-Stokes partial differential equations (PDEs), offering high accuracy but demanding heavy computational power, which limits real-time applications on mobile devices. As an alternative, Machine Learning (ML) has recently emerged as a faster, less compute-heavy, adaptable solution for physics simulation.

Historically, physics simulation approaches are broadly divided into:

  1. Non-ML: Highly accurate but too slow for real-time mobile use, as it explicitly solves physics equations
  2. Physics-driven ML: ML approximates physical laws, offering differentiable models but with high implementation complexity
  3. Data-driven ML: Uses ML to learn from data, achieving faster results for real-time applications but can suffer from overfitting, data quality issues and sub-optimal accuracy in edge cases
  4. Hybrid Approach: Combines ML and Non-ML methods, balancing speed and accuracy but is challenging to coordinate effectively

On a “frame of reference” classification level, all methods are roughly divided into two categories: Eulerian (grid-based) and Lagrangian (particle-based), which increases the diversity of implementations.

DeepMind: Learning to simulate

The paper called Learning to Simulate by DeepMind presents an innovative approach to simulating complex physics scenarios by using a GNN architecture. This approach successfully models various physical environments, for example fluids, rigid bodies, and deformable materials, with results that generalize well across new configurations and types of materials, in 2D and 3D. On the other hand, there were no performance numbers that were provided.

The paper has several novel ideas both on the “data modeling” and on the “GNN architecture” levels. More specifically, the paper’s framework is structured around an Encoder-Processor-Decoder architecture, each stage tailored to handle particle interactions and simulate physical behavior over time.

“Learning to Simulate” also presents a novel set of datasets for various materials, for example, water, sand, viscous substance. Each material’s behavior is being simulated using correct, traditional methods. Each dataset contains thousands of examples, providing both short-term and long-term interaction data, typically with around 200-500 timesteps, and with particles numbering between 1,000 and 2,000.

Adoption

We appreciate DeepMind’s contributions to the field and have adopted their theoretical approach to using GNNs for physics simulations as well as their provided datasets.

While DeepMind’s original implementation uses an older TensorFlow 1.0 framework, which lacks compatibility with recent libraries, we adapt their architecture to TensorFlow 2, exploring the newly released TF-GNN (TensorFlow Graph Neural Networks) library.

To manage the computational costs associated with creating GraphTensors, we pre-process all data, saving intermediate GraphTensor states in TFRecord format. These pre-processed TFRecords then serve as the datasets for our implementation.

Model architecture and implementation details

Our implementation largely follows DeepMind’s architecture, with an Encoder-Processor-Decoder separation. The Encoder creates the graph structure from raw data, the Processor (or Core GNN) implements Message Passing, and the Decoder extracts information from the graph. We start with position windows and end up predicting normalized acceleration for each particle.

Fig. 3: Node features.

Encoder

The Encoder processes particle positions into relevant features in a GraphTensor structure and embeds them into the tfgnn.HIDDEN_STATE vector.

Windowing

We start with five previous positions for each particle and their current absolute coordinates. This provides historical context and aids in generalization by focusing on relative dynamics instead of absolute positions.

Feature and edge generation

The Encoder derives specific features:

  • Velocity history: Normalized over time to remove absolute positions, allowing the model to learn particle dynamics directly and generally.
  • Distance to bounds: A normalized metric to track particle distance from simulation edges, which is clipped and set to constant ones unless particles are near bounds.

Edges are dynamically generated based on particle proximity (within a specified radius), balancing model complexity with computational efficiency. We replaced KD-Trees with a TF-native approach, achieving significant speed improvements for accelerated hardware while maintaining information fidelity. Edge features include:

  • Scalar normalized distance: Between connected nodes
  • Normalized displacement vector: Relative position between nodes

These relative features are the only position-based data available to the model; absolute positions are excluded from message-passing to support better generalization. Global context like gravity is scattered to each node as a feature, simplifying Message Passing.

Graph building

Nodes, edges, and global context are combined into GraphTensors, and node features are embedded into a 128-dimensional tfgnn.HIDDEN_STATE vector, preparing them for message passing.

Processor (Core GNN)

This is the core-GNN model, made up of multiple Message Passing (MP) layers. MP layers take the state of connected nodes (and edges, if edge features are present), pass them through a basic Multi-Layer Perceptron (MLP), and use that to update the node state. Multiple MP steps are necessary, as the default connectivity radius for particles is low (since edges are expensive), so multiple MPs are required to get data to particles further away. We achieved good results with 8 steps. The MP layers are highly configurable with various hyperparameters. Our choices mostly match the paper or are smaller – given that the model we are aiming for does not have to be as widely generalizable and should be runnable on mobile. We are using TF-GNN’s library-provided MtAlbis GraphUpdate layer, as it’s highly customizable. We set up ours with 128 hidden units and residual connections, but without using attention.

Decoder

The Decoder mirrors the Encoder, translating the node state (tfgnn.HIDDEN_STATE) back to normalized acceleration values.

Readout

The Readout layer takes each particle’s tfgnn.HIDDEN_STATE feature, using a 2-layer MLP with ReLU activation to predict normalized acceleration values for each particle.

Postprocessor

The Postprocessor further processes the output of the Readout, ensuring results are correctly scaled for regression tasks.

Training

The model’s training approach aims to simulate particle dynamics accurately over extended time steps by combining both stepwise and rollout evaluation modes (explained below). Stepwise mode is used for the training process while rollout mode is ultimately used for deciding which model is best.

In addition, to assess the model’s real-world accuracy, visual evaluations are conducted using side-by-side matplotlib animations comparing predictions to ground truth.

Stepwise Mode

Stepwise mode uses independent inputs to predict the next position frame-by-frame, allowing for faster training due to easier parallelization and trivial backpropagation.

Fig. 4: The Stepwise model.

Rollout Mode

Rollout mode, however, simulates real-world scenarios by feeding predictions as inputs for subsequent frames, causing accumulated error and revealing issues like boundary misinterpretation and oscillation near equilibrium (particle at rest).

Fig. 5: The Rollout model.

Rollout-MSE metric and training supervision

A rollout-MSE metric, implemented as a callback, generates rollout trajectories during training to better capture cumulative error, although this is slower due to the strictly serial nature of rollout evaluations. The supervised training approach predicts one frame at a time, with accumulated error over long rollouts.

Training strategies

  1. Noise Injection
  2. Loss and optimization: The L2 loss function (squared error) is used to minimize prediction error in accelerations, with the Adam optimizer and learning rate decay enhancing stable convergence.
  3. Normalization and Batch Size: Inputs are standardized to 0mean and unit variance for training stability.

Hyperparameter tuning

The hyperparameter settings primarily follow recommendations from the DeepMind paper. Beyond the inherent reduction in complexity from using single-material models, further efforts were made to explore alternative hyperparameter options to simplify the model. Using Keras Tuner for hyperparameter search, these simplifications are expected to enhance mobile performance.

Training stats

Models were trained on a NVIDIA RTX 6000 Ada GPU with 48GB GPU memory.

The training time depends on hyperparameters. For example, to train our model with the best tuned hyperparameters (8 message passing layers, with embedding size of 128, a total of ~700K parameters), it took around 10.5 minutes per epoch. We trained for 50 epochs with early stopping, so a total of around 9 hours of training time.

Results

The gif below is the result of passing one set of positions to the model and running for around 300 frames in rollout mode. You can see that while it does accumulate some error over time, it is still perceived as reasonable movement for liquids.

Fig. 6: Results.

Analysis

Model limitations

The time-step is learned implicitly and depends on consistent FPS, requiring separate models for different timings. Additional mechanisms and/or data would be needed to mitigate this.

Gravity is not learned directly, as the dataset only includes negative-Y gravity, so it is set as a global input parameter.

Also, while the model theoretically supports multi-material scenarios, it has only been tested with homogeneous materials and requires further validation.

Performance on Android

Results for Pixel 9’s CPU-only inference:

Num particles Inference time per step Avg. time per rollout of 25 steps Avg. time per rollout of 300 steps
32 ~1 ms ~100 ms ~1s
800 ~25 ms ~0.7 ms ~7s

Despite the very modest results, there is great room for performance improvement. The current readings use CPU-only inference and with the device’s default delegate (XNNPack). In addition, the model is yet to be quantized and to have its architecture optimized.

On the other hand, we have found some evidence of limitations for GPU delegate usage. There are reported problems with GATHER, RESHAPE and SLICE as TFLite GPU delegate supports only a subset of TFLite and TF ops.

Conclusions

In this blog post we introduced GNN and how to build GNN models using different features of the TF-GNN API. We also shared the results of implementing a model from scratch following DeepMind’s theoretical basis. Based on our results, we recommend Graph Neural Networks for physics simulation workloads.

The TF-GNN API has a steep initial learning curve for using its data layer, especially if there is no data available already in the GraphSchema format. For the layers above the data one, the TF-GNN API is easy to use, with close integration into Keras.

Graph Neural Networks are on the path to becoming more mainstream, with exciting opportunities for the maturation of essential operations across message passing such as scatters, gathers, segmented operations, ragged and dynamic tensors.

Explore more about on-device inference with Real-time low light video enhancement using Neural Networks on mobile.

Tomas Zilhao Borges is a graduate software engineer at Arm.



Leave a Reply


(Note: This name will be displayed publicly)