all repos — honk @ d516d9fda9cc4d5d6191184b675b69edde1cfe57

my fork of honk

util.go (view raw)

  1//
  2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
  3//
  4// Permission to use, copy, modify, and distribute this software for any
  5// purpose with or without fee is hereby granted, provided that the above
  6// copyright notice and this permission notice appear in all copies.
  7//
  8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 15
 16package main
 17
 18/*
 19#include <termios.h>
 20
 21void
 22termecho(int on)
 23{
 24	struct termios t;
 25	tcgetattr(1, &t);
 26	if (on)
 27		t.c_lflag |= ECHO;
 28	else
 29		t.c_lflag &= ~ECHO;
 30	tcsetattr(1, TCSADRAIN, &t);
 31}
 32*/
 33import "C"
 34
 35import (
 36	"bufio"
 37	"crypto/rand"
 38	"crypto/rsa"
 39	"crypto/sha512"
 40	"database/sql"
 41	"fmt"
 42	"io/ioutil"
 43	"log"
 44	"net"
 45	"os"
 46	"os/signal"
 47	"strings"
 48
 49	"golang.org/x/crypto/bcrypt"
 50	_ "humungus.tedunangst.com/r/go-sqlite3"
 51	"humungus.tedunangst.com/r/webs/httpsig"
 52)
 53
 54var savedassetparams = make(map[string]string)
 55
 56func getassetparam(file string) string {
 57	if p, ok := savedassetparams[file]; ok {
 58		return p
 59	}
 60	data, err := ioutil.ReadFile(file)
 61	if err != nil {
 62		return ""
 63	}
 64	hasher := sha512.New()
 65	hasher.Write(data)
 66
 67	return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
 68}
 69
 70var dbtimeformat = "2006-01-02 15:04:05"
 71
 72var alreadyopendb *sql.DB
 73var dbname = "honk.db"
 74var blobdbname = "blob.db"
 75var stmtConfig *sql.Stmt
 76var myVersion = 24
 77
 78func initdb() {
 79	schema, err := ioutil.ReadFile("schema.sql")
 80	if err != nil {
 81		log.Fatal(err)
 82	}
 83	_, err = os.Stat(dbname)
 84	if err == nil {
 85		log.Fatalf("%s already exists", dbname)
 86	}
 87	db, err := sql.Open("sqlite3", dbname)
 88	if err != nil {
 89		log.Fatal(err)
 90	}
 91	defer func() {
 92		os.Remove(dbname)
 93		os.Exit(1)
 94	}()
 95	c := make(chan os.Signal)
 96	signal.Notify(c, os.Interrupt)
 97	go func() {
 98		<-c
 99		C.termecho(1)
100		fmt.Printf("\n")
101		os.Remove(dbname)
102		os.Exit(1)
103	}()
104
105	for _, line := range strings.Split(string(schema), ";") {
106		_, err = db.Exec(line)
107		if err != nil {
108			log.Print(err)
109			return
110		}
111	}
112	defer db.Close()
113	r := bufio.NewReader(os.Stdin)
114
115	err = createuser(db, r)
116	if err != nil {
117		log.Print(err)
118		return
119	}
120
121	fmt.Printf("listen address: ")
122	addr, err := r.ReadString('\n')
123	if err != nil {
124		log.Print(err)
125		return
126	}
127	addr = addr[:len(addr)-1]
128	if len(addr) < 1 {
129		log.Print("that's way too short")
130		return
131	}
132	_, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
133	if err != nil {
134		log.Print(err)
135		return
136	}
137	fmt.Printf("server name: ")
138	addr, err = r.ReadString('\n')
139	if err != nil {
140		log.Print(err)
141		return
142	}
143	addr = addr[:len(addr)-1]
144	if len(addr) < 1 {
145		log.Print("that's way too short")
146		return
147	}
148	_, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
149	if err != nil {
150		log.Print(err)
151		return
152	}
153	var randbytes [16]byte
154	rand.Read(randbytes[:])
155	key := fmt.Sprintf("%x", randbytes)
156	_, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
157	if err != nil {
158		log.Print(err)
159		return
160	}
161	_, err = db.Exec("insert into config (key, value) values (?, ?)", "dbversion", myVersion)
162	if err != nil {
163		log.Print(err)
164		return
165	}
166
167	initblobdb()
168
169	prepareStatements(db)
170	db.Close()
171	fmt.Printf("done.\n")
172	os.Exit(0)
173}
174
175func initblobdb() {
176	_, err := os.Stat(blobdbname)
177	if err == nil {
178		log.Fatalf("%s already exists", blobdbname)
179	}
180	blobdb, err := sql.Open("sqlite3", blobdbname)
181	if err != nil {
182		log.Print(err)
183		return
184	}
185	_, err = blobdb.Exec("create table filedata (xid text, media text, content blob)")
186	if err != nil {
187		log.Print(err)
188		return
189	}
190	_, err = blobdb.Exec("create index idx_filexid on filedata(xid)")
191	if err != nil {
192		log.Print(err)
193		return
194	}
195	blobdb.Close()
196}
197
198func adduser() {
199	db := opendatabase()
200	defer func() {
201		os.Exit(1)
202	}()
203	c := make(chan os.Signal)
204	signal.Notify(c, os.Interrupt)
205	go func() {
206		<-c
207		C.termecho(1)
208		fmt.Printf("\n")
209		os.Exit(1)
210	}()
211
212	r := bufio.NewReader(os.Stdin)
213
214	err := createuser(db, r)
215	if err != nil {
216		log.Print(err)
217		return
218	}
219
220	db.Close()
221	os.Exit(0)
222}
223
224func createuser(db *sql.DB, r *bufio.Reader) error {
225	fmt.Printf("username: ")
226	name, err := r.ReadString('\n')
227	if err != nil {
228		return err
229	}
230	name = name[:len(name)-1]
231	if len(name) < 1 {
232		return fmt.Errorf("that's way too short")
233	}
234	C.termecho(0)
235	fmt.Printf("password: ")
236	pass, err := r.ReadString('\n')
237	C.termecho(1)
238	fmt.Printf("\n")
239	if err != nil {
240		return err
241	}
242	pass = pass[:len(pass)-1]
243	if len(pass) < 6 {
244		return fmt.Errorf("that's way too short")
245	}
246	hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
247	if err != nil {
248		return err
249	}
250	k, err := rsa.GenerateKey(rand.Reader, 2048)
251	if err != nil {
252		return err
253	}
254	pubkey, err := httpsig.EncodeKey(&k.PublicKey)
255	if err != nil {
256		return err
257	}
258	seckey, err := httpsig.EncodeKey(k)
259	if err != nil {
260		return err
261	}
262	_, err = db.Exec("insert into users (username, displayname, about, hash, pubkey, seckey, options) values (?, ?, ?, ?, ?, ?, ?)", name, name, "what about me?", hash, pubkey, seckey, "")
263	if err != nil {
264		return err
265	}
266	return nil
267}
268
269func opendatabase() *sql.DB {
270	if alreadyopendb != nil {
271		return alreadyopendb
272	}
273	var err error
274	_, err = os.Stat(dbname)
275	if err != nil {
276		log.Fatalf("unable to open database: %s", err)
277	}
278	db, err := sql.Open("sqlite3", dbname)
279	if err != nil {
280		log.Fatalf("unable to open database: %s", err)
281	}
282	stmtConfig, err = db.Prepare("select value from config where key = ?")
283	if err != nil {
284		log.Fatal(err)
285	}
286	alreadyopendb = db
287	return db
288}
289
290func openblobdb() *sql.DB {
291	var err error
292	_, err = os.Stat(blobdbname)
293	if err != nil {
294		log.Fatalf("unable to open database: %s", err)
295	}
296	db, err := sql.Open("sqlite3", blobdbname)
297	if err != nil {
298		log.Fatalf("unable to open database: %s", err)
299	}
300	return db
301}
302
303func getconfig(key string, value interface{}) error {
304	m, ok := value.(*map[string]bool)
305	if ok {
306		rows, err := stmtConfig.Query(key)
307		if err != nil {
308			return err
309		}
310		defer rows.Close()
311		for rows.Next() {
312			var s string
313			err = rows.Scan(&s)
314			if err != nil {
315				return err
316			}
317			(*m)[s] = true
318		}
319		return nil
320	}
321	row := stmtConfig.QueryRow(key)
322	err := row.Scan(value)
323	if err == sql.ErrNoRows {
324		err = nil
325	}
326	return err
327}
328
329func saveconfig(key string, val interface{}) {
330	db := opendatabase()
331	db.Exec("update config set value = ? where key = ?", val, key)
332}
333
334func openListener() (net.Listener, error) {
335	var listenAddr string
336	err := getconfig("listenaddr", &listenAddr)
337	if err != nil {
338		return nil, err
339	}
340	if listenAddr == "" {
341		return nil, fmt.Errorf("must have listenaddr")
342	}
343	proto := "tcp"
344	if listenAddr[0] == '/' {
345		proto = "unix"
346		err := os.Remove(listenAddr)
347		if err != nil && !os.IsNotExist(err) {
348			log.Printf("unable to unlink socket: %s", err)
349		}
350	}
351	listener, err := net.Listen(proto, listenAddr)
352	if err != nil {
353		return nil, err
354	}
355	if proto == "unix" {
356		os.Chmod(listenAddr, 0777)
357	}
358	return listener, nil
359}