util.go (view raw)
1//
2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
3//
4// Permission to use, copy, modify, and distribute this software for any
5// purpose with or without fee is hereby granted, provided that the above
6// copyright notice and this permission notice appear in all copies.
7//
8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16package main
17
18/*
19#include <termios.h>
20
21void
22termecho(int on)
23{
24 struct termios t;
25 tcgetattr(1, &t);
26 if (on)
27 t.c_lflag |= ECHO;
28 else
29 t.c_lflag &= ~ECHO;
30 tcsetattr(1, TCSADRAIN, &t);
31}
32*/
33import "C"
34
35import (
36 "bufio"
37 "crypto/rand"
38 "crypto/rsa"
39 "crypto/sha512"
40 "database/sql"
41 "fmt"
42 "io/ioutil"
43 "log"
44 "net"
45 "os"
46 "os/signal"
47 "strings"
48
49 "golang.org/x/crypto/bcrypt"
50 _ "humungus.tedunangst.com/r/go-sqlite3"
51)
52
53var savedstyleparams = make(map[string]string)
54
55func getstyleparam(file string) string {
56 if p, ok := savedstyleparams[file]; ok {
57 return p
58 }
59 data, err := ioutil.ReadFile(file)
60 if err != nil {
61 return ""
62 }
63 hasher := sha512.New()
64 hasher.Write(data)
65
66 return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
67}
68
69var dbtimeformat = "2006-01-02 15:04:05"
70
71var alreadyopendb *sql.DB
72var dbname = "honk.db"
73var stmtConfig *sql.Stmt
74var myVersion = 8
75
76func initdb() {
77 schema, err := ioutil.ReadFile("schema.sql")
78 if err != nil {
79 log.Fatal(err)
80 }
81 _, err = os.Stat(dbname)
82 if err == nil {
83 log.Fatalf("%s already exists", dbname)
84 }
85 db, err := sql.Open("sqlite3", dbname)
86 if err != nil {
87 log.Fatal(err)
88 }
89 defer func() {
90 os.Remove(dbname)
91 os.Exit(1)
92 }()
93 c := make(chan os.Signal)
94 signal.Notify(c, os.Interrupt)
95 go func() {
96 <-c
97 C.termecho(1)
98 fmt.Printf("\n")
99 os.Remove(dbname)
100 os.Exit(1)
101 }()
102
103 for _, line := range strings.Split(string(schema), ";") {
104 _, err = db.Exec(line)
105 if err != nil {
106 log.Print(err)
107 return
108 }
109 }
110 defer db.Close()
111 r := bufio.NewReader(os.Stdin)
112
113 err = createuser(db, r)
114 if err != nil {
115 log.Print(err)
116 return
117 }
118
119 fmt.Printf("listen address: ")
120 addr, err := r.ReadString('\n')
121 if err != nil {
122 log.Print(err)
123 return
124 }
125 addr = addr[:len(addr)-1]
126 if len(addr) < 1 {
127 log.Print("that's way too short")
128 return
129 }
130 _, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
131 if err != nil {
132 log.Print(err)
133 return
134 }
135 fmt.Printf("server name: ")
136 addr, err = r.ReadString('\n')
137 if err != nil {
138 log.Print(err)
139 return
140 }
141 addr = addr[:len(addr)-1]
142 if len(addr) < 1 {
143 log.Print("that's way too short")
144 return
145 }
146 _, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
147 if err != nil {
148 log.Print(err)
149 return
150 }
151 var randbytes [16]byte
152 rand.Read(randbytes[:])
153 key := fmt.Sprintf("%x", randbytes)
154 _, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
155 if err != nil {
156 log.Print(err)
157 return
158 }
159 _, err = db.Exec("insert into config (key, value) values (?, ?)", "dbversion", myVersion)
160 if err != nil {
161 log.Print(err)
162 return
163 }
164 prepareStatements(db)
165 db.Close()
166 fmt.Printf("done.\n")
167 os.Exit(0)
168}
169
170func adduser() {
171 db := opendatabase()
172 defer func() {
173 os.Exit(1)
174 }()
175 c := make(chan os.Signal)
176 signal.Notify(c, os.Interrupt)
177 go func() {
178 <-c
179 C.termecho(1)
180 fmt.Printf("\n")
181 os.Exit(1)
182 }()
183
184 r := bufio.NewReader(os.Stdin)
185
186 err := createuser(db, r)
187 if err != nil {
188 log.Print(err)
189 return
190 }
191
192 db.Close()
193 os.Exit(0)
194}
195
196func createuser(db *sql.DB, r *bufio.Reader) error {
197 fmt.Printf("username: ")
198 name, err := r.ReadString('\n')
199 if err != nil {
200 return err
201 }
202 name = name[:len(name)-1]
203 if len(name) < 1 {
204 return fmt.Errorf("that's way too short")
205 }
206 C.termecho(0)
207 fmt.Printf("password: ")
208 pass, err := r.ReadString('\n')
209 C.termecho(1)
210 fmt.Printf("\n")
211 if err != nil {
212 return err
213 }
214 pass = pass[:len(pass)-1]
215 if len(pass) < 6 {
216 return fmt.Errorf("that's way too short")
217 }
218 hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
219 if err != nil {
220 return err
221 }
222 k, err := rsa.GenerateKey(rand.Reader, 2048)
223 if err != nil {
224 return err
225 }
226 pubkey, err := zem(&k.PublicKey)
227 if err != nil {
228 return err
229 }
230 seckey, err := zem(k)
231 if err != nil {
232 return err
233 }
234 _, err = db.Exec("insert into users (username, displayname, about, hash, pubkey, seckey) values (?, ?, ?, ?, ?, ?)", name, name, "what about me?", hash, pubkey, seckey)
235 if err != nil {
236 return err
237 }
238 return nil
239}
240
241func opendatabase() *sql.DB {
242 if alreadyopendb != nil {
243 return alreadyopendb
244 }
245 var err error
246 _, err = os.Stat(dbname)
247 if err != nil {
248 log.Fatalf("unable to open database: %s", err)
249 }
250 db, err := sql.Open("sqlite3", dbname)
251 if err != nil {
252 log.Fatalf("unable to open database: %s", err)
253 }
254 stmtConfig, err = db.Prepare("select value from config where key = ?")
255 if err != nil {
256 log.Fatal(err)
257 }
258 alreadyopendb = db
259 return db
260}
261
262func getconfig(key string, value interface{}) error {
263 row := stmtConfig.QueryRow(key)
264 err := row.Scan(value)
265 if err == sql.ErrNoRows {
266 err = nil
267 }
268 return err
269}
270
271func openListener() (net.Listener, error) {
272 var listenAddr string
273 err := getconfig("listenaddr", &listenAddr)
274 if err != nil {
275 return nil, err
276 }
277 if listenAddr == "" {
278 return nil, fmt.Errorf("must have listenaddr")
279 }
280 proto := "tcp"
281 if listenAddr[0] == '/' {
282 proto = "unix"
283 err := os.Remove(listenAddr)
284 if err != nil && !os.IsNotExist(err) {
285 log.Printf("unable to unlink socket: %s", err)
286 }
287 }
288 listener, err := net.Listen(proto, listenAddr)
289 if err != nil {
290 return nil, err
291 }
292 if proto == "unix" {
293 os.Chmod(listenAddr, 0777)
294 }
295 return listener, nil
296}