zig.go (view raw)
1//
2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
3//
4// Permission to use, copy, modify, and distribute this software for any
5// purpose with or without fee is hereby granted, provided that the above
6// copyright notice and this permission notice appear in all copies.
7//
8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16package main
17
18import (
19 "bytes"
20 "crypto"
21 "crypto/rand"
22 "crypto/rsa"
23 "crypto/sha256"
24 "crypto/x509"
25 "encoding/base64"
26 "encoding/pem"
27 "fmt"
28 "io"
29 "log"
30 "net/http"
31 "regexp"
32 "strings"
33 "time"
34)
35
36func sb64(data []byte) string {
37 var sb strings.Builder
38 b64 := base64.NewEncoder(base64.StdEncoding, &sb)
39 b64.Write(data)
40 b64.Close()
41 return sb.String()
42
43}
44func b64s(s string) []byte {
45 var buf bytes.Buffer
46 b64 := base64.NewDecoder(base64.StdEncoding, strings.NewReader(s))
47 io.Copy(&buf, b64)
48 return buf.Bytes()
49}
50func sb64sha256(content []byte) string {
51 h := sha256.New()
52 h.Write(content)
53 return sb64(h.Sum(nil))
54}
55
56func zig(keyname string, key *rsa.PrivateKey, req *http.Request, content []byte) {
57 headers := []string{"(request-target)", "date", "host", "content-type", "digest"}
58 var stuff []string
59 for _, h := range headers {
60 var s string
61 switch h {
62 case "(request-target)":
63 s = strings.ToLower(req.Method) + " " + req.URL.RequestURI()
64 case "date":
65 s = req.Header.Get(h)
66 if s == "" {
67 s = time.Now().UTC().Format(http.TimeFormat)
68 req.Header.Set(h, s)
69 }
70 case "host":
71 s = req.Header.Get(h)
72 if s == "" {
73 s = req.URL.Hostname()
74 req.Header.Set(h, s)
75 }
76 case "content-type":
77 s = req.Header.Get(h)
78 case "digest":
79 s = req.Header.Get(h)
80 if s == "" {
81 s = "SHA-256=" + sb64sha256(content)
82 req.Header.Set(h, s)
83 }
84 }
85 stuff = append(stuff, h+": "+s)
86 }
87
88 h := sha256.New()
89 h.Write([]byte(strings.Join(stuff, "\n")))
90 sig, _ := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil))
91 bsig := sb64(sig)
92
93 sighdr := fmt.Sprintf(`keyId="%s",algorithm="%s",headers="%s",signature="%s"`,
94 keyname, "rsa-sha256", strings.Join(headers, " "), bsig)
95 req.Header.Set("Signature", sighdr)
96}
97
98var re_sighdrval = regexp.MustCompile(`(.*)="(.*)"`)
99
100func zag(req *http.Request, content []byte) (string, error) {
101 sighdr := req.Header.Get("Signature")
102
103 var keyname, algo, heads, bsig string
104 for _, v := range strings.Split(sighdr, ",") {
105 m := re_sighdrval.FindStringSubmatch(v)
106 if len(m) != 3 {
107 return "", fmt.Errorf("bad scan: %s from %s\n", v, sighdr)
108 }
109 switch m[1] {
110 case "keyId":
111 keyname = m[2]
112 case "algorithm":
113 algo = m[2]
114 case "headers":
115 heads = m[2]
116 case "signature":
117 bsig = m[2]
118 default:
119 return "", fmt.Errorf("bad sig val: %s", m[1])
120 }
121 }
122 if keyname == "" || algo == "" || heads == "" || bsig == "" {
123 return "", fmt.Errorf("missing a sig value")
124 }
125
126 key := zaggy(keyname)
127 if key == nil {
128 return keyname, fmt.Errorf("no key for %s", keyname)
129 }
130 headers := strings.Split(heads, " ")
131 var stuff []string
132 for _, h := range headers {
133 var s string
134 switch h {
135 case "(request-target)":
136 s = strings.ToLower(req.Method) + " " + req.URL.RequestURI()
137 case "host":
138 s = req.Host
139 if s != serverName {
140 log.Printf("caution: servername host header mismatch")
141 }
142 default:
143 s = req.Header.Get(h)
144 }
145 stuff = append(stuff, h+": "+s)
146 }
147
148 h := sha256.New()
149 h.Write([]byte(strings.Join(stuff, "\n")))
150 sig := b64s(bsig)
151 err := rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), sig)
152 if err != nil {
153 return keyname, err
154 }
155 return keyname, nil
156}
157
158func pez(s string) (pri *rsa.PrivateKey, pub *rsa.PublicKey, err error) {
159 block, _ := pem.Decode([]byte(s))
160 if block == nil {
161 err = fmt.Errorf("no pem data")
162 return
163 }
164 switch block.Type {
165 case "PUBLIC KEY":
166 var k interface{}
167 k, err = x509.ParsePKIXPublicKey(block.Bytes)
168 if k != nil {
169 pub, _ = k.(*rsa.PublicKey)
170 }
171 case "RSA PUBLIC KEY":
172 pub, err = x509.ParsePKCS1PublicKey(block.Bytes)
173 case "RSA PRIVATE KEY":
174 pri, err = x509.ParsePKCS1PrivateKey(block.Bytes)
175 if err == nil {
176 pub = &pri.PublicKey
177 }
178 default:
179 err = fmt.Errorf("unknown key type")
180 }
181 return
182}
183
184func zem(i interface{}) (string, error) {
185 var b pem.Block
186 var err error
187 switch k := i.(type) {
188 case *rsa.PrivateKey:
189 b.Type = "RSA PRIVATE KEY"
190 b.Bytes = x509.MarshalPKCS1PrivateKey(k)
191 case *rsa.PublicKey:
192 b.Type = "PUBLIC KEY"
193 b.Bytes, err = x509.MarshalPKIXPublicKey(k)
194 default:
195 err = fmt.Errorf("unknown key type: %s", k)
196 }
197 if err != nil {
198 return "", err
199 }
200 return string(pem.EncodeToMemory(&b)), nil
201}