From a4ac09e7da6357302baf720173c3d4348971ba83 Mon Sep 17 00:00:00 2001 From: James Carnegie Date: Thu, 29 Aug 2024 17:43:45 +0100 Subject: [PATCH] refactor! don't use ctx for policy evaluator (#140) * refactor! don't use ctx for policy evaluator --- internal/test/test.go | 15 ++---------- pkg/attest/verify.go | 40 ++++++++++++++----------------- pkg/attest/verify_test.go | 8 ++----- pkg/attestation/referrers_test.go | 1 - pkg/policy/evaluator.go | 19 --------------- pkg/policy/types.go | 1 + pkg/tuf/registry_test.go | 15 ++++++------ 7 files changed, 31 insertions(+), 68 deletions(-) diff --git a/internal/test/test.go b/internal/test/test.go index 6bb22c5..1b61cc2 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -6,16 +6,14 @@ import ( "path/filepath" "testing" - "github.com/docker/attest/pkg/policy" "github.com/docker/attest/pkg/signerverifier" "github.com/docker/attest/pkg/tlog" "github.com/secure-systems-lab/go-securesystemslib/dsse" ) const ( - UseMockTL = true - UseMockKMS = true - UseMockPolicy = true + UseMockTL = true + UseMockKMS = true AWSRegion = "us-east-1" AWSKMSKeyARN = "arn:aws:kms:us-east-1:175142243308:alias/doi-signing" // sandbox @@ -57,15 +55,6 @@ func Setup(t *testing.T) (context.Context, dsse.SignerVerifier) { ctx := tlog.WithTL(context.Background(), tl) - var policyEvaluator policy.Evaluator - if UseMockPolicy { - policyEvaluator = policy.GetMockPolicy() - } else { - policyEvaluator = policy.NewRegoEvaluator(true) - } - - ctx = policy.WithPolicyEvaluator(ctx, policyEvaluator) - var signer dsse.SignerVerifier var err error if UseMockKMS { diff --git a/pkg/attest/verify.go b/pkg/attest/verify.go index e57c0c6..f06507d 100644 --- a/pkg/attest/verify.go +++ b/pkg/attest/verify.go @@ -44,7 +44,7 @@ func NewVerifier(opts *policy.Options) (Verifier, error) { }, nil } -func (v *tufVerifier) Verify(ctx context.Context, src *oci.ImageSpec) (result *VerificationResult, err error) { +func (verifier *tufVerifier) Verify(ctx context.Context, src *oci.ImageSpec) (result *VerificationResult, err error) { // so that we can resolve mapping from the image name earlier detailsResolver, err := policy.CreateImageDetailsResolver(src) if err != nil { @@ -54,35 +54,36 @@ func (v *tufVerifier) Verify(ctx context.Context, src *oci.ImageSpec) (result *V if err != nil { return nil, fmt.Errorf("failed to resolve image name: %w", err) } - policyResolver := policy.NewResolver(v.tufClient, v.opts) - pctx, err := policyResolver.ResolvePolicy(ctx, imageName) + policyResolver := policy.NewResolver(verifier.tufClient, verifier.opts) + resolvedPolicy, err := policyResolver.ResolvePolicy(ctx, imageName) if err != nil { return nil, fmt.Errorf("failed to resolve policy: %w", err) } - if pctx == nil { + if resolvedPolicy == nil { return &VerificationResult{ Outcome: OutcomeNoPolicy, }, nil } // this is overriding the mapping with a referrers config. Useful for testing if nothing else - if v.opts.ReferrersRepo != "" { - pctx.Mapping.Attestations = &config.AttestationConfig{ - Repo: v.opts.ReferrersRepo, + if verifier.opts.ReferrersRepo != "" { + resolvedPolicy.Mapping.Attestations = &config.AttestationConfig{ + Repo: verifier.opts.ReferrersRepo, Style: config.AttestationStyleReferrers, } - } else if v.opts.AttestationStyle == config.AttestationStyleAttached { - pctx.Mapping.Attestations = &config.AttestationConfig{ - Repo: v.opts.ReferrersRepo, + } else if verifier.opts.AttestationStyle == config.AttestationStyleAttached { + resolvedPolicy.Mapping.Attestations = &config.AttestationConfig{ + Repo: verifier.opts.ReferrersRepo, Style: config.AttestationStyleAttached, } } // because we have a mapping now, we can select a resolver based on its contents (ie. referrers or attached) - resolver, err := policy.CreateAttestationResolver(detailsResolver, pctx.Mapping) + resolver, err := policy.CreateAttestationResolver(detailsResolver, resolvedPolicy.Mapping) if err != nil { return nil, fmt.Errorf("failed to create attestation resolver: %w", err) } - result, err = VerifyAttestations(ctx, resolver, pctx) + evaluator := policy.NewRegoEvaluator(verifier.opts.Debug) + result, err = VerifyAttestations(ctx, resolver, evaluator, resolvedPolicy) if err != nil { return nil, fmt.Errorf("failed to evaluate policy: %w", err) } @@ -183,7 +184,7 @@ func toVerificationResult(p *policy.Policy, input *policy.Input, result *policy. }, nil } -func VerifyAttestations(ctx context.Context, resolver attestation.Resolver, pctx *policy.Policy) (*VerificationResult, error) { +func VerifyAttestations(ctx context.Context, resolver attestation.Resolver, evaluator policy.Evaluator, resolvedPolicy *policy.Policy) (*VerificationResult, error) { desc, err := resolver.ImageDescriptor(ctx) if err != nil { return nil, fmt.Errorf("failed to get image descriptor: %w", err) @@ -198,7 +199,7 @@ func VerifyAttestations(ctx context.Context, resolver attestation.Resolver, pctx return nil, err } - if pctx.ResolvedName != "" { + if resolvedPolicy.ResolvedName != "" { // this means the name we have is not the one we want to use for policy evaluation // so we need to replace it with the one we resolved during policy resolution. // this can happen if the name is an alias for another image, e.g. if it is a mirror @@ -207,7 +208,7 @@ func VerifyAttestations(ctx context.Context, resolver attestation.Resolver, pctx return nil, fmt.Errorf("failed to parse image name: %w", err) } oldName := ref.Name() - name = strings.Replace(name, oldName, pctx.ResolvedName, 1) + name = strings.Replace(name, oldName, resolvedPolicy.ResolvedName, 1) } ref, err := reference.ParseNormalizedNamed(name) @@ -239,16 +240,11 @@ func VerifyAttestations(ctx context.Context, resolver attestation.Resolver, pctx if tag != "" { input.Tag = tag } - - evaluator, err := policy.GetPolicyEvaluator(ctx) - if err != nil { - return nil, err - } - result, err := evaluator.Evaluate(ctx, resolver, pctx, input) + result, err := evaluator.Evaluate(ctx, resolver, resolvedPolicy, input) if err != nil { return nil, fmt.Errorf("policy evaluation failed: %w", err) } - verificationResult, err := toVerificationResult(pctx, input, result) + verificationResult, err := toVerificationResult(resolvedPolicy, input, result) if err != nil { return nil, fmt.Errorf("failed to convert to policy result: %w", err) } diff --git a/pkg/attest/verify_test.go b/pkg/attest/verify_test.go index 923a06d..15624b2 100644 --- a/pkg/attest/verify_test.go +++ b/pkg/attest/verify_test.go @@ -45,7 +45,7 @@ func TestVerifyAttestations(t *testing.T) { {"policy ok", nil, nil}, {"policy error", fmt.Errorf("policy error"), fmt.Errorf("policy evaluation failed: policy error")}, } - + ctx := context.Background() for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { mockPE := policy.MockPolicyEvaluator{ @@ -54,8 +54,7 @@ func TestVerifyAttestations(t *testing.T) { }, } - ctx := policy.WithPolicyEvaluator(context.Background(), &mockPE) - _, err := VerifyAttestations(ctx, resolver, &policy.Policy{ResolvedName: ""}) + _, err := VerifyAttestations(ctx, resolver, &mockPE, &policy.Policy{ResolvedName: ""}) if tc.expectedError != nil { if assert.Error(t, err) { assert.Equal(t, tc.expectedError.Error(), err.Error()) @@ -69,7 +68,6 @@ func TestVerifyAttestations(t *testing.T) { func TestVSA(t *testing.T) { ctx, signer := test.Setup(t) - ctx = policy.WithPolicyEvaluator(ctx, policy.NewRegoEvaluator(true)) // setup an image with signed attestations outputLayout := test.CreateTempDir(t, "", TestTempDir) @@ -122,7 +120,6 @@ func TestVSA(t *testing.T) { func TestVerificationFailure(t *testing.T) { ctx, signer := test.Setup(t) - ctx = policy.WithPolicyEvaluator(ctx, policy.NewRegoEvaluator(true)) // setup an image with signed attestations outputLayout := test.CreateTempDir(t, "", TestTempDir) @@ -175,7 +172,6 @@ func TestVerificationFailure(t *testing.T) { func TestSignVerify(t *testing.T) { ctx, signer := test.Setup(t) - ctx = policy.WithPolicyEvaluator(ctx, policy.NewRegoEvaluator(true)) // setup an image with signed attestations outputLayout := test.CreateTempDir(t, "", TestTempDir) diff --git a/pkg/attestation/referrers_test.go b/pkg/attestation/referrers_test.go index 913fec4..fc548c1 100644 --- a/pkg/attestation/referrers_test.go +++ b/pkg/attestation/referrers_test.go @@ -32,7 +32,6 @@ var ( func TestAttestationReferenceTypes(t *testing.T) { ctx, signer := test.Setup(t) - ctx = policy.WithPolicyEvaluator(ctx, policy.NewRegoEvaluator(true)) platforms := []string{"linux/amd64", "linux/arm64"} for _, tc := range []struct { name string diff --git a/pkg/policy/evaluator.go b/pkg/policy/evaluator.go index efefefd..76369cd 100644 --- a/pkg/policy/evaluator.go +++ b/pkg/policy/evaluator.go @@ -2,29 +2,10 @@ package policy import ( "context" - "fmt" "github.com/docker/attest/pkg/attestation" ) -type policyEvaluatorCtxKeyType struct{} - -var PolicyEvaluatorCtxKey policyEvaluatorCtxKeyType - -// sets PolicyEvaluator in context. -func WithPolicyEvaluator(ctx context.Context, pe Evaluator) context.Context { - return context.WithValue(ctx, PolicyEvaluatorCtxKey, pe) -} - -// gets PolicyEvaluator from context, defaults to Rego PolicyEvaluator if not set. -func GetPolicyEvaluator(ctx context.Context) (Evaluator, error) { - t, ok := ctx.Value(PolicyEvaluatorCtxKey).(Evaluator) - if !ok { - return nil, fmt.Errorf("no policy evaluator client set on context (set one with policy.WithPolicyEvaluator)") - } - return t, nil -} - type Evaluator interface { Evaluate(ctx context.Context, resolver attestation.Resolver, pctx *Policy, input *Input) (*Result, error) } diff --git a/pkg/policy/types.go b/pkg/policy/types.go index bc949ec..6d3e0dd 100644 --- a/pkg/policy/types.go +++ b/pkg/policy/types.go @@ -34,6 +34,7 @@ type Options struct { PolicyID string ReferrersRepo string AttestationStyle config.AttestationStyle + Debug bool } type Policy struct { diff --git a/pkg/tuf/registry_test.go b/pkg/tuf/registry_test.go index d648470..78dd92a 100644 --- a/pkg/tuf/registry_test.go +++ b/pkg/tuf/registry_test.go @@ -21,6 +21,7 @@ import ( "github.com/google/go-containerregistry/pkg/v1/static" "github.com/google/go-containerregistry/pkg/v1/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go/modules/registry" "github.com/theupdateframework/go-tuf/v2/metadata" "github.com/theupdateframework/go-tuf/v2/metadata/config" @@ -56,7 +57,7 @@ func TestRegistryFetcher(t *testing.T) { delegatedTargetFile := fmt.Sprintf("%s/%s", delegatedRole, targetFile) cfg, err := config.New(metadataRepo, DockerTUFRootDev.Data) - assert.NoError(t, err) + require.NoError(t, err) cfg.Fetcher = NewRegistryFetcher(metadataRepo, metadataImgTag, targetsRepo) cfg.LocalMetadataDir = dir @@ -65,23 +66,23 @@ func TestRegistryFetcher(t *testing.T) { // create a new Updater instance up, err := updater.New(cfg) - assert.NoError(t, err) + require.NoError(t, err) // refresh the metadata err = up.Refresh() - assert.NoError(t, err) + require.NoError(t, err) // download top-level target targetInfo, err := up.GetTargetInfo(targetFile) - assert.NoError(t, err) + require.NoError(t, err) _, _, err = up.DownloadTarget(targetInfo, filepath.Join(dir, targetInfo.Path), "") - assert.NoError(t, err) + require.NoError(t, err) // download delegated target targetInfo, err = up.GetTargetInfo(delegatedTargetFile) - assert.NoError(t, err) + require.NoError(t, err) _, _, err = up.DownloadTarget(targetInfo, filepath.Join(delegatedDir, targetFile), "") - assert.NoError(t, err) + require.NoError(t, err) } func TestRoleFromConsistentName(t *testing.T) {