html/html.go

196 lines
3.3 KiB
Go

package html
import (
"bytes"
"io"
"strings"
"golang.org/x/net/html"
)
type Node struct {
*html.Node
}
func ParseDocument(r io.Reader) (*Node, error) {
n, err := html.Parse(r)
if err != nil {
return nil, err
}
return &Node{n}, nil
}
func ParseNode(r io.Reader) (*Node, error) {
n, err := html.ParseFragment(r, nil)
if err != nil {
return nil, err
}
return &Node{n[0]}, nil
}
func NewTextNode(text string) *Node {
return &Node{
Node: &html.Node{
Type: html.TextNode,
DataAtom: 0x0,
Data: text,
},
}
}
func (n *Node) QuerySelector(selector string) *Node {
if strings.HasPrefix(selector, "#") {
return n.GetElementById(selector[1:])
}
if strings.HasPrefix(selector, ".") {
return n.GetElementByClass(selector[1:])
}
return n.GetElementByTagName(selector)
}
func (n *Node) QuerySelectorAll(selector string) []*Node {
return n.FindMany(func(n *Node) bool {
if strings.HasPrefix(selector, "#") {
return n.HasAttr("id", selector[1:])
}
if strings.HasPrefix(selector, ".") {
return n.HasClass(selector[1:])
}
return n.Type == html.ElementNode && n.Data == selector
})
}
func (n *Node) GetElementById(id string) *Node {
return n.FindOne(func(n *Node) bool {
return n.HasAttr("id", id)
})
}
func (n *Node) GetElementByClass(class string) *Node {
return n.FindOne(func(n *Node) bool {
return n.HasClass(class)
})
}
func (n *Node) GetElementByTagName(name string) *Node {
return n.FindOne(func(n *Node) bool {
return n.Type == html.ElementNode && n.Data == name
})
}
func (n *Node) HasClass(class string) bool {
return n.HasAttr("class", class)
}
func (n *Node) GetAttr(key string) []string {
var res []string
for _, attr := range n.Attr {
if attr.Key == key {
res = append(res, attr.Val)
}
}
return res
}
func (n *Node) HasAttr(key, value string) bool {
for _, attr := range n.Attr {
if attr.Key == key && attr.Val == value {
return true
}
}
return false
}
func (n *Node) ForEach(cb func(n *Node)) {
for c := n.FirstChild; c != nil; c = c.NextSibling {
cb(&Node{c})
}
}
func (n *Node) ChildNodes() []*Node {
var res []*Node
n.ForEach(func(n *Node) {
res = append(res, n)
})
return res
}
func (n *Node) Children() []*Node {
var res []*Node
n.ForEach(func(n *Node) {
if n.Type == html.ElementNode {
res = append(res, n)
}
})
return res
}
func (n *Node) RemoveChild(other *Node) {
// thanks go stdlib!
n.Node.RemoveChild(other.Node)
}
func (n *Node) Traverse(cb func(n *Node)) {
var f func(*Node)
f = func(n *Node) {
cb(n)
n.ForEach(f)
}
f(n)
}
func (n *Node) FindOne(cb func(n *Node) bool) *Node {
var res *Node
var f func(*Node)
f = func(n *Node) {
if res != nil {
return
}
if cb(n) {
res = n
return
}
n.ForEach(f)
}
f(n)
return res
}
func (n *Node) FindMany(cb func(n *Node) bool) []*Node {
var res []*Node
var f func(*Node)
f = func(n *Node) {
if cb(n) {
res = append(res, n)
}
n.ForEach(f)
}
f(n)
return res
}
func (n *Node) Text() string {
res := ""
n.Traverse(func(n *Node) {
if n.Type == html.TextNode {
res += n.Data
}
})
return res
}
func (n *Node) TrimmedText() string {
return strings.Trim(n.Text(), " \n\t")
}
func (n *Node) Render() (string, error) {
w := bytes.NewBuffer([]byte{})
err := html.Render(w, n.Node)
return w.String(), err
}