diff --git a/README.md b/README.md index 41ee12b..e804fb1 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ When there are machines which have access to both the public internet and the GH - `destination-token` _(required)_ A personal access token to authenticate against the GHES instance when uploading repositories. - `repo-name` _(optional)_ - A single repository to be synced. In the format of `owner/repo`. Optionally if you wish the repository to be named different on your GHES instance you can provide an aliase in the format: `upstream_owner/up_streamrepo:destination_owner/destination_repo` + A single repository to be synced. In the format of `owner/repo`. Optionally if you wish the repository to be named different on your GHES instance you can provide an alias in the format: `upstream_owner/upstream_repo:destination_owner/destination_repo` - `repo-name-list` _(optional)_ A comma-separated list of repositories to be synced. Each entry follows the format of `repo-name`. - `repo-name-list-file` _(optional)_ diff --git a/cmd/root.go b/cmd/root.go index 7f34f91..5189938 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -11,8 +11,7 @@ import ( ) var ( - cacheDir string - rootCmd = &cobra.Command{ + rootCmd = &cobra.Command{ Use: "actions-sync", Short: "GHES Actions Sync", Long: "Sync Actions from github.com to a GHES instance.", @@ -22,7 +21,7 @@ var ( Use: "version", Short: "The version of actions-sync in use.", Run: func(cmd *cobra.Command, args []string) { - fmt.Fprintln(os.Stdout, "GHES Actions Sync v0.1") + fmt.Fprintln(os.Stdout, "GHES Actions Sync v0.2") }, } @@ -37,7 +36,7 @@ var ( os.Exit(1) return } - if err := src.Push(cmd.Context(), cacheDir, pushRepoFlags); err != nil { + if err := src.Push(cmd.Context(), pushRepoFlags); err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) return @@ -56,7 +55,7 @@ var ( os.Exit(1) return } - if err := src.Pull(cmd.Context(), cacheDir, pullRepoFlags); err != nil { + if err := src.Pull(cmd.Context(), pullRepoFlags); err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) return @@ -75,7 +74,7 @@ var ( os.Exit(1) return } - if err := src.Sync(cmd.Context(), cacheDir, syncRepoFlags); err != nil { + if err := src.Sync(cmd.Context(), syncRepoFlags); err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(1) return @@ -85,9 +84,6 @@ var ( ) func Execute(ctx context.Context) error { - rootCmd.PersistentFlags().StringVar(&cacheDir, "cache-dir", "", "Directory containing the repopositories cache created by the `pull` command") - _ = rootCmd.MarkPersistentFlagRequired("cache-dir") - rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(pushRepoCmd) diff --git a/src/commonflags.go b/src/commonflags.go new file mode 100644 index 0000000..7347418 --- /dev/null +++ b/src/commonflags.go @@ -0,0 +1,31 @@ +package src + +import ( + "github.com/spf13/cobra" +) + +// flags common to pull, push and sync operations +type CommonFlags struct { + CacheDir, RepoName, RepoNameList, RepoNameListFile string +} + +func (f *CommonFlags) Init(cmd *cobra.Command) { + cmd.Flags().StringVar(&f.CacheDir, "cache-dir", "", "Directory containing the repopositories cache created by the `pull` command") + _ = cmd.MarkFlagRequired("cache-dir") + + cmd.Flags().StringVar(&f.RepoName, "repo-name", "", "Single repository name to pull") + cmd.Flags().StringVar(&f.RepoNameList, "repo-name-list", "", "Comma delimited list of repository names to pull") + cmd.Flags().StringVar(&f.RepoNameListFile, "repo-name-list-file", "", "Path to file containing a list of repository names to pull") +} + +func (f *CommonFlags) Validate(reposRequired bool) Validations { + var validations Validations + if reposRequired && !f.HasAtLeastOneRepoFlag() { + validations = append(validations, "one of --repo-name, --repo-name-list, --repo-name-list-file must be set") + } + return validations +} + +func (f *CommonFlags) HasAtLeastOneRepoFlag() bool { + return f.RepoName != "" || f.RepoNameList != "" || f.RepoNameListFile != "" +} diff --git a/src/pull.go b/src/pull.go index feeb74d..84ae381 100644 --- a/src/pull.go +++ b/src/pull.go @@ -3,62 +3,49 @@ package src import ( "context" "fmt" - "io/ioutil" "os" "path" - "regexp" "strings" "github.com/go-git/go-git/v5" "github.com/go-git/go-git/v5/config" - "github.com/pkg/errors" "github.com/spf13/cobra" ) -var ( - RepoNameRegExp = regexp.MustCompile(`^[^/]+/\S+$`) - ErrEmptyRepoList = errors.New("repo list cannot be empty") -) +type PullOnlyFlags struct { + SourceURL string +} type PullFlags struct { - SourceURL, RepoName, RepoNameList, RepoNameListFile string + CommonFlags + PullOnlyFlags } func (f *PullFlags) Init(cmd *cobra.Command) { + f.CommonFlags.Init(cmd) + f.PullOnlyFlags.Init(cmd) +} + +func (f *PullOnlyFlags) Init(cmd *cobra.Command) { cmd.Flags().StringVar(&f.SourceURL, "source-url", "https://github.com", "The domain to pull from") - cmd.Flags().StringVar(&f.RepoName, "repo-name", "", "Single repository name to pull") - cmd.Flags().StringVar(&f.RepoNameList, "repo-name-list", "", "Comma delimited list of repository names to pull") - cmd.Flags().StringVar(&f.RepoNameListFile, "repo-name-list-file", "", "Path to file containing a list of repository names to pull") } func (f *PullFlags) Validate() Validations { + return f.CommonFlags.Validate(true).Join(f.PullOnlyFlags.Validate()) +} + +func (f *PullOnlyFlags) Validate() Validations { var validations Validations - if !f.HasAtLeastOneRepoFlag() { - validations = append(validations, "one of --repo-name, --repo-name-list, --repo-name-list-file must be set") - } return validations } -func (f *PullFlags) HasAtLeastOneRepoFlag() bool { - return f.RepoName != "" || f.RepoNameList != "" || f.RepoNameListFile != "" -} +func Pull(ctx context.Context, flags *PullFlags) error { + repoNames, err := getRepoNamesFromRepoFlags(&flags.CommonFlags) + if err != nil { + return err + } -func Pull(ctx context.Context, cacheDir string, flags *PullFlags) error { - if flags.RepoNameList != "" { - repoNames, err := getRepoNamesFromCSVString(flags.RepoNameList) - if err != nil { - return err - } - return PullManyWithGitImpl(ctx, flags.SourceURL, cacheDir, repoNames, gitImplementation{}) - } - if flags.RepoNameListFile != "" { - repoNames, err := getRepoNamesFromFile(flags.RepoNameListFile) - if err != nil { - return err - } - return PullManyWithGitImpl(ctx, flags.SourceURL, cacheDir, repoNames, gitImplementation{}) - } - return PullWithGitImpl(ctx, flags.SourceURL, cacheDir, flags.RepoName, gitImplementation{}) + return PullManyWithGitImpl(ctx, flags.SourceURL, flags.CacheDir, repoNames, gitImplementation{}) } func PullManyWithGitImpl(ctx context.Context, sourceURL, cacheDir string, repoNames []string, gitimpl GitImplementation) error { @@ -71,18 +58,11 @@ func PullManyWithGitImpl(ctx context.Context, sourceURL, cacheDir string, repoNa } func PullWithGitImpl(ctx context.Context, sourceURL, cacheDir string, repoName string, gitimpl GitImplementation) error { - repoNameParts := strings.SplitN(repoName, ":", 2) - originRepoName, err := validateRepoName(repoNameParts[0]) + originRepoName, destRepoName, err := extractSourceDest(repoName) if err != nil { return err } - destRepoName := originRepoName - if len(repoNameParts) > 1 { - destRepoName, err = validateRepoName(repoNameParts[1]) - if err != nil { - return err - } - } + _, err = os.Stat(cacheDir) if err != nil { return err @@ -122,41 +102,3 @@ func PullWithGitImpl(ctx context.Context, sourceURL, cacheDir string, repoName s return nil } - -func getRepoNamesFromCSVString(csv string) ([]string, error) { - repos := filterEmptyEntries(strings.Split(csv, ",")) - if len(repos) == 0 { - return nil, ErrEmptyRepoList - } - return repos, nil -} - -func getRepoNamesFromFile(file string) ([]string, error) { - data, err := ioutil.ReadFile(file) - if err != nil { - return nil, err - } - repos := filterEmptyEntries(strings.Split(string(data), "\n")) - if len(repos) == 0 { - return nil, ErrEmptyRepoList - } - return repos, nil -} - -func filterEmptyEntries(names []string) []string { - filtered := []string{} - for _, name := range names { - if name != "" { - filtered = append(filtered, name) - } - } - return filtered -} - -func validateRepoName(name string) (string, error) { - s := strings.TrimSpace(name) - if RepoNameRegExp.MatchString(s) { - return s, nil - } - return "", fmt.Errorf("`%s` is not a valid repo name", s) -} diff --git a/src/push.go b/src/push.go index 867471c..e42eb4a 100644 --- a/src/push.go +++ b/src/push.go @@ -3,7 +3,7 @@ package src import ( "context" "fmt" - "io/ioutil" + "os" "path" "github.com/go-git/go-git/v5" @@ -16,18 +16,32 @@ import ( "golang.org/x/oauth2" ) -type PushFlags struct { +type PushOnlyFlags struct { BaseURL, Token string DisableGitAuth bool } +type PushFlags struct { + CommonFlags + PushOnlyFlags +} + func (f *PushFlags) Init(cmd *cobra.Command) { + f.CommonFlags.Init(cmd) + f.PushOnlyFlags.Init(cmd) +} + +func (f *PushOnlyFlags) Init(cmd *cobra.Command) { cmd.Flags().StringVar(&f.BaseURL, "destination-url", "", "URL of GHES instance") cmd.Flags().StringVar(&f.Token, "destination-token", "", "Token to access API on GHES instance") cmd.Flags().BoolVar(&f.DisableGitAuth, "disable-push-git-auth", false, "Disables git authentication whilst pushing") } func (f *PushFlags) Validate() Validations { + return f.CommonFlags.Validate(false).Join(f.PushOnlyFlags.Validate()) +} + +func (f *PushOnlyFlags) Validate() Validations { var validations Validations if f.BaseURL == "" { validations = append(validations, "--destination-url must be set") @@ -38,11 +52,7 @@ func (f *PushFlags) Validate() Validations { return validations } -func Push(ctx context.Context, cacheDir string, flags *PushFlags) error { - return PushWithGitImpl(ctx, cacheDir, flags, gitImplementation{}) -} - -func PushWithGitImpl(ctx context.Context, cacheDir string, flags *PushFlags, gitimpl GitImplementation) error { +func Push(ctx context.Context, flags *PushFlags) error { ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: flags.Token}) tc := oauth2.NewClient(ctx, ts) ghClient, err := github.NewEnterpriseClient(flags.BaseURL, flags.BaseURL, tc) @@ -50,37 +60,57 @@ func PushWithGitImpl(ctx context.Context, cacheDir string, flags *PushFlags, git return errors.Wrap(err, "error creating enterprise client") } - orgDirs, err := ioutil.ReadDir(cacheDir) + repoNames, err := getRepoNamesFromRepoFlags(&flags.CommonFlags) if err != nil { - return errors.Wrapf(err, "error opening cache directory `%s`", cacheDir) + return err } - for _, orgDir := range orgDirs { - orgDirPath := path.Join(cacheDir, orgDir.Name()) - if !orgDir.IsDir() { - return errors.Errorf("unexpected file in root of cache directory `%s`", orgDirPath) - } - repoDirs, err := ioutil.ReadDir(orgDirPath) + + if repoNames == nil { + repoNames, err = getRepoNamesFromCacheDir(&flags.CommonFlags) if err != nil { - return errors.Wrapf(err, "error opening repository cache directory `%s`", orgDirPath) - } - for _, repoDir := range repoDirs { - repoDirPath := path.Join(orgDirPath, repoDir.Name()) - nwo := fmt.Sprintf("%s/%s", orgDir.Name(), repoDir.Name()) - if !orgDir.IsDir() { - return errors.Errorf("unexpected file in cache directory `%s`", nwo) - } - fmt.Printf("syncing `%s`\n", nwo) - ghRepo, err := getOrCreateGitHubRepo(ctx, ghClient, repoDir.Name(), orgDir.Name()) - if err != nil { - return errors.Wrapf(err, "error creating github repository `%s`", nwo) - } - err = syncWithCachedRepository(ctx, cacheDir, flags, ghRepo, repoDirPath, gitimpl) - if err != nil { - return errors.Wrapf(err, "error syncing repository `%s`", nwo) - } - fmt.Printf("successfully synced `%s`\n", nwo) + return err } } + + return PushManyWithGitImpl(ctx, flags, repoNames, ghClient, gitImplementation{}) +} + +func PushManyWithGitImpl(ctx context.Context, flags *PushFlags, repoNames []string, ghClient *github.Client, gitimpl GitImplementation) error { + for _, repoName := range repoNames { + if err := PushWithGitImpl(ctx, flags, repoName, ghClient, gitimpl); err != nil { + return err + } + } + return nil +} + +func PushWithGitImpl(ctx context.Context, flags *PushFlags, repoName string, ghClient *github.Client, gitimpl GitImplementation) error { + _, nwo, err := extractSourceDest(repoName) + if err != nil { + return err + } + + ownerName, bareRepoName, err := splitNwo(nwo) + if err != nil { + return err + } + + repoDirPath := path.Join(flags.CacheDir, nwo) + _, err = os.Stat(repoDirPath) + if err != nil { + return err + } + + fmt.Printf("syncing `%s`\n", nwo) + ghRepo, err := getOrCreateGitHubRepo(ctx, ghClient, bareRepoName, ownerName) + if err != nil { + return errors.Wrapf(err, "error creating github repository `%s`", nwo) + } + err = syncWithCachedRepository(ctx, flags, ghRepo, repoDirPath, gitimpl) + if err != nil { + return errors.Wrapf(err, "error syncing repository `%s`", nwo) + } + fmt.Printf("successfully synced `%s`\n", nwo) return nil } @@ -105,10 +135,10 @@ func getOrCreateGitHubRepo(ctx context.Context, client *github.Client, repoName, return ghRepo, nil } -func syncWithCachedRepository(ctx context.Context, cacheDir string, flags *PushFlags, ghRepo *github.Repository, repoDir string, gitimpl GitImplementation) error { +func syncWithCachedRepository(ctx context.Context, flags *PushFlags, ghRepo *github.Repository, repoDir string, gitimpl GitImplementation) error { gitRepo, err := gitimpl.NewGitRepository(repoDir) if err != nil { - return errors.Wrapf(err, "error opening git repository %s", cacheDir) + return errors.Wrapf(err, "error opening git repository %s", flags.CacheDir) } _ = gitRepo.DeleteRemote("ghes") remote, err := gitRepo.CreateRemote(&config.RemoteConfig{ diff --git a/src/reponames.go b/src/reponames.go new file mode 100644 index 0000000..d66e48b --- /dev/null +++ b/src/reponames.go @@ -0,0 +1,131 @@ +package src + +import ( + "fmt" + "io/ioutil" + "path" + "regexp" + "strings" + + "github.com/pkg/errors" +) + +var ( + NwoRegExp = regexp.MustCompile(`^[^/\s]+/[^/\s]+$`) + ErrEmptyRepoList = errors.New("repo list cannot be empty") + ErrEmptyCacheDir = errors.New("cache directory contains no actions to sync") +) + +func getRepoNamesFromRepoFlags(flags *CommonFlags) ([]string, error) { + if flags.RepoNameList != "" { + return getRepoNamesFromCSVString(flags.RepoNameList) + } + + if flags.RepoNameListFile != "" { + return getRepoNamesFromFile(flags.RepoNameListFile) + } + + if flags.RepoName != "" { + return []string{flags.RepoName}, nil + } + + return nil, nil +} + +func getRepoNamesFromCacheDir(flags *CommonFlags) ([]string, error) { + repoNames := make([]string, 0) + + orgDirs, err := ioutil.ReadDir(flags.CacheDir) + if err != nil { + return nil, errors.Wrapf(err, "error opening cache directory `%s`", flags.CacheDir) + } + for _, orgDir := range orgDirs { + orgDirPath := path.Join(flags.CacheDir, orgDir.Name()) + if !orgDir.IsDir() { + return nil, errors.Errorf("unexpected file in root of cache directory `%s`", orgDirPath) + } + repoDirs, err := ioutil.ReadDir(orgDirPath) + if err != nil { + return nil, errors.Wrapf(err, "error opening repository cache directory `%s`", orgDirPath) + } + for _, repoDir := range repoDirs { + nwo := fmt.Sprintf("%s/%s", orgDir.Name(), repoDir.Name()) + repoNames = append(repoNames, nwo) + } + } + + if len(repoNames) == 0 { + return nil, ErrEmptyCacheDir + } + + return repoNames, nil +} + +func getRepoNamesFromCSVString(csv string) ([]string, error) { + repos := filterEmptyEntries(strings.Split(csv, ",")) + if len(repos) == 0 { + return nil, ErrEmptyRepoList + } + return repos, nil +} + +func getRepoNamesFromFile(file string) ([]string, error) { + data, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + repos := filterEmptyEntries(strings.Split(string(data), "\n")) + if len(repos) == 0 { + return nil, ErrEmptyRepoList + } + return repos, nil +} + +func filterEmptyEntries(names []string) []string { + filtered := []string{} + for _, name := range names { + if name != "" { + filtered = append(filtered, name) + } + } + return filtered +} + +func extractSourceDest(repoName string) (string, string, error) { + repoNameParts := strings.Split(repoName, ":") + if len(repoNameParts) > 2 { + return "", "", fmt.Errorf("`%s` is not a valid repo name. Use a single colon to separate source and destination arguments. Example: `upstream_owner/upstream_repo:destination_owner/destination_repo`", repoName) + } + + originNwo, err := validateNwo(repoNameParts[0]) + if err != nil { + return "", "", err + } + + destNwo := originNwo + if len(repoNameParts) > 1 { + destNwo, err = validateNwo(repoNameParts[1]) + if err != nil { + return "", "", err + } + } + + return originNwo, destNwo, nil +} + +func validateNwo(nwo string) (string, error) { + s := strings.TrimSpace(nwo) + if NwoRegExp.MatchString(s) { + return s, nil + } + return "", fmt.Errorf("`%s` is not a valid repo name", s) +} + +func splitNwo(nwo string) (string, string, error) { + nwoParts := strings.Split(nwo, "/") + if len(nwoParts) != 2 { + return "", "", fmt.Errorf("`%s` is not a valid repo name", nwo) + } + + return nwoParts[0], nwoParts[1], nil +} diff --git a/src/sync.go b/src/sync.go index fb228a2..52e20ec 100644 --- a/src/sync.go +++ b/src/sync.go @@ -7,24 +7,30 @@ import ( ) type SyncFlags struct { - PullFlags - PushFlags + CommonFlags + PullOnlyFlags + PushOnlyFlags } func (f *SyncFlags) Init(cmd *cobra.Command) { - f.PullFlags.Init(cmd) - f.PushFlags.Init(cmd) + f.CommonFlags.Init(cmd) + f.PullOnlyFlags.Init(cmd) + f.PushOnlyFlags.Init(cmd) } func (f *SyncFlags) Validate() Validations { - return f.PullFlags.Validate().Join(f.PushFlags.Validate()) + return f.CommonFlags.Validate(true).Join(f.PullOnlyFlags.Validate().Join(f.PushOnlyFlags.Validate())) } -func Sync(ctx context.Context, cacheDir string, flags *SyncFlags) error { - if err := Pull(ctx, cacheDir, &flags.PullFlags); err != nil { +func Sync(ctx context.Context, flags *SyncFlags) error { + + pullFlags := &PullFlags{flags.CommonFlags, flags.PullOnlyFlags} + pushFlags := &PushFlags{flags.CommonFlags, flags.PushOnlyFlags} + + if err := Pull(ctx, pullFlags); err != nil { return err } - if err := Push(ctx, cacheDir, &flags.PushFlags); err != nil { + if err := Push(ctx, pushFlags); err != nil { return err } return nil