all repos — honk @ 24d8da1fd7e051272ff0b800f69a481986013477

my fork of honk

util.go (view raw)

  1//
  2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
  3//
  4// Permission to use, copy, modify, and distribute this software for any
  5// purpose with or without fee is hereby granted, provided that the above
  6// copyright notice and this permission notice appear in all copies.
  7//
  8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 15
 16package main
 17
 18/*
 19#include <termios.h>
 20
 21void
 22termecho(int on)
 23{
 24	struct termios t;
 25	tcgetattr(1, &t);
 26	if (on)
 27		t.c_lflag |= ECHO;
 28	else
 29		t.c_lflag &= ~ECHO;
 30	tcsetattr(1, TCSADRAIN, &t);
 31}
 32*/
 33import "C"
 34
 35import (
 36	"bufio"
 37	"crypto/rand"
 38	"crypto/rsa"
 39	"crypto/sha512"
 40	"database/sql"
 41	"fmt"
 42	"io/ioutil"
 43	"log"
 44	"net"
 45	"os"
 46	"os/signal"
 47	"strings"
 48
 49	"golang.org/x/crypto/bcrypt"
 50	_ "humungus.tedunangst.com/r/go-sqlite3"
 51	"humungus.tedunangst.com/r/webs/httpsig"
 52)
 53
 54var savedstyleparams = make(map[string]string)
 55
 56func getstyleparam(file string) string {
 57	if p, ok := savedstyleparams[file]; ok {
 58		return p
 59	}
 60	data, err := ioutil.ReadFile(file)
 61	if err != nil {
 62		return ""
 63	}
 64	hasher := sha512.New()
 65	hasher.Write(data)
 66
 67	return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
 68}
 69
 70var dbtimeformat = "2006-01-02 15:04:05"
 71
 72var alreadyopendb *sql.DB
 73var dbname = "honk.db"
 74var stmtConfig *sql.Stmt
 75var myVersion = 17
 76
 77func initdb() {
 78	schema, err := ioutil.ReadFile("schema.sql")
 79	if err != nil {
 80		log.Fatal(err)
 81	}
 82	_, err = os.Stat(dbname)
 83	if err == nil {
 84		log.Fatalf("%s already exists", dbname)
 85	}
 86	db, err := sql.Open("sqlite3", dbname)
 87	if err != nil {
 88		log.Fatal(err)
 89	}
 90	defer func() {
 91		os.Remove(dbname)
 92		os.Exit(1)
 93	}()
 94	c := make(chan os.Signal)
 95	signal.Notify(c, os.Interrupt)
 96	go func() {
 97		<-c
 98		C.termecho(1)
 99		fmt.Printf("\n")
100		os.Remove(dbname)
101		os.Exit(1)
102	}()
103
104	for _, line := range strings.Split(string(schema), ";") {
105		_, err = db.Exec(line)
106		if err != nil {
107			log.Print(err)
108			return
109		}
110	}
111	defer db.Close()
112	r := bufio.NewReader(os.Stdin)
113
114	err = createuser(db, r)
115	if err != nil {
116		log.Print(err)
117		return
118	}
119
120	fmt.Printf("listen address: ")
121	addr, err := r.ReadString('\n')
122	if err != nil {
123		log.Print(err)
124		return
125	}
126	addr = addr[:len(addr)-1]
127	if len(addr) < 1 {
128		log.Print("that's way too short")
129		return
130	}
131	_, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
132	if err != nil {
133		log.Print(err)
134		return
135	}
136	fmt.Printf("server name: ")
137	addr, err = r.ReadString('\n')
138	if err != nil {
139		log.Print(err)
140		return
141	}
142	addr = addr[:len(addr)-1]
143	if len(addr) < 1 {
144		log.Print("that's way too short")
145		return
146	}
147	_, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
148	if err != nil {
149		log.Print(err)
150		return
151	}
152	var randbytes [16]byte
153	rand.Read(randbytes[:])
154	key := fmt.Sprintf("%x", randbytes)
155	_, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
156	if err != nil {
157		log.Print(err)
158		return
159	}
160	_, err = db.Exec("insert into config (key, value) values (?, ?)", "dbversion", myVersion)
161	if err != nil {
162		log.Print(err)
163		return
164	}
165	prepareStatements(db)
166	db.Close()
167	fmt.Printf("done.\n")
168	os.Exit(0)
169}
170
171func adduser() {
172	db := opendatabase()
173	defer func() {
174		os.Exit(1)
175	}()
176	c := make(chan os.Signal)
177	signal.Notify(c, os.Interrupt)
178	go func() {
179		<-c
180		C.termecho(1)
181		fmt.Printf("\n")
182		os.Exit(1)
183	}()
184
185	r := bufio.NewReader(os.Stdin)
186
187	err := createuser(db, r)
188	if err != nil {
189		log.Print(err)
190		return
191	}
192
193	db.Close()
194	os.Exit(0)
195}
196
197func createuser(db *sql.DB, r *bufio.Reader) error {
198	fmt.Printf("username: ")
199	name, err := r.ReadString('\n')
200	if err != nil {
201		return err
202	}
203	name = name[:len(name)-1]
204	if len(name) < 1 {
205		return fmt.Errorf("that's way too short")
206	}
207	C.termecho(0)
208	fmt.Printf("password: ")
209	pass, err := r.ReadString('\n')
210	C.termecho(1)
211	fmt.Printf("\n")
212	if err != nil {
213		return err
214	}
215	pass = pass[:len(pass)-1]
216	if len(pass) < 6 {
217		return fmt.Errorf("that's way too short")
218	}
219	hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
220	if err != nil {
221		return err
222	}
223	k, err := rsa.GenerateKey(rand.Reader, 2048)
224	if err != nil {
225		return err
226	}
227	pubkey, err := httpsig.EncodeKey(&k.PublicKey)
228	if err != nil {
229		return err
230	}
231	seckey, err := httpsig.EncodeKey(k)
232	if err != nil {
233		return err
234	}
235	_, err = db.Exec("insert into users (username, displayname, about, hash, pubkey, seckey, options) values (?, ?, ?, ?, ?, ?, ?)", name, name, "what about me?", hash, pubkey, seckey, "")
236	if err != nil {
237		return err
238	}
239	return nil
240}
241
242func opendatabase() *sql.DB {
243	if alreadyopendb != nil {
244		return alreadyopendb
245	}
246	var err error
247	_, err = os.Stat(dbname)
248	if err != nil {
249		log.Fatalf("unable to open database: %s", err)
250	}
251	db, err := sql.Open("sqlite3", dbname)
252	if err != nil {
253		log.Fatalf("unable to open database: %s", err)
254	}
255	stmtConfig, err = db.Prepare("select value from config where key = ?")
256	if err != nil {
257		log.Fatal(err)
258	}
259	alreadyopendb = db
260	return db
261}
262
263func getconfig(key string, value interface{}) error {
264	m, ok := value.(*map[string]bool)
265	if ok {
266		rows, err := stmtConfig.Query(key)
267		if err != nil {
268			return err
269		}
270		defer rows.Close()
271		for rows.Next() {
272			var s string
273			err = rows.Scan(&s)
274			if err != nil {
275				return err
276			}
277			(*m)[s] = true
278		}
279		return nil
280	}
281	row := stmtConfig.QueryRow(key)
282	err := row.Scan(value)
283	if err == sql.ErrNoRows {
284		err = nil
285	}
286	return err
287}
288
289func saveconfig(key string, val interface{}) {
290	db := opendatabase()
291	db.Exec("update config set value = ? where key = ?", val, key)
292}
293
294func openListener() (net.Listener, error) {
295	var listenAddr string
296	err := getconfig("listenaddr", &listenAddr)
297	if err != nil {
298		return nil, err
299	}
300	if listenAddr == "" {
301		return nil, fmt.Errorf("must have listenaddr")
302	}
303	proto := "tcp"
304	if listenAddr[0] == '/' {
305		proto = "unix"
306		err := os.Remove(listenAddr)
307		if err != nil && !os.IsNotExist(err) {
308			log.Printf("unable to unlink socket: %s", err)
309		}
310	}
311	listener, err := net.Listen(proto, listenAddr)
312	if err != nil {
313		return nil, err
314	}
315	if proto == "unix" {
316		os.Chmod(listenAddr, 0777)
317	}
318	return listener, nil
319}