diff --git a/Dockerfile b/Dockerfile index af544a8..a4adc64 100644 --- a/Dockerfile +++ b/Dockerfile @@ -12,6 +12,7 @@ COPY . . RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o history-api ./cmd/api RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o email-worker ./cmd/worker/email RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o storage-worker ./cmd/worker/storage +RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o rag-worker ./cmd/worker/rag FROM alpine:latest @@ -23,9 +24,10 @@ WORKDIR /app COPY --from=builder /app/history-api . COPY --from=builder /app/email-worker . COPY --from=builder /app/storage-worker . +COPY --from=builder /app/rag-worker . COPY data ./data -RUN chmod +x ./history-api ./email-worker ./storage-worker +RUN chmod +x ./history-api ./email-worker ./storage-worker ./rag-worker EXPOSE 3344 diff --git a/cmd/api/server.go b/cmd/api/server.go index 0588524..9e5963f 100644 --- a/cmd/api/server.go +++ b/cmd/api/server.go @@ -95,6 +95,7 @@ func (s *FiberServer) SetupServer( submissionRepo := repositories.NewSubmissionRepository(poolPg, redis) raguRepo := repositories.NewRagRepository(poolPg, redis) + usageRepo := repositories.NewUsageRepository(redis) // service setup authService := services.NewAuthService(userRepo, roleRepo, tokenRepo, redis, poolPg) @@ -114,7 +115,7 @@ func (s *FiberServer) SetupServer( userRepo, wikiRepo, geometryRepo, entityRepo, raguRepo, raguUtils, poolPg, redis, ) - chatbotService := services.NewChatbotService(raguRepo, raguUtils) + chatbotService := services.NewChatbotService(raguRepo, usageRepo, raguUtils) // controller setup authController := controllers.NewAuthController(authService, oauth) diff --git a/cmd/worker/rag/main.go b/cmd/worker/rag/main.go new file mode 100644 index 0000000..7427b5e --- /dev/null +++ b/cmd/worker/rag/main.go @@ -0,0 +1,268 @@ +package main + +import ( + "context" + "encoding/json" + "math" + "strconv" + "sync" + "time" + + "history-api/internal/models" + "history-api/internal/repositories" + "history-api/pkg/ai" + "history-api/pkg/cache" + "history-api/pkg/config" + "history-api/pkg/constants" + "history-api/pkg/database" + _ "history-api/pkg/log" + + "github.com/redis/go-redis/v9" + "github.com/rs/zerolog/log" +) + +const ( + maxRetries = 3 + baseRetryDelay = 2 * time.Second + itemDelay = 500 * time.Millisecond +) + +func processRagTask(ctx context.Context, ragRepo repositories.RagRepository, ragUtils *ai.RagUtils, task *models.RagIndexTask, workerName string) { + if len(task.DeleteWikiIDs) > 0 { + if err := ragRepo.DeleteBySourceIDs(ctx, "wiki", task.DeleteWikiIDs); err != nil { + log.Error().Err(err).Str("worker", workerName).Msg("Failed to delete wiki RAG chunks") + } + } + + if len(task.DeleteEntityIDs) > 0 { + if err := ragRepo.DeleteBySourceIDs(ctx, "entity", task.DeleteEntityIDs); err != nil { + log.Error().Err(err).Str("worker", workerName).Msg("Failed to delete entity RAG chunks") + } + } + + // 2. Index wikis with delay + retry + for _, wiki := range task.Wikis { + if wiki.Source != "inline" { + continue + } + + log.Info().Str("worker", workerName).Str("wiki_id", wiki.ID).Msg("Indexing wiki") + + cleanText := ragUtils.StripHTML(wiki.Title + "\n" + wiki.Doc) + + var chunks []string + var vectors [][]float32 + var err error + + for attempt := 0; attempt <= maxRetries; attempt++ { + if attempt > 0 { + delay := baseRetryDelay * time.Duration(math.Pow(2, float64(attempt-1))) + log.Warn(). + Str("worker", workerName). + Str("wiki_id", wiki.ID). + Int("attempt", attempt). + Dur("delay", delay). + Msg("Retrying wiki embedding") + time.Sleep(delay) + } + + chunks, vectors, err = ragUtils.PrepareChunks(ctx, cleanText) + if err == nil { + break + } + + log.Error().Err(err). + Str("worker", workerName). + Str("wiki_id", wiki.ID). + Int("attempt", attempt). + Msg("Failed to prepare wiki chunks") + } + + if err != nil { + log.Error().Err(err).Str("worker", workerName).Str("wiki_id", wiki.ID).Msg("Giving up on wiki after max retries") + continue + } + + // Delete existing chunks then save new ones + _ = ragRepo.DeleteBySourceIDs(ctx, "wiki", []string{wiki.ID}) + for i, chunk := range chunks { + if saveErr := ragRepo.SaveChunk(ctx, "wiki", wiki.ID, task.ProjectID, i, chunk, vectors[i]); saveErr != nil { + log.Error().Err(saveErr).Str("wiki_id", wiki.ID).Int("chunk", i).Msg("Failed to save wiki chunk") + } + } + + log.Info().Str("worker", workerName).Str("wiki_id", wiki.ID).Int("chunks", len(chunks)).Msg("Wiki indexed successfully") + + // Delay between items to avoid rate limit + time.Sleep(itemDelay) + } + + // 3. Index entities with delay + retry + for _, entity := range task.Entities { + if entity.Source != "inline" { + continue + } + + log.Info().Str("worker", workerName).Str("entity_id", entity.ID).Msg("Indexing entity") + + cleanText := ragUtils.StripHTML(entity.Name + "\n" + entity.Description) + + var chunks []string + var vectors [][]float32 + var err error + + for attempt := 0; attempt <= maxRetries; attempt++ { + if attempt > 0 { + delay := baseRetryDelay * time.Duration(math.Pow(2, float64(attempt-1))) + log.Warn(). + Str("worker", workerName). + Str("entity_id", entity.ID). + Int("attempt", attempt). + Dur("delay", delay). + Msg("Retrying entity embedding") + time.Sleep(delay) + } + + chunks, vectors, err = ragUtils.PrepareChunks(ctx, cleanText) + if err == nil { + break + } + + log.Error().Err(err). + Str("worker", workerName). + Str("entity_id", entity.ID). + Int("attempt", attempt). + Msg("Failed to prepare entity chunks") + } + + if err != nil { + log.Error().Err(err).Str("worker", workerName).Str("entity_id", entity.ID).Msg("Giving up on entity after max retries") + continue + } + + // Delete existing chunks then save new ones + _ = ragRepo.DeleteBySourceIDs(ctx, "entity", []string{entity.ID}) + for i, chunk := range chunks { + if saveErr := ragRepo.SaveChunk(ctx, "entity", entity.ID, task.ProjectID, i, chunk, vectors[i]); saveErr != nil { + log.Error().Err(saveErr).Str("entity_id", entity.ID).Int("chunk", i).Msg("Failed to save entity chunk") + } + } + + log.Info().Str("worker", workerName).Str("entity_id", entity.ID).Int("chunks", len(chunks)).Msg("Entity indexed successfully") + time.Sleep(itemDelay) + } +} + +func runSingleWorker(ctx context.Context, rdb *redis.Client, consumerID int, ragRepo repositories.RagRepository, ragUtils *ai.RagUtils) { + consumerName := "worker-" + strconv.Itoa(consumerID) + + log.Info().Str("worker", consumerName).Msg("RAG worker started and ready") + + for { + entries, err := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{ + Group: constants.GroupRagName, + Consumer: consumerName, + Streams: []string{constants.StreamRagName, ">"}, + Count: 1, + Block: 0, + }).Result() + + if err != nil { + log.Error().Err(err).Str("worker", consumerName).Msg("Failed to read stream") + time.Sleep(2 * time.Second) + continue + } + + for _, stream := range entries { + for _, message := range stream.Messages { + taskType, ok1 := message.Values["task_type"].(string) + payloadStr, ok2 := message.Values["payload"].(string) + if !ok1 || !ok2 { + log.Error().Msg("Invalid message format") + rdb.XAck(ctx, constants.StreamRagName, constants.GroupRagName, message.ID) + continue + } + + if taskType == constants.TaskTypeRagIndexSubmission.String() { + var task models.RagIndexTask + if err := json.Unmarshal([]byte(payloadStr), &task); err != nil { + log.Error().Err(err).Msg("Failed to unmarshal RAG task payload") + rdb.XAck(ctx, constants.StreamRagName, constants.GroupRagName, message.ID) + continue + } + + log.Info(). + Str("worker", consumerName). + Str("project_id", task.ProjectID). + Int("wikis", len(task.Wikis)). + Int("entities", len(task.Entities)). + Msg("Processing RAG index task") + + processRagTask(ctx, ragRepo, ragUtils, &task, consumerName) + } + + rdb.XAck(ctx, constants.StreamRagName, constants.GroupRagName, message.ID) + log.Info().Str("msg_id", message.ID).Msg("Task acknowledged") + } + } + } +} + +func main() { + config.LoadEnv() + + workerCountStr := config.GetConfigWithDefault("RAG_WORKER_COUNT", "1") + workerCount, err := strconv.Atoi(workerCountStr) + if err != nil || workerCount <= 0 { + workerCount = 1 + } + + cacheInterface, err := cache.NewRedisClient() + if err != nil { + log.Fatal(). + Err(err). + Msg("Failed to connect to Redis") + } + + rdb := cacheInterface.GetRawClient() + + poolPg, err := database.NewPostgresqlDB() + if err != nil { + log.Fatal(). + Err(err). + Msg("Failed to connect to PostgreSQL") + } + defer poolPg.Close() + + ragUtils, err := ai.NewRagUtils() + if err != nil { + log.Fatal(). + Err(err). + Msg("Failed to initialize RAG utils") + } + + ragRepo := repositories.NewRagRepository(poolPg, cacheInterface) + + ctx := context.Background() + + err = rdb.XGroupCreateMkStream(ctx, constants.StreamRagName, constants.GroupRagName, "$").Err() + if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" { + log.Fatal(). + Err(err). + Msg("Failed to create Redis Stream Group") + } + + log.Info(). + Int("worker_count", workerCount). + Msg("Starting RAG worker system") + + var wg sync.WaitGroup + + for i := 1; i <= workerCount; i++ { + wg.Go(func() { + runSingleWorker(ctx, rdb, i, ragRepo, ragUtils) + }) + } + + wg.Wait() +} diff --git a/docker-compose.yml b/docker-compose.yml index 1a3a3f6..858fe35 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -109,6 +109,21 @@ services: command: ["./storage-worker"] networks: - history-api-project + + history_rag_worker: + build: . + container_name: history_rag_worker + restart: unless-stopped + depends_on: + history_db: + condition: service_healthy + history_cache: + condition: service_started + env_file: + - ./assets/resources/.env + command: ["./rag-worker"] + networks: + - history-api-project volumes: history_db_data: diff --git a/internal/controllers/chatbotController.go b/internal/controllers/chatbotController.go index 1796bfc..0548f69 100644 --- a/internal/controllers/chatbotController.go +++ b/internal/controllers/chatbotController.go @@ -45,8 +45,18 @@ func (cx *ChatbotController) Chat(c fiber.Ctx) error { }) } - answer, err := cx.chatbotService.Chat(ctx, dto.ProjectID, dto.Question) + claims := c.Locals("user").(*response.JWTClaims) + + answer, err := cx.chatbotService.Chat(ctx, claims.UId, dto.ProjectID, dto.Question) if err != nil { + // Trả về lỗi 429 (Too Many Requests) nếu hết lượt dùng + if err.Error() == "you have reached your daily limit of 10 questions. Please come back tomorrow" { + return c.Status(fiber.StatusTooManyRequests).JSON(response.CommonResponse{ + Status: false, + Message: err.Error(), + }) + } + return c.Status(fiber.StatusInternalServerError).JSON(response.CommonResponse{ Status: false, Message: err.Error(), diff --git a/internal/controllers/submissionController.go b/internal/controllers/submissionController.go index 1ab20b5..e300eab 100644 --- a/internal/controllers/submissionController.go +++ b/internal/controllers/submissionController.go @@ -78,7 +78,7 @@ func (s *submissionController) CreateSubmission(c fiber.Ctx) error { // @Security BearerAuth // @Router /submissions/{id}/status [patch] func (s *submissionController) UpdateSubmissionStatus(c fiber.Ctx) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second) defer cancel() id := c.Params("id") uid := c.Locals("uid").(string) diff --git a/internal/models/rag.go b/internal/models/rag.go index 5401b74..87d50c1 100644 --- a/internal/models/rag.go +++ b/internal/models/rag.go @@ -15,3 +15,25 @@ type RagChunk struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } + +type RagIndexTask struct { + ProjectID string `json:"project_id"` + DeleteWikiIDs []string `json:"delete_wiki_ids"` + DeleteEntityIDs []string `json:"delete_entity_ids"` + Wikis []*RagWikiItem `json:"wikis"` + Entities []*RagEntityItem `json:"entities"` +} + +type RagWikiItem struct { + ID string `json:"id"` + Title string `json:"title"` + Doc string `json:"doc"` + Source string `json:"source"` +} + +type RagEntityItem struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Source string `json:"source"` +} diff --git a/internal/repositories/usageRepository.go b/internal/repositories/usageRepository.go new file mode 100644 index 0000000..2688e4e --- /dev/null +++ b/internal/repositories/usageRepository.go @@ -0,0 +1,55 @@ +package repositories + +import ( + "context" + "fmt" + "history-api/pkg/cache" + "history-api/pkg/constants" + "time" +) + +type UsageRepository interface { + GetAIUsage(ctx context.Context, userID string) (int, error) + IncrementAIUsage(ctx context.Context, userID string) (int, error) +} + +type usageRepository struct { + c cache.Cache +} + +func NewUsageRepository(c cache.Cache) UsageRepository { + return &usageRepository{ + c: c, + } +} + +func (r *usageRepository) getUsageKey(userID string) string { + dateStr := time.Now().Format("20060102") + return fmt.Sprintf("usage:ai:%s:%s", userID, dateStr) +} + +func (r *usageRepository) GetAIUsage(ctx context.Context, userID string) (int, error) { + key := r.getUsageKey(userID) + var count int + err := r.c.Get(ctx, key, &count) + if err != nil { + return 0, nil + } + return count, nil +} + +func (r *usageRepository) IncrementAIUsage(ctx context.Context, userID string) (int, error) { + key := r.getUsageKey(userID) + rdb := r.c.GetRawClient() + + count, err := rdb.Incr(ctx, key).Result() + if err != nil { + return 0, err + } + + if count == 1 { + rdb.Expire(ctx, key, constants.UsageExpiration) + } + + return int(count), nil +} diff --git a/internal/services/chatbotService.go b/internal/services/chatbotService.go index 6f45d51..e6a3754 100644 --- a/internal/services/chatbotService.go +++ b/internal/services/chatbotService.go @@ -2,28 +2,41 @@ package services import ( "context" + "errors" "fmt" "history-api/internal/repositories" "history-api/pkg/ai" + "history-api/pkg/constants" ) type ChatbotService interface { - Chat(ctx context.Context, projectID *string, question string) (string, error) + Chat(ctx context.Context, userID string, projectID *string, question string) (string, error) } type chatbotService struct { - repo repositories.RagRepository - ragUtils *ai.RagUtils + repo repositories.RagRepository + usageRepo repositories.UsageRepository + ragUtils *ai.RagUtils } -func NewChatbotService(repo repositories.RagRepository, ragUtils *ai.RagUtils) ChatbotService { +func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, ragUtils *ai.RagUtils) ChatbotService { return &chatbotService{ - repo: repo, - ragUtils: ragUtils, + repo: repo, + usageRepo: usageRepo, + ragUtils: ragUtils, } } -func (s *chatbotService) Chat(ctx context.Context, projectID *string, question string) (string, error) { +func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *string, question string) (string, error) { + usage, err := s.usageRepo.GetAIUsage(ctx, userID) + if err != nil { + return "", fmt.Errorf("failed to check usage: %w", err) + } + + if usage >= constants.MaxDailyAIUsage { + return "", errors.New("you have reached your daily limit of 10 questions. Please come back tomorrow") + } + qVector, err := s.ragUtils.EmbedQuery(ctx, question) if err != nil { return "", fmt.Errorf("failed to embed question: %w", err) @@ -61,5 +74,13 @@ Context: Question: %s`, contextStr, question) } - return s.ragUtils.GenerateResponse(ctx, prompt) + response, err := s.ragUtils.GenerateResponse(ctx, prompt) + if err != nil { + return "", err + } + + // 3. Tăng số lần sử dụng sau khi gọi AI thành công + _, _ = s.usageRepo.IncrementAIUsage(ctx, userID) + + return response, nil } diff --git a/internal/services/submissionService.go b/internal/services/submissionService.go index 0a3b666..4a3e702 100644 --- a/internal/services/submissionService.go +++ b/internal/services/submissionService.go @@ -15,12 +15,13 @@ import ( "history-api/pkg/cache" "history-api/pkg/constants" "history-api/pkg/convert" - "strconv" "slices" + "strconv" "github.com/gofiber/fiber/v3" "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" + "github.com/rs/zerolog/log" "golang.org/x/sync/errgroup" ) @@ -166,7 +167,6 @@ func (s *submissionService) UpdateSubmissionStatus(ctx context.Context, reviewer entityRepo := s.entityRepo.WithTx(tx) geometryRepo := s.geometryRepo.WithTx(tx) wikiRepo := s.wikiRepo.WithTx(tx) - ragRepo := s.ragRepo.WithTx(tx) submissionUUID, err := convert.StringToUUID(submissionID) if err != nil { @@ -625,33 +625,31 @@ func (s *submissionService) UpdateSubmissionStatus(ctx context.Context, reviewer } } - _ = ragRepo.DeleteBySourceIDs(ctx, "wiki", wikiDeleteIDs) - _ = ragRepo.DeleteBySourceIDs(ctx, "entity", entityDeleteIDs) - - for _, wiki := range snapshotData.Wikis { - if wiki.Source == "inline" { - cleanText := s.ragUtils.StripHTML(wiki.Title + "\n" + wiki.Doc) - chunks, vectors, err := s.ragUtils.PrepareChunks(ctx, cleanText) - if err == nil { - _ = ragRepo.DeleteBySourceIDs(ctx, "wiki", []string{wiki.ID}) - for i, chunk := range chunks { - _ = ragRepo.SaveChunk(ctx, "wiki", wiki.ID, commit.ProjectID, i, chunk, vectors[i]) - } - } - } + ragTask := models.RagIndexTask{ + ProjectID: commit.ProjectID, + DeleteWikiIDs: wikiDeleteIDs, + DeleteEntityIDs: entityDeleteIDs, } + for _, wiki := range snapshotData.Wikis { + ragTask.Wikis = append(ragTask.Wikis, &models.RagWikiItem{ + ID: wiki.ID, + Title: wiki.Title, + Doc: wiki.Doc, + Source: wiki.Source, + }) + } for _, entity := range snapshotData.Entities { - if entity.Source == "inline" { - cleanText := s.ragUtils.StripHTML(entity.Name + "\n" + entity.Description) - chunks, vectors, err := s.ragUtils.PrepareChunks(ctx, cleanText) - if err == nil { - _ = ragRepo.DeleteBySourceIDs(ctx, "entity", []string{entity.ID}) - for i, chunk := range chunks { - _ = ragRepo.SaveChunk(ctx, "entity", entity.ID, commit.ProjectID, i, chunk, vectors[i]) - } - } - } + ragTask.Entities = append(ragTask.Entities, &models.RagEntityItem{ + ID: entity.ID, + Name: entity.Name, + Description: entity.Description, + Source: entity.Source, + }) + } + + if err := s.c.PublishTask(ctx, constants.StreamRagName, constants.TaskTypeRagIndexSubmission, ragTask); err != nil { + log.Error().Err(err).Str("project_id", commit.ProjectID).Msg("Failed to publish RAG index task") } } @@ -677,6 +675,7 @@ func (s *submissionService) UpdateSubmissionStatus(ctx context.Context, reviewer bgCtx := context.Background() _ = s.c.DelByPattern(bgCtx, "entity:search*") _ = s.c.DelByPattern(bgCtx, "geometry:search*") + _ = s.c.DelByPattern(bgCtx, "geometry:search:entity*") _ = s.c.DelByPattern(bgCtx, "wiki:search*") }() } diff --git a/pkg/constants/stream.go b/pkg/constants/stream.go index 6c365dd..efc3c60 100644 --- a/pkg/constants/stream.go +++ b/pkg/constants/stream.go @@ -3,6 +3,8 @@ package constants const ( StreamEmailName = "stream:email_tasks" StreamStorageName = "stream:storage_tasks" + StreamRagName = "stream:rag_tasks" GroupEmailName = "email_workers_group" GroupStorageName = "storage_workers_group" + GroupRagName = "rag_workers_group" ) diff --git a/pkg/constants/task.go b/pkg/constants/task.go index a4316de..ecee177 100644 --- a/pkg/constants/task.go +++ b/pkg/constants/task.go @@ -7,6 +7,7 @@ const ( TaskTypeNotifyHistorianReview TaskType = "NOTIFY_HISTORIAN_REVIEW" TaskTypeDeleteMedia TaskType = "DELETE_MEDIA" TaskTypeBulkDeleteMedia TaskType = "BULK_DELETE_MEDIA" + TaskTypeRagIndexSubmission TaskType = "RAG_INDEX_SUBMISSION" ) func (t TaskType) String() string { diff --git a/pkg/constants/usage.go b/pkg/constants/usage.go new file mode 100644 index 0000000..354ab4a --- /dev/null +++ b/pkg/constants/usage.go @@ -0,0 +1,8 @@ +package constants + +import "time" + +const ( + MaxDailyAIUsage = 10 + UsageExpiration = 24 * time.Hour +)