all repos — honk @ 0ab213ded4a5cfed0de9c435bcfe542f6c33c748

my fork of honk

util.go (view raw)

  1//
  2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
  3//
  4// Permission to use, copy, modify, and distribute this software for any
  5// purpose with or without fee is hereby granted, provided that the above
  6// copyright notice and this permission notice appear in all copies.
  7//
  8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 15
 16package main
 17
 18/*
 19#include <termios.h>
 20
 21void
 22termecho(int on)
 23{
 24	struct termios t;
 25	tcgetattr(1, &t);
 26	if (on)
 27		t.c_lflag |= ECHO;
 28	else
 29		t.c_lflag &= ~ECHO;
 30	tcsetattr(1, TCSADRAIN, &t);
 31}
 32*/
 33import "C"
 34
 35import (
 36	"bufio"
 37	"crypto/rand"
 38	"crypto/rsa"
 39	"crypto/sha512"
 40	"database/sql"
 41	"fmt"
 42	"io/ioutil"
 43	"log"
 44	"net"
 45	"os"
 46	"os/signal"
 47	"strings"
 48
 49	"golang.org/x/crypto/bcrypt"
 50	_ "humungus.tedunangst.com/r/go-sqlite3"
 51)
 52
 53var savedstyleparams = make(map[string]string)
 54
 55func getstyleparam(file string) string {
 56	if p, ok := savedstyleparams[file]; ok {
 57		return p
 58	}
 59	data, err := ioutil.ReadFile(file)
 60	if err != nil {
 61		return ""
 62	}
 63	hasher := sha512.New()
 64	hasher.Write(data)
 65
 66	return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
 67}
 68
 69var dbtimeformat = "2006-01-02 15:04:05"
 70
 71var alreadyopendb *sql.DB
 72var dbname = "honk.db"
 73var stmtConfig *sql.Stmt
 74var myVersion = 12
 75
 76func initdb() {
 77	schema, err := ioutil.ReadFile("schema.sql")
 78	if err != nil {
 79		log.Fatal(err)
 80	}
 81	_, err = os.Stat(dbname)
 82	if err == nil {
 83		log.Fatalf("%s already exists", dbname)
 84	}
 85	db, err := sql.Open("sqlite3", dbname)
 86	if err != nil {
 87		log.Fatal(err)
 88	}
 89	defer func() {
 90		os.Remove(dbname)
 91		os.Exit(1)
 92	}()
 93	c := make(chan os.Signal)
 94	signal.Notify(c, os.Interrupt)
 95	go func() {
 96		<-c
 97		C.termecho(1)
 98		fmt.Printf("\n")
 99		os.Remove(dbname)
100		os.Exit(1)
101	}()
102
103	for _, line := range strings.Split(string(schema), ";") {
104		_, err = db.Exec(line)
105		if err != nil {
106			log.Print(err)
107			return
108		}
109	}
110	defer db.Close()
111	r := bufio.NewReader(os.Stdin)
112
113	err = createuser(db, r)
114	if err != nil {
115		log.Print(err)
116		return
117	}
118
119	fmt.Printf("listen address: ")
120	addr, err := r.ReadString('\n')
121	if err != nil {
122		log.Print(err)
123		return
124	}
125	addr = addr[:len(addr)-1]
126	if len(addr) < 1 {
127		log.Print("that's way too short")
128		return
129	}
130	_, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
131	if err != nil {
132		log.Print(err)
133		return
134	}
135	fmt.Printf("server name: ")
136	addr, err = r.ReadString('\n')
137	if err != nil {
138		log.Print(err)
139		return
140	}
141	addr = addr[:len(addr)-1]
142	if len(addr) < 1 {
143		log.Print("that's way too short")
144		return
145	}
146	_, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
147	if err != nil {
148		log.Print(err)
149		return
150	}
151	var randbytes [16]byte
152	rand.Read(randbytes[:])
153	key := fmt.Sprintf("%x", randbytes)
154	_, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
155	if err != nil {
156		log.Print(err)
157		return
158	}
159	_, err = db.Exec("insert into config (key, value) values (?, ?)", "dbversion", myVersion)
160	if err != nil {
161		log.Print(err)
162		return
163	}
164	prepareStatements(db)
165	db.Close()
166	fmt.Printf("done.\n")
167	os.Exit(0)
168}
169
170func adduser() {
171	db := opendatabase()
172	defer func() {
173		os.Exit(1)
174	}()
175	c := make(chan os.Signal)
176	signal.Notify(c, os.Interrupt)
177	go func() {
178		<-c
179		C.termecho(1)
180		fmt.Printf("\n")
181		os.Exit(1)
182	}()
183
184	r := bufio.NewReader(os.Stdin)
185
186	err := createuser(db, r)
187	if err != nil {
188		log.Print(err)
189		return
190	}
191
192	db.Close()
193	os.Exit(0)
194}
195
196func createuser(db *sql.DB, r *bufio.Reader) error {
197	fmt.Printf("username: ")
198	name, err := r.ReadString('\n')
199	if err != nil {
200		return err
201	}
202	name = name[:len(name)-1]
203	if len(name) < 1 {
204		return fmt.Errorf("that's way too short")
205	}
206	C.termecho(0)
207	fmt.Printf("password: ")
208	pass, err := r.ReadString('\n')
209	C.termecho(1)
210	fmt.Printf("\n")
211	if err != nil {
212		return err
213	}
214	pass = pass[:len(pass)-1]
215	if len(pass) < 6 {
216		return fmt.Errorf("that's way too short")
217	}
218	hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
219	if err != nil {
220		return err
221	}
222	k, err := rsa.GenerateKey(rand.Reader, 2048)
223	if err != nil {
224		return err
225	}
226	pubkey, err := zem(&k.PublicKey)
227	if err != nil {
228		return err
229	}
230	seckey, err := zem(k)
231	if err != nil {
232		return err
233	}
234	_, err = db.Exec("insert into users (username, displayname, about, hash, pubkey, seckey) values (?, ?, ?, ?, ?, ?)", name, name, "what about me?", hash, pubkey, seckey)
235	if err != nil {
236		return err
237	}
238	return nil
239}
240
241func opendatabase() *sql.DB {
242	if alreadyopendb != nil {
243		return alreadyopendb
244	}
245	var err error
246	_, err = os.Stat(dbname)
247	if err != nil {
248		log.Fatalf("unable to open database: %s", err)
249	}
250	db, err := sql.Open("sqlite3", dbname)
251	if err != nil {
252		log.Fatalf("unable to open database: %s", err)
253	}
254	stmtConfig, err = db.Prepare("select value from config where key = ?")
255	if err != nil {
256		log.Fatal(err)
257	}
258	alreadyopendb = db
259	return db
260}
261
262func getconfig(key string, value interface{}) error {
263	m, ok := value.(*map[string]bool)
264	if ok {
265		rows, err := stmtConfig.Query(key)
266		if err != nil {
267			return err
268		}
269		defer rows.Close()
270		for rows.Next() {
271			var s string
272			err = rows.Scan(&s)
273			if err != nil {
274				return err
275			}
276			(*m)[s] = true
277		}
278		return nil
279	}
280	row := stmtConfig.QueryRow(key)
281	err := row.Scan(value)
282	if err == sql.ErrNoRows {
283		err = nil
284	}
285	return err
286}
287
288func openListener() (net.Listener, error) {
289	var listenAddr string
290	err := getconfig("listenaddr", &listenAddr)
291	if err != nil {
292		return nil, err
293	}
294	if listenAddr == "" {
295		return nil, fmt.Errorf("must have listenaddr")
296	}
297	proto := "tcp"
298	if listenAddr[0] == '/' {
299		proto = "unix"
300		err := os.Remove(listenAddr)
301		if err != nil && !os.IsNotExist(err) {
302			log.Printf("unable to unlink socket: %s", err)
303		}
304	}
305	listener, err := net.Listen(proto, listenAddr)
306	if err != nil {
307		return nil, err
308	}
309	if proto == "unix" {
310		os.Chmod(listenAddr, 0777)
311	}
312	return listener, nil
313}