【功能新增】AI:新增 document 向量的进度查询
This commit is contained in:
@@ -99,9 +99,10 @@ public interface AiModelFactory {
|
||||
* <p>
|
||||
* 如果不存在,则进行创建
|
||||
*
|
||||
* @param type 向量存储类型
|
||||
* @param embeddingModel 向量模型
|
||||
* @return VectorStore 对象
|
||||
*/
|
||||
VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel);
|
||||
VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel);
|
||||
|
||||
}
|
||||
|
@@ -1,9 +1,11 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.factory;
|
||||
|
||||
import cn.hutool.core.io.FileUtil;
|
||||
import cn.hutool.core.lang.Assert;
|
||||
import cn.hutool.core.lang.Singleton;
|
||||
import cn.hutool.core.lang.func.Func0;
|
||||
import cn.hutool.core.util.ArrayUtil;
|
||||
import cn.hutool.core.util.RuntimeUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
|
||||
@@ -24,6 +26,7 @@ import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingModel;
|
||||
import com.alibaba.cloud.ai.dashscope.embedding.DashScopeEmbeddingOptions;
|
||||
import com.alibaba.cloud.ai.dashscope.image.DashScopeImageModel;
|
||||
import com.azure.ai.openai.OpenAIClientBuilder;
|
||||
import lombok.SneakyThrows;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiChatProperties;
|
||||
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiConnectionProperties;
|
||||
@@ -60,7 +63,11 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
|
||||
import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
import java.io.File;
|
||||
import java.time.Duration;
|
||||
import java.util.List;
|
||||
import java.util.Timer;
|
||||
import java.util.TimerTask;
|
||||
|
||||
/**
|
||||
* AI Model 模型工厂的实现类
|
||||
@@ -73,7 +80,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
public ChatModel getOrCreateChatModel(AiPlatformEnum platform, String apiKey, String url) {
|
||||
String cacheKey = buildClientCacheKey(ChatModel.class, platform, apiKey, url);
|
||||
return Singleton.get(cacheKey, (Func0<ChatModel>) () -> {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return buildTongYiChatModel(apiKey);
|
||||
@@ -105,7 +112,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
|
||||
@Override
|
||||
public ChatModel getDefaultChatModel(AiPlatformEnum platform) {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return SpringUtil.getBean(DashScopeChatModel.class);
|
||||
@@ -136,7 +143,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
|
||||
@Override
|
||||
public ImageModel getDefaultImageModel(AiPlatformEnum platform) {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return SpringUtil.getBean(DashScopeImageModel.class);
|
||||
@@ -155,7 +162,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
|
||||
@Override
|
||||
public ImageModel getOrCreateImageModel(AiPlatformEnum platform, String apiKey, String url) {
|
||||
//noinspection EnhancedSwitchMigration
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
return buildTongYiImagesModel(apiKey);
|
||||
@@ -174,9 +181,11 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
|
||||
@Override
|
||||
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
|
||||
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url);
|
||||
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey,
|
||||
url);
|
||||
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
|
||||
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney();
|
||||
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class)
|
||||
.getMidjourney();
|
||||
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
|
||||
});
|
||||
}
|
||||
@@ -204,25 +213,31 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorStore getOrCreateVectorStore(EmbeddingModel embeddingModel) {
|
||||
// String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey, url);
|
||||
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel);
|
||||
public VectorStore getOrCreateVectorStore(Class<? extends VectorStore> type, EmbeddingModel embeddingModel) {
|
||||
// String cacheKey = buildClientCacheKey(VectorStore.class, platform, apiKey,
|
||||
// url);
|
||||
String cacheKey = buildClientCacheKey(VectorStore.class, embeddingModel, type);
|
||||
return Singleton.get(cacheKey, (Func0<VectorStore>) () -> {
|
||||
if (type == SimpleVectorStore.class) {
|
||||
return buildSimpleVectorStore(embeddingModel);
|
||||
}
|
||||
throw new IllegalArgumentException(StrUtil.format("未知类型({})", type));
|
||||
// TODO @芋艿:先临时使用 store
|
||||
return SimpleVectorStore.builder(embeddingModel).build();
|
||||
// TODO @芋艿:@xin:后续看看,是不是切到阿里云之类的
|
||||
// String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
|
||||
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
||||
// .withIndexName(cacheKey)
|
||||
// .withPrefix(prefix)
|
||||
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId", Schema.FieldType.NUMERIC))
|
||||
// .build();
|
||||
// RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
|
||||
// RedisVectorStore redisVectorStore = new RedisVectorStore(config, embeddingModel,
|
||||
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
||||
// true);
|
||||
// redisVectorStore.afterPropertiesSet();
|
||||
// return redisVectorStore;
|
||||
// String prefix = StrUtil.format("{}#{}:", platform.getPlatform(), apiKey);
|
||||
// var config = RedisVectorStore.RedisVectorStoreConfig.builder()
|
||||
// .withIndexName(cacheKey)
|
||||
// .withPrefix(prefix)
|
||||
// .withMetadataFields(new RedisVectorStore.MetadataField("knowledgeId",
|
||||
// Schema.FieldType.NUMERIC))
|
||||
// .build();
|
||||
// RedisProperties redisProperties = SpringUtils.getBean(RedisProperties.class);
|
||||
// RedisVectorStore redisVectorStore = new RedisVectorStore(config,
|
||||
// embeddingModel,
|
||||
// new JedisPooled(redisProperties.getHost(), redisProperties.getPort()),
|
||||
// true);
|
||||
// redisVectorStore.afterPropertiesSet();
|
||||
// return redisVectorStore;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -307,7 +322,7 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
*/
|
||||
private ChatModel buildSiliconFlowChatModel(String apiKey) {
|
||||
YudaoAiProperties.SiliconFlowProperties properties = new YudaoAiProperties.SiliconFlowProperties()
|
||||
.setApiKey(apiKey);
|
||||
.setApiKey(apiKey);
|
||||
return new YudaoAiAutoConfiguration().buildSiliconFlowChatClient(properties);
|
||||
}
|
||||
|
||||
@@ -397,7 +412,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
*/
|
||||
private DashScopeEmbeddingModel buildTongYiEmbeddingModel(String apiKey, String model) {
|
||||
DashScopeApi dashScopeApi = new DashScopeApi(apiKey);
|
||||
DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model).build();
|
||||
DashScopeEmbeddingOptions dashScopeEmbeddingOptions = DashScopeEmbeddingOptions.builder().withModel(model)
|
||||
.build();
|
||||
return new DashScopeEmbeddingModel(dashScopeApi, MetadataMode.EMBED, dashScopeEmbeddingOptions);
|
||||
}
|
||||
|
||||
@@ -407,4 +423,58 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return OllamaEmbeddingModel.builder().ollamaApi(ollamaApi).defaultOptions(ollamaOptions).build();
|
||||
}
|
||||
|
||||
// ========== 各种创建 VectorStore 的方法 ==========
|
||||
|
||||
/**
|
||||
* 注意:仅适合本地测试使用,生产建议还是使用 Qdrant、Milvus 等
|
||||
*/
|
||||
@SneakyThrows
|
||||
@SuppressWarnings("ResultOfMethodCallIgnored")
|
||||
private SimpleVectorStore buildSimpleVectorStore(EmbeddingModel embeddingModel) {
|
||||
SimpleVectorStore vectorStore = SimpleVectorStore.builder(embeddingModel).build();
|
||||
// 启动加载
|
||||
File file = new File(StrUtil.format("{}/vector_store/simple_{}.json",
|
||||
FileUtil.getUserHomePath(), embeddingModel.getClass().getSimpleName()));
|
||||
if (!file.exists()) {
|
||||
FileUtil.mkParentDirs(file);
|
||||
file.createNewFile();
|
||||
} else if (file.length() > 0) {
|
||||
vectorStore.load(file);
|
||||
}
|
||||
// 定时持久化,每分钟一次
|
||||
Timer timer = new Timer("SimpleVectorStoreTimer-" + file.getAbsolutePath());
|
||||
timer.scheduleAtFixedRate(new TimerTask() {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
vectorStore.save(file);
|
||||
}
|
||||
|
||||
}, Duration.ofMinutes(1).toMillis(), Duration.ofMinutes(1).toMillis());
|
||||
// 关闭时,进行持久化
|
||||
RuntimeUtil.addShutdownHook(() -> vectorStore.save(file));
|
||||
return vectorStore;
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建向量存储文件
|
||||
*
|
||||
* @param embeddingModel 嵌入模型
|
||||
* @return 向量存储文件
|
||||
*/
|
||||
private File createVectorStoreFile(EmbeddingModel embeddingModel) {
|
||||
// 获取简单类名
|
||||
String simpleClassName = embeddingModel.getClass().getSimpleName();
|
||||
// 获取用户主目录
|
||||
String userHome = FileUtil.getUserHomePath();
|
||||
// 创建vector_store目录
|
||||
File vectorStoreDir = new File(userHome, "vector_store");
|
||||
if (!vectorStoreDir.exists()) {
|
||||
vectorStoreDir.mkdirs();
|
||||
}
|
||||
|
||||
// 创建文件
|
||||
return new File(vectorStoreDir, "simple_" + simpleClassName + ".json");
|
||||
}
|
||||
|
||||
}
|
||||
|
Reference in New Issue
Block a user