diff --git a/core/src/main/java/dev/morphia/mapping/codec/DecodeSession.java b/core/src/main/java/dev/morphia/mapping/codec/DecodeSession.java new file mode 100644 index 00000000000..da11a6d3423 --- /dev/null +++ b/core/src/main/java/dev/morphia/mapping/codec/DecodeSession.java @@ -0,0 +1,81 @@ +package dev.morphia.mapping.codec; + +import java.util.HashMap; +import java.util.Map; + +import com.mongodb.lang.Nullable; + +import dev.morphia.annotations.internal.MorphiaInternal; + +/** + * Per-cursor-decode cache that maps (collection, id) → entity instance. + * Activated via {@link DecodeSession#activate()} at cursor creation and cleared via + * {@link DecodeSession#deactivate()} when the cursor is closed, so all documents in + * one query iteration share the same cache. + * + * @hidden + * @morphia.internal + */ +@MorphiaInternal +public class DecodeSession { + private static final ThreadLocal CURRENT = new ThreadLocal<>(); + + private final Map> cache = new HashMap<>(); + + private DecodeSession() { + } + + /** + * Activates a session on the current thread. If a session is already active it is + * reused, so nested activations (e.g. fetching a @Reference while decoding an outer + * document) share one cache. Returns {@code true} if this call created the root session + * and therefore owns the responsibility of calling {@link #deactivate()}. + * + * @return {@code true} if a new root session was created; {@code false} if an existing session was reused + */ + public static boolean activate() { + if (CURRENT.get() != null) { + return false; + } + CURRENT.set(new DecodeSession()); + return true; + } + + /** + * Returns the session active on the current thread, or {@code null} if none. + */ + @Nullable + public static DecodeSession current() { + return CURRENT.get(); + } + + /** + * Removes the session from the current thread. + */ + public static void deactivate() { + CURRENT.remove(); + } + + /** + * Stores a decoded entity in the cache. + * + * @param collection the MongoDB collection name + * @param id the entity's {@code _id} value + * @param entity the decoded entity instance + */ + public void register(String collection, Object id, Object entity) { + cache.computeIfAbsent(collection, k -> new HashMap<>()).put(id, entity); + } + + /** + * Returns a previously cached entity, or {@code null} if not present. + * + * @param collection the MongoDB collection name + * @param id the entity's {@code _id} value + */ + @Nullable + public Object lookup(String collection, Object id) { + Map col = cache.get(collection); + return col != null ? col.get(id) : null; + } +} diff --git a/core/src/main/java/dev/morphia/mapping/codec/pojo/EntityDecoder.java b/core/src/main/java/dev/morphia/mapping/codec/pojo/EntityDecoder.java index 67741760b74..1a2f424cb3b 100644 --- a/core/src/main/java/dev/morphia/mapping/codec/pojo/EntityDecoder.java +++ b/core/src/main/java/dev/morphia/mapping/codec/pojo/EntityDecoder.java @@ -4,6 +4,7 @@ import dev.morphia.annotations.internal.MorphiaInternal; import dev.morphia.mapping.DiscriminatorLookup; +import dev.morphia.mapping.codec.DecodeSession; import dev.morphia.mapping.codec.MorphiaInstanceCreator; import org.bson.BsonInvalidOperationException; @@ -45,7 +46,29 @@ public T decode(BsonReader reader, DecoderContext decoderContext) { if (decoderContext.hasCheckedDiscriminator()) { LOG.debug(format("Decoding document using codec for %s'", morphiaCodec.getEntityModel().getType().getName())); MorphiaInstanceCreator instanceCreator = getInstanceCreator(); + T instance = (T) instanceCreator.getInstance(); + + DecodeSession session = DecodeSession.current(); + Object prereadId = null; + if (session != null) { + prereadId = peekId(reader); + if (prereadId != null) { + session.register(classModel.collectionName(), prereadId, instance); + } + } + decodeProperties(reader, decoderContext, instanceCreator, classModel); + + if (session != null && prereadId == null) { + PropertyModel idProp = classModel.getIdProperty(); + if (idProp != null) { + Object id = morphiaCodec.getDatastore().getMapper().getId(instance); + if (id != null) { + session.register(classModel.collectionName(), id, instance); + } + } + } + return (T) instanceCreator.getInstance(); } else { entity = getCodecFromDocument(reader, classModel.useDiscriminator(), classModel.discriminatorKey(), @@ -117,6 +140,34 @@ protected Codec getCodecFromDocument(BsonReader reader, boolean useDiscrimina return codec != null ? codec : defaultCodec; } + @Nullable + private Object peekId(BsonReader reader) { + BsonReaderMark mark = reader.getMark(); + try { + reader.readStartDocument(); + String idName = classModel.getIdProperty() != null + ? classModel.getIdProperty().getMappedName() + : "_id"; + while (reader.readBsonType() != BsonType.END_OF_DOCUMENT) { + String name = reader.readName(); + if ("_id".equals(name) || name.equals(idName)) { + return morphiaCodec.getRegistry() + .get(Object.class) + .decode(reader, DecoderContext.builder().build()); + } else { + reader.skipValue(); + } + } + return null; + } catch (Exception e) { + LOG.debug("Could not pre-read _id for DecodeSession on {}; cycle detection may not apply", + classModel.getType().getSimpleName(), e); + return null; + } finally { + mark.reset(); + } + } + protected MorphiaInstanceCreator getInstanceCreator() { return classModel.getInstanceCreator(morphiaCodec.getConversions()); } diff --git a/core/src/main/java/dev/morphia/mapping/codec/pojo/MorphiaCodec.java b/core/src/main/java/dev/morphia/mapping/codec/pojo/MorphiaCodec.java index 752104e74a1..bf9eaba57de 100644 --- a/core/src/main/java/dev/morphia/mapping/codec/pojo/MorphiaCodec.java +++ b/core/src/main/java/dev/morphia/mapping/codec/pojo/MorphiaCodec.java @@ -7,6 +7,7 @@ import dev.morphia.mapping.DiscriminatorLookup; import dev.morphia.mapping.MappingException; import dev.morphia.mapping.codec.Conversions; +import dev.morphia.mapping.codec.DecodeSession; import dev.morphia.mapping.codec.PropertyCodecRegistryImpl; import dev.morphia.sofia.Sofia; @@ -77,7 +78,14 @@ public MorphiaCodec(MorphiaDatastore datastore, EntityModel model, @Override public T decode(BsonReader reader, DecoderContext decoderContext) { - return getDecoder().decode(reader, decoderContext); + boolean root = DecodeSession.activate(); + try { + return getDecoder().decode(reader, decoderContext); + } finally { + if (root) { + DecodeSession.deactivate(); + } + } } @Override diff --git a/core/src/main/java/dev/morphia/mapping/codec/references/ReferenceCodec.java b/core/src/main/java/dev/morphia/mapping/codec/references/ReferenceCodec.java index 508d93999fe..90b672c596b 100644 --- a/core/src/main/java/dev/morphia/mapping/codec/references/ReferenceCodec.java +++ b/core/src/main/java/dev/morphia/mapping/codec/references/ReferenceCodec.java @@ -27,6 +27,7 @@ import dev.morphia.annotations.internal.MorphiaInternal; import dev.morphia.mapping.Mapper; import dev.morphia.mapping.MappingException; +import dev.morphia.mapping.codec.DecodeSession; import dev.morphia.mapping.codec.pojo.EntityModel; import dev.morphia.mapping.codec.pojo.PropertyHandler; import dev.morphia.mapping.codec.pojo.PropertyModel; @@ -322,6 +323,64 @@ private Class makeProxy() { .getLoaded(); } + @Nullable + private Object lookupInSession(Object id, EntityModel entityModel) { + DecodeSession session = DecodeSession.current(); + if (session == null) { + return null; + } + String collection = id instanceof DBRef + ? ((DBRef) id).getCollectionName() + : entityModel.collectionName(); + Object lookupId = id instanceof DBRef ? ((DBRef) id).getId() : id; + return session.lookup(collection, lookupId); + } + + /** + * Returns a map of stripped-id → cached entity for each id in rawIds that is present + * in the current session. Ids not yet in the session are absent from the result. + */ + private Map partialSessionLookup(List rawIds, EntityModel entityModel) { + DecodeSession session = DecodeSession.current(); + if (session == null) { + return Map.of(); + } + Map hits = new LinkedHashMap<>(); + for (Object id : rawIds) { + String collection = id instanceof DBRef ? ((DBRef) id).getCollectionName() : entityModel.collectionName(); + Object lookupId = id instanceof DBRef ? ((DBRef) id).getId() : id; + Object cached = session.lookup(collection, lookupId); + if (cached != null) { + hits.put(lookupId, cached); + } + } + return hits; + } + + /** + * Returns a map of stripped-id → cached entity for each map value id that is present + * in the current session. Returns null if no session is active. + */ + @Nullable + private Map lookupMapInSession(Map ids, EntityModel entityModel) { + DecodeSession session = DecodeSession.current(); + if (session == null) { + return null; + } + Map result = new LinkedHashMap<>(); + for (Entry entry : ids.entrySet()) { + Object rawId = entry.getValue(); + String collection = rawId instanceof DBRef ? ((DBRef) rawId).getCollectionName() : entityModel.collectionName(); + Object lookupId = rawId instanceof DBRef ? ((DBRef) rawId).getId() : rawId; + Object cached = session.lookup(collection, lookupId); + if (cached == null) { + return null; // any miss: fall through to full DB fetch + } + result.put(entry.getKey(), cached); + } + return result; + } + @Nullable private Object fetch(Object value) { boolean lazy = annotation.lazy(); @@ -336,7 +395,7 @@ private Object fetch(Object value) { return preDecoded; } List ids = stripDbRefs(rawIds); - Supplier loader = () -> fetchCollection(rawIds, entityModel, ignoreMissing); + Supplier loader = () -> fetchCollectionMerged(rawIds, entityModel, ignoreMissing); return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get(); } else if (Set.class.isAssignableFrom(type)) { @@ -346,7 +405,7 @@ private Object fetch(Object value) { return new LinkedHashSet<>(preDecoded); } List ids = stripDbRefs(rawIds); - Supplier loader = () -> new LinkedHashSet<>(fetchCollection(rawIds, entityModel, ignoreMissing)); + Supplier loader = () -> new LinkedHashSet<>(fetchCollectionMerged(rawIds, entityModel, ignoreMissing)); return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get(); } else if (Map.class.isAssignableFrom(type)) { @@ -357,6 +416,10 @@ private Object fetch(Object value) { ids.put(mapper.getConversions().convert(entry.getKey(), keyType), entry.getValue()); } List idList = stripDbRefs(new ArrayList<>(ids.values())); + Map cachedMap = lookupMapInSession(ids, entityModel); + if (cachedMap != null) { + return cachedMap; + } Supplier loader = () -> fetchMap(ids, entityModel); return lazy ? createProxy(loader, idList, entityModel.getType()) : loader.get(); @@ -366,6 +429,10 @@ private Object fetch(Object value) { if (entityModel.getType().isInstance(id)) { return id; } + Object cached = lookupInSession(id, entityModel); + if (cached != null) { + return cached; + } List ids = List.of(stripDbRef(id)); Supplier loader = () -> fetchSingle(id, entityModel, ignoreMissing); return lazy ? createProxy(loader, ids, entityModel.getType()) : loader.get(); @@ -376,20 +443,40 @@ private Object fetchSingle(Object id, EntityModel entityModel, boolean ignoreMis var query = id instanceof DBRef ? datastore.find(mapper.getClassFromCollection(((DBRef) id).getCollectionName())) : datastore.find(entityModel.getType()); - Object result = query.filter(eq("_id", stripDbRef(id))).iterator().tryNext(); - if (result == null && !ignoreMissing) { - throw new ReferenceException(Sofia.missingReferencedEntity(entityModel.getType().getSimpleName())); + try (var cursor = query.filter(eq("_id", stripDbRef(id))).iterator()) { + Object result = cursor.tryNext(); + if (result == null && !ignoreMissing) { + throw new ReferenceException(Sofia.missingReferencedEntity(entityModel.getType().getSimpleName())); + } + return result; } - return result; } - private List fetchCollection(List ids, EntityModel entityModel, boolean ignoreMissing) { + private List fetchCollectionMerged(List rawIds, EntityModel entityModel, boolean ignoreMissing) { + Map idMap = new HashMap<>(partialSessionLookup(rawIds, entityModel)); + + List missingIds = rawIds.stream() + .filter(id -> { + Object lookupId = id instanceof DBRef ? ((DBRef) id).getId() : id; + return !idMap.containsKey(lookupId); + }) + .collect(Collectors.toList()); + + if (!missingIds.isEmpty()) { + idMap.putAll(buildIdMap(missingIds, entityModel, ignoreMissing)); + } + + return mapIdsToValues(rawIds, idMap).stream() + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } + + private Map buildIdMap(List ids, EntityModel entityModel, boolean ignoreMissing) { Map> byCollection = new HashMap<>(); for (Object id : ids) { if (id instanceof DBRef) { byCollection.computeIfAbsent(((DBRef) id).getCollectionName(), k -> new ArrayList<>()).add(((DBRef) id).getId()); } else { - // nested List items are stored as-is; extractFlatIds flattens them for the query byCollection.computeIfAbsent(entityModel.collectionName(), k -> new ArrayList<>()).add(id); } } @@ -398,8 +485,11 @@ private List fetchCollection(List ids, EntityModel entityModel, boole for (Entry> entry : byCollection.entrySet()) { idMap.putAll(queryCollection(entry.getKey(), extractFlatIds(entry.getValue()), entityModel, ignoreMissing)); } + return idMap; + } - return mapIdsToValues(ids, idMap).stream() + private List fetchCollection(List ids, EntityModel entityModel, boolean ignoreMissing) { + return mapIdsToValues(ids, buildIdMap(ids, entityModel, ignoreMissing)).stream() .filter(Objects::nonNull) .collect(Collectors.toList()); } diff --git a/core/src/main/java/dev/morphia/query/MorphiaCursor.java b/core/src/main/java/dev/morphia/query/MorphiaCursor.java index a0a5efc99ed..9be0b17b649 100644 --- a/core/src/main/java/dev/morphia/query/MorphiaCursor.java +++ b/core/src/main/java/dev/morphia/query/MorphiaCursor.java @@ -9,6 +9,7 @@ import com.mongodb.lang.NonNull; import dev.morphia.annotations.internal.MorphiaInternal; +import dev.morphia.mapping.codec.DecodeSession; /** * @param the original type being iterated @@ -16,6 +17,7 @@ */ public class MorphiaCursor implements AutoCloseable, MongoCursor { private final MongoCursor wrapped; + private final boolean rootSession; /** * Creates a MorphiaCursor @@ -27,13 +29,20 @@ public class MorphiaCursor implements AutoCloseable, MongoCursor { @MorphiaInternal public MorphiaCursor(MongoCursor cursor) { wrapped = cursor; + rootSession = DecodeSession.activate(); } /** - * Closes the underlying cursor. + * Closes the underlying cursor and releases the decode session if this cursor owns it. */ public void close() { - wrapped.close(); + try { + wrapped.close(); + } finally { + if (rootSession) { + DecodeSession.deactivate(); + } + } } @Override @@ -80,7 +89,7 @@ public void remove() { */ public List toList() { final List results = new ArrayList<>(); - try (wrapped) { + try (this) { while (wrapped.hasNext()) { results.add(next()); } diff --git a/core/src/test/java/dev/morphia/test/mapping/TestReferences.java b/core/src/test/java/dev/morphia/test/mapping/TestReferences.java index d2e46092ee9..acabf684b42 100644 --- a/core/src/test/java/dev/morphia/test/mapping/TestReferences.java +++ b/core/src/test/java/dev/morphia/test/mapping/TestReferences.java @@ -1228,4 +1228,72 @@ public void setId(ObjectId id) { this.id = id; } } + + @Entity + private static class TwoRefContainer { + @Id + private ObjectId id; + @Reference(idOnly = true) + private Ref ref1; + @Reference(idOnly = true) + private Ref ref2; + } + + @Entity + private static class NodeA { + @Id + private ObjectId id = new ObjectId(); + private String name; + @Reference + private NodeB partner; + } + + @Entity + private static class NodeB { + @Id + private ObjectId id = new ObjectId(); + private String name; + @Reference + private NodeA partner; + } + + @Test + public void testReferenceDeduplication() { + // A single document with two @Reference fields pointing to the same entity. + // Both fields should decode to the same Java instance within one decode session. + Ref shared = new Ref("shared-ref"); + getDs().save(shared); + + TwoRefContainer container = new TwoRefContainer(); + container.ref1 = shared; + container.ref2 = shared; + getDs().save(container); + + TwoRefContainer loaded = getDs().find(TwoRefContainer.class).first(); + Assertions.assertNotNull(loaded); + Assertions.assertSame(loaded.ref1, loaded.ref2, "Both ref fields should point to the same Ref instance"); + } + + @Test + public void testCyclicReferenceDoesNotStackOverflow() { + NodeA a = new NodeA(); + a.name = "alpha"; + NodeB b = new NodeB(); + b.name = "beta"; + + getDs().save(a); + getDs().save(b); + + a.partner = b; + b.partner = a; + getDs().save(a); + getDs().save(b); + + NodeA loaded = getDs().find(NodeA.class).filter(eq("_id", a.id)).first(); + Assertions.assertNotNull(loaded); + Assertions.assertNotNull(loaded.partner); + Assertions.assertEquals("beta", loaded.partner.name); + Assertions.assertNotNull(loaded.partner.partner); + Assertions.assertEquals("alpha", loaded.partner.partner.name); + } }