diff --git a/common/parser.go b/common/parser.go index a1803503..3dfa0efe 100644 --- a/common/parser.go +++ b/common/parser.go @@ -3,6 +3,7 @@ package common import ( "bytes" //"fmt" + //"log" "encoding/json" "io/ioutil" "net/url" @@ -650,7 +651,7 @@ func ParseMessage(msg string, sectionID int, sectionType string /*, user User*/) } else { sb.Write(URLOpenUser) } - sb.WriteString(msg[i : i+urlLen]) + sb.WriteString(media.URL) sb.Write(URLOpen2) sb.WriteString(media.FURL) sb.Write(URLClose) @@ -852,15 +853,20 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { if !validateURLString(data) { return media, false } - url, err := url.Parse(data) + uurl, err := url.Parse(data) if err != nil { return media, false } - hostname := url.Hostname() - scheme := url.Scheme - port := url.Port() - query := url.Query() + hostname := uurl.Hostname() + scheme := uurl.Scheme + port := uurl.Port() + query, err := url.ParseQuery(uurl.RawQuery) + if err != nil { + return media, false + } + //log.Print("hostname:",hostname) + //log.Print("Site.URL:",Site.URL) samesite := hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == Site.URL if samesite { @@ -871,12 +877,12 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { scheme = "https" } } - if scheme == "" { - scheme = "http" + if scheme != "" { + scheme += ":" } media.Trusted = samesite - path := url.EscapedPath() + path := uurl.EscapedPath() pathFrags := strings.Split(path, "/") if len(pathFrags) >= 2 { if samesite && pathFrags[1] == "attachs" && (scheme == "http" || scheme == "https") { @@ -885,7 +891,7 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { if port != "443" && port != "80" && port != "" { sport = ":" + port } - media.URL = scheme + "://" + hostname + sport + path + media.URL = scheme + "//" + hostname + sport + path extarr := strings.Split(path, ".") if len(extarr) == 0 { // TODO: Write a unit test for this @@ -923,7 +929,7 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { if port != "443" && port != "80" && port != "" { sport = ":" + port } - media.URL = scheme + "://" + hostname + sport + path + media.URL = scheme + "//" + hostname + sport + path return media, true } } @@ -933,7 +939,16 @@ func parseMediaString(data string) (media MediaEmbed, ok bool) { if port != "443" && port != "80" && port != "" { sport = ":" + port } - media.FURL = hostname + sport + path + var q string + if len(uurl.RawQuery) > 0 { + q = "?" + uurl.RawQuery + } + var frag string + if len(uurl.Fragment) > 0 { + frag = "#" + uurl.Fragment + } + media.URL = scheme + "//" + hostname + sport + path + q + frag + media.FURL = hostname + sport + path + q + frag return media, true } @@ -943,7 +958,6 @@ func CoerceIntString(data string) (res int, length int) { if !(data[0] > 47 && data[0] < 58) { return 0, 1 } - i := 0 for ; len(data) > i; i++ { if !(data[i] > 47 && data[i] < 58) { diff --git a/parser_test.go b/parser_test.go index f2cdc362..7cdc8724 100644 --- a/parser_test.go +++ b/parser_test.go @@ -122,8 +122,7 @@ func TestPreparser(t *testing.T) { // TODO: Do a test with invalid UTF-8 input for _, item := range l.Items { - res := c.PreparseMessage(item.Msg) - if res != item.Expects { + if res := c.PreparseMessage(item.Msg); res != item.Expects { if item.Name != "" { t.Error("Name: ", item.Name) } @@ -143,7 +142,7 @@ func TestParser(t *testing.T) { l := &METriList{nil} url := "github.com/Azareal/Gosora" - eurl := "//" + url + "" + eurl := "" + url + "" l.Add("", "") l.Add("haha", "haha") l.Add("t", "t") @@ -185,31 +184,36 @@ func TestParser(t *testing.T) { l.Add("ss", "ss") l.Add("haha\nhaha\nhaha", "haha
haha
haha") l.Add("//"+url, eurl) - l.Add("//a", "//a") - l.Add(" //a", " //a") - l.Add("//a ", "//a ") - l.Add(" //a ", " //a ") - l.Add("d //a ", "d //a ") - l.Add("ddd ddd //a ", "ddd ddd //a ") - l.Add("https://"+url, "https://"+url+"") - l.Add("https://t", "https://t") - l.Add("http://"+url, "http://"+url+"") + l.Add("//a", "a") + l.Add(" //a", " a") + l.Add("//a ", "a ") + l.Add(" //a ", " a ") + l.Add("d //a ", "d a ") + l.Add("ddd ddd //a ", "ddd ddd a ") + l.Add("https://"+url, ""+url+"") + l.Add("https://t", "t") + l.Add("http://"+url, ""+url+"") l.Add("#http://"+url, "#http://"+url) l.Add("@http://"+url, "[Invalid Profile]ttp://"+url) - l.Add("//"+url+"\n", "//"+url+"
") + l.Add("//"+url+"\n", ""+url+"
") l.Add("\n//"+url, "
"+eurl) l.Add("\n//"+url+"\n", "
"+eurl+"
") l.Add("\n//"+url+"\n\n", "
"+eurl+"

") l.Add("//"+url+"\n//"+url, eurl+"
"+eurl) l.Add("//"+url+"\n\n//"+url, eurl+"

"+eurl) - l.Add("//"+c.Site.URL, "//"+c.Site.URL+"") - l.Add("//"+c.Site.URL+"\n", "//"+c.Site.URL+"
") - l.Add("//"+c.Site.URL+"\n//"+c.Site.URL, "//"+c.Site.URL+"
//"+c.Site.URL+"") - local := func(url string) { - l.Add("//"+url, "//"+url+"") - l.Add("//"+url+"\n", "//"+url+"
") - l.Add("//"+url+"\n//"+url, "//"+url+"
//"+url+"") + local := func(u string) { + s := "//" + c.Site.URL + fs := "http://" + c.Site.URL + if c.Site.EnableSsl { + s = "https:" + s + fs = "https://" + c.Site.URL + } + l.Add("//"+u, ""+c.Site.URL+"") + l.Add("//"+u+"\n", ""+c.Site.URL+"
") + l.Add("//"+u+"\n//"+u, ""+c.Site.URL+"
"+c.Site.URL+"") + l.Add("http://"+u, ""+c.Site.URL+"") + l.Add("https://"+u, ""+c.Site.URL+"") } local("localhost") local("127.0.0.1") @@ -243,8 +247,9 @@ func TestParser(t *testing.T) { l.Add("@ #tid-@", "[Invalid Profile]#tid-@") l.Add("#tid-1 #tid-1", "#tid-1 #tid-1") l.Add("#tid-0", "[Invalid Topic]") - l.Add("https://"+url+"/#tid-1", "https://"+url+"/#tid-1") - l.Add("https://"+url+"/?hi=2", "https://"+url+"/?hi=2") + l.Add("https://"+url+"/#tid-1", ""+url+"/#tid-1") + l.Add("https://"+url+"/?hi=2", ""+url+"/?hi=2") + l.Add("https://"+url+"/?hi=2#t=1", ""+url+"/?hi=2#t=1") l.Add("#fid-1", "#fid-1") l.Add(" #fid-1", " #fid-1") l.Add("#fid-0", "[Invalid Forum]") @@ -283,8 +288,7 @@ func TestParser(t *testing.T) { l.Add("@-1", "[Invalid Profile]1") for _, item := range l.Items { - res := c.ParseMessage(item.Msg, 1, "forums") - if res != item.Expects { + if res := c.ParseMessage(item.Msg, 1, "forums"); res != item.Expects { if item.Name != "" { t.Error("Name: ", item.Name) } @@ -295,6 +299,25 @@ func TestParser(t *testing.T) { } } + l = &METriList{nil} + pre := c.Site.URL // Just in case this is localhost... + c.Site.URL = "example.com" + l.Add("//"+c.Site.URL, ""+c.Site.URL+"") + l.Add("//"+c.Site.URL+"\n", ""+c.Site.URL+"
") + l.Add("//"+c.Site.URL+"\n//"+c.Site.URL, ""+c.Site.URL+"
"+c.Site.URL+"") + for _, item := range l.Items { + if res := c.ParseMessage(item.Msg, 1, "forums"); res != item.Expects { + if item.Name != "" { + t.Error("Name: ", item.Name) + } + t.Error("Testing string '" + item.Msg + "'") + t.Error("Bad output:", "'"+res+"'") + t.Error("Expected:", "'"+item.Expects+"'") + break + } + } + c.Site.URL = pre + c.AddHashLinkType("nnid-", func(sb *strings.Builder, msg string, i *int) { tid, intLen := c.CoerceIntString(msg[*i:]) *i += intLen