diff --git a/pkg/cmd/attestation/download/download.go b/pkg/cmd/attestation/download/download.go index 77c9280933f..143912308d3 100644 --- a/pkg/cmd/attestation/download/download.go +++ b/pkg/cmd/attestation/download/download.go @@ -122,14 +122,13 @@ func runDownload(opts *Options) error { opts.Logger.VerbosePrintf("Downloading trusted metadata for artifact %s\n\n", opts.ArtifactPath) - c := verification.FetchAttestationsConfig{ - APIClient: opts.APIClient, - Digest: artifact.DigestWithAlg(), - Limit: opts.Limit, - Owner: opts.Owner, - Repo: opts.Repo, + params := verification.FetchRemoteAttestationsParams{ + Digest: artifact.DigestWithAlg(), + Limit: opts.Limit, + Owner: opts.Owner, + Repo: opts.Repo, } - attestations, err := verification.GetRemoteAttestations(c) + attestations, err := verification.GetRemoteAttestations(opts.APIClient, params) if err != nil { if errors.Is(err, api.ErrNoAttestations{}) { fmt.Fprintf(opts.Logger.IO.Out, "No attestations found for %s\n", opts.ArtifactPath) diff --git a/pkg/cmd/attestation/verification/attestation.go b/pkg/cmd/attestation/verification/attestation.go index 0ea91c2f7f0..07083a5c0a4 100644 --- a/pkg/cmd/attestation/verification/attestation.go +++ b/pkg/cmd/attestation/verification/attestation.go @@ -9,8 +9,8 @@ import ( "path/filepath" "github.com/cli/cli/v2/pkg/cmd/attestation/api" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" - "github.com/google/go-containerregistry/pkg/name" protobundle "github.com/sigstore/protobuf-specs/gen/pb-go/bundle/v1" "github.com/sigstore/sigstore-go/pkg/bundle" ) @@ -20,32 +20,11 @@ const SLSAPredicateV1 = "https://slsa.dev/provenance/v1" var ErrUnrecognisedBundleExtension = errors.New("bundle file extension not supported, must be json or jsonl") var ErrEmptyBundleFile = errors.New("provided bundle file is empty") -type FetchAttestationsConfig struct { - APIClient api.Client - BundlePath string - Digest string - Limit int - Owner string - Repo string - OCIClient oci.Client - UseBundleFromRegistry bool - NameRef name.Reference -} - -func (c *FetchAttestationsConfig) IsBundleProvided() bool { - return c.BundlePath != "" -} - -func GetAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) { - if c.IsBundleProvided() { - return GetLocalAttestations(c.BundlePath) - } - - if c.UseBundleFromRegistry { - return GetOCIAttestations(c) - } - - return GetRemoteAttestations(c) +type FetchRemoteAttestationsParams struct { + Digest string + Limit int + Owner string + Repo string } // GetLocalAttestations returns a slice of attestations read from a local bundle file. @@ -116,30 +95,30 @@ func loadBundlesFromJSONLinesFile(path string) ([]*api.Attestation, error) { return attestations, nil } -func GetRemoteAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) { - if c.APIClient == nil { +func GetRemoteAttestations(client api.Client, params FetchRemoteAttestationsParams) ([]*api.Attestation, error) { + if client == nil { return nil, fmt.Errorf("api client must be provided") } // check if Repo is set first because if Repo has been set, Owner will be set using the value of Repo. // If Repo is not set, the field will remain empty. It will not be populated using the value of Owner. - if c.Repo != "" { - attestations, err := c.APIClient.GetByRepoAndDigest(c.Repo, c.Digest, c.Limit) + if params.Repo != "" { + attestations, err := client.GetByRepoAndDigest(params.Repo, params.Digest, params.Limit) if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", c.Repo, err) + return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Repo, err) } return attestations, nil - } else if c.Owner != "" { - attestations, err := c.APIClient.GetByOwnerAndDigest(c.Owner, c.Digest, c.Limit) + } else if params.Owner != "" { + attestations, err := client.GetByOwnerAndDigest(params.Owner, params.Digest, params.Limit) if err != nil { - return nil, fmt.Errorf("failed to fetch attestations from %s: %w", c.Owner, err) + return nil, fmt.Errorf("failed to fetch attestations from %s: %w", params.Owner, err) } return attestations, nil } return nil, fmt.Errorf("owner or repo must be provided") } -func GetOCIAttestations(c FetchAttestationsConfig) ([]*api.Attestation, error) { - attestations, err := c.OCIClient.GetAttestations(c.NameRef, c.Digest) +func GetOCIAttestations(client oci.Client, artifact artifact.DigestedArtifact) ([]*api.Attestation, error) { + attestations, err := client.GetAttestations(artifact.NameRef(), artifact.Digest()) if err != nil { return nil, fmt.Errorf("failed to fetch OCI attestations: %w", err) } diff --git a/pkg/cmd/attestation/verification/extensions.go b/pkg/cmd/attestation/verification/extensions.go index e302d89c9d1..a0827e9ec13 100644 --- a/pkg/cmd/attestation/verification/extensions.go +++ b/pkg/cmd/attestation/verification/extensions.go @@ -20,7 +20,7 @@ func VerifyCertExtensions(results []*AttestationProcessingResult, ec Enforcement var lastErr error for _, attestation := range results { - err := verifyCertExtensions(*attestation.VerificationResult.Signature.Certificate, ec) + err := verifyCertExtensions(*attestation.VerificationResult.Signature.Certificate, ec.Certificate) if err == nil { // if at least one attestation is verified, we're good as verification // is defined as successful if at least one attestation is verified @@ -34,28 +34,23 @@ func VerifyCertExtensions(results []*AttestationProcessingResult, ec Enforcement return lastErr } -func verifyCertExtensions(verifiedCert certificate.Summary, criteria EnforcementCriteria) error { - sourceRepositoryOwnerURI := verifiedCert.Extensions.SourceRepositoryOwnerURI - if !strings.EqualFold(criteria.Certificate.SourceRepositoryOwnerURI, sourceRepositoryOwnerURI) { - return fmt.Errorf("expected SourceRepositoryOwnerURI to be %s, got %s", criteria.Certificate.SourceRepositoryOwnerURI, sourceRepositoryOwnerURI) +func verifyCertExtensions(given, expected certificate.Summary) error { + if !strings.EqualFold(expected.SourceRepositoryOwnerURI, given.SourceRepositoryOwnerURI) { + return fmt.Errorf("expected SourceRepositoryOwnerURI to be %s, got %s", expected.SourceRepositoryOwnerURI, given.SourceRepositoryOwnerURI) } - // if repo is set, check the SourceRepositoryURI field - if criteria.Certificate.SourceRepositoryURI != "" { - sourceRepositoryURI := verifiedCert.Extensions.SourceRepositoryURI - if !strings.EqualFold(criteria.Certificate.SourceRepositoryURI, sourceRepositoryURI) { - return fmt.Errorf("expected SourceRepositoryURI to be %s, got %s", criteria.Certificate.SourceRepositoryURI, sourceRepositoryURI) - } + // if repo is set, compare the SourceRepositoryURI fields + if expected.SourceRepositoryURI != "" && !strings.EqualFold(expected.SourceRepositoryURI, given.SourceRepositoryURI) { + return fmt.Errorf("expected SourceRepositoryURI to be %s, got %s", expected.SourceRepositoryURI, given.SourceRepositoryURI) } - // if issuer is anything other than the default, use the user-provided value; - // otherwise, select the appropriate default based on the tenant - certIssuer := verifiedCert.Extensions.Issuer - if !strings.EqualFold(criteria.Certificate.Issuer, certIssuer) { - if strings.Index(certIssuer, criteria.Certificate.Issuer+"/") == 0 { - return fmt.Errorf("expected Issuer to be %s, got %s -- if you have a custom OIDC issuer policy for your enterprise, use the --cert-oidc-issuer flag with your expected issuer", criteria.Certificate.Issuer, certIssuer) + // compare the OIDC issuers. If not equal, return an error depending + // on if there is a partial match + if !strings.EqualFold(expected.Issuer, given.Issuer) { + if strings.Index(given.Issuer, expected.Issuer+"/") == 0 { + return fmt.Errorf("expected Issuer to be %s, got %s -- if you have a custom OIDC issuer policy for your enterprise, use the --cert-oidc-issuer flag with your expected issuer", expected.Issuer, given.Issuer) } - return fmt.Errorf("expected Issuer to be %s, got %s", criteria.Certificate.Issuer, certIssuer) + return fmt.Errorf("expected Issuer to be %s, got %s", expected.Issuer, given.Issuer) } return nil diff --git a/pkg/cmd/attestation/verify/attestation.go b/pkg/cmd/attestation/verify/attestation.go new file mode 100644 index 00000000000..f3f2792c419 --- /dev/null +++ b/pkg/cmd/attestation/verify/attestation.go @@ -0,0 +1,50 @@ +package verify + +import ( + "fmt" + + "github.com/cli/cli/v2/internal/text" + "github.com/cli/cli/v2/pkg/cmd/attestation/api" + "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" + "github.com/cli/cli/v2/pkg/cmd/attestation/verification" +) + +func getAttestations(o *Options, a artifact.DigestedArtifact) ([]*api.Attestation, string, error) { + if o.BundlePath != "" { + attestations, err := verification.GetLocalAttestations(o.BundlePath) + if err != nil { + msg := fmt.Sprintf("✗ Loading attestations from %s failed", a.URL) + return nil, msg, err + } + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.BundlePath) + return attestations, msg, nil + } + + if o.UseBundleFromRegistry { + attestations, err := verification.GetOCIAttestations(o.OCIClient, a) + if err != nil { + msg := "✗ Loading attestations from OCI registry failed" + return nil, msg, err + } + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg := fmt.Sprintf("Loaded %s from %s", pluralAttestation, o.ArtifactPath) + return attestations, msg, nil + } + + params := verification.FetchRemoteAttestationsParams{ + Digest: a.DigestWithAlg(), + Limit: o.Limit, + Owner: o.Owner, + Repo: o.Repo, + } + + attestations, err := verification.GetRemoteAttestations(o.APIClient, params) + if err != nil { + msg := "✗ Loading attestations from GitHub API failed" + return nil, msg, err + } + pluralAttestation := text.Pluralize(len(attestations), "attestation") + msg := fmt.Sprintf("Loaded %s from GitHub API", pluralAttestation) + return attestations, msg, nil +} diff --git a/pkg/cmd/attestation/verify/verify.go b/pkg/cmd/attestation/verify/verify.go index 47b52bb30de..1d34fdf99ca 100644 --- a/pkg/cmd/attestation/verify/verify.go +++ b/pkg/cmd/attestation/verify/verify.go @@ -6,7 +6,6 @@ import ( "regexp" "github.com/cli/cli/v2/internal/ghinstance" - "github.com/cli/cli/v2/internal/text" "github.com/cli/cli/v2/pkg/cmd/attestation/api" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact" "github.com/cli/cli/v2/pkg/cmd/attestation/artifact/oci" @@ -222,42 +221,18 @@ func runVerify(opts *Options) error { opts.Logger.Printf("Loaded digest %s for %s\n", artifact.DigestWithAlg(), artifact.URL) - c := verification.FetchAttestationsConfig{ - APIClient: opts.APIClient, - BundlePath: opts.BundlePath, - Digest: artifact.DigestWithAlg(), - Limit: opts.Limit, - Owner: opts.Owner, - Repo: opts.Repo, - OCIClient: opts.OCIClient, - UseBundleFromRegistry: opts.UseBundleFromRegistry, - NameRef: artifact.NameRef(), - } - attestations, err := verification.GetAttestations(c) + attestations, logMsg, err := getAttestations(opts, *artifact) if err != nil { if ok := errors.Is(err, api.ErrNoAttestations{}); ok { opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ No attestations found for subject %s\n"), artifact.DigestWithAlg()) return err } - - if c.IsBundleProvided() { - opts.Logger.Printf(opts.Logger.ColorScheme.Red("✗ Loading attestations from %s failed\n"), artifact.URL) - } else if c.UseBundleFromRegistry { - opts.Logger.Println(opts.Logger.ColorScheme.Red("✗ Loading attestations from OCI registry failed")) - } else { - opts.Logger.Println(opts.Logger.ColorScheme.Red("✗ Loading attestations from GitHub API failed")) - } + // Print the message signifying failure fetching attestations + opts.Logger.Println(opts.Logger.ColorScheme.Red(logMsg)) return err } - - pluralAttestation := text.Pluralize(len(attestations), "attestation") - if c.IsBundleProvided() { - opts.Logger.Printf("Loaded %s from %s\n", pluralAttestation, opts.BundlePath) - } else if c.UseBundleFromRegistry { - opts.Logger.Printf("Loaded %s from %s\n", pluralAttestation, opts.ArtifactPath) - } else { - opts.Logger.Printf("Loaded %s from GitHub API\n", pluralAttestation) - } + // Print the message signifying success fetching attestations + opts.Logger.Println(logMsg) // Apply predicate type filter to returned attestations filteredAttestations := verification.FilterAttestations(ec.PredicateType, attestations)