@Component
@Slf4j
public class ContextAwareQueryEnhancer {
private final ChatClient chatClient;
private final ConversationMemoryService memoryService;
private final EntityResolutionService entityResolver;
private final IntentTrackingService intentTracker;
private final ObjectMapper objectMapper;
@Value("${spring.ai.context-enhancer.max-history-turns:10}")
private int maxHistoryTurns;
@Value("${spring.ai.context-enhancer.enable-llm:true}")
private boolean enableLlmEnhancement;
@Value("${spring.ai.context-enhancer.min-confidence:0.6}")
private double minConfidence;
public ContextAwareQueryEnhancer(ChatClient chatClient,
ConversationMemoryService memoryService,
EntityResolutionService entityResolver,
IntentTrackingService intentTracker,
ObjectMapper objectMapper) {
this.chatClient = chatClient;
this.memoryService = memoryService;
this.entityResolver = entityResolver;
this.intentTracker = intentTracker;
this.objectMapper = objectMapper;
}
public EnhancedQuery enhanceQuery(String currentQuery, String conversationId) {
long startTime = System.currentTimeMillis();
try {
log.info("开始上下文增强处理:{}", currentQuery);
ConversationContext context = memoryService.getConversationContext(conversationId, maxHistoryTurns);
QueryAnalysis analysis = analyzeQueryFeatures(currentQuery, context);
List<EnhancementCandidate> candidates = generateEnhancementCandidates(currentQuery, context, analysis);
EnhancedQuery enhancedQuery = selectBestEnhancement(candidates, currentQuery, context);
updateConversationMemory(conversationId, currentQuery, enhancedQuery, context);
long duration = System.currentTimeMillis() - startTime;
log.info("上下文增强完成:原始『{}』→ 增强『{}』, 耗时 {}ms", currentQuery, enhancedQuery.getEnhancedQuery(), duration);
return enhancedQuery;
} catch (Exception e) {
log.error("上下文增强失败:{}", currentQuery, e);
return createFallbackEnhancedQuery(currentQuery);
}
}
private QueryAnalysis analyzeQueryFeatures(String query, ConversationContext context) {
QueryAnalysis analysis = new QueryAnalysis();
analysis.setContainsPronouns(containsPronouns(query));
analysis.setPronouns(detectPronouns(query));
analysis.setContainsEllipsis(containsEllipsis(query));
analysis.setContainsReferences(containsReferences(query));
analysis.setIntentContinuation(analyzeIntentContinuation(query, context));
analysis.setEntityConsistency(analyzeEntityConsistency(query, context));
if (enableLlmEnhancement) {
QueryAnalysis llmAnalysis = analyzeWithLLM(query, context);
analysis.merge(llmAnalysis);
}
return analysis;
}
private List<EnhancementCandidate> generateEnhancementCandidates(String query, ConversationContext context, QueryAnalysis analysis) {
List<EnhancementCandidate> candidates = new ArrayList<>();
candidates.addAll(generateRuleBasedEnhancements(query, context, analysis));
candidates.addAll(generateTemplateBasedEnhancements(query, context, analysis));
if (enableLlmEnhancement) {
candidates.addAll(generateLlmBasedEnhancements(query, context, analysis));
}
return candidates.stream()
.filter(candidate -> candidate.getConfidence() >= minConfidence)
.sorted(Comparator.comparing(EnhancementCandidate::getConfidence).reversed())
.collect(Collectors.toList());
}
private List<EnhancementCandidate> generateRuleBasedEnhancements(String query, ConversationContext context, QueryAnalysis analysis) {
List<EnhancementCandidate> candidates = new ArrayList<>();
if (analysis.isContainsPronouns()) {
candidates.addAll(resolvePronouns(query, context, analysis));
}
if (analysis.isContainsEllipsis()) {
candidates.addAll(completeEllipsis(query, context, analysis));
}
if (analysis.isContainsReferences()) {
candidates.addAll(resolveReferences(query, context, analysis));
}
return candidates;
}
private List<EnhancementCandidate> resolvePronouns(String query, ConversationContext context, QueryAnalysis analysis) {
List<EnhancementCandidate> candidates = new ArrayList<>();
for (String pronoun : analysis.getPronouns()) {
List<EntityReference> possibleReferents = findPossibleReferents(pronoun, context);
for (EntityReference referent : possibleReferents) {
String enhancedQuery = query.replace(pronoun, referent.getEntityName());
List<ContextReference> references = List.of(new ContextReference("pronoun", pronoun, referent.getEntityName(), referent.getTurnOffset(), referent.getConfidence()));
EnhancementCandidate candidate = EnhancementCandidate.builder()
.enhancedQuery(enhancedQuery)
.enhancementType(EnhancedQuery.EnhancementType.PRONOUN_RESOLUTION)
.confidence(referent.getConfidence() * 0.9)
.contextReferences(references)
.reasoning(String.format("代词『%s』指代『%s』", pronoun, referent.getEntityName()))
.build();
candidates.add(candidate);
}
}
return candidates;
}
private List<EnhancementCandidate> completeEllipsis(String query, ConversationContext context, QueryAnalysis analysis) {
List<EnhancementCandidate> candidates = new ArrayList<>();
Optional<ConversationTurn> lastCompleteTurn = findLastCompleteTurn(context);
if (lastCompleteTurn.isPresent()) {
ConversationTurn lastTurn = lastCompleteTurn.get();
String lastIntent = intentTracker.extractIntent(lastTurn.getUserMessage());
String enhancedQuery = lastIntent + " " + query;
List<ContextReference> references = List.of(new ContextReference("ellipsis", query, enhancedQuery, -1, 0.8));
EnhancementCandidate candidate = EnhancementCandidate.builder()
.enhancedQuery(enhancedQuery)
.enhancementType(EnhancedQuery.EnhancementType.ELLIPSIS_COMPLETION)
.confidence(0.7)
.contextReferences(references)
.reasoning("基于上一轮意图补全省略查询")
.build();
candidates.add(candidate);
}
return candidates;
}
private List<EnhancementCandidate> generateLlmBasedEnhancements(String query, ConversationContext context, QueryAnalysis analysis) {
try {
String prompt = buildLlmEnhancementPrompt(query, context, analysis);
String response = chatClient.prompt().user(prompt).call().content();
return parseLlmEnhancementResponse(response, query, context);
} catch (Exception e) {
log.warn("LLM 增强生成失败", e);
return List.of();
}
}
private String buildLlmEnhancementPrompt(String query, ConversationContext context, QueryAnalysis analysis) {
StringBuilder prompt = new StringBuilder();
prompt.append("你是一个专业的对话上下文理解助手。请根据对话历史,完善当前查询。\n\n");
prompt.append("对话历史:\n");
List<ConversationTurn> recentTurns = context.getRecentTurns().stream().limit(5).collect(Collectors.toList());
for (int i = 0; i < recentTurns.size(); i++) {
ConversationTurn turn = recentTurns.get(i);
prompt.append(String.format("用户 [%d]: %s\n", i, turn.getUserMessage()));
if (turn.getAssistantResponse() != null) {
prompt.append(String.format("助手 [%d]: %s\n", i, turn.getAssistantResponse()));
}
}
prompt.append("\n当前查询:").append(query).append("\n\n");
prompt.append("""
请生成 3 个上下文增强版本,使其更完整、明确。考虑:
1. 代词指代消解(它、这个、那个等)
2. 省略内容补全
3. 意图继承和延续
4. 实体一致性维护
按以下 JSON 格式返回:
{ "enhanced_queries": [ { "enhanced_query": "增强后的完整查询", "enhancement_type": "pronoun_resolution|ellipsis_completion|intent_inheritance|temporal_context", "confidence": 0.0-1.0, "reasoning": "增强的理由和依据", "resolved_references": [ { "type": "pronoun|entity|intent", "original": "原始指代", "resolved": "解析结果", "source_turn": -1 } ] } ] }
""");
return prompt.toString();
}
private EnhancedQuery selectBestEnhancement(List<EnhancementCandidate> candidates, String originalQuery, ConversationContext context) {
if (candidates.isEmpty()) {
return createFallbackEnhancedQuery(originalQuery);
}
EnhancementCandidate bestCandidate = candidates.stream().max(Comparator.comparingDouble(EnhancementCandidate::getConfidence)).orElse(candidates.get(0));
return convertToEnhancedQuery(bestCandidate, originalQuery);
}
private boolean containsPronouns(String query) {
Set<String> chinesePronouns = Set.of("它", "他", "她", "这个", "那个", "这些", "那些", "其");
Set<String> englishPronouns = Set.of("it", "this", "that", "these", "those", "they", "them");
String lowerQuery = query.toLowerCase();
return Stream.concat(chinesePronouns.stream(), englishPronouns.stream()).anyMatch(lowerQuery::contains);
}
private List<String> detectPronouns(String query) {
Set<String> allPronouns = Set.of("它", "他", "她", "这个", "那个", "这些", "那些", "it", "this", "that", "these", "those", "they", "them");
return allPronouns.stream().filter(query::contains).collect(Collectors.toList());
}
private boolean containsEllipsis(String query) {
if (query.length() < 4) return true;
Set<String> ellipsisIndicators = Set.of("呢", "吗", "详细", "具体", "其他", "还有", "怎么样");
return ellipsisIndicators.stream().anyMatch(query::contains) || !query.matches(".*[\\u4e00-\\u9fff]+.*");
}
private boolean containsReferences(String query) {
Set<String> referenceWords = Set.of("前者", "后者", "刚才", "之前", "上面", "下面");
return referenceWords.stream().anyMatch(query::contains);
}
private List<EntityReference> findPossibleReferents(String pronoun, ConversationContext context) {
List<EntityReference> referents = new ArrayList<>();
for (int i = 0; i < context.getRecentTurns().size(); i++) {
ConversationTurn turn = context.getRecentTurns().get(i);
List<Entity> entities = entityResolver.extractEntities(turn.getUserMessage());
for (Entity entity : entities) {
double confidence = calculateReferentConfidence(pronoun, entity, i);
if (confidence > minConfidence) {
referents.add(new EntityReference(entity.getName(), entity.getType(), -i - 1, confidence));
}
}
}
return referents.stream().sorted(Comparator.comparing(EntityReference::getConfidence).reversed()).limit(3).collect(Collectors.toList());
}
private double calculateReferentConfidence(String pronoun, Entity entity, int turnOffset) {
double baseConfidence = 0.7;
double distancePenalty = Math.max(0.1, 1.0 - (turnOffset * 0.2));
double typeBonus = entity.getType().equals("MAIN") ? 0.2 : 0.0;
return baseConfidence * distancePenalty + typeBonus;
}
private Optional<ConversationTurn> findLastCompleteTurn(ConversationContext context) {
return context.getRecentTurns().stream().filter(turn -> turn.getUserMessage().length() > 5)
.findFirst();
}
private EnhancedQuery convertToEnhancedQuery(EnhancementCandidate candidate, String originalQuery) {
return EnhancedQuery.builder()
.originalQuery(originalQuery)
.enhancedQuery(candidate.getEnhancedQuery())
.contextReferences(candidate.getContextReferences())
.enhancementType(candidate.getEnhancementType())
.confidence(candidate.getConfidence())
.enhancementDetails(Map.of("reasoning", candidate.getReasoning(), "generation_method", candidate.getGenerationMethod()))
.build();
}
private EnhancedQuery createFallbackEnhancedQuery(String originalQuery) {
return EnhancedQuery.builder()
.originalQuery(originalQuery)
.enhancedQuery(originalQuery)
.contextReferences(List.of())
.enhancementType(EnhancedQuery.EnhancementType.ENTITY_CONSISTENCY)
.confidence(0.1)
.enhancementDetails(Map.of("fallback", true))
.build();
}
}
@Component
@Slf4j
public class ConversationMemoryService {
private final ConversationRepository conversationRepository;
private final EntityResolutionService entityResolver;
private final IntentTrackingService intentTracker;
private final Cache<String, ConversationContext> activeConversations;
public ConversationMemoryService(ConversationRepository conversationRepository,
EntityResolutionService entityResolver,
IntentTrackingService intentTracker) {
this.conversationRepository = conversationRepository;
this.entityResolver = entityResolver;
this.intentTracker = intentTracker;
this.activeConversations = Caffeine.newBuilder()
.maximumSize(1000)
.expireAfterAccess(30, TimeUnit.MINUTES)
.build();
}
public ConversationContext getConversationContext(String conversationId, int maxTurns) {
ConversationContext cached = activeConversations.getIfPresent(conversationId);
if (cached != null) {
return cached;
}
List<ConversationTurn> turns = conversationRepository.findRecentTurns(conversationId, maxTurns);
ConversationContext context = buildConversationContext(turns);
activeConversations.put(conversationId, context);
return context;
}
public void updateConversationMemory(String conversationId, String userQuery, EnhancedQuery enhancedQuery, ConversationContext context) {
List<Entity> entities = entityResolver.extractEntities(userQuery);
String intent = intentTracker.extractIntent(userQuery);
ConversationTurn newTurn = ConversationTurn.builder()
.conversationId(conversationId)
.userMessage(userQuery)
.enhancedQuery(enhancedQuery.getEnhancedQuery())
.entities(entities)
.intent(intent)
.timestamp(Instant.now())
.build();
conversationRepository.saveTurn(newTurn);
context.getRecentTurns().add(0, newTurn);
if (context.getRecentTurns().size() > 20) {
context.setRecentTurns(context.getRecentTurns().subList(0, 20));
}
activeConversations.put(conversationId, context);
}
private ConversationContext buildConversationContext(List<ConversationTurn> turns) {
ConversationContext context = new ConversationContext();
context.setRecentTurns(new ArrayList<>(turns));
Set<Entity> contextEntities = turns.stream().flatMap(turn -> turn.getEntities().stream()).collect(Collectors.toSet());
context.setContextEntities(new ArrayList<>(contextEntities));
List<String> topics = analyzeConversationTopics(turns);
context.setTopics(topics);
return context;
}
private List<String> analyzeConversationTopics(List<ConversationTurn> turns) {
Map<String, Integer> topicFrequency = new HashMap<>();
for (ConversationTurn turn : turns) {
for (Entity entity : turn.getEntities()) {
if (entity.getType().equals("TOPIC")) {
topicFrequency.merge(entity.getName(), 1, Integer::sum);
}
}
if (turn.getIntent() != null) {
String[] intentWords = turn.getIntent().split("\\s+");
for (String word : intentWords) {
if (word.length() > 1) {
topicFrequency.merge(word, 1, Integer::sum);
}
}
}
}
return topicFrequency.entrySet().stream()
.sorted(Map.Entry.<String, Integer>comparingByValue().reversed())
.limit(5)
.map(Map.Entry::getKey)
.collect(Collectors.toList());
}
}