util.go (view raw)
1//
2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
3//
4// Permission to use, copy, modify, and distribute this software for any
5// purpose with or without fee is hereby granted, provided that the above
6// copyright notice and this permission notice appear in all copies.
7//
8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16package main
17
18import (
19 "bufio"
20 "crypto/rand"
21 "crypto/sha512"
22 "database/sql"
23 "fmt"
24 "io/ioutil"
25 "log"
26 "net"
27 "os"
28 "os/signal"
29 "strings"
30
31 "golang.org/x/crypto/bcrypt"
32 "golang.org/x/crypto/ssh/terminal"
33 _ "humungus.tedunangst.com/r/go-sqlite3"
34)
35
36var savedstyleparam string
37
38func getstyleparam() string {
39 if savedstyleparam != "" {
40 return savedstyleparam
41 }
42 data, _ := ioutil.ReadFile("views/style.css")
43 hasher := sha512.New()
44 hasher.Write(data)
45 return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
46}
47
48var dbtimeformat = "2006-01-02 15:04:05"
49
50var alreadyopendb *sql.DB
51var dbname = "honk.db"
52var stmtConfig *sql.Stmt
53
54func initdb() {
55 schema, err := ioutil.ReadFile("schema.sql")
56 if err != nil {
57 log.Fatal(err)
58 }
59 _, err = os.Stat(dbname)
60 if err == nil {
61 log.Fatalf("%s already exists", dbname)
62 }
63 db, err := sql.Open("sqlite3", dbname)
64 if err != nil {
65 log.Fatal(err)
66 }
67 defer func() {
68 os.Remove(dbname)
69 os.Exit(1)
70 }()
71 c := make(chan os.Signal)
72 signal.Notify(c, os.Interrupt)
73 go func() {
74 <-c
75 fmt.Printf("\n")
76 os.Remove(dbname)
77 os.Exit(1)
78 }()
79
80 for _, line := range strings.Split(string(schema), ";") {
81 _, err = db.Exec(line)
82 if err != nil {
83 log.Print(err)
84 return
85 }
86 }
87 defer db.Close()
88 r := bufio.NewReader(os.Stdin)
89 fmt.Printf("username: ")
90 name, err := r.ReadString('\n')
91 if err != nil {
92 log.Print(err)
93 return
94 }
95 name = name[:len(name)-1]
96 if len(name) < 1 {
97 log.Print("that's way too short")
98 return
99 }
100 fmt.Printf("password: ")
101 passbytes, err := terminal.ReadPassword(1)
102 fmt.Printf("\n")
103 if err != nil {
104 log.Print(err)
105 return
106 }
107 pass := string(passbytes)
108 if len(pass) < 6 {
109 log.Print("that's way too short")
110 return
111 }
112 hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
113 if err != nil {
114 log.Print(err)
115 return
116 }
117 _, err = db.Exec("insert into users (username, hash) values (?, ?)", name, hash)
118 if err != nil {
119 log.Print(err)
120 return
121 }
122 fmt.Printf("listen address: ")
123 addr, err := r.ReadString('\n')
124 if err != nil {
125 log.Print(err)
126 return
127 }
128 addr = addr[:len(addr)-1]
129 if len(addr) < 1 {
130 log.Print("that's way too short")
131 return
132 }
133 _, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
134 if err != nil {
135 log.Print(err)
136 return
137 }
138 fmt.Printf("server name: ")
139 addr, err = r.ReadString('\n')
140 if err != nil {
141 log.Print(err)
142 return
143 }
144 addr = addr[:len(addr)-1]
145 if len(addr) < 1 {
146 log.Print("that's way too short")
147 return
148 }
149 _, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
150 if err != nil {
151 log.Print(err)
152 return
153 }
154 var randbytes [16]byte
155 rand.Read(randbytes[:])
156 key := fmt.Sprintf("%x", randbytes)
157 _, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
158 if err != nil {
159 log.Print(err)
160 return
161 }
162 err = finishusersetup()
163 if err != nil {
164 log.Print(err)
165 return
166 }
167 prepareStatements(db)
168 db.Close()
169 fmt.Printf("done.\n")
170 os.Exit(0)
171}
172
173func opendatabase() *sql.DB {
174 if alreadyopendb != nil {
175 return alreadyopendb
176 }
177 var err error
178 _, err = os.Stat(dbname)
179 if err != nil {
180 log.Fatalf("unable to open database: %s", err)
181 }
182 db, err := sql.Open("sqlite3", dbname)
183 if err != nil {
184 log.Fatalf("unable to open database: %s", err)
185 }
186 stmtConfig, err = db.Prepare("select value from config where key = ?")
187 if err != nil {
188 log.Fatal(err)
189 }
190 alreadyopendb = db
191 return db
192}
193
194func getconfig(key string, value interface{}) error {
195 row := stmtConfig.QueryRow(key)
196 err := row.Scan(value)
197 if err == sql.ErrNoRows {
198 err = nil
199 }
200 return err
201}
202
203func openListener() (net.Listener, error) {
204 var listenAddr string
205 err := getconfig("listenaddr", &listenAddr)
206 if err != nil {
207 return nil, err
208 }
209 if listenAddr == "" {
210 return nil, fmt.Errorf("must have listenaddr")
211 }
212 proto := "tcp"
213 if listenAddr[0] == '/' {
214 proto = "unix"
215 err := os.Remove(listenAddr)
216 if err != nil && !os.IsNotExist(err) {
217 log.Printf("unable to unlink socket: %s", err)
218 }
219 }
220 listener, err := net.Listen(proto, listenAddr)
221 if err != nil {
222 return nil, err
223 }
224 if proto == "unix" {
225 os.Chmod(listenAddr, 0777)
226 }
227 return listener, nil
228}