aboutsummaryrefslogblamecommitdiff
path: root/badger/ids.go
blob: 1295e8ad31f02d3af27b732e3a0759988ff935f5 (plain) (tree)





























































































                                                                                                       
                                                         



























                                                                    

































































































                                                                                                           








                                
package badger

import (
	"bytes"
	"errors"
	"fmt"
	"strconv"

	badger "github.com/dgraph-io/badger/v4"
	log "github.com/sirupsen/logrus"
)

// IdsDB represents a Badger database
type IdsDB struct {
	db *badger.DB
}

// IdsDBOpen returns a new DB object
func IdsDBOpen(path string) IdsDB {
	bdb, err := badger.Open(badger.DefaultOptions(path))
	if err != nil {
		log.Errorf("Failed to open ids database: %v, falling back to in-memory database", path)
		bdb, err = badger.Open(badger.DefaultOptions("").WithInMemory(true))
		if err != nil {
			log.Fatalf("Couldn't initialize the ids database")
		}
	}

	return IdsDB{
		db: bdb,
	}
}

// Set stores an id pair
func (db *IdsDB) Set(tgAccount, xmppAccount string, tgChatId, tgMsgId int64, xmppId string) error {
	bPrefix := toKeyPrefix(tgAccount, xmppAccount)
	bTgId := toTgByteString(tgChatId, tgMsgId)
	bXmppId := toXmppByteString(xmppId)
	bTgKey := toByteKey(bPrefix, bTgId, "tg")
	bXmppKey := toByteKey(bPrefix, bXmppId, "xmpp")

	return db.db.Update(func(txn *badger.Txn) error {
		if err := txn.Set(bTgKey, bXmppId); err != nil {
			return err
		}
		return txn.Set(bXmppKey, bTgId)
	})
}

func (db *IdsDB) getByteValue(key []byte) ([]byte, error) {
	var valCopy []byte
	err := db.db.View(func(txn *badger.Txn) error {
		item, err := txn.Get(key)
		if err != nil {
			return err
		}

		valCopy, err = item.ValueCopy(nil)
		return err
	})
	return valCopy, err
}

// GetByTgIds obtains an XMPP id by Telegram chat/message ids
func (db *IdsDB) GetByTgIds(tgAccount, xmppAccount string, tgChatId, tgMsgId int64) (string, error) {
	val, err := db.getByteValue(toByteKey(
		toKeyPrefix(tgAccount, xmppAccount),
		toTgByteString(tgChatId, tgMsgId),
		"tg",
	))
	if err != nil {
		return "", err
	}
	return string(val), nil
}

// GetByXmppId obtains Telegram chat/message ids by an XMPP id
func (db *IdsDB) GetByXmppId(tgAccount, xmppAccount, xmppId string) (int64, int64, error) {
	val, err := db.getByteValue(toByteKey(
		toKeyPrefix(tgAccount, xmppAccount),
		toXmppByteString(xmppId),
		"xmpp",
	))
	if err != nil {
		return 0, 0, err
	}
	return splitTgByteString(val)
}

func toKeyPrefix(tgAccount, xmppAccount string) []byte {
	return []byte(fmt.Sprintf("%v/%v/", tgAccount, xmppAccount))
}

func toByteKey(prefix, suffix []byte, typ string) []byte {
	key := make([]byte, 0, len(prefix)+len(suffix)+6)
	key = append(key, prefix...)
	key = append(key, []byte(typ)...)
	key = append(key, []byte("/")...)
	key = append(key, suffix...)
	return key
}

func toTgByteString(tgChatId, tgMsgId int64) []byte {
	return []byte(fmt.Sprintf("%v/%v", tgChatId, tgMsgId))
}

func toXmppByteString(xmppId string) []byte {
	return []byte(xmppId)
}

func splitTgByteString(val []byte) (int64, int64, error) {
	parts := bytes.Split(val, []byte("/"))
	if len(parts) != 2 {
		return 0, 0, errors.New("Couldn't parse tg id pair")
	}
	tgChatId, err := strconv.ParseInt(string(parts[0]), 10, 64)
	if err != nil {
		return 0, 0, err
	}
	tgMsgId, err := strconv.ParseInt(string(parts[1]), 10, 64)
	return tgChatId, tgMsgId, err
}

// ReplaceIdPair replaces an old entry by XMPP ID with both new XMPP and Tg ID
func (db *IdsDB) ReplaceIdPair(tgAccount, xmppAccount, oldXmppId, newXmppId string, newMsgId int64) error {
	// read old pair
	chatId, oldMsgId, err := db.GetByXmppId(tgAccount, xmppAccount, oldXmppId)
	if err != nil {
		return err
	}

	bPrefix := toKeyPrefix(tgAccount, xmppAccount)

	bOldTgId := toTgByteString(chatId, oldMsgId)
	bOldXmppId := toXmppByteString(oldXmppId)
	bOldTgKey := toByteKey(bPrefix, bOldTgId, "tg")
	bOldXmppKey := toByteKey(bPrefix, bOldXmppId, "xmpp")

	bTgId := toTgByteString(chatId, newMsgId)
	bXmppId := toXmppByteString(newXmppId)
	bTgKey := toByteKey(bPrefix, bTgId, "tg")
	bXmppKey := toByteKey(bPrefix, bXmppId, "xmpp")

	return db.db.Update(func(txn *badger.Txn) error {
		// save new pair
		if err := txn.Set(bTgKey, bXmppId); err != nil {
			return err
		}
		if err := txn.Set(bXmppKey, bTgId); err != nil {
			return err
		}
		// delete old pair
		if err := txn.Delete(bOldTgKey); err != nil {
			return err
		}
		return txn.Delete(bOldXmppKey)
	})
}

// ReplaceXmppId replaces an old XMPP ID with new XMPP ID and keeps Tg ID intact
func (db *IdsDB) ReplaceXmppId(tgAccount, xmppAccount, oldXmppId, newXmppId string) error {
	// read old Tg IDs
	chatId, msgId, err := db.GetByXmppId(tgAccount, xmppAccount, oldXmppId)
	if err != nil {
		return err
	}

	bPrefix := toKeyPrefix(tgAccount, xmppAccount)

	bOldXmppId := toXmppByteString(oldXmppId)
	bOldXmppKey := toByteKey(bPrefix, bOldXmppId, "xmpp")

	bTgId := toTgByteString(chatId, msgId)
	bXmppId := toXmppByteString(newXmppId)
	bTgKey := toByteKey(bPrefix, bTgId, "tg")
	bXmppKey := toByteKey(bPrefix, bXmppId, "xmpp")

	return db.db.Update(func(txn *badger.Txn) error {
		// save new pair
		if err := txn.Set(bTgKey, bXmppId); err != nil {
			return err
		}
		if err := txn.Set(bXmppKey, bTgId); err != nil {
			return err
		}
		// delete old xmpp->tg entry
		return txn.Delete(bOldXmppKey)
	})
}

// ReplaceTgId replaces an old Tg ID with new Tg ID and keeps Tg chat ID and XMPP ID intact
func (db *IdsDB) ReplaceTgId(tgAccount, xmppAccount string, chatId, oldMsgId, newMsgId int64) error {
	// read old XMPP ID
	xmppId, err := db.GetByTgIds(tgAccount, xmppAccount, chatId, oldMsgId)
	if err != nil {
		return err
	}

	bPrefix := toKeyPrefix(tgAccount, xmppAccount)

	bOldTgId := toTgByteString(chatId, oldMsgId)
	bOldTgKey := toByteKey(bPrefix, bOldTgId, "tg")

	bTgId := toTgByteString(chatId, newMsgId)
	bXmppId := toXmppByteString(xmppId)
	bTgKey := toByteKey(bPrefix, bTgId, "tg")
	bXmppKey := toByteKey(bPrefix, bXmppId, "xmpp")

	return db.db.Update(func(txn *badger.Txn) error {
		// save new pair
		if err := txn.Set(bTgKey, bXmppId); err != nil {
			return err
		}
		if err := txn.Set(bXmppKey, bTgId); err != nil {
			return err
		}
		// delete old tg->xmpp entry
		return txn.Delete(bOldTgKey)
	})
}

// Gc compacts the value log
func (db *IdsDB) Gc() {
	db.db.RunValueLogGC(0.7)
}

// Close closes a DB
func (db *IdsDB) Close() {
	db.db.Close()
}