Skip to main content
Apache Beam’s RunInference API provides a scalable way to run machine learning inference in your data pipelines. This page demonstrates ML patterns for various frameworks and use cases.

PyTorch Image Classification

Run inference on images using PyTorch models. Based on sdks/python/apache_beam/examples/inference/pytorch_image_classification.py:96-166
import apache_beam as beam
from apache_beam.ml.inference.base import RunInference, KeyedModelHandler
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor
import torch
from torchvision import models, transforms
from PIL import Image

# Preprocessing function
def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ),
    ])
    return transform(image)

# Create model handler
model_handler = KeyedModelHandler(
    PytorchModelHandlerTensor(
        state_dict_path='path/to/model.pth',
        model_class=models.mobilenet_v2,
        model_params={'num_classes': 1000},
        device='CPU',
        min_batch_size=10,
        max_batch_size=100
    )
)

# Build pipeline
with beam.Pipeline() as p:
    predictions = (
        p
        | 'ReadImagePaths' >> beam.io.ReadFromText('image_list.txt')
        | 'PreprocessImages' >> beam.Map(
            lambda path: (path, preprocess_image(path))
        )
        | 'RunInference' >> RunInference(model_handler)
        | 'ExtractPredictions' >> beam.Map(
            lambda x: f"{x[0]}: {torch.argmax(x[1].inference)}"
        )
        | 'WriteResults' >> beam.io.WriteToText('predictions.txt')
    )
Key Features:
  • Automatic batching for efficient inference
  • Support for CPU and GPU devices
  • Pre/post processing hooks
  • Keyed inputs for tracking

Scikit-learn Classification

Run inference with scikit-learn models. Based on sdks/python/apache_beam/examples/inference/sklearn_mnist_classification.py:88-133
import apache_beam as beam
from apache_beam.ml.inference.base import RunInference, KeyedModelHandler
from apache_beam.ml.inference.sklearn_inference import (
    SklearnModelHandlerNumpy,
    ModelFileType
)
from apache_beam.options.pipeline_options import PipelineOptions

def process_input(row: str) -> tuple:
    """Parse CSV row into label and features."""
    data = row.split(',')
    label = int(data[0])
    pixels = [int(pixel) for pixel in data[1:]]
    return label, pixels

class PostProcessor(beam.DoFn):
    """Extract predictions from results."""
    def process(self, element):
        label, prediction_result = element
        prediction = prediction_result.inference
        yield f'{label},{prediction}'

# Create model handler
model_handler = KeyedModelHandler(
    SklearnModelHandlerNumpy(
        model_file_type=ModelFileType.PICKLE,
        model_uri='gs://bucket/model.pkl',
        large_model=False
    )
)

# Build pipeline
with beam.Pipeline(options=PipelineOptions()) as p:
    predictions = (
        p
        | 'ReadInput' >> beam.io.ReadFromText('data.csv')
        | 'ParseInput' >> beam.Map(process_input)
        | 'RunInference' >> RunInference(model_handler)
        | 'PostProcess' >> beam.ParDo(PostProcessor())
        | 'WriteOutput' >> beam.io.WriteToText('predictions.csv')
    )
Supported Model Types:
  • ModelFileType.PICKLE - Pickled models
  • ModelFileType.JOBLIB - Joblib-serialized models
Large Model Optimization:
  • Set large_model=True for models that cause memory pressure
  • Reduces concurrent copies in workers

TensorFlow Inference

Run TensorFlow SavedModel inference.
import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.tensorflow_inference import (
    TFModelHandlerNumpy,
    default_tensor_inference_fn
)
import numpy as np

def preprocess_features(row: dict) -> np.ndarray:
    """Convert input to numpy array."""
    return np.array([row['feature1'], row['feature2']], dtype=np.float32)

# Create TensorFlow model handler
model_handler = TFModelHandlerNumpy(
    model_uri='gs://bucket/saved_model/',
    inference_fn=default_tensor_inference_fn
)

with beam.Pipeline() as p:
    predictions = (
        p
        | 'ReadData' >> beam.io.ReadFromBigQuery(query='SELECT * FROM dataset.table')
        | 'Preprocess' >> beam.Map(preprocess_features)
        | 'Inference' >> RunInference(model_handler)
        | 'ExtractPredictions' >> beam.Map(lambda result: result.inference[0])
    )

Multi-Model Inference

Run inference with different models based on keys.
import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor
from apache_beam.ml.inference.base import KeyedModelHandler
import torch

# Model handler that selects model based on key
model_handler = PytorchModelHandlerKeyedTensor(
    state_dict_path='gs://bucket/models/',
    model_class=MyModelClass,
    model_params={},
    device='CPU'
)

with beam.Pipeline() as p:
    predictions = (
        p
        | 'CreateData' >> beam.Create([
            ('model_a', torch.tensor([1.0, 2.0, 3.0])),
            ('model_b', torch.tensor([4.0, 5.0, 6.0])),
            ('model_a', torch.tensor([7.0, 8.0, 9.0])),
        ])
        | 'RunInference' >> RunInference(model_handler)
    )

Inference with Side Inputs

Use side inputs for dynamic model configuration.
import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor

class ModelWithSideInput(beam.DoFn):
    """Apply inference with configuration from side input."""
    def process(self, element, config):
        # Access side input configuration
        threshold = config['threshold']
        
        # Process element with configuration
        yield element if element.inference[0] > threshold else None

# Side input for configuration
config = p | 'Config' >> beam.Create([{'threshold': 0.5}])

predictions = (
    images
    | 'RunInference' >> RunInference(model_handler)
    | 'FilterWithConfig' >> beam.ParDo(
        ModelWithSideInput(),
        config=beam.pvalue.AsSingleton(config)
    )
)

Batching Configuration

Optimize inference performance with batching.
from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor

# Configure batching parameters
model_handler = PytorchModelHandlerTensor(
    state_dict_path='model.pth',
    model_class=MyModel,
    model_params={},
    device='GPU',
    # Batching configuration
    min_batch_size=10,      # Minimum batch size
    max_batch_size=100,     # Maximum batch size
    max_batch_duration_secs=1.0  # Max wait time for batch
)
Batching Guidelines:
  • min_batch_size: Wait for at least this many elements
  • max_batch_size: Process at most this many elements together
  • max_batch_duration_secs: Maximum time to wait for batch to fill

Model Loading Strategies

Different ways to load models.
from apache_beam.ml.inference.sklearn_inference import (
    SklearnModelHandlerNumpy,
    ModelFileType
)

model_handler = SklearnModelHandlerNumpy(
    model_file_type=ModelFileType.PICKLE,
    model_uri='gs://bucket/model.pkl'
)

Streaming ML Pipeline

Combine streaming with ML inference.
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions, StandardOptions
from apache_beam.ml.inference.base import RunInference
from apache_beam.transforms import window

# Configure for streaming
pipeline_options = PipelineOptions()
pipeline_options.view_as(StandardOptions).streaming = True

with beam.Pipeline(options=pipeline_options) as p:
    predictions = (
        p
        | 'ReadFromPubSub' >> beam.io.ReadFromPubSub(
            subscription='projects/project/subscriptions/input'
        )
        | 'ParseJSON' >> beam.Map(json.loads)
        | 'ExtractFeatures' >> beam.Map(extract_features)
        | 'Window' >> beam.WindowInto(window.FixedWindows(60))  # 1-min windows
        | 'RunInference' >> RunInference(model_handler)
        | 'FormatOutput' >> beam.Map(format_prediction)
        | 'WriteToPubSub' >> beam.io.WriteToPubSub(
            topic='projects/project/topics/predictions'
        )
    )

Feature Engineering

Preprocess features before inference.
import apache_beam as beam
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.ml.transforms.tft import TFTOperation
import tensorflow_transform as tft

# Define feature preprocessing
def preprocessing_fn(inputs):
    """TensorFlow Transform preprocessing."""
    outputs = {}
    outputs['normalized_feature'] = tft.scale_to_z_score(inputs['feature'])
    outputs['bucketized_feature'] = tft.bucketize(
        inputs['numeric_feature'], 
        num_buckets=10
    )
    return outputs

# Apply feature engineering
transformed_data = (
    raw_data
    | 'Preprocess' >> MLTransform(
        write_artifact_location='gs://bucket/artifacts/',
    ).with_transform(TFTOperation(preprocessing_fn))
    | 'RunInference' >> RunInference(model_handler)
)

Model Monitoring

Track inference metrics and performance.
import apache_beam as beam
from apache_beam.metrics import Metrics

class InferenceWithMetrics(beam.DoFn):
    """Track inference metrics."""
    def __init__(self):
        self.inference_counter = Metrics.counter('inference', 'predictions')
        self.latency_distribution = Metrics.distribution('inference', 'latency_ms')
    
    def process(self, element):
        import time
        start_time = time.time()
        
        # Run inference
        prediction = element.inference
        
        # Track metrics
        self.inference_counter.inc()
        latency_ms = (time.time() - start_time) * 1000
        self.latency_distribution.update(int(latency_ms))
        
        yield (element.example, prediction)

predictions = (
    data
    | 'RunInference' >> RunInference(model_handler)
    | 'TrackMetrics' >> beam.ParDo(InferenceWithMetrics())
)

A/B Testing Models

Compare multiple model versions.
import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
import random

class ABTestRouter(beam.DoFn):
    """Route traffic between model versions."""
    def process(self, element):
        # Route 90% to model A, 10% to model B
        model_version = 'model_a' if random.random() < 0.9 else 'model_b'
        yield beam.pvalue.TaggedOutput(model_version, element)

# Split traffic
model_a_tag = 'model_a'
model_b_tag = 'model_b'

routed = data | 'Route' >> beam.ParDo(ABTestRouter()).with_outputs(
    model_a_tag, model_b_tag
)

# Run inference on each branch
predictions_a = (
    routed[model_a_tag]
    | 'InferenceA' >> RunInference(model_handler_a)
    | 'TagA' >> beam.Map(lambda x: ('model_a', x))
)

predictions_b = (
    routed[model_b_tag]
    | 'InferenceB' >> RunInference(model_handler_b)
    | 'TagB' >> beam.Map(lambda x: ('model_b', x))
)

# Combine results
all_predictions = (predictions_a, predictions_b) | beam.Flatten()

Best Practices

Optimize Batching

  • Configure batch sizes for your hardware
  • Balance latency vs. throughput
  • Monitor batch utilization

Handle Large Models

  • Use large_model=True for memory efficiency
  • Consider model quantization
  • Use GPU workers for large models

Version Your Models

  • Include version info in model paths
  • Track model metadata
  • Support A/B testing

Monitor Performance

  • Track inference latency
  • Monitor prediction quality
  • Set up alerting for anomalies

Supported Frameworks

PyTorch

  • PytorchModelHandlerTensor
  • State dict or entire model
  • CPU and GPU support

Scikit-learn

  • SklearnModelHandlerNumpy
  • Pickle or Joblib format
  • All sklearn estimators

TensorFlow

  • TFModelHandlerNumpy
  • SavedModel format
  • TF 2.x models

ONNX

  • ONNXModelHandler
  • Cross-framework models
  • Optimized runtime

XGBoost

  • XGBoostModelHandler
  • Boosted trees
  • Native format

Vertex AI

  • VertexAIModelHandler
  • Managed endpoints
  • Auto-scaling

RunInference API

Complete API reference

Model Handlers

Framework-specific handlers

ML Transforms

Feature engineering transforms

ML Pipeline Patterns

Common ML pipeline architectures