summaryrefslogtreecommitdiffstats
path: root/modules/templates/scopedtmpl/scopedtmpl.go
blob: 2722ba97a28ac39413338bff29a3357f35232838 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT

package scopedtmpl

import (
	"fmt"
	"html/template"
	"io"
	"reflect"
	"sync"
	texttemplate "text/template"
	"text/template/parse"
	"unsafe"
)

type TemplateExecutor interface {
	Execute(wr io.Writer, data any) error
}

type ScopedTemplate struct {
	all        *template.Template
	parseFuncs template.FuncMap // this func map is only used for parsing templates
	frozen     bool

	scopedMu           sync.RWMutex
	scopedTemplateSets map[string]*scopedTemplateSet
}

func NewScopedTemplate() *ScopedTemplate {
	return &ScopedTemplate{
		all:                template.New(""),
		parseFuncs:         template.FuncMap{},
		scopedTemplateSets: map[string]*scopedTemplateSet{},
	}
}

func (t *ScopedTemplate) Funcs(funcMap template.FuncMap) {
	if t.frozen {
		panic("cannot add new functions to frozen template set")
	}
	t.all.Funcs(funcMap)
	for k, v := range funcMap {
		t.parseFuncs[k] = v
	}
}

func (t *ScopedTemplate) New(name string) *template.Template {
	if t.frozen {
		panic("cannot add new template to frozen template set")
	}
	return t.all.New(name)
}

func (t *ScopedTemplate) Freeze() {
	t.frozen = true
	// reset the exec func map, then `escapeTemplate` is safe to call `Execute` to do escaping
	m := template.FuncMap{}
	for k := range t.parseFuncs {
		m[k] = func(v ...any) any { return nil }
	}
	t.all.Funcs(m)
}

func (t *ScopedTemplate) Executor(name string, funcMap template.FuncMap) (TemplateExecutor, error) {
	t.scopedMu.RLock()
	scopedTmplSet, ok := t.scopedTemplateSets[name]
	t.scopedMu.RUnlock()

	if !ok {
		var err error
		t.scopedMu.Lock()
		if scopedTmplSet, ok = t.scopedTemplateSets[name]; !ok {
			if scopedTmplSet, err = newScopedTemplateSet(t.all, name); err == nil {
				t.scopedTemplateSets[name] = scopedTmplSet
			}
		}
		t.scopedMu.Unlock()
		if err != nil {
			return nil, err
		}
	}

	if scopedTmplSet == nil {
		return nil, fmt.Errorf("template %s not found", name)
	}
	return scopedTmplSet.newExecutor(funcMap), nil
}

type scopedTemplateSet struct {
	name          string
	htmlTemplates map[string]*template.Template
	textTemplates map[string]*texttemplate.Template
	execFuncs     map[string]reflect.Value
}

func escapeTemplate(t *template.Template) error {
	// force the Golang HTML template to complete the escaping work
	err := t.Execute(io.Discard, nil)
	if _, ok := err.(*template.Error); ok {
		return err
	}
	return nil
}

//nolint:unused
type htmlTemplate struct {
	escapeErr error
	text      *texttemplate.Template
}

//nolint:unused
type textTemplateCommon struct {
	tmpl   map[string]*template.Template // Map from name to defined templates.
	muTmpl sync.RWMutex                  // protects tmpl
	option struct {
		missingKey int
	}
	muFuncs    sync.RWMutex // protects parseFuncs and execFuncs
	parseFuncs texttemplate.FuncMap
	execFuncs  map[string]reflect.Value
}

//nolint:unused
type textTemplate struct {
	name string
	*parse.Tree
	*textTemplateCommon
	leftDelim  string
	rightDelim string
}

func ptr[T, P any](ptr *P) *T {
	// https://pkg.go.dev/unsafe#Pointer
	// (1) Conversion of a *T1 to Pointer to *T2.
	// Provided that T2 is no larger than T1 and that the two share an equivalent memory layout,
	// this conversion allows reinterpreting data of one type as data of another type.
	return (*T)(unsafe.Pointer(ptr))
}

func newScopedTemplateSet(all *template.Template, name string) (*scopedTemplateSet, error) {
	targetTmpl := all.Lookup(name)
	if targetTmpl == nil {
		return nil, fmt.Errorf("template %q not found", name)
	}
	if err := escapeTemplate(targetTmpl); err != nil {
		return nil, fmt.Errorf("template %q has an error when escaping: %w", name, err)
	}

	ts := &scopedTemplateSet{
		name:          name,
		htmlTemplates: map[string]*template.Template{},
		textTemplates: map[string]*texttemplate.Template{},
	}

	htmlTmpl := ptr[htmlTemplate](all)
	textTmpl := htmlTmpl.text
	textTmplPtr := ptr[textTemplate](textTmpl)

	textTmplPtr.muFuncs.Lock()
	ts.execFuncs = map[string]reflect.Value{}
	for k, v := range textTmplPtr.execFuncs {
		ts.execFuncs[k] = v
	}
	textTmplPtr.muFuncs.Unlock()

	var collectTemplates func(nodes []parse.Node)
	var collectErr error // only need to collect the one error
	collectTemplates = func(nodes []parse.Node) {
		for _, node := range nodes {
			if node.Type() == parse.NodeTemplate {
				nodeTemplate := node.(*parse.TemplateNode)
				subName := nodeTemplate.Name
				if ts.htmlTemplates[subName] == nil {
					subTmpl := all.Lookup(subName)
					if subTmpl == nil {
						// HTML template will add some internal templates like "$delimDoubleQuote" into the text template
						ts.textTemplates[subName] = textTmpl.Lookup(subName)
					} else if subTmpl.Tree == nil || subTmpl.Tree.Root == nil {
						collectErr = fmt.Errorf("template %q has no tree, it's usually caused by broken templates", subName)
					} else {
						ts.htmlTemplates[subName] = subTmpl
						if err := escapeTemplate(subTmpl); err != nil {
							collectErr = fmt.Errorf("template %q has an error when escaping: %w", subName, err)
							return
						}
						collectTemplates(subTmpl.Tree.Root.Nodes)
					}
				}
			} else if node.Type() == parse.NodeList {
				nodeList := node.(*parse.ListNode)
				collectTemplates(nodeList.Nodes)
			} else if node.Type() == parse.NodeIf {
				nodeIf := node.(*parse.IfNode)
				collectTemplates(nodeIf.BranchNode.List.Nodes)
				if nodeIf.BranchNode.ElseList != nil {
					collectTemplates(nodeIf.BranchNode.ElseList.Nodes)
				}
			} else if node.Type() == parse.NodeRange {
				nodeRange := node.(*parse.RangeNode)
				collectTemplates(nodeRange.BranchNode.List.Nodes)
				if nodeRange.BranchNode.ElseList != nil {
					collectTemplates(nodeRange.BranchNode.ElseList.Nodes)
				}
			} else if node.Type() == parse.NodeWith {
				nodeWith := node.(*parse.WithNode)
				collectTemplates(nodeWith.BranchNode.List.Nodes)
				if nodeWith.BranchNode.ElseList != nil {
					collectTemplates(nodeWith.BranchNode.ElseList.Nodes)
				}
			}
		}
	}
	ts.htmlTemplates[name] = targetTmpl
	collectTemplates(targetTmpl.Tree.Root.Nodes)
	return ts, collectErr
}

func (ts *scopedTemplateSet) newExecutor(funcMap map[string]any) TemplateExecutor {
	tmpl := texttemplate.New("")
	tmplPtr := ptr[textTemplate](tmpl)
	tmplPtr.execFuncs = map[string]reflect.Value{}
	for k, v := range ts.execFuncs {
		tmplPtr.execFuncs[k] = v
	}
	if funcMap != nil {
		tmpl.Funcs(funcMap)
	}
	// after escapeTemplate, the html templates are also escaped text templates, so it could be added to the text template directly
	for _, t := range ts.htmlTemplates {
		_, _ = tmpl.AddParseTree(t.Name(), t.Tree)
	}
	for _, t := range ts.textTemplates {
		_, _ = tmpl.AddParseTree(t.Name(), t.Tree)
	}

	// now the text template has all necessary escaped templates, so we can safely execute, just like what the html template does
	return tmpl.Lookup(ts.name)
}