- RAG & SPIN Specialist improvements

This commit is contained in:
Josako
2025-04-22 13:49:38 +02:00
parent 4bf12db142
commit 9652d0bff9
12 changed files with 24 additions and 41 deletions

View File

@@ -4,6 +4,6 @@ from pydantic import BaseModel, Field
class RAGOutput(BaseModel):
answer: Optional[str] = Field(None, description="Answer to the questions asked")
citations: Optional[List[str]] = Field(None, description="A list of sources used in generating the answer")
insufficient_info: Optional[bool] = Field(None, description="An indication if there's insufficient information to answer")
answer: str = Field(None, description="Answer to the questions asked")
insufficient_info: bool = Field(None, description="An indication if there's insufficient information to answer")

View File

@@ -1,4 +1,4 @@
from typing import Dict, Any
from typing import Dict, Any, Optional
from flask import current_app
from pydantic import BaseModel, Field, model_validator
@@ -11,6 +11,7 @@ class RetrieverMetadata(BaseModel):
document_id: int = Field(..., description="ID of the source document")
version_id: int = Field(..., description="Version ID of the source document")
document_name: str = Field(..., description="Name of the source document")
url: Optional[str] = Field(..., description="URL of the source document")
user_metadata: Dict[str, Any] = Field(
default_factory=dict, # This will use an empty dict if None is provided
description="User-defined metadata"

View File

@@ -97,6 +97,7 @@ class StandardRAGRetriever(BaseRetriever):
query_obj = (
db.session.query(
db_class,
DocumentVersion.url,
(1 - db_class.embedding.cosine_distance(query_embedding)).label('similarity')
)
.join(DocumentVersion, db_class.doc_vers_id == DocumentVersion.id)
@@ -116,7 +117,7 @@ class StandardRAGRetriever(BaseRetriever):
# Transform results into standard format
processed_results = []
for doc, similarity in results:
for doc, url, similarity in results:
# Parse user_metadata to ensure it's a dictionary
user_metadata = self._parse_metadata(doc.document_version.user_metadata)
processed_results.append(
@@ -128,6 +129,7 @@ class StandardRAGRetriever(BaseRetriever):
document_id=doc.document_version.doc_id,
version_id=doc.document_version.id,
document_name=doc.document_version.document.name,
url=url or "",
user_metadata=user_metadata,
)
)

View File

@@ -65,7 +65,6 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
"language": arguments.language,
"query": arguments.query,
"context": formatted_context,
"citations": citations,
"history": self.formatted_history,
"name": self.specialist.configuration.get('name', ''),
"company": self.specialist.configuration.get('company', ''),

View File

@@ -124,7 +124,6 @@ class SpecialistExecutor(CrewAIBaseSpecialistExecutor):
"language": arguments.language,
"query": arguments.query,
"context": formatted_context,
"citations": citations,
"history": self.formatted_history,
"historic_spin": json.dumps(self.latest_spin, indent=2),
"historic_lead_info": json.dumps(self.latest_lead_info, indent=2),
@@ -217,7 +216,9 @@ class SPINFlow(EveAICrewAIFlow[SPINFlowState]):
async def execute_rag(self):
inputs = self.state.input.model_dump()
try:
current_app.logger.debug("In execute_rag")
crew_output = await self.rag_crew.kickoff_async(inputs=inputs)
current_app.logger.debug(f"Crew execution ended with output:\n{crew_output}")
self.specialist_executor.log_tuning("RAG Crew Output", crew_output.model_dump())
output_pydantic = crew_output.pydantic
if not output_pydantic:
@@ -276,13 +277,16 @@ class SPINFlow(EveAICrewAIFlow[SPINFlowState]):
if self.state.spin:
additional_questions = additional_questions + self.state.spin.questions
inputs["additional_questions"] = additional_questions
current_app.logger.debug(f"Prepared Answers: \n{inputs['prepared_answers']}")
current_app.logger.debug(f"Additional Questions: \n{additional_questions}")
try:
crew_output = await self.rag_consolidation_crew.kickoff_async(inputs=inputs)
current_app.logger.debug(f"Consolidation output after crew execution:\n{crew_output}")
self.specialist_executor.log_tuning("RAG Consolidation Crew Output", crew_output.model_dump())
output_pydantic = crew_output.pydantic
if not output_pydantic:
raw_json = json.loads(crew_output.raw)
output_pydantic = LeadInfoOutput.model_validate(raw_json)
output_pydantic = RAGOutput.model_validate(raw_json)
self.state.final_output = output_pydantic
return crew_output
except Exception as e:

View File

@@ -314,7 +314,10 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
])
# Return document_ids for citations
citations = [ctx.metadata.document_id for ctx in unique_contexts]
citations = [{"document_id": ctx.metadata.document_id,
"document_version_id": ctx.metadata.version_id,
"url": ctx.metadata.url}
for ctx in unique_contexts]
self.log_tuning("Context Retrieval Results",
{"Formatted Context": formatted_context,
@@ -345,6 +348,7 @@ class CrewAIBaseSpecialistExecutor(BaseSpecialistExecutor):
modified_result = {
"detailed_query": detailed_query,
"citations": citations,
}
final_result = result.model_copy(update=modified_result)