diff --git a/common/parser.go b/common/parser.go
index a1803503..3dfa0efe 100644
--- a/common/parser.go
+++ b/common/parser.go
@@ -3,6 +3,7 @@ package common
import (
"bytes"
//"fmt"
+ //"log"
"encoding/json"
"io/ioutil"
"net/url"
@@ -650,7 +651,7 @@ func ParseMessage(msg string, sectionID int, sectionType string /*, user User*/)
} else {
sb.Write(URLOpenUser)
}
- sb.WriteString(msg[i : i+urlLen])
+ sb.WriteString(media.URL)
sb.Write(URLOpen2)
sb.WriteString(media.FURL)
sb.Write(URLClose)
@@ -852,15 +853,20 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) {
if !validateURLString(data) {
return media, false
}
- url, err := url.Parse(data)
+ uurl, err := url.Parse(data)
if err != nil {
return media, false
}
- hostname := url.Hostname()
- scheme := url.Scheme
- port := url.Port()
- query := url.Query()
+ hostname := uurl.Hostname()
+ scheme := uurl.Scheme
+ port := uurl.Port()
+ query, err := url.ParseQuery(uurl.RawQuery)
+ if err != nil {
+ return media, false
+ }
+ //log.Print("hostname:",hostname)
+ //log.Print("Site.URL:",Site.URL)
samesite := hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == Site.URL
if samesite {
@@ -871,12 +877,12 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) {
scheme = "https"
}
}
- if scheme == "" {
- scheme = "http"
+ if scheme != "" {
+ scheme += ":"
}
media.Trusted = samesite
- path := url.EscapedPath()
+ path := uurl.EscapedPath()
pathFrags := strings.Split(path, "/")
if len(pathFrags) >= 2 {
if samesite && pathFrags[1] == "attachs" && (scheme == "http" || scheme == "https") {
@@ -885,7 +891,7 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) {
if port != "443" && port != "80" && port != "" {
sport = ":" + port
}
- media.URL = scheme + "://" + hostname + sport + path
+ media.URL = scheme + "//" + hostname + sport + path
extarr := strings.Split(path, ".")
if len(extarr) == 0 {
// TODO: Write a unit test for this
@@ -923,7 +929,7 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) {
if port != "443" && port != "80" && port != "" {
sport = ":" + port
}
- media.URL = scheme + "://" + hostname + sport + path
+ media.URL = scheme + "//" + hostname + sport + path
return media, true
}
}
@@ -933,7 +939,16 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) {
if port != "443" && port != "80" && port != "" {
sport = ":" + port
}
- media.FURL = hostname + sport + path
+ var q string
+ if len(uurl.RawQuery) > 0 {
+ q = "?" + uurl.RawQuery
+ }
+ var frag string
+ if len(uurl.Fragment) > 0 {
+ frag = "#" + uurl.Fragment
+ }
+ media.URL = scheme + "//" + hostname + sport + path + q + frag
+ media.FURL = hostname + sport + path + q + frag
return media, true
}
@@ -943,7 +958,6 @@ func CoerceIntString(data string) (res int, length int) {
if !(data[0] > 47 && data[0] < 58) {
return 0, 1
}
-
i := 0
for ; len(data) > i; i++ {
if !(data[i] > 47 && data[i] < 58) {
diff --git a/parser_test.go b/parser_test.go
index f2cdc362..7cdc8724 100644
--- a/parser_test.go
+++ b/parser_test.go
@@ -122,8 +122,7 @@ func TestPreparser(t *testing.T) {
// TODO: Do a test with invalid UTF-8 input
for _, item := range l.Items {
- res := c.PreparseMessage(item.Msg)
- if res != item.Expects {
+ if res := c.PreparseMessage(item.Msg); res != item.Expects {
if item.Name != "" {
t.Error("Name: ", item.Name)
}
@@ -143,7 +142,7 @@ func TestParser(t *testing.T) {
l := &METriList{nil}
url := "github.com/Azareal/Gosora"
- eurl := "//" + url + ""
+ eurl := "" + url + ""
l.Add("", "")
l.Add("haha", "haha")
l.Add("t", "t")
@@ -185,31 +184,36 @@ func TestParser(t *testing.T) {
l.Add("ss", "ss")
l.Add("haha\nhaha\nhaha", "haha
haha
haha")
l.Add("//"+url, eurl)
- l.Add("//a", "//a")
- l.Add(" //a", " //a")
- l.Add("//a ", "//a ")
- l.Add(" //a ", " //a ")
- l.Add("d //a ", "d //a ")
- l.Add("ddd ddd //a ", "ddd ddd //a ")
- l.Add("https://"+url, "https://"+url+"")
- l.Add("https://t", "https://t")
- l.Add("http://"+url, "http://"+url+"")
+ l.Add("//a", "a")
+ l.Add(" //a", " a")
+ l.Add("//a ", "a ")
+ l.Add(" //a ", " a ")
+ l.Add("d //a ", "d a ")
+ l.Add("ddd ddd //a ", "ddd ddd a ")
+ l.Add("https://"+url, ""+url+"")
+ l.Add("https://t", "t")
+ l.Add("http://"+url, ""+url+"")
l.Add("#http://"+url, "#http://"+url)
l.Add("@http://"+url, "[Invalid Profile]ttp://"+url)
- l.Add("//"+url+"\n", "//"+url+"
")
+ l.Add("//"+url+"\n", ""+url+"
")
l.Add("\n//"+url, "
"+eurl)
l.Add("\n//"+url+"\n", "
"+eurl+"
")
l.Add("\n//"+url+"\n\n", "
"+eurl+"
")
l.Add("//"+url+"\n//"+url, eurl+"
"+eurl)
l.Add("//"+url+"\n\n//"+url, eurl+"
"+eurl)
- l.Add("//"+c.Site.URL, "//"+c.Site.URL+"")
- l.Add("//"+c.Site.URL+"\n", "//"+c.Site.URL+"
")
- l.Add("//"+c.Site.URL+"\n//"+c.Site.URL, "//"+c.Site.URL+"
//"+c.Site.URL+"")
- local := func(url string) {
- l.Add("//"+url, "//"+url+"")
- l.Add("//"+url+"\n", "//"+url+"
")
- l.Add("//"+url+"\n//"+url, "//"+url+"
//"+url+"")
+ local := func(u string) {
+ s := "//" + c.Site.URL
+ fs := "http://" + c.Site.URL
+ if c.Site.EnableSsl {
+ s = "https:" + s
+ fs = "https://" + c.Site.URL
+ }
+ l.Add("//"+u, ""+c.Site.URL+"")
+ l.Add("//"+u+"\n", ""+c.Site.URL+"
")
+ l.Add("//"+u+"\n//"+u, ""+c.Site.URL+"
"+c.Site.URL+"")
+ l.Add("http://"+u, ""+c.Site.URL+"")
+ l.Add("https://"+u, ""+c.Site.URL+"")
}
local("localhost")
local("127.0.0.1")
@@ -243,8 +247,9 @@ func TestParser(t *testing.T) {
l.Add("@ #tid-@", "[Invalid Profile]#tid-@")
l.Add("#tid-1 #tid-1", "#tid-1 #tid-1")
l.Add("#tid-0", "[Invalid Topic]")
- l.Add("https://"+url+"/#tid-1", "https://"+url+"/#tid-1")
- l.Add("https://"+url+"/?hi=2", "https://"+url+"/?hi=2")
+ l.Add("https://"+url+"/#tid-1", ""+url+"/#tid-1")
+ l.Add("https://"+url+"/?hi=2", ""+url+"/?hi=2")
+ l.Add("https://"+url+"/?hi=2#t=1", ""+url+"/?hi=2#t=1")
l.Add("#fid-1", "#fid-1")
l.Add(" #fid-1", " #fid-1")
l.Add("#fid-0", "[Invalid Forum]")
@@ -283,8 +288,7 @@ func TestParser(t *testing.T) {
l.Add("@-1", "[Invalid Profile]1")
for _, item := range l.Items {
- res := c.ParseMessage(item.Msg, 1, "forums")
- if res != item.Expects {
+ if res := c.ParseMessage(item.Msg, 1, "forums"); res != item.Expects {
if item.Name != "" {
t.Error("Name: ", item.Name)
}
@@ -295,6 +299,25 @@ func TestParser(t *testing.T) {
}
}
+ l = &METriList{nil}
+ pre := c.Site.URL // Just in case this is localhost...
+ c.Site.URL = "example.com"
+ l.Add("//"+c.Site.URL, ""+c.Site.URL+"")
+ l.Add("//"+c.Site.URL+"\n", ""+c.Site.URL+"
")
+ l.Add("//"+c.Site.URL+"\n//"+c.Site.URL, ""+c.Site.URL+"
"+c.Site.URL+"")
+ for _, item := range l.Items {
+ if res := c.ParseMessage(item.Msg, 1, "forums"); res != item.Expects {
+ if item.Name != "" {
+ t.Error("Name: ", item.Name)
+ }
+ t.Error("Testing string '" + item.Msg + "'")
+ t.Error("Bad output:", "'"+res+"'")
+ t.Error("Expected:", "'"+item.Expects+"'")
+ break
+ }
+ }
+ c.Site.URL = pre
+
c.AddHashLinkType("nnid-", func(sb *strings.Builder, msg string, i *int) {
tid, intLen := c.CoerceIntString(msg[*i:])
*i += intLen