diff --git a/packages/mongo/src/index.ts b/packages/mongo/src/index.ts index 8c711642..dcea6880 100644 --- a/packages/mongo/src/index.ts +++ b/packages/mongo/src/index.ts @@ -15,6 +15,7 @@ export class MongoDriver extends Driver { private session?: ClientSession private _createTasks: Dict> = {} + private _ensureTask: Dict> = {} private connectionStringFromConfig() { const { @@ -349,15 +350,27 @@ export class MongoDriver extends Driver { const { primary, autoInc } = model if (typeof primary === 'string' && autoInc && model.fields[primary]?.type !== 'primary') { const missing = data.filter(item => !(primary in item)) - if (!missing.length) return - const doc = await this.db.collection('_fields').findOneAndUpdate( - { table, field: primary }, - { $inc: { autoInc: missing.length } }, - { session: this.session, upsert: true }, - ) - for (let i = 1; i <= missing.length; i++) { - missing[i - 1][primary] = (doc!.autoInc ?? 0) + i - } + const lastTask = Promise.resolve(this._ensureTask[table]).catch(noop) + return this._ensureTask[table] = lastTask.then(async () => { + if (!missing.length) { + return this.db.collection('_fields').updateOne( + { table, field: primary }, + [{ $set: { autoInc: { $max: ['$autoInc', Math.max(...data.map(x => x[primary]))] } } }], + { session: this.session, upsert: true }, + ).then(noop) + } + const doc = await this.db.collection('_fields').findOne({ table, field: primary }, { session: this.session }) + const exists = data.filter(item => item[primary] > (doc?.autoInc ?? 0)).map(x => x[primary]).sort() + let j = 0 + for (let i = 1; i <= missing.length; i++) { + while (exists[j] && (doc?.autoInc ?? 0) + i + j >= exists[j]) j++ + missing[i - 1][primary] = (doc?.autoInc ?? 0) + i + j + } + await this.db.collection('_fields').updateOne( + { table, field: primary }, + { $set: { autoInc: (doc?.autoInc ?? 0) + missing.length + j } }, + { session: this.session, upsert: true }) + }) } } diff --git a/packages/postgres/src/index.ts b/packages/postgres/src/index.ts index f3e62fd2..c1728421 100644 --- a/packages/postgres/src/index.ts +++ b/packages/postgres/src/index.ts @@ -1,6 +1,6 @@ import postgres from 'postgres' import { Dict, difference, isNullable, makeArray, pick } from 'cosmokit' -import { Driver, Eval, executeUpdate, Field, Selection, z } from 'minato' +import { Driver, Eval, executeUpdate, Field, randomId, Selection, z } from 'minato' import { isBracketed } from '@minatojs/sql-utils' import { formatTime, PostgresBuilder } from './builder' @@ -393,10 +393,17 @@ export class PostgresDriver extends Driver { const { table, model } = sel const builder = new PostgresBuilder(sel.tables) const formatted = builder.dump(model, data) + const keys = Object.keys(formatted) const [row] = await this.query([ `INSERT INTO ${builder.escapeId(table)} (${keys.map(builder.escapeId).join(', ')})`, - `VALUES (${keys.map(key => builder.escape(formatted[key])).join(', ')})`, + `VALUES (${keys.map(key => { + if (model.autoInc && key === model.primary) { + return `(select ${formatted[key]} from (select setval(pg_get_serial_sequence(${builder.escapeKey(table)}, ${builder.escapeKey(key)}), + greatest(nextval(pg_get_serial_sequence(${builder.escapeKey(table)}, ${builder.escapeKey(key)}))-1, ${formatted[key]}))) ${randomId()})` + } + return builder.escape(formatted[key]) + }).join(', ')})`, `RETURNING *`, ].join(' ')) return builder.load(model, row) @@ -436,7 +443,11 @@ export class PostgresDriver extends Driver { const formatValues = (table: string, data: object, keys: readonly string[]) => { return keys.map((key) => { const field = this.database.tables[table]?.fields[key] - if (model.autoInc && model.primary === key && !data[key]) return 'default' + if (model.autoInc && model.primary === key) { + return data[key] ? `(select ${data[key]} from (select setval(pg_get_serial_sequence(${builder.escapeKey(table)}, ${builder.escapeKey(key)}), + greatest(nextval(pg_get_serial_sequence(${builder.escapeKey(table)}, ${builder.escapeKey(key)}))-1, ${data[key]}))) ${randomId()})` + : 'default' + } return builder.escape(data[key], field) }).join(', ') }