aboutsummaryrefslogblamecommitdiff
path: root/badger/ids.go
blob: 80fb9adc431bed676d15aabe4cd6edd7408b679d (plain) (tree)



































































































































                                                                                                       
package badger

import (
	"bytes"
	"errors"
	"fmt"
	"strconv"

	badger "github.com/dgraph-io/badger/v4"
	log "github.com/sirupsen/logrus"
)

// IdsDB represents a Badger database
type IdsDB struct {
	db *badger.DB
}

// IdsDBOpen returns a new DB object
func IdsDBOpen(path string) IdsDB {
	bdb, err := badger.Open(badger.DefaultOptions(path))
	if err != nil {
		log.Errorf("Failed to open ids database: %v, falling back to in-memory database", path)
		bdb, err = badger.Open(badger.DefaultOptions("").WithInMemory(true))
		if err != nil {
			log.Fatalf("Couldn't initialize the ids database")
		}
	}

	return IdsDB{
		db: bdb,
	}
}

// Set stores an id pair
func (db *IdsDB) Set(tgAccount, xmppAccount string, tgChatId, tgMsgId int64, xmppId string) error {
	bPrefix := toKeyPrefix(tgAccount, xmppAccount)
	bTgId := toTgByteString(tgChatId, tgMsgId)
	bXmppId := toXmppByteString(xmppId)
	bTgKey := toByteKey(bPrefix, bTgId, "tg")
	bXmppKey := toByteKey(bPrefix, bXmppId, "xmpp")

	return db.db.Update(func(txn *badger.Txn) error {
		if err := txn.Set(bTgKey, bXmppId); err != nil {
			return err
		}
		return txn.Set(bXmppKey, bTgId)
	})
}

func (db *IdsDB) getByteValue(key []byte) ([]byte, error) {
	var valCopy []byte
	err := db.db.View(func(txn *badger.Txn) error {
		item, err := txn.Get(key)
		if err != nil {
			return err
		}

		valCopy, err = item.ValueCopy(nil)
		return err
	})
	return valCopy, err
}

// GetByTgIds obtains an XMPP id by Telegram chat/message ids
func (db *IdsDB) GetByTgIds(tgAccount, xmppAccount string, tgChatId, tgMsgId int64) (string, error) {
	val, err := db.getByteValue(toByteKey(
		toKeyPrefix(tgAccount, xmppAccount),
		toTgByteString(tgChatId, tgMsgId),
		"tg",
	))
	if err != nil {
		return "", err
	}
	return string(val), nil
}

// GetByXmppId obtains Telegram chat/message ids by an XMPP id
func (db *IdsDB) GetByXmppId(tgAccount, xmppAccount, xmppId string) (int64, int64, error) {
	val, err := db.getByteValue(toByteKey(
		toKeyPrefix(tgAccount, xmppAccount),
		toXmppByteString(xmppId),
		"xmpp",
	))
	if err != nil {
		return 0, 0, err
	}
	return splitTgByteString(val)
}

func toKeyPrefix(tgAccount, xmppAccount string) []byte {
	return []byte(fmt.Sprintf("%v/%v/", tgAccount, xmppAccount))
}

func toByteKey(prefix, suffix []byte, typ string) []byte {
	key := make([]byte, 0, len(prefix) + len(suffix) + 6)
	key = append(key, prefix...)
	key = append(key, []byte(typ)...)
	key = append(key, []byte("/")...)
	key = append(key, suffix...)
	return key
}

func toTgByteString(tgChatId, tgMsgId int64) []byte {
	return []byte(fmt.Sprintf("%v/%v", tgChatId, tgMsgId))
}

func toXmppByteString(xmppId string) []byte {
	return []byte(xmppId)
}

func splitTgByteString(val []byte) (int64, int64, error) {
	parts := bytes.Split(val, []byte("/"))
	if len(parts) != 2 {
		return 0, 0, errors.New("Couldn't parse tg id pair")
	}
	tgChatId, err := strconv.ParseInt(string(parts[0]), 10, 64)
	if err != nil {
		return 0, 0, err
	}
	tgMsgId, err := strconv.ParseInt(string(parts[1]), 10, 64)
	return tgChatId, tgMsgId, err
}

// Gc compacts the value log
func (db *IdsDB) Gc() {
	db.db.RunValueLogGC(0.7)
}

// Close closes a DB
func (db *IdsDB) Close() {
	db.db.Close()
}