Skip to content

Commit

Permalink
VectorOps.dot and MatrixOps.matMul are slightly faster
Browse files Browse the repository at this point in the history
  • Loading branch information
Peng Cheng committed Mar 27, 2022
1 parent bc90255 commit 00c194d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
4 changes: 4 additions & 0 deletions core/src/main/scala/shapesafe/core/shape/Names.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ object Names extends Tuples with ApplyLiterals.ToNames {
object Syntax extends Syntax

val i = Names("i")
val ii = Names("i", "i")

val ij = Names("i", "j")
val jk = Names("j", "k")
val ijjk = Names("i", "j", "j", "k")

val ik = Names("i", "k")
}
20 changes: 15 additions & 5 deletions core/src/main/scala/shapesafe/core/shape/ops/MatrixOps.scala
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
package shapesafe.core.shape.ops

import shapeless.Nat
import shapesafe.core.shape.{Index, Indices, Names, Shape}
import shapesafe.core.shape.Index.LtoR
import shapesafe.core.shape.{Indices, Names, Shape}

trait MatrixOps extends HasShape {

def matMul[THAT <: Shape](that: THAT) = {
val s1 = shape :<<= Names.ij
val s2 = that :<<= Names.jk
// val s1 = shape :<<= Names.ij
// val s2 = that :<<= Names.jk
//
// s1.einSum(s2) --> Names.ik

s1.einSum(s2) --> Names.ik
// TODO: above is slight slower, need to optimise

val outer = (shape >< that) :<<= Names.ijjk
outer.einSum --> Names.ik
}

def mat_*[THAT <: Shape](that: THAT) = {
matMul[that.type](that)
}

def transpose = {
shape.requireNumDim(Nat._2).rearrangeBy(Indices & Index.LtoR(1) & Index.LtoR(0))
shape
.requireNumDim(Nat._2)
.rearrangeBy(
Indices & LtoR(1) & LtoR(0)
)
}

def `^T` = {
Expand Down
11 changes: 8 additions & 3 deletions core/src/main/scala/shapesafe/core/shape/ops/VectorOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@ import shapesafe.core.shape.{Names, Shape}
trait VectorOps extends HasShape {

def dot[THAT <: Shape](that: THAT) = {
val s1 = shape :<<= Names.i
val s2 = that :<<= Names.i
// val s1 = shape :<<= Names.i
// val s2 = that :<<= Names.i
//
// (s1.einSum(s2) --> Names.Eye) >< Shape(1)

(s1.einSum(s2) --> Names.Eye) >< Shape(1)
// TODO: above is slight slower, need to optimise

val outer = (shape >< that) :<<= Names.ii
(outer.einSum --> Names.Eye) >< Shape._1
}

def cross[THAT <: Shape](that: THAT) = {
Expand Down

0 comments on commit 00c194d

Please sign in to comment.