-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathpattern.go
73 lines (62 loc) · 2.61 KB
/
pattern.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package paths
//go:generate core generate -add-types
import (
"cogentcore.org/lab/tensor"
)
// Pattern defines a pattern of connectivity between two layers.
// The pattern is stored efficiently using a bitslice tensor of binary values indicating
// presence or absence of connection between two items.
// A receiver-based organization is generally assumed but connectivity can go either way.
type Pattern interface {
// Name returns the name of the pattern -- i.e., the "type" name of the actual pattern generatop
Name() string
// Connect connects layers with the given shapes, returning the pattern of connectivity
// as a bits tensor with shape = recv + send shapes, using row-major ordering with outer-most
// indexes first (i.e., for each recv unit, there is a full inner-level of sender bits).
// The number of connections for each recv and each send unit are also returned in
// recvn and send tensors, each the shape of send and recv respectively.
// The same flag should be set to true if the send and recv layers are the same (i.e., a self-connection)
// often there are some different options for such connections.
Connect(send, recv *tensor.Shape, same bool) (sendn, recvn *tensor.Int32, cons *tensor.Bool)
}
// NewTensors returns the tensors used for Connect method, based on layer sizes
func NewTensors(send, recv *tensor.Shape) (sendn, recvn *tensor.Int32, cons *tensor.Bool) {
sendn = tensor.NewInt32(send.Sizes...)
recvn = tensor.NewInt32(recv.Sizes...)
csh := tensor.AddShapes(recv, send)
cons = tensor.NewBoolShape(csh)
return
}
// ConsStringFull returns a []byte string showing the pattern of connectivity.
// if perRecv is true then it displays the sending connections
// per each recv unit -- otherwise it shows the entire matrix
// as a 2D matrix
func ConsStringFull(send, recv *tensor.Shape, cons *tensor.Bool) []byte {
nsend := send.Len()
nrecv := recv.Len()
one := []byte("1 ")
zero := []byte("0 ")
sz := nrecv * (nsend*2 + 1)
b := make([]byte, 0, sz)
for ri := 0; ri < nrecv; ri++ {
for si := 0; si < nsend; si++ {
off := ri*nsend + si
cn := cons.Value1D(off)
if cn {
b = append(b, one...)
} else {
b = append(b, zero...)
}
}
b = append(b, byte('\n'))
}
return b
}
// ConsStringPerRecv returns a []byte string showing the pattern of connectivity
// organized by receiving unit, showing the sending connections per each
func ConsStringPerRecv(send, recv *tensor.Shape, cons *tensor.Bool) []byte {
return nil
}