I'm working on a simple micro-benchmarking library in Java.
This library makes it easy to benchmark multiple reference implementations. You provide the inputs and trigger the calls, the library takes care of executing N times, measuring the runs and printing the averaged results.
Essentially, it works like this:
Create a dedicated class for the benchmark
Prepare input data in the constructor
Create one method per reference implementation, annotate with
@MeasureTimeAdd a
mainmethod to trigger the benchmark runner, with an instance of this class as parameterRun the class, find the results on
stdout
To set the number of warm-up iterations and iterations, use the @Benchmark annotation on the class, for example: @Benchmark(iterations = 10, warmUpIterations = 5)
The annotations:
@Retention(RetentionPolicy.RUNTIME) public @interface MeasureTime { int[] iterations() default {}; int[] warmUpIterations() default {}; } @Retention(RetentionPolicy.RUNTIME) public @interface Benchmark { int iterations() default BenchmarkRunner.DEFAULT_ITERATIONS; int warmUpIterations() default BenchmarkRunner.DEFAULT_WARM_UP_ITERATIONS; } The class that runs the benchmarks on a target object passed in, parameterized by annotations:
import microbench.api.annotation.Benchmark; import microbench.api.annotation.MeasureTime; import microbench.api.annotation.Prepare; import microbench.api.annotation.Validate; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.util.*; public class BenchmarkRunner { public static final int DEFAULT_ITERATIONS = 1; public static final int DEFAULT_WARM_UP_ITERATIONS = 0; private final Object target; private final int defaultIterations; private final int defaultWarmUpIterations; private final List<Method> measureTimeMethods = new ArrayList<>(); private final List<Method> prepareMethods = new ArrayList<>(); private final List<Method> validateMethods = new ArrayList<>(); public BenchmarkRunner(Object target) { this.target = target; Class<?> clazz = target.getClass(); Benchmark annotation = clazz.getAnnotation(Benchmark.class); if (annotation != null) { defaultIterations = annotation.iterations(); defaultWarmUpIterations = annotation.warmUpIterations(); } else { defaultIterations = DEFAULT_ITERATIONS; defaultWarmUpIterations = DEFAULT_WARM_UP_ITERATIONS; } for (Method method : clazz.getDeclaredMethods()) { if (method.getAnnotation(MeasureTime.class) != null) { measureTimeMethods.add(method); } else if (method.getAnnotation(Prepare.class) != null) { prepareMethods.add(method); } else if (method.getAnnotation(Validate.class) != null) { validateMethods.add(method); } } Collections.sort(measureTimeMethods, (o1, o2) -> o1.getName().compareTo(o2.getName())); } public static void run(Object target) { new BenchmarkRunner(target).run(); } public void run() { runQuietly(); } private void runQuietly() { try { runNormally(); } catch (InvocationTargetException | IllegalAccessException e) { e.printStackTrace(); } } private void runNormally() throws InvocationTargetException, IllegalAccessException { Map<Method, Throwable> validationFailures = new LinkedHashMap<>(); for (Method method : measureTimeMethods) { MeasureTime measureTime = method.getAnnotation(MeasureTime.class); if (measureTime != null) { try { runMeasureTime(target, method, measureTime); } catch (InvocationTargetException e) { Throwable cause = e.getCause(); if (cause instanceof AssertionError) { validationFailures.put(method, cause); printExecutionFailure(method); } else { throw e; } } } } if (!validationFailures.isEmpty()) { System.out.println(); for (Map.Entry<Method, Throwable> entry : validationFailures.entrySet()) { System.out.print("Validation failed while executing " + entry.getKey().getName() + ": "); System.out.println(entry.getValue()); } } } private void invokeMethods(Object instance, List<Method> methods) throws InvocationTargetException, IllegalAccessException { for (Method method : methods) { method.invoke(instance); } } private void runMeasureTime(Object instance, Method method, MeasureTime measureTime) throws InvocationTargetException, IllegalAccessException { for (int i = 0; i < getWarmUpIterations(measureTime); ++i) { invokeMethods(instance, prepareMethods); method.invoke(instance); invokeMethods(instance, validateMethods); } int iterations = getIterations(measureTime); long sumDiffs = 0; for (int i = 0; i < iterations; ++i) { invokeMethods(instance, prepareMethods); long start = System.nanoTime(); method.invoke(instance); sumDiffs += System.nanoTime() - start; invokeMethods(instance, validateMethods); } printExecutionResult(method, sumDiffs / iterations); } private void printExecutionInfo(String message, String ms) { System.out.println(String.format("%-60s: %10s ms", message, ms)); } private void printExecutionFailure(Method method) { printExecutionInfo("Validation failed while executing " + method.getName(), "-"); } private void printExecutionResult(Method method, long nanoSeconds) { printExecutionInfo("Average execution time of " + method.getName(), "" + nanoSeconds / 1_000_000); } private int getParamValue(int[] values, int defaultValue) { if (values.length > 0) { return values[0]; } return defaultValue; } private int getWarmUpIterations(MeasureTime measureTime) { return getParamValue(measureTime.warmUpIterations(), defaultWarmUpIterations); } private int getIterations(MeasureTime measureTime) { return getParamValue(measureTime.iterations(), defaultIterations); } } An example benchmark class:
public class SimpleSortingDemo { private List<Integer> shuffledList; public SimpleSortingDemo() { shuffledList = new ArrayList<>(); for (int i = 0; i < 10000; ++i) { shuffledList.add(i); } Collections.shuffle(shuffledList); } public static void main(String[] args) { new BenchmarkRunner(new SimpleSortingDemo()).run(); } @MeasureTime public void bubbleSort() { BubbleSort.sort(new ArrayList<Integer>(shuffledList)); } @MeasureTime public void insertionSort() { InsertionSort.sort(new ArrayList<Integer>(shuffledList)); } } If you want to test drive it in your own projects, the GitHub project page explains nicely the steps to get started.
I'd like a review in terms of everything, but here are some points you might want to pick on:
- What would you do differently?
- Is there a way to make the library easier to use?
- Is the implementation of
BenchmarkRunnerclear and natural?- Is it adequate the way it measures the execution time?
- Are the annotation names intuitive and natural? (If not, can you suggest better names?)
- The
@MeasureTimeannotation returnsint[]asiterations, which is sort of a dirty hack I use to treat itnullby default, inheriting from@Benchmark.iterationsor the global default. Is there a cleaner way to do this?
mvn install'd to be used. \$\endgroup\$mvn testin your main project without this dependency. Ok ok I'll try to get this in Maven Central soon :) \$\endgroup\$