diff options
Diffstat (limited to 'badger')
-rw-r--r-- | badger/ids.go | 132 | ||||
-rw-r--r-- | badger/ids_test.go | 72 |
2 files changed, 204 insertions, 0 deletions
diff --git a/badger/ids.go b/badger/ids.go new file mode 100644 index 0000000..80fb9ad --- /dev/null +++ b/badger/ids.go @@ -0,0 +1,132 @@ +package badger + +import ( + "bytes" + "errors" + "fmt" + "strconv" + + badger "github.com/dgraph-io/badger/v4" + log "github.com/sirupsen/logrus" +) + +// IdsDB represents a Badger database +type IdsDB struct { + db *badger.DB +} + +// IdsDBOpen returns a new DB object +func IdsDBOpen(path string) IdsDB { + bdb, err := badger.Open(badger.DefaultOptions(path)) + if err != nil { + log.Errorf("Failed to open ids database: %v, falling back to in-memory database", path) + bdb, err = badger.Open(badger.DefaultOptions("").WithInMemory(true)) + if err != nil { + log.Fatalf("Couldn't initialize the ids database") + } + } + + return IdsDB{ + db: bdb, + } +} + +// Set stores an id pair +func (db *IdsDB) Set(tgAccount, xmppAccount string, tgChatId, tgMsgId int64, xmppId string) error { + bPrefix := toKeyPrefix(tgAccount, xmppAccount) + bTgId := toTgByteString(tgChatId, tgMsgId) + bXmppId := toXmppByteString(xmppId) + bTgKey := toByteKey(bPrefix, bTgId, "tg") + bXmppKey := toByteKey(bPrefix, bXmppId, "xmpp") + + return db.db.Update(func(txn *badger.Txn) error { + if err := txn.Set(bTgKey, bXmppId); err != nil { + return err + } + return txn.Set(bXmppKey, bTgId) + }) +} + +func (db *IdsDB) getByteValue(key []byte) ([]byte, error) { + var valCopy []byte + err := db.db.View(func(txn *badger.Txn) error { + item, err := txn.Get(key) + if err != nil { + return err + } + + valCopy, err = item.ValueCopy(nil) + return err + }) + return valCopy, err +} + +// GetByTgIds obtains an XMPP id by Telegram chat/message ids +func (db *IdsDB) GetByTgIds(tgAccount, xmppAccount string, tgChatId, tgMsgId int64) (string, error) { + val, err := db.getByteValue(toByteKey( + toKeyPrefix(tgAccount, xmppAccount), + toTgByteString(tgChatId, tgMsgId), + "tg", + )) + if err != nil { + return "", err + } + return string(val), nil +} + +// GetByXmppId obtains Telegram chat/message ids by an XMPP id +func (db *IdsDB) GetByXmppId(tgAccount, xmppAccount, xmppId string) (int64, int64, error) { + val, err := db.getByteValue(toByteKey( + toKeyPrefix(tgAccount, xmppAccount), + toXmppByteString(xmppId), + "xmpp", + )) + if err != nil { + return 0, 0, err + } + return splitTgByteString(val) +} + +func toKeyPrefix(tgAccount, xmppAccount string) []byte { + return []byte(fmt.Sprintf("%v/%v/", tgAccount, xmppAccount)) +} + +func toByteKey(prefix, suffix []byte, typ string) []byte { + key := make([]byte, 0, len(prefix) + len(suffix) + 6) + key = append(key, prefix...) + key = append(key, []byte(typ)...) + key = append(key, []byte("/")...) + key = append(key, suffix...) + return key +} + +func toTgByteString(tgChatId, tgMsgId int64) []byte { + return []byte(fmt.Sprintf("%v/%v", tgChatId, tgMsgId)) +} + +func toXmppByteString(xmppId string) []byte { + return []byte(xmppId) +} + +func splitTgByteString(val []byte) (int64, int64, error) { + parts := bytes.Split(val, []byte("/")) + if len(parts) != 2 { + return 0, 0, errors.New("Couldn't parse tg id pair") + } + tgChatId, err := strconv.ParseInt(string(parts[0]), 10, 64) + if err != nil { + return 0, 0, err + } + tgMsgId, err := strconv.ParseInt(string(parts[1]), 10, 64) + return tgChatId, tgMsgId, err +} + +// Gc compacts the value log +func (db *IdsDB) Gc() { + db.db.RunValueLogGC(0.7) +} + +// Close closes a DB +func (db *IdsDB) Close() { + db.db.Close() +} diff --git a/badger/ids_test.go b/badger/ids_test.go new file mode 100644 index 0000000..efafdeb --- /dev/null +++ b/badger/ids_test.go @@ -0,0 +1,72 @@ +package badger + +import ( + "reflect" + "testing" +) + +func TestToKeyPrefix(t *testing.T) { + if !reflect.DeepEqual(toKeyPrefix("+123456789", "test@example.com"), []byte("+123456789/test@example.com/")) { + t.Error("Wrong prefix") + } +} + +func TestToByteKey(t *testing.T) { + if !reflect.DeepEqual(toByteKey([]byte("ababa/galamaga/"), []byte("123"), "ppp"), []byte("ababa/galamaga/ppp/123")) { + t.Error("Wrong key") + } +} + +func TestToTgByteString(t *testing.T) { + if !reflect.DeepEqual(toTgByteString(-2345, 6789), []byte("-2345/6789")) { + t.Error("Wrong tg string") + } +} + +func TestToXmppByteString(t *testing.T) { + if !reflect.DeepEqual(toXmppByteString("aboba"), []byte("aboba")) { + t.Error("Wrong xmpp string") + } +} + +func TestSplitTgByteStringUnparsable(t *testing.T) { + _, _, err := splitTgByteString([]byte("@#U*&$(@#")) + if err == nil { + t.Error("Unparsable should not be parsed") + return + } + if err.Error() != "Couldn't parse tg id pair" { + t.Error("Wrong parse error") + } +} + +func TestSplitTgByteManyParts(t *testing.T) { + _, _, err := splitTgByteString([]byte("a/b/c/d")) + if err == nil { + t.Error("Should not parse many parts") + return + } + if err.Error() != "Couldn't parse tg id pair" { + t.Error("Wrong parse error") + } +} + +func TestSplitTgByteNonNumeric(t *testing.T) { + _, _, err := splitTgByteString([]byte("0/a")) + if err == nil { + t.Error("Should not parse non-numeric msgid") + } +} + +func TestSplitTgByteSuccess(t *testing.T) { + chatId, msgId, err := splitTgByteString([]byte("-198282398/23798478")) + if err != nil { + t.Error("Should be parsed well") + } + if chatId != -198282398 { + t.Error("Wrong chatId") + } + if msgId != 23798478 { + t.Error("Wrong msgId") + } +} |