diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7a9526e..982b129 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,8 +9,8 @@ jobs: - name: Checkout uses: actions/checkout@v4 - name: Lint - run: docker-compose run --rm lint + run: docker compose run --rm lint - name: Test - run: docker-compose run --rm test + run: docker compose run --rm test - name: E2E - run: docker-compose run --rm test-build + run: docker compose run --rm test-build diff --git a/.github/workflows/licensed.yml b/.github/workflows/licensed.yml index bf0944f..0c6ccda 100644 --- a/.github/workflows/licensed.yml +++ b/.github/workflows/licensed.yml @@ -35,7 +35,7 @@ jobs: - run: go mod vendor # Ruby is required for licensed - - uses: ruby/setup-ruby@6bd3d993c602f6b675728ebaecb2b569ff86e99b + - uses: ruby/setup-ruby@90be1154f987f4dc0fe0dd0feedac9e473aa4ba8 # v1 with: ruby-version: "3.2" diff --git a/README.md b/README.md index b67a6b8..c407ce7 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,8 @@ When there are machines which have access to both the public internet and the GH A path to a file containing a newline separated list of repositories to be synced. Each entry follows the format of `repo-name`. - `actions-admin-user` _(optional)_ The name of the Actions admin user, which will be used for updating the chosen action. To use the default user, pass `actions-admin`. If not set, the impersonation is disabled. Note that `site_admin` scope is required in the token for the impersonation to work. +- `batch-size` _(optional)_ + Number of refs to push in each batch. Default is 0 (no batching). Use a value like 100 if pushing fails for large repositories with many branches and tags. **Example Usage:** @@ -114,6 +116,8 @@ When no machine has access to both the public internet and the GHES instance: Limit push to specific repositories in the cache directory. - `actions-admin-user` _(optional)_ The name of the Actions admin user, which will be used for updating the chosen action. To use the default user, pass `actions-admin`. If not set, the impersonation is disabled. Note that `site_admin` scope is required in the token for the impersonation to work. +- `batch-size` _(optional)_ + Number of refs to push in each batch. Default is 0 (no batching). Use a value like 100 if pushing fails for large repositories with many branches and tags. **Example Usage:** diff --git a/script/bootstrap b/script/bootstrap index ba23910..347efca 100755 --- a/script/bootstrap +++ b/script/bootstrap @@ -18,12 +18,12 @@ if [ ! -f go.mod ]; then go mod init tools fi -go get golang.org/x/tools/go/packages@master +go get golang.org/x/tools/go/packages@v0.16.0 if [ ! -f "${GOBIN}/mockgen" ]; then echo "mockgen was not found, installing..." - go get github.com/golang/mock/gomock@master - go get github.com/golang/mock/mockgen@master + go get github.com/golang/mock/gomock@v1.6.0 + go get github.com/golang/mock/mockgen@v1.6.0 fi if [ ! -f "${GOBIN}/golangci-lint" ]; then @@ -33,5 +33,5 @@ fi if [ ! -f "${GOBIN}/goimports" ]; then echo "goimports was not found, installing..." - go get golang.org/x/tools/cmd/goimports@master + go get golang.org/x/tools/cmd/goimports@v0.16.0 fi diff --git a/src/git.go b/src/git.go index eba3f50..5aa4ead 100644 --- a/src/git.go +++ b/src/git.go @@ -5,6 +5,7 @@ import ( "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/config" + "github.com/go-git/go-git/v5/plumbing/storer" ) // A really thin Git wrapper so we can stub it out in our tests @@ -19,6 +20,7 @@ type GitRepository interface { DeleteRemote(string) error CreateRemote(*config.RemoteConfig) (GitRemote, error) FetchContext(context.Context, *git.FetchOptions) error + References() (storer.ReferenceIter, error) } type GitRemote interface { @@ -65,3 +67,7 @@ func (r *gitRepository) CreateRemote(c *config.RemoteConfig) (GitRemote, error) func (r *gitRepository) FetchContext(ctx context.Context, o *git.FetchOptions) error { return r.inner.FetchContext(ctx, o) } + +func (r *gitRepository) References() (storer.ReferenceIter, error) { + return r.inner.References() +} diff --git a/src/git_test.go b/src/git_test.go new file mode 100644 index 0000000..d755cd6 --- /dev/null +++ b/src/git_test.go @@ -0,0 +1,75 @@ +package src + +import ( + "context" + "testing" + + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/config" + "github.com/go-git/go-git/v5/plumbing/storer" + "github.com/stretchr/testify/assert" +) + +// Tests for GitRepository interface and implementations + +func TestGitRepositoryInterface(t *testing.T) { + // This test verifies that our mock implements the GitRepository interface + var _ GitRepository = &mockGitRepository{} +} + +func TestGitRemoteInterface(t *testing.T) { + // This test verifies that our mock implements the GitRemote interface + var _ GitRemote = &mockGitRemote{} +} + +// Ensure the mockGitRepository implements all methods of GitRepository +func TestMockGitRepository_DeleteRemote(t *testing.T) { + repo := &mockGitRepository{} + err := repo.DeleteRemote("origin") + assert.NoError(t, err) +} + +func TestMockGitRepository_CreateRemote(t *testing.T) { + repo := &mockGitRepository{} + remote, err := repo.CreateRemote(&config.RemoteConfig{Name: "test"}) + assert.NoError(t, err) + assert.Nil(t, remote) +} + +func TestMockGitRepository_FetchContext(t *testing.T) { + repo := &mockGitRepository{} + err := repo.FetchContext(context.Background(), &git.FetchOptions{}) + assert.NoError(t, err) +} + +func TestMockGitRepository_References(t *testing.T) { + repo := &mockGitRepository{} + refs, err := repo.References() + assert.NoError(t, err) + assert.NotNil(t, refs) + + // Verify it returns a valid iterator + _, ok := refs.(storer.ReferenceIter) + assert.True(t, ok) +} + +// Ensure the mockGitRemote implements all methods of GitRemote +func TestMockGitRemote_PushContext(t *testing.T) { + remote := &mockGitRemote{} + err := remote.PushContext(context.Background(), &git.PushOptions{}) + assert.NoError(t, err) +} + +func TestMockGitRemote_Config(t *testing.T) { + remote := &mockGitRemote{} + cfg := remote.Config() + assert.NotNil(t, cfg) + assert.Equal(t, "test-remote", cfg.Name) + + // Test with custom config + customRemote := &mockGitRemote{ + remoteConfig: &config.RemoteConfig{Name: "custom-remote"}, + } + cfg = customRemote.Config() + assert.Equal(t, "custom-remote", cfg.Name) +} diff --git a/src/push.go b/src/push.go index dd3ce19..95968b1 100644 --- a/src/push.go +++ b/src/push.go @@ -9,6 +9,7 @@ import ( "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/config" + "github.com/go-git/go-git/v5/plumbing" "github.com/go-git/go-git/v5/plumbing/transport" "github.com/go-git/go-git/v5/plumbing/transport/http" "github.com/google/go-github/v43/github" @@ -22,9 +23,16 @@ const enterpriseAPIPath = "/api/v3" const enterpriseVersionHeaderKey = "X-GitHub-Enterprise-Version" const xOAuthScopesHeader = "X-OAuth-Scopes" +// DefaultBatchSize of 0 means no batching (push all refs at once, original behavior) +const DefaultBatchSize = 0 + +// MinBatchSize is the minimum allowed batch size when batching is enabled +const MinBatchSize = 10 + type PushOnlyFlags struct { BaseURL, Token, ActionsAdminUser string DisableGitAuth bool + BatchSize int } type PushFlags struct { @@ -42,6 +50,7 @@ func (f *PushOnlyFlags) Init(cmd *cobra.Command) { cmd.Flags().StringVar(&f.ActionsAdminUser, "actions-admin-user", "", "A user to impersonate for the push requests. To use the default name, pass 'actions-admin'. Note that the site_admin scope in the token is required for the impersonation to work.") cmd.Flags().StringVar(&f.Token, "destination-token", "", "Token to access API on GHES instance") cmd.Flags().BoolVar(&f.DisableGitAuth, "disable-push-git-auth", false, "Disables git authentication whilst pushing") + cmd.Flags().IntVar(&f.BatchSize, "batch-size", DefaultBatchSize, "Number of refs to push in each batch (0 = no batching). Use a value like 100 if pushing fails for large repositories.") } func (f *PushFlags) Validate() Validations { @@ -56,6 +65,9 @@ func (f *PushOnlyFlags) Validate() Validations { if f.Token == "" { validations = append(validations, "--destination-token must be set") } + if f.BatchSize != 0 && f.BatchSize < MinBatchSize { + validations = append(validations, fmt.Sprintf("--batch-size must be 0 (no batching) or at least %d", MinBatchSize)) + } return validations } @@ -282,16 +294,86 @@ func syncWithCachedRepository(ctx context.Context, flags *PushFlags, ghRepo *git Password: flags.Token, } } - err = remote.PushContext(ctx, &git.PushOptions{ - RemoteName: remote.Config().Name, - RefSpecs: []config.RefSpec{ - "+refs/heads/*:refs/heads/*", - "+refs/tags/*:refs/tags/*", - }, - Auth: auth, - }) - if errors.Cause(err) == git.NoErrAlreadyUpToDate { - return nil + + // If batch size is 0 or negative, use original wildcard approach (no batching) + if flags.BatchSize <= 0 { + err = remote.PushContext(ctx, &git.PushOptions{ + RemoteName: remote.Config().Name, + RefSpecs: []config.RefSpec{ + "+refs/heads/*:refs/heads/*", + "+refs/tags/*:refs/tags/*", + }, + Auth: auth, + }) + if errors.Cause(err) == git.NoErrAlreadyUpToDate { + return nil + } + return errors.Wrapf(err, "failed to push to repo: %s", ghRepo.GetCloneURL()) } - return errors.Wrapf(err, "failed to push to repo: %s", ghRepo.GetCloneURL()) + + // Batching requested - collect all refs and push in batches + refs, err := collectRefs(gitRepo) + if err != nil { + return errors.Wrap(err, "error collecting refs") + } + + return pushRefsInBatches(ctx, remote, refs, flags.BatchSize, auth, ghRepo.GetCloneURL()) +} + +// collectRefs gathers all branch and tag refs from the repository +func collectRefs(gitRepo GitRepository) ([]plumbing.ReferenceName, error) { + refIter, err := gitRepo.References() + if err != nil { + return nil, err + } + + var refs []plumbing.ReferenceName + err = refIter.ForEach(func(ref *plumbing.Reference) error { + name := ref.Name() + // Only include branches and tags + if name.IsBranch() || name.IsTag() { + refs = append(refs, name) + } + return nil + }) + if err != nil { + return nil, err + } + + return refs, nil +} + +// pushRefsInBatches pushes refs in smaller batches to avoid server-side limits +func pushRefsInBatches(ctx context.Context, remote GitRemote, refs []plumbing.ReferenceName, batchSize int, auth transport.AuthMethod, cloneURL string) error { + totalRefs := len(refs) + + for i := 0; i < totalRefs; i += batchSize { + end := i + batchSize + if end > totalRefs { + end = totalRefs + } + + batch := refs[i:end] + refSpecs := make([]config.RefSpec, len(batch)) + for j, ref := range batch { + // Create a refspec like "+refs/heads/main:refs/heads/main" + refSpecs[j] = config.RefSpec("+" + ref.String() + ":" + ref.String()) + } + + err := remote.PushContext(ctx, &git.PushOptions{ + RemoteName: remote.Config().Name, + RefSpecs: refSpecs, + Auth: auth, + }) + + if err != nil { + if errors.Cause(err) == git.NoErrAlreadyUpToDate { + // This batch was already up to date, continue to next batch + continue + } + return errors.Wrapf(err, "failed to push batch %d-%d of %d refs to repo: %s", i+1, end, totalRefs, cloneURL) + } + } + + return nil } diff --git a/src/push_test.go b/src/push_test.go new file mode 100644 index 0000000..05a8100 --- /dev/null +++ b/src/push_test.go @@ -0,0 +1,400 @@ +package src + +import ( + "context" + "fmt" + "testing" + + "github.com/go-git/go-git/v5" + "github.com/go-git/go-git/v5/config" + "github.com/go-git/go-git/v5/plumbing" + "github.com/go-git/go-git/v5/plumbing/storer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Mock implementations for testing + +type mockReferenceIter struct { + refs []*plumbing.Reference + index int +} + +func (m *mockReferenceIter) Next() (*plumbing.Reference, error) { + if m.index >= len(m.refs) { + return nil, storer.ErrStop + } + ref := m.refs[m.index] + m.index++ + return ref, nil +} + +func (m *mockReferenceIter) ForEach(fn func(*plumbing.Reference) error) error { + for _, ref := range m.refs { + if err := fn(ref); err != nil { + if err == storer.ErrStop { + return nil + } + return err + } + } + return nil +} + +func (m *mockReferenceIter) Close() {} + +type mockGitRepository struct { + refs []*plumbing.Reference + err error +} + +func (m *mockGitRepository) DeleteRemote(name string) error { + return nil +} + +func (m *mockGitRepository) CreateRemote(c *config.RemoteConfig) (GitRemote, error) { + return nil, nil +} + +func (m *mockGitRepository) FetchContext(ctx context.Context, o *git.FetchOptions) error { + return nil +} + +func (m *mockGitRepository) References() (storer.ReferenceIter, error) { + if m.err != nil { + return nil, m.err + } + return &mockReferenceIter{refs: m.refs, index: 0}, nil +} + +type mockGitRemote struct { + pushCalls [][]config.RefSpec + pushError error + alreadyUpToDate bool + remoteConfig *config.RemoteConfig +} + +func (m *mockGitRemote) PushContext(ctx context.Context, o *git.PushOptions) error { + m.pushCalls = append(m.pushCalls, o.RefSpecs) + if m.alreadyUpToDate { + return git.NoErrAlreadyUpToDate + } + return m.pushError +} + +func (m *mockGitRemote) Config() *config.RemoteConfig { + if m.remoteConfig != nil { + return m.remoteConfig + } + return &config.RemoteConfig{Name: "test-remote"} +} + +// Tests for PushOnlyFlags.Validate batch size validation + +func TestPushOnlyFlags_Validate_BatchSize(t *testing.T) { + tests := []struct { + name string + batchSize int + expectErr bool + errMessage string + }{ + { + name: "batch size 0 (no batching) is valid", + batchSize: 0, + expectErr: false, + }, + { + name: "batch size at minimum (10) is valid", + batchSize: MinBatchSize, + expectErr: false, + }, + { + name: "batch size above minimum is valid", + batchSize: 100, + expectErr: false, + }, + { + name: "batch size below minimum is invalid", + batchSize: 5, + expectErr: true, + errMessage: fmt.Sprintf("--batch-size must be 0 (no batching) or at least %d", MinBatchSize), + }, + { + name: "batch size of 1 is invalid", + batchSize: 1, + expectErr: true, + errMessage: fmt.Sprintf("--batch-size must be 0 (no batching) or at least %d", MinBatchSize), + }, + { + name: "batch size of 9 is invalid", + batchSize: 9, + expectErr: true, + errMessage: fmt.Sprintf("--batch-size must be 0 (no batching) or at least %d", MinBatchSize), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flags := PushOnlyFlags{ + BaseURL: "https://example.com", + Token: "test-token", + BatchSize: tt.batchSize, + } + + validations := flags.Validate() + + if tt.expectErr { + require.NotEmpty(t, validations, "expected validation error") + found := false + for _, v := range validations { + if v == tt.errMessage { + found = true + break + } + } + assert.True(t, found, "expected error message not found: %s", tt.errMessage) + } else { + // Check that batch size validation didn't add an error + for _, v := range validations { + assert.NotContains(t, v, "batch-size", "unexpected batch-size validation error") + } + } + }) + } +} + +// Tests for collectRefs function + +func TestCollectRefs(t *testing.T) { + tests := []struct { + name string + refs []*plumbing.Reference + expectedLen int + expectedRefs []plumbing.ReferenceName + expectErr bool + }{ + { + name: "empty repository", + refs: []*plumbing.Reference{}, + expectedLen: 0, + }, + { + name: "branches only", + refs: []*plumbing.Reference{ + plumbing.NewHashReference(plumbing.NewBranchReferenceName("main"), plumbing.NewHash("abc123")), + plumbing.NewHashReference(plumbing.NewBranchReferenceName("feature"), plumbing.NewHash("def456")), + }, + expectedLen: 2, + expectedRefs: []plumbing.ReferenceName{ + plumbing.NewBranchReferenceName("main"), + plumbing.NewBranchReferenceName("feature"), + }, + }, + { + name: "tags only", + refs: []*plumbing.Reference{ + plumbing.NewHashReference(plumbing.NewTagReferenceName("v1.0.0"), plumbing.NewHash("abc123")), + plumbing.NewHashReference(plumbing.NewTagReferenceName("v2.0.0"), plumbing.NewHash("def456")), + }, + expectedLen: 2, + expectedRefs: []plumbing.ReferenceName{ + plumbing.NewTagReferenceName("v1.0.0"), + plumbing.NewTagReferenceName("v2.0.0"), + }, + }, + { + name: "mixed branches and tags", + refs: []*plumbing.Reference{ + plumbing.NewHashReference(plumbing.NewBranchReferenceName("main"), plumbing.NewHash("abc123")), + plumbing.NewHashReference(plumbing.NewTagReferenceName("v1.0.0"), plumbing.NewHash("def456")), + plumbing.NewHashReference(plumbing.NewBranchReferenceName("develop"), plumbing.NewHash("ghi789")), + }, + expectedLen: 3, + }, + { + name: "filters out HEAD and other refs", + refs: []*plumbing.Reference{ + plumbing.NewHashReference(plumbing.HEAD, plumbing.NewHash("abc123")), + plumbing.NewHashReference(plumbing.NewBranchReferenceName("main"), plumbing.NewHash("def456")), + plumbing.NewHashReference(plumbing.NewRemoteReferenceName("origin", "main"), plumbing.NewHash("ghi789")), + plumbing.NewHashReference(plumbing.NewTagReferenceName("v1.0.0"), plumbing.NewHash("jkl012")), + }, + expectedLen: 2, // Only main branch and v1.0.0 tag + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &mockGitRepository{refs: tt.refs} + + refs, err := collectRefs(repo) + + if tt.expectErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + assert.Len(t, refs, tt.expectedLen) + + if tt.expectedRefs != nil { + for i, expected := range tt.expectedRefs { + assert.Equal(t, expected, refs[i]) + } + } + }) + } +} + +func TestCollectRefs_Error(t *testing.T) { + repo := &mockGitRepository{err: fmt.Errorf("failed to get references")} + + refs, err := collectRefs(repo) + + require.Error(t, err) + assert.Nil(t, refs) + assert.Contains(t, err.Error(), "failed to get references") +} + +// Tests for pushRefsInBatches function + +func TestPushRefsInBatches(t *testing.T) { + tests := []struct { + name string + refs []plumbing.ReferenceName + batchSize int + expectedBatches int + alreadyUpToDate bool + pushError error + expectErr bool + expectedErrSubstr string + }{ + { + name: "single batch - fewer refs than batch size", + refs: []plumbing.ReferenceName{ + plumbing.NewBranchReferenceName("main"), + plumbing.NewBranchReferenceName("feature"), + }, + batchSize: 10, + expectedBatches: 1, + }, + { + name: "single batch - exact batch size", + refs: createNRefs(10), + batchSize: 10, + expectedBatches: 1, + }, + { + name: "multiple batches - exactly divisible", + refs: createNRefs(30), + batchSize: 10, + expectedBatches: 3, + }, + { + name: "multiple batches - not exactly divisible", + refs: createNRefs(25), + batchSize: 10, + expectedBatches: 3, // 10 + 10 + 5 + }, + { + name: "empty refs", + refs: []plumbing.ReferenceName{}, + batchSize: 10, + expectedBatches: 0, + }, + { + name: "all batches already up to date", + refs: []plumbing.ReferenceName{ + plumbing.NewBranchReferenceName("main"), + }, + batchSize: 10, + expectedBatches: 1, + alreadyUpToDate: true, + }, + { + name: "push error", + refs: []plumbing.ReferenceName{ + plumbing.NewBranchReferenceName("main"), + }, + batchSize: 10, + pushError: fmt.Errorf("network error"), + expectErr: true, + expectedErrSubstr: "failed to push batch", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + remote := &mockGitRemote{ + alreadyUpToDate: tt.alreadyUpToDate, + pushError: tt.pushError, + } + + err := pushRefsInBatches(context.Background(), remote, tt.refs, tt.batchSize, nil, "https://example.com/repo.git") + + if tt.expectErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErrSubstr) + return + } + + require.NoError(t, err) + assert.Len(t, remote.pushCalls, tt.expectedBatches) + }) + } +} + +func TestPushRefsInBatches_RefSpecFormat(t *testing.T) { + refs := []plumbing.ReferenceName{ + plumbing.NewBranchReferenceName("main"), + plumbing.NewTagReferenceName("v1.0.0"), + } + + remote := &mockGitRemote{} + + err := pushRefsInBatches(context.Background(), remote, refs, 10, nil, "https://example.com/repo.git") + + require.NoError(t, err) + require.Len(t, remote.pushCalls, 1) + require.Len(t, remote.pushCalls[0], 2) + + // Check refspec format: should be "+refs/heads/main:refs/heads/main" + assert.Equal(t, config.RefSpec("+refs/heads/main:refs/heads/main"), remote.pushCalls[0][0]) + assert.Equal(t, config.RefSpec("+refs/tags/v1.0.0:refs/tags/v1.0.0"), remote.pushCalls[0][1]) +} + +func TestPushRefsInBatches_BatchSizes(t *testing.T) { + // Create 25 refs + refs := createNRefs(25) + batchSize := 10 + + remote := &mockGitRemote{} + + err := pushRefsInBatches(context.Background(), remote, refs, batchSize, nil, "https://example.com/repo.git") + + require.NoError(t, err) + require.Len(t, remote.pushCalls, 3) + + // First batch should have 10 refs + assert.Len(t, remote.pushCalls[0], 10) + // Second batch should have 10 refs + assert.Len(t, remote.pushCalls[1], 10) + // Third batch should have 5 refs (remainder) + assert.Len(t, remote.pushCalls[2], 5) +} + +// Tests for constants + +func TestConstants(t *testing.T) { + assert.Equal(t, 0, DefaultBatchSize, "DefaultBatchSize should be 0 for backward compatibility") + assert.Equal(t, 10, MinBatchSize, "MinBatchSize should be 10") +} + +// Helper function to create N test refs +func createNRefs(n int) []plumbing.ReferenceName { + refs := make([]plumbing.ReferenceName, n) + for i := 0; i < n; i++ { + refs[i] = plumbing.NewBranchReferenceName(fmt.Sprintf("branch-%d", i)) + } + return refs +}