Skip to main content

Overview

While Apache Beam provides many built-in I/O connectors, you may need to create custom I/O transforms for:
  • Proprietary data sources
  • Specialized file formats
  • Custom APIs or services
  • Performance optimization
  • Specific business requirements
This guide covers different approaches to building custom I/O connectors.

Approaches to Custom I/O

There are three main approaches to implementing custom I/O in Beam:
  1. DoFn-based I/O: Simple, quick implementation using ParDo
  2. PTransform-based I/O: Reusable, composable transforms
  3. Source/Sink API: Advanced, optimized for splitting and parallelism
  4. Splittable DoFn (SDF): Modern, flexible approach for advanced sources

DoFn-Based I/O

The simplest approach using DoFn for custom reading and writing.

Custom Reader with DoFn

import apache_beam as beam
from apache_beam.transforms import DoFn, ParDo
import requests

class ReadFromAPI(DoFn):
    """Read data from a REST API."""
    
    def __init__(self, api_url, auth_token):
        self.api_url = api_url
        self.auth_token = auth_token
    
    def setup(self):
        """Initialize resources (called once per worker)."""
        self.session = requests.Session()
        self.session.headers.update({
            'Authorization': f'Bearer {self.auth_token}'
        })
    
    def process(self, element):
        """Fetch data from API."""
        endpoint = f"{self.api_url}/{element}"
        response = self.session.get(endpoint)
        
        if response.status_code == 200:
            yield response.json()
        else:
            # Handle errors
            import logging
            logging.error(f"Failed to fetch {endpoint}: {response.status_code}")
    
    def teardown(self):
        """Clean up resources."""
        if hasattr(self, 'session'):
            self.session.close()

# Usage
with beam.Pipeline() as p:
    results = (
        p 
        | beam.Create(['user1', 'user2', 'user3'])  # Resource IDs
        | ParDo(ReadFromAPI(
              api_url='https://api.example.com/users',
              auth_token='your-token-here'))
    )

Custom Writer with DoFn

import apache_beam as beam
from apache_beam.transforms import DoFn
import json

class WriteToAPI(DoFn):
    """Write data to a REST API."""
    
    def __init__(self, api_url, auth_token, batch_size=100):
        self.api_url = api_url
        self.auth_token = auth_token
        self.batch_size = batch_size
        self.batch = []
    
    def setup(self):
        import requests
        self.session = requests.Session()
        self.session.headers.update({
            'Authorization': f'Bearer {self.auth_token}',
            'Content-Type': 'application/json'
        })
    
    def process(self, element):
        """Batch and write elements."""
        self.batch.append(element)
        
        if len(self.batch) >= self.batch_size:
            self._flush_batch()
    
    def finish_bundle(self):
        """Flush remaining elements at end of bundle."""
        if self.batch:
            self._flush_batch()
    
    def _flush_batch(self):
        """Send batch to API."""
        if not self.batch:
            return
        
        response = self.session.post(
            self.api_url,
            data=json.dumps(self.batch)
        )
        
        if response.status_code == 200:
            import logging
            logging.info(f"Successfully wrote {len(self.batch)} records")
            self.batch = []
        else:
            raise Exception(f"Failed to write batch: {response.status_code}")
    
    def teardown(self):
        if hasattr(self, 'session'):
            self.session.close()

# Usage
with beam.Pipeline() as p:
    (
        p 
        | beam.Create([{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}])
        | beam.ParDo(WriteToAPI(
              api_url='https://api.example.com/data',
              auth_token='your-token-here',
              batch_size=50))
    )

PTransform-Based I/O

Create reusable, composable I/O transforms by extending PTransform.

Custom Read PTransform

import apache_beam as beam
from apache_beam.transforms import PTransform
from apache_beam.io.filesystems import FileSystems
import csv
from io import StringIO

class ReadFromCustomCSV(PTransform):
    """Custom CSV reader with special handling."""
    
    def __init__(self, file_pattern, delimiter=',', skip_rows=0):
        super().__init__()
        self.file_pattern = file_pattern
        self.delimiter = delimiter
        self.skip_rows = skip_rows
    
    def expand(self, pcoll):
        """Expand the transform."""
        return (
            pcoll.pipeline
            | 'CreatePattern' >> beam.Create([self.file_pattern])
            | 'MatchFiles' >> beam.FlatMap(self._match_files)
            | 'ReadFiles' >> beam.FlatMap(self._read_file)
        )
    
    def _match_files(self, pattern):
        """Match files using glob pattern."""
        match_results = FileSystems.match([pattern])[0]
        return [metadata.path for metadata in match_results.metadata_list]
    
    def _read_file(self, file_path):
        """Read and parse CSV file."""
        with FileSystems.open(file_path) as f:
            content = f.read().decode('utf-8')
            reader = csv.DictReader(
                StringIO(content),
                delimiter=self.delimiter
            )
            
            # Skip rows if specified
            for _ in range(self.skip_rows):
                next(reader, None)
            
            for row in reader:
                # Custom processing
                processed_row = {k: v.strip() for k, v in row.items()}
                yield processed_row

# Usage
with beam.Pipeline() as p:
    records = p | ReadFromCustomCSV(
        'data/*.csv',
        delimiter='|',
        skip_rows=2
    )

Custom Write PTransform

import apache_beam as beam
from apache_beam.transforms import PTransform
from apache_beam.io.filesystems import FileSystems
import json

class WriteToCustomJSON(PTransform):
    """Custom JSON writer with formatting options."""
    
    def __init__(self, file_path_prefix, num_shards=1, 
                 indent=2, sort_keys=True):
        super().__init__()
        self.file_path_prefix = file_path_prefix
        self.num_shards = num_shards
        self.indent = indent
        self.sort_keys = sort_keys
    
    def expand(self, pcoll):
        """Expand the transform."""
        return (
            pcoll
            | 'GroupByShard' >> beam.transforms.util.ReshufflePerKey()
            | 'FormatJSON' >> beam.Map(self._format_json)
            | 'WriteToDisk' >> beam.Map(self._write_to_file)
        )
    
    def _format_json(self, element):
        """Format element as JSON."""
        return json.dumps(
            element,
            indent=self.indent,
            sort_keys=self.sort_keys
        )
    
    def _write_to_file(self, element):
        """Write JSON to file."""
        # Implementation depends on sharding strategy
        pass

# Better approach using WriteToText with formatting
class WriteToFormattedJSON(PTransform):
    """Write formatted JSON using built-in WriteToText."""
    
    def __init__(self, file_path_prefix, **json_kwargs):
        super().__init__()
        self.file_path_prefix = file_path_prefix
        self.json_kwargs = json_kwargs
    
    def expand(self, pcoll):
        return (
            pcoll
            | 'FormatJSON' >> beam.Map(
                  lambda x: json.dumps(x, **self.json_kwargs))
            | 'WriteToText' >> beam.io.WriteToText(
                  self.file_path_prefix,
                  file_name_suffix='.json')
        )

# Usage
with beam.Pipeline() as p:
    data = p | beam.Create([
        {'name': 'Alice', 'age': 30},
        {'name': 'Bob', 'age': 25}
    ])
    
    data | WriteToFormattedJSON(
        'output/data',
        indent=2,
        sort_keys=True
    )

BoundedSource API

For advanced use cases requiring optimal splitting and parallelism.
import apache_beam as beam
from apache_beam.io import BoundedSource, RangeTracker
from apache_beam.io.iobase import SourceBundle

class CustomBoundedSource(BoundedSource):
    """Custom source with splitting support."""
    
    def __init__(self, data_source, start_pos=0, end_pos=None):
        self.data_source = data_source
        self.start_pos = start_pos
        self.end_pos = end_pos
    
    def estimate_size(self):
        """Estimate the size of the source in bytes."""
        # Calculate and return estimated size
        return len(self.data_source) * 100  # Rough estimate
    
    def split(self, desired_bundle_size, start_position=None, stop_position=None):
        """Split source into bundles for parallel processing."""
        start = start_position or self.start_pos
        stop = stop_position or self.end_pos or len(self.data_source)
        
        # Calculate number of bundles
        total_size = self.estimate_size()
        num_bundles = max(1, int(total_size / desired_bundle_size))
        bundle_size = (stop - start) // num_bundles
        
        bundles = []
        for i in range(num_bundles):
            bundle_start = start + (i * bundle_size)
            bundle_stop = start + ((i + 1) * bundle_size) if i < num_bundles - 1 else stop
            
            bundle_source = CustomBoundedSource(
                self.data_source,
                bundle_start,
                bundle_stop
            )
            
            bundles.append(SourceBundle(
                weight=bundle_stop - bundle_start,
                source=bundle_source,
                start_position=bundle_start,
                stop_position=bundle_stop
            ))
        
        return bundles
    
    def get_range_tracker(self, start_position, stop_position):
        """Get a RangeTracker for the given range."""
        start = start_position or self.start_pos
        stop = stop_position or self.end_pos or len(self.data_source)
        return RangeTracker(start, stop)
    
    def read(self, range_tracker):
        """Read data from the source."""
        for i in range(range_tracker.start_position(), 
                      range_tracker.stop_position()):
            if not range_tracker.try_claim(i):
                # Position was already claimed by another worker
                return
            
            # Yield the element at position i
            yield self._read_element(i)
    
    def _read_element(self, position):
        """Read a single element at the given position."""
        # Custom logic to read element
        return self.data_source[position]
    
    def default_output_coder(self):
        """Return the coder for output elements."""
        from apache_beam.coders import coders
        return coders.StrUtf8Coder()

# Usage
with beam.Pipeline() as p:
    data = p | beam.io.Read(CustomBoundedSource(['item1', 'item2', 'item3']))

Splittable DoFn (SDF)

Modern approach for implementing advanced sources with dynamic work rebalancing.
import apache_beam as beam
from apache_beam.transforms.core import RestrictionProvider
from apache_beam.io.restriction_trackers import OffsetRange, OffsetRestrictionTracker

class ReadFromCustomSource(beam.DoFn):
    """Splittable DoFn for custom data source."""
    
    def __init__(self, source_config):
        self.source_config = source_config
    
    def process(self, element, tracker=beam.DoFn.RestrictionParam()):
        """Process elements with restriction tracking."""
        # Read data within the tracked range
        for i in range(tracker.current_restriction().start,
                      tracker.current_restriction().stop):
            if tracker.try_claim(i):
                # Read and yield element at position i
                yield self._read_at_position(i)
            else:
                # Stop if claim fails (work redistributed)
                return
    
    def create_tracker(self, restriction):
        """Create restriction tracker."""
        return OffsetRestrictionTracker(restriction)
    
    def initial_restriction(self, element):
        """Define initial restriction (range to process)."""
        # Return the full range for this element
        return OffsetRange(start=0, stop=self._get_total_records())
    
    def split(self, element, restriction):
        """Split restriction into sub-ranges."""
        # Split the range for parallel processing
        size = restriction.stop - restriction.start
        if size > 1:
            mid = restriction.start + size // 2
            yield OffsetRange(restriction.start, mid)
            yield OffsetRange(mid, restriction.stop)
        else:
            yield restriction
    
    def _read_at_position(self, position):
        """Read element at specific position."""
        # Custom reading logic
        return f"Record at position {position}"
    
    def _get_total_records(self):
        """Get total number of records in source."""
        return 1000  # Example

# Usage
with beam.Pipeline() as p:
    results = (
        p 
        | beam.Create(['source_id'])  # Trigger element
        | beam.ParDo(ReadFromCustomSource(config={'url': 'http://api.example.com'}))
    )

FileBasedSink Implementation

For advanced file writing with custom formatting.
import apache_beam as beam
from apache_beam.io import filebasedsink
from apache_beam.io.filesystem import CompressionTypes

class CustomFileBasedSink(filebasedsink.FileBasedSink):
    """Custom file sink with special formatting."""
    
    def __init__(self, file_path_prefix, file_name_suffix='',
                 num_shards=0, shard_name_template=None,
                 compression_type=CompressionTypes.AUTO):
        super().__init__(
            file_path_prefix,
            file_name_suffix=file_name_suffix,
            num_shards=num_shards,
            shard_name_template=shard_name_template,
            coder=beam.coders.coders.ToBytesCoder(),
            compression_type=compression_type
        )
    
    def open(self, temp_path):
        """Open file for writing."""
        return CustomFileWriter(temp_path)

class CustomFileWriter(filebasedsink.FileBasedSink.Writer):
    """Custom writer implementation."""
    
    def __init__(self, temp_path):
        super().__init__(temp_path)
        self.file_handle = None
    
    def initialize(self):
        """Initialize the writer."""
        from apache_beam.io.filesystems import FileSystems
        self.file_handle = FileSystems.create(self.temp_path)
        # Write header
        self.file_handle.write(b'# Custom File Format v1.0\n')
    
    def write(self, element):
        """Write a single element."""
        # Custom formatting
        formatted = self._format_element(element)
        self.file_handle.write(formatted.encode('utf-8'))
        self.file_handle.write(b'\n')
    
    def close(self):
        """Close the file."""
        if self.file_handle:
            # Write footer
            self.file_handle.write(b'# End of file\n')
            self.file_handle.close()
    
    def _format_element(self, element):
        """Format element for writing."""
        # Custom formatting logic
        if isinstance(element, dict):
            return ' | '.join(f"{k}={v}" for k, v in element.items())
        return str(element)

# Usage
with beam.Pipeline() as p:
    data = p | beam.Create([
        {'id': 1, 'name': 'Alice'},
        {'id': 2, 'name': 'Bob'}
    ])
    
    data | beam.io.Write(CustomFileBasedSink(
        'output/custom',
        file_name_suffix='.txt',
        num_shards=2
    ))

Testing Custom I/O

Unit Testing

import unittest
import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, equal_to

class TestCustomIO(unittest.TestCase):
    
    def test_custom_read(self):
        """Test custom read transform."""
        with TestPipeline() as p:
            output = p | ReadFromCustomCSV('test_data.csv')
            
            assert_that(
                output,
                equal_to([
                    {'id': '1', 'name': 'Alice'},
                    {'id': '2', 'name': 'Bob'}
                ])
            )
    
    def test_custom_write(self):
        """Test custom write transform."""
        with TestPipeline() as p:
            data = p | beam.Create([{'id': 1}, {'id': 2}])
            data | WriteToCustomJSON('output/test')
        
        # Verify output files
        # Read back and check contents
    
    def test_dofn_processing(self):
        """Test DoFn-based I/O."""
        with TestPipeline() as p:
            output = (
                p 
                | beam.Create(['id1', 'id2'])
                | beam.ParDo(ReadFromAPI(
                      'http://test.api.com',
                      'test-token'))
            )
            
            # Mock API responses in tests
            assert_that(output, equal_to([...]))

Performance Optimization

Batching for Write Operations

import apache_beam as beam
from apache_beam.transforms import GroupIntoBatches

class BatchedWriteDoFn(beam.DoFn):
    """Write with automatic batching."""
    
    def __init__(self, batch_size=100):
        self.batch_size = batch_size
    
    def process(self, batch):
        """Process a batch of elements."""
        # Write entire batch at once
        self._write_batch(batch)
        yield len(batch)  # Return count of written records
    
    def _write_batch(self, elements):
        """Write batch to destination."""
        # Efficient batch write
        pass

# Usage with batching
with beam.Pipeline() as p:
    (
        p 
        | beam.Create(range(1000))
        | beam.Map(lambda x: ('key', x))  # Add dummy key
        | GroupIntoBatches(batch_size=100)
        | beam.Values()  # Extract batched values
        | beam.ParDo(BatchedWriteDoFn())
    )

Connection Pooling

import apache_beam as beam
from apache_beam.transforms import DoFn

class PooledConnectionDoFn(DoFn):
    """DoFn with connection pooling."""
    
    def __init__(self, connection_string):
        self.connection_string = connection_string
    
    def setup(self):
        """Initialize connection pool (once per worker)."""
        from sqlalchemy import create_engine, pool
        self.engine = create_engine(
            self.connection_string,
            poolclass=pool.QueuePool,
            pool_size=10,
            max_overflow=20
        )
    
    def process(self, element):
        """Process element using pooled connection."""
        with self.engine.connect() as conn:
            result = conn.execute("SELECT * FROM table WHERE id = %s", (element,))
            yield result.fetchone()
    
    def teardown(self):
        """Clean up connection pool."""
        if hasattr(self, 'engine'):
            self.engine.dispose()

Error Handling and Retry Logic

import apache_beam as beam
from apache_beam.transforms import DoFn
import time
import logging

class ResilientReadDoFn(DoFn):
    """DoFn with retry logic and error handling."""
    
    def __init__(self, max_retries=3, backoff_factor=2):
        self.max_retries = max_retries
        self.backoff_factor = backoff_factor
    
    def process(self, element):
        """Process with retry logic."""
        for attempt in range(self.max_retries):
            try:
                result = self._read_with_errors(element)
                yield result
                return  # Success
            except Exception as e:
                if attempt < self.max_retries - 1:
                    wait_time = self.backoff_factor ** attempt
                    logging.warning(
                        f"Attempt {attempt + 1} failed for {element}. "
                        f"Retrying in {wait_time}s. Error: {e}"
                    )
                    time.sleep(wait_time)
                else:
                    # Final attempt failed
                    logging.error(f"All retries failed for {element}: {e}")
                    # Yield to dead letter queue
                    yield beam.pvalue.TaggedOutput('errors', {
                        'element': element,
                        'error': str(e)
                    })
    
    def _read_with_errors(self, element):
        """Simulated read that may fail."""
        # Your actual read logic here
        pass

# Usage with error handling
with beam.Pipeline() as p:
    results = (
        p 
        | beam.Create(['item1', 'item2', 'item3'])
        | beam.ParDo(ResilientReadDoFn()).with_outputs('errors', main='success')
    )
    
    # Process successful reads
    results.success | 'WriteSuccess' >> beam.io.WriteToText('output/success')
    
    # Handle errors
    results.errors | 'WriteErrors' >> beam.io.WriteToText('output/errors')

Best Practices

1. Resource Management

class WellManagedDoFn(beam.DoFn):
    def setup(self):
        # Initialize expensive resources once per worker
        self.connection = create_connection()
    
    def process(self, element):
        # Use the resource
        yield self.connection.query(element)
    
    def teardown(self):
        # Clean up resources
        if hasattr(self, 'connection'):
            self.connection.close()

2. Serialization

Ensure your custom classes are serializable:
import apache_beam as beam
from typing import Any

class SerializableDoFn(beam.DoFn):
    """Properly serializable DoFn."""
    
    def __init__(self, config: dict):
        # Store only serializable config
        self.config = config
        # Don't store connections or file handles in __init__!
    
    def setup(self):
        # Create non-serializable resources here
        self.client = create_client(self.config)
    
    def process(self, element):
        yield self.client.process(element)

3. Testing

Always test your custom I/O:
import unittest
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that, equal_to

class TestCustomIO(unittest.TestCase):
    def test_custom_source(self):
        with TestPipeline() as p:
            output = p | YourCustomSource()
            assert_that(output, equal_to(expected_data))

4. Documentation

Document your custom I/O transforms:
class ReadFromCustomSource(beam.PTransform):
    """
    Read data from CustomSource.
    
    Args:
        source_url: URL of the data source
        auth_token: Authentication token
        batch_size: Number of records to fetch per request (default: 100)
    
    Returns:
        PCollection of dictionaries containing the data
    
    Example:
        >>> with beam.Pipeline() as p:
        ...     data = p | ReadFromCustomSource(
        ...         source_url='https://api.example.com/data',
        ...         auth_token='secret-token'
        ...     )
    """
    pass

Complete Example: Custom Database Connector

import apache_beam as beam
from apache_beam.transforms import PTransform, DoFn
from typing import Iterable, Dict, Any
import psycopg2
from psycopg2 import pool

class ReadFromPostgreSQL(PTransform):
    """Read data from PostgreSQL database."""
    
    def __init__(self, connection_string: str, query: str, 
                 batch_size: int = 1000):
        super().__init__()
        self.connection_string = connection_string
        self.query = query
        self.batch_size = batch_size
    
    def expand(self, pcoll):
        return (
            pcoll.pipeline
            | 'CreateQuery' >> beam.Create([self.query])
            | 'ReadDB' >> beam.ParDo(
                  ReadPostgreSQLDoFn(
                      self.connection_string,
                      self.batch_size))
        )

class ReadPostgreSQLDoFn(DoFn):
    """DoFn for reading from PostgreSQL."""
    
    def __init__(self, connection_string: str, batch_size: int):
        self.connection_string = connection_string
        self.batch_size = batch_size
        self.connection_pool = None
    
    def setup(self):
        """Initialize connection pool."""
        self.connection_pool = psycopg2.pool.SimpleConnectionPool(
            minconn=1,
            maxconn=10,
            dsn=self.connection_string
        )
    
    def process(self, query: str) -> Iterable[Dict[str, Any]]:
        """Execute query and yield results."""
        conn = None
        try:
            conn = self.connection_pool.getconn()
            cursor = conn.cursor()
            cursor.execute(query)
            
            # Fetch in batches
            while True:
                rows = cursor.fetchmany(self.batch_size)
                if not rows:
                    break
                
                # Get column names
                columns = [desc[0] for desc in cursor.description]
                
                # Yield rows as dictionaries
                for row in rows:
                    yield dict(zip(columns, row))
            
            cursor.close()
        finally:
            if conn:
                self.connection_pool.putconn(conn)
    
    def teardown(self):
        """Close all connections."""
        if self.connection_pool:
            self.connection_pool.closeall()

# Usage
with beam.Pipeline() as p:
    users = p | ReadFromPostgreSQL(
        connection_string='postgresql://user:pass@localhost/db',
        query='SELECT * FROM users WHERE active = true',
        batch_size=500
    )
    
    # Process the data
    users | beam.Map(print)

Next Steps

I/O Overview

Learn about I/O concepts and patterns in Apache Beam

Built-in Connectors

Explore the built-in I/O connectors available in Beam