all repos — honk @ b162076c4d37bcff6770ff1ca9299b66d5752b2b

my fork of honk

util.go (view raw)

  1//
  2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
  3//
  4// Permission to use, copy, modify, and distribute this software for any
  5// purpose with or without fee is hereby granted, provided that the above
  6// copyright notice and this permission notice appear in all copies.
  7//
  8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 15
 16package main
 17
 18/*
 19#include <termios.h>
 20
 21void
 22termecho(int on)
 23{
 24	struct termios t;
 25	tcgetattr(1, &t);
 26	if (on)
 27		t.c_lflag |= ECHO;
 28	else
 29		t.c_lflag &= ~ECHO;
 30	tcsetattr(1, TCSADRAIN, &t);
 31}
 32*/
 33import "C"
 34
 35import (
 36	"bufio"
 37	"crypto/rand"
 38	"crypto/rsa"
 39	"crypto/sha512"
 40	"database/sql"
 41	"fmt"
 42	"io/ioutil"
 43	"log"
 44	"net"
 45	"os"
 46	"os/signal"
 47	"strings"
 48
 49	"golang.org/x/crypto/bcrypt"
 50	_ "humungus.tedunangst.com/r/go-sqlite3"
 51	"humungus.tedunangst.com/r/webs/httpsig"
 52	"humungus.tedunangst.com/r/webs/login"
 53)
 54
 55var savedassetparams = make(map[string]string)
 56
 57func getassetparam(file string) string {
 58	if p, ok := savedassetparams[file]; ok {
 59		return p
 60	}
 61	data, err := ioutil.ReadFile(file)
 62	if err != nil {
 63		return ""
 64	}
 65	hasher := sha512.New()
 66	hasher.Write(data)
 67
 68	return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
 69}
 70
 71var dbtimeformat = "2006-01-02 15:04:05"
 72
 73var alreadyopendb *sql.DB
 74var dbname = "honk.db"
 75var blobdbname = "blob.db"
 76var stmtConfig *sql.Stmt
 77
 78func initdb() {
 79	schema, err := ioutil.ReadFile("schema.sql")
 80	if err != nil {
 81		log.Fatal(err)
 82	}
 83	_, err = os.Stat(dbname)
 84	if err == nil {
 85		log.Fatalf("%s already exists", dbname)
 86	}
 87	db, err := sql.Open("sqlite3", dbname)
 88	if err != nil {
 89		log.Fatal(err)
 90	}
 91	alreadyopendb = db
 92	defer func() {
 93		os.Remove(dbname)
 94		os.Exit(1)
 95	}()
 96	c := make(chan os.Signal)
 97	signal.Notify(c, os.Interrupt)
 98	go func() {
 99		<-c
100		C.termecho(1)
101		fmt.Printf("\n")
102		os.Remove(dbname)
103		os.Exit(1)
104	}()
105
106	for _, line := range strings.Split(string(schema), ";") {
107		_, err = db.Exec(line)
108		if err != nil {
109			log.Print(err)
110			return
111		}
112	}
113	defer db.Close()
114	r := bufio.NewReader(os.Stdin)
115
116	err = createuser(db, r)
117	if err != nil {
118		log.Print(err)
119		return
120	}
121
122	fmt.Printf("listen address: ")
123	addr, err := r.ReadString('\n')
124	if err != nil {
125		log.Print(err)
126		return
127	}
128	addr = addr[:len(addr)-1]
129	if len(addr) < 1 {
130		log.Print("that's way too short")
131		return
132	}
133	setconfig("listenaddr", addr)
134	fmt.Printf("server name: ")
135	addr, err = r.ReadString('\n')
136	if err != nil {
137		log.Print(err)
138		return
139	}
140	addr = addr[:len(addr)-1]
141	if len(addr) < 1 {
142		log.Print("that's way too short")
143		return
144	}
145	setconfig("servername", addr)
146	var randbytes [16]byte
147	rand.Read(randbytes[:])
148	key := fmt.Sprintf("%x", randbytes)
149	setconfig("csrfkey", key)
150	setconfig("dbversion", myVersion)
151
152	setconfig("servermsg", "<h2>Things happen.</h2>")
153	setconfig("aboutmsg", "<h3>What is honk?</h3>\n<p>Honk is amazing!")
154	setconfig("loginmsg", "<h2>login</h2>")
155	setconfig("debug", 0)
156
157	initblobdb()
158
159	prepareStatements(db)
160	db.Close()
161	fmt.Printf("done.\n")
162	os.Exit(0)
163}
164
165func initblobdb() {
166	_, err := os.Stat(blobdbname)
167	if err == nil {
168		log.Fatalf("%s already exists", blobdbname)
169	}
170	blobdb, err := sql.Open("sqlite3", blobdbname)
171	if err != nil {
172		log.Print(err)
173		return
174	}
175	_, err = blobdb.Exec("create table filedata (xid text, media text, content blob)")
176	if err != nil {
177		log.Print(err)
178		return
179	}
180	_, err = blobdb.Exec("create index idx_filexid on filedata(xid)")
181	if err != nil {
182		log.Print(err)
183		return
184	}
185	blobdb.Close()
186}
187
188func adduser() {
189	db := opendatabase()
190	defer func() {
191		os.Exit(1)
192	}()
193	c := make(chan os.Signal)
194	signal.Notify(c, os.Interrupt)
195	go func() {
196		<-c
197		C.termecho(1)
198		fmt.Printf("\n")
199		os.Exit(1)
200	}()
201
202	r := bufio.NewReader(os.Stdin)
203
204	err := createuser(db, r)
205	if err != nil {
206		log.Print(err)
207		return
208	}
209
210	db.Close()
211	os.Exit(0)
212}
213
214func chpass() {
215	if len(os.Args) < 3 {
216		fmt.Printf("need a username\n")
217		os.Exit(1)
218	}
219	user, err := butwhatabout(os.Args[2])
220	if err != nil {
221		log.Fatal(err)
222	}
223	defer func() {
224		os.Exit(1)
225	}()
226	c := make(chan os.Signal)
227	signal.Notify(c, os.Interrupt)
228	go func() {
229		<-c
230		C.termecho(1)
231		fmt.Printf("\n")
232		os.Exit(1)
233	}()
234
235	db := opendatabase()
236	login.Init(db)
237
238	r := bufio.NewReader(os.Stdin)
239
240	pass, err := askpassword(r)
241	if err != nil {
242		log.Print(err)
243		return
244	}
245	err = login.SetPassword(user.ID, pass)
246	if err != nil {
247		log.Print(err)
248		return
249	}
250	fmt.Printf("done\n")
251	os.Exit(0)
252}
253
254func askpassword(r *bufio.Reader) (string, error) {
255	C.termecho(0)
256	fmt.Printf("password: ")
257	pass, err := r.ReadString('\n')
258	C.termecho(1)
259	fmt.Printf("\n")
260	if err != nil {
261		return "", err
262	}
263	pass = pass[:len(pass)-1]
264	if len(pass) < 6 {
265		return "", fmt.Errorf("that's way too short")
266	}
267	return pass, nil
268}
269
270func createuser(db *sql.DB, r *bufio.Reader) error {
271	fmt.Printf("username: ")
272	name, err := r.ReadString('\n')
273	if err != nil {
274		return err
275	}
276	name = name[:len(name)-1]
277	if len(name) < 1 {
278		return fmt.Errorf("that's way too short")
279	}
280	pass, err := askpassword(r)
281	if err != nil {
282		return err
283	}
284	hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
285	if err != nil {
286		return err
287	}
288	k, err := rsa.GenerateKey(rand.Reader, 2048)
289	if err != nil {
290		return err
291	}
292	pubkey, err := httpsig.EncodeKey(&k.PublicKey)
293	if err != nil {
294		return err
295	}
296	seckey, err := httpsig.EncodeKey(k)
297	if err != nil {
298		return err
299	}
300	_, err = db.Exec("insert into users (username, displayname, about, hash, pubkey, seckey, options) values (?, ?, ?, ?, ?, ?, ?)", name, name, "what about me?", hash, pubkey, seckey, "")
301	if err != nil {
302		return err
303	}
304	return nil
305}
306
307func opendatabase() *sql.DB {
308	if alreadyopendb != nil {
309		return alreadyopendb
310	}
311	var err error
312	_, err = os.Stat(dbname)
313	if err != nil {
314		log.Fatalf("unable to open database: %s", err)
315	}
316	db, err := sql.Open("sqlite3", dbname)
317	if err != nil {
318		log.Fatalf("unable to open database: %s", err)
319	}
320	stmtConfig, err = db.Prepare("select value from config where key = ?")
321	if err != nil {
322		log.Fatal(err)
323	}
324	alreadyopendb = db
325	return db
326}
327
328func openblobdb() *sql.DB {
329	var err error
330	_, err = os.Stat(blobdbname)
331	if err != nil {
332		log.Fatalf("unable to open database: %s", err)
333	}
334	db, err := sql.Open("sqlite3", blobdbname)
335	if err != nil {
336		log.Fatalf("unable to open database: %s", err)
337	}
338	return db
339}
340
341func getconfig(key string, value interface{}) error {
342	m, ok := value.(*map[string]bool)
343	if ok {
344		rows, err := stmtConfig.Query(key)
345		if err != nil {
346			return err
347		}
348		defer rows.Close()
349		for rows.Next() {
350			var s string
351			err = rows.Scan(&s)
352			if err != nil {
353				return err
354			}
355			(*m)[s] = true
356		}
357		return nil
358	}
359	row := stmtConfig.QueryRow(key)
360	err := row.Scan(value)
361	if err == sql.ErrNoRows {
362		err = nil
363	}
364	return err
365}
366
367func setconfig(key string, val interface{}) error {
368	db := opendatabase()
369	_, err := db.Exec("insert into config (key, value) values (?, ?)", key, val)
370	return err
371}
372
373func updateconfig(key string, val interface{}) error {
374	db := opendatabase()
375	_, err := db.Exec("update config set value = ? where key = ?", val, key)
376	return err
377}
378
379func openListener() (net.Listener, error) {
380	var listenAddr string
381	err := getconfig("listenaddr", &listenAddr)
382	if err != nil {
383		return nil, err
384	}
385	if listenAddr == "" {
386		return nil, fmt.Errorf("must have listenaddr")
387	}
388	proto := "tcp"
389	if listenAddr[0] == '/' {
390		proto = "unix"
391		err := os.Remove(listenAddr)
392		if err != nil && !os.IsNotExist(err) {
393			log.Printf("unable to unlink socket: %s", err)
394		}
395	}
396	listener, err := net.Listen(proto, listenAddr)
397	if err != nil {
398		return nil, err
399	}
400	if proto == "unix" {
401		os.Chmod(listenAddr, 0777)
402	}
403	return listener, nil
404}