aboutsummaryrefslogtreecommitdiff
path: root/badger
diff options
context:
space:
mode:
Diffstat (limited to 'badger')
-rw-r--r--badger/ids.go132
-rw-r--r--badger/ids_test.go72
2 files changed, 204 insertions, 0 deletions
diff --git a/badger/ids.go b/badger/ids.go
new file mode 100644
index 0000000..80fb9ad
--- /dev/null
+++ b/badger/ids.go
@@ -0,0 +1,132 @@
+package badger
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+ "strconv"
+
+ badger "github.com/dgraph-io/badger/v4"
+ log "github.com/sirupsen/logrus"
+)
+
+// IdsDB represents a Badger database
+type IdsDB struct {
+ db *badger.DB
+}
+
+// IdsDBOpen returns a new DB object
+func IdsDBOpen(path string) IdsDB {
+ bdb, err := badger.Open(badger.DefaultOptions(path))
+ if err != nil {
+ log.Errorf("Failed to open ids database: %v, falling back to in-memory database", path)
+ bdb, err = badger.Open(badger.DefaultOptions("").WithInMemory(true))
+ if err != nil {
+ log.Fatalf("Couldn't initialize the ids database")
+ }
+ }
+
+ return IdsDB{
+ db: bdb,
+ }
+}
+
+// Set stores an id pair
+func (db *IdsDB) Set(tgAccount, xmppAccount string, tgChatId, tgMsgId int64, xmppId string) error {
+ bPrefix := toKeyPrefix(tgAccount, xmppAccount)
+ bTgId := toTgByteString(tgChatId, tgMsgId)
+ bXmppId := toXmppByteString(xmppId)
+ bTgKey := toByteKey(bPrefix, bTgId, "tg")
+ bXmppKey := toByteKey(bPrefix, bXmppId, "xmpp")
+
+ return db.db.Update(func(txn *badger.Txn) error {
+ if err := txn.Set(bTgKey, bXmppId); err != nil {
+ return err
+ }
+ return txn.Set(bXmppKey, bTgId)
+ })
+}
+
+func (db *IdsDB) getByteValue(key []byte) ([]byte, error) {
+ var valCopy []byte
+ err := db.db.View(func(txn *badger.Txn) error {
+ item, err := txn.Get(key)
+ if err != nil {
+ return err
+ }
+
+ valCopy, err = item.ValueCopy(nil)
+ return err
+ })
+ return valCopy, err
+}
+
+// GetByTgIds obtains an XMPP id by Telegram chat/message ids
+func (db *IdsDB) GetByTgIds(tgAccount, xmppAccount string, tgChatId, tgMsgId int64) (string, error) {
+ val, err := db.getByteValue(toByteKey(
+ toKeyPrefix(tgAccount, xmppAccount),
+ toTgByteString(tgChatId, tgMsgId),
+ "tg",
+ ))
+ if err != nil {
+ return "", err
+ }
+ return string(val), nil
+}
+
+// GetByXmppId obtains Telegram chat/message ids by an XMPP id
+func (db *IdsDB) GetByXmppId(tgAccount, xmppAccount, xmppId string) (int64, int64, error) {
+ val, err := db.getByteValue(toByteKey(
+ toKeyPrefix(tgAccount, xmppAccount),
+ toXmppByteString(xmppId),
+ "xmpp",
+ ))
+ if err != nil {
+ return 0, 0, err
+ }
+ return splitTgByteString(val)
+}
+
+func toKeyPrefix(tgAccount, xmppAccount string) []byte {
+ return []byte(fmt.Sprintf("%v/%v/", tgAccount, xmppAccount))
+}
+
+func toByteKey(prefix, suffix []byte, typ string) []byte {
+ key := make([]byte, 0, len(prefix) + len(suffix) + 6)
+ key = append(key, prefix...)
+ key = append(key, []byte(typ)...)
+ key = append(key, []byte("/")...)
+ key = append(key, suffix...)
+ return key
+}
+
+func toTgByteString(tgChatId, tgMsgId int64) []byte {
+ return []byte(fmt.Sprintf("%v/%v", tgChatId, tgMsgId))
+}
+
+func toXmppByteString(xmppId string) []byte {
+ return []byte(xmppId)
+}
+
+func splitTgByteString(val []byte) (int64, int64, error) {
+ parts := bytes.Split(val, []byte("/"))
+ if len(parts) != 2 {
+ return 0, 0, errors.New("Couldn't parse tg id pair")
+ }
+ tgChatId, err := strconv.ParseInt(string(parts[0]), 10, 64)
+ if err != nil {
+ return 0, 0, err
+ }
+ tgMsgId, err := strconv.ParseInt(string(parts[1]), 10, 64)
+ return tgChatId, tgMsgId, err
+}
+
+// Gc compacts the value log
+func (db *IdsDB) Gc() {
+ db.db.RunValueLogGC(0.7)
+}
+
+// Close closes a DB
+func (db *IdsDB) Close() {
+ db.db.Close()
+}
diff --git a/badger/ids_test.go b/badger/ids_test.go
new file mode 100644
index 0000000..efafdeb
--- /dev/null
+++ b/badger/ids_test.go
@@ -0,0 +1,72 @@
+package badger
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestToKeyPrefix(t *testing.T) {
+ if !reflect.DeepEqual(toKeyPrefix("+123456789", "test@example.com"), []byte("+123456789/test@example.com/")) {
+ t.Error("Wrong prefix")
+ }
+}
+
+func TestToByteKey(t *testing.T) {
+ if !reflect.DeepEqual(toByteKey([]byte("ababa/galamaga/"), []byte("123"), "ppp"), []byte("ababa/galamaga/ppp/123")) {
+ t.Error("Wrong key")
+ }
+}
+
+func TestToTgByteString(t *testing.T) {
+ if !reflect.DeepEqual(toTgByteString(-2345, 6789), []byte("-2345/6789")) {
+ t.Error("Wrong tg string")
+ }
+}
+
+func TestToXmppByteString(t *testing.T) {
+ if !reflect.DeepEqual(toXmppByteString("aboba"), []byte("aboba")) {
+ t.Error("Wrong xmpp string")
+ }
+}
+
+func TestSplitTgByteStringUnparsable(t *testing.T) {
+ _, _, err := splitTgByteString([]byte("@#U*&$(@#"))
+ if err == nil {
+ t.Error("Unparsable should not be parsed")
+ return
+ }
+ if err.Error() != "Couldn't parse tg id pair" {
+ t.Error("Wrong parse error")
+ }
+}
+
+func TestSplitTgByteManyParts(t *testing.T) {
+ _, _, err := splitTgByteString([]byte("a/b/c/d"))
+ if err == nil {
+ t.Error("Should not parse many parts")
+ return
+ }
+ if err.Error() != "Couldn't parse tg id pair" {
+ t.Error("Wrong parse error")
+ }
+}
+
+func TestSplitTgByteNonNumeric(t *testing.T) {
+ _, _, err := splitTgByteString([]byte("0/a"))
+ if err == nil {
+ t.Error("Should not parse non-numeric msgid")
+ }
+}
+
+func TestSplitTgByteSuccess(t *testing.T) {
+ chatId, msgId, err := splitTgByteString([]byte("-198282398/23798478"))
+ if err != nil {
+ t.Error("Should be parsed well")
+ }
+ if chatId != -198282398 {
+ t.Error("Wrong chatId")
+ }
+ if msgId != 23798478 {
+ t.Error("Wrong msgId")
+ }
+}