Skip to content

HealthAIHub

字数
3828 字
阅读时间
20 分钟

1 RedisUtil

java
package com.mole.health.util;  
  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.data.redis.core.RedisTemplate;  
import org.springframework.stereotype.Service;  
  
import java.util.List;  
import java.util.Map;  
import java.util.Set;  
import java.util.concurrent.TimeUnit;  
  
@Service  
public class RedisUtil {  
    private final RedisTemplate<String, Object> redisTemplate;  
  
    @Autowired  
    public RedisUtil(RedisTemplate<String, Object> redisTemplate) {  
        this.redisTemplate = redisTemplate;  
    }  
  
    /**  
     * 保存属性  
     */  
    public void set(String key, Object value, long time) {  
        redisTemplate.opsForValue().set(key, value, time, TimeUnit.SECONDS);  
    }  
  
    /**  
     * 保存属性  
     */  
    public void set(String key, Object value) {  
        redisTemplate.opsForValue().set(key, value);  
    }  
  
    /**  
     * 获取属性  
     */  
    public Object get(String key) {  
        return redisTemplate.opsForValue().get(key);  
    }  
  
    /**  
     * 删除属性  
     */  
    public Boolean del(String key) {  
        return redisTemplate.delete(key);  
    }  
  
    /**  
     * 批量删除属性  
     */  
    public Long del(List<String> keys) {  
        return redisTemplate.delete(keys);  
    }  
  
    /**  
     * 设置过期时间  
     */  
    public Boolean expire(String key, long time) {  
        return redisTemplate.expire(key, time, TimeUnit.SECONDS);  
    }  
  
    /**  
     * 获取过期时间  
     */  
    public Long getExpire(String key) {  
        return redisTemplate.getExpire(key, TimeUnit.SECONDS);  
    }  
  
    /**  
     * 判断是否有该属性  
     */  
    public Boolean hasKey(String key) {  
        return redisTemplate.hasKey(key);  
    }  
  
    /**  
     * 按delta递增  
     */  
    public Long incr(String key, long delta) {  
        return redisTemplate.opsForValue().increment(key, delta);  
    }  
  
    /**  
     * 按delta递减  
     */  
    public Long decr(String key, long delta) {  
        return redisTemplate.opsForValue().increment(key, -delta);  
    }  
  
    /**  
     * 获取Hash结构中的属性  
     */  
    public Object hGet(String key, String hashKey) {  
        return redisTemplate.opsForHash().get(key, hashKey);  
    }  
  
    /**  
     * 向Hash结构中放入一个属性  
     */  
    public Boolean hSet(String key, String hashKey, Object value, long time) {  
        redisTemplate.opsForHash().put(key, hashKey, value);  
        return expire(key, time);  
    }  
  
    /**  
     * 向Hash结构中放入一个属性  
     */  
    public void hSet(String key, String hashKey, Object value) {  
        redisTemplate.opsForHash().put(key, hashKey, value);  
    }  
  
    /**  
     * 直接获取整个Hash结构  
     */  
    public Map<Object, Object> hGetAll(String key) {  
        return redisTemplate.opsForHash().entries(key);  
    }  
  
    /**  
     * 直接设置整个Hash结构  
     */  
    public Boolean hSetAll(String key, Map<String, Object> map, long time) {  
        redisTemplate.opsForHash().putAll(key, map);  
        return expire(key, time);  
    }  
  
    /**  
     * 直接设置整个Hash结构  
     */  
    public void hSetAll(String key, Map<String, ?> map) {  
        redisTemplate.opsForHash().putAll(key, map);  
    }  
  
    /**  
     * 删除Hash结构中的属性  
     */  
    public void hDel(String key, ObjecthashKey) {  
        redisTemplate.opsForHash().delete(key, hashKey);  
    }  
  
    /**  
     * 判断Hash结构中是否有该属性  
     */  
    public Boolean hHasKey(String key, String hashKey) {  
        return redisTemplate.opsForHash().hasKey(key, hashKey);  
    }  
  
    /**  
     * Hash结构中属性递增  
     */  
    public Long hIncr(String key, String hashKey, Long delta) {  
        return redisTemplate.opsForHash().increment(key, hashKey, delta);  
    }  
  
    /**  
     * Hash结构中属性递减  
     */  
    public Long hDecr(String key, String hashKey, Long delta) {  
        return redisTemplate.opsForHash().increment(key, hashKey, -delta);  
    }  
  
    /**  
     * 获取Set结构  
     */  
    public Set<Object> sMembers(String key) {  
        return redisTemplate.opsForSet().members(key);  
    }  
  
    /**  
     * 向Set结构中添加属性  
     */  
    public Long sAdd(String key, Objectvalues) {  
        return redisTemplate.opsForSet().add(key, values);  
    }  
  
    /**  
     * 向Set结构中添加属性  
     */  
    public Long sAdd(String key, long time, Objectvalues) {  
        Long count = redisTemplate.opsForSet().add(key, values);  
        expire(key, time);  
        return count;  
    }  
  
    /**  
     * 是否为Set中的属性  
     */  
    public Boolean sIsMember(String key, Object value) {  
        return redisTemplate.opsForSet().isMember(key, value);  
    }  
  
    /**  
     * 获取Set结构的长度  
     */  
    public Long sSize(String key) {  
        return redisTemplate.opsForSet().size(key);  
    }  
  
    /**  
     * 删除Set结构中的属性  
     */  
    public Long sRemove(String key, Objectvalues) {  
        return redisTemplate.opsForSet().remove(key, values);  
    }  
  
    /**  
     * 获取List结构中的属性  
     */  
    public List<Object> lRange(String key, long start, long end) {  
        return redisTemplate.opsForList().range(key, start, end);  
    }  
  
    /**  
     * 获取List结构的长度  
     */  
    public Long lSize(String key) {  
        return redisTemplate.opsForList().size(key);  
    }  
  
    /**  
     * 根据索引获取List中的属性  
     */  
    public Object lIndex(String key, long index) {  
        return redisTemplate.opsForList().index(key, index);  
    }  
  
    /**  
     * 向List结构中添加属性  
     */  
    public Long lPush(String key, Object value) {  
        return redisTemplate.opsForList().rightPush(key, value);  
    }  
  
    /**  
     * 向List结构中添加属性  
     */  
    public Long lPush(String key, Object value, long time) {  
        Long index = redisTemplate.opsForList().rightPush(key, value);  
        expire(key, time);  
        return index;  
    }  
  
    /**  
     * 向List结构中批量添加属性  
     */  
    public Long lPushAll(String key, Objectvalues) {  
        return redisTemplate.opsForList().rightPushAll(key, values);  
    }  
  
    /**  
     * 向List结构中批量添加属性  
     */  
    public Long lPushAll(String key, Long time, Objectvalues) {  
        Long count = redisTemplate.opsForList().rightPushAll(key, values);  
        expire(key, time);  
        return count;  
    }  
  
    /**  
     * 从List结构中移除属性  
     */  
    public Long lRemove(String key, long count, Object value) {  
        return redisTemplate.opsForList().remove(key, count, value);  
    }  
}

2 Sa-Token

使用 Sa-Token 实现用户鉴权,本项目使用 Simple 简单模式 实现基础的 jwt 注入:登录数据存储 - Redis / Session 数据存储 - Redis / 注销下线 - 前后端双清除数据

xml
<!-- sa-token -->  
<dependency>  
    <groupId>cn.dev33</groupId>  
    <artifactId>sa-token-spring-boot3-starter</artifactId>  
</dependency>  
<!-- Sa-Token 整合 Redis (使用fastjson2序列化方式) -->  
<dependency>  
    <groupId>cn.dev33</groupId>  
    <artifactId>sa-token-redis-fastjson2</artifactId>  
</dependency>  
<!-- Sa-Token 整合 jwt --><dependency>  
    <groupId>cn.dev33</groupId>  
    <artifactId>sa-token-jwt</artifactId>  
</dependency>  
<!-- Sa-Token 整合 SpringAOP 实现注解鉴权 -->  
<dependency>  
    <groupId>cn.dev33</groupId>  
    <artifactId>sa-token-spring-aop</artifactId>  
</dependency>
yaml
sa-token:  
  token-name: Authorization  
  timeout: 604800  
  active-timeout: -1  
  is-concurrent: true  
  is-share: false  
  token-style: uuid  
  is-log: false  
  is-read-cookie: false  
  is-read-header: true  
  token-prefix: Bearer  
  jwt-secret-key: puolikaumjhnygbdfvdcsxa  
  is-print: false
java
package com.mole.health.config;  
  
import cn.dev33.satoken.jwt.StpLogicJwtForSimple;  
import cn.dev33.satoken.stp.StpLogic;  
import cn.dev33.satoken.strategy.SaAnnotationStrategy;  
import jakarta.annotation.PostConstruct;  
import org.springframework.context.annotation.Bean;  
import org.springframework.context.annotation.Configuration;  
import org.springframework.core.annotation.AnnotatedElementUtils;  
  
@Configuration  
public class SaTokenConfig {  
    // Sa-Token 整合 jwt (Simple 简单模式)  
    @Bean  
    public StpLogic getStpLogicJwt() {  
        return new StpLogicJwtForSimple();  
    }  
  
    @PostConstruct  
    public void rewriteSaStrategy() {  
        // 重写Sa-Token的注解处理器,增加注解合并功能  
        SaAnnotationStrategy.instance.getAnnotation = AnnotatedElementUtils::getMergedAnnotation;  
    }  
}

2.1 身份区分

使用 StpKit.USER 代替 StpUtil 作为用户相关操作的工具类

java
package com.mole.health.util;  
  
import cn.dev33.satoken.stp.StpLogic;  
import cn.dev33.satoken.stp.StpUtil;  
  
/**  
 * StpLogic 门面类,管理项目中所有的 StpLogic 账号体系  
 */  
public class StpKit {  
  
    /**  
     * 默认原生会话对象  
     */  
    public static final StpLogic DEFAULT = StpUtil.stpLogic;  
  
    /**  
     * Admin 会话对象,管理 Admin 表所有账号的登录、权限认证  
     */  
    public static final StpLogic ADMIN = new StpLogic("admin");  
  
    /**  
     * User 会话对象,管理 User 表所有账号的登录、权限认证  
     */  
    public static final StpLogic USER = new StpLogic("user");  
  
    /**  
     * XX 会话对象,(项目中有多少套账号表,就声明几个 StpLogic 会话对象)  
     */  
    public static final StpLogic XXX = new StpLogic("xx");  
  
}

2.1.1 登录校验

java
package com.mole.health.annotation;  
  
import cn.dev33.satoken.annotation.SaCheckLogin;  
  
import java.lang.annotation.ElementType;  
import java.lang.annotation.Retention;  
import java.lang.annotation.RetentionPolicy;  
import java.lang.annotation.Target;  
  
/**  
 * 登录认证(User):只有登录之后才能进入该方法  
 * <p> 可标注在函数、类上(效果等同于标注在此类的所有方法上)  
 */  
@SaCheckLogin(type = "user")  
@Retention(RetentionPolicy.RUNTIME)  
@Target({ElementType.METHOD, ElementType.TYPE})  
public @interface SaUserCheckLogin {  
  
}
java
package com.mole.health.annotation;  
  
import cn.dev33.satoken.annotation.SaCheckLogin;  
  
import java.lang.annotation.ElementType;  
import java.lang.annotation.Retention;  
import java.lang.annotation.RetentionPolicy;  
import java.lang.annotation.Target;  
  
/**  
 * 登录认证(Admin):只有登录之后才能进入该方法  
 * <p> 可标注在函数、类上(效果等同于标注在此类的所有方法上)  
 */  
@SaCheckLogin(type = "admin")  
@Retention(RetentionPolicy.RUNTIME)  
@Target({ElementType.METHOD, ElementType.TYPE})  
public @interface SaAdminCheckLogin {  
  
}

3 Spring Mail

使用 Spring Mail 实现邮箱验证码发送

yaml
spring:
  mail:  
    host: ${MAIL_HOST:your-mail-host}  
    port: ${MAIL_PORT:your-mail-port}  
	username: ${MAIL_USERNAME:your-dev-email}  
    password: ${MAIL_PASSWORD:your-dev-password}  
    properties:  
      mail:  
        smtp:  
          auth: true  
          starttls:  
            enable: true  
          ssl:  
            enable: true
java
package com.mole.health.util;  
  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.beans.factory.annotation.Value;  
import org.springframework.mail.SimpleMailMessage;  
import org.springframework.mail.javamail.JavaMailSender;  
import org.springframework.stereotype.Component;  
  
@Component  
public class EmailUtil {  
  
    private final JavaMailSender mailSender;  
    @Value("${spring.mail.username}")  
    private String emailFrom;  
  
    @Autowired  
    public EmailUtil(JavaMailSender mailSender) {  
        this.mailSender = mailSender;  
    }  
  
    /**  
     * 发送简单文本邮件  
     *  
     * @param to      收件人邮箱地址  
     * @param subject 邮件主题  
     * @param text    邮件内容  
     */  
    public void sendSimpleEmail(String to, String subject, String text) {  
        SimpleMailMessage message = new SimpleMailMessage();  
        message.setFrom(emailFrom);  
        message.setTo(to);  
        message.setSubject(subject);  
        message.setText(text);  
        mailSender.send(message);  
    }  
}

3.1 Controller

java
package com.mole.health.controller;  
  
import com.mole.health.result.CommonResult;  
import com.mole.health.util.EmailUtil;  
import io.swagger.v3.oas.annotations.Operation;  
import io.swagger.v3.oas.annotations.tags.Tag;  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.web.bind.annotation.PostMapping;  
import org.springframework.web.bind.annotation.RequestMapping;  
import org.springframework.web.bind.annotation.RequestParam;  
import org.springframework.web.bind.annotation.RestController;  
  
@RestController  
@RequestMapping("/email")  
@Tag(name = "EmailController", description = "邮件发送接口")  
public class EmailController {  
  
    private final EmailUtil emailUtil;  
  
    @Autowired  
    public EmailController(EmailUtil emailUtil) {  
        this.emailUtil = emailUtil;  
    }  
  
    /**  
     * 发送简单文本邮件  
     *  
     * @param to      收件人邮箱地址  
     * @param subject 邮件主题  
     * @param text    邮件内容  
     * @return CommonResult<String> 返回结果  
     */  
    @PostMapping("/simple")  
    @Operation(summary = "发送简单文本邮件", description = "发送简单文本邮件")  
    public CommonResult<String> sendSimpleEmail(@RequestParam String to,  
                                                @RequestParam String subject,  
                                                @RequestParam String text) {  
        try {  
            emailUtil.sendSimpleEmail(to, subject, text);  
            return CommonResult.success("Simple email sent successfully");  
        } catch (Exception e) {  
            return CommonResult.failed("Error sending simple email: " + e.getMessage());  
        }  
    }  
}
java
package com.mole.health.controller;  
  
import com.mole.health.dto.user.EmailDto;  
import com.mole.health.result.CommonResult;  
import com.mole.health.service.EmailVerificationCodeService;  
import io.swagger.v3.oas.annotations.Operation;  
import io.swagger.v3.oas.annotations.Parameter;  
import io.swagger.v3.oas.annotations.responses.ApiResponse;  
import io.swagger.v3.oas.annotations.tags.Tag;  
import lombok.extern.slf4j.Slf4j;  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.web.bind.annotation.PostMapping;  
import org.springframework.web.bind.annotation.RequestBody;  
import org.springframework.web.bind.annotation.RequestMapping;  
import org.springframework.web.bind.annotation.RestController;  
  
@Slf4j  
@RestController  
@RequestMapping("/email-verification")  
@Tag(name = "EmailVerificationCodeController", description = "邮箱验证码控制器")  
public class EmailVerificationCodeController {  
  
    private final EmailVerificationCodeService emailVerificationCodeService;  
  
    @Autowired  
    public EmailVerificationCodeController(EmailVerificationCodeService emailVerificationCodeService) {  
        this.emailVerificationCodeService = emailVerificationCodeService;  
    }  
  
    /**  
     * 发送验证码接口  
     *  
     * @param emailDto 用户邮箱DTO  
     * @return 发送结果  
     */  
    @PostMapping("/send-code")  
    @Operation(summary = "发送验证码", description = "向指定邮箱发送验证码")  
    @ApiResponse(responseCode = "200", description = "验证码发送成功")  
    public CommonResult<String> sendVerificationCode(@RequestBody @Parameter(description = "用户邮箱") EmailDto emailDto) {  
        try {  
            String code = emailVerificationCodeService.generateCode(emailDto);  
            emailDto.setEmailVerificationCode(code);  
            emailVerificationCodeService.sendCode(emailDto);  
            return CommonResult.success("验证码已发送");  
        } catch (Exception e) {  
            log.error("发送验证码失败", e);  
            return CommonResult.failed("发送验证码失败:" + e.getMessage());  
        }  
    }  
  
    /**  
     * 验证验证码接口  
     *  
     * @param emailDto 包含邮箱和验证码的DTO  
     * @return 验证结果  
     */  
    @PostMapping("/verify-code")  
    @Operation(summary = "验证验证码", description = "验证用户输入的验证码是否正确")  
    @ApiResponse(responseCode = "200", description = "验证码验证结果")  
    public CommonResult<Boolean> verifyCode(@RequestBody @Parameter(description = "邮箱和验证码") EmailDto emailDto) {  
        boolean isValid = emailVerificationCodeService.verifyCode(emailDto);  
        return CommonResult.success(isValid);  
    }  
}

3.2 Service

java
package com.mole.health.service;  
  
import com.mole.health.dto.user.EmailDto;  
  
public interface EmailVerificationCodeService {  
  
    /**  
     * 生成验证码  
     *  
     * @param emailDto 包含邮箱的DTO  
     * @return 生成的验证码  
     */  
    String generateCode(EmailDto emailDto);  
  
    /**  
     * 发送验证码  
     *  
     * @param emailDto 包含邮箱和验证码的DTO  
     */    void sendCode(EmailDto emailDto);  
  
    /**  
     * 验证验证码  
     *  
     * @param emailDto 包含邮箱和验证码的DTO  
     * @return 验证结果  
     */  
    boolean verifyCode(EmailDto emailDto);  
}
java
package com.mole.health.service.impl;  
  
import com.mole.health.dto.user.EmailDto;  
import com.mole.health.service.EmailVerificationCodeService;  
import com.mole.health.util.EmailUtil;  
import com.mole.health.util.RedisUtil;  
import lombok.extern.slf4j.Slf4j;  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.stereotype.Service;  
  
import java.util.Random;  
  
@Slf4j  
@Service  
public class EmailVerificationCodeServiceImpl implements EmailVerificationCodeService {  
  
    // 验证码过期时间,单位为分钟  
    private static final long CODE_EXPIRE_TIME = 30;  
    private final RedisUtil redisUtil;  
    private final EmailUtil emailUtil;  
  
    /**  
     * 构造函数,通过依赖注入初始化 RedisUtil 和 EmailUtil  
     * @param redisUtil Redis 工具类  
     * @param emailUtil 邮件工具类  
     */  
    @Autowired  
    public EmailVerificationCodeServiceImpl(RedisUtil redisUtil, EmailUtil emailUtil) {  
        this.redisUtil = redisUtil;  
        this.emailUtil = emailUtil;  
    }  
  
    /**  
     * 生成验证码并存储到 Redis 中  
     *  
     * @param emailDto 包含邮箱的 DTO  
     * @return 生成的验证码  
     */  
    @Override  
    public String generateCode(EmailDto emailDto) {  
        String code = generateRandomCode(); // 生成随机验证码  
        // 将验证码存储到 Redis 中,并设置过期时间  
        redisUtil.set(getRedisKey(emailDto.getEmail()), code, CODE_EXPIRE_TIME * 60); // 转换为秒  
        log.info("为邮箱 {} 生成验证码:{}", emailDto.getEmail(), code); // 记录日志  
        return code; // 返回生成的验证码  
    }  
  
    /**  
     * 发送验证码到指定邮箱  
     *  
     * @param emailDto 包含邮箱和验证码的 DTO  
     */    @Override  
    public void sendCode(EmailDto emailDto) {  
        String code = emailDto.getEmailVerificationCode(); // 获取验证码  
        // 发送简单文本邮件  
        emailUtil.sendSimpleEmail(emailDto.getEmail(), "验证码", "您的验证码是:" + code);  
        log.info("向邮箱 {} 发送验证码:{}", emailDto.getEmail(), code); // 记录日志  
    }  
  
    /**  
     * 验证用户输入的验证码是否正确  
     *  
     * @param emailDto 包含邮箱和验证码的 DTO  
     * @return 验证结果  
     */  
    @Override  
    public boolean verifyCode(EmailDto emailDto) {  
        // 从 Redis 中获取存储的验证码  
        String storedCode = (String) redisUtil.get(getRedisKey(emailDto.getEmail()));  
        log.info("验证邮箱 {} 的验证码:{}", emailDto.getEmail(), storedCode); // 记录日志  
        // 比较用户输入的验证码和存储的验证码是否一致  
        boolean isValid = emailDto.getEmailVerificationCode().equals(storedCode);  
        log.info("验证码验证结果:{}", isValid ? "成功" : "失败"); // 记录验证结果  
        return isValid; // 返回验证结果  
    }  
  
    /**  
     * 生成6位随机数字验证码  
     *  
     * @return 生成的验证码  
     */  
    private String generateRandomCode() {  
        Random random = new Random();  
        StringBuilder sb = new StringBuilder();  
        for (int i = 0; i < 6; i++) {  
            sb.append(random.nextInt(10)); // 生成0-9之间的随机数  
        }  
        return sb.toString(); // 返回生成的验证码  
    }  
  
    /**  
     * 获取存储验证码的 Redis 键  
     *  
     * @param email 邮箱地址  
     * @return Redis 键  
     */  
    private String getRedisKey(String email) {  
        return "email_verification_code:" + email; // 拼接 Redis 键  
    }  
}

4 MinIO

5 SpringAI

通过 SpringAI 接入 ZhiPuAI (接入操作可参考 Spring AI 集成 OpenAI)实现个人健康助手,基于 Websocket 实现信息收发,消息持久化到 MongoDB 并缓存到 Redis

5.1 后端

yaml
spring:
  ai:  
    zhipuai:  
      api-key: ${API_KEY:your-default-dev-key}  
      chat:  
        options:  
          model: GLM-4-Flash  
          temperature: 0.7
java
package com.mole.health.config;  
  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.context.annotation.Configuration;  
import org.springframework.web.socket.config.annotation.EnableWebSocket;  
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;  
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;  
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;  
  
import com.mole.health.websocket.AiChatWebSocketHandler;  
  
@Configuration  
@EnableWebSocket  
public class WebSocketConfig implements WebSocketConfigurer {  
  
    private final AiChatWebSocketHandler aiChatWebSocketHandler;  
  
    @Autowired  
    public WebSocketConfig(AiChatWebSocketHandler aiChatWebSocketHandler) {  
        this.aiChatWebSocketHandler = aiChatWebSocketHandler;  
    }  
  
    @Override  
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {  
        registry.addHandler(aiChatWebSocketHandler, "/ws/ai-chat").addInterceptors(new HttpSessionHandshakeInterceptor()).setAllowedOrigins("*");  
    }  
}
java
package com.mole.health.websocket;  
  
import com.fasterxml.jackson.databind.ObjectMapper;  
import com.mole.health.pojo.UserAiInteraction;  
import com.mole.health.repository.UserAiInteractionRepository;  
import com.mole.health.util.RedisUtil;  
import com.mole.health.util.TokenUtil;  
import lombok.extern.slf4j.Slf4j;  
import org.springframework.ai.chat.messages.Message;  
import org.springframework.ai.chat.messages.UserMessage;  
import org.springframework.ai.chat.model.ChatResponse;  
import org.springframework.ai.chat.prompt.Prompt;  
import org.springframework.ai.zhipuai.ZhiPuAiChatModel;  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.stereotype.Component;  
import org.springframework.web.socket.CloseStatus;  
import org.springframework.web.socket.TextMessage;  
import org.springframework.web.socket.WebSocketSession;  
import org.springframework.web.socket.handler.TextWebSocketHandler;  
  
import java.time.Instant;  
import java.util.ArrayList;  
import java.util.List;  
import java.util.Objects;  
import java.util.Optional;  
  
/**  
 * 处理 AI 聊天 WebSocket 连接的处理器。  
 */  
@Slf4j  
@Component  
public class AiChatWebSocketHandler extends TextWebSocketHandler {  
  
    private final ObjectMapper objectMapper;  
    private final ZhiPuAiChatModel chatModel;  
    private final RedisUtil redisUtil;  
    private final UserAiInteractionRepository userAiInteractionRepository;  
    private final TokenUtil tokenUtil;  
    private final String basePrompt;  
  
    /**  
     * 构造函数,注入所需的依赖。  
     *  
     * @param objectMapper                用于 JSON 序列化和反序列化的 ObjectMapper  
     * @param chatModel                   用于处理聊天的 AI 模型  
     * @param redisUtil                   用于 Redis 操作的工具类  
     * @param userAiInteractionRepository 用户 AI 交互的存储库  
     */  
    @Autowired  
    public AiChatWebSocketHandler(ObjectMapper objectMapper, ZhiPuAiChatModel chatModel, RedisUtil redisUtil, UserAiInteractionRepository userAiInteractionRepository, TokenUtil tokenUtil) {  
        this.objectMapper = objectMapper;  
        this.chatModel = chatModel;  
        this.redisUtil = redisUtil;  
        this.userAiInteractionRepository = userAiInteractionRepository;  
        this.tokenUtil = tokenUtil;  
        this.basePrompt = "你是一个专业的健康助手AI。请根据以下问题提供简洁、准确的健康建议。如果遇到需要专业医疗诊断的问题,请建议用户咨询医生:\n\n";  
    }  
  
    /**  
     * 处理收到的文本消息。  
     *  
     * @param session 当前的 WebSocket 会话  
     * @param message 收到的文本消息  
     * @throws Exception 处理消息时可能抛出的异常  
     */  
    @Override  
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {  
        log.info("Received message: {}", message.getPayload());  
        String userId = getUserIdFromSession(session);  
        String sessionId = getSessionIdFromSession(session);  
        String question = message.getPayload();  
        log.info("User ID: {}, Session ID: {}, Question: {}", userId, sessionId, question);  
  
        String cacheKey = "user_ai_interaction:" + userId + ":" + sessionId;  
        UserAiInteraction interaction;  
  
        // 尝试从 Redis 缓存中获取交互信息  
        String cachedInteraction = (String) redisUtil.get(cacheKey);  
        log.info("Cached interaction: {}", cachedInteraction);  
        if (cachedInteraction != null) {  
            interaction = objectMapper.readValue(cachedInteraction, UserAiInteraction.class);  
        } else {  
            // 如果缓存中没有,则从数据库中获取  
            Optional<UserAiInteraction> interactionOpt = userAiInteractionRepository.findByUserIdAndSessionId(userId, sessionId);  
            interaction = interactionOpt.orElseGet(() -> UserAiInteraction.builder().userId(userId).sessionId(sessionId).context(new ArrayList<>()).timestamp(Instant.now()).build());  
        }  
  
        List<Message> messages = new ArrayList<>();  
        messages.add(new UserMessage(basePrompt));  
        for (UserAiInteraction.Message contextMessage : interaction.getContext()) {  
            messages.add(new UserMessage(contextMessage.getContent()));  
        }  
        messages.add(new UserMessage(question));  
  
        var prompt = new Prompt(messages);  
        ChatResponse response = chatModel.call(prompt);  
        String answer = response.getResult().getOutput().getContent();  
        log.info("Answer: {}", answer);  
  
        interaction.getContext().add(UserAiInteraction.Message.builder().role("user").content(question).build());  
  
        interaction.getContext().add(UserAiInteraction.Message.builder().role("assistant").content(answer).build());  
        userAiInteractionRepository.save(interaction);  
  
        session.sendMessage(new TextMessage(answer));  
        log.info("Sent answer: {}", answer);  
    }  
  
    /**  
     * 从 WebSocket 会话中获取用户 ID。  
     *  
     * @param session 当前的 WebSocket 会话  
     * @return 用户 ID  
     */    
     private String getUserIdFromSession(WebSocketSession session) {  
        return session.getAttributes().get("userId").toString();  
    }  
  
    /**
     * 从 WebSocket 会话中获取会话 ID。  
     *  
     * @param session 当前的 WebSocket 会话  
     * @return 会话 ID  
     */    
     private String getSessionIdFromSession(WebSocketSession session) {  
        // 直接返回 WebSocket 会话的 ID        
        return session.getId();  
    }  
  
    /**  
     * 在连接建立时进行初始化操作。  
     *  
     * @param session 当前的 WebSocket 会话  
     * @throws Exception 初始化操作时可能抛出的异常  
     */  
    @Override  
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {  
        String query = Objects.requireNonNull(session.getUri()).getQuery();  
        String token = null;  
        if (query != null && query.startsWith("token=")) {  
            token = query.substring(6);  
        }  
  
        if (token != null && !token.isEmpty()) {  
            // 验证 token            
            String userId = tokenUtil.validateTokenAndGetUserId(token);  
            if (userId != null) {  
                session.getAttributes().put("userId", userId);  
                log.info("WebSocket connection established for user: {}", userId);  
            } else {  
                // token 无效,关闭连接  
                session.close(CloseStatus.POLICY_VIOLATION);  
                log.info("WebSocket connection closed due to invalid token");  
            }  
        } else {  
            // 没有提供 token,关闭连接  
            session.close(CloseStatus.POLICY_VIOLATION);  
            log.info("WebSocket connection closed due to missing token");  
        }  
    }  
}

因为 StpUtil 无法在非常规 HTTP 上下文中操作,所以创建 TokenUtil 作为中间层引用

java
package com.mole.health.util;  
  
import lombok.extern.slf4j.Slf4j;  
import org.springframework.stereotype.Component;  
  
@Component  
@Slf4j  
public class TokenUtil {  
  
    public String validateTokenAndGetUserId(String token) {  
        try {  
            // 使用 Sa-Token 验证 token
            Object loginId = StpKit.USER.getLoginIdByToken(token);  
            if (loginId != null) {  
                return loginId.toString();  
            }  
        } catch (Exception e) {  
            // 处理验证过程中的异常  
            log.error("Token validation error", e);  
        }  
        return null;  
    }  
}

5.2 前端

ts
import { ref } from 'vue'
import type { Ref } from 'vue'

const apiBasePath = import.meta.env.VITE_API_BASE_PATH || '/api'
const isDevelopment = import.meta.env.MODE === 'development'
const wsProtocol = window.location.protocol === 'https:' ? 'wss' : 'ws'

let wsBaseUrl: string

if (isDevelopment) {
  const backendUrl =
    import.meta.env.VITE_DEV_SERVER_URL || window.location.origin
  wsBaseUrl = `${wsProtocol}://${new URL(backendUrl).host}/ws`
} else {
  // 在生产环境中,使用当前页面的主机名
  wsBaseUrl = `${wsProtocol}://${window.location.host}${apiBasePath}/ws`
}

export const useWebSocket = (endpoint: string) => {
  const socket: Ref<WebSocket | null> = ref(null)
  const isConnected = ref(false)

  const connect = () => {
    const token = localStorage.getItem('token')
    socket.value = new WebSocket(`${wsBaseUrl}/${endpoint}?token=${token}`)

    socket.value.onopen = () => {
      isConnected.value = true
      console.log('WebSocket connected')
    }

    socket.value.onclose = () => {
      isConnected.value = false
      console.log('WebSocket disconnected')
    }
    
    socket.value.onerror = error => {
      console.error('WebSocket error:', error)
    }
  }

  const disconnect = () => {
    if (socket.value) {
      socket.value.close()
    }
  }

  const sendMessage = (message: string) => {
    if (socket.value && socket.value.readyState === WebSocket.OPEN) {
      socket.value.send(message)
    } else {
      console.error('WebSocket is not connected')
    }
  }
  
  return {
    socket,
    isConnected,
    connect,
    disconnect,
    sendMessage,
  }
}
ts
import { defineStore } from 'pinia'
import { ref, watch } from 'vue'
import { useWebSocket } from '@/service/websocket'

export const useAiChatStore = defineStore('aiChat', () => {
  const { socket, isConnected, connect, disconnect, sendMessage } =
    useWebSocket('ai-chat')
  const messages = ref([
    {
      role: 'assistant',
      content: '您好!我是您的健康助手。有什么可以帮助您的吗?',
    },
  ])
  const isLoading = ref(false)

  const initWebSocket = () => {
    connect()
    if (socket.value) {
      socket.value.onmessage = event => {
        console.log('收到消息:', event.data)
        messages.value.push({ role: 'assistant', content: event.data })
        isLoading.value = false
      }
    }
  }

  const sendChatMessage = (message: string) => {
    if (message.trim() && !isLoading.value) {
      messages.value.push({ role: 'user', content: message })
      isLoading.value = true
      console.log('发送消息:', message)
      sendMessage(message)
    }
  }

  watch(isConnected, newValue => {
    console.log('WebSocket连接状态变化:', newValue)
    if (newValue) {
      console.log('WebSocket已连接')
    } else {
      console.log('WebSocket已断开')
    }
  })

  watch(
    () => socket.value,
    newSocket => {
      console.log('WebSocket实例变化:', newSocket)
      if (newSocket) {
        console.log('新的WebSocket实例已创建')
      } else {
        console.log('WebSocket实例已销毁')
      }
    },
  )

  return {
    messages,
    isLoading,
    isConnected,
    initWebSocket,
    sendChatMessage,
    disconnect,
  }
})