zig.go (view raw)
1//
2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
3//
4// Permission to use, copy, modify, and distribute this software for any
5// purpose with or without fee is hereby granted, provided that the above
6// copyright notice and this permission notice appear in all copies.
7//
8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16package main
17
18import (
19 "bytes"
20 "crypto"
21 "crypto/rand"
22 "crypto/rsa"
23 "crypto/sha256"
24 "crypto/x509"
25 "encoding/base64"
26 "encoding/pem"
27 "fmt"
28 "io"
29 "log"
30 "net/http"
31 "regexp"
32 "strings"
33 "time"
34)
35
36func sb64(data []byte) string {
37 var sb strings.Builder
38 b64 := base64.NewEncoder(base64.StdEncoding, &sb)
39 b64.Write(data)
40 b64.Close()
41 return sb.String()
42
43}
44func b64s(s string) []byte {
45 var buf bytes.Buffer
46 b64 := base64.NewDecoder(base64.StdEncoding, strings.NewReader(s))
47 io.Copy(&buf, b64)
48 return buf.Bytes()
49}
50func sb64sha256(content []byte) string {
51 h := sha256.New()
52 h.Write(content)
53 return sb64(h.Sum(nil))
54}
55
56func zig(keyname string, key *rsa.PrivateKey, req *http.Request, content []byte) {
57 headers := []string{"(request-target)", "date", "host", "content-type", "digest"}
58 var stuff []string
59 for _, h := range headers {
60 var s string
61 switch h {
62 case "(request-target)":
63 s = strings.ToLower(req.Method) + " " + req.URL.RequestURI()
64 case "date":
65 s = req.Header.Get(h)
66 if s == "" {
67 s = time.Now().UTC().Format(http.TimeFormat)
68 req.Header.Set(h, s)
69 }
70 case "host":
71 s = req.Header.Get(h)
72 if s == "" {
73 s = req.URL.Hostname()
74 req.Header.Set(h, s)
75 }
76 case "content-type":
77 s = req.Header.Get(h)
78 case "digest":
79 s = req.Header.Get(h)
80 if s == "" {
81 s = "SHA-256=" + sb64sha256(content)
82 req.Header.Set(h, s)
83 }
84 }
85 stuff = append(stuff, h+": "+s)
86 }
87
88 h := sha256.New()
89 h.Write([]byte(strings.Join(stuff, "\n")))
90 sig, _ := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil))
91 bsig := sb64(sig)
92
93 sighdr := fmt.Sprintf(`keyId="%s",algorithm="%s",headers="%s",signature="%s"`,
94 keyname, "rsa-sha256", strings.Join(headers, " "), bsig)
95 req.Header.Set("Signature", sighdr)
96}
97
98var re_sighdrval = regexp.MustCompile(`(.*)="(.*)"`)
99
100func zag(req *http.Request, content []byte) (string, error) {
101 sighdr := req.Header.Get("Signature")
102
103 var keyname, algo, heads, bsig string
104 for _, v := range strings.Split(sighdr, ",") {
105 m := re_sighdrval.FindStringSubmatch(v)
106 if len(m) != 3 {
107 return "", fmt.Errorf("bad scan: %s from %s\n", v, sighdr)
108 }
109 switch m[1] {
110 case "keyId":
111 keyname = m[2]
112 case "algorithm":
113 algo = m[2]
114 case "headers":
115 heads = m[2]
116 case "signature":
117 bsig = m[2]
118 default:
119 return "", fmt.Errorf("bad sig val: %s", m[1])
120 }
121 }
122 if keyname == "" || algo == "" || heads == "" || bsig == "" {
123 return "", fmt.Errorf("missing a sig value")
124 }
125
126 key := zaggy(keyname)
127 if key == nil {
128 return keyname, fmt.Errorf("no key for %s", keyname)
129 }
130 headers := strings.Split(heads, " ")
131 var stuff []string
132 for _, h := range headers {
133 var s string
134 switch h {
135 case "(request-target)":
136 s = strings.ToLower(req.Method) + " " + req.URL.RequestURI()
137 case "host":
138 s = req.Host
139 if s != serverName {
140 log.Printf("caution: servername host header mismatch")
141 }
142 case "digest":
143 s = req.Header.Get(h)
144 expv := "SHA-256=" + sb64sha256(content)
145 if s != expv {
146 return "", fmt.Errorf("digest header '%s' did not match content", s)
147 }
148 default:
149 s = req.Header.Get(h)
150 }
151 stuff = append(stuff, h+": "+s)
152 }
153
154 h := sha256.New()
155 h.Write([]byte(strings.Join(stuff, "\n")))
156 sig := b64s(bsig)
157 err := rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), sig)
158 if err != nil {
159 return keyname, err
160 }
161 return keyname, nil
162}
163
164func pez(s string) (pri *rsa.PrivateKey, pub *rsa.PublicKey, err error) {
165 block, _ := pem.Decode([]byte(s))
166 if block == nil {
167 err = fmt.Errorf("no pem data")
168 return
169 }
170 switch block.Type {
171 case "PUBLIC KEY":
172 var k interface{}
173 k, err = x509.ParsePKIXPublicKey(block.Bytes)
174 if k != nil {
175 pub, _ = k.(*rsa.PublicKey)
176 }
177 case "RSA PUBLIC KEY":
178 pub, err = x509.ParsePKCS1PublicKey(block.Bytes)
179 case "RSA PRIVATE KEY":
180 pri, err = x509.ParsePKCS1PrivateKey(block.Bytes)
181 if err == nil {
182 pub = &pri.PublicKey
183 }
184 default:
185 err = fmt.Errorf("unknown key type")
186 }
187 return
188}
189
190func zem(i interface{}) (string, error) {
191 var b pem.Block
192 var err error
193 switch k := i.(type) {
194 case *rsa.PrivateKey:
195 b.Type = "RSA PRIVATE KEY"
196 b.Bytes = x509.MarshalPKCS1PrivateKey(k)
197 case *rsa.PublicKey:
198 b.Type = "PUBLIC KEY"
199 b.Bytes, err = x509.MarshalPKIXPublicKey(k)
200 default:
201 err = fmt.Errorf("unknown key type: %s", k)
202 }
203 if err != nil {
204 return "", err
205 }
206 return string(pem.EncodeToMemory(&b)), nil
207}