loginsrv

Unnamed repository; edit this file 'description' to name the repository.
git clone git@jamesshield.xyz:repos/loginsrv.git
Log | Files | Refs | README | LICENSE

auth.go (3366B)


      1 package htpasswd
      2 
      3 import (
      4 	"bytes"
      5 	"crypto/sha1"
      6 	"crypto/subtle"
      7 	"encoding/base64"
      8 	"encoding/csv"
      9 	"fmt"
     10 	"github.com/abbot/go-http-auth"
     11 	"github.com/tarent/loginsrv/logging"
     12 	"golang.org/x/crypto/bcrypt"
     13 	"io"
     14 	"os"
     15 	"strings"
     16 	"sync"
     17 	"time"
     18 )
     19 
     20 // File is a struct to serve an individual modTime
     21 type File struct {
     22 	name string
     23 	// Used in func reloadIfChanged to reload htpasswd file if it changed
     24 	modTime time.Time
     25 }
     26 
     27 // Auth is the htpassword authenticater
     28 type Auth struct {
     29 	filenames  []File
     30 	userHash   map[string]string
     31 	muUserHash sync.RWMutex
     32 }
     33 
     34 // NewAuth creates an htpassword authenticater
     35 func NewAuth(filenames []string) (*Auth, error) {
     36 	var htpasswdFiles []File
     37 	for _, file := range filenames {
     38 		htpasswdFiles = append(htpasswdFiles, File{name: file})
     39 	}
     40 
     41 	a := &Auth{
     42 		filenames: htpasswdFiles,
     43 	}
     44 	return a, a.parse()
     45 }
     46 
     47 func (a *Auth) parse() error {
     48 	tmpUserHash := map[string]string{}
     49 	tmpFilenames := a.filenames
     50 
     51 	for i, filename := range a.filenames {
     52 		r, err := os.Open(filename.name)
     53 		if err != nil {
     54 			return err
     55 		}
     56 		defer r.Close()
     57 
     58 		fileInfo, err := os.Stat(filename.name)
     59 		if err != nil {
     60 			return err
     61 		}
     62 		tmpFilenames[i].modTime = fileInfo.ModTime()
     63 
     64 		cr := csv.NewReader(r)
     65 		cr.Comma = ':'
     66 		cr.Comment = '#'
     67 		cr.TrimLeadingSpace = true
     68 
     69 		for {
     70 			record, err := cr.Read()
     71 			if err == io.EOF {
     72 				break
     73 			}
     74 			if err != nil {
     75 				return err
     76 			}
     77 			if len(record) != 2 {
     78 				return fmt.Errorf("password file in wrong format (%v)", filename)
     79 			}
     80 
     81 			if _, exist := tmpUserHash[record[0]]; exist {
     82 				logging.Logger.Warnf("Found duplicate entry for user: (%v)", record[0])
     83 			}
     84 			tmpUserHash[record[0]] = record[1]
     85 		}
     86 	}
     87 	a.muUserHash.Lock()
     88 	a.userHash = tmpUserHash
     89 	a.filenames = tmpFilenames
     90 	a.muUserHash.Unlock()
     91 
     92 	return nil
     93 }
     94 
     95 // Authenticate the user
     96 func (a *Auth) Authenticate(username, password string) (bool, error) {
     97 	reloadIfChanged(a)
     98 	a.muUserHash.RLock()
     99 	defer a.muUserHash.RUnlock()
    100 	if hash, exist := a.userHash[username]; exist {
    101 		h := []byte(hash)
    102 		p := []byte(password)
    103 		if strings.HasPrefix(hash, "$2y$") || strings.HasPrefix(hash, "$2b$") || strings.HasPrefix(hash, "$2a$") {
    104 			matchErr := bcrypt.CompareHashAndPassword(h, p)
    105 			return (matchErr == nil), nil
    106 		}
    107 		if strings.HasPrefix(hash, "{SHA}") {
    108 			return compareSha(h, p), nil
    109 		}
    110 		if strings.HasPrefix(hash, "$apr1$") {
    111 			return compareMD5(h, p), nil
    112 		}
    113 		return false, fmt.Errorf("unknown algorithm for user %q", username)
    114 	}
    115 	return false, nil
    116 }
    117 
    118 // Reload htpasswd file if it changed during current run
    119 func reloadIfChanged(a *Auth) {
    120 	for _, file := range a.filenames {
    121 		fileInfo, err := os.Stat(file.name)
    122 		if err != nil {
    123 			//On error, retain current file
    124 			break
    125 		}
    126 		currentmodTime := fileInfo.ModTime()
    127 		if currentmodTime != file.modTime {
    128 			a.parse()
    129 			return
    130 		}
    131 	}
    132 }
    133 
    134 func compareSha(hashedPassword, password []byte) bool {
    135 	d := sha1.New()
    136 	d.Write(password)
    137 	return 1 == subtle.ConstantTimeCompare(hashedPassword[5:], []byte(base64.StdEncoding.EncodeToString(d.Sum(nil))))
    138 }
    139 
    140 func compareMD5(hashedPassword, password []byte) bool {
    141 	parts := bytes.SplitN(hashedPassword, []byte("$"), 4)
    142 	if len(parts) != 4 {
    143 		return false
    144 	}
    145 	magic := []byte("$" + string(parts[1]) + "$")
    146 	salt := parts[2]
    147 	return 1 == subtle.ConstantTimeCompare(hashedPassword, auth.MD5Crypt(password, salt, magic))
    148 }