Skip to content

Commit 6a51ba9

Browse files
QinYuuuuRader
authored andcommitted
fix: enhance concurrency in CheckRepoFiles and improve test coverage
1 parent 0c24eb9 commit 6a51ba9

File tree

16 files changed

+369
-263
lines changed

16 files changed

+369
-263
lines changed

_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepoStore.go

Lines changed: 48 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

builder/store/database/repository.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ type RepoStore interface {
9494
BatchUpdate(ctx context.Context, repos []*Repository) error
9595
FindByRepoTypeAndPaths(ctx context.Context, repoType types.RepositoryType, path []string) ([]Repository, error)
9696
FindUnhashedRepos(ctx context.Context, batchSize int, lastID int64) ([]Repository, error)
97+
UpdateRepoSensitiveCheckStatus(ctx context.Context, repoID int64, status types.SensitiveCheckStatus) error
98+
}
99+
100+
func (s *repoStoreImpl) UpdateRepoSensitiveCheckStatus(ctx context.Context, repoID int64, status types.SensitiveCheckStatus) error {
101+
_, err := s.db.Operator.Core.NewUpdate().
102+
Model(&Repository{}).
103+
Set("sensitive_check_status = ?", status).
104+
Where("id = ?", repoID).
105+
Exec(ctx)
106+
if err != nil {
107+
return err
108+
}
109+
return nil
97110
}
98111

99112
func newRepoStoreInstance(db *DB) RepoStore {

builder/store/database/repository_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,3 +1799,25 @@ func TestRepoStore_PublicToUserWithCacheFailed(t *testing.T) {
17991799
require.ElementsMatch(t, c.expected, names)
18001800
require.Equal(t, len(c.expected), count)
18011801
}
1802+
1803+
func TestRepoStore_UpdateRepoSensitiveCheckStatus(t *testing.T) {
1804+
db := tests.InitTestDB()
1805+
defer db.Close()
1806+
ctx := context.TODO()
1807+
store := database.NewRepoStoreWithDB(db)
1808+
1809+
repo, err := store.CreateRepo(ctx, database.Repository{
1810+
Name: "repo1",
1811+
UserID: 123,
1812+
GitPath: "foos_u/bar",
1813+
})
1814+
require.Nil(t, err)
1815+
1816+
err = store.UpdateRepoSensitiveCheckStatus(ctx, repo.ID, types.SensitiveCheckPass)
1817+
require.Nil(t, err)
1818+
1819+
rp := &database.Repository{}
1820+
err = db.Core.NewSelect().Model(rp).Where("id=?", repo.ID).Scan(ctx)
1821+
require.Nil(t, err)
1822+
require.Equal(t, types.SensitiveCheckPass, rp.SensitiveCheckStatus)
1823+
}

moderation/checker/file_checker.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package checker
22

33
import (
4+
"context"
45
"io"
56
"path"
67
"slices"
@@ -16,7 +17,7 @@ var knownTextFileExts = []string{".md", ".txt", ".csv", ".json", ".jsonl", ".htm
1617
".cs", ".js", ".ts", ".py", ".php", ".java", ".c", ".cpp", ".go", ".rb", ".sh"}
1718

1819
type FileChecker interface {
19-
Run(reader io.Reader) (types.SensitiveCheckStatus, string)
20+
Run(ctx context.Context, reader io.Reader) (types.SensitiveCheckStatus, string)
2021
}
2122

2223
// GetFileChecker returns a FileChecker for a given file based on its type and path.
@@ -66,22 +67,22 @@ func NewImageFileChecker() FileChecker {
6667
checker: contentChecker,
6768
}
6869
}
69-
func (c *ImageFileChecker) Run(io.Reader) (types.SensitiveCheckStatus, string) {
70+
func (c *ImageFileChecker) Run(context.Context, io.Reader) (types.SensitiveCheckStatus, string) {
7071
//TODO:check image in the future
7172
return types.SensitiveCheckSkip, "skip image file"
7273
}
7374

7475
type LfsFileChecker struct {
7576
}
7677

77-
func (c *LfsFileChecker) Run(io.Reader) (types.SensitiveCheckStatus, string) {
78+
func (c *LfsFileChecker) Run(context.Context, io.Reader) (types.SensitiveCheckStatus, string) {
7879
// dont need to check lfs file content
7980
return types.SensitiveCheckSkip, "skip lfs file"
8081
}
8182

8283
type FolderChecker struct {
8384
}
8485

85-
func (c *FolderChecker) Run(reader io.Reader) (types.SensitiveCheckStatus, string) {
86+
func (c *FolderChecker) Run(ctx context.Context, reader io.Reader) (types.SensitiveCheckStatus, string) {
8687
return types.SensitiveCheckSkip, "skip folder"
8788
}

moderation/checker/file_checker_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package checker
22

33
import (
4+
"context"
45
"strings"
56
"testing"
67

@@ -64,7 +65,7 @@ func TestImageFileChecker_Run(t *testing.T) {
6465
reader := strings.NewReader("image content")
6566
expectedStatus := types.SensitiveCheckSkip
6667
expectedMessage := "skip image file"
67-
status, message := checker.Run(reader)
68+
status, message := checker.Run(context.Background(), reader)
6869
if status != expectedStatus || message != expectedMessage {
6970
t.Errorf("Expected (%v, %v), got (%v, %v)", expectedStatus, expectedMessage, status, message)
7071
}
@@ -75,7 +76,7 @@ func TestLfsFileChecker_Run(t *testing.T) {
7576
reader := strings.NewReader("lfs content")
7677
expectedStatus := types.SensitiveCheckSkip
7778
expectedMessage := "skip lfs file"
78-
status, message := checker.Run(reader)
79+
status, message := checker.Run(context.Background(), reader)
7980
if status != expectedStatus || message != expectedMessage {
8081
t.Errorf("Expected (%v, %v), got (%v, %v)", expectedStatus, expectedMessage, status, message)
8182
}
@@ -86,7 +87,7 @@ func TestFolderChecker_Run(t *testing.T) {
8687
reader := strings.NewReader("folder content")
8788
expectedStatus := types.SensitiveCheckSkip
8889
expectedMessage := "skip folder"
89-
status, message := checker.Run(reader)
90+
status, message := checker.Run(context.Background(), reader)
9091
if status != expectedStatus || message != expectedMessage {
9192
t.Errorf("Expected (%v, %v), got (%v, %v)", expectedStatus, expectedMessage, status, message)
9293
}

moderation/checker/text_file_checker.go

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"log/slog"
88
"time"
99

10+
"github.com/avast/retry-go/v4"
1011
"opencsg.com/csghub-server/builder/sensitive"
1112
"opencsg.com/csghub-server/common/types"
1213
)
@@ -21,49 +22,77 @@ func NewTextFileChecker() *TextFileChecker {
2122
}
2223
}
2324

24-
func (c *TextFileChecker) Run(reader io.Reader) (types.SensitiveCheckStatus, string) {
25-
//at most 1MB
26-
reader = io.LimitReader(reader, 1024*1024)
27-
const blockSize = 10 * 9000
28-
// const blockSize = 3000
29-
var bufs []bytes.Buffer
30-
for {
31-
buf := bytes.Buffer{}
32-
var err error
33-
var avaliableSize int64
34-
if avaliableSize, err = io.CopyN(&buf, reader, blockSize); err != nil && err != io.EOF {
35-
return types.SensitiveCheckException, "failed to read file content"
36-
}
37-
if avaliableSize > 0 {
38-
bufs = append(bufs, buf)
39-
}
40-
//no more data to read
41-
if avaliableSize < blockSize {
42-
break
43-
}
25+
func (c *TextFileChecker) Run(ctx context.Context, reader io.Reader) (types.SensitiveCheckStatus, string) {
26+
type result struct {
27+
status types.SensitiveCheckStatus
28+
message string
4429
}
45-
for _, buf := range bufs {
46-
var result *sensitive.CheckResult
47-
var err error
48-
slog.Debug("check text", slog.String("scenario", string(sensitive.ScenarioCommentDetection)), slog.String("text", buf.String()))
49-
//do local check first
50-
txt := buf.String()
51-
contains := GetLocalWordChecker().ContainsSensitiveWord(txt)
52-
if contains {
53-
return types.SensitiveCheckFail, "contains sensitive word"
54-
}
55-
//call remote checker
56-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
57-
result, err = c.PassTextCheck(ctx, sensitive.ScenarioCommentDetection, txt)
58-
cancel()
59-
if err != nil {
60-
return types.SensitiveCheckException, "call sensitive checker api failed"
30+
31+
resultCh := make(chan result, 1)
32+
33+
go func() {
34+
//at most 1MB
35+
reader = io.LimitReader(reader, 1024*1024)
36+
const blockSize = 10 * 9000
37+
// const blockSize = 3000
38+
var bufs []bytes.Buffer
39+
for {
40+
buf := bytes.Buffer{}
41+
var err error
42+
var avaliableSize int64
43+
if avaliableSize, err = io.CopyN(&buf, reader, blockSize); err != nil && err != io.EOF {
44+
resultCh <- result{types.SensitiveCheckException, "failed to read file content"}
45+
return
46+
}
47+
if avaliableSize > 0 {
48+
bufs = append(bufs, buf)
49+
}
50+
//no more data to read
51+
if avaliableSize < blockSize {
52+
break
53+
}
6154
}
55+
for _, buf := range bufs {
56+
var res *sensitive.CheckResult
57+
var err error
58+
slog.Debug("check text", slog.String("scenario", string(sensitive.ScenarioCommentDetection)), slog.String("text", buf.String()))
59+
//do local check first
60+
txt := buf.String()
61+
contains := GetLocalWordChecker().ContainsSensitiveWord(txt)
62+
if contains {
63+
resultCh <- result{types.SensitiveCheckFail, "contains sensitive word"}
64+
return
65+
}
66+
//call remote checker
67+
res, err = retry.DoWithData(
68+
func() (*sensitive.CheckResult, error) {
69+
reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
70+
res, err = c.PassTextCheck(reqCtx, sensitive.ScenarioCommentDetection, txt)
71+
cancel()
72+
if err != nil {
73+
return nil, err
74+
}
75+
return res, nil
76+
}, retry.Attempts(3), retry.DelayType(retry.BackOffDelay), retry.LastErrorOnly(true))
77+
78+
if err != nil {
79+
resultCh <- result{types.SensitiveCheckException, "call sensitive checker api failed"}
80+
return
81+
}
6282

63-
if result.IsSensitive {
64-
return types.SensitiveCheckFail, result.Reason
83+
if res.IsSensitive {
84+
resultCh <- result{types.SensitiveCheckFail, res.Reason}
85+
return
86+
}
6587
}
66-
}
6788

68-
return types.SensitiveCheckPass, ""
89+
resultCh <- result{types.SensitiveCheckPass, ""}
90+
}()
91+
92+
select {
93+
case <-ctx.Done():
94+
return types.SensitiveCheckException, "context canceled"
95+
case res := <-resultCh:
96+
return res.status, res.message
97+
}
6998
}

0 commit comments

Comments
 (0)