Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions core/src/main/java/dev/morphia/mapping/codec/DecodeSession.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package dev.morphia.mapping.codec;

import java.util.HashMap;
import java.util.Map;

import com.mongodb.lang.Nullable;

import dev.morphia.annotations.internal.MorphiaInternal;

/**
* Per-cursor-decode cache that maps (collection, id) → entity instance.
* Activated via {@link DecodeSession#activate()} at cursor creation and cleared via
* {@link DecodeSession#deactivate()} when the cursor is closed, so all documents in
* one query iteration share the same cache.
*
* @hidden
* @morphia.internal
*/
@MorphiaInternal
public class DecodeSession {
private static final ThreadLocal<DecodeSession> CURRENT = new ThreadLocal<>();

private final Map<String, Map<Object, Object>> cache = new HashMap<>();

private DecodeSession() {
}

/**
* Activates a session on the current thread. If a session is already active it is
* reused, so nested activations (e.g. fetching a @Reference while decoding an outer
* document) share one cache. Returns {@code true} if this call created the root session
* and therefore owns the responsibility of calling {@link #deactivate()}.
*
* @return {@code true} if a new root session was created; {@code false} if an existing session was reused
*/
public static boolean activate() {
if (CURRENT.get() != null) {
return false;
}
CURRENT.set(new DecodeSession());
return true;
}

/**
* Returns the session active on the current thread, or {@code null} if none.
*/
@Nullable
public static DecodeSession current() {
return CURRENT.get();
}

/**
* Removes the session from the current thread.
*/
public static void deactivate() {
CURRENT.remove();
}

/**
* Stores a decoded entity in the cache.
*
* @param collection the MongoDB collection name
* @param id the entity's {@code _id} value
* @param entity the decoded entity instance
*/
public void register(String collection, Object id, Object entity) {
cache.computeIfAbsent(collection, k -> new HashMap<>()).put(id, entity);
}

/**
* Returns a previously cached entity, or {@code null} if not present.
*
* @param collection the MongoDB collection name
* @param id the entity's {@code _id} value
*/
@Nullable
public Object lookup(String collection, Object id) {
Map<Object, Object> col = cache.get(collection);
return col != null ? col.get(id) : null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import dev.morphia.annotations.internal.MorphiaInternal;
import dev.morphia.mapping.DiscriminatorLookup;
import dev.morphia.mapping.codec.DecodeSession;
import dev.morphia.mapping.codec.MorphiaInstanceCreator;

import org.bson.BsonInvalidOperationException;
Expand Down Expand Up @@ -45,7 +46,29 @@ public T decode(BsonReader reader, DecoderContext decoderContext) {
if (decoderContext.hasCheckedDiscriminator()) {
LOG.debug(format("Decoding document using codec for %s'", morphiaCodec.getEntityModel().getType().getName()));
MorphiaInstanceCreator instanceCreator = getInstanceCreator();
T instance = (T) instanceCreator.getInstance();

DecodeSession session = DecodeSession.current();
Object prereadId = null;
if (session != null) {
prereadId = peekId(reader);
if (prereadId != null) {
session.register(classModel.collectionName(), prereadId, instance);
}
}

decodeProperties(reader, decoderContext, instanceCreator, classModel);
Comment on lines +49 to 60

if (session != null && prereadId == null) {
PropertyModel idProp = classModel.getIdProperty();
if (idProp != null) {
Comment on lines +63 to +64
Object id = morphiaCodec.getDatastore().getMapper().getId(instance);
if (id != null) {
session.register(classModel.collectionName(), id, instance);
}
}
}
Comment on lines +51 to +70

return (T) instanceCreator.getInstance();
} else {
entity = getCodecFromDocument(reader, classModel.useDiscriminator(), classModel.discriminatorKey(),
Expand Down Expand Up @@ -117,6 +140,34 @@ protected Codec<T> getCodecFromDocument(BsonReader reader, boolean useDiscrimina
return codec != null ? codec : defaultCodec;
}

@Nullable
private Object peekId(BsonReader reader) {
BsonReaderMark mark = reader.getMark();
try {
reader.readStartDocument();
String idName = classModel.getIdProperty() != null
? classModel.getIdProperty().getMappedName()
: "_id";
while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) {
String name = reader.readName();
if ("_id".equals(name) || name.equals(idName)) {
return morphiaCodec.getRegistry()
.get(Object.class)
.decode(reader, DecoderContext.builder().build());
} else {
reader.skipValue();
}
}
return null;
} catch (Exception e) {
LOG.debug("Could not pre-read _id for DecodeSession on {}; cycle detection may not apply",
classModel.getType().getSimpleName(), e);
return null;
} finally {
mark.reset();
}
}

protected MorphiaInstanceCreator getInstanceCreator() {
return classModel.getInstanceCreator(morphiaCodec.getConversions());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dev.morphia.mapping.DiscriminatorLookup;
import dev.morphia.mapping.MappingException;
import dev.morphia.mapping.codec.Conversions;
import dev.morphia.mapping.codec.DecodeSession;
import dev.morphia.mapping.codec.PropertyCodecRegistryImpl;
import dev.morphia.sofia.Sofia;

Expand Down Expand Up @@ -77,7 +78,14 @@ public MorphiaCodec(MorphiaDatastore datastore, EntityModel model,

@Override
public T decode(BsonReader reader, DecoderContext decoderContext) {
return getDecoder().decode(reader, decoderContext);
boolean root = DecodeSession.activate();
try {
return getDecoder().decode(reader, decoderContext);
} finally {
if (root) {
DecodeSession.deactivate();
}
}
Comment on lines +81 to +88
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import dev.morphia.annotations.internal.MorphiaInternal;
import dev.morphia.mapping.Mapper;
import dev.morphia.mapping.MappingException;
import dev.morphia.mapping.codec.DecodeSession;
import dev.morphia.mapping.codec.pojo.EntityModel;
import dev.morphia.mapping.codec.pojo.PropertyHandler;
import dev.morphia.mapping.codec.pojo.PropertyModel;
Expand Down Expand Up @@ -322,6 +323,64 @@ private <T> Class<T> makeProxy() {
.getLoaded();
}

@Nullable
private Object lookupInSession(Object id, EntityModel entityModel) {
DecodeSession session = DecodeSession.current();
if (session == null) {
return null;
}
String collection = id instanceof DBRef
? ((DBRef) id).getCollectionName()
: entityModel.collectionName();
Object lookupId = id instanceof DBRef ? ((DBRef) id).getId() : id;
return session.lookup(collection, lookupId);
}

/**
* Returns a map of stripped-id → cached entity for each id in rawIds that is present
* in the current session. Ids not yet in the session are absent from the result.
*/
private Map<Object, Object> partialSessionLookup(List<?> rawIds, EntityModel entityModel) {
DecodeSession session = DecodeSession.current();
if (session == null) {
return Map.of();
}
Map<Object, Object> hits = new LinkedHashMap<>();
for (Object id : rawIds) {
String collection = id instanceof DBRef ? ((DBRef) id).getCollectionName() : entityModel.collectionName();
Object lookupId = id instanceof DBRef ? ((DBRef) id).getId() : id;
Object cached = session.lookup(collection, lookupId);
if (cached != null) {
hits.put(lookupId, cached);
}
}
return hits;
}

/**
* Returns a map of stripped-id → cached entity for each map value id that is present
* in the current session. Returns null if no session is active.
*/
@Nullable
private Map<Object, Object> lookupMapInSession(Map<Object, Object> ids, EntityModel entityModel) {
DecodeSession session = DecodeSession.current();
if (session == null) {
return null;
}
Map<Object, Object> result = new LinkedHashMap<>();
for (Entry<Object, Object> entry : ids.entrySet()) {
Object rawId = entry.getValue();
String collection = rawId instanceof DBRef ? ((DBRef) rawId).getCollectionName() : entityModel.collectionName();
Object lookupId = rawId instanceof DBRef ? ((DBRef) rawId).getId() : rawId;
Object cached = session.lookup(collection, lookupId);
if (cached == null) {
return null; // any miss: fall through to full DB fetch
}
result.put(entry.getKey(), cached);
}
return result;
}

@Nullable
private Object fetch(Object value) {
boolean lazy = annotation.lazy();
Expand All @@ -336,7 +395,7 @@ private Object fetch(Object value) {
return preDecoded;
}
List<Object> ids = stripDbRefs(rawIds);
Supplier<Object> loader = () -> fetchCollection(rawIds, entityModel, ignoreMissing);
Supplier<Object> loader = () -> fetchCollectionMerged(rawIds, entityModel, ignoreMissing);
return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get();

} else if (Set.class.isAssignableFrom(type)) {
Expand All @@ -346,7 +405,7 @@ private Object fetch(Object value) {
return new LinkedHashSet<>(preDecoded);
}
List<Object> ids = stripDbRefs(rawIds);
Supplier<Object> loader = () -> new LinkedHashSet<>(fetchCollection(rawIds, entityModel, ignoreMissing));
Supplier<Object> loader = () -> new LinkedHashSet<>(fetchCollectionMerged(rawIds, entityModel, ignoreMissing));
return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get();

} else if (Map.class.isAssignableFrom(type)) {
Expand All @@ -357,6 +416,10 @@ private Object fetch(Object value) {
ids.put(mapper.getConversions().convert(entry.getKey(), keyType), entry.getValue());
}
List<Object> idList = stripDbRefs(new ArrayList<>(ids.values()));
Map<Object, Object> cachedMap = lookupMapInSession(ids, entityModel);
if (cachedMap != null) {
return cachedMap;
}
Supplier<Object> loader = () -> fetchMap(ids, entityModel);
return lazy ? createProxy(loader, idList, entityModel.getType()) : loader.get();

Expand All @@ -366,6 +429,10 @@ private Object fetch(Object value) {
if (entityModel.getType().isInstance(id)) {
return id;
}
Object cached = lookupInSession(id, entityModel);
if (cached != null) {
return cached;
}
List<Object> ids = List.of(stripDbRef(id));
Supplier<Object> loader = () -> fetchSingle(id, entityModel, ignoreMissing);
return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get();
Expand All @@ -376,20 +443,40 @@ private Object fetchSingle(Object id, EntityModel entityModel, boolean ignoreMis
var query = id instanceof DBRef
? datastore.find(mapper.getClassFromCollection(((DBRef) id).getCollectionName()))
: datastore.find(entityModel.getType());
Object result = query.filter(eq("_id", stripDbRef(id))).iterator().tryNext();
if (result == null && !ignoreMissing) {
throw new ReferenceException(Sofia.missingReferencedEntity(entityModel.getType().getSimpleName()));
try (var cursor = query.filter(eq("_id", stripDbRef(id))).iterator()) {
Object result = cursor.tryNext();
if (result == null && !ignoreMissing) {
throw new ReferenceException(Sofia.missingReferencedEntity(entityModel.getType().getSimpleName()));
}
return result;
}
return result;
}

private List<Object> fetchCollection(List<?> ids, EntityModel entityModel, boolean ignoreMissing) {
private List<Object> fetchCollectionMerged(List<?> rawIds, EntityModel entityModel, boolean ignoreMissing) {
Map<Object, Object> idMap = new HashMap<>(partialSessionLookup(rawIds, entityModel));

List<Object> missingIds = rawIds.stream()
.filter(id -> {
Object lookupId = id instanceof DBRef ? ((DBRef) id).getId() : id;
return !idMap.containsKey(lookupId);
})
.collect(Collectors.toList());

if (!missingIds.isEmpty()) {
idMap.putAll(buildIdMap(missingIds, entityModel, ignoreMissing));
}

return mapIdsToValues(rawIds, idMap).stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
}

private Map<Object, Object> buildIdMap(List<?> ids, EntityModel entityModel, boolean ignoreMissing) {
Map<String, List<Object>> byCollection = new HashMap<>();
for (Object id : ids) {
if (id instanceof DBRef) {
byCollection.computeIfAbsent(((DBRef) id).getCollectionName(), k -> new ArrayList<>()).add(((DBRef) id).getId());
} else {
// nested List items are stored as-is; extractFlatIds flattens them for the query
byCollection.computeIfAbsent(entityModel.collectionName(), k -> new ArrayList<>()).add(id);
}
}
Expand All @@ -398,8 +485,11 @@ private List<Object> fetchCollection(List<?> ids, EntityModel entityModel, boole
for (Entry<String, List<Object>> entry : byCollection.entrySet()) {
idMap.putAll(queryCollection(entry.getKey(), extractFlatIds(entry.getValue()), entityModel, ignoreMissing));
}
return idMap;
}

return mapIdsToValues(ids, idMap).stream()
private List<Object> fetchCollection(List<?> ids, EntityModel entityModel, boolean ignoreMissing) {
return mapIdsToValues(ids, buildIdMap(ids, entityModel, ignoreMissing)).stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
}
Expand Down
15 changes: 12 additions & 3 deletions core/src/main/java/dev/morphia/query/MorphiaCursor.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
import com.mongodb.lang.NonNull;

import dev.morphia.annotations.internal.MorphiaInternal;
import dev.morphia.mapping.codec.DecodeSession;

/**
* @param <T> the original type being iterated
* @since 2.2
*/
public class MorphiaCursor<T> implements AutoCloseable, MongoCursor<T> {
private final MongoCursor<T> wrapped;
private final boolean rootSession;

/**
* Creates a MorphiaCursor
Expand All @@ -27,13 +29,20 @@ public class MorphiaCursor<T> implements AutoCloseable, MongoCursor<T> {
@MorphiaInternal
public MorphiaCursor(MongoCursor<T> cursor) {
wrapped = cursor;
rootSession = DecodeSession.activate();
}

/**
* Closes the underlying cursor.
* Closes the underlying cursor and releases the decode session if this cursor owns it.
*/
public void close() {
wrapped.close();
try {
wrapped.close();
} finally {
if (rootSession) {
DecodeSession.deactivate();
}
}
}

@Override
Expand Down Expand Up @@ -80,7 +89,7 @@ public void remove() {
*/
public List<T> toList() {
final List<T> results = new ArrayList<>();
try (wrapped) {
try (this) {
while (wrapped.hasNext()) {
results.add(next());
}
Expand Down
Loading
Loading