Java + Spring: How to correctly batch multiple requests into 1 network call with correlation support
To guarantee the scalability and availability of our own service, it's a must to protect downstream dependencies from being flooded with requests within a very shot period of time. One popular technique is to call a Batch API if one is available downstream instead of making separate network calls for each element.
Adopting this approach is straight-forward if our own logic is also a batch implementation (e.g. we have a list of user IDs as input). However, in reality, most of our logic works on a single entity (e.g. send email to a single user). In addition, assuming we have an API endpoint, each HTTP request arriving at this endpoint is completely independent and unaware of other requests. Hence, in the middle of the Batch API downstream and our logic, there must be a batcher that is capable of collecting "similar" requests and make a few network calls on behalf of everyone.
Is this tutorial for you?
If you ask ChatGPT or Google how to implement this batching layer, they would recommend HystrixCollapser
as one of the top results. Don't waste your time with it since Hystrix
library is no longer in active development. In addition, HystrixCollapser
can only support the most basic scenario where all requests can be batched together without any correlation strategy. Pretty useless in real projects :).
Instead, if you're already using Java + Spring framework in your projects, this tutorial is for you. I'm going to show you how to implement this in-memory batcher with the following dependencies. It should be easy to adapt the sample code to the specific versions in your actual projects.
<dependency>
<groupId>org.springframework.integration</groupId>
<artifactId>spring-integration-core</artifactId>
<version>5.5.12</version>
</dependency>
By the end of this tutorial, you will have an in-memory batching layer that can do the following things:
Trade latency for performance + throughput by batching similar requests that happen within a configured window of time (e.g. 200ms).
Invoke the batched logic when the size of the batch reaches a configured threshold or the batch has exceeded a configured timeout duration.
Distribute results back to individual callers using
CompletableFuture
.
MathService
First, I'd like to introduce 3 categories of batch API using this simple MathService
.
sum
represents the batch API that takes multiple inputs and returns the same output for all element.multiplyListByTwo
represents the batch API that takes multiple inputs and returns a dedicated output for each element.echo
represents the batch API that takes multiple inputs and returns nothing.
@Service
public class MathService {
public int sum(List<Integer> numbers) {
return numbers.stream().mapToInt(Integer::intValue).sum();
}
public Map<Integer, Integer> multiplyListByTwo(List<Integer> numbers) {
this.validateAllEvensOrOdd(numbers);
return StreamEx.of(numbers).toMap(i -> i, i -> i * 2);
}
public void echo(List<Integer> numbers) {
System.out.println("Received some numbers: " + numbers);
}
private void validateAllEvensOrOdd(List<Integer> numbers) {
boolean allEven = numbers.stream().allMatch(num -> num % 2 == 0);
boolean allOdd = numbers.stream().allMatch(num -> num % 2 != 0);
if (!allEven && !allOdd) {
throw new RuntimeException("cannot mix odd and even numbers");
}
}
}
MathGateway
Since the MathService
takes List<Integer
as input, we cannot call its methods directly. Instead, we need an intermediate layer that takes in only 1 single number as input. In Spring Integration world, this is called a MessagingGateway
which takes a single input and returns a CompletableFuture
.
@MessagingGateway
public interface MathGateway {
@Gateway(requestChannel = "sumChannel")
CompletableFuture<Integer> sum(@Payload Integer number);
@Gateway(requestChannel = "multiplyByTwoChannel")
CompletableFuture<Integer> multiplyByTwo(@Payload Integer number);
@Gateway(requestChannel = "echoChannel")
CompletableFuture<Void> echo(@Payload Integer number);
}
Invoking our batched API
Before going into the implementation of the batching layer, let's describe the behaviors and the coding syntax we are looking for with the test code below.
@Service
public class BatchTestService {
@Autowired
MathGateway mathGateway;
public void sum() {
List<CompletableFuture<Void>> futures = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
int input = i;
CompletableFuture<Void> future = mathGateway.sum(input)
.thenAccept(result -> {
log.info("same output for all elements, input: {}, output: {}", input, result);
})
.exceptionally(throwable -> {
log.info("same exception for all elements, input: {}, error: {}", input, throwable.getMessage());
return null;
});
futures.add(future);
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
}
public void multiplyByTwo() {
List<CompletableFuture<Void>> futures = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
int input = i;
CompletableFuture<Void> future = mathGateway.multiplyByTwo(input)
.thenAccept(result -> {
log.info("different output for each element, input: {}, output: {}", input, result);
})
.exceptionally(throwable -> {
log.info("same exception for all elements, input: {}, error: {}", input, throwable.getMessage());
return null;
});
futures.add(future);
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
}
public void echo() {
List<CompletableFuture<Void>> futures = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
int input = i;
CompletableFuture<Void> future = mathGateway.echo(input)
.exceptionally(throwable -> {
log.info("same exception for all elements, input: {}, error: {}", input, throwable.getMessage());
return null;
});
futures.add(future);
}
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
}
}
The in-memory batching layer
As you can see from the code above, MathGateway
is just an interface that knows nothing about MathService
. To execute our batched logic in MathService
, we must define 2 things using a @Configuration
class.
The
MessageChannel
to send in individual inputs to be batched together.The
IntegrationFlow
representing our in-memory batching layer that will interact with theMathService
and then distribute results back to each caller.
@Configuration
@EnableIntegration
public class BatchedMathServiceConfigs {
@Resource // Note: cannot use @Autowired here
MathService mathService;
@Bean
public MessageChannel sumChannel() {
return new DirectChannel();
}
@Bean
public MessageChannel multiplyByTwoChannel() {
return new DirectChannel();
}
@Bean
public MessageChannel echoChannel() {
return new DirectChannel();
}
@Bean
public IntegrationFlow sumFlow() {
return BatchingFlowBuilder.create("sumFlowId", this.sumChannel())
.withBatchingStrategy(
BatchAllCorrelationStrategy.create(),
() -> 2000, // max batch size
() -> 200 // batch timeout in ms
)
.withIdenticalResultBatchProcessor(
Function.identity(), // keep the collected List<Integer> as List<Integer>
mathService::sum
)
.build();
}
@Bean
public IntegrationFlow multiplyByTwoFlow() {
return BatchingFlowBuilder.create("multiplyByTwoFlowId", this.multiplyByTwoChannel())
.withBatchingStrategy(
PayloadCorrelationStrategy.create(num -> num % 2), // collapse all odd and even numbers into either 1 or 0
() -> 2000, // max batch size
() -> 200 // batch timeout in ms
)
.withKeyedResultBatchProcessor(
Function.identity(), // keep the collected List<Integer> as List<Integer>
mathService::multiplyByTwo
)
.build();
}
@Bean
public IntegrationFlow echoFlow() {
return BatchingFlowBuilder.create("echoFlowId", this.echoChannel())
.withBatchingStrategy(
BatchAllCorrelationStrategy.create(),
() -> 2000, // max batch size
() -> 200 // batch timeout in ms
)
.withNoResultBatchProcessor(
Function.identity(), // keep the collected List<Integer> as List<Integer>
mathService::echo
)
.build();
}
}
In this following sections, I will list out all components that we used to configure our batching logic above.
CorrelationStrategy
In the MathService
, I added a validation in multiplyListByTwo
to make sure the input numbers are not a mix of odd and even numbers. In essence, this means that the batching layer cannot blindly put all numbers into the same batch. Instead, it must be able to separate odd and even numbers into different batches before calling multiplyListByTwo
.
In Spring Integration world, we can let the batching layer know how to correctly put similar requests into different batches by providing a CorrelationStrategy
.
BatchAllCorrelationStrategy
This strategy indicates that all requests are considered similar
and can be put into a batch directly. This CorrelationStrategy
is suitable for calling sum
or echo
APIs in the MathService
.
import org.springframework.integration.aggregator.CorrelationStrategy;
import org.springframework.messaging.Message;
public class BatchAllCorrelationStrategy implements CorrelationStrategy {
private static final String CORRELATION_KEY = "BATCH_ALL_CORRELATION_KEY";
public static BatchAllCorrelationStrategy create() {
return new BatchAllCorrelationStrategy();
}
@Override
public Object getCorrelationKey(Message<?> message) {
return CORRELATION_KEY;
}
}
PayloadCorrelationStrategy
This strategy indicates that those requests that can be mapped to the same key are considered similar
. This CorrelationStrategy
can help us putting odd and even numbers into different batches before calling the multiplyListByTwo
API by computing inputNum % 2
and use it correlation key at runtime.
Assuming the batch size threshold is 2000 and 1000 numbers arrive at the batching layer around the same time, if 600 requests can be mapped to 1
and 400 can be mapped to 0
, the batching layer would create 2 separate batches, each contains 600 and 400 inputs respectively.
import java.util.function.Function;
import lombok.Builder;
import org.springframework.integration.aggregator.CorrelationStrategy;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
@Builder
public class PayloadCorrelationStrategy<T> implements CorrelationStrategy {
private final Function<T, Object> correlationKeyExtractor;
public static <T> PayloadCorrelationStrategy<T> create(Function<T, Object> correlationKeyExtractor) {
Assert.notNull(correlationKeyExtractor, "Correlation key extractor is required");
return PayloadCorrelationStrategy.<T>builder()
.correlationKeyExtractor(correlationKeyExtractor)
.build();
}
@Override
public Object getCorrelationKey(Message<?> message) {
return correlationKeyExtractor.apply((T) message.getPayload());
}
}
ReleaseStrategy
As mentioned in the beginning, we want to invoke our batched logic when the batch size reaches a certain threshold. With Spring Integration, we can let the batching layer know how to do this by providing a ReleaseStrategy
.
import java.util.Objects;
import java.util.function.Supplier;
import org.springframework.integration.aggregator.ReleaseStrategy;
import org.springframework.integration.store.MessageGroup;
import org.springframework.util.Assert;
public class DynamicMessageCountReleaseStrategy implements ReleaseStrategy {
private static final int DEFAULT_BATCH_SIZE = 500;
private final Supplier<Integer> thresholdSupplier;
public DynamicMessageCountReleaseStrategy() {
this(() -> DEFAULT_BATCH_SIZE);
}
public DynamicMessageCountReleaseStrategy(Supplier<Integer> thresholdSupplier) {
Assert.notNull(thresholdSupplier, "Threshold supplier is required");
this.thresholdSupplier = thresholdSupplier;
}
public boolean canRelease(MessageGroup group) {
return group.size() >= Objects.requireNonNullElse(this.thresholdSupplier.get(), DEFAULT_BATCH_SIZE);
}
}
BatchingFlowBuilder
In Spring Integration world, the "thing" that represents our batching layer is called IntegrationFlow
. We need to build separate IntegrationFlow
that matches each category of batch API that I described when I introduced MathService
earlier.
@Builder(access = AccessLevel.PRIVATE)
public class BatchingFlowBuilder {
private String id;
private MessageChannel channel;
private Supplier<Integer> thresholdSupplier;
private Supplier<Long> timeoutSupplier;
private UnaryOperator<AggregatorSpec> aggregator;
public static BatchingFlowBuilder create(String id, MessageChannel channel) {
return BatchingFlowBuilder.builder()
.id(id)
.channel(channel)
.aggregator(aggregator -> aggregator)
.build();
}
public BatchingFlowBuilder withBatchingStrategy(
CorrelationStrategy correlationStrategy,
Supplier<Integer> thresholdSupplier,
Supplier<Long> timeoutSupplier
) {
this.thresholdSupplier = thresholdSupplier;
this.timeoutSupplier = timeoutSupplier;
UnaryOperator<AggregatorSpec> currentAggregator = this.aggregator;
this.aggregator = aggregator -> currentAggregator.apply(aggregator)
.correlationStrategy(correlationStrategy)
.releaseStrategy(new DynamicMessageCountReleaseStrategy(thresholdSupplier)) // Max batch size
.groupTimeout(msg -> timeoutSupplier.get()); // Timeout for releasing the batch
return this;
}
public <P, K, T, R> BatchingFlowBuilder withKeyedResultBatchProcessor(
Function<P, K> payloadKeyExtractor,
Function<List<P>, T> payloadBatcher,
Function<T, Map<K, R>> batchProcessor
) {
UnaryOperator<AggregatorSpec> currentAggregator = this.aggregator;
this.aggregator = aggregator -> currentAggregator.apply(aggregator)
.outputProcessor(group -> this.doExecuteBatchWithKeyedResult(group, payloadKeyExtractor, payloadBatcher, batchProcessor));
return this;
}
<P, K, T, R> Message<?> doExecuteBatchWithKeyedResult(
MessageGroup messageGroup,
Function<P, K> payloadKeyExtractor,
Function<List<P>, T> payloadBatcher,
Function<T, Map<K, R>> batchProcessor
) {
try {
T batchedPayload = StreamEx.of(messageGroup.getMessages())
.map(message -> (P) message.getPayload())
.distinct(payloadKeyExtractor)
.toListAndThen(payloadBatcher);
Map<K, R> results = batchProcessor.apply(batchedPayload);
List<Pair<Object, R>> payload = StreamEx.of(messageGroup.getMessages())
.map(message ->
Pair.of(
message.getHeaders().getReplyChannel(),
results.get(payloadKeyExtractor.apply((P) message.getPayload()))
)
).toList();
return MessageBuilder.withPayload(payload).build();
} catch (Exception ex) {
List<Pair<Object, Exception>> payload = StreamEx.of(messageGroup.getMessages())
.map(message -> Pair.of(message.getHeaders().getReplyChannel(), ex))
.toList();
return MessageBuilder.withPayload(payload).build();
}
}
public <P, T, R> BatchingFlowBuilder withIdenticalResultBatchProcessor(
Function<List<P>, T> payloadBatcher,
Function<T, R> batchProcessor
) {
UnaryOperator<AggregatorSpec> currentAggregator = this.aggregator;
this.aggregator = aggregator -> currentAggregator.apply(aggregator)
.outputProcessor(group -> this.doExecuteBatchWithIdenticalResult(group, payloadBatcher, batchProcessor));
return this;
}
<P, T, R> Message<?> doExecuteBatchWithIdenticalResult(
MessageGroup messageGroup,
Function<List<P>, T> payloadBatcher,
Function<T, R> batchProcessor
) {
try {
T batchedPayload = StreamEx.of(messageGroup.getMessages())
.map(message -> (P) message.getPayload())
.toListAndThen(payloadBatcher);
R result = batchProcessor.apply(batchedPayload);
List<Pair<Object, R>> payload = StreamEx.of(messageGroup.getMessages())
.map(message -> Pair.of(message.getHeaders().getReplyChannel(), result))
.toList();
return MessageBuilder.withPayload(payload).build();
} catch (Exception ex) {
List<Pair<Object, Exception>> payload = StreamEx.of(messageGroup.getMessages())
.map(message -> Pair.of(message.getHeaders().getReplyChannel(), ex))
.toList();
return MessageBuilder.withPayload(payload).build();
}
}
public <P, T> BatchingFlowBuilder withNoResultBatchProcessor(
Function<List<P>, T> payloadBatcher,
Consumer<T> batchProcessor
) {
UnaryOperator<AggregatorSpec> currentAggregator = this.aggregator;
this.aggregator = aggregator -> currentAggregator.apply(aggregator)
.outputProcessor(group -> this.doExecuteBatchWithNoResult(group, payloadBatcher, batchProcessor));
return this;
}
<P, T> Message<?> doExecuteBatchWithNoResult(
MessageGroup messageGroup,
Function<List<P>, T> payloadBatcher,
Consumer<T> batchProcessor
) {
try {
T batchedPayload = StreamEx.of(messageGroup.getMessages())
.map(message -> (P) message.getPayload())
.toListAndThen(payloadBatcher);
batchProcessor.accept(batchedPayload);
List<Pair<Object, Object>> payload = StreamEx.of(messageGroup.getMessages())
.map(message -> Pair.of(message.getHeaders().getReplyChannel(), null))
.toList();
return MessageBuilder.withPayload(payload).build();
} catch (Exception ex) {
List<Pair<Object, Exception>> payload = StreamEx.of(messageGroup.getMessages())
.map(message -> Pair.of(message.getHeaders().getReplyChannel(), ex))
.toList();
return MessageBuilder.withPayload(payload).build();
}
}
public IntegrationFlow build() {
return IntegrationFlows.from(this.channel)
.aggregate(
aggregator -> this.aggregator.apply(aggregator)
.messageStore(new SimpleMessageStore())
.expireGroupsUponCompletion(true)
.expireGroupsUponTimeout(true)
.sendPartialResultOnExpiry(true)
)
.split()
.handle(Pair.class, (payload, headers) ->
MessageBuilder.withPayload(payload.getRight())
.copyHeaders(headers)
.setReplyChannel((MessageChannel) payload.getLeft())
.setErrorChannel((MessageChannel) payload.getLeft())
.build()
)
.handle((payload, headers) -> payload)
.get();
}
}
Final words
All of the components we need to build the in-memory batching layer provided above are fully functional and being used in production in my company. It should be easy for you to adapt the sample code to match actual use cases in your projects.
Subscribe to my newsletter
Read articles from James Tran directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by