Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions src/AI/AI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ import { logger } from "../logger";
import { checkRepeat, handleReply, MessageSegment, transformTextToArray } from "../utils/utils_string";
import { TimerManager } from "../timer";

export interface GroupInfo {
isPrivate: false;
id: string;
name: string;
}

export interface UserInfo {
isPrivate: true;
id: string;
name: string;
}

export type SessionInfo = GroupInfo | UserInfo;

export class Setting {
static validKeys: (keyof Setting)[] = ['priv', 'standby', 'counter', 'timer', 'prob', 'activeTimeInfo'];
priv: number;
Expand Down Expand Up @@ -147,7 +161,7 @@ export class AI {
const MaxRetry = 3;
for (let retry = 1; retry <= MaxRetry; retry++) {
// 处理messages
const messages = handleMessages(ctx, this);
const messages = await handleMessages(ctx, this);

//获取处理后的回复
const raw_reply = await sendChatRequest(ctx, msg, this, messages, "auto");
Expand Down Expand Up @@ -195,7 +209,7 @@ export class AI {

await this.stopCurrentChatStream();

const messages = handleMessages(ctx, this);
const messages = await handleMessages(ctx, this);
const id = await startStream(messages);
if (id === '') {
return;
Expand Down Expand Up @@ -337,7 +351,7 @@ export class AI {
}

// 若不在活动时间范围内,返回-1
getCurSegIndex(): number {
get curActiveTimeSegIndex(): number {
const now = new Date();
const cur = now.getHours() * 60 + now.getMinutes();
const { start, end, segs } = this.setting.activeTimeInfo;
Expand Down Expand Up @@ -378,7 +392,7 @@ export class AI {
if (segs !== 0 && (start !== 0 || end !== 0)) {
const timers = TimerManager.getTimers(this.id, '', ['activeTime']);
if (timers.length === 0) {
const curSegIndex = this.getCurSegIndex();
const curSegIndex = this.curActiveTimeSegIndex;
const nextTimePoint = this.getNextTimePoint(curSegIndex);
if (nextTimePoint !== -1) {
TimerManager.addActiveTimeTimer(ctx, msg, this, nextTimePoint);
Expand All @@ -390,17 +404,25 @@ export class AI {
}
}

export interface UsageInfo {
prompt_tokens: number,
completion_tokens: number
}

export class AIManager {
static version = "1.0.1";
static cache: { [key: string]: AI } = {};
static usageMap: {
[key: string]: { // 模型名
[key: number]: { // 年月日
prompt_tokens: number,
completion_tokens: number
static usageMapCache: { [model: string]: { [time: number]: UsageInfo } } = null;

static get usageMap(): { [model: string]: { [time: number]: UsageInfo } } {
if (!this.usageMapCache) {
try {
this.usageMapCache = JSON.parse(ConfigManager.ext.storageGet('usageMap') || '{}');
} catch (error) {
logger.error(`从数据库中获取usageMap失败:`, error);
}
}
} = {};
return this.usageMapCache;
}

static clearCache() {
this.cache = {};
Expand Down Expand Up @@ -458,7 +480,7 @@ export class AIManager {
}

static clearUsageMap() {
this.usageMap = {};
this.usageMapCache = {};
}

static clearExpiredUsage(model: string) {
Expand Down Expand Up @@ -504,17 +526,8 @@ export class AIManager {
}
}

static getUsageMap() {
try {
const usage = JSON.parse(ConfigManager.ext.storageGet('usageMap') || '{}');
this.usageMap = usage;
} catch (error) {
logger.error(`从数据库中获取usageMap失败:`, error);
}
}

static saveUsageMap() {
ConfigManager.ext.storageSet('usageMap', JSON.stringify(this.usageMap));
ConfigManager.ext.storageSet('usageMap', JSON.stringify(this.usageMapCache));
}

static updateUsage(model: string, usage: {
Expand Down
65 changes: 28 additions & 37 deletions src/AI/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@ import { ConfigManager } from "../config/config";
import { Image } from "./image";
import { createCtx, createMsg } from "../utils/utils_seal";
import { levenshteinDistance, MessageSegment } from "../utils/utils_string";
import { AI, AIManager } from "./AI";
import { AI, AIManager, UserInfo } from "./AI";
import { logger } from "../logger";
import { transformMsgId } from "../utils/utils";
import { getGroupMemberInfo, getStrangerInfo } from "../utils/utils_ob11";

export interface UserNameInfo { // 用于上下文名字修改相关操作
uid: string;
name: string;
}

export interface MessageInfo {
msgId: string;
time: number; // 秒
Expand Down Expand Up @@ -86,7 +81,7 @@ export class Context {
}

async addMessage(ctx: seal.MsgContext, msg: seal.Message, ai: AI, messageArray: MessageSegment[], images: Image[], role: 'user' | 'assistant', msgId: string = '') {
const { showNumber, showMsgId, maxRounds } = ConfigManager.message;
const { showNumber, showMsgId } = ConfigManager.message;
const { isShortMemory, shortMemorySummaryRound } = ConfigManager.memory;
const messages = this.messages;

Expand Down Expand Up @@ -129,11 +124,11 @@ export class Context {
return;
}

const now = Math.floor(Date.now() / 1000);
const uid = role == 'user' ? ctx.player.userId : ctx.endPoint.userId;

// 自动更新上下文里的名字
const exists = messages.some(message => message.uid === uid);
if (!exists) {
// 自动更新上下文里的名字,发言时间一小时内不更新
if (!messages.some(message => message.uid === uid && message.msgArray.some(msgInfo => msgInfo.time >= now - 3600))) {
await this.updateName(ctx.endPoint.userId, ctx.group.groupId, uid);
}

Expand Down Expand Up @@ -167,7 +162,7 @@ export class Context {
messages[length - 1].images.push(...images);
messages[length - 1].msgArray.push({
msgId: msgId,
time: Math.floor(Date.now() / 1000),
time: now,
content: s
});
} else {
Expand All @@ -178,7 +173,7 @@ export class Context {
images: images,
msgArray: [{
msgId: msgId,
time: Math.floor(Date.now() / 1000),
time: now,
content: s
}]
};
Expand All @@ -197,10 +192,10 @@ export class Context {
}

//更新记忆权重
ai.memory.updateMemoryWeight(ctx, ai.context, s, role);
ai.memory.updateRelatedMemoryWeight(ctx, ai.context, s, role);

//删除多余的上下文
this.limitMessages(maxRounds);
this.limitMessages();
}

async addToolCallsMessage(tool_calls: ToolCall[]) {
Expand All @@ -216,6 +211,7 @@ export class Context {
}

async addToolMessage(tool_call_id: string, s: string, images: Image[]) {
const now = Math.floor(Date.now() / 1000);
const message: Message = {
role: 'tool',
tool_call_id: tool_call_id,
Expand All @@ -224,7 +220,7 @@ export class Context {
images: images,
msgArray: [{
msgId: '',
time: Math.floor(Date.now() / 1000),
time: now,
content: s
}]
};
Expand All @@ -240,21 +236,23 @@ export class Context {
}

async addSystemUserMessage(name: string, s: string, images: Image[]) {
const now = Math.floor(Date.now() / 1000);
const message: Message = {
role: 'user',
uid: '',
name: `_${name}`,
images: images,
msgArray: [{
msgId: '',
time: Math.floor(Date.now() / 1000),
time: now,
content: s
}]
};
this.messages.push(message);
}

limitMessages(maxRounds: number) {
limitMessages() {
const { maxRounds } = ConfigManager.message;
const messages = this.messages;
let round = 0;
for (let i = messages.length - 1; i >= 0; i--) {
Expand Down Expand Up @@ -382,17 +380,14 @@ export class Context {
continue;
}

const ai = AIManager.getAI(uid);
const memoryList = Object.values(ai.memory.memoryMap);

for (const m of memoryList) {
if (m.group.groupName === groupName) {
return m.group.groupId;
for (const m of AIManager.getAI(uid).memory.memoryList) {
if (m.sessionInfo.isPrivate && m.sessionInfo.name === groupName) {
return m.sessionInfo.id;
}
if (m.group.groupName.length > 4) {
const distance = levenshteinDistance(groupName, m.group.groupName);
if (m.sessionInfo.isPrivate && m.sessionInfo.name.length > 4) {
const distance = levenshteinDistance(groupName, m.sessionInfo.name);
if (distance <= 2) {
return m.group.groupId;
return m.sessionInfo.id;
}
}
}
Expand Down Expand Up @@ -423,13 +418,14 @@ export class Context {
return null;
}

getUserNameInfo(): UserNameInfo[] {
const userMap: { [key: string]: UserNameInfo } = {};
get userInfoList(): UserInfo[] {
const userMap: { [key: string]: UserInfo } = {};
this.messages.forEach(message => {
if (message.role === 'user' && message.name && message.uid && !message.name.startsWith('_')) {
userMap[message.uid] = {
name: message.name,
uid: message.uid,
isPrivate: true,
id: message.uid,
name: message.name
};
}
});
Expand Down Expand Up @@ -530,15 +526,10 @@ export class Context {

const { localImagePaths } = ConfigManager.image;
const localImages: { [key: string]: string } = localImagePaths.reduce((acc: { [key: string]: string }, path: string) => {
if (path.trim() === '') {
return acc;
}
if (path.trim() === '') return acc;
try {
const name = path.split('/').pop().replace(/\.[^/.]+$/, '');
if (!name) {
throw new Error(`本地图片路径格式错误:${path}`);
}

if (!name) throw new Error(`本地图片路径格式错误:${path}`);
acc[name] = path;
} catch (e) {
logger.error(e);
Expand Down
Loading