From e68b9d00a6e05b3a941f63ffb696f91e554ac5ec Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 18 Oct 2024 20:33:49 +0200 Subject: Adding upstream version 9.0.3. Signed-off-by: Daniel Baumann --- modules/activitypub/client.go | 273 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 modules/activitypub/client.go (limited to 'modules/activitypub/client.go') diff --git a/modules/activitypub/client.go b/modules/activitypub/client.go new file mode 100644 index 0000000..064d898 --- /dev/null +++ b/modules/activitypub/client.go @@ -0,0 +1,273 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// Copyright 2024 The Forgejo Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +// TODO: Think about whether this should be moved to services/activitypub (compare to exosy/services/activitypub/client.go) +package activitypub + +import ( + "bytes" + "context" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "net/http" + "strings" + "time" + + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/proxy" + "code.gitea.io/gitea/modules/setting" + + "github.com/go-fed/httpsig" +) + +const ( + // ActivityStreamsContentType const + ActivityStreamsContentType = `application/ld+json; profile="https://www.w3.org/ns/activitystreams"` + httpsigExpirationTime = 60 +) + +func CurrentTime() string { + return time.Now().UTC().Format(http.TimeFormat) +} + +func containsRequiredHTTPHeaders(method string, headers []string) error { + var hasRequestTarget, hasDate, hasDigest, hasHost bool + for _, header := range headers { + hasRequestTarget = hasRequestTarget || header == httpsig.RequestTarget + hasDate = hasDate || header == "Date" + hasDigest = hasDigest || header == "Digest" + hasHost = hasHost || header == "Host" + } + if !hasRequestTarget { + return fmt.Errorf("missing http header for %s: %s", method, httpsig.RequestTarget) + } else if !hasDate { + return fmt.Errorf("missing http header for %s: Date", method) + } else if !hasHost { + return fmt.Errorf("missing http header for %s: Host", method) + } else if !hasDigest && method != http.MethodGet { + return fmt.Errorf("missing http header for %s: Digest", method) + } + return nil +} + +// Client struct +type ClientFactory struct { + client *http.Client + algs []httpsig.Algorithm + digestAlg httpsig.DigestAlgorithm + getHeaders []string + postHeaders []string +} + +// NewClient function +func NewClientFactory() (c *ClientFactory, err error) { + if err = containsRequiredHTTPHeaders(http.MethodGet, setting.Federation.GetHeaders); err != nil { + return nil, err + } else if err = containsRequiredHTTPHeaders(http.MethodPost, setting.Federation.PostHeaders); err != nil { + return nil, err + } + + c = &ClientFactory{ + client: &http.Client{ + Transport: &http.Transport{ + Proxy: proxy.Proxy(), + }, + Timeout: 5 * time.Second, + }, + algs: setting.HttpsigAlgs, + digestAlg: httpsig.DigestAlgorithm(setting.Federation.DigestAlgorithm), + getHeaders: setting.Federation.GetHeaders, + postHeaders: setting.Federation.PostHeaders, + } + return c, err +} + +type APClientFactory interface { + WithKeys(ctx context.Context, user *user_model.User, pubID string) (APClient, error) +} + +// Client struct +type Client struct { + client *http.Client + algs []httpsig.Algorithm + digestAlg httpsig.DigestAlgorithm + getHeaders []string + postHeaders []string + priv *rsa.PrivateKey + pubID string +} + +// NewRequest function +func (cf *ClientFactory) WithKeys(ctx context.Context, user *user_model.User, pubID string) (APClient, error) { + priv, err := GetPrivateKey(ctx, user) + if err != nil { + return nil, err + } + privPem, _ := pem.Decode([]byte(priv)) + privParsed, err := x509.ParsePKCS1PrivateKey(privPem.Bytes) + if err != nil { + return nil, err + } + + c := Client{ + client: cf.client, + algs: cf.algs, + digestAlg: cf.digestAlg, + getHeaders: cf.getHeaders, + postHeaders: cf.postHeaders, + priv: privParsed, + pubID: pubID, + } + return &c, nil +} + +// NewRequest function +func (c *Client) newRequest(method string, b []byte, to string) (req *http.Request, err error) { + buf := bytes.NewBuffer(b) + req, err = http.NewRequest(method, to, buf) + if err != nil { + return nil, err + } + req.Header.Add("Accept", "application/json, "+ActivityStreamsContentType) + req.Header.Add("Date", CurrentTime()) + req.Header.Add("Host", req.URL.Host) + req.Header.Add("User-Agent", "Gitea/"+setting.AppVer) + req.Header.Add("Content-Type", ActivityStreamsContentType) + + return req, err +} + +// Post function +func (c *Client) Post(b []byte, to string) (resp *http.Response, err error) { + var req *http.Request + if req, err = c.newRequest(http.MethodPost, b, to); err != nil { + return nil, err + } + + signer, _, err := httpsig.NewSigner(c.algs, c.digestAlg, c.postHeaders, httpsig.Signature, httpsigExpirationTime) + if err != nil { + return nil, err + } + if err := signer.SignRequest(c.priv, c.pubID, req, b); err != nil { + return nil, err + } + + resp, err = c.client.Do(req) + return resp, err +} + +// Create an http GET request with forgejo/gitea specific headers +func (c *Client) Get(to string) (resp *http.Response, err error) { + var req *http.Request + if req, err = c.newRequest(http.MethodGet, nil, to); err != nil { + return nil, err + } + signer, _, err := httpsig.NewSigner(c.algs, c.digestAlg, c.getHeaders, httpsig.Signature, httpsigExpirationTime) + if err != nil { + return nil, err + } + if err := signer.SignRequest(c.priv, c.pubID, req, nil); err != nil { + return nil, err + } + + resp, err = c.client.Do(req) + return resp, err +} + +// Create an http GET request with forgejo/gitea specific headers +func (c *Client) GetBody(uri string) ([]byte, error) { + response, err := c.Get(uri) + if err != nil { + return nil, err + } + log.Debug("Client: got status: %v", response.Status) + if response.StatusCode != 200 { + err = fmt.Errorf("got non 200 status code for id: %v", uri) + return nil, err + } + defer response.Body.Close() + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + log.Debug("Client: got body: %v", charLimiter(string(body), 120)) + return body, nil +} + +// Limit number of characters in a string (useful to prevent log injection attacks and overly long log outputs) +// Thanks to https://www.socketloop.com/tutorials/golang-characters-limiter-example +func charLimiter(s string, limit int) string { + reader := strings.NewReader(s) + buff := make([]byte, limit) + n, _ := io.ReadAtLeast(reader, buff, limit) + if n != 0 { + return fmt.Sprint(string(buff), "...") + } + return s +} + +type APClient interface { + newRequest(method string, b []byte, to string) (req *http.Request, err error) + Post(b []byte, to string) (resp *http.Response, err error) + Get(to string) (resp *http.Response, err error) + GetBody(uri string) ([]byte, error) +} + +// contextKey is a value for use with context.WithValue. +type contextKey struct { + name string +} + +// clientFactoryContextKey is a context key. It is used with context.Value() to get the current Food for the context +var ( + clientFactoryContextKey = &contextKey{"clientFactory"} + _ APClientFactory = &ClientFactory{} +) + +// Context represents an activitypub client factory context +type Context struct { + context.Context + e APClientFactory +} + +func NewContext(ctx context.Context, e APClientFactory) *Context { + return &Context{ + Context: ctx, + e: e, + } +} + +// APClientFactory represents an activitypub client factory +func (ctx *Context) APClientFactory() APClientFactory { + return ctx.e +} + +// provides APClientFactory +type GetAPClient interface { + GetClientFactory() APClientFactory +} + +// GetClientFactory will get an APClientFactory from this context or returns the default implementation +func GetClientFactory(ctx context.Context) (APClientFactory, error) { + if e := getClientFactory(ctx); e != nil { + return e, nil + } + return NewClientFactory() +} + +// getClientFactory will get an APClientFactory from this context or return nil +func getClientFactory(ctx context.Context) APClientFactory { + if clientFactory, ok := ctx.(APClientFactory); ok { + return clientFactory + } + clientFactoryInterface := ctx.Value(clientFactoryContextKey) + if clientFactoryInterface != nil { + return clientFactoryInterface.(GetAPClient).GetClientFactory() + } + return nil +} -- cgit v1.2.3