package html import ( "bytes" "io" "strings" "golang.org/x/net/html" ) type Node struct { *html.Node } func ParseDocument(r io.Reader) (*Node, error) { n, err := html.Parse(r) if err != nil { return nil, err } return &Node{n}, nil } func ParseNode(r io.Reader) (*Node, error) { n, err := html.ParseFragment(r, nil) if err != nil { return nil, err } return &Node{n[0]}, nil } func NewTextNode(text string) *Node { return &Node{ Node: &html.Node{ Type: html.TextNode, DataAtom: 0x0, Data: text, }, } } func (n *Node) QuerySelector(selector string) *Node { if strings.HasPrefix(selector, "#") { return n.GetElementById(selector[1:]) } if strings.HasPrefix(selector, ".") { return n.GetElementByClass(selector[1:]) } return n.GetElementByTagName(selector) } func (n *Node) QuerySelectorAll(selector string) []*Node { return n.FindMany(func(n *Node) bool { if strings.HasPrefix(selector, "#") { return n.HasAttr("id", selector[1:]) } if strings.HasPrefix(selector, ".") { return n.HasClass(selector[1:]) } return n.Type == html.ElementNode && n.Data == selector }) } func (n *Node) GetElementById(id string) *Node { return n.FindOne(func(n *Node) bool { return n.HasAttr("id", id) }) } func (n *Node) GetElementByClass(class string) *Node { return n.FindOne(func(n *Node) bool { return n.HasClass(class) }) } func (n *Node) GetElementByTagName(name string) *Node { return n.FindOne(func(n *Node) bool { return n.Type == html.ElementNode && n.Data == name }) } func (n *Node) HasClass(class string) bool { return n.HasAttr("class", class) } func (n *Node) GetAttr(key string) []string { var res []string for _, attr := range n.Attr { if attr.Key == key { res = append(res, attr.Val) } } return res } func (n *Node) HasAttr(key, value string) bool { for _, attr := range n.Attr { if attr.Key == key && attr.Val == value { return true } } return false } func (n *Node) ForEach(cb func(n *Node)) { for c := n.FirstChild; c != nil; c = c.NextSibling { cb(&Node{c}) } } func (n *Node) ChildNodes() []*Node { var res []*Node n.ForEach(func(n *Node) { res = append(res, n) }) return res } func (n *Node) Children() []*Node { var res []*Node n.ForEach(func(n *Node) { if n.Type == html.ElementNode { res = append(res, n) } }) return res } func (n *Node) RemoveChild(other *Node) { // thanks go stdlib! n.Node.RemoveChild(other.Node) } func (n *Node) Traverse(cb func(n *Node)) { var f func(*Node) f = func(n *Node) { cb(n) n.ForEach(f) } f(n) } func (n *Node) FindOne(cb func(n *Node) bool) *Node { var res *Node var f func(*Node) f = func(n *Node) { if res != nil { return } if cb(n) { res = n return } n.ForEach(f) } f(n) return res } func (n *Node) FindMany(cb func(n *Node) bool) []*Node { var res []*Node var f func(*Node) f = func(n *Node) { if cb(n) { res = append(res, n) } n.ForEach(f) } f(n) return res } func (n *Node) Text() string { res := "" n.Traverse(func(n *Node) { if n.Type == html.TextNode { res += n.Data } }) return res } func (n *Node) TrimmedText() string { return strings.Trim(n.Text(), " \n\t") } func (n *Node) Render() (string, error) { w := bytes.NewBuffer([]byte{}) err := html.Render(w, n.Node) return w.String(), err }