all repos — honk @ 2168c60f7d1833529f93f339ae694bda9f2afec6

my fork of honk

util.go (view raw)

  1//
  2// Copyright (c) 2019 Ted Unangst <tedu@tedunangst.com>
  3//
  4// Permission to use, copy, modify, and distribute this software for any
  5// purpose with or without fee is hereby granted, provided that the above
  6// copyright notice and this permission notice appear in all copies.
  7//
  8// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 10// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 11// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 12// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 13// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 14// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 15
 16package main
 17
 18import (
 19	"bufio"
 20	"crypto/rand"
 21	"crypto/sha512"
 22	"database/sql"
 23	"fmt"
 24	"io/ioutil"
 25	"log"
 26	"net"
 27	"os"
 28	"os/signal"
 29	"strings"
 30
 31	"golang.org/x/crypto/bcrypt"
 32	"golang.org/x/crypto/ssh/terminal"
 33	_ "humungus.tedunangst.com/r/go-sqlite3"
 34)
 35
 36var savedstyleparam string
 37
 38func getstyleparam() string {
 39	if savedstyleparam != "" {
 40		return savedstyleparam
 41	}
 42	data, _ := ioutil.ReadFile("views/style.css")
 43	hasher := sha512.New()
 44	hasher.Write(data)
 45	return fmt.Sprintf("?v=%.8x", hasher.Sum(nil))
 46}
 47
 48var dbtimeformat = "2006-01-02 15:04:05"
 49
 50var alreadyopendb *sql.DB
 51var dbname = "honk.db"
 52var stmtConfig *sql.Stmt
 53
 54func initdb() {
 55	schema, err := ioutil.ReadFile("schema.sql")
 56	if err != nil {
 57		log.Fatal(err)
 58	}
 59	_, err = os.Stat(dbname)
 60	if err == nil {
 61		log.Fatalf("%s already exists", dbname)
 62	}
 63	db, err := sql.Open("sqlite3", dbname)
 64	if err != nil {
 65		log.Fatal(err)
 66	}
 67	defer func() {
 68		os.Remove(dbname)
 69		os.Exit(1)
 70	}()
 71	c := make(chan os.Signal)
 72	signal.Notify(c, os.Interrupt)
 73	go func() {
 74		<-c
 75		fmt.Printf("\n")
 76		os.Remove(dbname)
 77		os.Exit(1)
 78	}()
 79
 80	for _, line := range strings.Split(string(schema), ";") {
 81		_, err = db.Exec(line)
 82		if err != nil {
 83			log.Print(err)
 84			return
 85		}
 86	}
 87	defer db.Close()
 88	r := bufio.NewReader(os.Stdin)
 89	fmt.Printf("username: ")
 90	name, err := r.ReadString('\n')
 91	if err != nil {
 92		log.Print(err)
 93		return
 94	}
 95	name = name[:len(name)-1]
 96	if len(name) < 1 {
 97		log.Print("that's way too short")
 98		return
 99	}
100	fmt.Printf("password: ")
101	passbytes, err := terminal.ReadPassword(1)
102	fmt.Printf("\n")
103	if err != nil {
104		log.Print(err)
105		return
106	}
107	pass := string(passbytes)
108	if len(pass) < 6 {
109		log.Print("that's way too short")
110		return
111	}
112	hash, err := bcrypt.GenerateFromPassword([]byte(pass), 12)
113	if err != nil {
114		log.Print(err)
115		return
116	}
117	_, err = db.Exec("insert into users (username, hash) values (?, ?)", name, hash)
118	if err != nil {
119		log.Print(err)
120		return
121	}
122	fmt.Printf("listen address: ")
123	addr, err := r.ReadString('\n')
124	if err != nil {
125		log.Print(err)
126		return
127	}
128	addr = addr[:len(addr)-1]
129	if len(addr) < 1 {
130		log.Print("that's way too short")
131		return
132	}
133	_, err = db.Exec("insert into config (key, value) values (?, ?)", "listenaddr", addr)
134	if err != nil {
135		log.Print(err)
136		return
137	}
138	fmt.Printf("server name: ")
139	addr, err = r.ReadString('\n')
140	if err != nil {
141		log.Print(err)
142		return
143	}
144	addr = addr[:len(addr)-1]
145	if len(addr) < 1 {
146		log.Print("that's way too short")
147		return
148	}
149	_, err = db.Exec("insert into config (key, value) values (?, ?)", "servername", addr)
150	if err != nil {
151		log.Print(err)
152		return
153	}
154	var randbytes [16]byte
155	rand.Read(randbytes[:])
156	key := fmt.Sprintf("%x", randbytes)
157	_, err = db.Exec("insert into config (key, value) values (?, ?)", "csrfkey", key)
158	if err != nil {
159		log.Print(err)
160		return
161	}
162	err = finishusersetup()
163	if err != nil {
164		log.Print(err)
165		return
166	}
167	prepareStatements(db)
168	db.Close()
169	fmt.Printf("done.\n")
170	os.Exit(0)
171}
172
173func opendatabase() *sql.DB {
174	if alreadyopendb != nil {
175		return alreadyopendb
176	}
177	var err error
178	_, err = os.Stat(dbname)
179	if err != nil {
180		log.Fatalf("unable to open database: %s", err)
181	}
182	db, err := sql.Open("sqlite3", dbname)
183	if err != nil {
184		log.Fatalf("unable to open database: %s", err)
185	}
186	stmtConfig, err = db.Prepare("select value from config where key = ?")
187	if err != nil {
188		log.Fatal(err)
189	}
190	alreadyopendb = db
191	return db
192}
193
194func getconfig(key string, value interface{}) error {
195	row := stmtConfig.QueryRow(key)
196	err := row.Scan(value)
197	if err == sql.ErrNoRows {
198		err = nil
199	}
200	return err
201}
202
203func openListener() (net.Listener, error) {
204	var listenAddr string
205	err := getconfig("listenaddr", &listenAddr)
206	if err != nil {
207		return nil, err
208	}
209	if listenAddr == "" {
210		return nil, fmt.Errorf("must have listenaddr")
211	}
212	proto := "tcp"
213	if listenAddr[0] == '/' {
214		proto = "unix"
215		err := os.Remove(listenAddr)
216		if err != nil && !os.IsNotExist(err) {
217			log.Printf("unable to unlink socket: %s", err)
218		}
219	}
220	listener, err := net.Listen(proto, listenAddr)
221	if err != nil {
222		return nil, err
223	}
224	if proto == "unix" {
225		os.Chmod(listenAddr, 0777)
226	}
227	return listener, nil
228}