Compare commits

...

6 Commits

Author SHA1 Message Date
Cursor Agent
0b6e9f7ee4 Refactor file path validation using @AssertTrue method in upload VOss
Co-authored-by: zhijiantianya <zhijiantianya@gmail.com>
2025-07-09 08:01:48 +00:00
Cursor Agent
a7247419ba Checkpoint before follow-up message 2025-07-09 07:55:20 +00:00
Cursor Agent
c853410f2b Checkpoint before follow-up message 2025-07-09 07:54:10 +00:00
Cursor Agent
fc0a9ddaf1 Add path validation to prevent directory traversal attacks
Co-authored-by: zhijiantianya <zhijiantianya@gmail.com>
2025-07-09 07:42:52 +00:00
芋道源码
e6fecd8efe Merge pull request #869 from YunaiV/fix-erp-statistics-tenant-issue
fix: ERP统计查询在多租户关闭时的NullPointerException问题
2025-07-06 16:08:21 +08:00
芋道源码
3b2a3dd0ea fix: ERP统计查询在多租户关闭时的NullPointerException问题
- 修复ErpSaleStatisticsMapper.xml中硬编码使用getRequiredTenantId()导致的空指针异常
- 修复ErpPurchaseStatisticsMapper.xml中硬编码使用getRequiredTenantId()导致的空指针异常
- 使用条件判断getTenantId() != null来决定是否添加租户条件
- 添加单元测试验证多租户开启和关闭时的统计查询功能
- 确保向后兼容,多租户开启时正常工作,关闭时不报错
2025-06-09 06:36:11 +00:00
11 changed files with 326 additions and 6 deletions

BIN
test_assert_true.class Normal file

Binary file not shown.

BIN
test_path_validation.class Normal file

Binary file not shown.

View File

@@ -10,7 +10,9 @@
<if test="endTime != null">
AND in_time &lt; #{endTime}
</if>
AND tenant_id = ${@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getRequiredTenantId()}
<if test="@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getTenantId() != null">
AND tenant_id = ${@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getTenantId()}
</if>
AND deleted = 0) -
(SELECT IFNULL(SUM(total_price), 0)
FROM erp_purchase_return
@@ -18,7 +20,9 @@
<if test="endTime != null">
AND return_time &lt; #{endTime}
</if>
AND tenant_id = ${@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getRequiredTenantId()}
<if test="@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getTenantId() != null">
AND tenant_id = ${@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getTenantId()}
</if>
AND deleted = 0)
</select>

View File

@@ -10,7 +10,9 @@
<if test="endTime != null">
AND out_time &lt; #{endTime}
</if>
AND tenant_id = ${@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getRequiredTenantId()}
<if test="@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getTenantId() != null">
AND tenant_id = ${@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getTenantId()}
</if>
AND deleted = 0) -
(SELECT IFNULL(SUM(total_price), 0)
FROM erp_sale_return
@@ -18,7 +20,9 @@
<if test="endTime != null">
AND return_time &lt; #{endTime}
</if>
AND tenant_id = ${@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getRequiredTenantId()}
<if test="@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getTenantId() != null">
AND tenant_id = ${@cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder@getTenantId()}
</if>
AND deleted = 0)
</select>

View File

@@ -0,0 +1,155 @@
package cn.iocoder.yudao.module.erp.service.statistics;
import cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder;
import cn.iocoder.yudao.module.erp.dal.mysql.statistics.ErpPurchaseStatisticsMapper;
import cn.iocoder.yudao.module.erp.dal.mysql.statistics.ErpSaleStatisticsMapper;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.ActiveProfiles;
import jakarta.annotation.Resource;
import java.math.BigDecimal;
import java.time.LocalDateTime;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
/**
* ERP 统计服务测试类
* 主要测试在多租户关闭情况下,统计查询是否能正常工作
*
* @author 芋道源码
*/
@SpringBootTest
@ActiveProfiles("unit-test")
public class ErpStatisticsServiceTest {
@Resource
private ErpSaleStatisticsService saleStatisticsService;
@Resource
private ErpPurchaseStatisticsService purchaseStatisticsService;
@MockBean
private ErpSaleStatisticsMapper saleStatisticsMapper;
@MockBean
private ErpPurchaseStatisticsMapper purchaseStatisticsMapper;
@BeforeEach
void setUp() {
// 清理租户上下文
TenantContextHolder.clear();
}
@AfterEach
void tearDown() {
// 清理租户上下文
TenantContextHolder.clear();
}
@Test
void testSaleStatisticsWithoutTenant() {
// 准备参数
LocalDateTime beginTime = LocalDateTime.of(2024, 1, 1, 0, 0, 0);
LocalDateTime endTime = LocalDateTime.of(2024, 1, 31, 23, 59, 59);
BigDecimal expectedPrice = new BigDecimal("1000.00");
// Mock 返回值
when(saleStatisticsMapper.getSalePrice(any(LocalDateTime.class), any(LocalDateTime.class)))
.thenReturn(expectedPrice);
// 测试在没有租户ID的情况下调用销售统计
assertDoesNotThrow(() -> {
BigDecimal result = saleStatisticsService.getSalePrice(beginTime, endTime);
assertEquals(expectedPrice, result);
}, "在多租户关闭时,销售统计查询应该能正常工作");
}
@Test
void testPurchaseStatisticsWithoutTenant() {
// 准备参数
LocalDateTime beginTime = LocalDateTime.of(2024, 1, 1, 0, 0, 0);
LocalDateTime endTime = LocalDateTime.of(2024, 1, 31, 23, 59, 59);
BigDecimal expectedPrice = new BigDecimal("800.00");
// Mock 返回值
when(purchaseStatisticsMapper.getPurchasePrice(any(LocalDateTime.class), any(LocalDateTime.class)))
.thenReturn(expectedPrice);
// 测试在没有租户ID的情况下调用采购统计
assertDoesNotThrow(() -> {
BigDecimal result = purchaseStatisticsService.getPurchasePrice(beginTime, endTime);
assertEquals(expectedPrice, result);
}, "在多租户关闭时,采购统计查询应该能正常工作");
}
@Test
void testSaleStatisticsWithTenant() {
// 设置租户ID
Long tenantId = 1L;
TenantContextHolder.setTenantId(tenantId);
// 准备参数
LocalDateTime beginTime = LocalDateTime.of(2024, 1, 1, 0, 0, 0);
LocalDateTime endTime = LocalDateTime.of(2024, 1, 31, 23, 59, 59);
BigDecimal expectedPrice = new BigDecimal("1500.00");
// Mock 返回值
when(saleStatisticsMapper.getSalePrice(any(LocalDateTime.class), any(LocalDateTime.class)))
.thenReturn(expectedPrice);
// 测试在有租户ID的情况下调用销售统计
assertDoesNotThrow(() -> {
BigDecimal result = saleStatisticsService.getSalePrice(beginTime, endTime);
assertEquals(expectedPrice, result);
}, "在多租户开启时,销售统计查询应该能正常工作");
// 验证租户ID是否正确设置
assertEquals(tenantId, TenantContextHolder.getTenantId());
}
@Test
void testPurchaseStatisticsWithTenant() {
// 设置租户ID
Long tenantId = 2L;
TenantContextHolder.setTenantId(tenantId);
// 准备参数
LocalDateTime beginTime = LocalDateTime.of(2024, 1, 1, 0, 0, 0);
LocalDateTime endTime = LocalDateTime.of(2024, 1, 31, 23, 59, 59);
BigDecimal expectedPrice = new BigDecimal("1200.00");
// Mock 返回值
when(purchaseStatisticsMapper.getPurchasePrice(any(LocalDateTime.class), any(LocalDateTime.class)))
.thenReturn(expectedPrice);
// 测试在有租户ID的情况下调用采购统计
assertDoesNotThrow(() -> {
BigDecimal result = purchaseStatisticsService.getPurchasePrice(beginTime, endTime);
assertEquals(expectedPrice, result);
}, "在多租户开启时,采购统计查询应该能正常工作");
// 验证租户ID是否正确设置
assertEquals(tenantId, TenantContextHolder.getTenantId());
}
@Test
void testTenantContextHolderMethods() {
// 测试 getTenantId() 在没有设置租户时返回 null
assertNull(TenantContextHolder.getTenantId(), "未设置租户时应该返回 null");
// 设置租户ID
Long tenantId = 3L;
TenantContextHolder.setTenantId(tenantId);
assertEquals(tenantId, TenantContextHolder.getTenantId(), "设置租户后应该能正确获取");
// 清理租户上下文
TenantContextHolder.clear();
assertNull(TenantContextHolder.getTenantId(), "清理后应该返回 null");
}
}

View File

@@ -43,7 +43,7 @@ public class FileController {
@PostMapping("/upload")
@Operation(summary = "上传文件", description = "模式一:后端上传文件")
public CommonResult<String> uploadFile(FileUploadReqVO uploadReqVO) throws Exception {
public CommonResult<String> uploadFile(@Valid FileUploadReqVO uploadReqVO) throws Exception {
MultipartFile file = uploadReqVO.getFile();
byte[] content = IoUtil.readBytes(file.getInputStream());
return success(fileService.createFile(content, file.getOriginalFilename(),

View File

@@ -1,6 +1,8 @@
package cn.iocoder.yudao.module.infra.controller.admin.file.vo.file;
import cn.hutool.core.util.StrUtil;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import org.springframework.web.multipart.MultipartFile;
@@ -16,4 +18,34 @@ public class FileUploadReqVO {
@Schema(description = "文件目录", example = "XXX/YYY")
private String directory;
@AssertTrue(message = "目录路径无效,包含非法字符")
public boolean isDirectoryValid() {
if (StrUtil.isEmpty(directory)) {
return true; // 空值认为是有效的
}
// 统一使用正斜杠
String normalizedPath = directory.replace('\\', '/');
// 检查绝对路径
if (normalizedPath.startsWith("/") || normalizedPath.matches("^[A-Za-z]:.*")) {
return false;
}
// 检查路径遍历攻击
String[] dangerousPatterns = {
"..", "..\\", "../", "..%2f", "..%5c", "..%2F", "..%5C",
"%2e%2e", "%2E%2E", "%2e%2e%2f", "%2E%2E%2F",
"....//", "....\\\\", "....%2f", "....%5c"
};
String lowerPath = normalizedPath.toLowerCase();
for (String pattern : dangerousPatterns) {
if (lowerPath.contains(pattern)) {
return false;
}
}
return true;
}
}

View File

@@ -33,7 +33,7 @@ public class AppFileController {
@PostMapping("/upload")
@Operation(summary = "上传文件")
@PermitAll
public CommonResult<String> uploadFile(AppFileUploadReqVO uploadReqVO) throws Exception {
public CommonResult<String> uploadFile(@Valid AppFileUploadReqVO uploadReqVO) throws Exception {
MultipartFile file = uploadReqVO.getFile();
byte[] content = IoUtil.readBytes(file.getInputStream());
return success(fileService.createFile(content, file.getOriginalFilename(),

View File

@@ -1,6 +1,8 @@
package cn.iocoder.yudao.module.infra.controller.app.file.vo;
import cn.hutool.core.util.StrUtil;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.AssertTrue;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
import org.springframework.web.multipart.MultipartFile;
@@ -16,4 +18,34 @@ public class AppFileUploadReqVO {
@Schema(description = "文件目录", example = "XXX/YYY")
private String directory;
@AssertTrue(message = "目录路径无效,包含非法字符")
public boolean isDirectoryValid() {
if (StrUtil.isEmpty(directory)) {
return true; // 空值认为是有效的
}
// 统一使用正斜杠
String normalizedPath = directory.replace('\\', '/');
// 检查绝对路径
if (normalizedPath.startsWith("/") || normalizedPath.matches("^[A-Za-z]:.*")) {
return false;
}
// 检查路径遍历攻击
String[] dangerousPatterns = {
"..", "..\\", "../", "..%2f", "..%5c", "..%2F", "..%5C",
"%2e%2e", "%2E%2E", "%2e%2e%2f", "%2E%2E%2F",
"....//", "....\\\\", "....%2f", "....%5c"
};
String lowerPath = normalizedPath.toLowerCase();
for (String pattern : dangerousPatterns) {
if (lowerPath.contains(pattern)) {
return false;
}
}
return true;
}
}

View File

@@ -33,6 +33,7 @@ public interface ErrorCodeConstants {
ErrorCode FILE_PATH_EXISTS = new ErrorCode(1_001_003_000, "文件路径已存在");
ErrorCode FILE_NOT_EXISTS = new ErrorCode(1_001_003_001, "文件不存在");
ErrorCode FILE_IS_EMPTY = new ErrorCode(1_001_003_002, "文件为空");
ErrorCode FILE_PATH_INVALID = new ErrorCode(1_001_003_003, "文件路径无效,包含非法字符");
// ========== 代码生成器 1-001-004-000 ==========
ErrorCode CODEGEN_TABLE_EXISTS = new ErrorCode(1_001_004_002, "表定义已经存在");

View File

@@ -0,0 +1,92 @@
package cn.iocoder.yudao.module.infra.controller.admin.file.vo.file;
import jakarta.validation.Validation;
import jakarta.validation.Validator;
import jakarta.validation.ValidatorFactory;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.mock.web.MockMultipartFile;
import static org.junit.jupiter.api.Assertions.*;
/**
* FileUploadReqVO 测试类
*
* @author 芋道源码
*/
class FileUploadReqVOTest {
private static Validator validator;
@BeforeAll
static void setUp() {
ValidatorFactory factory = Validation.buildDefaultValidatorFactory();
validator = factory.getValidator();
}
@Test
void testValidDirectory() {
// 测试有效目录
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("uploads/2024/01");
var violations = validator.validate(vo);
assertTrue(violations.isEmpty(), "有效目录应该通过验证");
}
@Test
void testNullDirectory() {
// 测试空目录
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory(null);
var violations = validator.validate(vo);
assertTrue(violations.isEmpty(), "空目录应该通过验证");
}
@Test
void testEmptyDirectory() {
// 测试空字符串目录
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("");
var violations = validator.validate(vo);
assertTrue(violations.isEmpty(), "空字符串目录应该通过验证");
}
@Test
void testPathTraversalAttack() {
// 测试路径遍历攻击
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("../../etc/passwd");
var violations = validator.validate(vo);
assertFalse(violations.isEmpty(), "路径遍历攻击应该被拒绝");
}
@Test
void testAbsolutePath() {
// 测试绝对路径
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("/etc/passwd");
var violations = validator.validate(vo);
assertFalse(violations.isEmpty(), "绝对路径应该被拒绝");
}
@Test
void testWindowsAbsolutePath() {
// 测试Windows绝对路径
FileUploadReqVO vo = new FileUploadReqVO();
vo.setFile(new MockMultipartFile("test.txt", "test.txt", "text/plain", "test".getBytes()));
vo.setDirectory("C:\\windows\\system32");
var violations = validator.validate(vo);
assertFalse(violations.isEmpty(), "Windows绝对路径应该被拒绝");
}
}