util.go (view raw)
1//
2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
3//
4// Permission to use, copy, modify, and distribute this software for any
5// purpose with or without fee is hereby granted, provided that the above
6// copyright notice and this permission notice appear in all copies.
7//
8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16package main
17
18/*
19#include <termios.h>
20
21void
22termecho(int on)
23{
24 struct termios t;
25 tcgetattr(1, &t);
26 if (on)
27 t.c_lflag |= ECHO;
28 else
29 t.c_lflag &= ~ECHO;
30 tcsetattr(1, TCSADRAIN, &t);
31}
32*/
33import "C"
34
35import (
36 "bufio"
37 "crypto/rand"
38 "crypto/rsa"
39 "crypto/sha512"
40 "database/sql"
41 "fmt"
42 "io/ioutil"
43 "log"
44 "net"
45 "os"
46 "os/signal"
47 "strings"
48
49 "golang.org/x/crypto/bcrypt"
50 _ "humungus.tedunangst.com/r/go-sqlite3"
51 "humungus.tedunangst.com/r/webs/httpsig"
52)
53
54var savedstyleparams = make(map[string]string)
55
56func getstyleparam(file string) string {
57 if p, ok := savedstyleparams[file]; ok {
58 return p
59 }
60 data, err := ioutil.ReadFile(file)
61 if err != nil {
62 return ""
63 }
64 hasher := sha512.New()
65 hasher.Write(data)
66
67 return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
68}
69
70var dbtimeformat = "2006-01-02 15:04:05"
71
72var alreadyopendb *sql.DB
73var dbname = "honk.db"
74var stmtConfig *sql.Stmt
75var myVersion = 17
76
77func initdb() {
78 schema, err := ioutil.ReadFile("schema.sql")
79 if err != nil {
80 log.Fatal(err)
81 }
82 _, err = os.Stat(dbname)
83 if err == nil {
84 log.Fatalf("%s already exists", dbname)
85 }
86 db, err := sql.Open("sqlite3", dbname)
87 if err != nil {
88 log.Fatal(err)
89 }
90 defer func() {
91 os.Remove(dbname)
92 os.Exit(1)
93 }()
94 c := make(chan os.Signal)
95 signal.Notify(c, os.Interrupt)
96 go func() {
97 <-c
98 C.termecho(1)
99 fmt.Printf("\n")
100 os.Remove(dbname)
101 os.Exit(1)
102 }()
103
104 for _, line := range strings.Split(string(schema), ";") {
105 _, err = db.Exec(line)
106 if err != nil {
107 log.Print(err)
108 return
109 }
110 }
111 defer db.Close()
112 r := bufio.NewReader(os.Stdin)
113
114 err = createuser(db, r)
115 if err != nil {
116 log.Print(err)
117 return
118 }
119
120 fmt.Printf("listen address: ")
121 addr, err := r.ReadString('\n')
122 if err != nil {
123 log.Print(err)
124 return
125 }
126 addr = addr[:len(addr)-1]
127 if len(addr) < 1 {
128 log.Print("that's way too short")
129 return
130 }
131 _, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
132 if err != nil {
133 log.Print(err)
134 return
135 }
136 fmt.Printf("server name: ")
137 addr, err = r.ReadString('\n')
138 if err != nil {
139 log.Print(err)
140 return
141 }
142 addr = addr[:len(addr)-1]
143 if len(addr) < 1 {
144 log.Print("that's way too short")
145 return
146 }
147 _, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
148 if err != nil {
149 log.Print(err)
150 return
151 }
152 var randbytes [16]byte
153 rand.Read(randbytes[:])
154 key := fmt.Sprintf("%x", randbytes)
155 _, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
156 if err != nil {
157 log.Print(err)
158 return
159 }
160 _, err = db.Exec("insert into config (key, value) values (?, ?)", "dbversion", myVersion)
161 if err != nil {
162 log.Print(err)
163 return
164 }
165 prepareStatements(db)
166 db.Close()
167 fmt.Printf("done.\n")
168 os.Exit(0)
169}
170
171func adduser() {
172 db := opendatabase()
173 defer func() {
174 os.Exit(1)
175 }()
176 c := make(chan os.Signal)
177 signal.Notify(c, os.Interrupt)
178 go func() {
179 <-c
180 C.termecho(1)
181 fmt.Printf("\n")
182 os.Exit(1)
183 }()
184
185 r := bufio.NewReader(os.Stdin)
186
187 err := createuser(db, r)
188 if err != nil {
189 log.Print(err)
190 return
191 }
192
193 db.Close()
194 os.Exit(0)
195}
196
197func createuser(db *sql.DB, r *bufio.Reader) error {
198 fmt.Printf("username: ")
199 name, err := r.ReadString('\n')
200 if err != nil {
201 return err
202 }
203 name = name[:len(name)-1]
204 if len(name) < 1 {
205 return fmt.Errorf("that's way too short")
206 }
207 C.termecho(0)
208 fmt.Printf("password: ")
209 pass, err := r.ReadString('\n')
210 C.termecho(1)
211 fmt.Printf("\n")
212 if err != nil {
213 return err
214 }
215 pass = pass[:len(pass)-1]
216 if len(pass) < 6 {
217 return fmt.Errorf("that's way too short")
218 }
219 hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
220 if err != nil {
221 return err
222 }
223 k, err := rsa.GenerateKey(rand.Reader, 2048)
224 if err != nil {
225 return err
226 }
227 pubkey, err := httpsig.EncodeKey(&k.PublicKey)
228 if err != nil {
229 return err
230 }
231 seckey, err := httpsig.EncodeKey(k)
232 if err != nil {
233 return err
234 }
235 _, err = db.Exec("insert into users (username, displayname, about, hash, pubkey, seckey, options) values (?, ?, ?, ?, ?, ?, ?)", name, name, "what about me?", hash, pubkey, seckey, "")
236 if err != nil {
237 return err
238 }
239 return nil
240}
241
242func opendatabase() *sql.DB {
243 if alreadyopendb != nil {
244 return alreadyopendb
245 }
246 var err error
247 _, err = os.Stat(dbname)
248 if err != nil {
249 log.Fatalf("unable to open database: %s", err)
250 }
251 db, err := sql.Open("sqlite3", dbname)
252 if err != nil {
253 log.Fatalf("unable to open database: %s", err)
254 }
255 stmtConfig, err = db.Prepare("select value from config where key = ?")
256 if err != nil {
257 log.Fatal(err)
258 }
259 alreadyopendb = db
260 return db
261}
262
263func getconfig(key string, value interface{}) error {
264 m, ok := value.(*map[string]bool)
265 if ok {
266 rows, err := stmtConfig.Query(key)
267 if err != nil {
268 return err
269 }
270 defer rows.Close()
271 for rows.Next() {
272 var s string
273 err = rows.Scan(&s)
274 if err != nil {
275 return err
276 }
277 (*m)[s] = true
278 }
279 return nil
280 }
281 row := stmtConfig.QueryRow(key)
282 err := row.Scan(value)
283 if err == sql.ErrNoRows {
284 err = nil
285 }
286 return err
287}
288
289func saveconfig(key string, val interface{}) {
290 db := opendatabase()
291 db.Exec("update config set value = ? where key = ?", val, key)
292}
293
294func openListener() (net.Listener, error) {
295 var listenAddr string
296 err := getconfig("listenaddr", &listenAddr)
297 if err != nil {
298 return nil, err
299 }
300 if listenAddr == "" {
301 return nil, fmt.Errorf("must have listenaddr")
302 }
303 proto := "tcp"
304 if listenAddr[0] == '/' {
305 proto = "unix"
306 err := os.Remove(listenAddr)
307 if err != nil && !os.IsNotExist(err) {
308 log.Printf("unable to unlink socket: %s", err)
309 }
310 }
311 listener, err := net.Listen(proto, listenAddr)
312 if err != nil {
313 return nil, err
314 }
315 if proto == "unix" {
316 os.Chmod(listenAddr, 0777)
317 }
318 return listener, nil
319}