diff options
author | Daniel Baumann <daniel@debian.org> | 2024-10-18 20:33:49 +0200 |
---|---|---|
committer | Daniel Baumann <daniel@debian.org> | 2024-10-18 20:33:49 +0200 |
commit | dd136858f1ea40ad3c94191d647487fa4f31926c (patch) | |
tree | 58fec94a7b2a12510c9664b21793f1ed560c6518 /modules/templates/scopedtmpl | |
parent | Initial commit. (diff) | |
download | forgejo-dd136858f1ea40ad3c94191d647487fa4f31926c.tar.xz forgejo-dd136858f1ea40ad3c94191d647487fa4f31926c.zip |
Adding upstream version 9.0.0.
Signed-off-by: Daniel Baumann <daniel@debian.org>
Diffstat (limited to 'modules/templates/scopedtmpl')
-rw-r--r-- | modules/templates/scopedtmpl/scopedtmpl.go | 239 | ||||
-rw-r--r-- | modules/templates/scopedtmpl/scopedtmpl_test.go | 99 |
2 files changed, 338 insertions, 0 deletions
diff --git a/modules/templates/scopedtmpl/scopedtmpl.go b/modules/templates/scopedtmpl/scopedtmpl.go new file mode 100644 index 0000000..2722ba9 --- /dev/null +++ b/modules/templates/scopedtmpl/scopedtmpl.go @@ -0,0 +1,239 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package scopedtmpl + +import ( + "fmt" + "html/template" + "io" + "reflect" + "sync" + texttemplate "text/template" + "text/template/parse" + "unsafe" +) + +type TemplateExecutor interface { + Execute(wr io.Writer, data any) error +} + +type ScopedTemplate struct { + all *template.Template + parseFuncs template.FuncMap // this func map is only used for parsing templates + frozen bool + + scopedMu sync.RWMutex + scopedTemplateSets map[string]*scopedTemplateSet +} + +func NewScopedTemplate() *ScopedTemplate { + return &ScopedTemplate{ + all: template.New(""), + parseFuncs: template.FuncMap{}, + scopedTemplateSets: map[string]*scopedTemplateSet{}, + } +} + +func (t *ScopedTemplate) Funcs(funcMap template.FuncMap) { + if t.frozen { + panic("cannot add new functions to frozen template set") + } + t.all.Funcs(funcMap) + for k, v := range funcMap { + t.parseFuncs[k] = v + } +} + +func (t *ScopedTemplate) New(name string) *template.Template { + if t.frozen { + panic("cannot add new template to frozen template set") + } + return t.all.New(name) +} + +func (t *ScopedTemplate) Freeze() { + t.frozen = true + // reset the exec func map, then `escapeTemplate` is safe to call `Execute` to do escaping + m := template.FuncMap{} + for k := range t.parseFuncs { + m[k] = func(v ...any) any { return nil } + } + t.all.Funcs(m) +} + +func (t *ScopedTemplate) Executor(name string, funcMap template.FuncMap) (TemplateExecutor, error) { + t.scopedMu.RLock() + scopedTmplSet, ok := t.scopedTemplateSets[name] + t.scopedMu.RUnlock() + + if !ok { + var err error + t.scopedMu.Lock() + if scopedTmplSet, ok = t.scopedTemplateSets[name]; !ok { + if scopedTmplSet, err = newScopedTemplateSet(t.all, name); err == nil { + t.scopedTemplateSets[name] = scopedTmplSet + } + } + t.scopedMu.Unlock() + if err != nil { + return nil, err + } + } + + if scopedTmplSet == nil { + return nil, fmt.Errorf("template %s not found", name) + } + return scopedTmplSet.newExecutor(funcMap), nil +} + +type scopedTemplateSet struct { + name string + htmlTemplates map[string]*template.Template + textTemplates map[string]*texttemplate.Template + execFuncs map[string]reflect.Value +} + +func escapeTemplate(t *template.Template) error { + // force the Golang HTML template to complete the escaping work + err := t.Execute(io.Discard, nil) + if _, ok := err.(*template.Error); ok { + return err + } + return nil +} + +//nolint:unused +type htmlTemplate struct { + escapeErr error + text *texttemplate.Template +} + +//nolint:unused +type textTemplateCommon struct { + tmpl map[string]*template.Template // Map from name to defined templates. + muTmpl sync.RWMutex // protects tmpl + option struct { + missingKey int + } + muFuncs sync.RWMutex // protects parseFuncs and execFuncs + parseFuncs texttemplate.FuncMap + execFuncs map[string]reflect.Value +} + +//nolint:unused +type textTemplate struct { + name string + *parse.Tree + *textTemplateCommon + leftDelim string + rightDelim string +} + +func ptr[T, P any](ptr *P) *T { + // https://pkg.go.dev/unsafe#Pointer + // (1) Conversion of a *T1 to Pointer to *T2. + // Provided that T2 is no larger than T1 and that the two share an equivalent memory layout, + // this conversion allows reinterpreting data of one type as data of another type. + return (*T)(unsafe.Pointer(ptr)) +} + +func newScopedTemplateSet(all *template.Template, name string) (*scopedTemplateSet, error) { + targetTmpl := all.Lookup(name) + if targetTmpl == nil { + return nil, fmt.Errorf("template %q not found", name) + } + if err := escapeTemplate(targetTmpl); err != nil { + return nil, fmt.Errorf("template %q has an error when escaping: %w", name, err) + } + + ts := &scopedTemplateSet{ + name: name, + htmlTemplates: map[string]*template.Template{}, + textTemplates: map[string]*texttemplate.Template{}, + } + + htmlTmpl := ptr[htmlTemplate](all) + textTmpl := htmlTmpl.text + textTmplPtr := ptr[textTemplate](textTmpl) + + textTmplPtr.muFuncs.Lock() + ts.execFuncs = map[string]reflect.Value{} + for k, v := range textTmplPtr.execFuncs { + ts.execFuncs[k] = v + } + textTmplPtr.muFuncs.Unlock() + + var collectTemplates func(nodes []parse.Node) + var collectErr error // only need to collect the one error + collectTemplates = func(nodes []parse.Node) { + for _, node := range nodes { + if node.Type() == parse.NodeTemplate { + nodeTemplate := node.(*parse.TemplateNode) + subName := nodeTemplate.Name + if ts.htmlTemplates[subName] == nil { + subTmpl := all.Lookup(subName) + if subTmpl == nil { + // HTML template will add some internal templates like "$delimDoubleQuote" into the text template + ts.textTemplates[subName] = textTmpl.Lookup(subName) + } else if subTmpl.Tree == nil || subTmpl.Tree.Root == nil { + collectErr = fmt.Errorf("template %q has no tree, it's usually caused by broken templates", subName) + } else { + ts.htmlTemplates[subName] = subTmpl + if err := escapeTemplate(subTmpl); err != nil { + collectErr = fmt.Errorf("template %q has an error when escaping: %w", subName, err) + return + } + collectTemplates(subTmpl.Tree.Root.Nodes) + } + } + } else if node.Type() == parse.NodeList { + nodeList := node.(*parse.ListNode) + collectTemplates(nodeList.Nodes) + } else if node.Type() == parse.NodeIf { + nodeIf := node.(*parse.IfNode) + collectTemplates(nodeIf.BranchNode.List.Nodes) + if nodeIf.BranchNode.ElseList != nil { + collectTemplates(nodeIf.BranchNode.ElseList.Nodes) + } + } else if node.Type() == parse.NodeRange { + nodeRange := node.(*parse.RangeNode) + collectTemplates(nodeRange.BranchNode.List.Nodes) + if nodeRange.BranchNode.ElseList != nil { + collectTemplates(nodeRange.BranchNode.ElseList.Nodes) + } + } else if node.Type() == parse.NodeWith { + nodeWith := node.(*parse.WithNode) + collectTemplates(nodeWith.BranchNode.List.Nodes) + if nodeWith.BranchNode.ElseList != nil { + collectTemplates(nodeWith.BranchNode.ElseList.Nodes) + } + } + } + } + ts.htmlTemplates[name] = targetTmpl + collectTemplates(targetTmpl.Tree.Root.Nodes) + return ts, collectErr +} + +func (ts *scopedTemplateSet) newExecutor(funcMap map[string]any) TemplateExecutor { + tmpl := texttemplate.New("") + tmplPtr := ptr[textTemplate](tmpl) + tmplPtr.execFuncs = map[string]reflect.Value{} + for k, v := range ts.execFuncs { + tmplPtr.execFuncs[k] = v + } + if funcMap != nil { + tmpl.Funcs(funcMap) + } + // after escapeTemplate, the html templates are also escaped text templates, so it could be added to the text template directly + for _, t := range ts.htmlTemplates { + _, _ = tmpl.AddParseTree(t.Name(), t.Tree) + } + for _, t := range ts.textTemplates { + _, _ = tmpl.AddParseTree(t.Name(), t.Tree) + } + + // now the text template has all necessary escaped templates, so we can safely execute, just like what the html template does + return tmpl.Lookup(ts.name) +} diff --git a/modules/templates/scopedtmpl/scopedtmpl_test.go b/modules/templates/scopedtmpl/scopedtmpl_test.go new file mode 100644 index 0000000..9bbd0c7 --- /dev/null +++ b/modules/templates/scopedtmpl/scopedtmpl_test.go @@ -0,0 +1,99 @@ +// Copyright 2023 The Gitea Authors. All rights reserved. +// SPDX-License-Identifier: MIT + +package scopedtmpl + +import ( + "bytes" + "html/template" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScopedTemplateSetFuncMap(t *testing.T) { + all := template.New("") + + all.Funcs(template.FuncMap{"CtxFunc": func(s string) string { + return "default" + }}) + + _, err := all.New("base").Parse(`{{CtxFunc "base"}}`) + require.NoError(t, err) + + _, err = all.New("test").Parse(strings.TrimSpace(` +{{template "base"}} +{{CtxFunc "test"}} +{{template "base"}} +{{CtxFunc "test"}} +`)) + require.NoError(t, err) + + ts, err := newScopedTemplateSet(all, "test") + require.NoError(t, err) + + // try to use different CtxFunc to render concurrently + + funcMap1 := template.FuncMap{ + "CtxFunc": func(s string) string { + time.Sleep(100 * time.Millisecond) + return s + "1" + }, + } + + funcMap2 := template.FuncMap{ + "CtxFunc": func(s string) string { + time.Sleep(100 * time.Millisecond) + return s + "2" + }, + } + + out1 := bytes.Buffer{} + out2 := bytes.Buffer{} + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + err := ts.newExecutor(funcMap1).Execute(&out1, nil) + require.NoError(t, err) + wg.Done() + }() + go func() { + err := ts.newExecutor(funcMap2).Execute(&out2, nil) + require.NoError(t, err) + wg.Done() + }() + wg.Wait() + assert.Equal(t, "base1\ntest1\nbase1\ntest1", out1.String()) + assert.Equal(t, "base2\ntest2\nbase2\ntest2", out2.String()) +} + +func TestScopedTemplateSetEscape(t *testing.T) { + all := template.New("") + _, err := all.New("base").Parse(`<a href="?q={{.param}}">{{.text}}</a>`) + require.NoError(t, err) + + _, err = all.New("test").Parse(`{{template "base" .}}<form action="?q={{.param}}">{{.text}}</form>`) + require.NoError(t, err) + + ts, err := newScopedTemplateSet(all, "test") + require.NoError(t, err) + + out := bytes.Buffer{} + err = ts.newExecutor(nil).Execute(&out, map[string]string{"param": "/", "text": "<"}) + require.NoError(t, err) + + assert.Equal(t, `<a href="?q=%2f"><</a><form action="?q=%2f"><</form>`, out.String()) +} + +func TestScopedTemplateSetUnsafe(t *testing.T) { + all := template.New("") + _, err := all.New("test").Parse(`<a href="{{if true}}?{{end}}a={{.param}}"></a>`) + require.NoError(t, err) + + _, err = newScopedTemplateSet(all, "test") + require.ErrorContains(t, err, "appears in an ambiguous context within a URL") +} |