all repos — honk @ b8252156b2ede26d315cf39e8e0ef88b4d703943

my fork of honk

util.go (view raw)

  1//
  2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
  3//
  4// Permission to use, copy, modify, and distribute this software for any
  5// purpose with or without fee is hereby granted, provided that the above
  6// copyright notice and this permission notice appear in all copies.
  7//
  8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 15
 16package main
 17
 18/*
 19#include <termios.h>
 20
 21void
 22termecho(int on)
 23{
 24	struct termios t;
 25	tcgetattr(1, &t);
 26	if (on)
 27		t.c_lflag |= ECHO;
 28	else
 29		t.c_lflag &= ~ECHO;
 30	tcsetattr(1, TCSADRAIN, &t);
 31}
 32*/
 33import "C"
 34
 35import (
 36	"bufio"
 37	"crypto/rand"
 38	"crypto/rsa"
 39	"crypto/sha512"
 40	"database/sql"
 41	"fmt"
 42	"io/ioutil"
 43	"log"
 44	"net"
 45	"os"
 46	"os/signal"
 47	"strings"
 48
 49	"golang.org/x/crypto/bcrypt"
 50	_ "humungus.tedunangst.com/r/go-sqlite3"
 51	"humungus.tedunangst.com/r/webs/httpsig"
 52	"humungus.tedunangst.com/r/webs/login"
 53)
 54
 55var savedassetparams = make(map[string]string)
 56
 57func getassetparam(file string) string {
 58	if p, ok := savedassetparams[file]; ok {
 59		return p
 60	}
 61	data, err := ioutil.ReadFile(file)
 62	if err != nil {
 63		return ""
 64	}
 65	hasher := sha512.New()
 66	hasher.Write(data)
 67
 68	return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
 69}
 70
 71var dbtimeformat = "2006-01-02 15:04:05"
 72
 73var alreadyopendb *sql.DB
 74var dbname = "honk.db"
 75var blobdbname = "blob.db"
 76var stmtConfig *sql.Stmt
 77var myVersion = 25
 78
 79func initdb() {
 80	schema, err := ioutil.ReadFile("schema.sql")
 81	if err != nil {
 82		log.Fatal(err)
 83	}
 84	_, err = os.Stat(dbname)
 85	if err == nil {
 86		log.Fatalf("%s already exists", dbname)
 87	}
 88	db, err := sql.Open("sqlite3", dbname)
 89	if err != nil {
 90		log.Fatal(err)
 91	}
 92	defer func() {
 93		os.Remove(dbname)
 94		os.Exit(1)
 95	}()
 96	c := make(chan os.Signal)
 97	signal.Notify(c, os.Interrupt)
 98	go func() {
 99		<-c
100		C.termecho(1)
101		fmt.Printf("\n")
102		os.Remove(dbname)
103		os.Exit(1)
104	}()
105
106	for _, line := range strings.Split(string(schema), ";") {
107		_, err = db.Exec(line)
108		if err != nil {
109			log.Print(err)
110			return
111		}
112	}
113	defer db.Close()
114	r := bufio.NewReader(os.Stdin)
115
116	err = createuser(db, r)
117	if err != nil {
118		log.Print(err)
119		return
120	}
121
122	fmt.Printf("listen address: ")
123	addr, err := r.ReadString('\n')
124	if err != nil {
125		log.Print(err)
126		return
127	}
128	addr = addr[:len(addr)-1]
129	if len(addr) < 1 {
130		log.Print("that's way too short")
131		return
132	}
133	_, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
134	if err != nil {
135		log.Print(err)
136		return
137	}
138	fmt.Printf("server name: ")
139	addr, err = r.ReadString('\n')
140	if err != nil {
141		log.Print(err)
142		return
143	}
144	addr = addr[:len(addr)-1]
145	if len(addr) < 1 {
146		log.Print("that's way too short")
147		return
148	}
149	_, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
150	if err != nil {
151		log.Print(err)
152		return
153	}
154	var randbytes [16]byte
155	rand.Read(randbytes[:])
156	key := fmt.Sprintf("%x", randbytes)
157	_, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
158	if err != nil {
159		log.Print(err)
160		return
161	}
162	_, err = db.Exec("insert into config (key, value) values (?, ?)", "dbversion", myVersion)
163	if err != nil {
164		log.Print(err)
165		return
166	}
167
168	initblobdb()
169
170	prepareStatements(db)
171	db.Close()
172	fmt.Printf("done.\n")
173	os.Exit(0)
174}
175
176func initblobdb() {
177	_, err := os.Stat(blobdbname)
178	if err == nil {
179		log.Fatalf("%s already exists", blobdbname)
180	}
181	blobdb, err := sql.Open("sqlite3", blobdbname)
182	if err != nil {
183		log.Print(err)
184		return
185	}
186	_, err = blobdb.Exec("create table filedata (xid text, media text, content blob)")
187	if err != nil {
188		log.Print(err)
189		return
190	}
191	_, err = blobdb.Exec("create index idx_filexid on filedata(xid)")
192	if err != nil {
193		log.Print(err)
194		return
195	}
196	blobdb.Close()
197}
198
199func adduser() {
200	db := opendatabase()
201	defer func() {
202		os.Exit(1)
203	}()
204	c := make(chan os.Signal)
205	signal.Notify(c, os.Interrupt)
206	go func() {
207		<-c
208		C.termecho(1)
209		fmt.Printf("\n")
210		os.Exit(1)
211	}()
212
213	r := bufio.NewReader(os.Stdin)
214
215	err := createuser(db, r)
216	if err != nil {
217		log.Print(err)
218		return
219	}
220
221	db.Close()
222	os.Exit(0)
223}
224
225func chpass() {
226	if len(os.Args) < 3 {
227		fmt.Printf("need a username\n")
228		os.Exit(1)
229	}
230	user, err := butwhatabout(os.Args[2])
231	if err != nil {
232		log.Fatal(err)
233	}
234	defer func() {
235		os.Exit(1)
236	}()
237	c := make(chan os.Signal)
238	signal.Notify(c, os.Interrupt)
239	go func() {
240		<-c
241		C.termecho(1)
242		fmt.Printf("\n")
243		os.Exit(1)
244	}()
245
246	db := opendatabase()
247	login.Init(db)
248
249	r := bufio.NewReader(os.Stdin)
250
251	pass, err := askpassword(r)
252	if err != nil {
253		log.Print(err)
254		return
255	}
256	err = login.SetPassword(user.ID, pass)
257	if err != nil {
258		log.Print(err)
259		return
260	}
261	fmt.Printf("done\n")
262	os.Exit(0)
263}
264
265func askpassword(r *bufio.Reader) (string, error) {
266	C.termecho(0)
267	fmt.Printf("password: ")
268	pass, err := r.ReadString('\n')
269	C.termecho(1)
270	fmt.Printf("\n")
271	if err != nil {
272		return "", err
273	}
274	pass = pass[:len(pass)-1]
275	if len(pass) < 6 {
276		return "", fmt.Errorf("that's way too short")
277	}
278	return pass, nil
279}
280
281func createuser(db *sql.DB, r *bufio.Reader) error {
282	fmt.Printf("username: ")
283	name, err := r.ReadString('\n')
284	if err != nil {
285		return err
286	}
287	name = name[:len(name)-1]
288	if len(name) < 1 {
289		return fmt.Errorf("that's way too short")
290	}
291	pass, err := askpassword(r)
292	if err != nil {
293		return err
294	}
295	hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
296	if err != nil {
297		return err
298	}
299	k, err := rsa.GenerateKey(rand.Reader, 2048)
300	if err != nil {
301		return err
302	}
303	pubkey, err := httpsig.EncodeKey(&k.PublicKey)
304	if err != nil {
305		return err
306	}
307	seckey, err := httpsig.EncodeKey(k)
308	if err != nil {
309		return err
310	}
311	_, err = db.Exec("insert into users (username, displayname, about, hash, pubkey, seckey, options) values (?, ?, ?, ?, ?, ?, ?)", name, name, "what about me?", hash, pubkey, seckey, "")
312	if err != nil {
313		return err
314	}
315	return nil
316}
317
318func opendatabase() *sql.DB {
319	if alreadyopendb != nil {
320		return alreadyopendb
321	}
322	var err error
323	_, err = os.Stat(dbname)
324	if err != nil {
325		log.Fatalf("unable to open database: %s", err)
326	}
327	db, err := sql.Open("sqlite3", dbname)
328	if err != nil {
329		log.Fatalf("unable to open database: %s", err)
330	}
331	stmtConfig, err = db.Prepare("select value from config where key = ?")
332	if err != nil {
333		log.Fatal(err)
334	}
335	alreadyopendb = db
336	return db
337}
338
339func openblobdb() *sql.DB {
340	var err error
341	_, err = os.Stat(blobdbname)
342	if err != nil {
343		log.Fatalf("unable to open database: %s", err)
344	}
345	db, err := sql.Open("sqlite3", blobdbname)
346	if err != nil {
347		log.Fatalf("unable to open database: %s", err)
348	}
349	return db
350}
351
352func getconfig(key string, value interface{}) error {
353	m, ok := value.(*map[string]bool)
354	if ok {
355		rows, err := stmtConfig.Query(key)
356		if err != nil {
357			return err
358		}
359		defer rows.Close()
360		for rows.Next() {
361			var s string
362			err = rows.Scan(&s)
363			if err != nil {
364				return err
365			}
366			(*m)[s] = true
367		}
368		return nil
369	}
370	row := stmtConfig.QueryRow(key)
371	err := row.Scan(value)
372	if err == sql.ErrNoRows {
373		err = nil
374	}
375	return err
376}
377
378func saveconfig(key string, val interface{}) {
379	db := opendatabase()
380	db.Exec("update config set value = ? where key = ?", val, key)
381}
382
383func openListener() (net.Listener, error) {
384	var listenAddr string
385	err := getconfig("listenaddr", &listenAddr)
386	if err != nil {
387		return nil, err
388	}
389	if listenAddr == "" {
390		return nil, fmt.Errorf("must have listenaddr")
391	}
392	proto := "tcp"
393	if listenAddr[0] == '/' {
394		proto = "unix"
395		err := os.Remove(listenAddr)
396		if err != nil && !os.IsNotExist(err) {
397			log.Printf("unable to unlink socket: %s", err)
398		}
399	}
400	listener, err := net.Listen(proto, listenAddr)
401	if err != nil {
402		return nil, err
403	}
404	if proto == "unix" {
405		os.Chmod(listenAddr, 0777)
406	}
407	return listener, nil
408}