diff --git a/randomized-runner/src/main/java/com/carrotsearch/randomizedtesting/rules/StaticFieldsInvariantRule.java b/randomized-runner/src/main/java/com/carrotsearch/randomizedtesting/rules/StaticFieldsInvariantRule.java new file mode 100644 index 0000000..1b4c2c3 --- /dev/null +++ b/randomized-runner/src/main/java/com/carrotsearch/randomizedtesting/rules/StaticFieldsInvariantRule.java @@ -0,0 +1,110 @@ +package com.carrotsearch.randomizedtesting.rules; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +import junit.framework.AssertionFailedError; + +import org.junit.ClassRule; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +/** + * A {@link TestRule} that ensures static, reference fields of the suite class + * (and optionally its superclasses) are cleaned up after a suite is completed. + * This is helpful in finding out static memory leaks (a class references + * something huge but is no longer used). + * + * @see ClassRule + * @see #accept(Field) + */ +public class StaticFieldsInvariantRule implements TestRule { + public static final long DEFAULT_LEAK_THRESHOLD = 10 * 1024 * 1024; + + private final long leakThreshold; + private final boolean countSuperclasses; + + /** + * By default use {@link #DEFAULT_LEAK_THRESHOLD} as the threshold and count + * in superclasses. + */ + public StaticFieldsInvariantRule() { + this(DEFAULT_LEAK_THRESHOLD, true); + } + + public StaticFieldsInvariantRule(long leakThresholdBytes, boolean countSuperclasses) { + this.leakThreshold = leakThresholdBytes; + this.countSuperclasses = countSuperclasses; + } + + static class Entry implements Comparable { + final long ramUsed; + final Field field; + + public Entry(Field field, long ramUsed) { + this.field = field; + this.ramUsed = ramUsed; + } + + @Override + public int compareTo(Entry o) { + if (this.ramUsed > o.ramUsed) return -1; + if (this.ramUsed < o.ramUsed) return 1; + return this.field.toString().compareTo(o.field.toString()); + } + } + + @Override + public Statement apply(final Statement s, final Description d) { + return new StatementAdapter(s) { + @Override + protected void afterAlways(List errors) throws Throwable { + ArrayList fields = new ArrayList(); + long ramEnd = 0; + for (Class c = d.getTestClass(); countSuperclasses && c.getSuperclass() != null; c = c.getSuperclass()) { + for (Field field : c.getDeclaredFields()) { + if (Modifier.isStatic(field.getModifiers()) && + !field.getType().isPrimitive() && + accept(field)) { + field.setAccessible(true); + final long fieldRam = RamUsageEstimator.sizeOf(field.get(null)); + if (fieldRam > 0) { + fields.add(new Entry(field, fieldRam)); + ramEnd += fieldRam; + } + } + } + } + + if (ramEnd > leakThreshold) { + Collections.sort(fields); + + StringBuilder b = new StringBuilder(); + b.append(String.format(Locale.ENGLISH, "Clean up static fields (in @AfterClass?), " + + "your test seems to hang on to approximately %,d bytes (threshold is %,d):", + ramEnd, leakThreshold)); + + for (Entry e : fields) { + b.append(String.format(Locale.ENGLISH, "\n - %,d bytes, %s", e.ramUsed, + e.field.toString())); + } + + errors.add(new AssertionFailedError(b.toString())); + } + } + }; + } + + /** + * @return Return false to exclude a given field from being + * counted. By default final fields are rejected. + */ + protected boolean accept(Field field) { + return !Modifier.isFinal(field.getModifiers()); + } +} \ No newline at end of file diff --git a/randomized-runner/src/test/java/com/carrotsearch/randomizedtesting/rules/TestStaticFieldsInvariantRule.java b/randomized-runner/src/test/java/com/carrotsearch/randomizedtesting/rules/TestStaticFieldsInvariantRule.java new file mode 100644 index 0000000..2ed4caa --- /dev/null +++ b/randomized-runner/src/test/java/com/carrotsearch/randomizedtesting/rules/TestStaticFieldsInvariantRule.java @@ -0,0 +1,87 @@ +package com.carrotsearch.randomizedtesting.rules; + +import org.fest.assertions.api.Assertions; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.RuleChain; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runner.JUnitCore; +import org.junit.runner.Result; +import org.junit.runners.model.Statement; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.WithNestedTestClass; + +public class TestStaticFieldsInvariantRule extends WithNestedTestClass { + static int LEAK_THRESHOLD = 5 * 1024 * 1024; + + public static class Base extends RandomizedTest { + private static TestRule assumeNotNestedRule = new TestRule() { + public Statement apply(final Statement base, Description description) { + return new Statement() { + public void evaluate() throws Throwable { + assumeRunningNested(); + base.evaluate(); + } + }; + } + }; + + @ClassRule + public static TestRule classRules = + RuleChain + .outerRule(assumeNotNestedRule) + .around(new StaticFieldsInvariantRule(LEAK_THRESHOLD, true)); + + @Test + public void testEmpty() {} + } + + public static class Smaller extends Base { + + static byte [] field0; + + @SuppressWarnings("unused") + @BeforeClass + private static void setup() { + field0 = new byte [LEAK_THRESHOLD / 2]; + } + } + + public static class Exceeding extends Smaller { + static byte [] field1; + static byte [] field2; + static int [] field3; + static long field4; + final static long [] field5 = new long [1024]; + + @SuppressWarnings("unused") + @BeforeClass + private static void setup() { + field1 = new byte [LEAK_THRESHOLD / 2]; + field2 = new byte [100]; + field3 = new int [100]; + } + } + + @Test + public void testPassingUnderThreshold() { + Result runClasses = JUnitCore.runClasses(Smaller.class); + Assertions.assertThat(runClasses.getFailures()).isEmpty(); + } + + @Test + public void testFailingAboveThreshold() { + Result runClasses = JUnitCore.runClasses(Exceeding.class); + Assertions.assertThat(runClasses.getFailures()).hasSize(1); + + Assertions.assertThat(runClasses.getFailures().get(0).getTrace()) + .contains(".field0") + .contains(".field1") + .contains(".field2") + .contains(".field3") + .doesNotContain(".field5"); + } +}