util.go (view raw)
1//
2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
3//
4// Permission to use, copy, modify, and distribute this software for any
5// purpose with or without fee is hereby granted, provided that the above
6// copyright notice and this permission notice appear in all copies.
7//
8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16package main
17
18/*
19#include <termios.h>
20
21void
22termecho(int on)
23{
24 struct termios t;
25 tcgetattr(1, &t);
26 if (on)
27 t.c_lflag |= ECHO;
28 else
29 t.c_lflag &= ~ECHO;
30 tcsetattr(1, TCSADRAIN, &t);
31}
32*/
33import "C"
34
35import (
36 "bufio"
37 "crypto/rand"
38 "crypto/sha512"
39 "database/sql"
40 "fmt"
41 "io/ioutil"
42 "log"
43 "net"
44 "os"
45 "os/signal"
46 "strings"
47
48 "golang.org/x/crypto/bcrypt"
49 _ "humungus.tedunangst.com/r/go-sqlite3"
50)
51
52var savedstyleparams = make(map[string]string)
53
54func getstyleparam(file string) string {
55 if p, ok := savedstyleparams[file]; ok {
56 return p
57 }
58 data, err := ioutil.ReadFile(file)
59 if err != nil {
60 return ""
61 }
62 hasher := sha512.New()
63 hasher.Write(data)
64
65 return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
66}
67
68var dbtimeformat = "2006-01-02 15:04:05"
69
70var alreadyopendb *sql.DB
71var dbname = "honk.db"
72var stmtConfig *sql.Stmt
73
74func initdb() {
75 schema, err := ioutil.ReadFile("schema.sql")
76 if err != nil {
77 log.Fatal(err)
78 }
79 _, err = os.Stat(dbname)
80 if err == nil {
81 log.Fatalf("%s already exists", dbname)
82 }
83 db, err := sql.Open("sqlite3", dbname)
84 if err != nil {
85 log.Fatal(err)
86 }
87 defer func() {
88 os.Remove(dbname)
89 os.Exit(1)
90 }()
91 c := make(chan os.Signal)
92 signal.Notify(c, os.Interrupt)
93 go func() {
94 <-c
95 C.termecho(1)
96 fmt.Printf("\n")
97 os.Remove(dbname)
98 os.Exit(1)
99 }()
100
101 for _, line := range strings.Split(string(schema), ";") {
102 _, err = db.Exec(line)
103 if err != nil {
104 log.Print(err)
105 return
106 }
107 }
108 defer db.Close()
109 r := bufio.NewReader(os.Stdin)
110 fmt.Printf("username: ")
111 name, err := r.ReadString('\n')
112 if err != nil {
113 log.Print(err)
114 return
115 }
116 name = name[:len(name)-1]
117 if len(name) < 1 {
118 log.Print("that's way too short")
119 return
120 }
121 C.termecho(0)
122 fmt.Printf("password: ")
123 pass, err := r.ReadString('\n')
124 C.termecho(1)
125 fmt.Printf("\n")
126 if err != nil {
127 log.Print(err)
128 return
129 }
130 pass = pass[:len(pass)-1]
131 if len(pass) < 6 {
132 log.Print("that's way too short")
133 return
134 }
135 hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
136 if err != nil {
137 log.Print(err)
138 return
139 }
140 _, err = db.Exec("insert into users (username, hash) values (?, ?)", name, hash)
141 if err != nil {
142 log.Print(err)
143 return
144 }
145 fmt.Printf("listen address: ")
146 addr, err := r.ReadString('\n')
147 if err != nil {
148 log.Print(err)
149 return
150 }
151 addr = addr[:len(addr)-1]
152 if len(addr) < 1 {
153 log.Print("that's way too short")
154 return
155 }
156 _, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
157 if err != nil {
158 log.Print(err)
159 return
160 }
161 fmt.Printf("server name: ")
162 addr, err = r.ReadString('\n')
163 if err != nil {
164 log.Print(err)
165 return
166 }
167 addr = addr[:len(addr)-1]
168 if len(addr) < 1 {
169 log.Print("that's way too short")
170 return
171 }
172 _, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
173 if err != nil {
174 log.Print(err)
175 return
176 }
177 var randbytes [16]byte
178 rand.Read(randbytes[:])
179 key := fmt.Sprintf("%x", randbytes)
180 _, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
181 if err != nil {
182 log.Print(err)
183 return
184 }
185 err = finishusersetup()
186 if err != nil {
187 log.Print(err)
188 return
189 }
190 prepareStatements(db)
191 db.Close()
192 fmt.Printf("done.\n")
193 os.Exit(0)
194}
195
196func opendatabase() *sql.DB {
197 if alreadyopendb != nil {
198 return alreadyopendb
199 }
200 var err error
201 _, err = os.Stat(dbname)
202 if err != nil {
203 log.Fatalf("unable to open database: %s", err)
204 }
205 db, err := sql.Open("sqlite3", dbname)
206 if err != nil {
207 log.Fatalf("unable to open database: %s", err)
208 }
209 stmtConfig, err = db.Prepare("select value from config where key = ?")
210 if err != nil {
211 log.Fatal(err)
212 }
213 alreadyopendb = db
214 return db
215}
216
217func getconfig(key string, value interface{}) error {
218 row := stmtConfig.QueryRow(key)
219 err := row.Scan(value)
220 if err == sql.ErrNoRows {
221 err = nil
222 }
223 return err
224}
225
226func openListener() (net.Listener, error) {
227 var listenAddr string
228 err := getconfig("listenaddr", &listenAddr)
229 if err != nil {
230 return nil, err
231 }
232 if listenAddr == "" {
233 return nil, fmt.Errorf("must have listenaddr")
234 }
235 proto := "tcp"
236 if listenAddr[0] == '/' {
237 proto = "unix"
238 err := os.Remove(listenAddr)
239 if err != nil && !os.IsNotExist(err) {
240 log.Printf("unable to unlink socket: %s", err)
241 }
242 }
243 listener, err := net.Listen(proto, listenAddr)
244 if err != nil {
245 return nil, err
246 }
247 if proto == "unix" {
248 os.Chmod(listenAddr, 0777)
249 }
250 return listener, nil
251}