diff --git a/policy/rego.go b/policy/rego.go index 37e606e..a879c32 100644 --- a/policy/rego.go +++ b/policy/rego.go @@ -88,10 +88,7 @@ func (re *regoEvaluator) Evaluate(ctx context.Context, resolver attestation.Reso rego.Store(store), rego.GenerateJSON(jsonGenerator[Result]()), ) - regoFnOpts := ®oFnOpts{ - attestationResolver: resolver, - attestationVerifier: re.attestationVerifier, - } + regoFnOpts := NewRegoFunctionOptions(resolver, re.attestationVerifier) for _, custom := range RegoFunctions(regoFnOpts) { regoOpts = append(regoOpts, custom.Func) } @@ -175,7 +172,7 @@ func handleErrors2(f func(rCtx rego.BuiltinContext, a, b *ast.Term) (*ast.Term, } } -func RegoFunctions(regoOpts *regoFnOpts) []*tester.Builtin { +func RegoFunctions(regoOpts *RegoFnOpts) []*tester.Builtin { return []*tester.Builtin{ { Decl: verifyDecl, @@ -186,7 +183,7 @@ func RegoFunctions(regoOpts *regoFnOpts) []*tester.Builtin { Memoize: true, Nondeterministic: verifyDecl.Nondeterministic, }, - handleErrors2(verifyInTotoEnvelope(regoOpts))), + handleErrors2(regoOpts.verifyInTotoEnvelope)), }, { Decl: attestDecl, @@ -197,89 +194,97 @@ func RegoFunctions(regoOpts *regoFnOpts) []*tester.Builtin { Memoize: true, Nondeterministic: attestDecl.Nondeterministic, }, - handleErrors1(fetchInTotoAttestations(regoOpts))), + handleErrors1(regoOpts.fetchInTotoAttestations)), }, } } -func fetchInTotoAttestations(regoOpts *regoFnOpts) rego.Builtin1 { - return func(rCtx rego.BuiltinContext, predicateTypeTerm *ast.Term) (*ast.Term, error) { - predicateTypeStr, ok := predicateTypeTerm.Value.(ast.String) - if !ok { - return nil, fmt.Errorf("predicateTypeTerm is not a string") - } - predicateType := string(predicateTypeStr) +// because we don't control the signature here (blame rego) +// nolint:gocritic +func (regoOpts *RegoFnOpts) fetchInTotoAttestations(rCtx rego.BuiltinContext, predicateTypeTerm *ast.Term) (*ast.Term, error) { + predicateTypeStr, ok := predicateTypeTerm.Value.(ast.String) + if !ok { + return nil, fmt.Errorf("predicateTypeTerm is not a string") + } + predicateType := string(predicateTypeStr) - envelopes, err := regoOpts.attestationResolver.Attestations(rCtx.Context, predicateType) + envelopes, err := regoOpts.attestationResolver.Attestations(rCtx.Context, predicateType) + if err != nil { + return nil, err + } + + // Convert each envelope to an ast.Value. + values := make([]*ast.Term, len(envelopes)) + for i, envelope := range envelopes { + value, err := ast.InterfaceToValue(envelope) if err != nil { return nil, err } - - // Convert each envelope to an ast.Value. - values := make([]*ast.Term, len(envelopes)) - for i, envelope := range envelopes { - value, err := ast.InterfaceToValue(envelope) - if err != nil { - return nil, err - } - values[i] = ast.NewTerm(value) - } - - // Wrap the values in an ast.Set and convert it to an ast.Term. - set := ast.NewTerm(ast.NewSet(values...)) - - return set, nil + values[i] = ast.NewTerm(value) } + + // Wrap the values in an ast.Set and convert it to an ast.Term. + set := ast.NewTerm(ast.NewSet(values...)) + + return set, nil } -type regoFnOpts struct { +type RegoFnOpts struct { attestationResolver attestation.Resolver attestationVerifier attestation.Verifier } -func verifyInTotoEnvelope(regoOpts *regoFnOpts) rego.Builtin2 { - return func(rCtx rego.BuiltinContext, envTerm, optsTerm *ast.Term) (*ast.Term, error) { - env := new(attestation.Envelope) - opts := new(attestation.VerifyOptions) - err := ast.As(envTerm.Value, env) - if err != nil { - return nil, fmt.Errorf("failed to cast envelope: %w", err) - } - err = ast.As(optsTerm.Value, &opts) - if err != nil { - return nil, fmt.Errorf("failed to cast verifier options: %w", err) - } - payload, err := attestation.VerifyDSSE(rCtx.Context, regoOpts.attestationVerifier, env, opts) - if err != nil { - return nil, fmt.Errorf("failed to verify envelope: %w", err) - } - - statement := new(intoto.Statement) - - switch env.PayloadType { - case intoto.PayloadType: - err = json.Unmarshal(payload, statement) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal statement: %w", err) - } - // TODO: implement other types of envelope - default: - return nil, fmt.Errorf("unsupported payload type: %s", env.PayloadType) - } - - err = VerifySubject(rCtx.Context, statement.Subject, regoOpts.attestationResolver) - if err != nil { - return nil, fmt.Errorf("failed to verify subject: %w", err) - } - - value, err := ast.InterfaceToValue(statement) - if err != nil { - return nil, err - } - return ast.NewTerm(value), nil +// this is exported for testing here and in clients of the library. +func NewRegoFunctionOptions(resolver attestation.Resolver, verifier attestation.Verifier) *RegoFnOpts { + return &RegoFnOpts{ + attestationResolver: resolver, + attestationVerifier: verifier, } } +// because we don't control the signature here (blame rego) +// nolint:gocritic +func (regoOpts *RegoFnOpts) verifyInTotoEnvelope(rCtx rego.BuiltinContext, envTerm, optsTerm *ast.Term) (*ast.Term, error) { + env := new(attestation.Envelope) + opts := new(attestation.VerifyOptions) + err := ast.As(envTerm.Value, env) + if err != nil { + return nil, fmt.Errorf("failed to cast envelope: %w", err) + } + err = ast.As(optsTerm.Value, &opts) + if err != nil { + return nil, fmt.Errorf("failed to cast verifier options: %w", err) + } + payload, err := attestation.VerifyDSSE(rCtx.Context, regoOpts.attestationVerifier, env, opts) + if err != nil { + return nil, fmt.Errorf("failed to verify envelope: %w", err) + } + + statement := new(intoto.Statement) + + switch env.PayloadType { + case intoto.PayloadType: + err = json.Unmarshal(payload, statement) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal statement: %w", err) + } + // TODO: implement other types of envelope + default: + return nil, fmt.Errorf("unsupported payload type: %s", env.PayloadType) + } + + err = VerifySubject(rCtx.Context, statement.Subject, regoOpts.attestationResolver) + if err != nil { + return nil, fmt.Errorf("failed to verify subject: %w", err) + } + + value, err := ast.InterfaceToValue(statement) + if err != nil { + return nil, err + } + return ast.NewTerm(value), nil +} + func loadYAML(path string, bs []byte) (interface{}, error) { var x interface{} bs, err := yaml.YAMLToJSON(bs)