diff --git a/CoreMLBert.xcodeproj/project.pbxproj b/CoreMLBert.xcodeproj/project.pbxproj index a9c8b35..1f04eb4 100644 --- a/CoreMLBert.xcodeproj/project.pbxproj +++ b/CoreMLBert.xcodeproj/project.pbxproj @@ -68,6 +68,10 @@ 79F2CCA022C666C7009F8551 /* question_tokens.json in Resources */ = {isa = PBXBuildFile; fileRef = 79F2CC9F22C666C7009F8551 /* question_tokens.json */; }; 79F2CCA222C6717E009F8551 /* LoaderView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79F2CCA122C6717D009F8551 /* LoaderView.swift */; }; 79F7060E22EA0CA900C4432C /* BERTSQUADFP16.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 79F2CC9022C5590C009F8551 /* BERTSQUADFP16.mlmodel */; }; + 7E7B9A4029099D44004914F5 /* gpt2-512.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 7E7B9A3F29099D44004914F5 /* gpt2-512.mlmodel */; }; + 7E7B9A4129099D44004914F5 /* gpt2-512.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 7E7B9A3F29099D44004914F5 /* gpt2-512.mlmodel */; }; + 7E7B9A4329099ED1004914F5 /* gpt2.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 7E7B9A4229099ED1004914F5 /* gpt2.mlmodel */; }; + 7E7B9A4429099ED1004914F5 /* gpt2.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 7E7B9A4229099ED1004914F5 /* gpt2.mlmodel */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -136,6 +140,8 @@ 79F2CC9D22C57825009F8551 /* BertForQATests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = BertForQATests.swift; sourceTree = ""; }; 79F2CC9F22C666C7009F8551 /* question_tokens.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = question_tokens.json; sourceTree = ""; }; 79F2CCA122C6717D009F8551 /* LoaderView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = LoaderView.swift; sourceTree = ""; }; + 7E7B9A3F29099D44004914F5 /* gpt2-512.mlmodel */ = {isa = PBXFileReference; lastKnownFileType = file.mlmodel; path = "gpt2-512.mlmodel"; sourceTree = ""; }; + 7E7B9A4229099ED1004914F5 /* gpt2.mlmodel */ = {isa = PBXFileReference; lastKnownFileType = file.mlmodel; path = gpt2.mlmodel; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -179,7 +185,9 @@ 79F2CC8422C50F87009F8551 /* vocab.txt */, 79F2CC8722C51E03009F8551 /* basic_tokenized_questions.json */, 79F2CC8D22C55413009F8551 /* tokenized_questions.json */, + 7E7B9A3F29099D44004914F5 /* gpt2-512.mlmodel */, 79F2CC9F22C666C7009F8551 /* question_tokens.json */, + 7E7B9A4229099ED1004914F5 /* gpt2.mlmodel */, 79F2CC9022C5590C009F8551 /* BERTSQUADFP16.mlmodel */, 79908C13234E95FB00D0FE5B /* distilbert-squad-384.mlmodel */, 79908C18234EAB5300D0FE5B /* distilbert-squad-384_FP16.mlmodel */, @@ -462,7 +470,9 @@ 796DF55022E1026700140C02 /* AppDelegate.swift in Sources */, 796DF57222E1039C00140C02 /* Utils.swift in Sources */, 79D94AD6234CE4830033EA7D /* gpt2-64-12.mlmodel in Sources */, + 7E7B9A4129099D44004914F5 /* gpt2-512.mlmodel in Sources */, 796DF57422E1039C00140C02 /* MLMultiArray+Utils.swift in Sources */, + 7E7B9A4429099ED1004914F5 /* gpt2.mlmodel in Sources */, 796DF57022E1039C00140C02 /* BertTokenizer.swift in Sources */, 796DF57522E1039C00140C02 /* GPT2Tokenizer.swift in Sources */, 796DF55222E1026700140C02 /* SceneDelegate.swift in Sources */, @@ -486,6 +496,7 @@ buildActionMask = 2147483647; files = ( 79F2CC9A22C57132009F8551 /* MLMultiArray+Utils.swift in Sources */, + 7E7B9A4329099ED1004914F5 /* gpt2.mlmodel in Sources */, 79F2CCA222C6717E009F8551 /* LoaderView.swift in Sources */, 79F2CC6022C50078009F8551 /* ViewController.swift in Sources */, 79F2CC8122C5041C009F8551 /* SquadDataset.swift in Sources */, @@ -495,6 +506,7 @@ 79F2CC5C22C50078009F8551 /* AppDelegate.swift in Sources */, 796DF51022E0EB1D00140C02 /* GPT2Tokenizer.swift in Sources */, 79F2CC9422C56693009F8551 /* BertForQuestionAnswering.swift in Sources */, + 7E7B9A4029099D44004914F5 /* gpt2-512.mlmodel in Sources */, 79908C19234EAB5400D0FE5B /* distilbert-squad-384_FP16.mlmodel in Sources */, 79F2CC9122C5590C009F8551 /* BERTSQUADFP16.mlmodel in Sources */, 79F2CC8322C50E00009F8551 /* BertTokenizer.swift in Sources */, diff --git a/Sources/GPT2.swift b/Sources/GPT2.swift index 7865bea..0015ad1 100644 --- a/Sources/GPT2.swift +++ b/Sources/GPT2.swift @@ -21,9 +21,9 @@ class GPT2 { case topP(Double) } - private let model = distilgpt2_64_6() + private let model = gpt2_512() public let tokenizer = GPT2Tokenizer() - public let seqLen = 64 + public let seqLen = 512 private let strategy: DecodingStrategy init(strategy: DecodingStrategy = .greedy) {