Skip to content

Provide extension composition #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
<maven-invoker.version>3.1.0</maven-invoker.version>
<maven-resolver.version>1.7.0</maven-resolver.version>

<junit.jupiter.version>5.7.1</junit.jupiter.version>
<junit.jupiter.version>5.9.0</junit.jupiter.version>
<assertj.version>3.19.0</assertj.version>
<mockito-core.version>3.10.0</mockito-core.version>
<log4j.version>2.14.1</log4j.version>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
package com.github.fridujo.classpath.junit.extension.jupiter;

import java.io.File;
import java.lang.reflect.Method;
import java.util.Optional;
import java.util.stream.Collectors;

import com.github.fridujo.classpath.junit.extension.Classpath;
import com.github.fridujo.classpath.junit.extension.PathElement;
import com.github.fridujo.classpath.junit.extension.buildtool.BuildTool;
import org.junit.jupiter.api.extension.ExecutableInvoker;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.InvocationInterceptor;
import org.junit.jupiter.api.extension.ReflectiveInvocationContext;
import org.junit.platform.commons.logging.Logger;
import org.junit.platform.commons.logging.LoggerFactory;
import org.junit.platform.commons.util.ReflectionUtils;

import com.github.fridujo.classpath.junit.extension.Classpath;
import com.github.fridujo.classpath.junit.extension.PathElement;
import com.github.fridujo.classpath.junit.extension.buildtool.BuildTool;
import java.io.File;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.util.List;
import java.util.stream.Collectors;

abstract class AbstractClasspathExtension implements InvocationInterceptor {

private final ExtensionContext.Namespace namespace = ExtensionContext.Namespace.create(AbstractClasspathExtension.class);
private static final Logger logger = LoggerFactory.getLogger(AbstractClasspathExtension.class);

@Override
public void interceptTestTemplateMethod(Invocation<Void> invocation,
ReflectiveInvocationContext<Method> invocationContext,
ExtensionContext extensionContext) throws Throwable {
ExtensionContext extensionContext) {
intercept(invocation, invocationContext, extensionContext);
}

Expand All @@ -39,22 +43,22 @@ private void intercept(Invocation<Void> invocation, ReflectiveInvocationContext<
Classpath classpath = supplyClasspath(invocationContext, extensionContext, buildTool);
BuildToolLocator.store(extensionContext, classpath.buildTool);

invokeMethodWithModifiedClasspath(invocationContext, classpath);
invokeMethodWithModifiedClasspath(invocationContext, classpath, extensionContext.getExecutableInvoker());
}

protected abstract Classpath supplyClasspath(ReflectiveInvocationContext<Method> invocationContext, ExtensionContext extensionContext, BuildTool buildTool);

private void invokeMethodWithModifiedClasspath(ReflectiveInvocationContext<Method> invocationContext, Classpath classpath) {
private void invokeMethodWithModifiedClasspath(ReflectiveInvocationContext<Method> invocationContext, Classpath classpath, ExecutableInvoker executableInvoker) {
ClassLoader modifiedClassLoader = classpath.newClassLoader();

ClassLoader currentThreadPreviousClassLoader = replaceCurrentThreadClassLoader(modifiedClassLoader);
String previousClassPathProperty = replaceClassPathProperty(classpath);

try {
invokeMethodWithModifiedClasspath(
invocationContext.getExecutable().getDeclaringClass().getName(),
invocationContext.getExecutable().getName(),
modifiedClassLoader);
invocationContext.getExecutable(),
modifiedClassLoader,
executableInvoker);
} finally {
System.setProperty(Classpath.SYSTEM_PROPERTY, previousClassPathProperty);
Thread.currentThread().setContextClassLoader(currentThreadPreviousClassLoader);
Expand All @@ -73,18 +77,28 @@ private String replaceClassPathProperty(Classpath classpath) {
return previousClassPathProperty;
}

private void invokeMethodWithModifiedClasspath(String className, String methodName, ClassLoader classLoader) {
private void invokeMethodWithModifiedClasspath(Executable originalExecutable, ClassLoader classLoader, ExecutableInvoker executableInvoker) {
String className = originalExecutable.getDeclaringClass().getName();
final Class<?> testClass;
try {
testClass = classLoader.loadClass(className);
} catch (ClassNotFoundException e) {
throw new IllegalStateException("Cannot load test class [" + className + "] from modified classloader, verify that you did not exclude a path containing the test", e);
}

Object testInstance = ReflectionUtils.newInstance(testClass);
final Optional<Method> method = ReflectionUtils.findMethod(testClass, methodName);
ReflectionUtils.invokeMethod(
method.orElseThrow(() -> new IllegalStateException("No test method named " + methodName)),
testInstance);
Constructor<?> constructor = testClass.getDeclaredConstructors()[0];
Object testInstance = executableInvoker.invoke(constructor);

String methodName = originalExecutable.getName();
int parameterCount = originalExecutable.getParameterCount();
List<Method> matchingMethods = ReflectionUtils.findMethods(testClass,
m -> m.getName().equals(methodName) && m.getParameterCount() == parameterCount);
if (matchingMethods.size() == 0) {
throw new IllegalStateException("No test method named " + methodName);
} else if (matchingMethods.size() > 1) {
logger.warn(() -> "Multiple test methods with name " + methodName + " and " + parameterCount + " parameters");
}

executableInvoker.invoke(matchingMethods.get(0), testInstance);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package com.github.fridujo.classpath.junit.extension;

import com.github.fridujo.classpath.junit.extension.jupiter.ModifiedClasspath;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.*;
import org.junit.platform.commons.util.ReflectionUtils;

import java.lang.annotation.*;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;

class OtherExtensionsCompatibilityTests {

private final String constructorInjectedParameter;

private final AtomicInteger mutatedByExtensionBeforeTest = new AtomicInteger(-2);

OtherExtensionsCompatibilityTests(@InjectStringWithRandomUUID String constructorInjectedParameter) {
this.constructorInjectedParameter = constructorInjectedParameter;
}

@Test
@ModifiedClasspath(addDependencies = "ch.qos.logback:logback-classic:1.2.3")
void parameter_resolver_extensions_are_triggered(@InjectStringWithRandomUUID String methodInjectedParameter) throws ClassNotFoundException {
assertThat(Class.forName("ch.qos.logback.core.Appender")).isInterface();
assertThatNoException().isThrownBy(() -> UUID.fromString(constructorInjectedParameter));
assertThatNoException().isThrownBy(() -> UUID.fromString(methodInjectedParameter));
}

@Test
@ModifiedClasspath(addDependencies = "ch.qos.logback:logback-classic:1.2.3")
@SetAtomicIntegerFieldTo(45)
@Disabled("")
void lifecycle_extensions_are_triggered() {
assertThat(mutatedByExtensionBeforeTest).hasValue(45);
}

@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@ExtendWith(InjectStringWithRandomUUID.InjectStringWithRandomUUIDExtension.class)
@interface InjectStringWithRandomUUID {
class InjectStringWithRandomUUIDExtension implements ParameterResolver {

@Override
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
boolean match = Arrays.stream(parameterContext.getParameter().getAnnotations())
.anyMatch(a -> {
boolean equals = a.annotationType().getName().equals(InjectStringWithRandomUUID.class.getName());
return equals;
});
return match;
}

@Override
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) throws ParameterResolutionException {
return UUID.randomUUID().toString();
}
}
}

/**
* {@link ExecutableInvoker#invoke(Method, Object)}
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@ExtendWith(SetAtomicIntegerFieldTo.SetAtomicIntegerToExtension.class)
@interface SetAtomicIntegerFieldTo {

int value();

class SetAtomicIntegerToExtension implements BeforeEachCallback {

@Override
public void beforeEach(ExtensionContext context) {
Method method = context.getTestMethod().get();
Arrays.stream(method.getAnnotations())
.filter(a -> a.annotationType().getName().equals(SetAtomicIntegerFieldTo.class.getName()))
.findFirst()
.ifPresent(a -> {
Method valueMethod = ReflectionUtils.findMethod(a.getClass(), "value").get();
int value = (int) ReflectionUtils.invokeMethod(valueMethod, a);

Class<?> testClass = context.getTestClass().get();
Object testInstance = context.getTestInstance().get();
Arrays.stream(testClass.getDeclaredFields())
.filter(f -> f.getType() == AtomicInteger.class)
.map(ReflectionUtils::makeAccessible)
.map(f -> ReflectionUtils.tryToReadFieldValue(f, testInstance).toOptional())
.filter(v -> v.isPresent())
.map(Optional::get)
.map(AtomicInteger.class::cast)
.forEach(ai -> ai.set(value));
});
}
}
}
}