diff --git a/pom.xml b/pom.xml
index 039c29d..cbba88c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -21,7 +21,7 @@
3.1.0
1.7.0
- 5.7.1
+ 5.9.0
3.19.0
3.10.0
2.14.1
diff --git a/src/main/java/com/github/fridujo/classpath/junit/extension/jupiter/AbstractClasspathExtension.java b/src/main/java/com/github/fridujo/classpath/junit/extension/jupiter/AbstractClasspathExtension.java
index b638bd2..8f1b5b4 100644
--- a/src/main/java/com/github/fridujo/classpath/junit/extension/jupiter/AbstractClasspathExtension.java
+++ b/src/main/java/com/github/fridujo/classpath/junit/extension/jupiter/AbstractClasspathExtension.java
@@ -1,27 +1,31 @@
package com.github.fridujo.classpath.junit.extension.jupiter;
-import java.io.File;
-import java.lang.reflect.Method;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
+import com.github.fridujo.classpath.junit.extension.Classpath;
+import com.github.fridujo.classpath.junit.extension.PathElement;
+import com.github.fridujo.classpath.junit.extension.buildtool.BuildTool;
+import org.junit.jupiter.api.extension.ExecutableInvoker;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.InvocationInterceptor;
import org.junit.jupiter.api.extension.ReflectiveInvocationContext;
+import org.junit.platform.commons.logging.Logger;
+import org.junit.platform.commons.logging.LoggerFactory;
import org.junit.platform.commons.util.ReflectionUtils;
-import com.github.fridujo.classpath.junit.extension.Classpath;
-import com.github.fridujo.classpath.junit.extension.PathElement;
-import com.github.fridujo.classpath.junit.extension.buildtool.BuildTool;
+import java.io.File;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Executable;
+import java.lang.reflect.Method;
+import java.util.List;
+import java.util.stream.Collectors;
abstract class AbstractClasspathExtension implements InvocationInterceptor {
- private final ExtensionContext.Namespace namespace = ExtensionContext.Namespace.create(AbstractClasspathExtension.class);
+ private static final Logger logger = LoggerFactory.getLogger(AbstractClasspathExtension.class);
@Override
public void interceptTestTemplateMethod(Invocation invocation,
ReflectiveInvocationContext invocationContext,
- ExtensionContext extensionContext) throws Throwable {
+ ExtensionContext extensionContext) {
intercept(invocation, invocationContext, extensionContext);
}
@@ -39,12 +43,12 @@ private void intercept(Invocation invocation, ReflectiveInvocationContext<
Classpath classpath = supplyClasspath(invocationContext, extensionContext, buildTool);
BuildToolLocator.store(extensionContext, classpath.buildTool);
- invokeMethodWithModifiedClasspath(invocationContext, classpath);
+ invokeMethodWithModifiedClasspath(invocationContext, classpath, extensionContext.getExecutableInvoker());
}
protected abstract Classpath supplyClasspath(ReflectiveInvocationContext invocationContext, ExtensionContext extensionContext, BuildTool buildTool);
- private void invokeMethodWithModifiedClasspath(ReflectiveInvocationContext invocationContext, Classpath classpath) {
+ private void invokeMethodWithModifiedClasspath(ReflectiveInvocationContext invocationContext, Classpath classpath, ExecutableInvoker executableInvoker) {
ClassLoader modifiedClassLoader = classpath.newClassLoader();
ClassLoader currentThreadPreviousClassLoader = replaceCurrentThreadClassLoader(modifiedClassLoader);
@@ -52,9 +56,9 @@ private void invokeMethodWithModifiedClasspath(ReflectiveInvocationContext testClass;
try {
testClass = classLoader.loadClass(className);
@@ -81,10 +86,19 @@ private void invokeMethodWithModifiedClasspath(String className, String methodNa
throw new IllegalStateException("Cannot load test class [" + className + "] from modified classloader, verify that you did not exclude a path containing the test", e);
}
- Object testInstance = ReflectionUtils.newInstance(testClass);
- final Optional method = ReflectionUtils.findMethod(testClass, methodName);
- ReflectionUtils.invokeMethod(
- method.orElseThrow(() -> new IllegalStateException("No test method named " + methodName)),
- testInstance);
+ Constructor> constructor = testClass.getDeclaredConstructors()[0];
+ Object testInstance = executableInvoker.invoke(constructor);
+
+ String methodName = originalExecutable.getName();
+ int parameterCount = originalExecutable.getParameterCount();
+ List matchingMethods = ReflectionUtils.findMethods(testClass,
+ m -> m.getName().equals(methodName) && m.getParameterCount() == parameterCount);
+ if (matchingMethods.size() == 0) {
+ throw new IllegalStateException("No test method named " + methodName);
+ } else if (matchingMethods.size() > 1) {
+ logger.warn(() -> "Multiple test methods with name " + methodName + " and " + parameterCount + " parameters");
+ }
+
+ executableInvoker.invoke(matchingMethods.get(0), testInstance);
}
}
diff --git a/src/test/java/com/github/fridujo/classpath/junit/extension/OtherExtensionsCompatibilityTests.java b/src/test/java/com/github/fridujo/classpath/junit/extension/OtherExtensionsCompatibilityTests.java
new file mode 100644
index 0000000..caade23
--- /dev/null
+++ b/src/test/java/com/github/fridujo/classpath/junit/extension/OtherExtensionsCompatibilityTests.java
@@ -0,0 +1,106 @@
+package com.github.fridujo.classpath.junit.extension;
+
+import com.github.fridujo.classpath.junit.extension.jupiter.ModifiedClasspath;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.*;
+import org.junit.platform.commons.util.ReflectionUtils;
+
+import java.lang.annotation.*;
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatNoException;
+
+class OtherExtensionsCompatibilityTests {
+
+ private final String constructorInjectedParameter;
+
+ private final AtomicInteger mutatedByExtensionBeforeTest = new AtomicInteger(-2);
+
+ OtherExtensionsCompatibilityTests(@InjectStringWithRandomUUID String constructorInjectedParameter) {
+ this.constructorInjectedParameter = constructorInjectedParameter;
+ }
+
+ @Test
+ @ModifiedClasspath(addDependencies = "ch.qos.logback:logback-classic:1.2.3")
+ void parameter_resolver_extensions_are_triggered(@InjectStringWithRandomUUID String methodInjectedParameter) throws ClassNotFoundException {
+ assertThat(Class.forName("ch.qos.logback.core.Appender")).isInterface();
+ assertThatNoException().isThrownBy(() -> UUID.fromString(constructorInjectedParameter));
+ assertThatNoException().isThrownBy(() -> UUID.fromString(methodInjectedParameter));
+ }
+
+ @Test
+ @ModifiedClasspath(addDependencies = "ch.qos.logback:logback-classic:1.2.3")
+ @SetAtomicIntegerFieldTo(45)
+ @Disabled("")
+ void lifecycle_extensions_are_triggered() {
+ assertThat(mutatedByExtensionBeforeTest).hasValue(45);
+ }
+
+ @Target(ElementType.PARAMETER)
+ @Retention(RetentionPolicy.RUNTIME)
+ @Documented
+ @ExtendWith(InjectStringWithRandomUUID.InjectStringWithRandomUUIDExtension.class)
+ @interface InjectStringWithRandomUUID {
+ class InjectStringWithRandomUUIDExtension implements ParameterResolver {
+
+ @Override
+ public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
+ boolean match = Arrays.stream(parameterContext.getParameter().getAnnotations())
+ .anyMatch(a -> {
+ boolean equals = a.annotationType().getName().equals(InjectStringWithRandomUUID.class.getName());
+ return equals;
+ });
+ return match;
+ }
+
+ @Override
+ public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
+ return UUID.randomUUID().toString();
+ }
+ }
+ }
+
+ /**
+ * {@link ExecutableInvoker#invoke(Method, Object)}
+ */
+ @Target(ElementType.METHOD)
+ @Retention(RetentionPolicy.RUNTIME)
+ @Documented
+ @ExtendWith(SetAtomicIntegerFieldTo.SetAtomicIntegerToExtension.class)
+ @interface SetAtomicIntegerFieldTo {
+
+ int value();
+
+ class SetAtomicIntegerToExtension implements BeforeEachCallback {
+
+ @Override
+ public void beforeEach(ExtensionContext context) {
+ Method method = context.getTestMethod().get();
+ Arrays.stream(method.getAnnotations())
+ .filter(a -> a.annotationType().getName().equals(SetAtomicIntegerFieldTo.class.getName()))
+ .findFirst()
+ .ifPresent(a -> {
+ Method valueMethod = ReflectionUtils.findMethod(a.getClass(), "value").get();
+ int value = (int) ReflectionUtils.invokeMethod(valueMethod, a);
+
+ Class> testClass = context.getTestClass().get();
+ Object testInstance = context.getTestInstance().get();
+ Arrays.stream(testClass.getDeclaredFields())
+ .filter(f -> f.getType() == AtomicInteger.class)
+ .map(ReflectionUtils::makeAccessible)
+ .map(f -> ReflectionUtils.tryToReadFieldValue(f, testInstance).toOptional())
+ .filter(v -> v.isPresent())
+ .map(Optional::get)
+ .map(AtomicInteger.class::cast)
+ .forEach(ai -> ai.set(value));
+ });
+ }
+ }
+ }
+}