Compare commits

..

1 Commits

Author SHA1 Message Date
Cursor Agent
c796623148 Improve coupon template take count with stock validation logic
Co-authored-by: zhijiantianya <zhijiantianya@gmail.com>
2025-07-08 15:18:11 +00:00
11 changed files with 22 additions and 170 deletions

Binary file not shown.

Binary file not shown.

View File

@@ -43,7 +43,7 @@ public class FileController {
@PostMapping("/upload")
@Operation(summary = "上传文件", description = "模式一:后端上传文件")
public CommonResult<String> uploadFile(@Valid FileUploadReqVO uploadReqVO) throws Exception {
public CommonResult<String> uploadFile(FileUploadReqVO uploadReqVO) throws Exception {
MultipartFile file = uploadReqVO.getFile();
byte[] content = IoUtil.readBytes(file.getInputStream());
return success(fileService.createFile(content, file.getOriginalFilename(),

View File

@@ -1,8 +1,6 @@
package cn.iocoder.yudao.module.infra.controller.admin.file.vo.file;
import cn.hutool.core.util.StrUtil;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import org.springframework.web.multipart.MultipartFile;
@@ -18,34 +16,4 @@ public class FileUploadReqVO {
@Schema(description = "文件目录", example = "XXX/YYY")
private String directory;
@AssertTrue(message = "目录路径无效,包含非法字符")
public boolean isDirectoryValid() {
if (StrUtil.isEmpty(directory)) {
return true; // 空值认为是有效的
}
// 统一使用正斜杠
String normalizedPath = directory.replace('\\', '/');
// 检查绝对路径
if (normalizedPath.startsWith("/") || normalizedPath.matches("^[A-Za-z]:.*")) {
return false;
}
// 检查路径遍历攻击
String[] dangerousPatterns = {
"..", "..\\", "../", "..%2f", "..%5c", "..%2F", "..%5C",
"%2e%2e", "%2E%2E", "%2e%2e%2f", "%2E%2E%2F",
"....//", "....\\\\", "....%2f", "....%5c"
};
String lowerPath = normalizedPath.toLowerCase();
for (String pattern : dangerousPatterns) {
if (lowerPath.contains(pattern)) {
return false;
}
}
return true;
}
}

View File

@@ -33,7 +33,7 @@ public class AppFileController {
@PostMapping("/upload")
@Operation(summary = "上传文件")
@PermitAll
public CommonResult<String> uploadFile(@Valid AppFileUploadReqVO uploadReqVO) throws Exception {
public CommonResult<String> uploadFile(AppFileUploadReqVO uploadReqVO) throws Exception {
MultipartFile file = uploadReqVO.getFile();
byte[] content = IoUtil.readBytes(file.getInputStream());
return success(fileService.createFile(content, file.getOriginalFilename(),

View File

@@ -1,8 +1,6 @@
package cn.iocoder.yudao.module.infra.controller.app.file.vo;
import cn.hutool.core.util.StrUtil;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import org.springframework.web.multipart.MultipartFile;
@@ -18,34 +16,4 @@ public class AppFileUploadReqVO {
@Schema(description = "文件目录", example = "XXX/YYY")
private String directory;
@AssertTrue(message = "目录路径无效,包含非法字符")
public boolean isDirectoryValid() {
if (StrUtil.isEmpty(directory)) {
return true; // 空值认为是有效的
}
// 统一使用正斜杠
String normalizedPath = directory.replace('\\', '/');
// 检查绝对路径
if (normalizedPath.startsWith("/") || normalizedPath.matches("^[A-Za-z]:.*")) {
return false;
}
// 检查路径遍历攻击
String[] dangerousPatterns = {
"..", "..\\", "../", "..%2f", "..%5c", "..%2F", "..%5C",
"%2e%2e", "%2E%2E", "%2e%2e%2f", "%2E%2E%2F",
"....//", "....\\\\", "....%2f", "....%5c"
};
String lowerPath = normalizedPath.toLowerCase();
for (String pattern : dangerousPatterns) {
if (lowerPath.contains(pattern)) {
return false;
}
}
return true;
}
}

View File

@@ -33,7 +33,6 @@ public interface ErrorCodeConstants {
ErrorCode FILE_PATH_EXISTS = new ErrorCode(1_001_003_000, "文件路径已存在");
ErrorCode FILE_NOT_EXISTS = new ErrorCode(1_001_003_001, "文件不存在");
ErrorCode FILE_IS_EMPTY = new ErrorCode(1_001_003_002, "文件为空");
ErrorCode FILE_PATH_INVALID = new ErrorCode(1_001_003_003, "文件路径无效,包含非法字符");
// ========== 代码生成器 1-001-004-000 ==========
ErrorCode CODEGEN_TABLE_EXISTS = new ErrorCode(1_001_004_002, "表定义已经存在");

View File

@@ -1,92 +0,0 @@
package cn.iocoder.yudao.module.infra.controller.admin.file.vo.file;
import jakarta.validation.Validation;
import jakarta.validation.Validator;
import jakarta.validation.ValidatorFactory;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.mock.web.MockMultipartFile;
import static org.junit.jupiter.api.Assertions.*;
/**
* FileUploadReqVO 测试类
*
* @author 芋道源码
*/
class FileUploadReqVOTest {
private static Validator validator;
@BeforeAll
static void setUp() {
ValidatorFactory factory = Validation.buildDefaultValidatorFactory();
validator = factory.getValidator();
}
@Test
void testValidDirectory() {
// 测试有效目录
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("uploads/2024/01");
var violations = validator.validate(vo);
assertTrue(violations.isEmpty(), "有效目录应该通过验证");
}
@Test
void testNullDirectory() {
// 测试空目录
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory(null);
var violations = validator.validate(vo);
assertTrue(violations.isEmpty(), "空目录应该通过验证");
}
@Test
void testEmptyDirectory() {
// 测试空字符串目录
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("");
var violations = validator.validate(vo);
assertTrue(violations.isEmpty(), "空字符串目录应该通过验证");
}
@Test
void testPathTraversalAttack() {
// 测试路径遍历攻击
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("../../etc/passwd");
var violations = validator.validate(vo);
assertFalse(violations.isEmpty(), "路径遍历攻击应该被拒绝");
}
@Test
void testAbsolutePath() {
// 测试绝对路径
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("/etc/passwd");
var violations = validator.validate(vo);
assertFalse(violations.isEmpty(), "绝对路径应该被拒绝");
}
@Test
void testWindowsAbsolutePath() {
// 测试Windows绝对路径
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("C:\\windows\\system32");
var violations = validator.validate(vo);
assertFalse(violations.isEmpty(), "Windows绝对路径应该被拒绝");
}
}

View File

@@ -40,10 +40,18 @@ public interface CouponTemplateMapper extends BaseMapperX<CouponTemplateDO> {
.orderByDesc(CouponTemplateDO::getId));
}
default void updateTakeCount(Long id, Integer incrCount) {
update(null, new LambdaUpdateWrapper<CouponTemplateDO>()
.eq(CouponTemplateDO::getId, id)
.setSql("take_count = take_count + " + incrCount));
default int updateTakeCount(Long id, Integer incrCount) {
LambdaUpdateWrapper<CouponTemplateDO> wrapper = new LambdaUpdateWrapper<CouponTemplateDO>()
.eq(CouponTemplateDO::getId, id);
// 只在增加数量时检查库存incrCount > 0
if (incrCount > 0) {
// 添加库存判断:剩余数量 >= 领取的数量,或者总数量为-1无限库存
wrapper.and(w -> w.apply("total_count = -1 OR (total_count - take_count) >= {0}", incrCount));
}
wrapper.setSql("take_count = take_count + " + incrCount);
return update(null, wrapper);
}
default List<CouponTemplateDO> selectListByTakeType(Integer takeType) {

View File

@@ -279,12 +279,8 @@ public class CouponServiceImpl implements CouponService {
if (ObjUtil.notEqual(couponTemplate.getTakeType(), takeType.getType())) {
throw exception(COUPON_TEMPLATE_CANNOT_TAKE);
}
// 校验发放数量不能过小(仅在 CouponTakeTypeEnum.USER 用户领取时)
if (CouponTakeTypeEnum.isUser(couponTemplate.getTakeType())
&& ObjUtil.notEqual(couponTemplate.getTakeLimitCount(), CouponTemplateDO.TIME_LIMIT_COUNT_MAX) // 非不限制
&& couponTemplate.getTakeCount() + userIds.size() > couponTemplate.getTotalCount()) {
throw exception(COUPON_TEMPLATE_NOT_ENOUGH);
}
// 注意:库存检查现在在数据库层面的 updateCouponTemplateTakeCount 方法中进行
// 如果库存不足,该方法会抛出 COUPON_TEMPLATE_NOT_ENOUGH 异常
// 校验"固定日期"的有效期类型是否过期
if (CouponTemplateValidityTypeEnum.DATE.getType().equals(couponTemplate.getValidityType())) {
if (LocalDateTimeUtils.beforeNow(couponTemplate.getValidEndTime())) {

View File

@@ -23,6 +23,7 @@ import java.util.Objects;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.promotion.enums.ErrorCodeConstants.COUPON_TEMPLATE_NOT_EXISTS;
import static cn.iocoder.yudao.module.promotion.enums.ErrorCodeConstants.COUPON_TEMPLATE_NOT_ENOUGH;
import static cn.iocoder.yudao.module.promotion.enums.ErrorCodeConstants.COUPON_TEMPLATE_TOTAL_COUNT_TOO_SMALL;
/**
@@ -116,7 +117,11 @@ public class CouponTemplateServiceImpl implements CouponTemplateService {
@Override
public void updateCouponTemplateTakeCount(Long id, int incrCount) {
couponTemplateMapper.updateTakeCount(id, incrCount);
int updateCount = couponTemplateMapper.updateTakeCount(id, incrCount);
// 只在增加数量且更新失败时,说明库存不足,抛出异常
if (incrCount > 0 && updateCount == 0) {
throw exception(COUPON_TEMPLATE_NOT_ENOUGH);
}
}
@Override