util.go (view raw)
1//
2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
3//
4// Permission to use, copy, modify, and distribute this software for any
5// purpose with or without fee is hereby granted, provided that the above
6// copyright notice and this permission notice appear in all copies.
7//
8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
15
16package main
17
18import (
19 "bufio"
20 "crypto/rand"
21 "crypto/sha512"
22 "database/sql"
23 "fmt"
24 "io/ioutil"
25 "log"
26 "net"
27 "os"
28 "os/signal"
29 "strings"
30
31 "golang.org/x/crypto/bcrypt"
32 _ "humungus.tedunangst.com/r/go-sqlite3"
33)
34
35var savedstyleparam string
36
37func getstyleparam() string {
38 if savedstyleparam != "" {
39 return savedstyleparam
40 }
41 data, _ := ioutil.ReadFile("views/style.css")
42 hasher := sha512.New()
43 hasher.Write(data)
44 return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
45}
46
47var dbtimeformat = "2006-01-02 15:04:05"
48
49var alreadyopendb *sql.DB
50var dbname = "honk.db"
51var stmtConfig *sql.Stmt
52
53func initdb() {
54 schema, err := ioutil.ReadFile("schema.sql")
55 if err != nil {
56 log.Fatal(err)
57 }
58 _, err = os.Stat(dbname)
59 if err == nil {
60 log.Fatalf("%s already exists", dbname)
61 }
62 db, err := sql.Open("sqlite3", dbname)
63 if err != nil {
64 log.Fatal(err)
65 }
66 defer func() {
67 os.Remove(dbname)
68 os.Exit(1)
69 }()
70 c := make(chan os.Signal)
71 signal.Notify(c, os.Interrupt)
72 go func() {
73 <-c
74 fmt.Printf("\x1b[?12;25h\x1b[0m")
75 fmt.Printf("\n")
76 os.Remove(dbname)
77 os.Exit(1)
78 }()
79 for _, line := range strings.Split(string(schema), ";") {
80 _, err = db.Exec(line)
81 if err != nil {
82 log.Print(err)
83 return
84 }
85 }
86 defer db.Close()
87 r := bufio.NewReader(os.Stdin)
88 fmt.Printf("username: ")
89 name, err := r.ReadString('\n')
90 if err != nil {
91 log.Print(err)
92 return
93 }
94 name = name[:len(name)-1]
95 if len(name) < 1 {
96 log.Print("that's way too short")
97 return
98 }
99 fmt.Printf("password: \x1b[?25l\x1b[%d;%dm \x1b[16D", 30, 40)
100 pass, err := r.ReadString('\n')
101 fmt.Printf("\x1b[0m\x1b[?12;25h")
102 if err != nil {
103 log.Fatal(err)
104 return
105 }
106 pass = pass[:len(pass)-1]
107 if len(pass) < 6 {
108 log.Print("that's way too short")
109 return
110 }
111 hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
112 if err != nil {
113 log.Print(err)
114 return
115 }
116 _, err = db.Exec("insert into users (username, hash) values (?, ?)", name, hash)
117 if err != nil {
118 log.Print(err)
119 return
120 }
121 fmt.Printf("listen address: ")
122 addr, err := r.ReadString('\n')
123 if err != nil {
124 log.Print(err)
125 return
126 }
127 addr = addr[:len(addr)-1]
128 if len(addr) < 1 {
129 log.Print("that's way too short")
130 return
131 }
132 _, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
133 if err != nil {
134 log.Print(err)
135 return
136 }
137 fmt.Printf("server name: ")
138 addr, err = r.ReadString('\n')
139 if err != nil {
140 log.Print(err)
141 return
142 }
143 addr = addr[:len(addr)-1]
144 if len(addr) < 1 {
145 log.Print("that's way too short")
146 return
147 }
148 _, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
149 if err != nil {
150 log.Print(err)
151 return
152 }
153 var randbytes [16]byte
154 rand.Read(randbytes[:])
155 key := fmt.Sprintf("%x", randbytes)
156 _, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
157 if err != nil {
158 log.Print(err)
159 return
160 }
161 err = finishusersetup()
162 if err != nil {
163 log.Print(err)
164 return
165 }
166 prepareStatements(db)
167 db.Close()
168 fmt.Printf("done.\n")
169 os.Exit(0)
170}
171
172func opendatabase() *sql.DB {
173 if alreadyopendb != nil {
174 return alreadyopendb
175 }
176 var err error
177 _, err = os.Stat(dbname)
178 if err != nil {
179 log.Fatalf("unable to open database: %s", err)
180 }
181 db, err := sql.Open("sqlite3", dbname)
182 if err != nil {
183 log.Fatalf("unable to open database: %s", err)
184 }
185 stmtConfig, err = db.Prepare("select value from config where key = ?")
186 if err != nil {
187 log.Fatal(err)
188 }
189 alreadyopendb = db
190 return db
191}
192
193func getconfig(key string, value interface{}) error {
194 row := stmtConfig.QueryRow(key)
195 err := row.Scan(value)
196 if err == sql.ErrNoRows {
197 err = nil
198 }
199 return err
200}
201
202func openListener() (net.Listener, error) {
203 var listenAddr string
204 err := getconfig("listenaddr", &listenAddr)
205 if err != nil {
206 return nil, err
207 }
208 if listenAddr == "" {
209 return nil, fmt.Errorf("must have listenaddr")
210 }
211 proto := "tcp"
212 if listenAddr[0] == '/' {
213 proto = "unix"
214 err := os.Remove(listenAddr)
215 if err != nil && !os.IsNotExist(err) {
216 log.Printf("unable to unlink socket: %s", err)
217 }
218 }
219 listener, err := net.Listen(proto, listenAddr)
220 if err != nil {
221 return nil, err
222 }
223 if proto == "unix" {
224 os.Chmod(listenAddr, 0777)
225 }
226 return listener, nil
227}