aboutsummaryrefslogtreecommitdiff
path: root/xmpp
diff options
context:
space:
mode:
Diffstat (limited to 'xmpp')
-rw-r--r--xmpp/extensions/extensions.go36
-rw-r--r--xmpp/gateway/gateway.go6
-rw-r--r--xmpp/handlers.go204
3 files changed, 244 insertions, 2 deletions
diff --git a/xmpp/extensions/extensions.go b/xmpp/extensions/extensions.go
index 192b630..8e2f743 100644
--- a/xmpp/extensions/extensions.go
+++ b/xmpp/extensions/extensions.go
@@ -193,6 +193,26 @@ type Replace struct {
Id string `xml:"id,attr"`
}
+// QueryRegister is from XEP-0077
+type QueryRegister struct {
+ XMLName xml.Name `xml:"jabber:iq:register query"`
+ Instructions string `xml:"instructions"`
+ Username string `xml:"username"`
+ Registered *QueryRegisterRegistered `xml:"registered"`
+ Remove *QueryRegisterRemove `xml:"remove"`
+ ResultSet *stanza.ResultSet `xml:"set,omitempty"`
+}
+
+// QueryRegisterRegistered is a child element from XEP-0077
+type QueryRegisterRegistered struct {
+ XMLName xml.Name `xml:"registered"`
+}
+
+// QueryRegisterRemove is a child element from XEP-0077
+type QueryRegisterRemove struct {
+ XMLName xml.Name `xml:"remove"`
+}
+
// Namespace is a namespace!
func (c PresenceNickExtension) Namespace() string {
return c.XMLName.Space
@@ -248,6 +268,16 @@ func (c Replace) Namespace() string {
return c.XMLName.Space
}
+// Namespace is a namespace!
+func (c QueryRegister) Namespace() string {
+ return c.XMLName.Space
+}
+
+// GetSet getsets!
+func (c QueryRegister) GetSet() *stanza.ResultSet {
+ return c.ResultSet
+}
+
// Name is a packet name
func (ClientMessage) Name() string {
return "message"
@@ -326,4 +356,10 @@ func init() {
"urn:xmpp:message-correct:0",
"replace",
}, Replace{})
+
+ // register query
+ stanza.TypeRegistry.MapExtension(stanza.PKTIQ, xml.Name{
+ "jabber:iq:register",
+ "query",
+ }, QueryRegister{})
}
diff --git a/xmpp/gateway/gateway.go b/xmpp/gateway/gateway.go
index 7a2500e..dfe2ebf 100644
--- a/xmpp/gateway/gateway.go
+++ b/xmpp/gateway/gateway.go
@@ -360,6 +360,12 @@ func ResumableSend(component *xmpp.Component, packet stanza.Packet) error {
return err
}
+// SubscribeToTransport ensures a two-way subscription to the transport
+func SubscribeToTransport(component *xmpp.Component, jid string) {
+ SendPresence(component, jid, SPType("subscribe"))
+ SendPresence(component, jid, SPType("subscribed"))
+}
+
// SplitJID tokenizes a JID string to bare JID and resource
func SplitJID(from string) (string, string, bool) {
fromJid, err := stanza.NewJid(from)
diff --git a/xmpp/handlers.go b/xmpp/handlers.go
index 36f9cf9..fdcf647 100644
--- a/xmpp/handlers.go
+++ b/xmpp/handlers.go
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/base64"
"encoding/xml"
+ "fmt"
"github.com/pkg/errors"
"io"
"strconv"
@@ -57,6 +58,22 @@ func HandleIq(s xmpp.Sender, p stanza.Packet) {
go handleGetDiscoInfo(s, iq)
return
}
+ _, ok = iq.Payload.(*stanza.DiscoItems)
+ if ok {
+ go handleGetDiscoItems(s, iq)
+ return
+ }
+ _, ok = iq.Payload.(*extensions.QueryRegister)
+ if ok {
+ go handleGetQueryRegister(s, iq)
+ return
+ }
+ } else if iq.Type == "set" {
+ query, ok := iq.Payload.(*extensions.QueryRegister)
+ if ok {
+ go handleSetQueryRegister(s, iq, query)
+ return
+ }
}
}
@@ -91,8 +108,7 @@ func HandleMessage(s xmpp.Sender, p stanza.Packet) {
session, ok := sessions[bare]
if !ok {
if msg.To == gatewayJid {
- gateway.SendPresence(component, msg.From, gateway.SPType("subscribe"))
- gateway.SendPresence(component, msg.From, gateway.SPType("subscribed"))
+ gateway.SubscribeToTransport(component, msg.From)
} else {
log.Error("Message from stranger")
}
@@ -444,6 +460,7 @@ func handleGetDiscoInfo(s xmpp.Sender, iq *stanza.IQ) {
disco.AddIdentity("", "account", "registered")
} else {
disco.AddIdentity("Telegram Gateway", "gateway", "telegram")
+ disco.AddFeatures("jabber:iq:register")
}
answer.Payload = disco
@@ -458,6 +475,189 @@ func handleGetDiscoInfo(s xmpp.Sender, iq *stanza.IQ) {
_ = gateway.ResumableSend(component, answer)
}
+func handleGetDiscoItems(s xmpp.Sender, iq *stanza.IQ) {
+ answer, err := stanza.NewIQ(stanza.Attrs{
+ Type: stanza.IQTypeResult,
+ From: iq.To,
+ To: iq.From,
+ Id: iq.Id,
+ Lang: "en",
+ })
+ if err != nil {
+ log.Errorf("Failed to create answer IQ: %v", err)
+ return
+ }
+
+ answer.Payload = answer.DiscoItems()
+
+ component, ok := s.(*xmpp.Component)
+ if !ok {
+ log.Error("Not a component")
+ return
+ }
+
+ _ = gateway.ResumableSend(component, answer)
+}
+
+func handleGetQueryRegister(s xmpp.Sender, iq *stanza.IQ) {
+ component, ok := s.(*xmpp.Component)
+ if !ok {
+ log.Error("Not a component")
+ return
+ }
+
+ answer, err := stanza.NewIQ(stanza.Attrs{
+ Type: stanza.IQTypeResult,
+ From: iq.To,
+ To: iq.From,
+ Id: iq.Id,
+ Lang: "en",
+ })
+ if err != nil {
+ log.Errorf("Failed to create answer IQ: %v", err)
+ return
+ }
+
+ var login string
+ bare, _, ok := gateway.SplitJID(iq.From)
+ if ok {
+ session, ok := sessions[bare]
+ if ok {
+ login = session.Session.Login
+ }
+ }
+
+ var query stanza.IQPayload
+ if login == "" {
+ query = extensions.QueryRegister{
+ Instructions: fmt.Sprintf("Authorization in Telegram is a multi-step process, so please accept %v to your contacts and follow further instructions (provide the authentication code there, etc.).\nFor now, please provide your login.", iq.To),
+ }
+ } else {
+ query = extensions.QueryRegister{
+ Instructions: "Already logged in",
+ Username: login,
+ Registered: &extensions.QueryRegisterRegistered{},
+ }
+ }
+ answer.Payload = query
+
+ log.Debugf("%#v", query)
+
+ _ = gateway.ResumableSend(component, answer)
+
+ if login == "" {
+ gateway.SubscribeToTransport(component, iq.From)
+ }
+}
+
+func handleSetQueryRegister(s xmpp.Sender, iq *stanza.IQ, query *extensions.QueryRegister) {
+ component, ok := s.(*xmpp.Component)
+ if !ok {
+ log.Error("Not a component")
+ return
+ }
+
+ answer, err := stanza.NewIQ(stanza.Attrs{
+ Type: stanza.IQTypeResult,
+ From: iq.To,
+ To: iq.From,
+ Id: iq.Id,
+ Lang: "en",
+ })
+ if err != nil {
+ log.Errorf("Failed to create answer IQ: %v", err)
+ return
+ }
+
+ defer gateway.ResumableSend(component, answer)
+
+ if query.Remove != nil {
+ iqAnswerSetError(answer, query, 405)
+ return
+ }
+
+ var login string
+ var session *telegram.Client
+ bare, resource, ok := gateway.SplitJID(iq.From)
+ if ok {
+ session, ok = sessions[bare]
+ if ok {
+ login = session.Session.Login
+ }
+ }
+
+ if login == "" {
+ if !ok {
+ session, ok = getTelegramInstance(bare, &persistence.Session{}, component)
+ if !ok {
+ iqAnswerSetError(answer, query, 500)
+ return
+ }
+ }
+
+ err := session.TryLogin(resource, query.Username)
+ if err != nil {
+ if err.Error() == telegram.TelegramAuthDone {
+ iqAnswerSetError(answer, query, 406)
+ } else {
+ iqAnswerSetError(answer, query, 500)
+ }
+ return
+ }
+
+ err = session.SetPhoneNumber(query.Username)
+ if err != nil {
+ iqAnswerSetError(answer, query, 500)
+ return
+ }
+
+ // everything okay, the response should be empty with no payload/error at this point
+ gateway.SubscribeToTransport(component, iq.From)
+ } else {
+ iqAnswerSetError(answer, query, 406)
+ }
+}
+
+func iqAnswerSetError(answer *stanza.IQ, payload *extensions.QueryRegister, code int) {
+ answer.Type = stanza.IQTypeError
+ answer.Payload = *payload
+ switch code {
+ case 400:
+ answer.Error = &stanza.Err{
+ Code: code,
+ Type: stanza.ErrorTypeModify,
+ Reason: "bad-request",
+ }
+ case 405:
+ answer.Error = &stanza.Err{
+ Code: code,
+ Type: stanza.ErrorTypeCancel,
+ Reason: "not-allowed",
+ Text: "Logging out is dangerous. If you are sure you would be able to receive the authentication code again, issue the /logout command to the transport",
+ }
+ case 406:
+ answer.Error = &stanza.Err{
+ Code: code,
+ Type: stanza.ErrorTypeModify,
+ Reason: "not-acceptable",
+ Text: "Phone number already provided, chat with the transport for further instruction",
+ }
+ case 500:
+ answer.Error = &stanza.Err{
+ Code: code,
+ Type: stanza.ErrorTypeWait,
+ Reason: "internal-server-error",
+ }
+ default:
+ log.Error("Unknown error code, falling back with empty reason")
+ answer.Error = &stanza.Err{
+ Code: code,
+ Type: stanza.ErrorTypeCancel,
+ Reason: "undefined-condition",
+ }
+ }
+}
+
func toToID(to string) (int64, bool) {
toParts := strings.Split(to, "@")
if len(toParts) < 2 {