@@ -14,12 +14,14 @@ import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessage
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO ;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO ;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO ;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegment DO ;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRole DO ;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO ;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper ;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum ;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants ;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService ;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO ;
import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO ;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService ;
import cn.iocoder.yudao.module.ai.service.model.AiModelService ;
import jakarta.annotation.Resource ;
import lombok.extern.slf4j.Slf4j ;
@@ -32,13 +34,13 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel ;
import org.springframework.ai.chat.prompt.ChatOptions ;
import org.springframework.ai.chat.prompt.Prompt ;
import org.springframework.ai.chat.prompt.PromptTemplate ;
import org.springframework.stereotype.Service ;
import org.springframework.transaction.annotation.Transactional ;
import reactor.core.publisher.Flux ;
import java.time.LocalDateTime ;
import java.util.* ;
import java.util.stream.Collectors ;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception ;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error ;
@@ -56,12 +58,21 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
@Slf4j
public class AiChatMessageServiceImpl implements AiChatMessageService {
/**
* 知识库转 {@link UserMessage} 的内容模版
*/
private static final String KNOWLEDGE_USER_MESSAGE_TEMPLATE = " 使用 <Reference></Reference> 标记中的内容作为本次对话的参考: \ n \ n " +
" %s \ n \ n " + // 多个 <Reference></Reference> 的拼接
" 回答要求: \ n- 避免提及你是从 <Reference></Reference> 获取的知识。 " ;
@Resource
private AiChatMessageMapper chatMessageMapper ;
@Resource
private AiChatConversationService chatConversationService ;
@Resource
private AiChatRoleService chatRoleService ;
@Resource
private AiModelService modalService ;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService ;
@@ -69,118 +80,143 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
@Transactional ( rollbackFor = Exception . class )
public AiChatMessageSendRespVO sendMessage ( AiChatMessageSendReqVO sendReqVO , Long userId ) {
// 1.1 校验对话存在
AiChatConversationDO conversation = chatConversationService . validateChatConversationExists ( sendReqVO . getConversationId ( ) ) ;
AiChatConversationDO conversation = chatConversationService
. validateChatConversationExists ( sendReqVO . getConversationId ( ) ) ;
if ( ObjUtil . notEqual ( conversation . getUserId ( ) , userId ) ) {
throw exception ( CHAT_CONVERSATION_NOT_EXISTS ) ;
}
List < AiChatMessageDO > historyMessages = chatMessageMapper . selectListByConversationId ( conversation . getId ( ) ) ;
// 1.2 校验模型
AiModelDO model = modalService . validateModel ( conversation . getModelId ( ) ) ;
ChatModel chatModel = modalService . getChatModel ( model . getKey Id ( ) ) ;
ChatModel chatModel = modalService . getChatModel ( model . getId ( ) ) ;
// 2. 插入 user 发送消息
// 2. 知识库找回
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = recallKnowledgeSegment ( sendReqVO . getContent ( ) ,
conversation ) ;
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage ( conversation . getId ( ) , null , model ,
userId , conversation . getRoleId ( ) , MessageType . USER , sendReqVO . getContent ( ) , sendReqVO . getUseContext ( ) ) ;
userId , conversation . getRoleId ( ) , MessageType . USER , sendReqVO . getContent ( ) , sendReqVO . getUseContext ( ) ,
null ) ;
// 3.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage ( conversation . getId ( ) , userMessage . getId ( ) , model ,
userId , conversation . getRoleId ( ) , MessageType . ASSISTANT , " " , sendReqVO . getUseContext ( ) ) ;
userId , conversation . getRoleId ( ) , MessageType . ASSISTANT , " " , sendReqVO . getUseContext ( ) ,
knowledgeSegments ) ;
// 3.2 召回段落
List < AiKnowledgeSegmentDO > segmentList = recallSegment ( sendReqVO . getContent ( ) , conversation . getKnowledgeId ( ) ) ;
// 3.3 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt ( conversation , historyMessages , segmentList , model , sendReqVO ) ;
// 3.2 创建 chat 需要的 Prompt
Prompt prompt = buildPrompt ( conversation , historyMessages , knowledgeSegments , model , sendReqVO ) ;
ChatResponse chatResponse = chatModel . call ( prompt ) ;
// 3.4 段式返回
// 3.3 段式返回
String newContent = chatResponse . getResult ( ) . getOutput ( ) . getText ( ) ;
chatMessageMapper . updateById ( new AiChatMessageDO ( ) . setId ( assistantMessage . getId ( ) ) . setSegmentIds ( convertList ( segmentList , AiKnowledgeSegmentDO : : getId ) ) . setContent( newContent ) ) ;
return new AiChatMessageSendRespVO ( ) . setSend ( BeanUtils . toBean ( userMessage , AiChatMessageSendRespVO . Message . class ) )
. setReceive ( BeanUtils . toBean ( assistant Message, AiChatMessageSendRespVO . Message . class ) . setContent ( newContent ) ) ;
chatMessageMapper . updateById ( new AiChatMessageDO ( ) . setId ( assistantMessage . getId ( ) ) . setContent ( newContent ) ) ;
return new AiChatMessageSendRespVO ( )
. setSend ( BeanUtils . toBean ( user Message, AiChatMessageSendRespVO . Message . class ) )
. setReceive ( BeanUtils . toBean ( assistantMessage , AiChatMessageSendRespVO . Message . class )
. setContent ( newContent ) ) ;
}
@Override
public Flux < CommonResult < AiChatMessageSendRespVO > > sendChatMessageStream ( AiChatMessageSendReqVO sendReqVO , Long userId ) {
public Flux < CommonResult < AiChatMessageSendRespVO > > sendChatMessageStream ( AiChatMessageSendReqVO sendReqVO ,
Long userId ) {
// 1.1 校验对话存在
AiChatConversationDO conversation = chatConversationService . validateChatConversationExists ( sendReqVO . getConversationId ( ) ) ;
AiChatConversationDO conversation = chatConversationService
. validateChatConversationExists ( sendReqVO . getConversationId ( ) ) ;
if ( ObjUtil . notEqual ( conversation . getUserId ( ) , userId ) ) {
throw exception ( CHAT_CONVERSATION_NOT_EXISTS ) ;
}
List < AiChatMessageDO > historyMessages = chatMessageMapper . selectListByConversationId ( conversation . getId ( ) ) ;
// 1.2 校验模型
AiModelDO model = modalService . validateModel ( conversation . getModelId ( ) ) ;
StreamingChatModel chatModel = modalService . getChatModel ( model . getKey Id ( ) ) ;
StreamingChatModel chatModel = modalService . getChatModel ( model . getId ( ) ) ;
// 2. 插入 user 发送消息
// 2. 知识库找回
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = recallKnowledgeSegment ( sendReqVO . getContent ( ) ,
conversation ) ;
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage ( conversation . getId ( ) , null , model ,
userId , conversation . getRoleId ( ) , MessageType . USER , sendReqVO . getContent ( ) , sendReqVO . getUseContext ( ) ) ;
userId , conversation . getRoleId ( ) , MessageType . USER , sendReqVO . getContent ( ) , sendReqVO . getUseContext ( ) ,
null ) ;
// 3 .1 插入 assistant 接收消息
// 4 .1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage ( conversation . getId ( ) , userMessage . getId ( ) , model ,
userId , conversation . getRoleId ( ) , MessageType . ASSISTANT , " " , sendReqVO . getUseContext ( ) ) ;
userId , conversation . getRoleId ( ) , MessageType . ASSISTANT , " " , sendReqVO . getUseContext ( ) ,
knowledgeSegments ) ;
// 3 .2 召回段落
List < AiKnowledgeSegmentDO > segmentList = recallSegment ( sendReqVO . getContent ( ) , conversation . getKnowledgeId ( ) ) ;
// 3.3 构建 Prompt, 并进行调用
Prompt prompt = buildPrompt ( conversation , historyMessages , segmentList , model , sendReqVO ) ;
// 4 .2 构建 Prompt, 并进行调用
Prompt prompt = buildPrompt ( conversation , historyMessages , knowledgeSegments , model , sendReqVO ) ;
Flux < ChatResponse > streamResponse = chatModel . stream ( prompt ) ;
// 3.4 流式返回
// 4.3 流式返回
StringBuffer contentBuffer = new StringBuffer ( ) ;
return streamResponse . map ( chunk - > {
String newContent = chunk . getResult ( ) ! = null ? chunk . getResult ( ) . getOutput ( ) . getText ( ) : null ;
newContent = StrUtil . nullToDefault ( newContent , " " ) ; // 避免 null 的 情况
contentBuffer . append ( newContent ) ;
// 响应结果
return success ( new AiChatMessageSendRespVO ( ) . setSend ( BeanUtils . toBean ( userMessage , AiChatMessageSendRespVO . Message . class ) )
. setReceive ( BeanUtils . toBean ( assistant Message, AiChatMessageSendRespVO . Message . class ) . setContent ( newContent ) ) ) ;
return success ( new AiChatMessageSendRespVO ( )
. setSend ( BeanUtils . toBean ( user Message, AiChatMessageSendRespVO . Message . class ) )
. setReceive ( BeanUtils . toBean ( assistantMessage , AiChatMessageSendRespVO . Message . class )
. setContent ( newContent ) ) ) ;
} ) . doOnComplete ( ( ) - > {
// 忽略租户,因为 Flux 异步无法透传租户
TenantUtils . executeIgnore ( ( ) - >
chatMessageMapper . updateById ( new AiChatMessageDO ( ) . setId ( assistantMessage . getId ( ) ) . setSegmentIds ( convertList ( segmentList , AiKnowledgeSegmentDO : : getId ) )
. setContent ( contentBuffer . toString ( ) ) ) ) ;
TenantUtils . executeIgnore ( ( ) - > chatMessageMapper . updateById (
new AiChatMessageDO ( ) . setId ( assistantMessage . getId ( ) ) . setContent ( contentBuffer . toString ( ) ) ) ) ;
} ) . doOnError ( throwable - > {
log . error ( " [sendChatMessageStream][userId({}) sendReqVO({}) 发生异常] " , userId , sendReqVO , throwable ) ;
// 忽略租户,因为 Flux 异步无法透传租户
TenantUtils . executeIgnore ( ( ) - >
chatMessageMapper . updateById ( new AiChatMessageDO ( ) . setId ( assistantMessage . getId ( ) ) . setContent ( throwable . getMessage ( ) ) ) ) ;
TenantUtils . executeIgnore ( ( ) - > chatMessageMapper . updateById (
new AiChatMessageDO ( ) . setId ( assistantMessage . getId ( ) ) . setContent ( throwable . getMessage ( ) ) ) ) ;
} ) . onErrorResume ( error - > Flux . just ( error ( ErrorCodeConstants . CHAT_STREAM_ERROR ) ) ) ;
}
private List < AiKnowledgeSegmentD O > recallSegment ( String content , Long knowledgeId ) {
if ( Objects . isNull ( knowledgeId ) ) {
private List < AiKnowledgeSegmentSearchRespB O > recallKnowledge Segment ( String content ,
AiChatConversationDO conversation ) {
// 1. 查询聊天角色
if ( conversation = = null | | conversation . getRoleId ( ) = = null ) {
return Collections . emptyList ( ) ;
}
// return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content)) ;
return null ;
}
private Prompt buildPrompt ( AiChatConversationDO conversation , List < AiChatMessageDO > messages , List < AiKnowledgeSegmentDO > segmentList ,
AiModelDO model , AiChatMessageSendReqVO sendReqVO ) {
// 1. 构建 Prompt Message 列表
List < Message > chatMessages = new ArrayList < > ( ) ;
// 1.1 召回内容消息构建
if ( CollUtil . isNotEmpty ( segmentList ) ) {
PromptTemplate promptTemplate = new PromptTemplate ( AiChatRoleEnum . AI_KNOWLEDGE_ROLE . getSystemMessage ( ) ) ;
StringBuilder infoBuilder = StrUtil . builder ( ) ;
segmentList . forEach ( segment - > infoBuilder . append ( System . lineSeparator ( ) ) . append ( segment . getContent ( ) ) ) ;
Message message = promptTemplate . createMessage ( Map . of ( " info " , infoBuilder . toString ( ) ) ) ;
chatMessages . add ( message ) ;
AiChatRoleDO role = chatRoleService . getChatRole ( conversation . getRoleId ( ) ) ;
if ( role = = null | | CollUtil . isEmpty ( role . getKnowledgeIds ( ) ) ) {
return Collections . emptyList ( ) ;
}
// 1.2 system context 角色设定
// 2. 遍历找回
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments = new ArrayList < > ( ) ;
for ( Long knowledgeId : role . getKnowledgeIds ( ) ) {
knowledgeSegments . addAll ( knowledgeSegmentService . searchKnowledgeSegment ( new AiKnowledgeSegmentSearchReqBO ( )
. setKnowledgeId ( knowledgeId ) . setContent ( content ) ) ) ;
}
return knowledgeSegments ;
}
private Prompt buildPrompt ( AiChatConversationDO conversation , List < AiChatMessageDO > messages ,
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments ,
AiModelDO model , AiChatMessageSendReqVO sendReqVO ) {
List < Message > chatMessages = new ArrayList < > ( ) ;
// 1.1 System Context 角色设定
if ( StrUtil . isNotBlank ( conversation . getSystemMessage ( ) ) ) {
chatMessages . add ( new SystemMessage ( conversation . getSystemMessage ( ) ) ) ;
}
// 1.3 history message 历史消息
// 1.2 历史 history message 历史消息
List < AiChatMessageDO > contextMessages = filterContextMessages ( messages , conversation , sendReqVO ) ;
contextMessages . forEach ( message - > chatMessages . add ( AiUtils . buildMessage ( message . getType ( ) , message . getContent ( ) ) ) ) ;
// 1.4 user message 新发送消息
contextMessages
. forEach ( message - > chatMessages . add ( AiUtils . buildMessage ( message . getType ( ) , message . getContent ( ) ) ) ) ;
// 1.3 当前 user message 新发送消息
chatMessages . add ( new UserMessage ( sendReqVO . getContent ( ) ) ) ;
// 1.4 知识库,通过 UserMessage 实现
if ( CollUtil . isNotEmpty ( knowledgeSegments ) ) {
String reference = knowledgeSegments . stream ( )
. map ( segment - > " <Reference> \ n " + segment . getContent ( ) + " </Reference> " )
. collect ( Collectors . joining ( " \ n \ n " ) ) ;
chatMessages . add ( new UserMessage ( String . format ( KNOWLEDGE_USER_MESSAGE_TEMPLATE , reference ) ) ) ;
}
// 2. 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum . validatePlatform ( model . getPlatform ( ) ) ;
ChatOptions chatOptions = AiUtils . buildChatOptions ( platform , model . getModel ( ) ,
@@ -199,8 +235,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
* @return 消息上下文
*/
private List < AiChatMessageDO > filterContextMessages ( List < AiChatMessageDO > messages ,
AiChatConversationDO conversation ,
AiChatMessageSendReqVO sendReqVO ) {
AiChatConversationDO conversation ,
AiChatMessageSendReqVO sendReqVO ) {
if ( conversation . getMaxContexts ( ) = = null | | ObjUtil . notEqual ( sendReqVO . getUseContext ( ) , Boolean . TRUE ) ) {
return Collections . emptyList ( ) ;
}
@@ -211,7 +247,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
continue ;
}
AiChatMessageDO userMessage = CollUtil . get ( messages , i - 1 ) ;
if ( userMessage = = null | | ObjUtil . notEqual ( assistantMessage . getReplyId ( ) , userMessage . getId ( ) )
if ( userMessage = = null
| | ObjUtil . notEqual ( assistantMessage . getReplyId ( ) , userMessage . getId ( ) )
| | StrUtil . isEmpty ( assistantMessage . getContent ( ) ) ) {
continue ;
}
@@ -228,11 +265,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
private AiChatMessageDO createChatMessage ( Long conversationId , Long replyId ,
AiModelDO model , Long userId , Long roleId ,
MessageType messageType , String content , Boolean useContext ) {
AiModelDO model , Long userId , Long roleId ,
MessageType messageType , String content , Boolean useContext ,
List < AiKnowledgeSegmentSearchRespBO > knowledgeSegments ) {
AiChatMessageDO message = new AiChatMessageDO ( ) . setConversationId ( conversationId ) . setReplyId ( replyId )
. setModel ( model . getModel ( ) ) . setModelId ( model . getId ( ) ) . setUserId ( userId ) . setRoleId ( roleId )
. setType ( messageType . getValue ( ) ) . setContent ( content ) . setUseContext ( useContext ) ;
. setType ( messageType . getValue ( ) ) . setContent ( content ) . setUseContext ( useContext )
. setSegmentIds ( convertList ( knowledgeSegments , AiKnowledgeSegmentSearchRespBO : : getId ) ) ;
message . setCreateTime ( LocalDateTime . now ( ) ) ;
chatMessageMapper . insert ( message ) ;
return message ;