From 7943a770ddfac23fc6a0a47020924e608d28d15f Mon Sep 17 00:00:00 2001 From: aravindcz Date: Mon, 28 Aug 2023 13:08:25 +0530 Subject: [PATCH 1/2] feat: Add the fluent api for creating collection This commit is made to discuss with the team regarding the approach with which we need to design the fluent api and thereby work upon on those ideas Author: Aravind C Date: 28/08/23 --- .../java/tech/amikos/chromadb/Client.java | 60 +++++++++++++++---- .../java/tech/amikos/chromadb/Collection.java | 39 +++++++++++- 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/src/main/java/tech/amikos/chromadb/Client.java b/src/main/java/tech/amikos/chromadb/Client.java index da2f153..54894e2 100644 --- a/src/main/java/tech/amikos/chromadb/Client.java +++ b/src/main/java/tech/amikos/chromadb/Client.java @@ -19,28 +19,49 @@ public class Client { DefaultApi api; - public Client(String basePath) { + private Client(String basePath) { apiClient.setBasePath(basePath); api = new DefaultApi(apiClient); } - public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException { - return new Collection(api, collectionName, embeddingFunction).fetch(); + public static Client newClient(String basePath){ + return new Client(basePath); + } + + public Boolean reset() throws ApiException { + return api.reset(); + } + + public String version() throws ApiException { + return api.version(); } public Map heartbeat() throws ApiException { return api.heartbeat(); } + public List listCollections() throws ApiException { + List apiResponse = (List) api.listCollections(); + return apiResponse.stream().map((LinkedTreeMap m) -> { + try { + return getCollection((String) m.get("name"), null); + } catch (ApiException e) { + e.printStackTrace(); //this is not great as we're swallowing the exception + } + return null; + }).collect(Collectors.toList()); + } + + + public Collection newCollection(){ + return new Collection(api,this,null,null); + } + + public Collection createCollection(String collectionName, Map metadata, Boolean createOrGet, EmbeddingFunction embeddingFunction) throws ApiException { return this.createCollection(collectionName, metadata, createOrGet, embeddingFunction, DistanceFunction.L2); } - public static enum DistanceFunction { - L2, - COSINE, - IP - } public Collection createCollection(String collectionName, Map metadata, Boolean createOrGet, EmbeddingFunction embeddingFunction, DistanceFunction distanceFunction) throws ApiException { CreateCollection req = new CreateCollection(); @@ -58,6 +79,25 @@ public Collection createCollection(String collectionName, Map me return new Collection(api, (String) resp.get("name"), embeddingFunction).fetch(); } + + + + + public Collection getCollection(String collectionName, EmbeddingFunction embeddingFunction) throws ApiException { + return new Collection(api, collectionName, embeddingFunction).fetch(); + } + + + + + + public static enum DistanceFunction { + L2, + COSINE, + IP + } + + public Collection deleteCollection(String collectionName) throws ApiException { Collection collection = Collection.getInstance(api, collectionName); api.deleteCollection(collectionName); @@ -70,9 +110,7 @@ public Collection upsert(String collectionName, EmbeddingFunction ef) throws Api return collection; } - public Boolean reset() throws ApiException { - return api.reset(); - } + public List listCollections() throws ApiException { List apiResponse = (List) api.listCollections(); diff --git a/src/main/java/tech/amikos/chromadb/Collection.java b/src/main/java/tech/amikos/chromadb/Collection.java index 0fb76d5..8b83201 100644 --- a/src/main/java/tech/amikos/chromadb/Collection.java +++ b/src/main/java/tech/amikos/chromadb/Collection.java @@ -14,6 +14,8 @@ public class Collection { static Gson gson = new Gson(); DefaultApi api; + + Client client; String collectionName; String collectionId; @@ -22,13 +24,29 @@ public class Collection { private EmbeddingFunction embeddingFunction; - public Collection(DefaultApi api, String collectionName, EmbeddingFunction embeddingFunction) { + + public Collection(DefaultApi api,String collectionName, EmbeddingFunction embeddingFunction) { this.api = api; this.collectionName = collectionName; this.embeddingFunction = embeddingFunction; } + public Collection(DefaultApi api,Client client, String collectionName, EmbeddingFunction embeddingFunction) { + this.api = api; + this.client = client; + this.collectionName = collectionName; + this.embeddingFunction = embeddingFunction; + + } + + + + public Collection name(String collectionName){ + this.collectionName = collectionName; + return this; + } + public String getName() { return collectionName; } @@ -37,10 +55,27 @@ public String getId() { return collectionId; } + public Collection metadata(String key,Object value){ + metadata.put(key,value); + return this; + } public Map getMetadata() { return metadata; } + public Collection ef(EmbeddingFunction embeddingFunction){ + this.embeddingFunction = embeddingFunction; + return this; + } + + public Collection createOrGet(){ + return client.createCollection(this.collectionName,this.metadata,true,this.embeddingFunction); + } + + public Collection create(){ + return client.createCollection(this.collectionName,this.metadata,false,this.embeddingFunction); + } + public Collection fetch() throws ApiException { try { LinkedTreeMap resp = (LinkedTreeMap) api.getCollection(collectionName); @@ -54,7 +89,7 @@ public Collection fetch() throws ApiException { } public static Collection getInstance(DefaultApi api, String collectionName) throws ApiException { - return new Collection(api, collectionName, null); + return new Collection(api,collectionName, null); } @Override From 6107a631fb7ae15a25f12cd8d08fe82b907a8b63 Mon Sep 17 00:00:00 2001 From: aravindcz Date: Wed, 13 Sep 2023 00:14:06 +0530 Subject: [PATCH 2/2] feat: Add fluent api support --- .../java/tech/amikos/chromadb/Client.java | 37 ++---- .../java/tech/amikos/chromadb/Collection.java | 106 ++++++++++-------- .../java/tech/amikos/chromadb/Embedding.java | 99 ++++++++++++++++ src/main/java/tech/amikos/chromadb/Query.java | 78 +++++++++++++ src/main/resources/openapi/api.yaml | 2 +- 5 files changed, 250 insertions(+), 72 deletions(-) create mode 100644 src/main/java/tech/amikos/chromadb/Embedding.java create mode 100644 src/main/java/tech/amikos/chromadb/Query.java diff --git a/src/main/java/tech/amikos/chromadb/Client.java b/src/main/java/tech/amikos/chromadb/Client.java index 54894e2..792d30d 100644 --- a/src/main/java/tech/amikos/chromadb/Client.java +++ b/src/main/java/tech/amikos/chromadb/Client.java @@ -24,6 +24,13 @@ private Client(String basePath) { api = new DefaultApi(apiClient); } + + public static enum DistanceFunction { + L2, + COSINE, + IP + } + public static Client newClient(String basePath){ return new Client(basePath); } @@ -89,12 +96,10 @@ public Collection getCollection(String collectionName, EmbeddingFunction embeddi - - - public static enum DistanceFunction { - L2, - COSINE, - IP + public Collection upsertCollection(String collectionName, EmbeddingFunction ef) throws ApiException { + Collection collection = getCollection(collectionName, ef); +// collection.upsert(); + return collection; } @@ -104,27 +109,7 @@ public Collection deleteCollection(String collectionName) throws ApiException { return collection; } - public Collection upsert(String collectionName, EmbeddingFunction ef) throws ApiException { - Collection collection = getCollection(collectionName, ef); -// collection.upsert(); - return collection; - } - public List listCollections() throws ApiException { - List apiResponse = (List) api.listCollections(); - return apiResponse.stream().map((LinkedTreeMap m) -> { - try { - return getCollection((String) m.get("name"), null); - } catch (ApiException e) { - e.printStackTrace(); //this is not great as we're swallowing the exception - } - return null; - }).collect(Collectors.toList()); - } - - public String version() throws ApiException { - return api.version(); - } } diff --git a/src/main/java/tech/amikos/chromadb/Collection.java b/src/main/java/tech/amikos/chromadb/Collection.java index 8b83201..75f9527 100644 --- a/src/main/java/tech/amikos/chromadb/Collection.java +++ b/src/main/java/tech/amikos/chromadb/Collection.java @@ -55,7 +55,7 @@ public String getId() { return collectionId; } - public Collection metadata(String key,Object value){ + public Collection metadata(String key,String value){ metadata.put(key,value); return this; } @@ -88,6 +88,32 @@ public Collection fetch() throws ApiException { } } + public Object update(){ + return this.update(this.collectionName,this.metadata); + } + public Object update(String newName, Map newMetadata) throws ApiException { + UpdateCollection req = new UpdateCollection(); + if (newName != null) { + req.setNewName(newName); + } + if (newMetadata != null && embeddingFunction != null) { + if (!newMetadata.containsKey("embedding_function")) { + newMetadata.put("embedding_function", embeddingFunction.getClass().getName()); + } + req.setNewMetadata(newMetadata); + } + Object resp = api.updateCollection(req, this.collectionId); + this.collectionName = newName; + this.fetch(); //do we really need to fetch? + return resp; + } + + public Collection remove(){ + return client.deleteCollection(this.collectionName); + } + + + public static Collection getInstance(DefaultApi api, String collectionName) throws ApiException { return new Collection(api,collectionName, null); } @@ -101,6 +127,29 @@ public String toString() { '}'; } + + public Embedding newEmbedding(){ + return new Embedding(this); + } + + public Object add(List> embeddings, List> metadatas, List documents, List ids) throws ApiException { + AddEmbedding req = new AddEmbedding(); + List> _embeddings = embeddings; + if (_embeddings == null) { + _embeddings = this.embeddingFunction.createEmbedding(documents); + } + req.setEmbeddings((List) (Object) _embeddings); + req.setMetadatas((List>) (Object) metadatas); + req.setDocuments(documents); + req.incrementIndex(true); + req.setIds(ids); + return api.add(req, this.collectionId); + } + + public GetResult get() throws ApiException { + return this.get(null, null, null); + } + public GetResult get(List ids, Map where, Map whereDocument) throws ApiException { GetEmbedding req = new GetEmbedding(); req.ids(ids).where(where).whereDocument(whereDocument); @@ -109,13 +158,6 @@ public GetResult get(List ids, Map where, Map> embeddings, List> metadatas, List documents, List ids) throws ApiException { AddEmbedding req = new AddEmbedding(); @@ -131,23 +173,21 @@ public Object upsert(List> embeddings, List> met return api.upsert(req, this.collectionId); } - - public Object add(List> embeddings, List> metadatas, List documents, List ids) throws ApiException { - AddEmbedding req = new AddEmbedding(); + public Object updateEmbeddings(List> embeddings, List> metadatas, List documents, List ids) throws ApiException { + UpdateEmbedding req = new UpdateEmbedding(); List> _embeddings = embeddings; if (_embeddings == null) { _embeddings = this.embeddingFunction.createEmbedding(documents); } req.setEmbeddings((List) (Object) _embeddings); - req.setMetadatas((List>) (Object) metadatas); req.setDocuments(documents); - req.incrementIndex(true); + req.setMetadatas((List) (Object) metadatas); req.setIds(ids); - return api.add(req, this.collectionId); + return api.update(req, this.collectionId); } - public Integer count() throws ApiException { - return api.count(this.collectionId); + public Object delete() throws ApiException { + return this.delete(null, null, null); } public Object delete(List ids, Map where, Map whereDocument) throws ApiException { @@ -178,42 +218,18 @@ public Object deleteWhereDocuments(Map whereDocument) throws Api return delete(null, null, whereDocument); } + public Integer count() throws ApiException { + return api.count(this.collectionId); + } @Deprecated public Boolean createIndex() throws ApiException { return (Boolean) api.createIndex(this.collectionId); } - public Object update(String newName, Map newMetadata) throws ApiException { - UpdateCollection req = new UpdateCollection(); - if (newName != null) { - req.setNewName(newName); - } - if (newMetadata != null && embeddingFunction != null) { - if (!newMetadata.containsKey("embedding_function")) { - newMetadata.put("embedding_function", embeddingFunction.getClass().getName()); - } - req.setNewMetadata(newMetadata); - } - Object resp = api.updateCollection(req, this.collectionId); - this.collectionName = newName; - this.fetch(); //do we really need to fetch? - return resp; - } - - public Object updateEmbeddings(List> embeddings, List> metadatas, List documents, List ids) throws ApiException { - UpdateEmbedding req = new UpdateEmbedding(); - List> _embeddings = embeddings; - if (_embeddings == null) { - _embeddings = this.embeddingFunction.createEmbedding(documents); - } - req.setEmbeddings((List) (Object) _embeddings); - req.setDocuments(documents); - req.setMetadatas((List) (Object) metadatas); - req.setIds(ids); - return api.update(req, this.collectionId); + public Query newQuery(){ + return new Query(); } - public QueryResponse query(List queryTexts, Integer nResults, Map where, Map whereDocument, List include) throws ApiException { QueryEmbedding body = new QueryEmbedding(); body.queryEmbeddings((List) (Object) this.embeddingFunction.createEmbedding(queryTexts)); diff --git a/src/main/java/tech/amikos/chromadb/Embedding.java b/src/main/java/tech/amikos/chromadb/Embedding.java new file mode 100644 index 0000000..d349927 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/Embedding.java @@ -0,0 +1,99 @@ +package tech.amikos.chromadb; + +import com.google.gson.annotations.SerializedName; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Embedding { + + + private Collection collection; + private List> embeddings = null; + private List> metadatas = null; + private List documents = null; + private List ids = null; + private Map wheres; + private Map whereDocuments; + + public Embedding(Collection collection){ + this.collection = collection; + } + + + public Embedding id(String id){ + if(this.ids == null) + this.ids = new ArrayList<>(); + this.ids.add(id); + return this; + } + + public Embedding metadata(String key,String value){ + if(this.metadatas == null) + metadatas = new ArrayList<>(); + if(this.metadatas.get(0) == null) + metadatas.add(new HashMap<>()); + metadatas.get(metadatas.size()-1).put(key,value); + return this; + } + + public Embedding document(String document){ + if(this.documents == null) + this.documents = new ArrayList<>(); + documents.add(document); + return this; + } + + public Embedding embedding(Float embedding){ + if(this.embeddings == null) + this.embeddings = new ArrayList<>(); + if(this.embeddings.get(0) == null) + this.embeddings.add(new ArrayList<>()); + this.embeddings.get(embeddings.size()-1).add(embedding); + return this; + } + + public Object add(){ + return this.collection.add(embeddings,metadatas,documents,ids); + } + + public Object batchAdd(){ + return this.collection.add(embeddings,metadatas,documents,ids); + } + + public Embedding where(String key,String value){ + if(this.wheres == null) + wheres = new HashMap<>(); + wheres.put(key,value); + return this; + } + + public Embedding whereDocument(String key,Object value){ + if(this.whereDocuments == null) + whereDocuments = new HashMap<>(); + whereDocuments.put(key,value); + return this; + } + + public Collection.GetResult get(){ + if(ids == null) + return this.collection.get(); + return this.collection.get(ids,wheres,whereDocuments); + } + + public Object upsert(){ + return this.collection.upsert(embeddings,metadatas,documents,ids); + } + + public Object update(){ + return this.collection.updateEmbeddings(embeddings,metadatas,documents,ids); + } + + public Object delete(){ + return this.collection.delete(ids,wheres,whereDocuments); + } + + +} diff --git a/src/main/java/tech/amikos/chromadb/Query.java b/src/main/java/tech/amikos/chromadb/Query.java new file mode 100644 index 0000000..6ab5fa5 --- /dev/null +++ b/src/main/java/tech/amikos/chromadb/Query.java @@ -0,0 +1,78 @@ +package tech.amikos.chromadb; + +import com.google.gson.annotations.SerializedName; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class Query { + + private Collection collection; + private List queryTexts; + private Integer nResults; + private Map wheres; + private Map whereDocuments; + private List includes; + public Query(){ + } + + public Query queryText(String queryText){ + if(this.queryTexts == null) + this.queryTexts = new ArrayList<>(); + this.queryTexts.add(queryText); + return this; + } + + public Query nResults(Integer nResults){ + this.nResults = nResults; + return this; + } + + public Query where(String key,String value){ + if(this.wheres == null) + wheres = new HashMap<>(); + wheres.put(key,value); + return this; + } + + public Query whereDocument(String key,String value){ + if(this.whereDocuments == null) + whereDocuments = new HashMap<>(); + whereDocuments.put(key,value); + return this; + } + + public Query includeDocuments(){ + if(this.includes == null) + this.includes = new ArrayList(); + this.includes.add(QueryEmbedding.IncludeEnum.DOCUMENTS); + return this; + } + + public Query includeEmbeddings(){ + if(this.includes == null) + this.includes = new ArrayList(); + this.includes.add(QueryEmbedding.IncludeEnum.EMBEDDINGS); + return this; + } + + public Query includeMetadatas(){ + if(this.includes == null) + this.includes = new ArrayList(); + this.includes.add(QueryEmbedding.IncludeEnum.METADATAS); + return this; + } + + public Query includeDistances(){ + if(this.includes == null) + this.includes = new ArrayList(); + this.includes.add(QueryEmbedding.IncludeEnum.DISTANCES); + return this; + } + + public Collection.QueryResponse query(){ + return this.collection.query(queryTexts,nResults,wheres,whereDocuments,includes); + } +} diff --git a/src/main/resources/openapi/api.yaml b/src/main/resources/openapi/api.yaml index ef5e6fc..e48cf5f 100644 --- a/src/main/resources/openapi/api.yaml +++ b/src/main/resources/openapi/api.yaml @@ -63,7 +63,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/RawSql' + $ref: e'#/components/schemas/RawSql' required: true responses: '200':