【功能新增】AI:新增 document 向量的进度查询

This commit is contained in:
YunaiV
2025-03-02 20:54:02 +08:00
parent ebd93514b3
commit 5f5e77a392
20 changed files with 385 additions and 61 deletions

View File

@@ -99,9 +99,10 @@ public interface AiModelFactory {
* <p>
* 如果不存在,则进行创建
*
* @param type 向量存储类型
* @param embeddingModel 向量模型
* @return VectorStore 对象
*/
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel);
VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel);
}

View File

@@ -1,9 +1,11 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.lang.Singleton;
import cn.hutool.core.lang.func.Func0;
import cn.hutool.core.util.ArrayUtil;
import cn.hutool.core.util.RuntimeUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
@@ -24,6 +26,7 @@ import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions;
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
import com.azure.ai.openai.OpenAIClientBuilder;
import lombok.SneakyThrows;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
@@ -60,7 +63,11 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
import org.springframework.web.client.RestClient;
import java.io.File;
import java.time.Duration;
import java.util.List;
import java.util.Timer;
import java.util.TimerTask;
/**
* AI Model 模型工厂的实现类
@@ -73,7 +80,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
//noinspection EnhancedSwitchMigration
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return buildTongYiChatModel(apiKey);
@@ -105,7 +112,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return SpringUtil.getBean(DashScopeChatModel.class);
@@ -136,7 +143,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return SpringUtil.getBean(DashScopeImageModel.class);
@@ -155,7 +162,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return buildTongYiImagesModel(apiKey);
@@ -174,9 +181,11 @@ public class AiModelFactoryImpl implements AiModelFactory {
@Override
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url);
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey,
url);
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney();
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class)
.getMidjourney();
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
});
}
@@ -204,25 +213,31 @@ public class AiModelFactoryImpl implements AiModelFactory {
}
@Override
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel) {
// String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel);
public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel) {
// String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey,
// url);
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel, type);
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
if (type == SimpleVectorStore.class) {
return buildSimpleVectorStore(embeddingModel);
}
throw new IllegalArgumentException(StrUtil.format("未知类型({})", type));
// TODO @芋艿:先临时使用 store
return SimpleVectorStore.builder(embeddingModel).build();
// TODO @芋艿:@xin后续看看是不是切到阿里云之类的
// String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(cacheKey)
// .withPrefix(prefix)
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
// .build();
// RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
// true);
// redisVectorStore.afterPropertiesSet();
// return redisVectorStore;
// String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
// .withIndexName(cacheKey)
// .withPrefix(prefix)
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId",
// Schema.FieldType.NUMERIC))
// .build();
// RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
// RedisVectorStore redisVectorStore = new RedisVectorStore(config,
// embeddingModel,
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
// true);
// redisVectorStore.afterPropertiesSet();
// return redisVectorStore;
});
}
@@ -307,7 +322,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
*/
private ChatModel buildSiliconFlowChatModel(String apiKey) {
YudaoAiProperties.SiliconFlowProperties properties = new YudaoAiProperties.SiliconFlowProperties()
.setApiKey(apiKey);
.setApiKey(apiKey);
return new YudaoAiAutoConfiguration().buildSiliconFlowChatClient(properties);
}
@@ -397,7 +412,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
*/
private DashScopeEmbeddingModel buildTongYiEmbeddingModel(String apiKey, String model) {
DashScopeApi dashScopeApi = new DashScopeApi(apiKey);
DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model).build();
DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model)
.build();
return new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, dashScopeEmbeddingOptions);
}
@@ -407,4 +423,58 @@ public class AiModelFactoryImpl implements AiModelFactory {
return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build();
}
// ========== 各种创建 VectorStore 的方法 ==========
/**
* 注意:仅适合本地测试使用,生产建议还是使用 Qdrant、Milvus 等
*/
@SneakyThrows
@SuppressWarnings("ResultOfMethodCallIgnored")
private SimpleVectorStore buildSimpleVectorStore(EmbeddingModel embeddingModel) {
SimpleVectorStore vectorStore = SimpleVectorStore.builder(embeddingModel).build();
// 启动加载
File file = new File(StrUtil.format("{}/vector_store/simple_{}.json",
FileUtil.getUserHomePath(), embeddingModel.getClass().getSimpleName()));
if (!file.exists()) {
FileUtil.mkParentDirs(file);
file.createNewFile();
} else if (file.length() > 0) {
vectorStore.load(file);
}
// 定时持久化,每分钟一次
Timer timer = new Timer("SimpleVectorStoreTimer-" + file.getAbsolutePath());
timer.scheduleAtFixedRate(new TimerTask() {
@Override
public void run() {
vectorStore.save(file);
}
}, Duration.ofMinutes(1).toMillis(), Duration.ofMinutes(1).toMillis());
// 关闭时,进行持久化
RuntimeUtil.addShutdownHook(() -> vectorStore.save(file));
return vectorStore;
}
/**
* 创建向量存储文件
*
* @param embeddingModel 嵌入模型
* @return 向量存储文件
*/
private File createVectorStoreFile(EmbeddingModel embeddingModel) {
// 获取简单类名
String simpleClassName = embeddingModel.getClass().getSimpleName();
// 获取用户主目录
String userHome = FileUtil.getUserHomePath();
// 创建vector_store目录
File vectorStoreDir = new File(userHome, "vector_store");
if (!vectorStoreDir.exists()) {
vectorStoreDir.mkdirs();
}
// 创建文件
return new File(vectorStoreDir, "simple_" + simpleClassName + ".json");
}
}