Merge remote-tracking branch 'yd_origin/feature/ai' into feature/ai
This commit is contained in:
@@ -61,4 +61,8 @@ public interface ErrorCodeConstants {
|
||||
ErrorCode TOOL_NOT_EXISTS = new ErrorCode(1_040_010_000, "工具不存在");
|
||||
ErrorCode TOOL_NAME_NOT_EXISTS = new ErrorCode(1_040_010_001, "工具({})找不到 Bean");
|
||||
|
||||
// ========== AI 工作流 1-040-011-000 ==========
|
||||
ErrorCode WORKFLOW_NOT_EXISTS = new ErrorCode(1_040_011_000, "工作流不存在");
|
||||
ErrorCode WORKFLOW_CODE_EXISTS = new ErrorCode(1_040_011_001, "工作流标识已存在");
|
||||
|
||||
}
|
||||
|
@@ -0,0 +1,77 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.workflow;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.*;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.workflow.AiWorkflowDO;
|
||||
import cn.iocoder.yudao.module.ai.service.workflow.AiWorkflowService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
import jakarta.validation.Valid;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.security.access.prepost.PreAuthorize;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
|
||||
@Tag(name = "管理后台 - AI 工作流")
|
||||
@RestController
|
||||
@RequestMapping("/ai/workflow")
|
||||
@Slf4j
|
||||
public class AiWorkflowController {
|
||||
|
||||
@Resource
|
||||
private AiWorkflowService workflowService;
|
||||
|
||||
@PostMapping("/create")
|
||||
@Operation(summary = "创建 AI 工作流")
|
||||
@PreAuthorize("@ss.hasPermission('ai:workflow:create')")
|
||||
public CommonResult<Long> createWorkflow(@Valid @RequestBody AiWorkflowSaveReqVO createReqVO) {
|
||||
return success(workflowService.createWorkflow(createReqVO));
|
||||
}
|
||||
|
||||
@PutMapping("/update")
|
||||
@Operation(summary = "更新 AI 工作流")
|
||||
@PreAuthorize("@ss.hasPermission('ai:workflow:update')")
|
||||
public CommonResult<Boolean> updateWorkflow(@Valid @RequestBody AiWorkflowSaveReqVO updateReqVO) {
|
||||
workflowService.updateWorkflow(updateReqVO);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@DeleteMapping("/delete")
|
||||
@Operation(summary = "删除 AI 工作流")
|
||||
@Parameter(name = "id", description = "编号", required = true)
|
||||
@PreAuthorize("@ss.hasPermission('ai:workflow:delete')")
|
||||
public CommonResult<Boolean> deleteWorkflow(@RequestParam("id") Long id) {
|
||||
workflowService.deleteWorkflow(id);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@GetMapping("/get")
|
||||
@Operation(summary = "获得 AI 工作流")
|
||||
@Parameter(name = "id", description = "编号", required = true, example = "1024")
|
||||
@PreAuthorize("@ss.hasPermission('ai:workflow:query')")
|
||||
public CommonResult<AiWorkflowRespVO> getWorkflow(@RequestParam("id") Long id) {
|
||||
AiWorkflowDO workflow = workflowService.getWorkflow(id);
|
||||
return success(BeanUtils.toBean(workflow, AiWorkflowRespVO.class));
|
||||
}
|
||||
|
||||
@GetMapping("/page")
|
||||
@Operation(summary = "获得 AI 工作流分页")
|
||||
@PreAuthorize("@ss.hasPermission('ai:workflow:query')")
|
||||
public CommonResult<PageResult<AiWorkflowRespVO>> getWorkflowPage(@Valid AiWorkflowPageReqVO pageReqVO) {
|
||||
PageResult<AiWorkflowDO> pageResult = workflowService.getWorkflowPage(pageReqVO);
|
||||
return success(BeanUtils.toBean(pageResult, AiWorkflowRespVO.class));
|
||||
}
|
||||
|
||||
@PostMapping("/test")
|
||||
@Operation(summary = "测试 AI 工作流")
|
||||
@PreAuthorize("@ss.hasPermission('ai:workflow:test')")
|
||||
public CommonResult<Object> testWorkflow(@Valid @RequestBody AiWorkflowTestReqVO testReqVO) {
|
||||
return success(workflowService.testWorkflow(testReqVO));
|
||||
}
|
||||
|
||||
}
|
@@ -0,0 +1,32 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.workflow.vo;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import cn.iocoder.yudao.framework.common.validation.InEnum;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
import org.springframework.format.annotation.DateTimeFormat;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND;
|
||||
|
||||
@Schema(description = "管理后台 - AI 工作流分页 Request VO")
|
||||
@Data
|
||||
public class AiWorkflowPageReqVO extends PageParam {
|
||||
|
||||
@Schema(description = "名称", example = "工作流")
|
||||
private String name;
|
||||
|
||||
@Schema(description = "标识", example = "FLOW")
|
||||
private String code;
|
||||
|
||||
@Schema(description = "状态", example = "1")
|
||||
@InEnum(CommonStatusEnum.class)
|
||||
private Integer status;
|
||||
|
||||
@Schema(description = "创建时间")
|
||||
@DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND)
|
||||
private LocalDateTime[] createTime;
|
||||
|
||||
}
|
@@ -0,0 +1,33 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.workflow.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
@Schema(description = "管理后台 - AI 工作流 Response VO")
|
||||
@Data
|
||||
public class AiWorkflowRespVO {
|
||||
|
||||
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
private Long id;
|
||||
|
||||
@Schema(description = "工作流标识", requiredMode = Schema.RequiredMode.REQUIRED, example = "FLOW")
|
||||
private String code;
|
||||
|
||||
@Schema(description = "工作流名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "工作流")
|
||||
private String name;
|
||||
|
||||
@Schema(description = "备注", requiredMode = Schema.RequiredMode.REQUIRED, example = "工作流")
|
||||
private String remark;
|
||||
|
||||
@Schema(description = "状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
private Integer status;
|
||||
|
||||
@Schema(description = "工作流模型 JSON", requiredMode = Schema.RequiredMode.REQUIRED, example = "{}")
|
||||
private String graph;
|
||||
|
||||
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED, example = "时间戳格式")
|
||||
private LocalDateTime createTime;
|
||||
|
||||
}
|
@@ -0,0 +1,34 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.workflow.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
@Schema(description = "管理后台 - AI 工作流新增/修改 Request VO")
|
||||
@Data
|
||||
public class AiWorkflowSaveReqVO {
|
||||
|
||||
@Schema(description = "编号", example = "1")
|
||||
private Long id;
|
||||
|
||||
@Schema(description = "工作流标识", requiredMode = Schema.RequiredMode.REQUIRED, example = "FLOW")
|
||||
@NotEmpty(message = "工作流标识不能为空")
|
||||
private String code;
|
||||
|
||||
@Schema(description = "工作流名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "工作流")
|
||||
@NotEmpty(message = "工作流名称不能为空")
|
||||
private String name;
|
||||
|
||||
@Schema(description = "备注", example = "FLOW")
|
||||
private String remark;
|
||||
|
||||
@Schema(description = "工作流模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "{}")
|
||||
@NotEmpty(message = "工作流模型不能为空")
|
||||
private String graph;
|
||||
|
||||
@Schema(description = "状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "FLOW")
|
||||
@NotNull(message = "状态不能为空")
|
||||
private Integer status;
|
||||
|
||||
}
|
@@ -0,0 +1,20 @@
|
||||
package cn.iocoder.yudao.module.ai.controller.admin.workflow.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Schema(description = "管理后台 - AI 工作流测试 Request VO")
|
||||
@Data
|
||||
public class AiWorkflowTestReqVO {
|
||||
|
||||
@Schema(description = "工作流模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "{}")
|
||||
@NotEmpty(message = "工作流模型不能为空")
|
||||
private String graph;
|
||||
|
||||
@Schema(description = "参数", requiredMode = Schema.RequiredMode.REQUIRED, example = "{}")
|
||||
private Map<String, Object> params;
|
||||
|
||||
}
|
@@ -0,0 +1,51 @@
|
||||
package cn.iocoder.yudao.module.ai.dal.dataobject.workflow;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* AI 工作流 DO
|
||||
*
|
||||
* @author lesan
|
||||
*/
|
||||
@TableName(value = "ai_workflow", autoResultMap = true)
|
||||
@KeySequence("ai_workflow") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
|
||||
@Data
|
||||
public class AiWorkflowDO extends BaseDO {
|
||||
|
||||
/**
|
||||
* 编号
|
||||
*/
|
||||
@TableId
|
||||
private Long id;
|
||||
/**
|
||||
* 工作流名称
|
||||
*/
|
||||
private String name;
|
||||
/**
|
||||
* 工作流标识
|
||||
*/
|
||||
private String code;
|
||||
|
||||
/**
|
||||
* 工作流模型 JSON 数据
|
||||
*/
|
||||
private String graph;
|
||||
|
||||
/**
|
||||
* 备注
|
||||
*/
|
||||
private String remark;
|
||||
|
||||
/**
|
||||
* 状态
|
||||
*
|
||||
* 枚举 {@link CommonStatusEnum}
|
||||
*/
|
||||
private Integer status;
|
||||
|
||||
}
|
@@ -0,0 +1,30 @@
|
||||
package cn.iocoder.yudao.module.ai.dal.mysql.workflow;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.workflow.AiWorkflowDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
/**
|
||||
* AI 工作流 Mapper
|
||||
*
|
||||
* @author lesan
|
||||
*/
|
||||
@Mapper
|
||||
public interface AiWorkflowMapper extends BaseMapperX<AiWorkflowDO> {
|
||||
|
||||
default AiWorkflowDO selectByCode(String code) {
|
||||
return selectOne(AiWorkflowDO::getCode, code);
|
||||
}
|
||||
|
||||
default PageResult<AiWorkflowDO> selectPage(AiWorkflowPageReqVO pageReqVO) {
|
||||
return selectPage(pageReqVO, new LambdaQueryWrapperX<AiWorkflowDO>()
|
||||
.eqIfPresent(AiWorkflowDO::getStatus, pageReqVO.getStatus())
|
||||
.likeIfPresent(AiWorkflowDO::getName, pageReqVO.getName())
|
||||
.likeIfPresent(AiWorkflowDO::getCode, pageReqVO.getCode())
|
||||
.betweenIfPresent(AiWorkflowDO::getCreateTime, pageReqVO.getCreateTime()));
|
||||
}
|
||||
|
||||
}
|
@@ -11,6 +11,7 @@ import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.hutool.http.HttpUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageOptions;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
|
||||
@@ -144,7 +145,12 @@ public class AiImageServiceImpl implements AiImageService {
|
||||
.withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
|
||||
.withResponseFormat("b64_json")
|
||||
.build();
|
||||
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
||||
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.SILICON_FLOW.getPlatform())) {
|
||||
// https://docs.siliconflow.cn/cn/api-reference/images/images-generations
|
||||
return SiliconFlowImageOptions.builder().model(model.getModel())
|
||||
.height(draw.getHeight()).width(draw.getWidth())
|
||||
.build();
|
||||
} else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
|
||||
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
|
||||
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
|
||||
return StabilityAiImageOptions.builder().model(model.getModel())
|
||||
|
@@ -0,0 +1,62 @@
|
||||
package cn.iocoder.yudao.module.ai.service.workflow;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowTestReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.workflow.AiWorkflowDO;
|
||||
import jakarta.validation.Valid;
|
||||
|
||||
/**
|
||||
* AI 工作流 Service 接口
|
||||
*
|
||||
* @author lesan
|
||||
*/
|
||||
public interface AiWorkflowService {
|
||||
|
||||
/**
|
||||
* 创建 AI 工作流
|
||||
*
|
||||
* @param createReqVO 创建信息
|
||||
* @return 编号
|
||||
*/
|
||||
Long createWorkflow(@Valid AiWorkflowSaveReqVO createReqVO);
|
||||
|
||||
/**
|
||||
* 更新 AI 工作流
|
||||
*
|
||||
* @param updateReqVO 更新信息
|
||||
*/
|
||||
void updateWorkflow(@Valid AiWorkflowSaveReqVO updateReqVO);
|
||||
|
||||
/**
|
||||
* 删除 AI 工作流
|
||||
*
|
||||
* @param id 编号
|
||||
*/
|
||||
void deleteWorkflow(Long id);
|
||||
|
||||
/**
|
||||
* 获得 AI 工作流
|
||||
*
|
||||
* @param id 编号
|
||||
* @return AI 工作流
|
||||
*/
|
||||
AiWorkflowDO getWorkflow(Long id);
|
||||
|
||||
/**
|
||||
* 获得 AI 工作流分页
|
||||
*
|
||||
* @param pageReqVO 分页查询
|
||||
* @return AI 工作流分页
|
||||
*/
|
||||
PageResult<AiWorkflowDO> getWorkflowPage(AiWorkflowPageReqVO pageReqVO);
|
||||
|
||||
/**
|
||||
* 测试 AI 工作流
|
||||
*
|
||||
* @param testReqVO 测试数据
|
||||
*/
|
||||
Object testWorkflow(AiWorkflowTestReqVO testReqVO);
|
||||
|
||||
}
|
@@ -0,0 +1,150 @@
|
||||
package cn.iocoder.yudao.module.ai.service.workflow;
|
||||
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowPageReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowSaveReqVO;
|
||||
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowTestReqVO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.dataobject.workflow.AiWorkflowDO;
|
||||
import cn.iocoder.yudao.module.ai.dal.mysql.workflow.AiWorkflowMapper;
|
||||
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
||||
import com.alibaba.fastjson.JSONArray;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import dev.tinyflow.core.Tinyflow;
|
||||
import jakarta.annotation.Resource;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.WORKFLOW_CODE_EXISTS;
|
||||
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.WORKFLOW_NOT_EXISTS;
|
||||
|
||||
/**
|
||||
* AI 工作流 Service 实现类
|
||||
*
|
||||
* @author lesan
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class AiWorkflowServiceImpl implements AiWorkflowService {
|
||||
|
||||
@Resource
|
||||
private AiWorkflowMapper workflowMapper;
|
||||
|
||||
@Resource
|
||||
private AiApiKeyService apiKeyService;
|
||||
|
||||
@Override
|
||||
public Long createWorkflow(AiWorkflowSaveReqVO createReqVO) {
|
||||
validateWorkflowForCreateOrUpdate(null, createReqVO.getCode());
|
||||
AiWorkflowDO workflow = BeanUtils.toBean(createReqVO, AiWorkflowDO.class);
|
||||
workflowMapper.insert(workflow);
|
||||
return workflow.getId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateWorkflow(AiWorkflowSaveReqVO updateReqVO) {
|
||||
validateWorkflowForCreateOrUpdate(updateReqVO.getId(), updateReqVO.getCode());
|
||||
AiWorkflowDO workflow = BeanUtils.toBean(updateReqVO, AiWorkflowDO.class);
|
||||
workflowMapper.updateById(workflow);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteWorkflow(Long id) {
|
||||
validateWorkflowExists(id);
|
||||
workflowMapper.deleteById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiWorkflowDO getWorkflow(Long id) {
|
||||
return workflowMapper.selectById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageResult<AiWorkflowDO> getWorkflowPage(AiWorkflowPageReqVO pageReqVO) {
|
||||
return workflowMapper.selectPage(pageReqVO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object testWorkflow(AiWorkflowTestReqVO testReqVO) {
|
||||
Map<String, Object> variables = testReqVO.getParams();
|
||||
Tinyflow tinyflow = parseFlowParam(testReqVO.getGraph());
|
||||
return tinyflow.toChain().executeForResult(variables);
|
||||
}
|
||||
|
||||
private void validateWorkflowForCreateOrUpdate(Long id, String code) {
|
||||
validateWorkflowExists(id);
|
||||
validateCodeUnique(id, code);
|
||||
}
|
||||
|
||||
private void validateWorkflowExists(Long id) {
|
||||
if (ObjUtil.isNull(id)) {
|
||||
return;
|
||||
}
|
||||
AiWorkflowDO workflow = workflowMapper.selectById(id);
|
||||
if (ObjUtil.isNull(workflow)) {
|
||||
throw exception(WORKFLOW_NOT_EXISTS);
|
||||
}
|
||||
}
|
||||
|
||||
private void validateCodeUnique(Long id, String code) {
|
||||
if (StrUtil.isBlank(code)) {
|
||||
return;
|
||||
}
|
||||
AiWorkflowDO workflow = workflowMapper.selectByCode(code);
|
||||
if (ObjUtil.isNull(workflow)) {
|
||||
return;
|
||||
}
|
||||
if (ObjUtil.isNull(id)) {
|
||||
throw exception(WORKFLOW_CODE_EXISTS);
|
||||
}
|
||||
if (ObjUtil.notEqual(workflow.getId(), id)) {
|
||||
throw exception(WORKFLOW_CODE_EXISTS);
|
||||
}
|
||||
}
|
||||
|
||||
private Tinyflow parseFlowParam(String graph) {
|
||||
// TODO @lesan:可以使用 jackson 哇?
|
||||
JSONObject json = JSONObject.parseObject(graph);
|
||||
JSONArray nodeArr = json.getJSONArray("nodes");
|
||||
Tinyflow tinyflow = new Tinyflow(json.toJSONString());
|
||||
for (int i = 0; i < nodeArr.size(); i++) {
|
||||
JSONObject node = nodeArr.getJSONObject(i);
|
||||
switch (node.getString("type")) {
|
||||
case "llmNode":
|
||||
JSONObject data = node.getJSONObject("data");
|
||||
AiApiKeyDO apiKey = apiKeyService.getApiKey(data.getLong("llmId"));
|
||||
switch (apiKey.getPlatform()) {
|
||||
// TODO @lesan 需要讨论一下这里怎么弄
|
||||
// TODO @lesan llmId 对应 model 的编号如何?这样的话,就是 apiModelService 提供一个获取 LLM 的方法。然后,创建的方法,也在 AiModelFactory 提供。可以先接个 deepseek 先。deepseek yyds!
|
||||
case "OpenAI":
|
||||
break;
|
||||
case "Ollama":
|
||||
break;
|
||||
case "YiYan":
|
||||
break;
|
||||
case "XingHuo":
|
||||
break;
|
||||
case "TongYi":
|
||||
break;
|
||||
case "DeepSeek":
|
||||
break;
|
||||
case "ZhiPu":
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case "internalNode":
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return tinyflow;
|
||||
}
|
||||
|
||||
}
|
@@ -15,6 +15,7 @@
|
||||
<description>AI 大模型拓展,接入国内外大模型</description>
|
||||
<properties>
|
||||
<spring-ai.version>1.0.0-M6</spring-ai.version>
|
||||
<tinyflow.version>1.0.0-rc.3</tinyflow.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
@@ -117,6 +118,13 @@
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<!-- TinyFlow:AI 工作流 -->
|
||||
<dependency>
|
||||
<groupId>dev.tinyflow</groupId>
|
||||
<artifactId>tinyflow-java-core</artifactId>
|
||||
<version>${tinyflow.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!-- Test 测试相关 -->
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
|
@@ -4,10 +4,12 @@ import cn.hutool.core.util.StrUtil;
|
||||
import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
|
||||
import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactoryImpl;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
@@ -113,11 +115,11 @@ public class YudaoAiAutoConfiguration {
|
||||
|
||||
public SiliconFlowChatModel buildSiliconFlowChatClient(YudaoAiProperties.SiliconFlowProperties properties) {
|
||||
if (StrUtil.isEmpty(properties.getModel())) {
|
||||
properties.setModel(SiliconFlowChatModel.MODEL_DEFAULT);
|
||||
properties.setModel(SiliconFlowApiConstants.MODEL_DEFAULT);
|
||||
}
|
||||
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(SiliconFlowChatModel.BASE_URL)
|
||||
.baseUrl(SiliconFlowApiConstants.DEFAULT_BASE_URL)
|
||||
.apiKey(properties.getApiKey())
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
@@ -192,6 +194,33 @@ public class YudaoAiAutoConfiguration {
|
||||
return new XingHuoChatModel(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.baichuan.enable", havingValue = "true")
|
||||
public BaiChuanChatModel baiChuanChatClient(YudaoAiProperties yudaoAiProperties) {
|
||||
YudaoAiProperties.BaiChuanProperties properties = yudaoAiProperties.getBaichuan();
|
||||
return buildBaiChuanChatClient(properties);
|
||||
}
|
||||
|
||||
public BaiChuanChatModel buildBaiChuanChatClient(YudaoAiProperties.BaiChuanProperties properties) {
|
||||
if (StrUtil.isEmpty(properties.getModel())) {
|
||||
properties.setModel(BaiChuanChatModel.MODEL_DEFAULT);
|
||||
}
|
||||
OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(BaiChuanChatModel.BASE_URL)
|
||||
.apiKey(properties.getApiKey())
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(properties.getModel())
|
||||
.temperature(properties.getTemperature())
|
||||
.maxTokens(properties.getMaxTokens())
|
||||
.topP(properties.getTopP())
|
||||
.build())
|
||||
.toolCallingManager(getToolCallingManager())
|
||||
.build();
|
||||
return new BaiChuanChatModel(openAiChatModel);
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
|
||||
public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {
|
||||
|
@@ -43,6 +43,12 @@ public class YudaoAiProperties {
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private XingHuoProperties xinghuo;
|
||||
|
||||
/**
|
||||
* 百川
|
||||
*/
|
||||
@SuppressWarnings("SpellCheckingInspection")
|
||||
private BaiChuanProperties baichuan;
|
||||
|
||||
/**
|
||||
* Midjourney 绘图
|
||||
*/
|
||||
@@ -122,6 +128,19 @@ public class YudaoAiProperties {
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class BaiChuanProperties {
|
||||
|
||||
private String enable;
|
||||
private String apiKey;
|
||||
|
||||
private String model;
|
||||
private Double temperature;
|
||||
private Integer maxTokens;
|
||||
private Double topP;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class MidjourneyProperties {
|
||||
|
||||
|
@@ -27,6 +27,7 @@ public enum AiPlatformEnum implements ArrayValuable<String> {
|
||||
SILICON_FLOW("SiliconFlow", "硅基流动"), // 硅基流动
|
||||
MINI_MAX("MiniMax", "MiniMax"), // 稀宇科技
|
||||
MOONSHOT("Moonshot", "月之暗灭"), // KIMI
|
||||
BAI_CHUAN("BaiChuan", "百川智能"), // 百川智能
|
||||
|
||||
// ========== 国外平台 ==========
|
||||
|
||||
|
@@ -11,11 +11,15 @@ import cn.hutool.extra.spring.SpringUtil;
|
||||
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
|
||||
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.doubao.DouBaoChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.hunyuan.HunYuanChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
|
||||
import cn.iocoder.yudao.framework.common.util.spring.SpringUtils;
|
||||
@@ -42,6 +46,7 @@ import org.springframework.ai.autoconfigure.moonshot.MoonshotAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.stabilityai.StabilityAiImageAutoConfiguration;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientConnectionDetails;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusServiceClientProperties;
|
||||
import org.springframework.ai.autoconfigure.vectorstore.milvus.MilvusVectorStoreAutoConfiguration;
|
||||
@@ -146,6 +151,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return buildMoonshotChatModel(apiKey, url);
|
||||
case XING_HUO:
|
||||
return buildXingHuoChatModel(apiKey);
|
||||
case BAI_CHUAN:
|
||||
return buildBaiChuanChatModel(apiKey);
|
||||
case OPENAI:
|
||||
return buildOpenAiChatModel(apiKey, url);
|
||||
case AZURE_OPENAI:
|
||||
@@ -182,6 +189,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return SpringUtil.getBean(MoonshotChatModel.class);
|
||||
case XING_HUO:
|
||||
return SpringUtil.getBean(XingHuoChatModel.class);
|
||||
case BAI_CHUAN:
|
||||
return SpringUtil.getBean(AzureOpenAiChatModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiChatModel.class);
|
||||
case AZURE_OPENAI:
|
||||
@@ -203,6 +212,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return SpringUtil.getBean(QianFanImageModel.class);
|
||||
case ZHI_PU:
|
||||
return SpringUtil.getBean(ZhiPuAiImageModel.class);
|
||||
case SILICON_FLOW:
|
||||
return SpringUtil.getBean(SiliconFlowImageModel.class);
|
||||
case OPENAI:
|
||||
return SpringUtil.getBean(OpenAiImageModel.class);
|
||||
case STABLE_DIFFUSION:
|
||||
@@ -224,6 +235,8 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return buildZhiPuAiImageModel(apiKey, url);
|
||||
case OPENAI:
|
||||
return buildOpenAiImageModel(apiKey, url);
|
||||
case SILICON_FLOW:
|
||||
return buildSiliconFlowImageModel(apiKey,url);
|
||||
case STABLE_DIFFUSION:
|
||||
return buildStabilityAiImageModel(apiKey, url);
|
||||
default:
|
||||
@@ -433,6 +446,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return new YudaoAiAutoConfiguration().buildXingHuoChatClient(properties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link YudaoAiAutoConfiguration#baiChuanChatClient(YudaoAiProperties)}
|
||||
*/
|
||||
private BaiChuanChatModel buildBaiChuanChatModel(String apiKey) {
|
||||
YudaoAiProperties.BaiChuanProperties properties = new YudaoAiProperties.BaiChuanProperties()
|
||||
.setApiKey(apiKey);
|
||||
return new YudaoAiAutoConfiguration().buildBaiChuanChatClient(properties);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OpenAiAutoConfiguration} 的 openAiChatModel 方法
|
||||
*/
|
||||
@@ -468,6 +490,15 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return new OpenAiImageModel(openAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 SiliconFlowImageModel 对象
|
||||
*/
|
||||
private SiliconFlowImageModel buildSiliconFlowImageModel(String apiToken, String url) {
|
||||
url = StrUtil.blankToDefault(url, SiliconFlowApiConstants.DEFAULT_BASE_URL);
|
||||
SiliconFlowImageApi openAiApi = new SiliconFlowImageApi(url, apiToken);
|
||||
return new SiliconFlowImageModel(openAiApi);
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link OllamaAutoConfiguration} 的 ollamaApi 方法
|
||||
*/
|
||||
@@ -476,6 +507,9 @@ public class AiModelFactoryImpl implements AiModelFactory {
|
||||
return OllamaChatModel.builder().ollamaApi(ollamaApi).toolCallingManager(getToolCallingManager()).build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 可参考 {@link StabilityAiImageAutoConfiguration} 的 stabilityAiImageModel 方法
|
||||
*/
|
||||
private StabilityAiImageModel buildStabilityAiImageModel(String apiKey, String url) {
|
||||
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
|
||||
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
|
||||
|
@@ -0,0 +1,45 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.model.baichuan;
|
||||
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.ChatOptions;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
/**
|
||||
* 百川 {@link ChatModel} 实现类
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
@Slf4j
|
||||
@RequiredArgsConstructor
|
||||
public class BaiChuanChatModel implements ChatModel {
|
||||
|
||||
public static final String BASE_URL = "https://api.baichuan-ai.com";
|
||||
|
||||
public static final String MODEL_DEFAULT = "Baichuan4-Turbo";
|
||||
|
||||
/**
|
||||
* 兼容 OpenAI 接口,进行复用
|
||||
*/
|
||||
private final OpenAiChatModel openAiChatModel;
|
||||
|
||||
@Override
|
||||
public ChatResponse call(Prompt prompt) {
|
||||
return openAiChatModel.call(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Flux<ChatResponse> stream(Prompt prompt) {
|
||||
return openAiChatModel.stream(prompt);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatOptions getDefaultOptions() {
|
||||
return openAiChatModel.getDefaultOptions();
|
||||
}
|
||||
|
||||
}
|
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
|
||||
|
||||
/**
|
||||
* SiliconFlow API 枚举类
|
||||
*
|
||||
* @author zzt
|
||||
*/
|
||||
public final class SiliconFlowApiConstants {
|
||||
|
||||
public static final String DEFAULT_BASE_URL = "https://api.siliconflow.cn";
|
||||
|
||||
public static final String MODEL_DEFAULT = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B";
|
||||
|
||||
public static final String DEFAULT_IMAGE_MODEL = "Kwai-Kolors/Kolors";
|
||||
|
||||
public static final String PROVIDER_NAME = "Siiconflow";
|
||||
|
||||
}
|
@@ -20,10 +20,6 @@ import reactor.core.publisher.Flux;
|
||||
@RequiredArgsConstructor
|
||||
public class SiliconFlowChatModel implements ChatModel {
|
||||
|
||||
public static final String BASE_URL = "https://api.siliconflow.cn";
|
||||
|
||||
public static final String MODEL_DEFAULT = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B";
|
||||
|
||||
/**
|
||||
* 兼容 OpenAI 接口,进行复用
|
||||
*/
|
||||
|
@@ -0,0 +1,115 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonInclude;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import org.springframework.ai.model.ApiKey;
|
||||
import org.springframework.ai.model.NoopApiKey;
|
||||
import org.springframework.ai.model.SimpleApiKey;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.MultiValueMap;
|
||||
import org.springframework.web.client.ResponseErrorHandler;
|
||||
import org.springframework.web.client.RestClient;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 硅基流动 Image API
|
||||
*
|
||||
* @see <a href= "https://docs.siliconflow.cn/cn/api-reference/images/images-generations">Images</a>
|
||||
*
|
||||
* @author zzt
|
||||
*/
|
||||
public class SiliconFlowImageApi {
|
||||
|
||||
private final RestClient restClient;
|
||||
|
||||
public SiliconFlowImageApi(String aiToken) {
|
||||
this(SiliconFlowApiConstants.DEFAULT_BASE_URL, aiToken, RestClient.builder());
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, String openAiToken) {
|
||||
this(baseUrl, openAiToken, RestClient.builder());
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {
|
||||
this(baseUrl, openAiToken, restClientBuilder, RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, String apiKey, RestClient.Builder restClientBuilder,
|
||||
ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, apiKey, CollectionUtils.toMultiValueMap(Map.of()), restClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, String apiKey, MultiValueMap<String, String> headers,
|
||||
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
this(baseUrl, new SimpleApiKey(apiKey), headers, restClientBuilder, responseErrorHandler);
|
||||
}
|
||||
|
||||
public SiliconFlowImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers,
|
||||
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
|
||||
|
||||
// @formatter:off
|
||||
this.restClient = restClientBuilder.baseUrl(baseUrl)
|
||||
.defaultHeaders(h -> {
|
||||
if(!(apiKey instanceof NoopApiKey)) {
|
||||
h.setBearerAuth(apiKey.getValue());
|
||||
}
|
||||
h.setContentType(MediaType.APPLICATION_JSON);
|
||||
h.addAll(headers);
|
||||
})
|
||||
.defaultStatusHandler(responseErrorHandler)
|
||||
.build();
|
||||
// @formatter:on
|
||||
}
|
||||
|
||||
public ResponseEntity<OpenAiImageApi.OpenAiImageResponse> createImage(SiliconflowImageRequest siliconflowImageRequest) {
|
||||
Assert.notNull(siliconflowImageRequest, "Image request cannot be null.");
|
||||
Assert.hasLength(siliconflowImageRequest.prompt(), "Prompt cannot be empty.");
|
||||
|
||||
return this.restClient.post()
|
||||
.uri("v1/images/generations")
|
||||
.body(siliconflowImageRequest)
|
||||
.retrieve()
|
||||
.toEntity(OpenAiImageApi.OpenAiImageResponse.class);
|
||||
}
|
||||
|
||||
|
||||
// @formatter:off
|
||||
@JsonInclude(JsonInclude.Include.NON_NULL)
|
||||
public record SiliconflowImageRequest (
|
||||
@JsonProperty("prompt") String prompt,
|
||||
@JsonProperty("model") String model,
|
||||
@JsonProperty("batch_size") Integer batchSize,
|
||||
@JsonProperty("negative_prompt") String negativePrompt,
|
||||
@JsonProperty("seed") Integer seed,
|
||||
@JsonProperty("num_inference_steps") Integer numInferenceSteps,
|
||||
@JsonProperty("guidance_scale") Float guidanceScale,
|
||||
@JsonProperty("image") String image) {
|
||||
|
||||
public SiliconflowImageRequest(String prompt, String model) {
|
||||
this(prompt, model, null, null, null, null, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
@@ -0,0 +1,159 @@
|
||||
/*
|
||||
* Copyright 2023-2024 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
|
||||
|
||||
import io.micrometer.observation.ObservationRegistry;
|
||||
import lombok.Setter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.image.*;
|
||||
import org.springframework.ai.image.observation.DefaultImageModelObservationConvention;
|
||||
import org.springframework.ai.image.observation.ImageModelObservationContext;
|
||||
import org.springframework.ai.image.observation.ImageModelObservationConvention;
|
||||
import org.springframework.ai.image.observation.ImageModelObservationDocumentation;
|
||||
import org.springframework.ai.model.ModelOptionsUtils;
|
||||
import org.springframework.ai.openai.OpenAiImageModel;
|
||||
import org.springframework.ai.openai.api.OpenAiImageApi;
|
||||
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
|
||||
import org.springframework.ai.retry.RetryUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.lang.Nullable;
|
||||
import org.springframework.retry.support.RetryTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 硅基流动 {@link ImageModel} 实现类
|
||||
*
|
||||
* 参考 {@link OpenAiImageModel} 实现
|
||||
*
|
||||
* @author zzt
|
||||
*/
|
||||
public class SiliconFlowImageModel implements ImageModel {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(SiliconFlowImageModel.class);
|
||||
|
||||
private static final ImageModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultImageModelObservationConvention();
|
||||
|
||||
private final SiliconFlowImageOptions defaultOptions;
|
||||
|
||||
private final RetryTemplate retryTemplate;
|
||||
|
||||
private final SiliconFlowImageApi siliconFlowImageApi;
|
||||
|
||||
private final ObservationRegistry observationRegistry;
|
||||
|
||||
@Setter
|
||||
private ImageModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
|
||||
|
||||
public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi) {
|
||||
this(siliconFlowImageApi, SiliconFlowImageOptions.builder().build(), RetryUtils.DEFAULT_RETRY_TEMPLATE);
|
||||
}
|
||||
|
||||
public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi, SiliconFlowImageOptions options, RetryTemplate retryTemplate) {
|
||||
this(siliconFlowImageApi, options, retryTemplate, ObservationRegistry.NOOP);
|
||||
}
|
||||
|
||||
public SiliconFlowImageModel(SiliconFlowImageApi siliconFlowImageApi, SiliconFlowImageOptions options, RetryTemplate retryTemplate,
|
||||
ObservationRegistry observationRegistry) {
|
||||
Assert.notNull(siliconFlowImageApi, "OpenAiImageApi must not be null");
|
||||
Assert.notNull(options, "options must not be null");
|
||||
Assert.notNull(retryTemplate, "retryTemplate must not be null");
|
||||
Assert.notNull(observationRegistry, "observationRegistry must not be null");
|
||||
this.siliconFlowImageApi = siliconFlowImageApi;
|
||||
this.defaultOptions = options;
|
||||
this.retryTemplate = retryTemplate;
|
||||
this.observationRegistry = observationRegistry;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageResponse call(ImagePrompt imagePrompt) {
|
||||
SiliconFlowImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
|
||||
SiliconFlowImageApi.SiliconflowImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);
|
||||
|
||||
var observationContext = ImageModelObservationContext.builder()
|
||||
.imagePrompt(imagePrompt)
|
||||
.provider(SiliconFlowApiConstants.PROVIDER_NAME)
|
||||
.requestOptions(imagePrompt.getOptions())
|
||||
.build();
|
||||
|
||||
return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION
|
||||
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
|
||||
this.observationRegistry)
|
||||
.observe(() -> {
|
||||
ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity = this.retryTemplate
|
||||
.execute(ctx -> this.siliconFlowImageApi.createImage(imageRequest));
|
||||
|
||||
ImageResponse imageResponse = convertResponse(imageResponseEntity, imageRequest);
|
||||
|
||||
observationContext.setResponse(imageResponse);
|
||||
|
||||
return imageResponse;
|
||||
});
|
||||
}
|
||||
|
||||
private SiliconFlowImageApi.SiliconflowImageRequest createRequest(ImagePrompt imagePrompt,
|
||||
SiliconFlowImageOptions requestImageOptions) {
|
||||
String instructions = imagePrompt.getInstructions().get(0).getText();
|
||||
|
||||
SiliconFlowImageApi.SiliconflowImageRequest imageRequest = new SiliconFlowImageApi.SiliconflowImageRequest(instructions,
|
||||
SiliconFlowApiConstants.DEFAULT_IMAGE_MODEL);
|
||||
|
||||
return ModelOptionsUtils.merge(requestImageOptions, imageRequest, SiliconFlowImageApi.SiliconflowImageRequest.class);
|
||||
}
|
||||
|
||||
private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
|
||||
SiliconFlowImageApi.SiliconflowImageRequest siliconflowImageRequest) {
|
||||
OpenAiImageApi.OpenAiImageResponse imageApiResponse = imageResponseEntity.getBody();
|
||||
if (imageApiResponse == null) {
|
||||
logger.warn("No image response returned for request: {}", siliconflowImageRequest);
|
||||
return new ImageResponse(List.of());
|
||||
}
|
||||
|
||||
List<ImageGeneration> imageGenerationList = imageApiResponse.data()
|
||||
.stream()
|
||||
.map(entry -> new ImageGeneration(new Image(entry.url(), entry.b64Json()),
|
||||
new OpenAiImageGenerationMetadata(entry.revisedPrompt())))
|
||||
.toList();
|
||||
|
||||
ImageResponseMetadata openAiImageResponseMetadata = new ImageResponseMetadata(imageApiResponse.created());
|
||||
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
|
||||
}
|
||||
|
||||
private SiliconFlowImageOptions mergeOptions(@Nullable ImageOptions runtimeOptions, SiliconFlowImageOptions defaultOptions) {
|
||||
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
|
||||
SiliconFlowImageOptions.class);
|
||||
|
||||
if (runtimeOptionsForProvider == null) {
|
||||
return defaultOptions;
|
||||
}
|
||||
|
||||
return SiliconFlowImageOptions.builder()
|
||||
// Handle portable image options
|
||||
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
|
||||
.batchSize(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
|
||||
.width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
|
||||
.height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
|
||||
// Handle SiliconFlow specific image options
|
||||
.negativePrompt(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNegativePrompt(), defaultOptions.getNegativePrompt()))
|
||||
.numInferenceSteps(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getNumInferenceSteps(), defaultOptions.getNumInferenceSteps()))
|
||||
.guidanceScale(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getGuidanceScale(), defaultOptions.getGuidanceScale()))
|
||||
.seed(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getSeed(), defaultOptions.getSeed()))
|
||||
.build();
|
||||
}
|
||||
}
|
@@ -0,0 +1,105 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.model.siliconflow;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.springframework.ai.image.ImageOptions;
|
||||
|
||||
/**
|
||||
* 硅基流动 {@link ImageOptions}
|
||||
*
|
||||
* @author zzt
|
||||
*/
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class SiliconFlowImageOptions implements ImageOptions {
|
||||
|
||||
@JsonProperty("model")
|
||||
private String model;
|
||||
|
||||
@JsonProperty("negative_prompt")
|
||||
private String negativePrompt;
|
||||
|
||||
/**
|
||||
* The number of images to generate. Must be between 1 and 4.
|
||||
*/
|
||||
@JsonProperty("image_size")
|
||||
private String imageSize;
|
||||
|
||||
/**
|
||||
* The number of images to generate. Must be between 1 and 4.
|
||||
*/
|
||||
@JsonProperty("batch_size")
|
||||
private Integer batchSize = 1;
|
||||
|
||||
/**
|
||||
* number of inference steps
|
||||
*/
|
||||
@JsonProperty("num_inference_steps")
|
||||
private Integer numInferenceSteps = 25;
|
||||
|
||||
/**
|
||||
* This value is used to control the degree of match between the generated image and the given prompt. The higher the value, the more the generated image will tend to strictly match the text prompt. The lower the value, the more creative and diverse the generated image will be, potentially containing more unexpected elements.
|
||||
*
|
||||
* Required range: 0 <= x <= 20
|
||||
*/
|
||||
@JsonProperty("guidance_scale")
|
||||
private Float guidanceScale = 0.75F;
|
||||
|
||||
/**
|
||||
* 如果想要每次都生成固定的图片,可以把 seed 设置为固定值
|
||||
*
|
||||
*/
|
||||
@JsonProperty("seed")
|
||||
private Integer seed = (int)(Math.random() * 1_000_000_000);
|
||||
|
||||
/**
|
||||
* The image that needs to be uploaded should be converted into base64 format.
|
||||
*/
|
||||
@JsonProperty("image")
|
||||
private String image;
|
||||
|
||||
/**
|
||||
* 宽
|
||||
*/
|
||||
private Integer width;
|
||||
|
||||
/**
|
||||
* 高
|
||||
*/
|
||||
private Integer height;
|
||||
|
||||
public void setHeight(Integer height) {
|
||||
this.height = height;
|
||||
if (this.width != null && this.height != null) {
|
||||
this.imageSize = this.width + "x" + this.height;
|
||||
}
|
||||
}
|
||||
|
||||
public void setWidth(Integer width) {
|
||||
this.width = width;
|
||||
if (this.width != null && this.height != null) {
|
||||
this.imageSize = this.width + "x" + this.height;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer getN() {
|
||||
return batchSize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getResponseFormat() {
|
||||
return "url";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getStyle() {
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
@@ -1,5 +1,6 @@
|
||||
package cn.iocoder.yudao.framework.ai.core.util;
|
||||
|
||||
import cn.hutool.core.util.ObjUtil;
|
||||
import cn.hutool.core.util.StrUtil;
|
||||
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
|
||||
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
|
||||
@@ -13,6 +14,7 @@ import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.qianfan.QianFanChatOptions;
|
||||
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
@@ -28,6 +30,7 @@ public class AiUtils {
|
||||
|
||||
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
|
||||
Set<String> toolNames) {
|
||||
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
|
||||
// noinspection EnhancedSwitchMigration
|
||||
switch (platform) {
|
||||
case TONG_YI:
|
||||
@@ -50,6 +53,7 @@ public class AiUtils {
|
||||
case HUN_YUAN: // 复用 OpenAI 客户端
|
||||
case XING_HUO: // 复用 OpenAI 客户端
|
||||
case SILICON_FLOW: // 复用 OpenAI 客户端
|
||||
case BAI_CHUAN: // 复用 OpenAI 客户端
|
||||
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
|
||||
.toolNames(toolNames).build();
|
||||
case AZURE_OPENAI:
|
||||
|
@@ -0,0 +1,68 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.baichuan.BaiChuanChatModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.deepseek.DeepSeekChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.openai.OpenAiChatModel;
|
||||
import org.springframework.ai.openai.OpenAiChatOptions;
|
||||
import org.springframework.ai.openai.api.OpenAiApi;
|
||||
import reactor.core.publisher.Flux;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* {@link BaiChuanChatModel} 集成测试
|
||||
*
|
||||
* @author 芋道源码
|
||||
*/
|
||||
public class BaiChuanChatModelTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(BaiChuanChatModel.BASE_URL)
|
||||
.apiKey("sk-61b6766a94c70786ed02673f5e16af3c") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model("Baichuan4-Turbo") // 模型(https://platform.baichuan-ai.com/docs/api)
|
||||
.temperature(0.7)
|
||||
.build())
|
||||
.build();
|
||||
|
||||
private final DeepSeekChatModel chatModel = new DeepSeekChatModel(openAiChatModel);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
ChatResponse response = chatModel.call(new Prompt(messages));
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testStream() {
|
||||
// 准备参数
|
||||
List<Message> messages = new ArrayList<>();
|
||||
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
|
||||
messages.add(new UserMessage("1 + 1 = ?"));
|
||||
|
||||
// 调用
|
||||
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
|
||||
// 打印结果
|
||||
flux.doOnNext(System.out::println).then().block();
|
||||
}
|
||||
|
||||
}
|
@@ -1,5 +1,6 @@
|
||||
package cn.iocoder.yudao.framework.ai.chat;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowApiConstants;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowChatModel;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -25,11 +26,11 @@ public class SiliconFlowChatModelTests {
|
||||
|
||||
private final OpenAiChatModel openAiChatModel = OpenAiChatModel.builder()
|
||||
.openAiApi(OpenAiApi.builder()
|
||||
.baseUrl(SiliconFlowChatModel.BASE_URL)
|
||||
.baseUrl(SiliconFlowApiConstants.DEFAULT_BASE_URL)
|
||||
.apiKey("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // apiKey
|
||||
.build())
|
||||
.defaultOptions(OpenAiChatOptions.builder()
|
||||
.model(SiliconFlowChatModel.MODEL_DEFAULT) // 模型
|
||||
.model(SiliconFlowApiConstants.MODEL_DEFAULT) // 模型
|
||||
// .model("deepseek-ai/DeepSeek-R1") // 模型(deepseek-ai/DeepSeek-R1)可用赠费
|
||||
// .model("Pro/deepseek-ai/DeepSeek-R1") // 模型(Pro/deepseek-ai/DeepSeek-R1)需要付费
|
||||
.temperature(0.7)
|
||||
|
@@ -0,0 +1,35 @@
|
||||
package cn.iocoder.yudao.framework.ai.image;
|
||||
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageApi;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageModel;
|
||||
import cn.iocoder.yudao.framework.ai.core.model.siliconflow.SiliconFlowImageOptions;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.ai.image.ImagePrompt;
|
||||
import org.springframework.ai.image.ImageResponse;
|
||||
|
||||
/**
|
||||
* {@link SiliconFlowImageModel} 集成测试
|
||||
*/
|
||||
public class SiliconFlowImageModelTests {
|
||||
|
||||
private final SiliconFlowImageModel imageModel = new SiliconFlowImageModel(
|
||||
new SiliconFlowImageApi("sk-epsakfenqnyzoxhmbucsxlhkdqlcbnimslqoivkshalvdozz") // 密钥
|
||||
);
|
||||
|
||||
@Test
|
||||
@Disabled
|
||||
public void testCall() {
|
||||
// 准备参数
|
||||
SiliconFlowImageOptions imageOptions = SiliconFlowImageOptions.builder()
|
||||
.model("Kwai-Kolors/Kolors")
|
||||
.build();
|
||||
ImagePrompt prompt = new ImagePrompt("万里长城", imageOptions);
|
||||
|
||||
// 方法调用
|
||||
ImageResponse response = imageModel.call(prompt);
|
||||
// 打印结果
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
}
|
Reference in New Issue
Block a user