feat: implement RAG-based chatbot service with daily usage rate limiting and background index worker
All checks were successful
Build and Release / release (push) Successful in 1m27s
All checks were successful
Build and Release / release (push) Successful in 1m27s
This commit is contained in:
@@ -12,6 +12,7 @@ COPY . .
|
|||||||
RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o history-api ./cmd/api
|
RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o history-api ./cmd/api
|
||||||
RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o email-worker ./cmd/worker/email
|
RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o email-worker ./cmd/worker/email
|
||||||
RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o storage-worker ./cmd/worker/storage
|
RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o storage-worker ./cmd/worker/storage
|
||||||
|
RUN GOOS=linux GOARCH=amd64 CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o rag-worker ./cmd/worker/rag
|
||||||
|
|
||||||
FROM alpine:latest
|
FROM alpine:latest
|
||||||
|
|
||||||
@@ -23,9 +24,10 @@ WORKDIR /app
|
|||||||
COPY --from=builder /app/history-api .
|
COPY --from=builder /app/history-api .
|
||||||
COPY --from=builder /app/email-worker .
|
COPY --from=builder /app/email-worker .
|
||||||
COPY --from=builder /app/storage-worker .
|
COPY --from=builder /app/storage-worker .
|
||||||
|
COPY --from=builder /app/rag-worker .
|
||||||
COPY data ./data
|
COPY data ./data
|
||||||
|
|
||||||
RUN chmod +x ./history-api ./email-worker ./storage-worker
|
RUN chmod +x ./history-api ./email-worker ./storage-worker ./rag-worker
|
||||||
|
|
||||||
EXPOSE 3344
|
EXPOSE 3344
|
||||||
|
|
||||||
|
|||||||
@@ -95,6 +95,7 @@ func (s *FiberServer) SetupServer(
|
|||||||
submissionRepo := repositories.NewSubmissionRepository(poolPg, redis)
|
submissionRepo := repositories.NewSubmissionRepository(poolPg, redis)
|
||||||
|
|
||||||
raguRepo := repositories.NewRagRepository(poolPg, redis)
|
raguRepo := repositories.NewRagRepository(poolPg, redis)
|
||||||
|
usageRepo := repositories.NewUsageRepository(redis)
|
||||||
|
|
||||||
// service setup
|
// service setup
|
||||||
authService := services.NewAuthService(userRepo, roleRepo, tokenRepo, redis, poolPg)
|
authService := services.NewAuthService(userRepo, roleRepo, tokenRepo, redis, poolPg)
|
||||||
@@ -114,7 +115,7 @@ func (s *FiberServer) SetupServer(
|
|||||||
userRepo, wikiRepo, geometryRepo, entityRepo,
|
userRepo, wikiRepo, geometryRepo, entityRepo,
|
||||||
raguRepo, raguUtils, poolPg, redis,
|
raguRepo, raguUtils, poolPg, redis,
|
||||||
)
|
)
|
||||||
chatbotService := services.NewChatbotService(raguRepo, raguUtils)
|
chatbotService := services.NewChatbotService(raguRepo, usageRepo, raguUtils)
|
||||||
|
|
||||||
// controller setup
|
// controller setup
|
||||||
authController := controllers.NewAuthController(authService, oauth)
|
authController := controllers.NewAuthController(authService, oauth)
|
||||||
|
|||||||
268
cmd/worker/rag/main.go
Normal file
268
cmd/worker/rag/main.go
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"history-api/internal/models"
|
||||||
|
"history-api/internal/repositories"
|
||||||
|
"history-api/pkg/ai"
|
||||||
|
"history-api/pkg/cache"
|
||||||
|
"history-api/pkg/config"
|
||||||
|
"history-api/pkg/constants"
|
||||||
|
"history-api/pkg/database"
|
||||||
|
_ "history-api/pkg/log"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxRetries = 3
|
||||||
|
baseRetryDelay = 2 * time.Second
|
||||||
|
itemDelay = 500 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
func processRagTask(ctx context.Context, ragRepo repositories.RagRepository, ragUtils *ai.RagUtils, task *models.RagIndexTask, workerName string) {
|
||||||
|
if len(task.DeleteWikiIDs) > 0 {
|
||||||
|
if err := ragRepo.DeleteBySourceIDs(ctx, "wiki", task.DeleteWikiIDs); err != nil {
|
||||||
|
log.Error().Err(err).Str("worker", workerName).Msg("Failed to delete wiki RAG chunks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(task.DeleteEntityIDs) > 0 {
|
||||||
|
if err := ragRepo.DeleteBySourceIDs(ctx, "entity", task.DeleteEntityIDs); err != nil {
|
||||||
|
log.Error().Err(err).Str("worker", workerName).Msg("Failed to delete entity RAG chunks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Index wikis with delay + retry
|
||||||
|
for _, wiki := range task.Wikis {
|
||||||
|
if wiki.Source != "inline" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Str("worker", workerName).Str("wiki_id", wiki.ID).Msg("Indexing wiki")
|
||||||
|
|
||||||
|
cleanText := ragUtils.StripHTML(wiki.Title + "\n" + wiki.Doc)
|
||||||
|
|
||||||
|
var chunks []string
|
||||||
|
var vectors [][]float32
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
delay := baseRetryDelay * time.Duration(math.Pow(2, float64(attempt-1)))
|
||||||
|
log.Warn().
|
||||||
|
Str("worker", workerName).
|
||||||
|
Str("wiki_id", wiki.ID).
|
||||||
|
Int("attempt", attempt).
|
||||||
|
Dur("delay", delay).
|
||||||
|
Msg("Retrying wiki embedding")
|
||||||
|
time.Sleep(delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks, vectors, err = ragUtils.PrepareChunks(ctx, cleanText)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Error().Err(err).
|
||||||
|
Str("worker", workerName).
|
||||||
|
Str("wiki_id", wiki.ID).
|
||||||
|
Int("attempt", attempt).
|
||||||
|
Msg("Failed to prepare wiki chunks")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("worker", workerName).Str("wiki_id", wiki.ID).Msg("Giving up on wiki after max retries")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete existing chunks then save new ones
|
||||||
|
_ = ragRepo.DeleteBySourceIDs(ctx, "wiki", []string{wiki.ID})
|
||||||
|
for i, chunk := range chunks {
|
||||||
|
if saveErr := ragRepo.SaveChunk(ctx, "wiki", wiki.ID, task.ProjectID, i, chunk, vectors[i]); saveErr != nil {
|
||||||
|
log.Error().Err(saveErr).Str("wiki_id", wiki.ID).Int("chunk", i).Msg("Failed to save wiki chunk")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Str("worker", workerName).Str("wiki_id", wiki.ID).Int("chunks", len(chunks)).Msg("Wiki indexed successfully")
|
||||||
|
|
||||||
|
// Delay between items to avoid rate limit
|
||||||
|
time.Sleep(itemDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Index entities with delay + retry
|
||||||
|
for _, entity := range task.Entities {
|
||||||
|
if entity.Source != "inline" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Str("worker", workerName).Str("entity_id", entity.ID).Msg("Indexing entity")
|
||||||
|
|
||||||
|
cleanText := ragUtils.StripHTML(entity.Name + "\n" + entity.Description)
|
||||||
|
|
||||||
|
var chunks []string
|
||||||
|
var vectors [][]float32
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
delay := baseRetryDelay * time.Duration(math.Pow(2, float64(attempt-1)))
|
||||||
|
log.Warn().
|
||||||
|
Str("worker", workerName).
|
||||||
|
Str("entity_id", entity.ID).
|
||||||
|
Int("attempt", attempt).
|
||||||
|
Dur("delay", delay).
|
||||||
|
Msg("Retrying entity embedding")
|
||||||
|
time.Sleep(delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks, vectors, err = ragUtils.PrepareChunks(ctx, cleanText)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Error().Err(err).
|
||||||
|
Str("worker", workerName).
|
||||||
|
Str("entity_id", entity.ID).
|
||||||
|
Int("attempt", attempt).
|
||||||
|
Msg("Failed to prepare entity chunks")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("worker", workerName).Str("entity_id", entity.ID).Msg("Giving up on entity after max retries")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete existing chunks then save new ones
|
||||||
|
_ = ragRepo.DeleteBySourceIDs(ctx, "entity", []string{entity.ID})
|
||||||
|
for i, chunk := range chunks {
|
||||||
|
if saveErr := ragRepo.SaveChunk(ctx, "entity", entity.ID, task.ProjectID, i, chunk, vectors[i]); saveErr != nil {
|
||||||
|
log.Error().Err(saveErr).Str("entity_id", entity.ID).Int("chunk", i).Msg("Failed to save entity chunk")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Str("worker", workerName).Str("entity_id", entity.ID).Int("chunks", len(chunks)).Msg("Entity indexed successfully")
|
||||||
|
time.Sleep(itemDelay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSingleWorker(ctx context.Context, rdb *redis.Client, consumerID int, ragRepo repositories.RagRepository, ragUtils *ai.RagUtils) {
|
||||||
|
consumerName := "worker-" + strconv.Itoa(consumerID)
|
||||||
|
|
||||||
|
log.Info().Str("worker", consumerName).Msg("RAG worker started and ready")
|
||||||
|
|
||||||
|
for {
|
||||||
|
entries, err := rdb.XReadGroup(ctx, &redis.XReadGroupArgs{
|
||||||
|
Group: constants.GroupRagName,
|
||||||
|
Consumer: consumerName,
|
||||||
|
Streams: []string{constants.StreamRagName, ">"},
|
||||||
|
Count: 1,
|
||||||
|
Block: 0,
|
||||||
|
}).Result()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Str("worker", consumerName).Msg("Failed to read stream")
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, stream := range entries {
|
||||||
|
for _, message := range stream.Messages {
|
||||||
|
taskType, ok1 := message.Values["task_type"].(string)
|
||||||
|
payloadStr, ok2 := message.Values["payload"].(string)
|
||||||
|
if !ok1 || !ok2 {
|
||||||
|
log.Error().Msg("Invalid message format")
|
||||||
|
rdb.XAck(ctx, constants.StreamRagName, constants.GroupRagName, message.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if taskType == constants.TaskTypeRagIndexSubmission.String() {
|
||||||
|
var task models.RagIndexTask
|
||||||
|
if err := json.Unmarshal([]byte(payloadStr), &task); err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to unmarshal RAG task payload")
|
||||||
|
rdb.XAck(ctx, constants.StreamRagName, constants.GroupRagName, message.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Str("worker", consumerName).
|
||||||
|
Str("project_id", task.ProjectID).
|
||||||
|
Int("wikis", len(task.Wikis)).
|
||||||
|
Int("entities", len(task.Entities)).
|
||||||
|
Msg("Processing RAG index task")
|
||||||
|
|
||||||
|
processRagTask(ctx, ragRepo, ragUtils, &task, consumerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
rdb.XAck(ctx, constants.StreamRagName, constants.GroupRagName, message.ID)
|
||||||
|
log.Info().Str("msg_id", message.ID).Msg("Task acknowledged")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
config.LoadEnv()
|
||||||
|
|
||||||
|
workerCountStr := config.GetConfigWithDefault("RAG_WORKER_COUNT", "1")
|
||||||
|
workerCount, err := strconv.Atoi(workerCountStr)
|
||||||
|
if err != nil || workerCount <= 0 {
|
||||||
|
workerCount = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheInterface, err := cache.NewRedisClient()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to connect to Redis")
|
||||||
|
}
|
||||||
|
|
||||||
|
rdb := cacheInterface.GetRawClient()
|
||||||
|
|
||||||
|
poolPg, err := database.NewPostgresqlDB()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to connect to PostgreSQL")
|
||||||
|
}
|
||||||
|
defer poolPg.Close()
|
||||||
|
|
||||||
|
ragUtils, err := ai.NewRagUtils()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to initialize RAG utils")
|
||||||
|
}
|
||||||
|
|
||||||
|
ragRepo := repositories.NewRagRepository(poolPg, cacheInterface)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = rdb.XGroupCreateMkStream(ctx, constants.StreamRagName, constants.GroupRagName, "$").Err()
|
||||||
|
if err != nil && err.Error() != "BUSYGROUP Consumer Group name already exists" {
|
||||||
|
log.Fatal().
|
||||||
|
Err(err).
|
||||||
|
Msg("Failed to create Redis Stream Group")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().
|
||||||
|
Int("worker_count", workerCount).
|
||||||
|
Msg("Starting RAG worker system")
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 1; i <= workerCount; i++ {
|
||||||
|
wg.Go(func() {
|
||||||
|
runSingleWorker(ctx, rdb, i, ragRepo, ragUtils)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
@@ -110,6 +110,21 @@ services:
|
|||||||
networks:
|
networks:
|
||||||
- history-api-project
|
- history-api-project
|
||||||
|
|
||||||
|
history_rag_worker:
|
||||||
|
build: .
|
||||||
|
container_name: history_rag_worker
|
||||||
|
restart: unless-stopped
|
||||||
|
depends_on:
|
||||||
|
history_db:
|
||||||
|
condition: service_healthy
|
||||||
|
history_cache:
|
||||||
|
condition: service_started
|
||||||
|
env_file:
|
||||||
|
- ./assets/resources/.env
|
||||||
|
command: ["./rag-worker"]
|
||||||
|
networks:
|
||||||
|
- history-api-project
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
history_db_data:
|
history_db_data:
|
||||||
history_cache_data:
|
history_cache_data:
|
||||||
|
|||||||
@@ -45,8 +45,18 @@ func (cx *ChatbotController) Chat(c fiber.Ctx) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
answer, err := cx.chatbotService.Chat(ctx, dto.ProjectID, dto.Question)
|
claims := c.Locals("user").(*response.JWTClaims)
|
||||||
|
|
||||||
|
answer, err := cx.chatbotService.Chat(ctx, claims.UId, dto.ProjectID, dto.Question)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Trả về lỗi 429 (Too Many Requests) nếu hết lượt dùng
|
||||||
|
if err.Error() == "you have reached your daily limit of 10 questions. Please come back tomorrow" {
|
||||||
|
return c.Status(fiber.StatusTooManyRequests).JSON(response.CommonResponse{
|
||||||
|
Status: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return c.Status(fiber.StatusInternalServerError).JSON(response.CommonResponse{
|
return c.Status(fiber.StatusInternalServerError).JSON(response.CommonResponse{
|
||||||
Status: false,
|
Status: false,
|
||||||
Message: err.Error(),
|
Message: err.Error(),
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ func (s *submissionController) CreateSubmission(c fiber.Ctx) error {
|
|||||||
// @Security BearerAuth
|
// @Security BearerAuth
|
||||||
// @Router /submissions/{id}/status [patch]
|
// @Router /submissions/{id}/status [patch]
|
||||||
func (s *submissionController) UpdateSubmissionStatus(c fiber.Ctx) error {
|
func (s *submissionController) UpdateSubmissionStatus(c fiber.Ctx) error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
id := c.Params("id")
|
id := c.Params("id")
|
||||||
uid := c.Locals("uid").(string)
|
uid := c.Locals("uid").(string)
|
||||||
|
|||||||
@@ -15,3 +15,25 @@ type RagChunk struct {
|
|||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RagIndexTask struct {
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
DeleteWikiIDs []string `json:"delete_wiki_ids"`
|
||||||
|
DeleteEntityIDs []string `json:"delete_entity_ids"`
|
||||||
|
Wikis []*RagWikiItem `json:"wikis"`
|
||||||
|
Entities []*RagEntityItem `json:"entities"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RagWikiItem struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Doc string `json:"doc"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RagEntityItem struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
}
|
||||||
|
|||||||
55
internal/repositories/usageRepository.go
Normal file
55
internal/repositories/usageRepository.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package repositories
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"history-api/pkg/cache"
|
||||||
|
"history-api/pkg/constants"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UsageRepository interface {
|
||||||
|
GetAIUsage(ctx context.Context, userID string) (int, error)
|
||||||
|
IncrementAIUsage(ctx context.Context, userID string) (int, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageRepository struct {
|
||||||
|
c cache.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUsageRepository(c cache.Cache) UsageRepository {
|
||||||
|
return &usageRepository{
|
||||||
|
c: c,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageRepository) getUsageKey(userID string) string {
|
||||||
|
dateStr := time.Now().Format("20060102")
|
||||||
|
return fmt.Sprintf("usage:ai:%s:%s", userID, dateStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageRepository) GetAIUsage(ctx context.Context, userID string) (int, error) {
|
||||||
|
key := r.getUsageKey(userID)
|
||||||
|
var count int
|
||||||
|
err := r.c.Get(ctx, key, &count)
|
||||||
|
if err != nil {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *usageRepository) IncrementAIUsage(ctx context.Context, userID string) (int, error) {
|
||||||
|
key := r.getUsageKey(userID)
|
||||||
|
rdb := r.c.GetRawClient()
|
||||||
|
|
||||||
|
count, err := rdb.Incr(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 1 {
|
||||||
|
rdb.Expire(ctx, key, constants.UsageExpiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(count), nil
|
||||||
|
}
|
||||||
@@ -2,28 +2,41 @@ package services
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"history-api/internal/repositories"
|
"history-api/internal/repositories"
|
||||||
"history-api/pkg/ai"
|
"history-api/pkg/ai"
|
||||||
|
"history-api/pkg/constants"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatbotService interface {
|
type ChatbotService interface {
|
||||||
Chat(ctx context.Context, projectID *string, question string) (string, error)
|
Chat(ctx context.Context, userID string, projectID *string, question string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type chatbotService struct {
|
type chatbotService struct {
|
||||||
repo repositories.RagRepository
|
repo repositories.RagRepository
|
||||||
ragUtils *ai.RagUtils
|
usageRepo repositories.UsageRepository
|
||||||
|
ragUtils *ai.RagUtils
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChatbotService(repo repositories.RagRepository, ragUtils *ai.RagUtils) ChatbotService {
|
func NewChatbotService(repo repositories.RagRepository, usageRepo repositories.UsageRepository, ragUtils *ai.RagUtils) ChatbotService {
|
||||||
return &chatbotService{
|
return &chatbotService{
|
||||||
repo: repo,
|
repo: repo,
|
||||||
ragUtils: ragUtils,
|
usageRepo: usageRepo,
|
||||||
|
ragUtils: ragUtils,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *chatbotService) Chat(ctx context.Context, projectID *string, question string) (string, error) {
|
func (s *chatbotService) Chat(ctx context.Context, userID string, projectID *string, question string) (string, error) {
|
||||||
|
usage, err := s.usageRepo.GetAIUsage(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to check usage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if usage >= constants.MaxDailyAIUsage {
|
||||||
|
return "", errors.New("you have reached your daily limit of 10 questions. Please come back tomorrow")
|
||||||
|
}
|
||||||
|
|
||||||
qVector, err := s.ragUtils.EmbedQuery(ctx, question)
|
qVector, err := s.ragUtils.EmbedQuery(ctx, question)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to embed question: %w", err)
|
return "", fmt.Errorf("failed to embed question: %w", err)
|
||||||
@@ -61,5 +74,13 @@ Context:
|
|||||||
Question: %s`, contextStr, question)
|
Question: %s`, contextStr, question)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.ragUtils.GenerateResponse(ctx, prompt)
|
response, err := s.ragUtils.GenerateResponse(ctx, prompt)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Tăng số lần sử dụng sau khi gọi AI thành công
|
||||||
|
_, _ = s.usageRepo.IncrementAIUsage(ctx, userID)
|
||||||
|
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,12 +15,13 @@ import (
|
|||||||
"history-api/pkg/cache"
|
"history-api/pkg/cache"
|
||||||
"history-api/pkg/constants"
|
"history-api/pkg/constants"
|
||||||
"history-api/pkg/convert"
|
"history-api/pkg/convert"
|
||||||
"strconv"
|
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v3"
|
"github.com/gofiber/fiber/v3"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -166,7 +167,6 @@ func (s *submissionService) UpdateSubmissionStatus(ctx context.Context, reviewer
|
|||||||
entityRepo := s.entityRepo.WithTx(tx)
|
entityRepo := s.entityRepo.WithTx(tx)
|
||||||
geometryRepo := s.geometryRepo.WithTx(tx)
|
geometryRepo := s.geometryRepo.WithTx(tx)
|
||||||
wikiRepo := s.wikiRepo.WithTx(tx)
|
wikiRepo := s.wikiRepo.WithTx(tx)
|
||||||
ragRepo := s.ragRepo.WithTx(tx)
|
|
||||||
|
|
||||||
submissionUUID, err := convert.StringToUUID(submissionID)
|
submissionUUID, err := convert.StringToUUID(submissionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -625,33 +625,31 @@ func (s *submissionService) UpdateSubmissionStatus(ctx context.Context, reviewer
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = ragRepo.DeleteBySourceIDs(ctx, "wiki", wikiDeleteIDs)
|
ragTask := models.RagIndexTask{
|
||||||
_ = ragRepo.DeleteBySourceIDs(ctx, "entity", entityDeleteIDs)
|
ProjectID: commit.ProjectID,
|
||||||
|
DeleteWikiIDs: wikiDeleteIDs,
|
||||||
for _, wiki := range snapshotData.Wikis {
|
DeleteEntityIDs: entityDeleteIDs,
|
||||||
if wiki.Source == "inline" {
|
|
||||||
cleanText := s.ragUtils.StripHTML(wiki.Title + "\n" + wiki.Doc)
|
|
||||||
chunks, vectors, err := s.ragUtils.PrepareChunks(ctx, cleanText)
|
|
||||||
if err == nil {
|
|
||||||
_ = ragRepo.DeleteBySourceIDs(ctx, "wiki", []string{wiki.ID})
|
|
||||||
for i, chunk := range chunks {
|
|
||||||
_ = ragRepo.SaveChunk(ctx, "wiki", wiki.ID, commit.ProjectID, i, chunk, vectors[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, wiki := range snapshotData.Wikis {
|
||||||
|
ragTask.Wikis = append(ragTask.Wikis, &models.RagWikiItem{
|
||||||
|
ID: wiki.ID,
|
||||||
|
Title: wiki.Title,
|
||||||
|
Doc: wiki.Doc,
|
||||||
|
Source: wiki.Source,
|
||||||
|
})
|
||||||
|
}
|
||||||
for _, entity := range snapshotData.Entities {
|
for _, entity := range snapshotData.Entities {
|
||||||
if entity.Source == "inline" {
|
ragTask.Entities = append(ragTask.Entities, &models.RagEntityItem{
|
||||||
cleanText := s.ragUtils.StripHTML(entity.Name + "\n" + entity.Description)
|
ID: entity.ID,
|
||||||
chunks, vectors, err := s.ragUtils.PrepareChunks(ctx, cleanText)
|
Name: entity.Name,
|
||||||
if err == nil {
|
Description: entity.Description,
|
||||||
_ = ragRepo.DeleteBySourceIDs(ctx, "entity", []string{entity.ID})
|
Source: entity.Source,
|
||||||
for i, chunk := range chunks {
|
})
|
||||||
_ = ragRepo.SaveChunk(ctx, "entity", entity.ID, commit.ProjectID, i, chunk, vectors[i])
|
}
|
||||||
}
|
|
||||||
}
|
if err := s.c.PublishTask(ctx, constants.StreamRagName, constants.TaskTypeRagIndexSubmission, ragTask); err != nil {
|
||||||
}
|
log.Error().Err(err).Str("project_id", commit.ProjectID).Msg("Failed to publish RAG index task")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -677,6 +675,7 @@ func (s *submissionService) UpdateSubmissionStatus(ctx context.Context, reviewer
|
|||||||
bgCtx := context.Background()
|
bgCtx := context.Background()
|
||||||
_ = s.c.DelByPattern(bgCtx, "entity:search*")
|
_ = s.c.DelByPattern(bgCtx, "entity:search*")
|
||||||
_ = s.c.DelByPattern(bgCtx, "geometry:search*")
|
_ = s.c.DelByPattern(bgCtx, "geometry:search*")
|
||||||
|
_ = s.c.DelByPattern(bgCtx, "geometry:search:entity*")
|
||||||
_ = s.c.DelByPattern(bgCtx, "wiki:search*")
|
_ = s.c.DelByPattern(bgCtx, "wiki:search*")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package constants
|
|||||||
const (
|
const (
|
||||||
StreamEmailName = "stream:email_tasks"
|
StreamEmailName = "stream:email_tasks"
|
||||||
StreamStorageName = "stream:storage_tasks"
|
StreamStorageName = "stream:storage_tasks"
|
||||||
|
StreamRagName = "stream:rag_tasks"
|
||||||
GroupEmailName = "email_workers_group"
|
GroupEmailName = "email_workers_group"
|
||||||
GroupStorageName = "storage_workers_group"
|
GroupStorageName = "storage_workers_group"
|
||||||
|
GroupRagName = "rag_workers_group"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ const (
|
|||||||
TaskTypeNotifyHistorianReview TaskType = "NOTIFY_HISTORIAN_REVIEW"
|
TaskTypeNotifyHistorianReview TaskType = "NOTIFY_HISTORIAN_REVIEW"
|
||||||
TaskTypeDeleteMedia TaskType = "DELETE_MEDIA"
|
TaskTypeDeleteMedia TaskType = "DELETE_MEDIA"
|
||||||
TaskTypeBulkDeleteMedia TaskType = "BULK_DELETE_MEDIA"
|
TaskTypeBulkDeleteMedia TaskType = "BULK_DELETE_MEDIA"
|
||||||
|
TaskTypeRagIndexSubmission TaskType = "RAG_INDEX_SUBMISSION"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t TaskType) String() string {
|
func (t TaskType) String() string {
|
||||||
|
|||||||
8
pkg/constants/usage.go
Normal file
8
pkg/constants/usage.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package constants
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
const (
|
||||||
|
MaxDailyAIUsage = 10
|
||||||
|
UsageExpiration = 24 * time.Hour
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user