all repos — honk @ 0fefd080a1f248b0437c4228d0f0a3084338ae61

my fork of honk

util.go (view raw)

  1//
  2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
  3//
  4// Permission to use, copy, modify, and distribute this software for any
  5// purpose with or without fee is hereby granted, provided that the above
  6// copyright notice and this permission notice appear in all copies.
  7//
  8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 15
 16package main
 17
 18/*
 19#include <termios.h>
 20
 21void
 22termecho(int on)
 23{
 24	struct termios t;
 25	tcgetattr(1, &t);
 26	if (on)
 27		t.c_lflag |= ECHO;
 28	else
 29		t.c_lflag &= ~ECHO;
 30	tcsetattr(1, TCSADRAIN, &t);
 31}
 32*/
 33import "C"
 34
 35import (
 36	"bufio"
 37	"crypto/rand"
 38	"crypto/rsa"
 39	"crypto/sha512"
 40	"database/sql"
 41	"fmt"
 42	"io/ioutil"
 43	"log"
 44	"net"
 45	"os"
 46	"os/signal"
 47	"strings"
 48
 49	"golang.org/x/crypto/bcrypt"
 50	_ "humungus.tedunangst.com/r/go-sqlite3"
 51	"humungus.tedunangst.com/r/webs/httpsig"
 52	"humungus.tedunangst.com/r/webs/login"
 53)
 54
 55var savedassetparams = make(map[string]string)
 56
 57func getassetparam(file string) string {
 58	if p, ok := savedassetparams[file]; ok {
 59		return p
 60	}
 61	data, err := ioutil.ReadFile(file)
 62	if err != nil {
 63		return ""
 64	}
 65	hasher := sha512.New()
 66	hasher.Write(data)
 67
 68	return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
 69}
 70
 71var dbtimeformat = "2006-01-02 15:04:05"
 72
 73var alreadyopendb *sql.DB
 74var dbname = "honk.db"
 75var blobdbname = "blob.db"
 76var stmtConfig *sql.Stmt
 77
 78func initdb() {
 79	schema, err := ioutil.ReadFile("schema.sql")
 80	if err != nil {
 81		log.Fatal(err)
 82	}
 83	_, err = os.Stat(dbname)
 84	if err == nil {
 85		log.Fatalf("%s already exists", dbname)
 86	}
 87	db, err := sql.Open("sqlite3", dbname)
 88	if err != nil {
 89		log.Fatal(err)
 90	}
 91	defer func() {
 92		os.Remove(dbname)
 93		os.Exit(1)
 94	}()
 95	c := make(chan os.Signal)
 96	signal.Notify(c, os.Interrupt)
 97	go func() {
 98		<-c
 99		C.termecho(1)
100		fmt.Printf("\n")
101		os.Remove(dbname)
102		os.Exit(1)
103	}()
104
105	for _, line := range strings.Split(string(schema), ";") {
106		_, err = db.Exec(line)
107		if err != nil {
108			log.Print(err)
109			return
110		}
111	}
112	defer db.Close()
113	r := bufio.NewReader(os.Stdin)
114
115	err = createuser(db, r)
116	if err != nil {
117		log.Print(err)
118		return
119	}
120
121	fmt.Printf("listen address: ")
122	addr, err := r.ReadString('\n')
123	if err != nil {
124		log.Print(err)
125		return
126	}
127	addr = addr[:len(addr)-1]
128	if len(addr) < 1 {
129		log.Print("that's way too short")
130		return
131	}
132	_, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
133	if err != nil {
134		log.Print(err)
135		return
136	}
137	fmt.Printf("server name: ")
138	addr, err = r.ReadString('\n')
139	if err != nil {
140		log.Print(err)
141		return
142	}
143	addr = addr[:len(addr)-1]
144	if len(addr) < 1 {
145		log.Print("that's way too short")
146		return
147	}
148	_, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
149	if err != nil {
150		log.Print(err)
151		return
152	}
153	var randbytes [16]byte
154	rand.Read(randbytes[:])
155	key := fmt.Sprintf("%x", randbytes)
156	_, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
157	if err != nil {
158		log.Print(err)
159		return
160	}
161	_, err = db.Exec("insert into config (key, value) values (?, ?)", "dbversion", myVersion)
162	if err != nil {
163		log.Print(err)
164		return
165	}
166
167	initblobdb()
168
169	prepareStatements(db)
170	db.Close()
171	fmt.Printf("done.\n")
172	os.Exit(0)
173}
174
175func initblobdb() {
176	_, err := os.Stat(blobdbname)
177	if err == nil {
178		log.Fatalf("%s already exists", blobdbname)
179	}
180	blobdb, err := sql.Open("sqlite3", blobdbname)
181	if err != nil {
182		log.Print(err)
183		return
184	}
185	_, err = blobdb.Exec("create table filedata (xid text, media text, content blob)")
186	if err != nil {
187		log.Print(err)
188		return
189	}
190	_, err = blobdb.Exec("create index idx_filexid on filedata(xid)")
191	if err != nil {
192		log.Print(err)
193		return
194	}
195	blobdb.Close()
196}
197
198func adduser() {
199	db := opendatabase()
200	defer func() {
201		os.Exit(1)
202	}()
203	c := make(chan os.Signal)
204	signal.Notify(c, os.Interrupt)
205	go func() {
206		<-c
207		C.termecho(1)
208		fmt.Printf("\n")
209		os.Exit(1)
210	}()
211
212	r := bufio.NewReader(os.Stdin)
213
214	err := createuser(db, r)
215	if err != nil {
216		log.Print(err)
217		return
218	}
219
220	db.Close()
221	os.Exit(0)
222}
223
224func chpass() {
225	if len(os.Args) < 3 {
226		fmt.Printf("need a username\n")
227		os.Exit(1)
228	}
229	user, err := butwhatabout(os.Args[2])
230	if err != nil {
231		log.Fatal(err)
232	}
233	defer func() {
234		os.Exit(1)
235	}()
236	c := make(chan os.Signal)
237	signal.Notify(c, os.Interrupt)
238	go func() {
239		<-c
240		C.termecho(1)
241		fmt.Printf("\n")
242		os.Exit(1)
243	}()
244
245	db := opendatabase()
246	login.Init(db)
247
248	r := bufio.NewReader(os.Stdin)
249
250	pass, err := askpassword(r)
251	if err != nil {
252		log.Print(err)
253		return
254	}
255	err = login.SetPassword(user.ID, pass)
256	if err != nil {
257		log.Print(err)
258		return
259	}
260	fmt.Printf("done\n")
261	os.Exit(0)
262}
263
264func askpassword(r *bufio.Reader) (string, error) {
265	C.termecho(0)
266	fmt.Printf("password: ")
267	pass, err := r.ReadString('\n')
268	C.termecho(1)
269	fmt.Printf("\n")
270	if err != nil {
271		return "", err
272	}
273	pass = pass[:len(pass)-1]
274	if len(pass) < 6 {
275		return "", fmt.Errorf("that's way too short")
276	}
277	return pass, nil
278}
279
280func createuser(db *sql.DB, r *bufio.Reader) error {
281	fmt.Printf("username: ")
282	name, err := r.ReadString('\n')
283	if err != nil {
284		return err
285	}
286	name = name[:len(name)-1]
287	if len(name) < 1 {
288		return fmt.Errorf("that's way too short")
289	}
290	pass, err := askpassword(r)
291	if err != nil {
292		return err
293	}
294	hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
295	if err != nil {
296		return err
297	}
298	k, err := rsa.GenerateKey(rand.Reader, 2048)
299	if err != nil {
300		return err
301	}
302	pubkey, err := httpsig.EncodeKey(&k.PublicKey)
303	if err != nil {
304		return err
305	}
306	seckey, err := httpsig.EncodeKey(k)
307	if err != nil {
308		return err
309	}
310	_, err = db.Exec("insert into users (username, displayname, about, hash, pubkey, seckey, options) values (?, ?, ?, ?, ?, ?, ?)", name, name, "what about me?", hash, pubkey, seckey, "")
311	if err != nil {
312		return err
313	}
314	return nil
315}
316
317func opendatabase() *sql.DB {
318	if alreadyopendb != nil {
319		return alreadyopendb
320	}
321	var err error
322	_, err = os.Stat(dbname)
323	if err != nil {
324		log.Fatalf("unable to open database: %s", err)
325	}
326	db, err := sql.Open("sqlite3", dbname)
327	if err != nil {
328		log.Fatalf("unable to open database: %s", err)
329	}
330	stmtConfig, err = db.Prepare("select value from config where key = ?")
331	if err != nil {
332		log.Fatal(err)
333	}
334	alreadyopendb = db
335	return db
336}
337
338func openblobdb() *sql.DB {
339	var err error
340	_, err = os.Stat(blobdbname)
341	if err != nil {
342		log.Fatalf("unable to open database: %s", err)
343	}
344	db, err := sql.Open("sqlite3", blobdbname)
345	if err != nil {
346		log.Fatalf("unable to open database: %s", err)
347	}
348	return db
349}
350
351func getconfig(key string, value interface{}) error {
352	m, ok := value.(*map[string]bool)
353	if ok {
354		rows, err := stmtConfig.Query(key)
355		if err != nil {
356			return err
357		}
358		defer rows.Close()
359		for rows.Next() {
360			var s string
361			err = rows.Scan(&s)
362			if err != nil {
363				return err
364			}
365			(*m)[s] = true
366		}
367		return nil
368	}
369	row := stmtConfig.QueryRow(key)
370	err := row.Scan(value)
371	if err == sql.ErrNoRows {
372		err = nil
373	}
374	return err
375}
376
377func saveconfig(key string, val interface{}) {
378	db := opendatabase()
379	db.Exec("update config set value = ? where key = ?", val, key)
380}
381
382func openListener() (net.Listener, error) {
383	var listenAddr string
384	err := getconfig("listenaddr", &listenAddr)
385	if err != nil {
386		return nil, err
387	}
388	if listenAddr == "" {
389		return nil, fmt.Errorf("must have listenaddr")
390	}
391	proto := "tcp"
392	if listenAddr[0] == '/' {
393		proto = "unix"
394		err := os.Remove(listenAddr)
395		if err != nil && !os.IsNotExist(err) {
396			log.Printf("unable to unlink socket: %s", err)
397		}
398	}
399	listener, err := net.Listen(proto, listenAddr)
400	if err != nil {
401		return nil, err
402	}
403	if proto == "unix" {
404		os.Chmod(listenAddr, 0777)
405	}
406	return listener, nil
407}