Berlin Tech Meetup: The Future of Relational Foundation Models, Systems, and Real-World Applications

Register now:
PyG/Use Case11 min read

Network Intrusion: GNN on Traffic Flow Graphs

The average data breach costs $4.45M and takes 277 days to detect. Signature-based IDS misses novel attacks. Here is how to build a GNN that detects intrusions by analyzing the communication graph structure.

PyTorch Geometric

TL;DR

  • 1Network intrusion detection is a graph anomaly problem. Attacks create distinctive patterns in the host communication graph: lateral movement, C2 channels, and data exfiltration have structural signatures.
  • 2GATConv on the traffic flow graph learns which communication patterns are anomalous, using attention to weight the importance of each network flow.
  • 3On RelBench benchmarks, GNNs achieve 75.83 AUROC vs 62.44 for flat-table LightGBM. Graph structure captures multi-stage attack patterns that flow-level analysis misses.
  • 4The PyG model operates on windowed graph snapshots (1-5 minute windows), enabling near-real-time detection of evolving attack campaigns.
  • 5KumoRFM detects intrusion patterns with one PQL query (76.71 AUROC zero-shot), constructing the traffic graph and identifying anomalies automatically.

The business problem

IBM's Cost of a Data Breach report shows that the average breach costs $4.45 million and takes 277 days to identify and contain. Signature-based intrusion detection systems (IDS) catch known attack patterns but miss novel threats, zero-day exploits, and advanced persistent threats (APTs) that use legitimate tools in unusual ways.

The key insight: attacks create anomalous patterns in the network communication graph. Lateral movement appears as a compromised host suddenly connecting to many internal hosts. Command-and-control channels create periodic, low-volume connections to external hosts. Data exfiltration shows as unusual data volume flowing to unfamiliar destinations.

Why flat ML fails

  • Flow-level analysis: Flat models classify individual flows as malicious or benign. They miss multi-stage attacks where each individual flow looks normal but the pattern across flows is anomalous.
  • No communication context: A flow from Host A to Host B looks different in context: if A just received data from a known-compromised host, the flow to B is more suspicious.
  • Lateral movement blindness: An admin SSH session is normal. Twenty admin SSH sessions from the same host to different servers in 5 minutes is lateral movement. Flat models see 20 normal flows.
  • Beaconing detection: C2 beaconing creates periodic traffic patterns. The graph temporal structure (regular intervals, consistent packet sizes) reveals beaconing that individual flow analysis misses.

The relational schema

schema.txt
Node types:
  Host     (id, os, role, subnet, patch_level)
  Service  (id, port, protocol, criticality)
  External (id, ip_range, geo, reputation_score)

Edge types:
  Host     --[connects_to]--> Host     (bytes, packets, duration)
  Host     --[runs]-->        Service
  Host     --[reaches]-->     External (bytes, frequency)
  Host     --[same_subnet]--> Host

Internal hosts, services, and external endpoints form the network graph. Flow attributes on edges capture traffic volume and patterns.

PyG architecture: GATConv for anomaly scoring

intrusion_model.py
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, HeteroConv, Linear

class IntrusionGNN(torch.nn.Module):
    def __init__(self, hidden_dim=64, heads=4):
        super().__init__()
        self.host_lin = Linear(-1, hidden_dim)
        self.service_lin = Linear(-1, hidden_dim)
        self.external_lin = Linear(-1, hidden_dim)

        self.conv1 = HeteroConv({
            ('host', 'connects_to', 'host'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads,
                edge_dim=3),  # bytes, packets, duration
            ('host', 'runs', 'service'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
            ('host', 'reaches', 'external'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads,
                edge_dim=2),  # bytes, frequency
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('host', 'connects_to', 'host'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads,
                edge_dim=3),
            ('host', 'runs', 'service'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads),
            ('host', 'reaches', 'external'): GATConv(
                hidden_dim, hidden_dim // heads, heads=heads,
                edge_dim=2),
        }, aggr='sum')

        self.anomaly_scorer = Linear(hidden_dim, 1)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        x_dict['host'] = self.host_lin(x_dict['host'])
        x_dict['service'] = self.service_lin(x_dict['service'])
        x_dict['external'] = self.external_lin(
            x_dict['external'])

        x_dict = {k: F.elu(v) for k, v in
                  self.conv1(x_dict, edge_index_dict).items()}
        x_dict = self.conv2(x_dict, edge_index_dict)

        return torch.sigmoid(
            self.anomaly_scorer(x_dict['host']).squeeze(-1))

GATConv with edge features (traffic volume, frequency) learns which communication patterns are anomalous. Attention weights highlight the most suspicious flows per host.

Expected performance

  • Signature-based IDS: High recall for known attacks, ~0% for novel attacks
  • LightGBM (flow-level): 62.44 AUROC
  • GNN (GATConv traffic graph): 75.83 AUROC
  • KumoRFM (zero-shot): 76.71 AUROC

Or use KumoRFM in one line

KumoRFM PQL
PREDICT is_compromised FOR host
USING host, service, flow_log, external_connection

One PQL query. KumoRFM constructs the traffic graph from flow logs and identifies anomalous communication patterns.

Frequently asked questions

Why use GNNs for network intrusion detection?

Network attacks exploit the communication graph: lateral movement, command-and-control channels, and data exfiltration create distinctive patterns in the traffic flow graph. Flat models see individual flows; GNNs see how flows connect hosts, revealing attack campaigns that span multiple stages and hosts.

What graph structure represents network traffic?

Hosts are nodes with features (OS, role, patch level). Network flows become edges with features (bytes, packets, duration, protocol, port). The resulting graph shows communication patterns: which hosts talk to which, how much data flows, and whether the pattern matches known attack signatures.

How do GNNs detect lateral movement?

Lateral movement creates unusual graph patterns: a compromised host suddenly communicates with many internal hosts it never contacted before, often using administrative protocols. A 2-hop GNN sees this fan-out pattern in the host's neighborhood, flagging the anomalous communication pattern.

Can GNN intrusion detection run in real-time?

Yes, with windowed graph snapshots. Build a new traffic graph every 1-5 minutes from flow logs. Run the GNN on each snapshot to score hosts and flows. Pre-computed embeddings for known-good hosts make incremental updates efficient. Target: score a new snapshot in under 30 seconds.

How does KumoRFM handle cybersecurity use cases?

KumoRFM can take your network flow data (hosts, flows, alerts) and predict intrusion risk per host with one PQL query. It constructs the traffic graph automatically and identifies anomalous communication patterns without manual feature engineering.

Learn more about graph ML

PyTorch Geometric is the open-source foundation for graph neural networks. Explore more layers, concepts, and production patterns.