Skip to content

Commit

Permalink
Improve ArchCondition evaluation #876
Browse files Browse the repository at this point in the history
At the moment evaluating `ArchConditions` is quite memory inefficient. E.g. for a rule like `noClasses().should().accessField(Foo.class, "bar")` we create a `ConditionEvent` for every single field access and qualify it with "allowed" or "violation". Thus, this is independent of some access actually being a violation but done for every access. Furthermore, all such events of all evaluated objects will be kept in memory until the test failure is reported. This will keep in the former case events for all field accesses of all classes in memory until the final failure (or success) is reported.

The background for this behavior is the complex evaluation logic of nested and inverted conditions. E.g. `noClasses()` will invert the events and thus make the events that were "allowed" before to be violations. To make matters worse we have combinations like `accessField(..).orShould()...` and aggregations like `noClasses().should().haveAny...` where `haveAny...` will be satisfied if there is any occurrence that supports the condition, but by negating it via `noClasses()` we effectively change the evaluation to "all occurrences must NOT satisfy the condition". So it is quite challenging to determine at which point the information about "allowed" events is finally obsolete, since we can arbitrarily aggregate and invert.

I tried to come up with a lazy solution, where we would use some `Supplier/Consumer` approach, but failed to come up with one that a) makes up a good public API for users, b) works for all those cases of aggregation including customized behavior by users and c) really saves substantial amounts of memory. So in the end I gave up on that and focused on optimizing the current approach.

This PR will address two points: a) improve the memory footprint of `ArchCondition` evaluation and b) improve the public API to make future optimizations easier and the code more maintainable (now is a very good opportunity since we are about to release a new major version, so we can break the public API a little bit).

## Performance Aspect

As for point a) the main optimization here is to throw out intermediate copying of events, and once we know we are done with event evaluation (i.e. we are now top-level about to create the result) to dump all references to now obsolete satisfied events. This had the following positive effect when testing `noClasses().should().accessField(System.class, "out")` on a big code base on my machine:

### Before

#### Heap allocation

![image](https://user-images.githubusercontent.com/4095015/171833975-0a77b41c-780d-4a45-bb03-4bfadbed0b21.png)

#### Performance

Evaluating the rule could be done 89 times in 60 seconds.

### After

#### Heap allocation

![image](https://user-images.githubusercontent.com/4095015/171834213-50f25b4d-1325-4169-902c-566cf6bcc64b.png)

#### Performance

Evaluating the rule could be done 135 times in 60 seconds.

### Result

So altogether we see a smoother memory allocation with less GCs and a performance improvement of about 50% when evaluating such conditions. Thus, I consider the change worthwile.

## API Aspect

Before the public API included a final class `ConditionEvents`, this is now an interface to allow different implementations to be used internally. Also the API of this interface has been stripped of all non-essential methods, like `getAllowed()` which would return the allowed events and force us to always keep a reference to satisfied events as long as we hold `ConditionEvents`. I tried to strip the API down to really just the essentials, e.g. `containViolation()` or `getViolating()` and remove all methods that don't make sense in all contexts.
  • Loading branch information
codecholeric authored Jun 25, 2022
2 parents 681fdf8 + 64ddbe4 commit 19643f5
Show file tree
Hide file tree
Showing 34 changed files with 803 additions and 713 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import com.tngtech.archunit.core.domain.JavaFieldAccess;
import com.tngtech.archunit.core.domain.JavaMethodCall;
import com.tngtech.archunit.lang.EvaluationResult;
import com.tngtech.archunit.lang.ViolationHandler;
import com.tngtech.archunit.testutils.ExpectedAccess.ExpectedCall;
import com.tngtech.archunit.testutils.ExpectedAccess.ExpectedFieldAccess;

Expand All @@ -26,8 +25,6 @@
import static java.util.Collections.singleton;
import static java.util.stream.Collectors.toSet;

@SuppressWarnings("Convert2Lambda")
// We need the reified type parameter here
class HandlingAssertion {
private final Set<ExpectedRelation> expectedFieldAccesses;
private final Set<ExpectedRelation> expectedMethodCalls;
Expand Down Expand Up @@ -81,53 +78,35 @@ Result evaluate(EvaluationResult evaluationResult) {
return result;
}

// Too bad Java erases types, otherwise a lot of boilerplate could be avoided here :-(
// This way we must write explicitly ViolationHandler<ConcreteType> or the bound wont work correctly
private Set<String> evaluateFieldAccesses(EvaluationResult result) {
Set<String> errorMessages = new HashSet<>();
final Set<ExpectedRelation> left = new HashSet<>(this.expectedFieldAccesses);
result.handleViolations(new ViolationHandler<JavaFieldAccess>() {
@Override
public void handle(Collection<JavaFieldAccess> violatingObjects, String message) {
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left));
}
});
result.handleViolations((Collection<JavaFieldAccess> violatingObjects, String message) ->
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left)));
return union(errorMessages, errorMessagesFrom(left));
}

private Set<String> evaluateMethodCalls(EvaluationResult result) {
Set<String> errorMessages = new HashSet<>();
final Set<ExpectedRelation> left = new HashSet<>(expectedMethodCalls);
result.handleViolations(new ViolationHandler<JavaMethodCall>() {
@Override
public void handle(Collection<JavaMethodCall> violatingObjects, String message) {
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left));
}
});
result.handleViolations((Collection<JavaMethodCall> violatingObjects, String message) ->
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left)));
return union(errorMessages, errorMessagesFrom(left));
}

private Set<String> evaluateConstructorCalls(EvaluationResult result) {
Set<String> errorMessages = new HashSet<>();
final Set<ExpectedRelation> left = new HashSet<>(expectedConstructorCalls);
result.handleViolations(new ViolationHandler<JavaConstructorCall>() {
@Override
public void handle(Collection<JavaConstructorCall> violatingObjects, String message) {
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left));
}
});
result.handleViolations((Collection<JavaConstructorCall> violatingObjects, String message) ->
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left)));
return union(errorMessages, errorMessagesFrom(left));
}

private Set<String> evaluateCalls(EvaluationResult result) {
Set<String> errorMessages = new HashSet<>();
final Set<ExpectedRelation> left = new HashSet<>(Sets.union(expectedConstructorCalls, expectedMethodCalls));
result.handleViolations(new ViolationHandler<JavaCall<?>>() {
@Override
public void handle(Collection<JavaCall<?>> violatingObjects, String message) {
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left));
}
});
result.handleViolations((Collection<JavaCall<?>> violatingObjects, String message) ->
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left)));
return union(errorMessages, errorMessagesFrom(left));
}

Expand All @@ -140,24 +119,16 @@ private Set<String> evaluateAccesses(EvaluationResult result) {
addAll(expectedFieldAccesses);
}
};
result.handleViolations(new ViolationHandler<JavaAccess<?>>() {
@Override
public void handle(Collection<JavaAccess<?>> violatingObjects, String message) {
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left));
}
});
result.handleViolations((Collection<JavaAccess<?>> violatingObjects, String message) ->
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left)));
return union(errorMessages, errorMessagesFrom(left));
}

private Set<String> evaluateDependencies(EvaluationResult result) {
Set<String> errorMessages = new HashSet<>();
final Set<ExpectedRelation> left = new HashSet<>(expectedDependencies);
result.handleViolations(new ViolationHandler<Dependency>() {
@Override
public void handle(Collection<Dependency> violatingObjects, String message) {
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left));
}
});
result.handleViolations((Collection<Dependency> violatingObjects, String message) ->
errorMessages.addAll(removeExpectedAccesses(violatingObjects, left)));
return union(errorMessages, errorMessagesFrom(left));
}

Expand Down
199 changes: 3 additions & 196 deletions archunit/src/main/java/com/tngtech/archunit/lang/ArchCondition.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,11 @@
package com.tngtech.archunit.lang;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;

import com.google.common.base.Joiner;
import com.google.common.base.MoreObjects;
import com.google.common.collect.ImmutableList;
import com.tngtech.archunit.PublicAPI;
import com.tngtech.archunit.lang.conditions.ArchConditions;

import static com.tngtech.archunit.PublicAPI.Usage.INHERITANCE;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

@PublicAPI(usage = INHERITANCE)
public abstract class ArchCondition<T> {
Expand Down Expand Up @@ -60,11 +52,11 @@ public void finish(ConditionEvents events) {
}

public ArchCondition<T> and(ArchCondition<? super T> condition) {
return new AndCondition<>(this, condition.forSubtype());
return ArchConditions.and(this, condition.forSubtype());
}

public ArchCondition<T> or(ArchCondition<? super T> condition) {
return new OrCondition<>(this, condition.forSubtype());
return ArchConditions.or(this, condition.forSubtype());
}

public String getDescription() {
Expand Down Expand Up @@ -99,189 +91,4 @@ public String toString() {
public <U extends T> ArchCondition<U> forSubtype() {
return (ArchCondition<U>) this;
}

private abstract static class JoinCondition<T> extends ArchCondition<T> {
private final Collection<ArchCondition<T>> conditions;

private JoinCondition(String infix, Collection<ArchCondition<T>> conditions) {
super(joinDescriptionsOf(infix, conditions));
this.conditions = conditions;
}

private static <T> String joinDescriptionsOf(String infix, Collection<ArchCondition<T>> conditions) {
return conditions.stream().map(ArchCondition::getDescription).collect(joining(" " + infix + " "));
}

@Override
public void init(Collection<T> allObjectsToTest) {
for (ArchCondition<T> condition : conditions) {
condition.init(allObjectsToTest);
}
}

@Override
public void finish(ConditionEvents events) {
for (ArchCondition<T> condition : conditions) {
condition.finish(events);
}
}

List<ConditionWithEvents<T>> evaluateConditions(T item) {
return conditions.stream().map(condition -> new ConditionWithEvents<>(condition, item)).collect(toList());
}

@Override
public String toString() {
return getClass().getSimpleName() + "{" + conditions + "}";
}
}

private static class ConditionWithEvents<T> {
private final ArchCondition<T> condition;
private final ConditionEvents events;

ConditionWithEvents(ArchCondition<T> condition, T item) {
this(condition, check(condition, item));
}

ConditionWithEvents(ArchCondition<T> condition, ConditionEvents events) {
this.condition = condition;
this.events = events;
}

private static <T> ConditionEvents check(ArchCondition<T> condition, T item) {
ConditionEvents events = new ConditionEvents();
condition.check(item, events);
return events;
}

@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("condition", condition)
.add("events", events)
.toString();
}
}

private abstract static class JoinConditionEvent<T> implements ConditionEvent {
final T correspondingObject;
final List<ConditionWithEvents<T>> evaluatedConditions;

JoinConditionEvent(T correspondingObject, List<ConditionWithEvents<T>> evaluatedConditions) {
this.correspondingObject = correspondingObject;
this.evaluatedConditions = evaluatedConditions;
}

List<String> getUniqueLinesOfViolations() { // TODO: Sort by line number, then lexicographically
final Set<String> result = new TreeSet<>();
for (ConditionWithEvents<T> evaluation : evaluatedConditions) {
for (ConditionEvent event : evaluation.events) {
if (event.isViolation()) {
result.addAll(event.getDescriptionLines());
}
}
}
return ImmutableList.copyOf(result);
}

@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("evaluatedConditions", evaluatedConditions)
.toString();
}

List<ConditionWithEvents<T>> invert(List<ConditionWithEvents<T>> evaluatedConditions) {
return evaluatedConditions.stream().map(this::invert).collect(toList());
}

private ConditionWithEvents<T> invert(ConditionWithEvents<T> evaluation) {
ConditionEvents invertedEvents = new ConditionEvents();
for (ConditionEvent event : evaluation.events) {
event.addInvertedTo(invertedEvents);
}
return new ConditionWithEvents<>(evaluation.condition, invertedEvents);
}
}

private static class AndCondition<T> extends JoinCondition<T> {
private AndCondition(ArchCondition<T> first, ArchCondition<T> second) {
super("and", ImmutableList.of(first, second));
}

@Override
public void check(T item, ConditionEvents events) {
events.add(new AndConditionEvent<>(item, evaluateConditions(item)));
}
}

private static class OrCondition<T> extends JoinCondition<T> {
private OrCondition(ArchCondition<T> first, ArchCondition<T> second) {
super("or", ImmutableList.of(first, second));
}

@Override
public void check(T item, ConditionEvents events) {
events.add(new OrConditionEvent<>(item, evaluateConditions(item)));
}
}

private static class AndConditionEvent<T> extends JoinConditionEvent<T> {
AndConditionEvent(T item, List<ConditionWithEvents<T>> evaluatedConditions) {
super(item, evaluatedConditions);
}

@Override
public boolean isViolation() {
return evaluatedConditions.stream().anyMatch(evaluation -> evaluation.events.containViolation());
}

@Override
public void addInvertedTo(ConditionEvents events) {
events.add(new OrConditionEvent<>(correspondingObject, invert(evaluatedConditions)));
}

@Override
public List<String> getDescriptionLines() {
return getUniqueLinesOfViolations();
}

@Override
public void handleWith(final Handler handler) {
for (ConditionWithEvents<T> condition : evaluatedConditions) {
condition.events.handleViolations(handler::handle);
}
}
}

private static class OrConditionEvent<T> extends JoinConditionEvent<T> {
OrConditionEvent(T item, List<ConditionWithEvents<T>> evaluatedConditions) {
super(item, evaluatedConditions);
}

@Override
public boolean isViolation() {
return evaluatedConditions.stream().allMatch(evaluation -> evaluation.events.containViolation());
}

@Override
public void addInvertedTo(ConditionEvents events) {
events.add(new AndConditionEvent<>(correspondingObject, invert(evaluatedConditions)));
}

@Override
public List<String> getDescriptionLines() {
return ImmutableList.of(createMessage());
}

private String createMessage() {
return Joiner.on(" and ").join(getUniqueLinesOfViolations());
}

@Override
public void handleWith(final Handler handler) {
handler.handle(Collections.singleton(correspondingObject), createMessage());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ public EvaluationResult evaluate(JavaClasses classes) {
verifyNoEmptyShouldIfEnabled(allObjects);

condition.init(allObjects);
ConditionEvents events = new ConditionEvents();
ConditionEvents events = ConditionEvents.Factory.create();
for (T object : allObjects) {
condition.check(object, events);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
import static com.tngtech.archunit.PublicAPI.Usage.ACCESS;
import static com.tngtech.archunit.PublicAPI.Usage.INHERITANCE;

/**
* An event that occurred while checking an {@link ArchCondition}. This can either be a {@link #isViolation() violation}
* or be allowed. An event that is allowed will turn into a violation if it is {@link #invert() inverted}
* (e.g. for negation of the rule).
*/
@PublicAPI(usage = INHERITANCE)
public interface ConditionEvent {
/**
Expand All @@ -33,15 +38,13 @@ public interface ConditionEvent {
boolean isViolation();

/**
* Adds the 'opposite' of the event. <br>
* E.g. <i>The event is a violation, if some conditions A and B are both true?</i>
* <br> {@literal ->} <i>The 'inverted' event is a violation if either A or B (or both) are not true</i><br>
* In the most simple case, this is just an equivalent event evaluating {@link #isViolation()}
* inverted.
*
* @param events The events to add the 'inverted self' to
* @return the 'opposite' of the event. <br>
* Assume e.g. <i>The event is a violation, if some conditions A and B are both true</i>
* <br> {@literal =>} <i>The 'inverted' event is a violation if either A or B (or both) are not true</i><br>
* In the most simple case, this is just an equivalent event evaluating {@link #isViolation()}
* inverted.
*/
void addInvertedTo(ConditionEvents events);
ConditionEvent invert();

/**
* @return A textual description of this event as a list of lines
Expand Down
Loading

0 comments on commit 19643f5

Please sign in to comment.