diff --git a/pom.xml b/pom.xml index 039c29d..cbba88c 100644 --- a/pom.xml +++ b/pom.xml @@ -21,7 +21,7 @@ 3.1.0 1.7.0 - 5.7.1 + 5.9.0 3.19.0 3.10.0 2.14.1 diff --git a/src/main/java/com/github/fridujo/classpath/junit/extension/jupiter/AbstractClasspathExtension.java b/src/main/java/com/github/fridujo/classpath/junit/extension/jupiter/AbstractClasspathExtension.java index b638bd2..8f1b5b4 100644 --- a/src/main/java/com/github/fridujo/classpath/junit/extension/jupiter/AbstractClasspathExtension.java +++ b/src/main/java/com/github/fridujo/classpath/junit/extension/jupiter/AbstractClasspathExtension.java @@ -1,27 +1,31 @@ package com.github.fridujo.classpath.junit.extension.jupiter; -import java.io.File; -import java.lang.reflect.Method; -import java.util.Optional; -import java.util.stream.Collectors; - +import com.github.fridujo.classpath.junit.extension.Classpath; +import com.github.fridujo.classpath.junit.extension.PathElement; +import com.github.fridujo.classpath.junit.extension.buildtool.BuildTool; +import org.junit.jupiter.api.extension.ExecutableInvoker; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.InvocationInterceptor; import org.junit.jupiter.api.extension.ReflectiveInvocationContext; +import org.junit.platform.commons.logging.Logger; +import org.junit.platform.commons.logging.LoggerFactory; import org.junit.platform.commons.util.ReflectionUtils; -import com.github.fridujo.classpath.junit.extension.Classpath; -import com.github.fridujo.classpath.junit.extension.PathElement; -import com.github.fridujo.classpath.junit.extension.buildtool.BuildTool; +import java.io.File; +import java.lang.reflect.Constructor; +import java.lang.reflect.Executable; +import java.lang.reflect.Method; +import java.util.List; +import java.util.stream.Collectors; abstract class AbstractClasspathExtension implements InvocationInterceptor { - private final ExtensionContext.Namespace namespace = ExtensionContext.Namespace.create(AbstractClasspathExtension.class); + private static final Logger logger = LoggerFactory.getLogger(AbstractClasspathExtension.class); @Override public void interceptTestTemplateMethod(Invocation invocation, ReflectiveInvocationContext invocationContext, - ExtensionContext extensionContext) throws Throwable { + ExtensionContext extensionContext) { intercept(invocation, invocationContext, extensionContext); } @@ -39,12 +43,12 @@ private void intercept(Invocation invocation, ReflectiveInvocationContext< Classpath classpath = supplyClasspath(invocationContext, extensionContext, buildTool); BuildToolLocator.store(extensionContext, classpath.buildTool); - invokeMethodWithModifiedClasspath(invocationContext, classpath); + invokeMethodWithModifiedClasspath(invocationContext, classpath, extensionContext.getExecutableInvoker()); } protected abstract Classpath supplyClasspath(ReflectiveInvocationContext invocationContext, ExtensionContext extensionContext, BuildTool buildTool); - private void invokeMethodWithModifiedClasspath(ReflectiveInvocationContext invocationContext, Classpath classpath) { + private void invokeMethodWithModifiedClasspath(ReflectiveInvocationContext invocationContext, Classpath classpath, ExecutableInvoker executableInvoker) { ClassLoader modifiedClassLoader = classpath.newClassLoader(); ClassLoader currentThreadPreviousClassLoader = replaceCurrentThreadClassLoader(modifiedClassLoader); @@ -52,9 +56,9 @@ private void invokeMethodWithModifiedClasspath(ReflectiveInvocationContext testClass; try { testClass = classLoader.loadClass(className); @@ -81,10 +86,19 @@ private void invokeMethodWithModifiedClasspath(String className, String methodNa throw new IllegalStateException("Cannot load test class [" + className + "] from modified classloader, verify that you did not exclude a path containing the test", e); } - Object testInstance = ReflectionUtils.newInstance(testClass); - final Optional method = ReflectionUtils.findMethod(testClass, methodName); - ReflectionUtils.invokeMethod( - method.orElseThrow(() -> new IllegalStateException("No test method named " + methodName)), - testInstance); + Constructor constructor = testClass.getDeclaredConstructors()[0]; + Object testInstance = executableInvoker.invoke(constructor); + + String methodName = originalExecutable.getName(); + int parameterCount = originalExecutable.getParameterCount(); + List matchingMethods = ReflectionUtils.findMethods(testClass, + m -> m.getName().equals(methodName) && m.getParameterCount() == parameterCount); + if (matchingMethods.size() == 0) { + throw new IllegalStateException("No test method named " + methodName); + } else if (matchingMethods.size() > 1) { + logger.warn(() -> "Multiple test methods with name " + methodName + " and " + parameterCount + " parameters"); + } + + executableInvoker.invoke(matchingMethods.get(0), testInstance); } } diff --git a/src/test/java/com/github/fridujo/classpath/junit/extension/OtherExtensionsCompatibilityTests.java b/src/test/java/com/github/fridujo/classpath/junit/extension/OtherExtensionsCompatibilityTests.java new file mode 100644 index 0000000..caade23 --- /dev/null +++ b/src/test/java/com/github/fridujo/classpath/junit/extension/OtherExtensionsCompatibilityTests.java @@ -0,0 +1,106 @@ +package com.github.fridujo.classpath.junit.extension; + +import com.github.fridujo.classpath.junit.extension.jupiter.ModifiedClasspath; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.*; +import org.junit.platform.commons.util.ReflectionUtils; + +import java.lang.annotation.*; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNoException; + +class OtherExtensionsCompatibilityTests { + + private final String constructorInjectedParameter; + + private final AtomicInteger mutatedByExtensionBeforeTest = new AtomicInteger(-2); + + OtherExtensionsCompatibilityTests(@InjectStringWithRandomUUID String constructorInjectedParameter) { + this.constructorInjectedParameter = constructorInjectedParameter; + } + + @Test + @ModifiedClasspath(addDependencies = "ch.qos.logback:logback-classic:1.2.3") + void parameter_resolver_extensions_are_triggered(@InjectStringWithRandomUUID String methodInjectedParameter) throws ClassNotFoundException { + assertThat(Class.forName("ch.qos.logback.core.Appender")).isInterface(); + assertThatNoException().isThrownBy(() -> UUID.fromString(constructorInjectedParameter)); + assertThatNoException().isThrownBy(() -> UUID.fromString(methodInjectedParameter)); + } + + @Test + @ModifiedClasspath(addDependencies = "ch.qos.logback:logback-classic:1.2.3") + @SetAtomicIntegerFieldTo(45) + @Disabled("") + void lifecycle_extensions_are_triggered() { + assertThat(mutatedByExtensionBeforeTest).hasValue(45); + } + + @Target(ElementType.PARAMETER) + @Retention(RetentionPolicy.RUNTIME) + @Documented + @ExtendWith(InjectStringWithRandomUUID.InjectStringWithRandomUUIDExtension.class) + @interface InjectStringWithRandomUUID { + class InjectStringWithRandomUUIDExtension implements ParameterResolver { + + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { + boolean match = Arrays.stream(parameterContext.getParameter().getAnnotations()) + .anyMatch(a -> { + boolean equals = a.annotationType().getName().equals(InjectStringWithRandomUUID.class.getName()); + return equals; + }); + return match; + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException { + return UUID.randomUUID().toString(); + } + } + } + + /** + * {@link ExecutableInvoker#invoke(Method, Object)} + */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @Documented + @ExtendWith(SetAtomicIntegerFieldTo.SetAtomicIntegerToExtension.class) + @interface SetAtomicIntegerFieldTo { + + int value(); + + class SetAtomicIntegerToExtension implements BeforeEachCallback { + + @Override + public void beforeEach(ExtensionContext context) { + Method method = context.getTestMethod().get(); + Arrays.stream(method.getAnnotations()) + .filter(a -> a.annotationType().getName().equals(SetAtomicIntegerFieldTo.class.getName())) + .findFirst() + .ifPresent(a -> { + Method valueMethod = ReflectionUtils.findMethod(a.getClass(), "value").get(); + int value = (int) ReflectionUtils.invokeMethod(valueMethod, a); + + Class testClass = context.getTestClass().get(); + Object testInstance = context.getTestInstance().get(); + Arrays.stream(testClass.getDeclaredFields()) + .filter(f -> f.getType() == AtomicInteger.class) + .map(ReflectionUtils::makeAccessible) + .map(f -> ReflectionUtils.tryToReadFieldValue(f, testInstance).toOptional()) + .filter(v -> v.isPresent()) + .map(Optional::get) + .map(AtomicInteger.class::cast) + .forEach(ai -> ai.set(value)); + }); + } + } + } +}